1#[derive(Debug)]
3pub struct CancellationError;
4
5pub struct CancellationGuardVtable {
6 drop: unsafe fn(data: *const (), func: &mut CancellationFunc),
7}
8
9pub struct CancellationGuard<'a> {
13 func: &'a mut CancellationFunc<'a>,
14 data: *const (),
15 vtable: &'static CancellationGuardVtable,
16}
17
18impl<'a> Drop for CancellationGuard<'a> {
19 fn drop(&mut self) {
20 unsafe { (self.vtable.drop)(self.data, self.func) };
21 }
22}
23
24pub struct CancellationFunc<'a> {
29 inner: Option<&'a mut (dyn FnMut() + Sync)>,
30 prev: *const (),
31 next: *const (),
32}
33
34impl<'a> CancellationFunc<'a> {
35 unsafe fn from_raw(raw: *const ()) -> &'a mut Self {
36 let a = raw as *const Self;
37 std::mem::transmute(a)
38 }
39
40 fn into_raw(&mut self) -> *const () {
41 self as *const Self as _
42 }
43
44 pub fn new(func: &'a mut (dyn FnMut() + Sync)) -> Self {
45 Self {
46 inner: Some(func),
47 prev: std::ptr::null(),
48 next: std::ptr::null(),
49 }
50 }
51}
52
53pub trait CancellationToken {
60 fn error_if_cancelled(&self) -> Result<(), CancellationError>;
64
65 fn on_cancellation<'a>(&self, func: &'a mut CancellationFunc<'a>) -> CancellationGuard<'a>;
73}
74
75thread_local! {
76 pub static CURRENT_CANCELLATION_TOKEN: std::cell::RefCell<Option<&'static dyn CancellationToken>> = std::cell::RefCell::new(None);
77}
78
79pub fn with_current_cancellation_token<R>(func: impl FnOnce(&dyn CancellationToken) -> R) -> R {
81 CURRENT_CANCELLATION_TOKEN.with(|token| {
82 let x = &*token.borrow();
83 match x {
84 Some(token) => func(*token),
85 None => func(&UncancellableToken::default()),
86 }
87 })
88}
89
90pub fn with_cancellation_token<'a, R>(
94 token: &'a dyn CancellationToken,
95 func: impl FnOnce() -> R,
96) -> R {
97 struct RevertToOldTokenGuard<'a> {
101 prev: Option<&'static dyn CancellationToken>,
102 storage: &'a std::cell::RefCell<Option<&'static dyn CancellationToken>>,
103 }
104
105 impl<'a> Drop for RevertToOldTokenGuard<'a> {
106 fn drop(&mut self) {
107 let mut guard = self.storage.borrow_mut();
108 *guard = self.prev;
109 }
110 }
111
112 CURRENT_CANCELLATION_TOKEN.with(|storage| {
113 let mut guard = storage.borrow_mut();
114 let static_token: &'static dyn CancellationToken = unsafe { std::mem::transmute(token) };
115
116 let prev = std::mem::replace(&mut *guard, Some(static_token));
117 drop(guard);
118
119 let _revert_guard = RevertToOldTokenGuard { prev, storage };
122
123 func()
124 })
125}
126
127#[derive(Debug, Default)]
130pub struct UncancellableToken {}
131
132fn noop_drop(_data: *const (), _func: &mut CancellationFunc) {}
133
134fn noop_vtable() -> &'static CancellationGuardVtable {
135 &CancellationGuardVtable { drop: noop_drop }
136}
137
138impl CancellationToken for UncancellableToken {
139 fn error_if_cancelled(&self) -> Result<(), CancellationError> {
140 Ok(())
141 }
142
143 fn on_cancellation<'a>(&self, func: &'a mut CancellationFunc<'a>) -> CancellationGuard<'a> {
144 CancellationGuard {
145 func,
146 data: std::ptr::null(),
147 vtable: noop_vtable(),
148 }
149 }
150}
151
152pub mod std_impl {
153 use super::*;
154 use std::sync::{Arc, Mutex};
155
156 struct State {
158 pub cancelled: bool,
160 pub first_func: *const (),
162 pub last_func: *const (),
164 }
165
166 fn std_cancellation_token_drop(data: *const (), func: &mut CancellationFunc) {
167 let state: Arc<Mutex<State>> = unsafe { Arc::from_raw(data as _) };
168
169 let mut guard = state.lock().unwrap();
170 if guard.cancelled {
171 assert!(func.prev.is_null());
175 assert!(func.next.is_null());
176 return;
177 }
178
179 unsafe {
180 if func.prev.is_null() {
181 guard.first_func = func.next;
183 if !guard.first_func.is_null() {
184 let mut first = CancellationFunc::from_raw(guard.first_func);
185 first.prev = std::ptr::null();
186 } else {
187 guard.last_func = std::ptr::null();
189 }
190 func.next = std::ptr::null();
191 } else {
192 let mut prev = CancellationFunc::from_raw(func.prev);
194 prev.next = func.next;
195 if !func.next.is_null() {
196 let mut next = CancellationFunc::from_raw(func.next);
197 next.prev = func.prev;
198 }
199
200 func.next = std::ptr::null();
201 func.prev = std::ptr::null();
202 }
203 }
204
205 std::mem::drop(data);
206 }
207
208 fn std_cancellation_token_vtable() -> &'static CancellationGuardVtable {
209 &CancellationGuardVtable {
210 drop: std_cancellation_token_drop,
211 }
212 }
213
214 pub struct StdCancellationToken {
215 state: Arc<Mutex<State>>,
216 }
217
218 impl StdCancellationToken {}
219
220 impl CancellationToken for StdCancellationToken {
221 fn error_if_cancelled(&self) -> Result<(), crate::CancellationError> {
222 if self.state.lock().unwrap().cancelled {
223 Err(CancellationError)
224 } else {
225 Ok(())
226 }
227 }
228
229 fn on_cancellation<'a>(&self, func: &'a mut CancellationFunc<'a>) -> CancellationGuard<'a> {
230 let mut guard = self.state.lock().unwrap();
231 if guard.cancelled {
232 if let Some(func) = (&mut func.inner).take() {
233 (func)();
234 }
235 return CancellationGuard {
236 data: std::ptr::null(),
237 vtable: noop_vtable(),
238 func,
239 };
240 }
241
242 func.next = std::ptr::null();
243 func.prev = std::ptr::null();
244 if guard.first_func.is_null() {
245 guard.first_func = func.into_raw();
247 guard.last_func = func.into_raw();
248 } else {
249 unsafe {
250 let mut last = CancellationFunc::from_raw(guard.last_func);
252 last.next = func.into_raw();
253 func.prev = last.into_raw();
254 guard.last_func = func.into_raw();
255 }
256 }
257
258 CancellationGuard {
259 data: Arc::into_raw(self.state.clone()) as _,
260 vtable: std_cancellation_token_vtable(),
261 func,
262 }
263 }
264 }
265
266 pub struct StdCancellationTokenSource {
267 state: Arc<Mutex<State>>,
268 }
269
270 unsafe impl Send for StdCancellationTokenSource {}
271 unsafe impl Sync for StdCancellationTokenSource {}
272
273 impl StdCancellationTokenSource {
274 pub fn new() -> StdCancellationTokenSource {
275 StdCancellationTokenSource {
276 state: Arc::new(Mutex::new(State {
277 cancelled: false,
278 first_func: std::ptr::null(),
279 last_func: std::ptr::null(),
280 })),
281 }
282 }
283
284 pub fn token(&self) -> StdCancellationToken {
285 StdCancellationToken {
286 state: self.state.clone(),
287 }
288 }
289
290 pub fn cancel(&self) {
291 let mut guard = self.state.lock().unwrap();
292 if guard.cancelled {
293 return;
294 }
295 guard.cancelled = true;
296
297 while !guard.first_func.is_null() {
298 unsafe {
299 let mut first = CancellationFunc::from_raw(guard.first_func);
300 guard.first_func = first.next;
301 first.prev = std::ptr::null();
302 first.next = std::ptr::null();
303 if let Some(func) = first.inner.take() {
304 (func)();
305 }
306 }
307 }
308 guard.last_func = std::ptr::null();
309 }
310 }
311}
312
313pub mod utils {
315 use super::*;
316
317 pub fn wait_cancelled(token: &dyn CancellationToken) {
318 let mtx = std::sync::Mutex::new(false);
319 let cv = std::sync::Condvar::new();
320
321 let func = &mut || {
322 let mut guard = mtx.lock().unwrap();
323 *guard = true;
324 drop(guard);
325 cv.notify_all();
326 };
327 let mut wait_func = CancellationFunc::new(func);
328 let _guard = token.on_cancellation(&mut wait_func);
329
330 let mut cancelled = mtx.lock().unwrap();
331 while !*cancelled {
332 cancelled = cv.wait(cancelled).unwrap();
333 }
334 }
335
336 pub fn wait_cancelled_polled(token: &dyn CancellationToken) {
337 let is_cancelled = std::sync::atomic::AtomicBool::new(false);
338
339 let func = &mut || {
340 is_cancelled.store(true, std::sync::atomic::Ordering::Release);
341 };
342 let mut wait_func = CancellationFunc::new(func);
343 let _guard = token.on_cancellation(&mut wait_func);
344
345 while !is_cancelled.load(std::sync::atomic::Ordering::Acquire) {
346 std::thread::sleep(std::time::Duration::from_millis(1));
347 }
348 }
349
350 pub async fn await_cancelled(token: &dyn CancellationToken) {
351 use std::future::Future;
352 use std::pin::Pin;
353 use std::sync::Mutex;
354 use std::task::{Context, Poll, Waker};
355
356 struct CancelFut<'a> {
357 token: &'a dyn CancellationToken,
358 waker: &'a Mutex<Option<Waker>>,
359 }
360
361 impl<'a> Future for CancelFut<'a> {
362 type Output = ();
363
364 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
365 match self.token.error_if_cancelled() {
366 Ok(()) => {
367 let mut guard = self.waker.lock().unwrap();
368 *guard = Some(cx.waker().clone());
369
370 Poll::Pending
374 }
375 Err(_) => Poll::Ready(()),
376 }
377 }
378 }
379
380 let waker_store = Mutex::<Option<Waker>>::new(None);
383
384 let mut on_cancel = || {
385 let mut guard = waker_store.lock().unwrap();
386 if let Some(waker) = guard.take() {
387 waker.wake();
388 }
389 };
390 let mut wait_func = CancellationFunc::new(&mut on_cancel);
391 let _guard = token.on_cancellation(&mut wait_func);
392
393 let fut = CancelFut {
394 token,
395 waker: &waker_store,
396 };
397
398 fut.await
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::std_impl::*;
405 use super::*;
406 use std::sync::atomic::{AtomicUsize, Ordering};
407 use std::sync::Arc;
408 use std::time::{Duration, Instant};
409
410 #[test]
411 fn simple_cancel() {
412 let source = StdCancellationTokenSource::new();
413 let token = source.token();
414
415 assert!(token.error_if_cancelled().is_ok());
416 source.cancel();
417 assert!(token.error_if_cancelled().is_err());
418 }
419
420 #[test]
421 fn test_token() {
422 let source = StdCancellationTokenSource::new();
423 let token = source.token();
424
425 let dyn_token: &dyn CancellationToken = &token;
426
427 let (sender, receiver) = std::sync::mpsc::sync_channel(1);
428
429 let start = Instant::now();
430
431 std::thread::spawn(move || {
432 std::thread::sleep(Duration::from_secs(1));
433 source.cancel();
434 });
435
436 let mut func = || {
437 sender.send(true).unwrap();
438 };
439 let mut cancel_func = CancellationFunc::new(&mut func);
440
441 let _guard = dyn_token.on_cancellation(&mut cancel_func);
442
443 let _ = receiver.recv();
444
445 let elapsed = start.elapsed();
446 assert!(elapsed >= Duration::from_secs(1));
447 }
448
449 #[test]
450 fn test_wait_cancelled_immediately() {
451 let source = StdCancellationTokenSource::new();
452 source.cancel();
453 let token = source.token();
454
455 let dyn_token: &dyn CancellationToken = &token;
456
457 let start = Instant::now();
458
459 utils::wait_cancelled(dyn_token);
460
461 let elapsed = start.elapsed();
462 assert!(elapsed < Duration::from_millis(50));
463 }
464
465 #[test]
466 fn test_wait_cancelled() {
467 let source = StdCancellationTokenSource::new();
468 let token = source.token();
469
470 let dyn_token: &dyn CancellationToken = &token;
471
472 let start = Instant::now();
473
474 std::thread::spawn(move || {
475 std::thread::sleep(Duration::from_secs(1));
476 source.cancel();
477 });
478
479 utils::wait_cancelled(dyn_token);
480
481 let elapsed = start.elapsed();
482 assert!(elapsed >= Duration::from_secs(1));
483 }
484
485 #[test]
486 fn test_wait_cancelled_polled() {
487 let source = StdCancellationTokenSource::new();
488 let token = source.token();
489
490 let dyn_token: &dyn CancellationToken = &token;
491
492 let start = Instant::now();
493
494 std::thread::spawn(move || {
495 std::thread::sleep(Duration::from_secs(1));
496 source.cancel();
497 });
498
499 utils::wait_cancelled_polled(dyn_token);
500
501 let elapsed = start.elapsed();
502 assert!(elapsed >= Duration::from_secs(1));
503 }
504
505 #[test]
506 fn test_await_cancelled_immediately() {
507 futures::executor::block_on(async {
508 let source = StdCancellationTokenSource::new();
509 source.cancel();
510 let token = source.token();
511 let dyn_token: &dyn CancellationToken = &token;
512
513 let start = Instant::now();
514
515 utils::await_cancelled(dyn_token).await;
516
517 let elapsed = start.elapsed();
518 assert!(elapsed < Duration::from_millis(50));
519 });
520 }
521
522 #[test]
523 fn test_await_cancelled() {
524 futures::executor::block_on(async {
525 let source = StdCancellationTokenSource::new();
526 let token = source.token();
527
528 let dyn_token: &dyn CancellationToken = &token;
529
530 let start = Instant::now();
531
532 std::thread::spawn(move || {
533 std::thread::sleep(Duration::from_secs(1));
534 source.cancel();
535 });
536
537 utils::await_cancelled(dyn_token).await;
538
539 let elapsed = start.elapsed();
540 assert!(elapsed >= Duration::from_secs(1));
541 });
542 }
543
544 #[test]
545 fn unregister_before_cancel() {
546 for token1_to_drop in 0..4 {
547 for token2_to_drop in 0..4 {
548 let source = StdCancellationTokenSource::new();
549 let tokens = (0..4).map(|_| source.token()).collect::<Vec<_>>();
550
551 let counter = Arc::new(AtomicUsize::new(0));
552
553 std::thread::spawn(move || {
554 std::thread::sleep(Duration::from_secs(1));
555 source.cancel();
556 });
557
558 let mut func_1 = || {
559 counter.fetch_add(1, Ordering::SeqCst);
560 };
561 let mut func_2 = || {
562 counter.fetch_add(1, Ordering::SeqCst);
563 };
564 let mut func_3 = || {
565 counter.fetch_add(1, Ordering::SeqCst);
566 };
567 let mut func_4 = || {
568 counter.fetch_add(1, Ordering::SeqCst);
569 };
570 let mut cancel_func_1 = CancellationFunc::new(&mut func_1);
571 let mut cancel_func_2 = CancellationFunc::new(&mut func_2);
572 let mut cancel_func_3 = CancellationFunc::new(&mut func_3);
573 let mut cancel_func_4 = CancellationFunc::new(&mut func_4);
574
575 let mut guards = vec![None, None, None, None];
576 guards[0] = Some(tokens[0].on_cancellation(&mut cancel_func_1));
577 guards[1] = Some(tokens[1].on_cancellation(&mut cancel_func_2));
578 guards[2] = Some(tokens[2].on_cancellation(&mut cancel_func_3));
579 guards[3] = Some(tokens[3].on_cancellation(&mut cancel_func_4));
580
581 guards[token1_to_drop] = None;
582 guards[token2_to_drop] = None;
583
584 std::thread::sleep(Duration::from_secs(2));
585 let expected = if token1_to_drop == token2_to_drop {
586 3
587 } else {
588 2
589 };
590 assert_eq!(counter.load(Ordering::SeqCst), expected);
591 }
592 }
593 }
594
595 #[test]
596 fn test_thread_local_cancellation() {
597 let source = StdCancellationTokenSource::new();
598 let token = source.token();
599 let dyn_token: &dyn CancellationToken = &token;
600
601 let start = Instant::now();
602
603 std::thread::spawn(move || {
604 std::thread::sleep(Duration::from_secs(1));
605 source.cancel();
606 });
607
608 with_cancellation_token(dyn_token, || {
609 with_current_cancellation_token(|token| {
610 utils::wait_cancelled(token);
611 })
612 });
613
614 let elapsed = start.elapsed();
615 assert!(elapsed >= Duration::from_secs(1));
616 }
617
618 #[test]
619 fn test_nested_cancellation() {
620 let source = StdCancellationTokenSource::new();
621 let token = source.token();
622 let dyn_token: &dyn CancellationToken = &token;
623
624 let start = Instant::now();
625
626 std::thread::spawn(move || {
627 std::thread::sleep(Duration::from_secs(1));
628 source.cancel();
629 });
630
631 with_cancellation_token(dyn_token, || {
632 let next_source = StdCancellationTokenSource::new();
633 let next_token = next_source.token();
634
635 let mut cancel_func = || {
636 next_source.cancel();
637 };
638 let mut cancel_func = CancellationFunc::new(&mut cancel_func);
639 let _guard =
640 with_current_cancellation_token(|token| token.on_cancellation(&mut cancel_func));
641
642 with_cancellation_token(&next_token, || {
643 let third_source = StdCancellationTokenSource::new();
644 let third_token = third_source.token();
645
646 let mut cancel_func = || {
647 third_source.cancel();
648 };
649 let mut cancel_func = CancellationFunc::new(&mut cancel_func);
650 let _guard = with_current_cancellation_token(|token| {
651 token.on_cancellation(&mut cancel_func)
652 });
653
654 with_cancellation_token(&third_token, || {
655 with_current_cancellation_token(|token| {
656 futures::executor::block_on(async {
657 utils::await_cancelled(token).await;
658 });
659 });
660 });
661 });
662 });
663
664 let elapsed = start.elapsed();
665 assert!(elapsed >= Duration::from_secs(1));
666 }
667}