cargo_cleaner/
notify_rw_lock.rs

1use atomic_wait::{wait, wake_all, wake_one};
2use std::cell::UnsafeCell;
3use std::ops::{Deref, DerefMut};
4use std::sync::atomic::{AtomicU32, Ordering};
5
6const WRITE_LOCK_STATE: u32 = u32::MAX;
7const READ_LOCK_STEP: u32 = 2;
8
9pub type NotifySender = std::sync::mpsc::SyncSender<()>;
10
11pub struct NotifyRwLock<T> {
12    //2刻みでカウントアップされていく、リードロックのカウント
13    // 奇数の場合は、ライトロックが待っていることを指す
14    state: AtomicU32,
15    writer_wake_counter: AtomicU32,
16    value: UnsafeCell<T>,
17    notify_tx: NotifySender,
18}
19
20impl<T> NotifyRwLock<T> {
21    pub fn new(notify_tx: NotifySender, value: T) -> Self {
22        Self {
23            state: AtomicU32::new(0),
24            writer_wake_counter: AtomicU32::new(0),
25            value: UnsafeCell::new(value),
26            notify_tx,
27        }
28    }
29
30    pub fn read(&self) -> ReadGuard<'_, T> {
31        let mut s = self.state.load(Ordering::Relaxed);
32
33        loop {
34            if s % 2 == 0 {
35                assert!(s < u32::MAX - 2, "too many readers");
36                match self.state.compare_exchange_weak(
37                    s,
38                    s + READ_LOCK_STEP,
39                    Ordering::Acquire,
40                    Ordering::Relaxed,
41                ) {
42                    Ok(_) => return ReadGuard { rwlock: self },
43                    Err(e) => {
44                        s = e;
45                    }
46                }
47            }
48            if s % 2 == 1 {
49                wait(&self.state, s);
50                s = self.state.load(Ordering::Relaxed);
51            }
52        }
53    }
54
55    pub fn write(&self) -> WriteGuard<'_, T> {
56        let mut s = self.state.load(Ordering::Relaxed);
57
58        loop {
59            if s <= 1 {
60                match self.state.compare_exchange(
61                    s,
62                    WRITE_LOCK_STATE,
63                    Ordering::Acquire,
64                    Ordering::Relaxed,
65                ) {
66                    Ok(_) => {
67                        return WriteGuard { rwlock: self };
68                    }
69                    Err(e) => {
70                        s = e;
71                        continue;
72                    }
73                }
74            }
75            if s % 2 == 0 {
76                if let Err(e) =
77                    self.state
78                        .compare_exchange(s, s + 1, Ordering::Relaxed, Ordering::Relaxed)
79                {
80                    s = e;
81                    continue;
82                }
83            }
84            let w = self.writer_wake_counter.load(Ordering::Acquire);
85            s = self.state.load(Ordering::Relaxed);
86            if s >= READ_LOCK_STEP {
87                wait(&self.writer_wake_counter, w);
88                s = self.state.load(Ordering::Relaxed);
89            }
90        }
91    }
92}
93
94unsafe impl<T: Send + Sync> Sync for NotifyRwLock<T> {}
95
96pub struct ReadGuard<'a, T> {
97    rwlock: &'a NotifyRwLock<T>,
98}
99
100impl<T> Drop for ReadGuard<'_, T> {
101    fn drop(&mut self) {
102        if self
103            .rwlock
104            .state
105            .fetch_sub(READ_LOCK_STEP, Ordering::Release)
106            == 3
107        {
108            self.rwlock
109                .writer_wake_counter
110                .fetch_add(1, Ordering::Release);
111            wake_one(&self.rwlock.writer_wake_counter);
112        }
113    }
114}
115
116impl<T> Deref for ReadGuard<'_, T> {
117    type Target = T;
118
119    fn deref(&self) -> &Self::Target {
120        unsafe { &*self.rwlock.value.get() }
121    }
122}
123
124pub struct WriteGuard<'a, T> {
125    rwlock: &'a NotifyRwLock<T>,
126}
127
128impl<T> Drop for WriteGuard<'_, T> {
129    fn drop(&mut self) {
130        self.rwlock.state.store(0, Ordering::Release);
131        self.rwlock
132            .writer_wake_counter
133            .fetch_add(1, Ordering::Release);
134        wake_one(&self.rwlock.writer_wake_counter);
135        wake_all(&self.rwlock.state);
136        let _ = self.rwlock.notify_tx.try_send(()); // 通知が一杯で送れない場合は、エラーを無視する
137    }
138}
139
140impl<T> Deref for WriteGuard<'_, T> {
141    type Target = T;
142
143    fn deref(&self) -> &Self::Target {
144        unsafe { &*self.rwlock.value.get() }
145    }
146}
147
148impl<T> DerefMut for WriteGuard<'_, T> {
149    fn deref_mut(&mut self) -> &mut Self::Target {
150        unsafe { &mut *self.rwlock.value.get() }
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use std::hint::black_box;
158    use std::time::Instant;
159
160    #[test]
161    fn add_list() {
162        let (tx, _rx) = std::sync::mpsc::sync_channel(1000);
163        let waiter_list = NotifyRwLock::new(tx, Vec::new());
164        black_box(&waiter_list);
165
166        let start = Instant::now();
167        std::thread::scope(|s| {
168            let t1 = s.spawn({
169                || {
170                    for i in 0..1000 {
171                        let mut c = waiter_list.write();
172                        black_box(&c);
173                        c.push(i);
174                    }
175                }
176            });
177            let t2 = s.spawn({
178                || {
179                    for i in 0..1000 {
180                        let mut c = waiter_list.write();
181                        black_box(&c);
182                        c.push(i);
183                    }
184                }
185            });
186            let t3 = s.spawn({
187                || {
188                    for i in 0..1000 {
189                        let c = waiter_list.read();
190                        black_box(&c);
191                    }
192                }
193            });
194            t1.join().unwrap();
195            t2.join().unwrap();
196            t3.join().unwrap();
197        });
198        assert_eq!(waiter_list.read().len(), 2_000);
199    }
200}