lightning/
spin.rs

1use std::cell::UnsafeCell;
2use std::ops::{Deref, DerefMut};
3use std::sync::atomic::{
4    AtomicU8,
5    Ordering::{AcqRel, Acquire, Release},
6};
7
8pub struct SpinLock<T> {
9    mark: AtomicU8,
10    obj: UnsafeCell<T>,
11}
12
13pub struct SpinLockGuard<'a, T> {
14    lock: &'a SpinLock<T>,
15}
16
17impl<T> SpinLock<T> {
18    pub fn new(obj: T) -> Self {
19        Self {
20            mark: AtomicU8::new(0),
21            obj: UnsafeCell::new(obj),
22        }
23    }
24
25    pub fn lock(&self) -> SpinLockGuard<T> {
26        let backoff = crossbeam_utils::Backoff::new();
27        while self.mark.compare_and_swap(0, 1, AcqRel) != 0 {
28            backoff.spin();
29        }
30        SpinLockGuard { lock: self }
31    }
32}
33
34impl<'a, T> Deref for SpinLockGuard<'a, T> {
35    type Target = T;
36    fn deref(&self) -> &'a Self::Target {
37        unsafe { &*self.lock.obj.get() }
38    }
39}
40
41impl<'a, T> DerefMut for SpinLockGuard<'a, T> {
42    fn deref_mut(&mut self) -> &'a mut T {
43        unsafe { &mut *self.lock.obj.get() }
44    }
45}
46
47impl<'a, T> Drop for SpinLockGuard<'a, T> {
48    fn drop(&mut self) {
49        self.lock.mark.store(0, Release);
50    }
51}
52
53unsafe impl<T> Sync for SpinLock<T> {}
54
55#[cfg(test)]
56mod test {
57    use super::*;
58    use std::sync::Arc;
59    use std::thread;
60
61    #[test]
62    fn lot_load_of_lock() {
63        let lock = Arc::new(SpinLock::new(0));
64        let num_threads = 32;
65        let thread_turns = 2048;
66        let mut threads = vec![];
67        for _ in 0..num_threads {
68            let lock = lock.clone();
69            threads.push(thread::spawn(move || {
70                for _ in 0..thread_turns {
71                    *lock.lock() += 1;
72                }
73            }));
74        }
75        for t in threads {
76            t.join().unwrap();
77        }
78        assert_eq!(*lock.lock(), num_threads * thread_turns);
79    }
80}