1use std::{
2 cell::{Cell, RefCell},
3 collections::HashSet,
4 fmt::Debug,
5 mem,
6 ops::DerefMut,
7 pin::Pin,
8 rc::Rc,
9 task::{Context, Poll},
10};
11
12use compio_driver::{Cancel, Key, OpCode, Proactor};
13use futures_util::{FutureExt, ready};
14use synchrony::unsync::event::{Event, EventListener};
15
16use crate::{ContextExt, Runtime};
17
18struct Inner {
19 tokens: RefCell<HashSet<Cancel>>,
20 is_cancelled: Cell<bool>,
21 driver: Rc<RefCell<Proactor>>,
22 notify: Event,
23}
24
25impl Debug for Inner {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("Inner")
28 .field("tokens", &self.tokens)
29 .field("is_cancelled", &self.is_cancelled)
30 .field("driver", &"...")
31 .field("notify", &self.notify)
32 .finish()
33 }
34}
35
36#[derive(Clone, Debug)]
51pub struct CancelToken(Rc<Inner>);
52
53impl PartialEq for CancelToken {
54 fn eq(&self, other: &Self) -> bool {
55 Rc::ptr_eq(&self.0, &other.0)
56 }
57}
58
59impl Eq for CancelToken {}
60
61impl CancelToken {
62 pub fn new() -> Self {
69 Self(Rc::new(Inner {
70 tokens: RefCell::new(HashSet::new()),
71 is_cancelled: Cell::new(false),
72 driver: Runtime::with_current(|r| r.driver.clone()),
73 notify: Event::new(),
74 }))
75 }
76
77 pub(crate) fn listen(&self) -> EventListener {
78 self.0.notify.listen()
79 }
80
81 pub fn cancel(self) {
83 self.0.notify.notify_all();
84 if self.0.is_cancelled.replace(true) {
85 return;
86 }
87 let tokens = mem::take(self.0.tokens.borrow_mut().deref_mut());
88 for t in tokens {
89 self.0.driver.borrow_mut().cancel_token(t);
90 }
91 }
92
93 pub fn is_cancelled(&self) -> bool {
95 self.0.is_cancelled.get()
96 }
97
98 pub fn register<T: OpCode>(&self, key: &Key<T>) {
109 if self.0.is_cancelled.get() {
110 self.0.driver.borrow_mut().cancel(key.clone());
111 } else {
112 let token = self.0.driver.borrow_mut().register_cancel(key);
113 self.0.tokens.borrow_mut().insert(token);
114 }
115 }
116
117 pub fn wait(self) -> WaitFuture {
119 WaitFuture::new(self)
120 }
121
122 pub async fn current() -> Option<Self> {
127 std::future::poll_fn(|cx| Poll::Ready(cx.get_cancel().cloned())).await
128 }
129}
130
131impl Default for CancelToken {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137pub struct WaitFuture {
139 listen: EventListener,
140 token: CancelToken,
141}
142
143impl WaitFuture {
144 fn new(token: CancelToken) -> WaitFuture {
145 WaitFuture {
146 listen: token.listen(),
147 token,
148 }
149 }
150}
151
152impl Future for WaitFuture {
153 type Output = ();
154
155 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
156 loop {
157 if self.token.is_cancelled() {
158 return Poll::Ready(());
159 } else {
160 ready!(self.listen.poll_unpin(cx))
161 }
162 }
163 }
164}