min_cancel_token/
lib.rs

1/// Error type which is returned if a cancellation token is cancelled
2#[derive(Debug)]
3pub struct CancellationError;
4
5pub struct CancellationGuardVtable {
6    drop: unsafe fn(data: *const (), func: &mut CancellationFunc),
7}
8
9/// Guard for a registered cancellation function
10///
11/// If the guard is dropped, the cancellation handler will be removed
12pub struct CancellationGuard<'a> {
13    func: &'a mut CancellationFunc<'a>,
14    data: *const (),
15    vtable: &'static CancellationGuardVtable,
16}
17
18impl<'a> Drop for CancellationGuard<'a> {
19    fn drop(&mut self) {
20        unsafe { (self.vtable.drop)(self.data, self.func) };
21    }
22}
23
24/// A function that will be executed on cancellation
25///
26/// This is a tiny wrapper around a `FnMut` closure which allows to intrusively
27/// link multiple closures.
28pub struct CancellationFunc<'a> {
29    inner: Option<&'a mut (dyn FnMut() + Sync)>,
30    prev: *const (),
31    next: *const (),
32}
33
34impl<'a> CancellationFunc<'a> {
35    unsafe fn from_raw(raw: *const ()) -> &'a mut Self {
36        let a = raw as *const Self;
37        std::mem::transmute(a)
38    }
39
40    fn into_raw(&mut self) -> *const () {
41        self as *const Self as _
42    }
43
44    pub fn new(func: &'a mut (dyn FnMut() + Sync)) -> Self {
45        Self {
46            inner: Some(func),
47            prev: std::ptr::null(),
48            next: std::ptr::null(),
49        }
50    }
51}
52
53/// A `CancellationToken` provides information whether a flow of execution is expected
54/// to be cancelled.
55///
56/// There are 2 ways to interact with `CancellationToken`:
57/// 1. The token can be queried on whether the flow is cancelled
58/// 2. A callback can be registered if thhe flow is cancelled.
59pub trait CancellationToken {
60    /// Performs a one-time check whether the flow of execution is cancelled.
61    ///
62    /// Returns an error if cancellation is initiated
63    fn error_if_cancelled(&self) -> Result<(), CancellationError>;
64
65    /// Registers a cancellation handler, which will be invoked when the execution flow
66    /// is cancelled.
67    /// The cancellation handler can be called from any thread which initiates the cancellation.
68    /// If the flow is already cancelled when this function is called, the cancellation
69    /// handler will be called synchronously.
70    /// The function returns a guard which can be used to unregister the cancellation handler.
71    /// After the guard is dropped, the handler is guaranteed not be called anymore.
72    fn on_cancellation<'a>(&self, func: &'a mut CancellationFunc<'a>) -> CancellationGuard<'a>;
73}
74
75thread_local! {
76    pub static CURRENT_CANCELLATION_TOKEN: std::cell::RefCell<Option<&'static dyn CancellationToken>> = std::cell::RefCell::new(None);
77}
78
79/// Executes a function that gets access to the current cancellation token
80pub fn with_current_cancellation_token<R>(func: impl FnOnce(&dyn CancellationToken) -> R) -> R {
81    CURRENT_CANCELLATION_TOKEN.with(|token| {
82        let x = &*token.borrow();
83        match x {
84            Some(token) => func(*token),
85            None => func(&UncancellableToken::default()),
86        }
87    })
88}
89
90/// Replaces the currently active (thread-local) cancellation token with the provided one,
91/// and executes the given function.
92/// Once the scope ends, the current cancellation token will be reset to the previous one.
93pub fn with_cancellation_token<'a, R>(
94    token: &'a dyn CancellationToken,
95    func: impl FnOnce() -> R,
96) -> R {
97    // Note that `'static` is a hack here to avoid having to specify the outer (unknown) lifetimes.
98    // Since we are guaranteed to reset the token before the lifetime ends and
99    // don't copy it anywhere else, this is ok.
100    struct RevertToOldTokenGuard<'a> {
101        prev: Option<&'static dyn CancellationToken>,
102        storage: &'a std::cell::RefCell<Option<&'static dyn CancellationToken>>,
103    }
104
105    impl<'a> Drop for RevertToOldTokenGuard<'a> {
106        fn drop(&mut self) {
107            let mut guard = self.storage.borrow_mut();
108            *guard = self.prev;
109        }
110    }
111
112    CURRENT_CANCELLATION_TOKEN.with(|storage| {
113        let mut guard = storage.borrow_mut();
114        let static_token: &'static dyn CancellationToken = unsafe { std::mem::transmute(token) };
115
116        let prev = std::mem::replace(&mut *guard, Some(static_token));
117        drop(guard);
118
119        // Revert to the last token once the function ends
120        // This guard makes sure we are adhering to to the non static lifetime
121        let _revert_guard = RevertToOldTokenGuard { prev, storage };
122
123        func()
124    })
125}
126
127/// An implementation of `CancellationToken` which will never signal the
128/// cancelled state
129#[derive(Debug, Default)]
130pub struct UncancellableToken {}
131
132fn noop_drop(_data: *const (), _func: &mut CancellationFunc) {}
133
134fn noop_vtable() -> &'static CancellationGuardVtable {
135    &CancellationGuardVtable { drop: noop_drop }
136}
137
138impl CancellationToken for UncancellableToken {
139    fn error_if_cancelled(&self) -> Result<(), CancellationError> {
140        Ok(())
141    }
142
143    fn on_cancellation<'a>(&self, func: &'a mut CancellationFunc<'a>) -> CancellationGuard<'a> {
144        CancellationGuard {
145            func,
146            data: std::ptr::null(),
147            vtable: noop_vtable(),
148        }
149    }
150}
151
152pub mod std_impl {
153    use super::*;
154    use std::sync::{Arc, Mutex};
155
156    /// Inner state of the `std` `CancellationToken` implementation
157    struct State {
158        /// Whether cancellation was initiated
159        pub cancelled: bool,
160        /// Linked list of cancellation callbacks
161        pub first_func: *const (),
162        /// Linked list of cancellation callbacks
163        pub last_func: *const (),
164    }
165
166    fn std_cancellation_token_drop(data: *const (), func: &mut CancellationFunc) {
167        let state: Arc<Mutex<State>> = unsafe { Arc::from_raw(data as _) };
168
169        let mut guard = state.lock().unwrap();
170        if guard.cancelled {
171            // The token was already cancelled and the callback was called
172            // This also means the function should already have been removed from
173            // the linked list
174            assert!(func.prev.is_null());
175            assert!(func.next.is_null());
176            return;
177        }
178
179        unsafe {
180            if func.prev.is_null() {
181                // This must be the first function that is registered
182                guard.first_func = func.next;
183                if !guard.first_func.is_null() {
184                    let mut first = CancellationFunc::from_raw(guard.first_func);
185                    first.prev = std::ptr::null();
186                } else {
187                    // The list is drained
188                    guard.last_func = std::ptr::null();
189                }
190                func.next = std::ptr::null();
191            } else {
192                // There exists a previous function, since its not the first
193                let mut prev = CancellationFunc::from_raw(func.prev);
194                prev.next = func.next;
195                if !func.next.is_null() {
196                    let mut next = CancellationFunc::from_raw(func.next);
197                    next.prev = func.prev;
198                }
199
200                func.next = std::ptr::null();
201                func.prev = std::ptr::null();
202            }
203        }
204
205        std::mem::drop(data);
206    }
207
208    fn std_cancellation_token_vtable() -> &'static CancellationGuardVtable {
209        &CancellationGuardVtable {
210            drop: std_cancellation_token_drop,
211        }
212    }
213
214    pub struct StdCancellationToken {
215        state: Arc<Mutex<State>>,
216    }
217
218    impl StdCancellationToken {}
219
220    impl CancellationToken for StdCancellationToken {
221        fn error_if_cancelled(&self) -> Result<(), crate::CancellationError> {
222            if self.state.lock().unwrap().cancelled {
223                Err(CancellationError)
224            } else {
225                Ok(())
226            }
227        }
228
229        fn on_cancellation<'a>(&self, func: &'a mut CancellationFunc<'a>) -> CancellationGuard<'a> {
230            let mut guard = self.state.lock().unwrap();
231            if guard.cancelled {
232                if let Some(func) = (&mut func.inner).take() {
233                    (func)();
234                }
235                return CancellationGuard {
236                    data: std::ptr::null(),
237                    vtable: noop_vtable(),
238                    func,
239                };
240            }
241
242            func.next = std::ptr::null();
243            func.prev = std::ptr::null();
244            if guard.first_func.is_null() {
245                // Only function in the list
246                guard.first_func = func.into_raw();
247                guard.last_func = func.into_raw();
248            } else {
249                unsafe {
250                    // This must exist, since its not the only function in the list
251                    let mut last = CancellationFunc::from_raw(guard.last_func);
252                    last.next = func.into_raw();
253                    func.prev = last.into_raw();
254                    guard.last_func = func.into_raw();
255                }
256            }
257
258            CancellationGuard {
259                data: Arc::into_raw(self.state.clone()) as _,
260                vtable: std_cancellation_token_vtable(),
261                func,
262            }
263        }
264    }
265
266    pub struct StdCancellationTokenSource {
267        state: Arc<Mutex<State>>,
268    }
269
270    unsafe impl Send for StdCancellationTokenSource {}
271    unsafe impl Sync for StdCancellationTokenSource {}
272
273    impl StdCancellationTokenSource {
274        pub fn new() -> StdCancellationTokenSource {
275            StdCancellationTokenSource {
276                state: Arc::new(Mutex::new(State {
277                    cancelled: false,
278                    first_func: std::ptr::null(),
279                    last_func: std::ptr::null(),
280                })),
281            }
282        }
283
284        pub fn token(&self) -> StdCancellationToken {
285            StdCancellationToken {
286                state: self.state.clone(),
287            }
288        }
289
290        pub fn cancel(&self) {
291            let mut guard = self.state.lock().unwrap();
292            if guard.cancelled {
293                return;
294            }
295            guard.cancelled = true;
296
297            while !guard.first_func.is_null() {
298                unsafe {
299                    let mut first = CancellationFunc::from_raw(guard.first_func);
300                    guard.first_func = first.next;
301                    first.prev = std::ptr::null();
302                    first.next = std::ptr::null();
303                    if let Some(func) = first.inner.take() {
304                        (func)();
305                    }
306                }
307            }
308            guard.last_func = std::ptr::null();
309        }
310    }
311}
312
313/// Utilities for working with cancellation tokens
314pub mod utils {
315    use super::*;
316
317    pub fn wait_cancelled(token: &dyn CancellationToken) {
318        let mtx = std::sync::Mutex::new(false);
319        let cv = std::sync::Condvar::new();
320
321        let func = &mut || {
322            let mut guard = mtx.lock().unwrap();
323            *guard = true;
324            drop(guard);
325            cv.notify_all();
326        };
327        let mut wait_func = CancellationFunc::new(func);
328        let _guard = token.on_cancellation(&mut wait_func);
329
330        let mut cancelled = mtx.lock().unwrap();
331        while !*cancelled {
332            cancelled = cv.wait(cancelled).unwrap();
333        }
334    }
335
336    pub fn wait_cancelled_polled(token: &dyn CancellationToken) {
337        let is_cancelled = std::sync::atomic::AtomicBool::new(false);
338
339        let func = &mut || {
340            is_cancelled.store(true, std::sync::atomic::Ordering::Release);
341        };
342        let mut wait_func = CancellationFunc::new(func);
343        let _guard = token.on_cancellation(&mut wait_func);
344
345        while !is_cancelled.load(std::sync::atomic::Ordering::Acquire) {
346            std::thread::sleep(std::time::Duration::from_millis(1));
347        }
348    }
349
350    pub async fn await_cancelled(token: &dyn CancellationToken) {
351        use std::future::Future;
352        use std::pin::Pin;
353        use std::sync::Mutex;
354        use std::task::{Context, Poll, Waker};
355
356        struct CancelFut<'a> {
357            token: &'a dyn CancellationToken,
358            waker: &'a Mutex<Option<Waker>>,
359        }
360
361        impl<'a> Future for CancelFut<'a> {
362            type Output = ();
363
364            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
365                match self.token.error_if_cancelled() {
366                    Ok(()) => {
367                        let mut guard = self.waker.lock().unwrap();
368                        *guard = Some(cx.waker().clone());
369
370                        // TODO: Theres a race here, and the waker might just have been
371                        // installed after the token was cancelled
372
373                        Poll::Pending
374                    }
375                    Err(_) => Poll::Ready(()),
376                }
377            }
378        }
379
380        // TODO: A Mutex requires a heap allocation, and is probably not required
381        // here. Something like `AtomicWaker` should work.
382        let waker_store = Mutex::<Option<Waker>>::new(None);
383
384        let mut on_cancel = || {
385            let mut guard = waker_store.lock().unwrap();
386            if let Some(waker) = guard.take() {
387                waker.wake();
388            }
389        };
390        let mut wait_func = CancellationFunc::new(&mut on_cancel);
391        let _guard = token.on_cancellation(&mut wait_func);
392
393        let fut = CancelFut {
394            token,
395            waker: &waker_store,
396        };
397
398        fut.await
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::std_impl::*;
405    use super::*;
406    use std::sync::atomic::{AtomicUsize, Ordering};
407    use std::sync::Arc;
408    use std::time::{Duration, Instant};
409
410    #[test]
411    fn simple_cancel() {
412        let source = StdCancellationTokenSource::new();
413        let token = source.token();
414
415        assert!(token.error_if_cancelled().is_ok());
416        source.cancel();
417        assert!(token.error_if_cancelled().is_err());
418    }
419
420    #[test]
421    fn test_token() {
422        let source = StdCancellationTokenSource::new();
423        let token = source.token();
424
425        let dyn_token: &dyn CancellationToken = &token;
426
427        let (sender, receiver) = std::sync::mpsc::sync_channel(1);
428
429        let start = Instant::now();
430
431        std::thread::spawn(move || {
432            std::thread::sleep(Duration::from_secs(1));
433            source.cancel();
434        });
435
436        let mut func = || {
437            sender.send(true).unwrap();
438        };
439        let mut cancel_func = CancellationFunc::new(&mut func);
440
441        let _guard = dyn_token.on_cancellation(&mut cancel_func);
442
443        let _ = receiver.recv();
444
445        let elapsed = start.elapsed();
446        assert!(elapsed >= Duration::from_secs(1));
447    }
448
449    #[test]
450    fn test_wait_cancelled_immediately() {
451        let source = StdCancellationTokenSource::new();
452        source.cancel();
453        let token = source.token();
454
455        let dyn_token: &dyn CancellationToken = &token;
456
457        let start = Instant::now();
458
459        utils::wait_cancelled(dyn_token);
460
461        let elapsed = start.elapsed();
462        assert!(elapsed < Duration::from_millis(50));
463    }
464
465    #[test]
466    fn test_wait_cancelled() {
467        let source = StdCancellationTokenSource::new();
468        let token = source.token();
469
470        let dyn_token: &dyn CancellationToken = &token;
471
472        let start = Instant::now();
473
474        std::thread::spawn(move || {
475            std::thread::sleep(Duration::from_secs(1));
476            source.cancel();
477        });
478
479        utils::wait_cancelled(dyn_token);
480
481        let elapsed = start.elapsed();
482        assert!(elapsed >= Duration::from_secs(1));
483    }
484
485    #[test]
486    fn test_wait_cancelled_polled() {
487        let source = StdCancellationTokenSource::new();
488        let token = source.token();
489
490        let dyn_token: &dyn CancellationToken = &token;
491
492        let start = Instant::now();
493
494        std::thread::spawn(move || {
495            std::thread::sleep(Duration::from_secs(1));
496            source.cancel();
497        });
498
499        utils::wait_cancelled_polled(dyn_token);
500
501        let elapsed = start.elapsed();
502        assert!(elapsed >= Duration::from_secs(1));
503    }
504
505    #[test]
506    fn test_await_cancelled_immediately() {
507        futures::executor::block_on(async {
508            let source = StdCancellationTokenSource::new();
509            source.cancel();
510            let token = source.token();
511            let dyn_token: &dyn CancellationToken = &token;
512
513            let start = Instant::now();
514
515            utils::await_cancelled(dyn_token).await;
516
517            let elapsed = start.elapsed();
518            assert!(elapsed < Duration::from_millis(50));
519        });
520    }
521
522    #[test]
523    fn test_await_cancelled() {
524        futures::executor::block_on(async {
525            let source = StdCancellationTokenSource::new();
526            let token = source.token();
527
528            let dyn_token: &dyn CancellationToken = &token;
529
530            let start = Instant::now();
531
532            std::thread::spawn(move || {
533                std::thread::sleep(Duration::from_secs(1));
534                source.cancel();
535            });
536
537            utils::await_cancelled(dyn_token).await;
538
539            let elapsed = start.elapsed();
540            assert!(elapsed >= Duration::from_secs(1));
541        });
542    }
543
544    #[test]
545    fn unregister_before_cancel() {
546        for token1_to_drop in 0..4 {
547            for token2_to_drop in 0..4 {
548                let source = StdCancellationTokenSource::new();
549                let tokens = (0..4).map(|_| source.token()).collect::<Vec<_>>();
550
551                let counter = Arc::new(AtomicUsize::new(0));
552
553                std::thread::spawn(move || {
554                    std::thread::sleep(Duration::from_secs(1));
555                    source.cancel();
556                });
557
558                let mut func_1 = || {
559                    counter.fetch_add(1, Ordering::SeqCst);
560                };
561                let mut func_2 = || {
562                    counter.fetch_add(1, Ordering::SeqCst);
563                };
564                let mut func_3 = || {
565                    counter.fetch_add(1, Ordering::SeqCst);
566                };
567                let mut func_4 = || {
568                    counter.fetch_add(1, Ordering::SeqCst);
569                };
570                let mut cancel_func_1 = CancellationFunc::new(&mut func_1);
571                let mut cancel_func_2 = CancellationFunc::new(&mut func_2);
572                let mut cancel_func_3 = CancellationFunc::new(&mut func_3);
573                let mut cancel_func_4 = CancellationFunc::new(&mut func_4);
574
575                let mut guards = vec![None, None, None, None];
576                guards[0] = Some(tokens[0].on_cancellation(&mut cancel_func_1));
577                guards[1] = Some(tokens[1].on_cancellation(&mut cancel_func_2));
578                guards[2] = Some(tokens[2].on_cancellation(&mut cancel_func_3));
579                guards[3] = Some(tokens[3].on_cancellation(&mut cancel_func_4));
580
581                guards[token1_to_drop] = None;
582                guards[token2_to_drop] = None;
583
584                std::thread::sleep(Duration::from_secs(2));
585                let expected = if token1_to_drop == token2_to_drop {
586                    3
587                } else {
588                    2
589                };
590                assert_eq!(counter.load(Ordering::SeqCst), expected);
591            }
592        }
593    }
594
595    #[test]
596    fn test_thread_local_cancellation() {
597        let source = StdCancellationTokenSource::new();
598        let token = source.token();
599        let dyn_token: &dyn CancellationToken = &token;
600
601        let start = Instant::now();
602
603        std::thread::spawn(move || {
604            std::thread::sleep(Duration::from_secs(1));
605            source.cancel();
606        });
607
608        with_cancellation_token(dyn_token, || {
609            with_current_cancellation_token(|token| {
610                utils::wait_cancelled(token);
611            })
612        });
613
614        let elapsed = start.elapsed();
615        assert!(elapsed >= Duration::from_secs(1));
616    }
617
618    #[test]
619    fn test_nested_cancellation() {
620        let source = StdCancellationTokenSource::new();
621        let token = source.token();
622        let dyn_token: &dyn CancellationToken = &token;
623
624        let start = Instant::now();
625
626        std::thread::spawn(move || {
627            std::thread::sleep(Duration::from_secs(1));
628            source.cancel();
629        });
630
631        with_cancellation_token(dyn_token, || {
632            let next_source = StdCancellationTokenSource::new();
633            let next_token = next_source.token();
634
635            let mut cancel_func = || {
636                next_source.cancel();
637            };
638            let mut cancel_func = CancellationFunc::new(&mut cancel_func);
639            let _guard =
640                with_current_cancellation_token(|token| token.on_cancellation(&mut cancel_func));
641
642            with_cancellation_token(&next_token, || {
643                let third_source = StdCancellationTokenSource::new();
644                let third_token = third_source.token();
645
646                let mut cancel_func = || {
647                    third_source.cancel();
648                };
649                let mut cancel_func = CancellationFunc::new(&mut cancel_func);
650                let _guard = with_current_cancellation_token(|token| {
651                    token.on_cancellation(&mut cancel_func)
652                });
653
654                with_cancellation_token(&third_token, || {
655                    with_current_cancellation_token(|token| {
656                        futures::executor::block_on(async {
657                            utils::await_cancelled(token).await;
658                        });
659                    });
660                });
661            });
662        });
663
664        let elapsed = start.elapsed();
665        assert!(elapsed >= Duration::from_secs(1));
666    }
667}