Skip to main content

burn_central_core/experiment/
cancellation.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4/// A trait representing a cancellable task or operation. Implementors should ensure that `cancel` is idempotent and thread-safe, and that `is_cancelled` accurately reflects the cancellation state.
5pub trait Cancellable: Send + Sync {
6    /// Cancel the task. Should be idempotent and thread-safe.
7    fn cancel(&self);
8    /// Check if the task has been cancelled. Should be thread-safe.
9    fn is_cancelled(&self) -> bool;
10}
11
12type CancellableRef = Arc<dyn Cancellable>;
13
14/// A cancellation token that can be shared across tasks. When cancelled, it will cancel all linked child tasks/tokens. Thread-safe and can be cloned.
15#[derive(Clone, Default)]
16pub struct CancelToken {
17    inner: Arc<Inner>,
18}
19
20#[derive(Default)]
21struct Inner {
22    cancelled: AtomicBool,
23    children: Mutex<Vec<CancellableRef>>,
24}
25
26impl CancelToken {
27    /// Create a new cancellation token that is not cancelled and has no children.
28    pub fn new() -> Self {
29        Self {
30            inner: Arc::new(Inner {
31                cancelled: AtomicBool::new(false),
32                children: Mutex::new(Vec::new()),
33            }),
34        }
35    }
36
37    /// Check if this token has been cancelled. Thread-safe.
38    pub fn is_cancelled(&self) -> bool {
39        self.inner.cancelled.load(Ordering::Acquire)
40    }
41
42    /// Attach a child task/token to be cancelled when this token is cancelled.
43    /// If already cancelled, child is cancelled immediately.
44    pub fn link<T: Cancellable + 'static>(&self, child: T) {
45        let child = Arc::new(child);
46        if self.is_cancelled() {
47            child.cancel();
48            return;
49        }
50
51        let mut children = self.inner.children.lock().unwrap();
52
53        if self.is_cancelled() {
54            drop(children);
55            child.cancel();
56            return;
57        }
58
59        children.push(child);
60    }
61
62    /// Cancel this token and all currently-linked children.
63    pub fn cancel(&self) {
64        if self.inner.cancelled.swap(true, Ordering::AcqRel) {
65            return;
66        }
67
68        let children = {
69            let mut children = self.inner.children.lock().unwrap();
70            std::mem::take(&mut *children)
71        };
72
73        for c in children {
74            c.cancel();
75        }
76    }
77
78    /// Create a new cancellable task/token of type T, link it to this token, and return it.
79    /// See [Self::link] for more details.
80    pub fn into_linked<T: Cancellable + Default + Clone + 'static>(&self) -> T {
81        let merged = T::default();
82        self.link(merged.clone());
83        merged
84    }
85
86    /// Link an existing cancellable task/token of type T to this token and return it.
87    /// See [Self::link] for more details.
88    pub fn linked<T: Cancellable + Clone + 'static>(&self, child: T) -> T {
89        self.link(child.clone());
90        child
91    }
92}
93
94impl Cancellable for CancelToken {
95    fn cancel(&self) {
96        CancelToken::cancel(self)
97    }
98    fn is_cancelled(&self) -> bool {
99        self.is_cancelled()
100    }
101}
102
103impl<T: Cancellable> Cancellable for Arc<T> {
104    fn cancel(&self) {
105        self.as_ref().cancel()
106    }
107    fn is_cancelled(&self) -> bool {
108        self.as_ref().is_cancelled()
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use std::sync::Arc;
116    use std::sync::atomic::AtomicUsize;
117    use std::thread;
118    use std::time::Duration;
119
120    struct TestCancellable {
121        cancelled: AtomicBool,
122        cancel_count: AtomicUsize,
123    }
124
125    impl TestCancellable {
126        fn new() -> Self {
127            Self {
128                cancelled: AtomicBool::new(false),
129                cancel_count: AtomicUsize::new(0),
130            }
131        }
132
133        fn cancel_count(&self) -> usize {
134            self.cancel_count.load(Ordering::Relaxed)
135        }
136    }
137
138    impl Cancellable for TestCancellable {
139        fn cancel(&self) {
140            if !self.cancelled.swap(true, Ordering::AcqRel) {
141                self.cancel_count.fetch_add(1, Ordering::AcqRel);
142            }
143        }
144
145        fn is_cancelled(&self) -> bool {
146            self.cancelled.load(Ordering::Acquire)
147        }
148    }
149
150    #[test]
151    fn test_cancel_token() {
152        let token = CancelToken::new();
153
154        token.cancel();
155
156        assert!(token.is_cancelled());
157    }
158
159    #[test]
160    fn test_cancel_children() {
161        let token = CancelToken::new();
162        let child1 = Arc::new(TestCancellable::new());
163        let child2 = Arc::new(TestCancellable::new());
164        token.link(child1.clone());
165        token.link(child2.clone());
166
167        token.cancel();
168
169        assert!(child1.is_cancelled());
170        assert!(child2.is_cancelled());
171    }
172
173    #[test]
174    fn test_idempotent_cancel() {
175        let token = CancelToken::new();
176        let child = Arc::new(TestCancellable::new());
177        token.link(child.clone());
178        token.cancel();
179
180        token.cancel();
181
182        assert!(token.is_cancelled());
183        assert!(child.is_cancelled());
184        assert_eq!(child.cancel_count(), 1);
185    }
186
187    #[test]
188    fn test_concurrent_cancel() {
189        let token = CancelToken::new();
190        let child = Arc::new(TestCancellable::new());
191        token.link(child.clone());
192        let token_clone = token.clone();
193        let handle = thread::spawn(move || {
194            thread::sleep(Duration::from_millis(50));
195            token_clone.cancel();
196        });
197
198        token.cancel();
199
200        handle.join().unwrap();
201        assert!(token.is_cancelled());
202        assert!(child.is_cancelled());
203        assert_eq!(child.cancel_count(), 1); // Should only be cancelled once
204    }
205
206    #[test]
207    fn test_link_after_cancel() {
208        let token = CancelToken::new();
209        token.cancel();
210
211        let child = Arc::new(TestCancellable::new());
212        token.link(child.clone());
213
214        assert!(child.is_cancelled());
215    }
216}