use core::cell::UnsafeCell;
use core::ops::{Deref, DerefMut};
use core::sync::atomic::{
AtomicBool,
Ordering::{Acquire, Release},
};
use crate::BackOff;
pub struct SpinLock<T> {
data: UnsafeCell<T>,
locked: AtomicBool,
}
pub struct SpinGuard<'a, T> {
guard: &'a SpinLock<T>,
}
impl<'a, T> Drop for SpinGuard<'a, T> {
#[inline]
fn drop(&mut self) {
self.guard.locked.store(false, Release)
}
}
impl<T> SpinLock<T> {
#[inline(always)]
pub const fn new(data: T) -> Self {
SpinLock {
data: UnsafeCell::new(data),
locked: AtomicBool::new(false),
}
}
#[inline]
pub fn lock(&self) -> SpinGuard<'_, T> {
let backoff = BackOff::new();
while self.locked.swap(true, Acquire) {
backoff.wait();
}
SpinGuard { guard: self }
}
#[inline]
pub unsafe fn unlock(&self) {
self.locked.store(false, Release);
}
#[inline]
pub fn try_lock(&self) -> Option<SpinGuard<'_, T>> {
if !self.locked.swap(true, Acquire) {
Some(SpinGuard { guard: self })
} else {
None
}
}
#[inline(always)]
pub fn is_locked(&self) -> bool {
self.locked.load(Acquire)
}
#[inline]
pub fn try_lock_for(&self, spins: usize) -> Option<SpinGuard<'_, T>> {
let backoff = BackOff::new();
for _ in 0..spins {
if !self.locked.swap(true, Acquire) {
return Some(SpinGuard { guard: self });
}
backoff.wait();
}
None
}
#[inline]
pub fn with_lock<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
let mut guard = self.lock();
f(&mut *guard)
}
}
impl<T> Deref for SpinGuard<'_, T> {
type Target = T;
#[inline(always)]
fn deref(&self) -> &T {
unsafe { &*(self.guard.data.get()) }
}
}
impl<T> DerefMut for SpinGuard<'_, T> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.guard.data.get() }
}
}
unsafe impl<T: Send> Send for SpinLock<T> {}
unsafe impl<T: Send> Sync for SpinLock<T> {}
#[cfg(test)]
mod test{
#[test]
fn test_basic_lock_unlock() {
use crate::SpinLock;
let lock = SpinLock::new(10);
{
let mut guard = lock.lock();
*guard += 5;
assert_eq!(*guard, 15);
}
assert!(!lock.is_locked(), "Lock should be released after guard drop");
}
#[cfg(feature = "std")]
#[test]
fn test_concurrent_access() {
use crate::SpinLock;
use std::sync::Arc;
use std::thread;
let lock = Arc::new(SpinLock::new(0usize));
let mut handles = vec![];
for _ in 0..8 {
let lock_cloned = lock.clone();
handles.push(thread::spawn(move || {
for _ in 0..10_000 {
let mut guard = lock_cloned.lock();
*guard += 1;
}
}));
}
for h in handles {
h.join().unwrap();
}
let final_value = *lock.lock();
assert_eq!(final_value, 8 * 10_000, "Counter should match total increments");
}
#[test]
fn test_try_lock_for_behavior() {
use crate::SpinLock;
let lock = SpinLock::new(42);
let _guard = lock.lock();
assert!(lock.try_lock_for(10).is_none(), "Lock should not be acquirable while held");
drop(_guard);
let guard2 = lock.try_lock_for(1000);
assert!(guard2.is_some(), "Lock should succeed after previous guard drop");
}
}