1use std::cell::UnsafeCell;
2use std::ops::{Deref, DerefMut};
3use std::sync::atomic::{
4 AtomicU8,
5 Ordering::{AcqRel, Acquire, Release},
6};
7
8pub struct SpinLock<T> {
9 mark: AtomicU8,
10 obj: UnsafeCell<T>,
11}
12
13pub struct SpinLockGuard<'a, T> {
14 lock: &'a SpinLock<T>,
15}
16
17impl<T> SpinLock<T> {
18 pub fn new(obj: T) -> Self {
19 Self {
20 mark: AtomicU8::new(0),
21 obj: UnsafeCell::new(obj),
22 }
23 }
24
25 pub fn lock(&self) -> SpinLockGuard<T> {
26 let backoff = crossbeam_utils::Backoff::new();
27 while self.mark.compare_and_swap(0, 1, AcqRel) != 0 {
28 backoff.spin();
29 }
30 SpinLockGuard { lock: self }
31 }
32}
33
34impl<'a, T> Deref for SpinLockGuard<'a, T> {
35 type Target = T;
36 fn deref(&self) -> &'a Self::Target {
37 unsafe { &*self.lock.obj.get() }
38 }
39}
40
41impl<'a, T> DerefMut for SpinLockGuard<'a, T> {
42 fn deref_mut(&mut self) -> &'a mut T {
43 unsafe { &mut *self.lock.obj.get() }
44 }
45}
46
47impl<'a, T> Drop for SpinLockGuard<'a, T> {
48 fn drop(&mut self) {
49 self.lock.mark.store(0, Release);
50 }
51}
52
53unsafe impl<T> Sync for SpinLock<T> {}
54
55#[cfg(test)]
56mod test {
57 use super::*;
58 use std::sync::Arc;
59 use std::thread;
60
61 #[test]
62 fn lot_load_of_lock() {
63 let lock = Arc::new(SpinLock::new(0));
64 let num_threads = 32;
65 let thread_turns = 2048;
66 let mut threads = vec![];
67 for _ in 0..num_threads {
68 let lock = lock.clone();
69 threads.push(thread::spawn(move || {
70 for _ in 0..thread_turns {
71 *lock.lock() += 1;
72 }
73 }));
74 }
75 for t in threads {
76 t.join().unwrap();
77 }
78 assert_eq!(*lock.lock(), num_threads * thread_turns);
79 }
80}