1use std::future::Future;
7use std::pin::Pin;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex};
10use std::task::{Context, Poll, Waker};
11
12use crate::runtime::executor::{BoxedCancelToken, CancelToken, EventFuture};
13
14#[derive(Clone, Debug)]
16pub struct SimpleCancelToken {
17 inner: Arc<TokenInner>,
18}
19
20#[derive(Debug)]
21struct TokenInner {
22 cancelled: AtomicBool,
23 wakers: Mutex<Vec<Waker>>,
24}
25
26impl SimpleCancelToken {
27 pub fn new() -> Self {
28 SimpleCancelToken {
29 inner: Arc::new(TokenInner {
30 cancelled: AtomicBool::new(false),
31 wakers: Mutex::new(Vec::new()),
32 }),
33 }
34 }
35
36 pub fn new_boxed() -> BoxedCancelToken {
37 Box::new(Self::new())
38 }
39
40 pub fn cancel(&self) {
41 self.inner.cancelled.store(true, Ordering::Release);
43
44 let mut wakers = self.inner.wakers.lock().unwrap();
46 for waker in wakers.drain(..) {
47 waker.wake();
48 }
49 }
50
51 pub fn is_cancelled(&self) -> bool {
52 self.inner.cancelled.load(Ordering::Acquire)
53 }
54
55 pub fn cancelled(&self) -> CancelledFuture {
57 CancelledFuture {
58 token: self.clone(),
59 }
60 }
61}
62
63pub struct CancelledFuture {
65 token: SimpleCancelToken,
66}
67
68impl Future for CancelledFuture {
69 type Output = ();
70
71 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
72 if self.token.is_cancelled() {
73 return Poll::Ready(());
74 }
75
76 let mut wakers = self.token.inner.wakers.lock().unwrap();
78
79 if self.token.is_cancelled() {
81 return Poll::Ready(());
82 }
83
84 wakers.push(cx.waker().clone());
86 Poll::Pending
87 }
88}
89
90impl Default for SimpleCancelToken {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96impl CancelToken for SimpleCancelToken {
98 fn cancel(&self) {
99 self.cancel();
100 }
101
102 fn is_cancelled(&self) -> bool {
103 self.is_cancelled()
104 }
105
106 fn wait(&self) -> Pin<Box<dyn EventFuture>> {
107 Box::pin(self.cancelled())
108 }
109
110 fn clone_box(&self) -> Box<dyn CancelToken> {
111 Box::new(self.clone())
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_cancel_token() {
121 let token = SimpleCancelToken::new();
122 assert!(!token.is_cancelled());
123 token.cancel();
124 assert!(token.is_cancelled());
125 }
126
127 #[tokio::test]
128 async fn test_cancel_token_async() {
129 let token = SimpleCancelToken::new();
130 let h = tokio::spawn({
131 let token = token.clone();
132 async move {
133 token.wait().await;
134 }
135 });
136 token.cancel();
137 assert!(token.is_cancelled());
138 h.await.unwrap();
139 }
140
141 #[tokio::test]
142 async fn test_cancel_token_multi() {
143 let token = SimpleCancelToken::new();
144 let mut join_set = tokio::task::JoinSet::new();
145
146 for _ in 0..10 {
147 let token = token.clone();
148 join_set.spawn(async move {
149 token.wait().await;
150 });
151 tokio::task::yield_now().await;
152 }
153 token.cancel();
154 assert!(token.is_cancelled());
155 join_set.join_all().await;
156 }
157}