omango_util/
lock.rs

1// Copyright (c) 2024 Trung Tran <tqtrungse@gmail.com>
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in all
11// copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19// SOFTWARE.
20
21use std::{
22    cell::UnsafeCell,
23    ops::{Deref, DerefMut},
24    sync::atomic::{AtomicU32, Ordering},
25};
26
27use crate::backoff::Backoff;
28use crate::hint::unlikely;
29
30const WRITE_NUMBER: u32 = 1_u32 << 30;
31
32pub struct RwSpinlock<T> {
33    flag: AtomicU32,
34    value: UnsafeCell<T>,
35}
36
37unsafe impl<T: Send> Send for RwSpinlock<T> {}
38
39unsafe impl<T: Send> Sync for RwSpinlock<T> {}
40
41impl<T> RwSpinlock<T> {
42    #[inline(always)]
43    pub fn new(value: T) -> Self {
44        Self {
45            flag: AtomicU32::new(0),
46            value: UnsafeCell::new(value),
47        }
48    }
49
50    pub fn try_write(&self) -> Option<RwSpinlockGuard<T>> {
51        if self.flag.compare_exchange_weak(
52            0,
53            WRITE_NUMBER,
54            Ordering::Acquire,
55            Ordering::Relaxed,
56        ).is_ok() {
57            return Some(RwSpinlockGuard { parent: self });
58        }
59        None
60    }
61
62    pub fn try_read(&self) -> Option<RwSpinlockGuard<T>> {
63        let pre_value = self.flag.fetch_add(1, Ordering::Relaxed);
64        if pre_value < WRITE_NUMBER {
65            return Some(RwSpinlockGuard { parent: self });
66        }
67        None
68    }
69
70    pub fn write(&self) -> RwSpinlockGuard<T> {
71        let backoff = Backoff::default();
72        loop {
73            // "compare_exchange" performance is better than "swap".
74            // The reason for using a weak "compare_exchange" is explained here:
75            // https://github.com/Amanieu/parking_lot/pull/207#issuecomment-575869107
76            if self.flag.compare_exchange_weak(
77                0,
78                WRITE_NUMBER,
79                Ordering::Acquire,
80                Ordering::Relaxed,
81            ).is_ok() {
82                break;
83            }
84
85            while self.flag.load(Ordering::Relaxed) != 0 {
86                // Waits the lock is unlocked to reduce CPU cache coherence.
87                backoff.spin();
88            }
89        }
90        RwSpinlockGuard { parent: self }
91    }
92
93    pub fn read(&self) -> RwSpinlockGuard<T> {
94        let backoff = Backoff::default();
95        loop {
96            let pre_value = self.flag.fetch_add(1, Ordering::Relaxed);
97            if pre_value < WRITE_NUMBER {
98                break;
99            }
100
101            while self.flag.load(Ordering::Relaxed) != 0 {
102                // Waits the lock is unlocked to reduce CPU cache coherence.
103                backoff.spin();
104            }
105        }
106        RwSpinlockGuard { parent: self }
107    }
108}
109
110pub struct RwSpinlockGuard<'a, T> {
111    parent: &'a RwSpinlock<T>,
112}
113
114impl<T> Drop for RwSpinlockGuard<'_, T> {
115    #[inline(always)]
116    fn drop(&mut self) {
117        if unlikely(self.parent.flag.load(Ordering::Relaxed) >= WRITE_NUMBER) {
118            self.parent.flag.store(0, Ordering::Release);
119        } else {
120            self.parent.flag.fetch_sub(1, Ordering::Relaxed);
121        }
122    }
123}
124
125impl<T> Deref for RwSpinlockGuard<'_, T> {
126    type Target = T;
127
128    #[inline(always)]
129    fn deref(&self) -> &T {
130        unsafe { &*self.parent.value.get() }
131    }
132}
133
134impl<T> DerefMut for RwSpinlockGuard<'_, T> {
135    #[inline(always)]
136    fn deref_mut(&mut self) -> &mut T {
137        unsafe { &mut *self.parent.value.get() }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use std::sync::Arc;
144    
145    use super::*;
146
147    #[allow(unused_variables)]
148    #[test]
149    fn test_read_unlock() {
150        let m = RwSpinlock::<i32>::new(0);
151        {
152            let r1 = m.read();
153            {
154                let r2 = m.read();
155                let r2 = m.read();
156                assert!(m.try_write().is_none());
157            }
158            assert!(m.try_write().is_none());
159        }
160        assert!(m.try_write().is_some());
161    }
162
163    #[allow(unused_variables)]
164    #[test]
165    fn test_write_unlock() {
166        let m = RwSpinlock::<i32>::new(0);
167        {
168            let w1 = m.write();
169            assert!(m.try_read().is_none());
170        }
171        assert!(m.try_read().is_some());
172    }
173    
174    #[test]
175    fn test_rw_arc() {
176        let arc = Arc::new(RwSpinlock::new(0));
177        let arc2 = arc.clone();
178        let (tx, rx) = std::sync::mpsc::channel();
179    
180        std::thread::spawn(move || {
181            let mut lock = arc2.write();
182            for _ in 0..10 {
183                let tmp = *lock;
184                *lock = -1;
185                std::thread::yield_now();
186                *lock = tmp + 1;
187            }
188            tx.send(()).unwrap();
189        });
190    
191        // Readers try to catch the writer in the act
192        let mut children = Vec::new();
193        for _ in 0..5 {
194            let arc3 = arc.clone();
195            children.push(std::thread::spawn(move || {
196                let lock = arc3.read();
197                assert!(*lock >= 0);
198            }));
199        }
200    
201        // Wait for children to pass their asserts
202        for r in children {
203            assert!(r.join().is_ok());
204        }
205    
206        // Wait for writer to finish
207        rx.recv().unwrap();
208        let lock = arc.read();
209        assert_eq!(*lock, 10);
210    }
211    
212    #[test]
213    fn test_rw_access_in_unwind() {
214        let arc = Arc::new(RwSpinlock::new(1));
215        let arc2 = arc.clone();
216        let _ = std::thread::spawn(move || {
217            struct Unwinder {
218                i: Arc<RwSpinlock<isize>>,
219            }
220            impl Drop for Unwinder {
221                fn drop(&mut self) {
222                    let mut lock = self.i.write();
223                    *lock += 1;
224                }
225            }
226            let _u = Unwinder { i: arc2 };
227            panic!();
228        })
229            .join();
230        let lock = arc.read();
231        assert_eq!(*lock, 2);
232    }
233    
234    #[test]
235    fn test_rwlock_unsized() {
236        let rw: &RwSpinlock<[i32;3]> = &RwSpinlock::new([1, 2, 3]);
237        {
238            let b = &mut *rw.write();
239            b[0] = 4;
240            b[2] = 5;
241        }
242        let comp: &[i32] = &[4, 2, 5];
243        assert_eq!(&*rw.read(), comp);
244    }
245
246    #[allow(clippy::assertions_on_constants)]
247    #[test]
248    fn test_rwlock_try_write() {
249        let lock = RwSpinlock::new(0isize);
250        let read_guard = lock.read();
251    
252        let write_result = lock.try_write();
253        match write_result {
254            None => (),
255            Some(_) => assert!(
256                false,
257                "try_write should not succeed while read_guard is in scope"
258            ),
259        }
260    
261        drop(read_guard);
262    }
263    
264    #[test]
265    fn test_rw_try_read() {
266        let m = RwSpinlock::new(0);
267        std::mem::forget(m.write());
268        assert!(m.try_read().is_none());
269    }
270}