1use std::cell::Cell;
2use std::cell::UnsafeCell;
3use std::ops::{Deref, DerefMut};
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::thread_local;
6
7thread_local! {
8 pub(crate) static CURRENT_ACQUIRED: Cell<bool> = const{ Cell::new(false) };
9}
10
11#[derive(Debug)]
12pub struct SpinLock<T> {
13 state: AtomicBool,
14 data: UnsafeCell<T>,
15}
16impl<T> SpinLock<T> {
17 pub fn new(val: T) -> Self {
18 Self {
19 state: AtomicBool::new(false),
20 data: UnsafeCell::new(val),
21 }
22 }
23 pub fn lock(&self) -> SpinLockGuard<'_, T> {
24 while self
25 .state
26 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
27 .is_err()
28 {
29 std::hint::spin_loop();
30 }
31 CURRENT_ACQUIRED.set(true);
32 SpinLockGuard {
33 data_mut_borrow: unsafe { &mut *self.data.get() },
34 state: self,
35 }
36 }
37
38 pub fn try_lock(&self) -> std::io::Result<SpinLockGuard<'_, T>> {
39 if CURRENT_ACQUIRED.get() {
40 return Err(std::io::Error::new(
41 std::io::ErrorKind::WouldBlock,
42 "Re-acquire the lock that has been acquired would cause a deadlock",
43 ));
44 }
45 Ok(self.lock())
46 }
47
48 pub(self) fn unlock(&self) {
49 CURRENT_ACQUIRED.set(false);
50 self.state.store(false, Ordering::Release);
51 }
52}
53
54unsafe impl<T> Send for SpinLock<T> {}
55unsafe impl<T> Sync for SpinLock<T> {}
56
57#[derive(Debug)]
58pub struct SpinLockGuard<'a, T: 'a> {
59 data_mut_borrow: &'a mut T,
60 state: &'a SpinLock<T>,
61}
62impl<'a, T: 'a> Drop for SpinLockGuard<'a, T> {
63 fn drop(&mut self) {
64 self.state.unlock()
65 }
66}
67impl<'a, T: 'a> Deref for SpinLockGuard<'a, T> {
68 type Target = T;
69
70 fn deref(&self) -> &Self::Target {
71 self.data_mut_borrow
72 }
73}
74impl<'a, T: 'a> DerefMut for SpinLockGuard<'a, T> {
75 fn deref_mut(&mut self) -> &mut Self::Target {
76 self.data_mut_borrow
77 }
78}
79
80#[cfg(test)]
81mod test {
82 use std::sync::Arc;
83
84 use crate::SpinLock;
85
86 #[test]
87 fn synchronization() {
88 let mut result = 1;
89 for i in 1..=20000 {
90 let s_lock = Arc::new(SpinLock::new(result));
91 let s_lock_sub = s_lock.clone();
92 let spin = s_lock.clone();
93 let t = std::thread::spawn(move || {
94 let mut guard = s_lock.lock();
95 *guard += 2;
96 });
97 let t2 = std::thread::spawn(move || {
98 let mut guard = s_lock_sub.lock();
99 *guard += 3;
100 });
101 t.join().unwrap();
102 t2.join().unwrap();
103 result = *spin.lock();
104 assert_eq!(result, (i * 5) + 1);
105 }
106 }
107 #[test]
108 fn sync_ptr() {
109 let mut result = 1;
110 let b = Box::new(result);
111 let mut_ptr = Box::into_raw(b);
112 for i in 1..=20000 {
113 let s_lock = Arc::new(SpinLock::new(mut_ptr));
114 let s_lock_sub = s_lock.clone();
115 let spin = s_lock.clone();
116 let t = std::thread::spawn(move || {
117 let guard = s_lock.lock();
118 let ptr = *guard;
119 unsafe {
120 *ptr += 2;
121 };
122 });
123 let t2 = std::thread::spawn(move || {
124 let guard = s_lock_sub.lock();
125 let ptr = *guard;
126 unsafe {
127 *ptr += 3;
128 };
129 });
130 t.join().unwrap();
131 t2.join().unwrap();
132 result = unsafe { **spin.lock() };
133 assert_eq!(result, (i * 5) + 1);
134 }
135 unsafe { drop(Box::from_raw(mut_ptr)) };
136 }
137
138 #[test]
139 fn avoid_deadlock() {
140 let spin = Arc::new(SpinLock::new(0));
141 let spin2 = spin.clone();
142 let t1 = std::thread::spawn(move || {
143 let _guard = spin.lock();
144 let r = spin.try_lock();
145 println!("{:?}", r);
146 assert!(r.is_err());
147 match r {
148 Ok(_) => unreachable!(),
149 Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::WouldBlock),
150 };
151 });
152 let t = std::thread::spawn(move || {
153 assert!(spin2.try_lock().is_ok());
154 });
155 t.join().unwrap();
156 t1.join().unwrap();
157 }
158}