use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Mutex;
use crate::util::intrusive_double_linked_list::{LinkedList, ListNode};
use core::future::Future;
use core::pin::Pin;
use core::ptr::NonNull;
use core::sync::atomic::Ordering;
use core::task::{Context, Poll, Waker};
pub struct CancellationToken {
inner: NonNull<CancellationTokenState>,
}
unsafe impl Send for CancellationToken {}
unsafe impl Sync for CancellationToken {}
#[must_use = "futures do nothing unless polled"]
pub struct WaitForCancellationFuture<'a> {
cancellation_token: Option<&'a CancellationToken>,
wait_node: ListNode<WaitQueueEntry>,
is_registered: bool,
}
unsafe impl<'a> Send for WaitForCancellationFuture<'a> {}
impl core::fmt::Debug for CancellationToken {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("CancellationToken")
.field("is_cancelled", &self.is_cancelled())
.finish()
}
}
impl Clone for CancellationToken {
fn clone(&self) -> Self {
let inner = self.state();
let current_state = inner.snapshot();
inner.increment_refcount(current_state);
CancellationToken { inner: self.inner }
}
}
impl Drop for CancellationToken {
fn drop(&mut self) {
let token_state_pointer = self.inner;
let inner = unsafe { &mut *self.inner.as_ptr() };
let mut current_state = inner.snapshot();
let parent = inner.parent;
current_state = inner.decrement_refcount(current_state);
if current_state.refcount == 0 {
if let Some(mut parent) = parent {
let parent = unsafe { parent.as_mut() };
parent.unregister_child(token_state_pointer, current_state);
}
}
}
}
impl CancellationToken {
pub fn new() -> CancellationToken {
let state = Box::new(CancellationTokenState::new(
None,
StateSnapshot {
cancel_state: CancellationState::NotCancelled,
has_parent_ref: false,
refcount: 1,
},
));
CancellationToken {
inner: unsafe { NonNull::new_unchecked(Box::into_raw(state)) },
}
}
fn state(&self) -> &CancellationTokenState {
unsafe { &*self.inner.as_ptr() }
}
pub fn child_token(&self) -> CancellationToken {
let inner = self.state();
let _current_state = inner.increment_refcount(inner.snapshot());
let mut unpacked_child_state = StateSnapshot {
has_parent_ref: true,
refcount: 1,
cancel_state: CancellationState::NotCancelled,
};
let mut child_token_state = Box::new(CancellationTokenState::new(
Some(self.inner),
unpacked_child_state,
));
{
let mut guard = inner.synchronized.lock().unwrap();
if guard.is_cancelled {
(*child_token_state.synchronized.lock().unwrap()).is_cancelled = true;
unpacked_child_state.cancel_state = CancellationState::Cancelled;
unpacked_child_state.has_parent_ref = false;
child_token_state
.state
.store(unpacked_child_state.pack(), Ordering::SeqCst);
} else {
if let Some(mut first_child) = guard.first_child {
child_token_state.from_parent.next_peer = Some(first_child);
unsafe {
first_child.as_mut().from_parent.prev_peer =
Some((&mut *child_token_state).into())
};
}
guard.first_child = Some((&mut *child_token_state).into());
}
};
let child_token_ptr = Box::into_raw(child_token_state);
CancellationToken {
inner: unsafe { NonNull::new_unchecked(child_token_ptr) },
}
}
pub fn cancel(&self) {
self.state().cancel();
}
pub fn is_cancelled(&self) -> bool {
self.state().is_cancelled()
}
pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
WaitForCancellationFuture {
cancellation_token: Some(self),
wait_node: ListNode::new(WaitQueueEntry::new()),
is_registered: false,
}
}
unsafe fn register(
&self,
wait_node: &mut ListNode<WaitQueueEntry>,
cx: &mut Context<'_>,
) -> Poll<()> {
self.state().register(wait_node, cx)
}
fn check_for_cancellation(
&self,
wait_node: &mut ListNode<WaitQueueEntry>,
cx: &mut Context<'_>,
) -> Poll<()> {
self.state().check_for_cancellation(wait_node, cx)
}
fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) {
self.state().unregister(wait_node)
}
}
impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("WaitForCancellationFuture").finish()
}
}
impl<'a> Future for WaitForCancellationFuture<'a> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let mut_self: &mut WaitForCancellationFuture<'_> = unsafe { Pin::get_unchecked_mut(self) };
let cancellation_token = mut_self
.cancellation_token
.expect("polled WaitForCancellationFuture after completion");
let poll_res = if !mut_self.is_registered {
unsafe { cancellation_token.register(&mut mut_self.wait_node, cx) }
} else {
cancellation_token.check_for_cancellation(&mut mut_self.wait_node, cx)
};
if let Poll::Ready(()) = poll_res {
mut_self.cancellation_token = None;
mut_self.is_registered = false;
mut_self.wait_node.task = None;
} else {
mut_self.is_registered = true;
}
poll_res
}
}
impl<'a> Drop for WaitForCancellationFuture<'a> {
fn drop(&mut self) {
if let Some(token) = self.cancellation_token {
if self.is_registered {
token.unregister(&mut self.wait_node);
}
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum PollState {
New,
Waiting,
Done,
}
struct WaitQueueEntry {
task: Option<Waker>,
state: PollState,
}
impl WaitQueueEntry {
fn new() -> WaitQueueEntry {
WaitQueueEntry {
task: None,
state: PollState::New,
}
}
}
struct SynchronizedState {
waiters: LinkedList<WaitQueueEntry>,
first_child: Option<NonNull<CancellationTokenState>>,
is_cancelled: bool,
}
impl SynchronizedState {
fn new() -> Self {
Self {
waiters: LinkedList::new(),
first_child: None,
is_cancelled: false,
}
}
}
struct SynchronizedThroughParent {
next_peer: Option<NonNull<CancellationTokenState>>,
prev_peer: Option<NonNull<CancellationTokenState>>,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum CancellationState {
NotCancelled = 0,
Cancelling = 1,
Cancelled = 2,
}
impl CancellationState {
fn pack(self) -> usize {
self as usize
}
fn unpack(value: usize) -> Self {
match value {
0 => CancellationState::NotCancelled,
1 => CancellationState::Cancelling,
2 => CancellationState::Cancelled,
_ => unreachable!("Invalid value"),
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
struct StateSnapshot {
refcount: usize,
has_parent_ref: bool,
cancel_state: CancellationState,
}
impl StateSnapshot {
fn pack(self) -> usize {
self.refcount << 3 | if self.has_parent_ref { 4 } else { 0 } | self.cancel_state.pack()
}
fn unpack(value: usize) -> Self {
let refcount = value >> 3;
let has_parent_ref = value & 4 != 0;
let cancel_state = CancellationState::unpack(value & 0x03);
StateSnapshot {
refcount,
has_parent_ref,
cancel_state,
}
}
fn has_refs(&self) -> bool {
self.refcount != 0 || self.has_parent_ref
}
}
const MAX_REFS: u32 = (std::u32::MAX - 7) >> 3;
struct CancellationTokenState {
state: AtomicUsize,
parent: Option<NonNull<CancellationTokenState>>,
from_parent: SynchronizedThroughParent,
synchronized: Mutex<SynchronizedState>,
}
impl CancellationTokenState {
fn new(
parent: Option<NonNull<CancellationTokenState>>,
state: StateSnapshot,
) -> CancellationTokenState {
CancellationTokenState {
parent,
from_parent: SynchronizedThroughParent {
prev_peer: None,
next_peer: None,
},
state: AtomicUsize::new(state.pack()),
synchronized: Mutex::new(SynchronizedState::new()),
}
}
fn snapshot(&self) -> StateSnapshot {
StateSnapshot::unpack(self.state.load(Ordering::SeqCst))
}
fn atomic_update_state<F>(&self, mut current_state: StateSnapshot, func: F) -> StateSnapshot
where
F: Fn(StateSnapshot) -> StateSnapshot,
{
let mut current_packed_state = current_state.pack();
loop {
let next_state = func(current_state);
match self.state.compare_exchange(
current_packed_state,
next_state.pack(),
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => {
return next_state;
}
Err(actual) => {
current_packed_state = actual;
current_state = StateSnapshot::unpack(actual);
}
}
}
}
fn increment_refcount(&self, current_state: StateSnapshot) -> StateSnapshot {
self.atomic_update_state(current_state, |mut state: StateSnapshot| {
if state.refcount >= MAX_REFS as usize {
eprintln!("[ERROR] Maximum reference count for CancellationToken was exceeded");
std::process::abort();
}
state.refcount += 1;
state
})
}
fn decrement_refcount(&self, current_state: StateSnapshot) -> StateSnapshot {
let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| {
state.refcount -= 1;
state
});
if !current_state.has_refs() {
let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) };
}
current_state
}
fn remove_parent_ref(&self, current_state: StateSnapshot) -> StateSnapshot {
let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| {
state.has_parent_ref = false;
state
});
if !current_state.has_refs() {
let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) };
}
current_state
}
fn unregister_child(
&mut self,
mut child_state: NonNull<CancellationTokenState>,
current_child_state: StateSnapshot,
) {
let removed_child = {
let mut guard = self.synchronized.lock().unwrap();
if !guard.is_cancelled {
let mut child_state = unsafe { child_state.as_mut() };
debug_assert!(child_state.snapshot().has_parent_ref);
if guard.first_child == Some(child_state.into()) {
guard.first_child = child_state.from_parent.next_peer;
}
unsafe {
if let Some(mut prev_peer) = child_state.from_parent.prev_peer {
prev_peer.as_mut().from_parent.next_peer =
child_state.from_parent.next_peer;
}
if let Some(mut next_peer) = child_state.from_parent.next_peer {
next_peer.as_mut().from_parent.prev_peer =
child_state.from_parent.prev_peer;
}
}
child_state.from_parent.prev_peer = None;
child_state.from_parent.next_peer = None;
true
} else {
false
}
};
if removed_child {
unsafe { child_state.as_mut().remove_parent_ref(current_child_state) };
}
self.decrement_refcount(self.snapshot());
}
fn cancel(&self) {
let mut current_state = self.snapshot();
let state_after_cancellation = loop {
if current_state.cancel_state != CancellationState::NotCancelled {
return;
}
let mut next_state = current_state;
next_state.cancel_state = CancellationState::Cancelling;
match self.state.compare_exchange(
current_state.pack(),
next_state.pack(),
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => break next_state,
Err(actual) => current_state = StateSnapshot::unpack(actual),
}
};
let mut first_child = {
let mut guard = self.synchronized.lock().unwrap();
guard.is_cancelled = true;
guard.waiters.reverse_drain(|waiter| {
if let Some(handle) = &mut waiter.task {
handle.wake_by_ref();
}
waiter.state = PollState::Done;
});
guard.first_child.take()
};
while let Some(mut child) = first_child {
let mut_child = unsafe { child.as_mut() };
first_child = mut_child.from_parent.next_peer;
mut_child.from_parent.prev_peer = None;
mut_child.from_parent.next_peer = None;
mut_child.cancel();
mut_child.remove_parent_ref(mut_child.snapshot());
}
self.atomic_update_state(state_after_cancellation, |mut state| {
state.cancel_state = CancellationState::Cancelled;
state
});
}
fn is_cancelled(&self) -> bool {
let current_state = self.snapshot();
current_state.cancel_state != CancellationState::NotCancelled
}
unsafe fn register(
&self,
wait_node: &mut ListNode<WaitQueueEntry>,
cx: &mut Context<'_>,
) -> Poll<()> {
debug_assert_eq!(PollState::New, wait_node.state);
let current_state = self.snapshot();
if current_state.cancel_state != CancellationState::NotCancelled {
return Poll::Ready(());
}
let mut guard = self.synchronized.lock().unwrap();
if guard.is_cancelled {
wait_node.state = PollState::Done;
Poll::Ready(())
} else {
wait_node.task = Some(cx.waker().clone());
wait_node.state = PollState::Waiting;
guard.waiters.add_front(wait_node);
Poll::Pending
}
}
fn check_for_cancellation(
&self,
wait_node: &mut ListNode<WaitQueueEntry>,
cx: &mut Context<'_>,
) -> Poll<()> {
debug_assert!(
wait_node.task.is_some(),
"Method can only be called after task had been registered"
);
let current_state = self.snapshot();
if current_state.cancel_state != CancellationState::NotCancelled {
if current_state.cancel_state != CancellationState::Cancelled {
self.unregister(wait_node);
}
Poll::Ready(())
} else {
let need_waker_update = wait_node
.task
.as_ref()
.map(|waker| waker.will_wake(cx.waker()))
.unwrap_or(true);
if need_waker_update {
let guard = self.synchronized.lock().unwrap();
if guard.is_cancelled {
debug_assert_eq!(PollState::Done, wait_node.state);
wait_node.task = None;
Poll::Ready(())
} else {
wait_node.task = Some(cx.waker().clone());
Poll::Pending
}
} else {
Poll::Pending
}
}
}
fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) {
debug_assert!(
wait_node.task.is_some(),
"waiter can not be active without task"
);
let mut guard = self.synchronized.lock().unwrap();
if let PollState::Waiting = wait_node.state {
if !unsafe { guard.waiters.remove(wait_node) } {
panic!("Future could not be removed from wait queue");
}
wait_node.state = PollState::Done;
}
wait_node.task = None;
}
}