spin-lock 0.2.1

A simple spin lock implemented by the standard facilities
Documentation
use std::cell::Cell;
use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread_local;

thread_local! {
    pub(crate) static CURRENT_ACQUIRED: Cell<bool> = const{ Cell::new(false) };
}

#[derive(Debug)]
pub struct SpinLock<T> {
    state: AtomicBool,
    data: UnsafeCell<T>,
}
impl<T> SpinLock<T> {
    pub fn new(val: T) -> Self {
        Self {
            state: AtomicBool::new(false),
            data: UnsafeCell::new(val),
        }
    }
    pub fn lock(&self) -> SpinLockGuard<'_, T> {
        while self
            .state
            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
            .is_err()
        {
            std::hint::spin_loop();
        }
        CURRENT_ACQUIRED.set(true);
        SpinLockGuard {
            data_mut_borrow: unsafe { &mut *self.data.get() },
            state: self,
        }
    }

    pub fn try_lock(&self) -> std::io::Result<SpinLockGuard<'_, T>> {
        if CURRENT_ACQUIRED.get() {
            return Err(std::io::Error::new(
                std::io::ErrorKind::WouldBlock,
                "Re-acquire the lock that has been acquired would cause a deadlock",
            ));
        }
        Ok(self.lock())
    }

    pub(self) fn unlock(&self) {
        CURRENT_ACQUIRED.set(false);
        self.state.store(false, Ordering::Release);
    }
}

unsafe impl<T> Send for SpinLock<T> {}
unsafe impl<T> Sync for SpinLock<T> {}

#[derive(Debug)]
pub struct SpinLockGuard<'a, T: 'a> {
    data_mut_borrow: &'a mut T,
    state: &'a SpinLock<T>,
}
impl<'a, T: 'a> Drop for SpinLockGuard<'a, T> {
    fn drop(&mut self) {
        self.state.unlock()
    }
}
impl<'a, T: 'a> Deref for SpinLockGuard<'a, T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        self.data_mut_borrow
    }
}
impl<'a, T: 'a> DerefMut for SpinLockGuard<'a, T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        self.data_mut_borrow
    }
}

#[cfg(test)]
mod test {
    use std::sync::Arc;

    use crate::SpinLock;

    #[test]
    fn synchronization() {
        let mut result = 1;
        for i in 1..=20000 {
            let s_lock = Arc::new(SpinLock::new(result));
            let s_lock_sub = s_lock.clone();
            let spin = s_lock.clone();
            let t = std::thread::spawn(move || {
                let mut guard = s_lock.lock();
                *guard += 2;
            });
            let t2 = std::thread::spawn(move || {
                let mut guard = s_lock_sub.lock();
                *guard += 3;
            });
            t.join().unwrap();
            t2.join().unwrap();
            result = *spin.lock();
            assert_eq!(result, (i * 5) + 1);
        }
    }
    #[test]
    fn sync_ptr() {
        let mut result = 1;
        let b = Box::new(result);
        let mut_ptr = Box::into_raw(b);
        for i in 1..=20000 {
            let s_lock = Arc::new(SpinLock::new(mut_ptr));
            let s_lock_sub = s_lock.clone();
            let spin = s_lock.clone();
            let t = std::thread::spawn(move || {
                let guard = s_lock.lock();
                let ptr = *guard;
                unsafe {
                    *ptr += 2;
                };
            });
            let t2 = std::thread::spawn(move || {
                let guard = s_lock_sub.lock();
                let ptr = *guard;
                unsafe {
                    *ptr += 3;
                };
            });
            t.join().unwrap();
            t2.join().unwrap();
            result = unsafe { **spin.lock() };
            assert_eq!(result, (i * 5) + 1);
        }
        unsafe { drop(Box::from_raw(mut_ptr)) };
    }

    #[test]
    fn avoid_deadlock() {
        let spin = Arc::new(SpinLock::new(0));
        let spin2 = spin.clone();
        let t1 = std::thread::spawn(move || {
            let _guard = spin.lock();
            let r = spin.try_lock();
            println!("{:?}", r);
            assert!(r.is_err());
            match r {
                Ok(_) => unreachable!(),
                Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::WouldBlock),
            };
        });
        let t = std::thread::spawn(move || {
            assert!(spin2.try_lock().is_ok());
        });
        t.join().unwrap();
        t1.join().unwrap();
    }
}