use std::cell::RefCell;
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use super::CURRENT_TASK_ID;
use super::io::try_with_state;
struct State {
cancelled: bool,
waiters: Vec<u32>,
children: Vec<Rc<RefCell<State>>>,
}
pub struct CancellationToken {
state: Rc<RefCell<State>>,
}
impl CancellationToken {
pub fn new() -> Self {
CancellationToken {
state: Rc::new(RefCell::new(State {
cancelled: false,
waiters: Vec::new(),
children: Vec::new(),
})),
}
}
pub fn cancel(&self) {
cancel_state(&self.state);
}
pub fn is_cancelled(&self) -> bool {
self.state.borrow().cancelled
}
pub fn cancelled(&self) -> CancelledFuture {
CancelledFuture {
state: Rc::clone(&self.state),
}
}
pub fn child_token(&self) -> CancellationToken {
let child = CancellationToken::new();
let mut s = self.state.borrow_mut();
if s.cancelled {
drop(s);
child.cancel();
} else {
s.children.push(Rc::clone(&child.state));
}
child
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
impl Clone for CancellationToken {
fn clone(&self) -> Self {
CancellationToken {
state: Rc::clone(&self.state),
}
}
}
fn cancel_state(state: &Rc<RefCell<State>>) {
let (waiters, children) = {
let mut s = state.borrow_mut();
if s.cancelled {
return;
}
s.cancelled = true;
let waiters = std::mem::take(&mut s.waiters);
let children = std::mem::take(&mut s.children);
(waiters, children)
};
for waiter_id in waiters {
try_with_state(|_driver, executor| {
executor.wake_task(waiter_id);
});
}
for child in children {
cancel_state(&child);
}
}
pub struct CancelledFuture {
state: Rc<RefCell<State>>,
}
impl Future for CancelledFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
let mut s = self.state.borrow_mut();
if s.cancelled {
return Poll::Ready(());
}
let task_id = CURRENT_TASK_ID.with(|c| c.get());
if !s.waiters.contains(&task_id) {
s.waiters.push(task_id);
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_token_is_not_cancelled() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
}
#[test]
fn cancel_sets_flag() {
let token = CancellationToken::new();
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn cancel_is_idempotent() {
let token = CancellationToken::new();
token.cancel();
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn clone_shares_state() {
let token = CancellationToken::new();
let clone = token.clone();
token.cancel();
assert!(clone.is_cancelled());
}
#[test]
fn child_cancelled_with_parent() {
let parent = CancellationToken::new();
let child = parent.child_token();
assert!(!child.is_cancelled());
parent.cancel();
assert!(child.is_cancelled());
}
#[test]
fn child_cancelled_independently() {
let parent = CancellationToken::new();
let child = parent.child_token();
child.cancel();
assert!(child.is_cancelled());
assert!(!parent.is_cancelled());
}
#[test]
fn child_of_cancelled_parent_is_cancelled() {
let parent = CancellationToken::new();
parent.cancel();
let child = parent.child_token();
assert!(child.is_cancelled());
}
#[test]
fn grandchild_cancelled_with_grandparent() {
let gp = CancellationToken::new();
let parent = gp.child_token();
let child = parent.child_token();
gp.cancel();
assert!(parent.is_cancelled());
assert!(child.is_cancelled());
}
#[test]
fn default_is_new() {
let token = CancellationToken::default();
assert!(!token.is_cancelled());
}
}