#![cfg_attr(not(feature = "std"), no_std)]
use core::cell::UnsafeCell;
use core::fmt::{self, Debug, Display, Formatter};
use core::ops::{Deref, DerefMut};
use core::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering};
const SOME: usize = 0x1;
const FREE: usize = 0x2;
const AVAILABLE: usize = FREE | SOME;
const SHIFT: usize = 2;
pub struct OptionLock<T> {
data: UnsafeCell<Option<T>>,
state: AtomicUsize,
}
impl<T: Default> Default for OptionLock<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum ReadError {
Empty,
Locked,
}
impl ReadError {
pub fn is_empty(&self) -> bool {
matches!(self, Self::Empty)
}
pub fn is_locked(&self) -> bool {
matches!(self, Self::Locked)
}
}
impl Display for ReadError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Empty => write!(f, "OptionLock is empty"),
Self::Locked => write!(f, "OptionLock is locked exclusively"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for ReadError {}
#[derive(Debug, PartialEq, Eq)]
pub struct TryLockError;
impl Display for TryLockError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "OptionLock cannot be locked exclusively")
}
}
#[cfg(feature = "std")]
impl std::error::Error for TryLockError {}
#[derive(Debug, PartialEq, Eq)]
pub enum Status {
Available,
Empty,
ExclusiveLock,
ReadLock(usize),
}
impl Status {
#[inline]
pub(crate) fn new(val: usize) -> Self {
match val {
AVAILABLE => Self::Available,
state if state & AVAILABLE == FREE => Self::Empty,
state if state & FREE == 0 && (state & SOME == 0 || state >> SHIFT == 0) => {
Self::ExclusiveLock
}
state => Self::ReadLock(state >> SHIFT),
}
}
pub fn can_read(&self) -> bool {
matches!(self, Self::Available | Self::ReadLock(..))
}
pub fn can_take(&self) -> bool {
matches!(self, Self::Available)
}
pub fn is_empty(&self) -> bool {
matches!(self, Self::Empty)
}
pub fn is_locked(&self) -> bool {
matches!(self, Self::ExclusiveLock)
}
pub fn readers(&self) -> usize {
if let Self::ReadLock(count) = self {
*count
} else {
0
}
}
}
unsafe impl<T: Send> Send for OptionLock<T> {}
unsafe impl<T: Send> Sync for OptionLock<T> {}
impl<T> OptionLock<T> {
pub const fn new() -> Self {
Self {
data: UnsafeCell::new(None),
state: AtomicUsize::new(FREE),
}
}
pub unsafe fn get_unchecked(&self) -> &mut Option<T> {
&mut *self.data.get()
}
pub fn get_mut(&mut self) -> &mut Option<T> {
unsafe { self.get_unchecked() }
}
pub fn into_inner(self) -> Option<T> {
self.data.into_inner()
}
pub fn spin_lock(&self) -> OptionGuard<'_, T> {
loop {
if let Ok(guard) = self.try_lock() {
break guard;
}
while self.status().is_locked() {
spin_loop_hint();
}
}
}
pub fn spin_take(&self) -> T {
loop {
if let Ok(result) = self.try_take() {
break result;
}
while !self.status().can_take() {
spin_loop_hint();
}
}
}
pub fn status(&self) -> Status {
Status::new(self.state.load(Ordering::Relaxed))
}
pub fn try_lock(&self) -> Result<OptionGuard<'_, T>, TryLockError> {
let state = self.state.fetch_and(!FREE, Ordering::AcqRel);
if state & FREE == FREE && (state & SOME == 0 || state >> SHIFT == 0) {
Ok(OptionGuard { lock: self })
} else {
Err(TryLockError)
}
}
pub fn try_read(&self) -> Result<OptionRead<'_, T>, ReadError> {
let state = self.state.fetch_add(1 << SHIFT, Ordering::AcqRel);
if state == AVAILABLE || (state & SOME == SOME && state >> SHIFT != 0) {
Ok(OptionRead { lock: self })
} else if state >> 1 == 0 {
Err(ReadError::Locked)
} else {
Err(ReadError::Empty)
}
}
pub fn try_take(&self) -> Result<T, ReadError> {
loop {
match self.state.compare_exchange_weak(
AVAILABLE,
SOME,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break Ok(OptionGuard { lock: self }.take().unwrap()),
Err(AVAILABLE) => {
}
Err(state) if state & AVAILABLE == FREE => break Err(ReadError::Empty),
Err(_) => break Err(ReadError::Locked),
}
}
}
}
impl<T> From<T> for OptionLock<T> {
fn from(data: T) -> Self {
Self {
data: UnsafeCell::new(Some(data)),
state: AtomicUsize::new(SOME | FREE),
}
}
}
impl<T> From<Option<T>> for OptionLock<T> {
fn from(data: Option<T>) -> Self {
let state = if data.is_some() { AVAILABLE } else { FREE };
Self {
data: UnsafeCell::new(data),
state: AtomicUsize::new(state),
}
}
}
impl<T> Debug for OptionLock<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("OptionLock")
.field("status", &self.status())
.finish()
}
}
pub struct OptionRead<'a, T> {
lock: &'a OptionLock<T>,
}
impl<'a, T: 'a> OptionRead<'a, T> {
pub fn try_lock(read: OptionRead<'a, T>) -> Result<OptionGuard<'a, T>, OptionRead<'a, T>> {
let mut state = (1 << SHIFT) | AVAILABLE;
loop {
match read.lock.state.compare_exchange_weak(
state,
SOME,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
let guard = OptionGuard { lock: read.lock };
std::mem::forget(read);
break Ok(guard);
}
Err(s) if s & !FREE == (1 << SHIFT) | SOME => {
state = s;
continue;
}
_ => break Err(read),
}
}
}
pub fn try_take(read: OptionRead<'a, T>) -> Result<T, OptionRead<'a, T>> {
Self::try_lock(read).map(|mut guard| guard.take().unwrap())
}
}
impl<T> Deref for OptionRead<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.lock.get_unchecked() }.as_ref().unwrap()
}
}
impl<T: Debug> Debug for OptionRead<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if f.alternate() {
f.debug_tuple("OptionRead").field(&**self).finish()
} else {
Debug::fmt(&**self, f)
}
}
}
impl<T: Display> Display for OptionRead<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&**self, f)
}
}
impl<T> Drop for OptionRead<'_, T> {
fn drop(&mut self) {
let prev = self.lock.state.fetch_sub(1 << SHIFT, Ordering::AcqRel);
if prev == 1 << SHIFT | SOME {
self.lock
.state
.compare_exchange(SOME, AVAILABLE, Ordering::AcqRel, Ordering::Relaxed)
.unwrap_or_default();
}
}
}
impl<T> PartialEq for OptionRead<'_, T> {
fn eq(&self, other: &Self) -> bool {
std::ptr::eq(self.lock, other.lock)
}
}
impl<T> Eq for OptionRead<'_, T> {}
pub struct OptionGuard<'a, T> {
lock: &'a OptionLock<T>,
}
impl<T> OptionGuard<'_, T> {
pub fn downgrade<'a>(guard: OptionGuard<'a, T>) -> Option<OptionRead<'_, T>> {
if guard.is_some() {
guard
.lock
.state
.store((1 << SHIFT) | AVAILABLE, Ordering::Release);
let read = OptionRead { lock: guard.lock };
std::mem::forget(guard);
Some(read)
} else {
None
}
}
}
impl<T: Debug> Debug for OptionGuard<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if f.alternate() {
f.debug_tuple("OptionGuard").field(&**self).finish()
} else {
Debug::fmt(&**self, f)
}
}
}
impl<T> Deref for OptionGuard<'_, T> {
type Target = Option<T>;
fn deref(&self) -> &Self::Target {
unsafe { self.lock.get_unchecked() }
}
}
impl<T> DerefMut for OptionGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.lock.get_unchecked() }
}
}
impl<'a, T> Drop for OptionGuard<'a, T> {
fn drop(&mut self) {
self.lock.state.store(
if self.is_some() { AVAILABLE } else { FREE },
Ordering::Release,
);
}
}
impl<T> PartialEq for OptionGuard<'_, T> {
fn eq(&self, other: &Self) -> bool {
std::ptr::eq(self.lock, other.lock)
}
}
impl<T> Eq for OptionGuard<'_, T> {}