spin_lock/
lib.rs

1use std::cell::Cell;
2use std::cell::UnsafeCell;
3use std::ops::{Deref, DerefMut};
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::thread_local;
6
7thread_local! {
8    pub(crate) static CURRENT_ACQUIRED: Cell<bool> = const{ Cell::new(false) };
9}
10
11#[derive(Debug)]
12pub struct SpinLock<T> {
13    state: AtomicBool,
14    data: UnsafeCell<T>,
15}
16impl<T> SpinLock<T> {
17    pub fn new(val: T) -> Self {
18        Self {
19            state: AtomicBool::new(false),
20            data: UnsafeCell::new(val),
21        }
22    }
23    pub fn lock(&self) -> SpinLockGuard<'_, T> {
24        while self
25            .state
26            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
27            .is_err()
28        {
29            std::hint::spin_loop();
30        }
31        CURRENT_ACQUIRED.set(true);
32        SpinLockGuard {
33            data_mut_borrow: unsafe { &mut *self.data.get() },
34            state: self,
35        }
36    }
37
38    pub fn try_lock(&self) -> std::io::Result<SpinLockGuard<'_, T>> {
39        if CURRENT_ACQUIRED.get() {
40            return Err(std::io::Error::new(
41                std::io::ErrorKind::WouldBlock,
42                "Re-acquire the lock that has been acquired would cause a deadlock",
43            ));
44        }
45        Ok(self.lock())
46    }
47
48    pub(self) fn unlock(&self) {
49        CURRENT_ACQUIRED.set(false);
50        self.state.store(false, Ordering::Release);
51    }
52}
53
54unsafe impl<T> Send for SpinLock<T> {}
55unsafe impl<T> Sync for SpinLock<T> {}
56
57#[derive(Debug)]
58pub struct SpinLockGuard<'a, T: 'a> {
59    data_mut_borrow: &'a mut T,
60    state: &'a SpinLock<T>,
61}
62impl<'a, T: 'a> Drop for SpinLockGuard<'a, T> {
63    fn drop(&mut self) {
64        self.state.unlock()
65    }
66}
67impl<'a, T: 'a> Deref for SpinLockGuard<'a, T> {
68    type Target = T;
69
70    fn deref(&self) -> &Self::Target {
71        self.data_mut_borrow
72    }
73}
74impl<'a, T: 'a> DerefMut for SpinLockGuard<'a, T> {
75    fn deref_mut(&mut self) -> &mut Self::Target {
76        self.data_mut_borrow
77    }
78}
79
80#[cfg(test)]
81mod test {
82    use std::sync::Arc;
83
84    use crate::SpinLock;
85
86    #[test]
87    fn synchronization() {
88        let mut result = 1;
89        for i in 1..=20000 {
90            let s_lock = Arc::new(SpinLock::new(result));
91            let s_lock_sub = s_lock.clone();
92            let spin = s_lock.clone();
93            let t = std::thread::spawn(move || {
94                let mut guard = s_lock.lock();
95                *guard += 2;
96            });
97            let t2 = std::thread::spawn(move || {
98                let mut guard = s_lock_sub.lock();
99                *guard += 3;
100            });
101            t.join().unwrap();
102            t2.join().unwrap();
103            result = *spin.lock();
104            assert_eq!(result, (i * 5) + 1);
105        }
106    }
107    #[test]
108    fn sync_ptr() {
109        let mut result = 1;
110        let b = Box::new(result);
111        let mut_ptr = Box::into_raw(b);
112        for i in 1..=20000 {
113            let s_lock = Arc::new(SpinLock::new(mut_ptr));
114            let s_lock_sub = s_lock.clone();
115            let spin = s_lock.clone();
116            let t = std::thread::spawn(move || {
117                let guard = s_lock.lock();
118                let ptr = *guard;
119                unsafe {
120                    *ptr += 2;
121                };
122            });
123            let t2 = std::thread::spawn(move || {
124                let guard = s_lock_sub.lock();
125                let ptr = *guard;
126                unsafe {
127                    *ptr += 3;
128                };
129            });
130            t.join().unwrap();
131            t2.join().unwrap();
132            result = unsafe { **spin.lock() };
133            assert_eq!(result, (i * 5) + 1);
134        }
135        unsafe { drop(Box::from_raw(mut_ptr)) };
136    }
137
138    #[test]
139    fn avoid_deadlock() {
140        let spin = Arc::new(SpinLock::new(0));
141        let spin2 = spin.clone();
142        let t1 = std::thread::spawn(move || {
143            let _guard = spin.lock();
144            let r = spin.try_lock();
145            println!("{:?}", r);
146            assert!(r.is_err());
147            match r {
148                Ok(_) => unreachable!(),
149                Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::WouldBlock),
150            };
151        });
152        let t = std::thread::spawn(move || {
153            assert!(spin2.try_lock().is_ok());
154        });
155        t.join().unwrap();
156        t1.join().unwrap();
157    }
158}