1use std::{
2 cell::{Cell, RefCell},
3 collections::HashSet,
4 mem,
5 ops::DerefMut,
6 pin::Pin,
7 rc::Rc,
8 task::{Context, Poll},
9};
10
11use compio_driver::{Cancel, Key, OpCode};
12use futures_util::{FutureExt, ready};
13use synchrony::unsync::event::{Event, EventListener};
14
15use crate::{ContextExt, Runtime};
16
17#[derive(Debug)]
18struct Inner {
19 tokens: RefCell<HashSet<Cancel>>,
20 is_cancelled: Cell<bool>,
21 runtime: Runtime,
22 notify: Event,
23}
24
25#[derive(Clone, Debug)]
40pub struct CancelToken(Rc<Inner>);
41
42impl PartialEq for CancelToken {
43 fn eq(&self, other: &Self) -> bool {
44 Rc::ptr_eq(&self.0, &other.0)
45 }
46}
47
48impl Eq for CancelToken {}
49
50impl CancelToken {
51 pub fn new() -> Self {
58 Self(Rc::new(Inner {
59 tokens: RefCell::new(HashSet::new()),
60 is_cancelled: Cell::new(false),
61 runtime: Runtime::current(),
62 notify: Event::new(),
63 }))
64 }
65
66 pub(crate) fn listen(&self) -> EventListener {
67 self.0.notify.listen()
68 }
69
70 pub fn cancel(self) {
72 self.0.notify.notify_all();
73 if self.0.is_cancelled.replace(true) {
74 return;
75 }
76 let tokens = mem::take(self.0.tokens.borrow_mut().deref_mut());
77 for t in tokens {
78 self.0.runtime.cancel_token(t);
79 }
80 }
81
82 pub fn is_cancelled(&self) -> bool {
84 self.0.is_cancelled.get()
85 }
86
87 pub fn register<T: OpCode>(&self, key: &Key<T>) {
98 if self.0.is_cancelled.get() {
99 self.0.runtime.cancel(key.clone());
100 } else {
101 let token = self.0.runtime.register_cancel(key);
102 self.0.tokens.borrow_mut().insert(token);
103 }
104 }
105
106 pub fn wait(self) -> WaitFuture {
108 WaitFuture::new(self)
109 }
110
111 pub async fn current() -> Option<Self> {
116 std::future::poll_fn(|cx| Poll::Ready(cx.get_cancel().cloned())).await
117 }
118}
119
120impl Default for CancelToken {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126pub struct WaitFuture {
128 listen: EventListener,
129 token: CancelToken,
130}
131
132impl WaitFuture {
133 fn new(token: CancelToken) -> WaitFuture {
134 WaitFuture {
135 listen: token.listen(),
136 token,
137 }
138 }
139}
140
141impl Future for WaitFuture {
142 type Output = ();
143
144 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
145 loop {
146 if self.token.is_cancelled() {
147 return Poll::Ready(());
148 } else {
149 ready!(self.listen.poll_unpin(cx))
150 }
151 }
152 }
153}