use core:: {
sync::atomic::{AtomicBool, Ordering},
cell::UnsafeCell,
marker::Sync,
ops::Drop,
};
pub struct SpinMutex<T> {
lock_obj: AtomicBool,
data: UnsafeCell<T>, }
impl<T> SpinMutex<T> {
pub const fn new(data: T) -> Self {
SpinMutex {
lock_obj: AtomicBool::new(false),
data: UnsafeCell::new(data),
}
}
pub fn lock(&self) -> SpinMutexGuard<T> {
while self.lock_obj.swap(true, Ordering::Acquire) {
}
SpinMutexGuard {
lock_obj: &self.lock_obj,
locked_data: self.data.get(),
}
}
pub fn force_unlock(&self) {
self.lock_obj.store(false, Ordering::Release);
}
}
unsafe impl<T> Sync for SpinMutex<T> {}
pub struct SpinMutexGuard<'a, T> {
lock_obj: &'a AtomicBool,
locked_data: *mut T,
}
impl<'a, T> SpinMutexGuard<'a, T>
{
pub fn get_raw_locked_data(&self) -> *mut T
{
self.locked_data
}
}
impl<'a, T> core::ops::Deref for SpinMutexGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*(self.locked_data) }
}
}
impl<'a, T> core::ops::DerefMut for SpinMutexGuard<'a, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *(self.locked_data) }
}
}
impl<'a, T> Drop for SpinMutexGuard<'a, T> {
fn drop(&mut self) {
self.lock_obj.store(false, Ordering::Release);
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Barrier};
use std::thread;
#[test]
fn test_spin_mutex() {
let spin_mutex = SpinMutex::new(AtomicUsize::new(0));
let arc_spin_mutex = Arc::new(spin_mutex);
let barrier = Arc::new(Barrier::new(2));
let barrier_clone = Arc::clone(&barrier);
let arc_spin_mutex_clone = Arc::clone(&arc_spin_mutex);
let handle = thread::spawn(move || {
let guard = arc_spin_mutex.lock();
let data = guard.get_raw_locked_data();
let atomic_data = unsafe { &*data };
atomic_data.store(1, Ordering::Relaxed);
barrier.wait();
});
let handle2 = thread::spawn(move || {
barrier_clone.wait();
let guard = arc_spin_mutex_clone.lock();
let data = guard.get_raw_locked_data();
let atomic_data = unsafe { &*data };
assert_eq!(atomic_data.load(Ordering::Relaxed), 1);
});
handle.join().unwrap();
handle2.join().unwrap();
}
}