use core::cell::UnsafeCell;
use core::fmt;
use core::marker::PhantomData;
use core::ops::{Deref, DerefMut};
use core::sync::atomic::{AtomicBool, AtomicU8, Ordering};
#[derive(Debug)]
pub struct PriorityLock<T> {
wants_to_enter: [AtomicBool; 2],
turn: AtomicU8,
data: UnsafeCell<T>,
}
impl<T> PriorityLock<T> {
pub const fn new(data: T) -> Self {
Self {
wants_to_enter: [AtomicBool::new(false), AtomicBool::new(false)],
turn: AtomicU8::new(0),
data: UnsafeCell::new(data),
}
}
pub fn split<'a>(&'a mut self) -> (LockHalf<'a, T, PLow>, LockHalf<'a, T, PHigh>) {
let low = LockHalf {
lock: self,
_p: PhantomData,
};
let high = LockHalf {
lock: self,
_p: PhantomData,
};
(low, high)
}
fn try_acquire_raw(&self, index: u8) -> Result<(), ()> {
let other_index = (index + 1) % 2;
self.wants_to_enter[usize::from(index)].store(true, Ordering::Release);
self.turn.store(other_index, Ordering::Release);
if self.wants_to_enter[usize::from(other_index)].load(Ordering::Acquire)
&& self.turn.load(Ordering::Acquire) == other_index
{
self.wants_to_enter[usize::from(index)].store(false, Ordering::Release);
Err(())
} else {
Ok(())
}
}
fn block_acquire_raw(&self, index: u8) {
let other_index = (index + 1) % 2;
self.wants_to_enter[usize::from(index)].store(true, Ordering::Release);
self.turn.store(other_index, Ordering::Release);
while self.wants_to_enter[usize::from(other_index)].load(Ordering::Acquire)
&& self.turn.load(Ordering::Acquire) == other_index
{}
}
unsafe fn unlock(&self, index: u8) {
self.wants_to_enter[usize::from(index)].store(false, Ordering::Release);
}
}
mod sealed {
pub trait Sealed {}
}
pub trait LockPriority: sealed::Sealed {
#[doc(hidden)]
const INDEX: u8;
}
#[derive(Debug)]
pub enum PHigh {}
#[derive(Debug)]
pub enum PLow {}
impl sealed::Sealed for PHigh {}
impl sealed::Sealed for PLow {}
impl LockPriority for PLow {
const INDEX: u8 = 0;
}
impl LockPriority for PHigh {
const INDEX: u8 = 1;
}
#[allow(missing_debug_implementations)]
pub struct Deadlock {}
#[derive(Debug)]
pub struct LockHalf<'a, T, P: LockPriority> {
lock: &'a PriorityLock<T>,
_p: PhantomData<P>,
}
impl<'a, T> LockHalf<'a, T, PLow> {
pub fn lock(&mut self) -> LockGuard<'a, T, PLow> {
self.lock.block_acquire_raw(0);
LockGuard {
lock: self.lock,
_p: PhantomData,
}
}
}
impl<'a, T> LockHalf<'a, T, PHigh> {
pub fn try_lock(&mut self) -> Result<LockGuard<'a, T, PHigh>, Deadlock> {
self.lock.try_acquire_raw(1).map_err(|_| Deadlock {})?;
Ok(LockGuard {
lock: self.lock,
_p: PhantomData,
})
}
}
pub struct LockGuard<'a, T, P: LockPriority> {
lock: &'a PriorityLock<T>,
_p: PhantomData<P>,
}
impl<'a, T, P: LockPriority> Deref for LockGuard<'a, T, P> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.data.get() }
}
}
impl<'a, T, P: LockPriority> DerefMut for LockGuard<'a, T, P> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.data.get() }
}
}
impl<'a, T, P: LockPriority> Drop for LockGuard<'a, T, P> {
fn drop(&mut self) {
unsafe {
self.lock.unlock(P::INDEX);
}
}
}
impl<'a, T: fmt::Debug, P: LockPriority> fmt::Debug for LockGuard<'a, T, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<'a, T: fmt::Display, P: LockPriority> fmt::Display for LockGuard<'a, T, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&**self, f)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple() {
let mut lock = PriorityLock::new(0u32);
let (mut low, mut high) = lock.split();
let mut low_guard = low.lock();
*low_guard += 1;
assert!(high.try_lock().is_err());
drop(low_guard);
let mut high_guard = high.try_lock().map_err(drop).unwrap();
assert_eq!(*high_guard, 1);
*high_guard += 1;
}
}