cargo_cleaner/
notify_rw_lock.rs1use atomic_wait::{wait, wake_all, wake_one};
2use std::cell::UnsafeCell;
3use std::ops::{Deref, DerefMut};
4use std::sync::atomic::{AtomicU32, Ordering};
5
6const WRITE_LOCK_STATE: u32 = u32::MAX;
7const READ_LOCK_STEP: u32 = 2;
8
9pub type NotifySender = std::sync::mpsc::SyncSender<()>;
10
11pub struct NotifyRwLock<T> {
12 state: AtomicU32,
15 writer_wake_counter: AtomicU32,
16 value: UnsafeCell<T>,
17 notify_tx: NotifySender,
18}
19
20impl<T> NotifyRwLock<T> {
21 pub fn new(notify_tx: NotifySender, value: T) -> Self {
22 Self {
23 state: AtomicU32::new(0),
24 writer_wake_counter: AtomicU32::new(0),
25 value: UnsafeCell::new(value),
26 notify_tx,
27 }
28 }
29
30 pub fn read(&self) -> ReadGuard<'_, T> {
31 let mut s = self.state.load(Ordering::Relaxed);
32
33 loop {
34 if s % 2 == 0 {
35 assert!(s < u32::MAX - 2, "too many readers");
36 match self.state.compare_exchange_weak(
37 s,
38 s + READ_LOCK_STEP,
39 Ordering::Acquire,
40 Ordering::Relaxed,
41 ) {
42 Ok(_) => return ReadGuard { rwlock: self },
43 Err(e) => {
44 s = e;
45 }
46 }
47 }
48 if s % 2 == 1 {
49 wait(&self.state, s);
50 s = self.state.load(Ordering::Relaxed);
51 }
52 }
53 }
54
55 pub fn write(&self) -> WriteGuard<'_, T> {
56 let mut s = self.state.load(Ordering::Relaxed);
57
58 loop {
59 if s <= 1 {
60 match self.state.compare_exchange(
61 s,
62 WRITE_LOCK_STATE,
63 Ordering::Acquire,
64 Ordering::Relaxed,
65 ) {
66 Ok(_) => {
67 return WriteGuard { rwlock: self };
68 }
69 Err(e) => {
70 s = e;
71 continue;
72 }
73 }
74 }
75 if s % 2 == 0 {
76 if let Err(e) =
77 self.state
78 .compare_exchange(s, s + 1, Ordering::Relaxed, Ordering::Relaxed)
79 {
80 s = e;
81 continue;
82 }
83 }
84 let w = self.writer_wake_counter.load(Ordering::Acquire);
85 s = self.state.load(Ordering::Relaxed);
86 if s >= READ_LOCK_STEP {
87 wait(&self.writer_wake_counter, w);
88 s = self.state.load(Ordering::Relaxed);
89 }
90 }
91 }
92}
93
94unsafe impl<T: Send + Sync> Sync for NotifyRwLock<T> {}
95
96pub struct ReadGuard<'a, T> {
97 rwlock: &'a NotifyRwLock<T>,
98}
99
100impl<T> Drop for ReadGuard<'_, T> {
101 fn drop(&mut self) {
102 if self
103 .rwlock
104 .state
105 .fetch_sub(READ_LOCK_STEP, Ordering::Release)
106 == 3
107 {
108 self.rwlock
109 .writer_wake_counter
110 .fetch_add(1, Ordering::Release);
111 wake_one(&self.rwlock.writer_wake_counter);
112 }
113 }
114}
115
116impl<T> Deref for ReadGuard<'_, T> {
117 type Target = T;
118
119 fn deref(&self) -> &Self::Target {
120 unsafe { &*self.rwlock.value.get() }
121 }
122}
123
124pub struct WriteGuard<'a, T> {
125 rwlock: &'a NotifyRwLock<T>,
126}
127
128impl<T> Drop for WriteGuard<'_, T> {
129 fn drop(&mut self) {
130 self.rwlock.state.store(0, Ordering::Release);
131 self.rwlock
132 .writer_wake_counter
133 .fetch_add(1, Ordering::Release);
134 wake_one(&self.rwlock.writer_wake_counter);
135 wake_all(&self.rwlock.state);
136 let _ = self.rwlock.notify_tx.try_send(()); }
138}
139
140impl<T> Deref for WriteGuard<'_, T> {
141 type Target = T;
142
143 fn deref(&self) -> &Self::Target {
144 unsafe { &*self.rwlock.value.get() }
145 }
146}
147
148impl<T> DerefMut for WriteGuard<'_, T> {
149 fn deref_mut(&mut self) -> &mut Self::Target {
150 unsafe { &mut *self.rwlock.value.get() }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use std::hint::black_box;
158 use std::time::Instant;
159
160 #[test]
161 fn add_list() {
162 let (tx, _rx) = std::sync::mpsc::sync_channel(1000);
163 let waiter_list = NotifyRwLock::new(tx, Vec::new());
164 black_box(&waiter_list);
165
166 let start = Instant::now();
167 std::thread::scope(|s| {
168 let t1 = s.spawn({
169 || {
170 for i in 0..1000 {
171 let mut c = waiter_list.write();
172 black_box(&c);
173 c.push(i);
174 }
175 }
176 });
177 let t2 = s.spawn({
178 || {
179 for i in 0..1000 {
180 let mut c = waiter_list.write();
181 black_box(&c);
182 c.push(i);
183 }
184 }
185 });
186 let t3 = s.spawn({
187 || {
188 for i in 0..1000 {
189 let c = waiter_list.read();
190 black_box(&c);
191 }
192 }
193 });
194 t1.join().unwrap();
195 t2.join().unwrap();
196 t3.join().unwrap();
197 });
198 assert_eq!(waiter_list.read().len(), 2_000);
199 }
200}