Skip to main content

compio_runtime/
cancel.rs

1use std::{
2    cell::{Cell, RefCell},
3    collections::HashSet,
4    mem,
5    ops::DerefMut,
6    pin::Pin,
7    rc::Rc,
8    task::{Context, Poll},
9};
10
11use compio_driver::{Cancel, Key, OpCode};
12use futures_util::{FutureExt, ready};
13use synchrony::unsync::event::{Event, EventListener};
14
15use crate::{ContextExt, Runtime};
16
17#[derive(Debug)]
18struct Inner {
19    tokens: RefCell<HashSet<Cancel>>,
20    is_cancelled: Cell<bool>,
21    runtime: Runtime,
22    notify: Event,
23}
24
25/// A token that can be used to cancel multiple operations at once.
26///
27/// When [`CancelToken::cancel`] is called, all operations that have been
28/// registered with this token will be cancelled.
29///
30/// It is also possible to use [`CancelToken::wait`] to wait until the token is
31/// cancelled, which can be useful for implementing timeouts or other
32/// cancellation-based logic.
33///
34/// To associate a future with this cancel token, use the [`with_cancel`]
35/// combinator from the [`FutureExt`] trait.
36///
37/// [`with_cancel`]: crate::future::FutureExt::with_cancel
38/// [`FutureExt`]: crate::future::FutureExt
39#[derive(Clone, Debug)]
40pub struct CancelToken(Rc<Inner>);
41
42impl PartialEq for CancelToken {
43    fn eq(&self, other: &Self) -> bool {
44        Rc::ptr_eq(&self.0, &other.0)
45    }
46}
47
48impl Eq for CancelToken {}
49
50impl CancelToken {
51    /// Create a new cancel token.
52    ///
53    /// # Panics
54    ///
55    /// [`CancelToken`] can only be created within compio runtime environment.
56    /// This will panic without a runtime.
57    pub fn new() -> Self {
58        Self(Rc::new(Inner {
59            tokens: RefCell::new(HashSet::new()),
60            is_cancelled: Cell::new(false),
61            runtime: Runtime::current(),
62            notify: Event::new(),
63        }))
64    }
65
66    pub(crate) fn listen(&self) -> EventListener {
67        self.0.notify.listen()
68    }
69
70    /// Cancel all operations registered with this token.
71    pub fn cancel(self) {
72        self.0.notify.notify_all();
73        if self.0.is_cancelled.replace(true) {
74            return;
75        }
76        let tokens = mem::take(self.0.tokens.borrow_mut().deref_mut());
77        for t in tokens {
78            self.0.runtime.cancel_token(t);
79        }
80    }
81
82    /// Check if this token has been cancelled.
83    pub fn is_cancelled(&self) -> bool {
84        self.0.is_cancelled.get()
85    }
86
87    /// Register an operation with this token.
88    ///
89    /// If the token has already been cancelled, the operation will be cancelled
90    /// immediately. Usually this method should not be used directly, but rather
91    /// through the [`with_cancel`] combinator.
92    ///
93    /// Multiple registrations of the same key does nothing, and the key will
94    /// only be cancelled once.
95    ///
96    /// [`with_cancel`]: crate::FutureExt::with_cancel
97    pub fn register<T: OpCode>(&self, key: &Key<T>) {
98        if self.0.is_cancelled.get() {
99            self.0.runtime.cancel(key.clone());
100        } else {
101            let token = self.0.runtime.register_cancel(key);
102            self.0.tokens.borrow_mut().insert(token);
103        }
104    }
105
106    /// Wait until this token is cancelled.
107    pub fn wait(self) -> WaitFuture {
108        WaitFuture::new(self)
109    }
110
111    /// Try to get the current cancel token associated with the future.
112    ///
113    /// This is done by checking if the current context has a cancel token
114    /// associated with it.
115    pub async fn current() -> Option<Self> {
116        std::future::poll_fn(|cx| Poll::Ready(cx.get_cancel().cloned())).await
117    }
118}
119
120impl Default for CancelToken {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126/// Future returned by [`CancelToken::wait`].
127pub struct WaitFuture {
128    listen: EventListener,
129    token: CancelToken,
130}
131
132impl WaitFuture {
133    fn new(token: CancelToken) -> WaitFuture {
134        WaitFuture {
135            listen: token.listen(),
136            token,
137        }
138    }
139}
140
141impl Future for WaitFuture {
142    type Output = ();
143
144    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
145        loop {
146            if self.token.is_cancelled() {
147                return Poll::Ready(());
148            } else {
149                ready!(self.listen.poll_unpin(cx))
150            }
151        }
152    }
153}