1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
use std::cell::UnsafeCell; use std::ops::{Deref, DerefMut}; use std::sync::atomic::{ AtomicU8, Ordering::{AcqRel, Acquire, Release}, }; pub struct SpinLock<T> { mark: AtomicU8, obj: UnsafeCell<T>, } pub struct SpinLockGuard<'a, T> { lock: &'a SpinLock<T>, } impl<T> SpinLock<T> { pub fn new(obj: T) -> Self { Self { mark: AtomicU8::new(0), obj: UnsafeCell::new(obj), } } pub fn lock(&self) -> SpinLockGuard<T> { let backoff = crossbeam_utils::Backoff::new(); while self.mark.compare_and_swap(0, 1, AcqRel) != 0 { backoff.spin(); } SpinLockGuard { lock: self } } } impl<'a, T> Deref for SpinLockGuard<'a, T> { type Target = T; fn deref(&self) -> &'a Self::Target { unsafe { &*self.lock.obj.get() } } } impl<'a, T> DerefMut for SpinLockGuard<'a, T> { fn deref_mut(&mut self) -> &'a mut T { unsafe { &mut *self.lock.obj.get() } } } impl<'a, T> Drop for SpinLockGuard<'a, T> { fn drop(&mut self) { self.lock.mark.store(0, Release); } } unsafe impl<T> Sync for SpinLock<T> {} #[cfg(test)] mod test { use super::*; use std::sync::Arc; use std::thread; #[test] fn lot_load_of_lock() { let lock = Arc::new(SpinLock::new(0)); let num_threads = 32; let thread_turns = 2048; let mut threads = vec![]; for _ in 0..num_threads { let lock = lock.clone(); threads.push(thread::spawn(move || { for _ in 0..thread_turns { *lock.lock() += 1; } })); } for t in threads { t.join().unwrap(); } assert_eq!(*lock.lock(), num_threads * thread_turns); } }