#![warn(
clippy::pedantic,
missing_debug_implementations,
missing_docs,
noop_method_call,
trivial_casts,
trivial_numeric_casts,
unsafe_op_in_unsafe_fn,
unused_lifetimes,
unused_qualifications
)]
#![allow(
clippy::items_after_statements,
// `ǃ` (latin letter retroflex click) is used in the tests for a never type
uncommon_codepoints,
)]
#![no_std]
#![cfg_attr(doc_nightly, feature(doc_cfg))]
#[cfg(feature = "alloc")]
extern crate alloc;
#[cfg(feature = "std")]
extern crate std;
#[cfg(feature = "lock_api_04")]
pub extern crate lock_api_04_crate as lock_api_04;
#[cfg(feature = "loom_05")]
pub extern crate loom_05_crate as loom_05;
use core::cell::UnsafeCell;
use core::fmt;
use core::fmt::Debug;
use core::fmt::Formatter;
use core::future::Future;
use core::mem;
use core::ops::Deref;
use core::pin::Pin;
use core::ptr;
use core::ptr::NonNull;
use core::task;
use core::task::Poll;
use pin_project_lite::pin_project;
use pinned_aliasable::Aliasable;
pub mod lock;
#[doc(no_inline)]
pub use lock::Lock;
pub struct WaitList<L: Lock, I, O> {
pub lock: L,
inner: UnsafeCell<Inner<I, O>>,
}
unsafe impl<L: Lock, I, O> Send for WaitList<L, I, O>
where
L: Send,
I: Send,
O:,
{
}
unsafe impl<L: Lock, I, O> Sync for WaitList<L, I, O>
where
L: Sync,
I: Send + Sync,
O: Send,
{
}
struct Inner<I, O> {
head: Option<NonNull<UnsafeCell<Waiter<I, O>>>>,
tail: Option<NonNull<UnsafeCell<Waiter<I, O>>>>,
}
struct Waiter<I, O> {
next: Option<NonNull<UnsafeCell<Waiter<I, O>>>>,
prev: Option<NonNull<UnsafeCell<Waiter<I, O>>>>,
state: WaiterState<I, O>,
}
enum WaiterState<I, O> {
Waiting { input: I, waker: task::Waker },
Woken { output: O },
}
impl<I, O> Inner<I, O> {
unsafe fn enqueue(&mut self, waiter: &UnsafeCell<Waiter<I, O>>) {
unsafe { &mut *waiter.get() }.prev = self.tail;
let waiter_ptr = NonNull::from(waiter);
if let Some(prev) = self.tail {
let prev = unsafe { &mut *prev.as_ref().get() };
debug_assert_eq!(prev.next, None);
prev.next = Some(waiter_ptr);
}
self.tail = Some(waiter_ptr);
self.head.get_or_insert(waiter_ptr);
}
unsafe fn dequeue(&mut self, waiter: &UnsafeCell<Waiter<I, O>>) {
let waiter_ptr = Some(NonNull::from(waiter));
let waiter = unsafe { &mut *waiter.get() };
let prev = waiter.prev;
let next = waiter.next;
let prev_next_pointer = match waiter.prev {
Some(prev) => &mut unsafe { &mut *prev.as_ref().get() }.next,
None => &mut self.head,
};
debug_assert_eq!(*prev_next_pointer, waiter_ptr);
*prev_next_pointer = next;
let next_prev_pointer = match waiter.next {
Some(next) => &mut unsafe { &mut *next.as_ref().get() }.prev,
None => &mut self.tail,
};
debug_assert_eq!(*next_prev_pointer, waiter_ptr);
*next_prev_pointer = prev;
}
}
impl<L, I, O> WaitList<L, I, O>
where
<core::iter::Empty<L> as Iterator>::Item: Lock,
{
#[must_use]
pub const fn new(lock: L) -> Self {
Self {
lock,
inner: UnsafeCell::new(Inner {
head: None,
tail: None,
}),
}
}
#[must_use]
pub fn lock_exclusive(&self) -> LockedExclusive<'_, L, I, O> {
LockedExclusive {
guard: self.lock.lock_exclusive(),
common: LockedCommon { wait_list: self },
}
}
#[must_use]
pub fn lock_shared(&self) -> LockedShared<'_, L, I, O> {
LockedShared {
guard: self.lock.lock_shared(),
common: LockedCommon { wait_list: self },
}
}
}
impl<L: Lock + Default, I, O> Default for WaitList<L, I, O> {
fn default() -> Self {
Self::new(L::default())
}
}
impl<L: Lock + Debug, I, O> Debug for WaitList<L, I, O> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("WaitList")
.field("lock", &self.lock)
.finish()
}
}
pub struct LockedExclusive<'wait_list, L: Lock, I, O> {
pub guard: <L as lock::Lifetime<'wait_list>>::ExclusiveGuard,
common: LockedCommon<'wait_list, L, I, O>,
}
impl<'wait_list, L: Lock, I, O> Deref for LockedExclusive<'wait_list, L, I, O> {
type Target = LockedCommon<'wait_list, L, I, O>;
fn deref(&self) -> &Self::Target {
&self.common
}
}
impl<'wait_list, L: Lock, I, O> LockedExclusive<'wait_list, L, I, O> {
fn inner_mut(&mut self) -> &mut Inner<I, O> {
unsafe { &mut *self.wait_list.inner.get() }
}
#[must_use]
pub fn head_input_mut(&mut self) -> Option<&mut I> {
Some(match unsafe { &mut (*self.head()?.get()).state } {
WaiterState::Waiting { input, waker: _ } => input,
WaiterState::Woken { .. } => unreachable!(),
})
}
pub fn init_and_wait<OnCancel>(
self,
input: I,
on_cancel: OnCancel,
) -> InitAndWait<'wait_list, L, I, O, OnCancel>
where
OnCancel: CancelCallback<'wait_list, L, I, O>,
{
InitAndWait {
input: Some(InitAndWaitInput {
lock: self,
input,
on_cancel,
}),
inner: Wait::new(),
}
}
pub fn pop(&mut self, output: O) -> Result<(I, task::Waker), O> {
let head = match self.inner_mut().head {
Some(head) => head,
None => return Err(output),
};
let (input, waker) = {
let head_waiter = unsafe { &mut *head.as_ref().get() };
let new_state = WaiterState::Woken { output };
match mem::replace(&mut head_waiter.state, new_state) {
WaiterState::Waiting { input, waker } => (input, waker),
WaiterState::Woken { .. } => unreachable!(),
}
};
unsafe { self.inner_mut().dequeue(head.as_ref()) };
Ok((input, waker))
}
pub fn wake_one(mut self, output: O) -> Result<I, O> {
let (input, waker) = self.pop(output)?;
drop(self);
waker.wake();
Ok(input)
}
}
impl<'wait_list, L: Lock + Debug, I, O> Debug for LockedExclusive<'wait_list, L, I, O>
where
<L as lock::Lifetime<'wait_list>>::ExclusiveGuard: Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("LockedExclusive")
.field("guard", &self.guard)
.field("common", &self.common)
.finish()
}
}
pub struct LockedShared<'wait_list, L: Lock, I, O> {
pub guard: <L as lock::Lifetime<'wait_list>>::SharedGuard,
common: LockedCommon<'wait_list, L, I, O>,
}
impl<'wait_list, L: Lock, I, O> Deref for LockedShared<'wait_list, L, I, O> {
type Target = LockedCommon<'wait_list, L, I, O>;
fn deref(&self) -> &Self::Target {
&self.common
}
}
impl<'wait_list, L: Lock + Debug, I, O> Debug for LockedShared<'wait_list, L, I, O>
where
<L as lock::Lifetime<'wait_list>>::SharedGuard: Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("LockedShared")
.field("guard", &self.guard)
.field("common", &self.common)
.finish()
}
}
#[non_exhaustive]
pub struct LockedCommon<'wait_list, L: Lock, I, O> {
pub wait_list: &'wait_list WaitList<L, I, O>,
}
impl<'wait_list, L: Lock, I, O> LockedCommon<'wait_list, L, I, O> {
fn inner(&self) -> &Inner<I, O> {
unsafe { &*self.wait_list.inner.get() }
}
fn head(&self) -> Option<&UnsafeCell<Waiter<I, O>>> {
Some(unsafe { self.inner().head?.as_ref() })
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner().head.is_none()
}
#[must_use]
pub fn head_input(&self) -> Option<&I> {
Some(match unsafe { &(*self.head()?.get()).state } {
WaiterState::Waiting { input, waker: _ } => input,
WaiterState::Woken { .. } => unreachable!(),
})
}
}
impl<'wait_list, L: Lock + Debug, I, O> Debug for LockedCommon<'wait_list, L, I, O> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("LockedCommon")
.field("wait_list", &self.wait_list)
.finish()
}
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Wait<'wait_list, L: Lock, I, O, OnCancel>
where
OnCancel: CancelCallback<'wait_list, L, I, O>,
{
inner: Option<WaitInner<'wait_list, L, I, O, OnCancel>>,
}
pin_project! {
struct WaitInner<'wait_list, L: Lock, I, O, OnCancel> {
wait_list: &'wait_list WaitList<L, I, O>,
#[pin]
waiter: Aliasable<UnsafeCell<Waiter<I, O>>>,
on_cancel: OnCancel,
}
}
unsafe impl<'wait_list, L: Lock, I, O, OnCancel> Send for WaitInner<'wait_list, L, I, O, OnCancel>
where
WaitList<L, I, O>: Sync,
OnCancel: Send,
{
}
unsafe impl<'wait_list, L: Lock, I, O, OnCancel> Sync for WaitInner<'wait_list, L, I, O, OnCancel>
where
WaitList<L, I, O>: Sync,
O: Sync,
{
}
impl<'wait_list, L: Lock, I, O, OnCancel> Wait<'wait_list, L, I, O, OnCancel>
where
OnCancel: CancelCallback<'wait_list, L, I, O>,
{
fn project(self: Pin<&mut Self>) -> Pin<&mut Option<WaitInner<'wait_list, L, I, O, OnCancel>>> {
let this = unsafe { Pin::into_inner_unchecked(self) };
unsafe { Pin::new_unchecked(&mut this.inner) }
}
}
impl<'wait_list, L: Lock, I, O, OnCancel> Wait<'wait_list, L, I, O, OnCancel>
where
OnCancel: CancelCallback<'wait_list, L, I, O>,
{
pub fn new() -> Self {
Self { inner: None }
}
#[must_use]
pub fn is_completed(&self) -> bool {
self.inner.is_none()
}
pub fn init(
self: Pin<&mut Self>,
waker: task::Waker,
guard: &mut LockedExclusive<'wait_list, L, I, O>,
input: I,
on_cancel: OnCancel,
) {
assert!(
self.as_ref().is_completed(),
"called `Wait::init` on an incomplete future"
);
let mut inner = self.project();
let waiter = Aliasable::new(UnsafeCell::new(Waiter {
next: None,
prev: None,
state: WaiterState::Waiting { input, waker },
}));
inner.set(Some(WaitInner {
wait_list: guard.wait_list,
waiter,
on_cancel,
}));
let inner = inner.as_ref().as_pin_ref().unwrap();
let waiter = inner.project_ref().waiter.get();
unsafe { guard.inner_mut().enqueue(waiter) };
}
pub fn init_without_waker(
self: Pin<&mut Self>,
guard: &mut LockedExclusive<'wait_list, L, I, O>,
input: I,
on_cancel: OnCancel,
) {
self.init(noop_waker(), guard, input, on_cancel);
}
fn inner(&self) -> Pin<&WaitInner<'wait_list, L, I, O, OnCancel>> {
let inner = self.inner.as_ref().expect("`Wait` is in completed state");
unsafe { Pin::new_unchecked(inner) }
}
#[must_use]
pub fn wait_list(&self) -> &'wait_list WaitList<L, I, O> {
self.inner().wait_list
}
}
impl<'wait_list, L: Lock, I, O, OnCancel> Future for Wait<'wait_list, L, I, O, OnCancel>
where
OnCancel: CancelCallback<'wait_list, L, I, O>,
{
type Output = (LockedExclusive<'wait_list, L, I, O>, O);
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let inner = self.inner().project_ref();
let lock = inner.wait_list.lock_exclusive();
let waiter = inner.waiter.get();
let waiter = unsafe { &mut *waiter.get() };
match &mut waiter.state {
WaiterState::Waiting { waker, .. } => {
if !waker.will_wake(cx.waker()) {
*waker = cx.waker().clone();
}
Poll::Pending
}
WaiterState::Woken { .. } => {
let inner = unsafe { Pin::into_inner_unchecked(self.project()) };
let old_inner = inner.take().unwrap();
let output = match old_inner.waiter.into_inner().into_inner().state {
WaiterState::Woken { output } => output,
WaiterState::Waiting { .. } => unreachable!(),
};
Poll::Ready((lock, output))
}
}
}
}
impl<'wait_list, L: Lock, I, O, OnCancel> Drop for Wait<'wait_list, L, I, O, OnCancel>
where
OnCancel: CancelCallback<'wait_list, L, I, O>,
{
fn drop(&mut self) {
let this = unsafe { Pin::new_unchecked(self) };
if this.is_completed() {
return;
}
let inner = this.inner().project_ref();
let abort_on_panic = PanicOnDrop;
let mut list = inner.wait_list.lock_exclusive();
let waiter = inner.waiter.as_ref().get();
if let WaiterState::Waiting { .. } = unsafe { &(*waiter.get()).state } {
unsafe { list.inner_mut().dequeue(waiter) };
}
mem::forget(abort_on_panic);
let inner = unsafe { Pin::into_inner_unchecked(this.project()) };
let old_inner = inner.take().unwrap();
let waiter = old_inner.waiter.into_inner().into_inner();
if let WaiterState::Woken { output } = waiter.state {
old_inner.on_cancel.on_cancel(list, output);
}
}
}
impl<'wait_list, L: Lock, I, O, OnCancel> Default for Wait<'wait_list, L, I, O, OnCancel>
where
OnCancel: CancelCallback<'wait_list, L, I, O>,
{
fn default() -> Self {
Self::new()
}
}
impl<'wait_list, L: Lock, I: Debug, O, OnCancel> Debug for Wait<'wait_list, L, I, O, OnCancel>
where
OnCancel: CancelCallback<'wait_list, L, I, O>,
L: Debug,
<L as lock::Lifetime<'wait_list>>::ExclusiveGuard: Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match &self.inner {
Some(inner) => f
.debug_struct("Wait::Waiting")
.field("wait_list", inner.wait_list)
.finish(),
None => f.pad("Wait::Done"),
}
}
}
pin_project! {
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct InitAndWait<'wait_list, L: Lock, I, O, OnCancel>
where
OnCancel: CancelCallback<'wait_list, L, I, O>,
{
input: Option<InitAndWaitInput<'wait_list, L, I, O, OnCancel>>,
#[pin]
inner: Wait<'wait_list, L, I, O, OnCancel>,
}
}
impl<'wait_list, L: Lock, I, O, OnCancel> Future for InitAndWait<'wait_list, L, I, O, OnCancel>
where
OnCancel: CancelCallback<'wait_list, L, I, O>,
{
type Output = (LockedExclusive<'wait_list, L, I, O>, O);
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Some(InitAndWaitInput {
mut lock,
input,
on_cancel,
}) = this.input.take()
{
this.inner
.init(cx.waker().clone(), &mut lock, input, on_cancel);
Poll::Pending
} else {
this.inner.poll(cx)
}
}
}
impl<'wait_list, L: Lock, I, O, OnCancel> Debug for InitAndWait<'wait_list, L, I, O, OnCancel>
where
OnCancel: CancelCallback<'wait_list, L, I, O>,
<L as lock::Lifetime<'wait_list>>::ExclusiveGuard: Debug,
I: Debug,
L: Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if let Some(input) = &self.input {
f.debug_struct("InitAndWait::Initial")
.field("lock", &input.lock)
.field("input", &input.input)
.finish()
} else {
f.debug_struct("InitAndWait::Waiting")
.field("inner", &self.inner)
.finish()
}
}
}
struct InitAndWaitInput<'wait_list, L: Lock, I, O, OnCancel> {
lock: LockedExclusive<'wait_list, L, I, O>,
input: I,
on_cancel: OnCancel,
}
pub trait CancelCallback<'wait_list, L: Lock, I, O>: Sized {
fn on_cancel(self, list: LockedExclusive<'wait_list, L, I, O>, output: O);
}
impl<'wait_list, L: Lock, I, O, F> CancelCallback<'wait_list, L, I, O> for F
where
L: 'wait_list,
I: 'wait_list,
O: 'wait_list,
F: FnOnce(LockedExclusive<'wait_list, L, I, O>, O),
{
fn on_cancel(self, list: LockedExclusive<'wait_list, L, I, O>, output: O) {
self(list, output);
}
}
struct PanicOnDrop;
impl Drop for PanicOnDrop {
fn drop(&mut self) {
panic!();
}
}
const fn noop_waker() -> task::Waker {
const VTABLE: task::RawWakerVTable = task::RawWakerVTable::new(
|_| RAW,
|_| {},
|_| {},
|_| {},
);
const RAW: task::RawWaker = task::RawWaker::new(ptr::null(), &VTABLE);
unsafe { mem::transmute::<task::RawWaker, task::Waker>(RAW) }
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::WaitList;
use crate::lock;
use crate::lock::Lock;
use alloc::boxed::Box;
use core::future::Future;
use core::task;
use core::task::Poll;
#[derive(Debug, PartialEq)]
enum ǃ {}
#[test]
fn wake_empty() {
let list = <WaitList<lock::Local<()>, ǃ, Box<u32>>>::default();
assert_eq!(*list.lock_exclusive().wake_one(Box::new(1)).unwrap_err(), 1);
assert_eq!(*list.lock_exclusive().wake_one(Box::new(2)).unwrap_err(), 2);
assert_eq!(list.lock_exclusive().head_input(), None);
assert_eq!(list.lock_exclusive().head_input_mut(), None);
assert!(list.lock_shared().is_empty());
}
#[test]
fn cancel() {
let cx = &mut noop_cx();
let list = <WaitList<lock::Local<()>, Box<u32>, ǃ>>::default();
let mut future = Box::pin(list.lock_exclusive().init_and_wait(Box::new(5), no_cancel));
for _ in 0..10 {
assert!(future.as_mut().poll(cx).is_pending());
}
assert_eq!(**list.lock_exclusive().head_input().unwrap(), 5);
assert!(!list.lock_shared().is_empty());
drop(future);
assert_eq!(list.lock_exclusive().head_input(), None);
assert!(list.lock_shared().is_empty());
}
#[test]
fn wake_single() {
let cx = &mut noop_cx();
let list = <WaitList<lock::Local<()>, Box<u8>, Box<u32>>>::default();
let mut future = Box::pin(list.lock_exclusive().init_and_wait(Box::new(5), no_cancel));
assert!(future.as_mut().poll(cx).is_pending());
assert_eq!(*list.lock_exclusive().wake_one(Box::new(6)).unwrap(), 5);
assert_eq!(
future.as_mut().poll(cx).map(|(_, output)| output),
Poll::Ready(Box::new(6))
);
assert!(list.lock_shared().is_empty());
}
#[test]
fn wake_multiple() {
let cx = &mut noop_cx();
let list = <WaitList<lock::Local<()>, Box<u8>, Box<u32>>>::default();
let mut f1 = Box::pin(list.lock_exclusive().init_and_wait(Box::new(1), no_cancel));
assert!(f1.as_mut().poll(cx).is_pending());
let mut f2 = Box::pin(list.lock_exclusive().init_and_wait(Box::new(2), no_cancel));
assert!(f2.as_mut().poll(cx).is_pending());
assert_eq!(*list.lock_exclusive().wake_one(Box::new(11)).unwrap(), 1);
let mut f3_out = None;
let mut f3 = Box::pin(
list.lock_exclusive()
.init_and_wait(Box::new(3), |_, out| f3_out = Some(out)),
);
assert!(f3.as_mut().poll(cx).is_pending());
assert_eq!(*list.lock_exclusive().wake_one(Box::new(12)).unwrap(), 2);
assert_eq!(*list.lock_exclusive().wake_one(Box::new(13)).unwrap(), 3);
assert_eq!(*list.lock_exclusive().wake_one(Box::new(9)).unwrap_err(), 9);
assert_eq!(
f2.as_mut().poll(cx).map(|(_, output)| output),
Poll::Ready(Box::new(12))
);
assert_eq!(
f1.as_mut().poll(cx).map(|(_, output)| output),
Poll::Ready(Box::new(11))
);
drop(f3);
assert_eq!(f3_out, Some(Box::new(13)));
}
#[test]
fn drop_in_middle() {
let cx = &mut noop_cx();
let list = <WaitList<lock::Local<()>, Box<u32>, ǃ>>::default();
let mut f1 = Box::pin(list.lock_exclusive().init_and_wait(Box::new(1), no_cancel));
assert!(f1.as_mut().poll(cx).is_pending());
let mut f2 = Box::pin(list.lock_exclusive().init_and_wait(Box::new(2), no_cancel));
assert!(f2.as_mut().poll(cx).is_pending());
let mut f3 = Box::pin(list.lock_exclusive().init_and_wait(Box::new(3), no_cancel));
assert!(f3.as_mut().poll(cx).is_pending());
drop(f2);
drop(f3);
drop(f1);
assert!(list.lock_shared().is_empty());
}
#[test]
fn cancellation_waking_chain() {
let cx = &mut noop_cx();
let list = <WaitList<lock::Local<()>, Box<u8>, Box<u32>>>::default();
let mut f1 = Box::pin(list.lock_exclusive().init_and_wait(
Box::new(1),
|list: crate::LockedExclusive<_, Box<u8>, _>, mut output: Box<u32>| {
*output += 1;
assert_eq!(*list.wake_one(output).unwrap(), 2);
},
));
assert!(f1.as_mut().poll(cx).is_pending());
let mut f2 = Box::pin(list.lock_exclusive().init_and_wait(
Box::new(2),
|list: crate::LockedExclusive<_, Box<u8>, _>, mut output: Box<u32>| {
*output += 1;
assert_eq!(*list.wake_one(output).unwrap(), 3);
},
));
assert!(f2.as_mut().poll(cx).is_pending());
let mut final_output = None;
let mut f3 = Box::pin(list.lock_exclusive().init_and_wait(
Box::new(3),
|list: crate::LockedExclusive<_, Box<u8>, _>, output| {
assert!(list.is_empty());
final_output = Some(output);
},
));
assert!(f3.as_mut().poll(cx).is_pending());
assert_eq!(*list.lock_exclusive().wake_one(Box::new(12)).unwrap(), 1);
drop(f1);
drop(f2);
drop(f3);
assert_eq!(final_output, Some(Box::new(14)));
}
fn no_cancel<L: Lock, I, O>(_: crate::LockedExclusive<'_, L, I, O>, _: O) {
panic!("did not expect cancellation")
}
fn noop_cx() -> task::Context<'static> {
static WAKER: task::Waker = crate::noop_waker();
task::Context::from_waker(&WAKER)
}
}
#[cfg(test)]
mod test_util {
pub(crate) trait AssertSend {
fn assert_send(&self) {}
}
impl<T: ?Sized + Send> AssertSend for T {}
pub(crate) trait AssertNotSend<A> {
fn assert_not_send(&self) {}
}
impl<T: ?Sized> AssertNotSend<()> for T {}
impl<T: ?Sized + Send> AssertNotSend<u8> for T {}
}