#![allow(unknown_lints)]
#![deny(clippy::enum_glob_use)]
#![allow(clippy::missing_transmute_annotations)]
#![allow(clippy::doc_lazy_continuation)]
#![doc = include_str!("../README.md")]
#![no_std]
#[cfg(feature = "std")]
extern crate std;
#[cfg(not(any(feature = "std", panic = "abort")))]
compile_error!("no_std version of this crate requires panic = abort to ensure safety.");
use core::cell::UnsafeCell;
use core::future::Future;
use core::marker::{PhantomData, PhantomPinned};
use core::mem::{align_of, size_of, ManuallyDrop};
use core::pin::Pin;
use core::ptr::{null_mut, NonNull};
use core::sync::atomic::{self, AtomicBool, AtomicPtr, AtomicU8, Ordering};
use core::task::{Context, Poll};
use atomic_waker::{AtomicWaker, AtomicWakerState};
#[doc(hidden)]
pub mod mpmc;
use mpmc::{MPMCRef, MPMC};
#[doc(hidden)]
mod atomic_waker;
#[repr(C)]
pub struct BorrowMutex<const MAX_BORROWERS: usize, T: ?Sized> {
inner_ref: UnsafeCell<Option<NonNull<T>>>,
lend_waiter: AtomicWaker,
lend_waiter_state: AtomicWakerState,
state: AtomicU8,
borrowers: MPMC<MAX_BORROWERS, BorrowRef>,
}
#[repr(u8)]
#[derive(Debug)]
enum LendState {
None = 0,
Starting,
Lending,
Terminating,
}
impl<const M: usize, T: ?Sized> BorrowMutex<M, T> {
pub const MAX_BORROWERS: usize = M;
pub const fn new() -> Self {
Self {
inner_ref: UnsafeCell::new(None),
lend_waiter: AtomicWaker::new(None),
lend_waiter_state: AtomicWakerState::new(0),
state: AtomicU8::new(LendState::None as u8),
borrowers: MPMC::new(),
}
}
#[inline]
fn as_ptr(&self) -> BorrowMutexRef<'_, T> {
BorrowMutexRef(self as *const _ as *const BorrowMutex<0, T>, PhantomData)
}
pub fn request_borrow<'g, 'm: 'g>(&'m self) -> BorrowGuardUnarmed<'g, T> {
BorrowGuardUnarmed {
mutex: self.as_ptr(),
inner: AtomicPtr::new(null_mut()),
terminated: AtomicBool::new(false),
}
}
pub fn wait_to_lend<'g, 'm: 'g>(&'m self) -> LendWaiter<'g, T> {
LendWaiter {
mutex: self.as_ptr(),
}
}
pub fn lend<'g, 'm: 'g>(&'m self, value: &'g mut T) -> Option<LendGuard<'g, T>> {
if let Err(prev) = self.state.compare_exchange(
LendState::None as u8,
LendState::Starting as u8,
Ordering::Acquire,
Ordering::Acquire,
) {
if prev == LendState::Terminating as u8 {
abort("BorrowMutex lended to while terminated");
} else {
abort("multiple distinct references lended to a BorrowMutex");
}
}
unsafe { *self.inner_ref.get() = Some(NonNull::from(value)) };
self.state
.store(LendState::Lending as u8, Ordering::Release);
let borrow = self.borrowers.peek()?;
let borrow = unsafe { &*borrow.get() };
Some(LendGuard {
mutex: self.as_ptr(),
borrow,
_marker: PhantomPinned,
})
}
pub async fn terminate(&self) {
if let Err(prev) = self.state.compare_exchange(
LendState::None as u8,
LendState::Terminating as u8,
Ordering::Relaxed,
Ordering::Relaxed,
) {
if prev == LendState::Terminating as u8 {
return;
} else {
abort("BorrowMutex terminated while a reference is lended");
}
}
atomic::fence(Ordering::SeqCst);
while let Some(borrow) = self.borrowers.peek() {
let borrow = unsafe { &*borrow.get() };
let lend_guard = LendGuard {
mutex: self.as_ptr(),
borrow,
_marker: PhantomPinned,
};
lend_guard.await;
}
}
const fn borrowers_offset() -> usize {
let offset = size_of::<UnsafeCell<Option<NonNull<T>>>>()
+ size_of::<AtomicWaker>()
+ size_of::<AtomicWakerState>()
+ size_of::<AtomicU8>();
let align = align_of::<MPMC<M, BorrowRef>>();
(offset + align - 1) & !(align - 1)
}
}
unsafe impl<const M: usize, T: ?Sized + Send> Send for BorrowMutex<M, T> {}
unsafe impl<const M: usize, T: ?Sized + Send> Sync for BorrowMutex<M, T> {}
impl<const M: usize, T: ?Sized> Default for BorrowMutex<M, T> {
fn default() -> Self {
Self::new()
}
}
impl<const M: usize, T: ?Sized> core::fmt::Debug for BorrowMutex<M, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!(
"BorrowMutex {{ lender_state: {:?} }}",
self.state.load(Ordering::Relaxed)
))
}
}
struct BorrowMutexRef<'a, T: ?Sized>(*const BorrowMutex<0, T>, PhantomData<&'a T>);
impl<'a, T: ?Sized> BorrowMutexRef<'a, T> {
#[inline]
fn borrowers(&self) -> MPMCRef<'_, BorrowRef> {
unsafe {
MPMCRef::from_ptr(
self.0
.cast::<u8>()
.add(BorrowMutex::<0, T>::borrowers_offset())
as *const MPMC<0, BorrowRef>,
)
}
}
}
impl<'a, T: ?Sized> core::ops::Deref for BorrowMutexRef<'a, T> {
type Target = BorrowMutex<0, T>;
fn deref(&self) -> &Self::Target {
unsafe { &*self.0 }
}
}
impl<'a, T: ?Sized> Clone for BorrowMutexRef<'a, T> {
fn clone(&self) -> Self {
BorrowMutexRef(self.0, PhantomData)
}
}
struct BorrowRef {
borrow_waker: AtomicWaker,
borrow_waker_state: AtomicWakerState,
ref_acquired: AtomicBool,
guard_present: AtomicBool,
}
pub struct BorrowGuardUnarmed<'g, T: ?Sized> {
mutex: BorrowMutexRef<'g, T>,
inner: AtomicPtr<BorrowRef>,
terminated: AtomicBool,
}
#[derive(Debug)]
pub enum Error {
TooManyBorrows,
Terminated,
}
impl core::fmt::Display for Error {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
use Error as E;
let msg = match self {
E::TooManyBorrows => "Too many borrow requests that are still pending",
E::Terminated => "The mutex was terminated and won't be ever lend-ed to again",
};
f.write_str(msg)
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
impl<'g, T: 'g + ?Sized> Future for BorrowGuardUnarmed<'g, T> {
type Output = Result<BorrowGuardArmed<'g, T>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.mutex.state.load(Ordering::Acquire) == LendState::Terminating as u8 {
return Poll::Ready(Err(Error::Terminated));
}
if self.terminated.load(Ordering::Relaxed) {
return Poll::Ready(Err(Error::Terminated));
}
if self.inner.load(Ordering::Relaxed).is_null() {
let Ok(inner) = self
.mutex
.borrowers() .push(BorrowRef {
borrow_waker: AtomicWaker::new(None),
borrow_waker_state: AtomicWakerState::new(0),
ref_acquired: AtomicBool::new(false),
guard_present: AtomicBool::new(true),
})
else {
return Poll::Ready(Err(Error::TooManyBorrows));
};
atomic_waker::wake(&self.mutex.lend_waiter, &self.mutex.lend_waiter_state);
self.inner.store(inner.get(), Ordering::Relaxed);
}
atomic::fence(Ordering::SeqCst);
if self.mutex.state.load(Ordering::Relaxed) == LendState::Terminating as u8 {
return Poll::Ready(Err(Error::Terminated));
}
let inner = unsafe { &*self.inner.load(Ordering::Relaxed) };
if atomic_waker::poll_const(&inner.borrow_waker, &inner.borrow_waker_state, cx.waker())
== Poll::Pending
{
return Poll::Pending;
}
let lend_state = self.mutex.state.load(Ordering::Acquire);
assert_eq!(lend_state, LendState::Lending as u8);
let inner_ref = unsafe { *self.mutex.inner_ref.get() }.unwrap();
self.terminated.store(true, Ordering::Relaxed);
let (lend_waiter, lend_waiter_state) = unsafe {
(
std::mem::transmute(&self.mutex.lend_waiter),
std::mem::transmute(&self.mutex.lend_waiter_state),
)
};
Poll::Ready(Ok(BorrowGuardArmed {
inner_ref,
lend_waiter,
lend_waiter_state,
inner,
}))
}
}
impl<'m, T: ?Sized> Drop for BorrowGuardUnarmed<'m, T> {
fn drop(&mut self) {
if !self.terminated.load(Ordering::Relaxed) {
let ref_ptr = self.inner.load(Ordering::Relaxed);
if !ref_ptr.is_null() {
unsafe { &*ref_ptr }
.guard_present
.store(false, Ordering::Relaxed);
atomic_waker::wake(&self.mutex.lend_waiter, &self.mutex.lend_waiter_state);
}
}
}
}
unsafe impl<'m, T: ?Sized + Send> Send for BorrowGuardUnarmed<'m, T> {}
pub struct BorrowGuardArmed<'g, T: ?Sized> {
inner_ref: NonNull<T>,
lend_waiter: &'g AtomicWaker,
lend_waiter_state: &'g AtomicWakerState,
inner: &'g BorrowRef,
}
impl<'g, T: ?Sized> BorrowGuardArmed<'g, T> {
pub fn map<U: ?Sized, F>(orig: Self, f: F) -> BorrowGuardArmed<'g, U>
where
F: FnOnce(&mut T) -> &mut U,
{
let inner_ref = f(unsafe { &mut *orig.inner_ref.as_ptr() });
let orig = ManuallyDrop::new(orig);
BorrowGuardArmed {
inner_ref: NonNull::from(inner_ref),
lend_waiter: orig.lend_waiter,
lend_waiter_state: orig.lend_waiter_state,
inner: orig.inner,
}
}
}
impl<'g, T: ?Sized> core::ops::Deref for BorrowGuardArmed<'g, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.inner_ref.as_ref() }
}
}
impl<'g, T: ?Sized> core::ops::DerefMut for BorrowGuardArmed<'g, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.inner_ref.as_mut() }
}
}
impl<'m, T: ?Sized> Drop for BorrowGuardArmed<'m, T> {
fn drop(&mut self) {
self.inner.guard_present.store(false, Ordering::Release);
atomic_waker::wake(self.lend_waiter, self.lend_waiter_state);
}
}
unsafe impl<'m, T: ?Sized + Send> Send for BorrowGuardArmed<'m, T> {}
unsafe impl<'m, T: ?Sized + Sync> Sync for BorrowGuardArmed<'m, T> {}
pub struct LendWaiter<'m, T: ?Sized> {
mutex: BorrowMutexRef<'m, T>,
}
impl<'m, T: ?Sized> Future for LendWaiter<'m, T> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if !self.mutex.borrowers().is_empty() {
return Poll::Ready(());
}
while atomic_waker::poll_const(
&self.mutex.lend_waiter,
&self.mutex.lend_waiter_state,
cx.waker(),
) == Poll::Ready(())
{
if !self.mutex.borrowers().is_empty() {
return Poll::Ready(());
}
}
Poll::Pending
}
}
unsafe impl<'g, T: ?Sized> Send for LendWaiter<'g, T> {}
pub struct LendGuard<'l, T: ?Sized> {
mutex: BorrowMutexRef<'l, T>,
borrow: &'l BorrowRef,
_marker: PhantomPinned,
}
impl<'m, T: ?Sized> Future for LendGuard<'m, T> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if !self.borrow.ref_acquired.swap(true, Ordering::Relaxed) {
atomic_waker::wake(&self.borrow.borrow_waker, &self.borrow.borrow_waker_state);
if !self.borrow.guard_present.load(Ordering::Acquire) {
return Poll::Ready(());
}
}
while atomic_waker::poll_const(
&self.mutex.lend_waiter,
&self.mutex.lend_waiter_state,
cx.waker(),
) == Poll::Ready(())
{
if !self.borrow.guard_present.load(Ordering::Acquire) {
return Poll::Ready(());
}
}
Poll::Pending
}
}
impl<'l, T: ?Sized> Drop for LendGuard<'l, T> {
fn drop(&mut self) {
let guard_present = self.borrow.guard_present.load(Ordering::Acquire);
if self.borrow.ref_acquired.load(Ordering::Relaxed) && guard_present {
abort("LendGuard dropped while the reference is still borrowed");
}
unsafe { *self.mutex.inner_ref.get() = None };
if self.borrow.ref_acquired.load(Ordering::Relaxed) {
let _ = self.mutex.borrowers().pop().unwrap();
}
let _ = self.mutex.state.compare_exchange(
LendState::Lending as u8,
LendState::None as u8,
Ordering::Release,
Ordering::Relaxed,
);
}
}
unsafe impl<'l, T: ?Sized + Send> Send for LendGuard<'l, T> {}
#[cfg(feature = "std")]
static ABORT_FN: AtomicPtr<fn() -> !> = AtomicPtr::new(std::process::abort as *mut _);
#[cfg(feature = "std")]
#[doc(hidden)]
pub unsafe fn set_abort_fn(f: fn() -> !) {
ABORT_FN.store(f as *mut _, Ordering::Relaxed);
}
fn abort(msg: &str) -> ! {
#[cfg(feature = "std")]
{
use std::io::Write;
let _ = std::io::stderr().write_all(msg.as_bytes());
let _ = std::io::stderr().write_all(b"\n");
let _ = std::io::stderr().flush();
let abort_fn =
unsafe { *(&ABORT_FN.load(Ordering::Relaxed) as *const _ as *mut fn() -> !) };
abort_fn();
}
#[cfg(not(feature = "std"))]
{
panic!("{msg}");
}
}
#[cfg(test)]
mod tests {
use super::BorrowMutex;
#[test]
fn validate_borrowers_field_offset() {
assert_eq!(
BorrowMutex::<0, usize>::borrowers_offset(),
core::mem::offset_of!(BorrowMutex<0, usize>, borrowers)
);
assert_eq!(
BorrowMutex::<0, &dyn core::any::Any>::borrowers_offset(),
core::mem::offset_of!(BorrowMutex<0, &dyn core::any::Any>, borrowers)
);
}
}