burn_central_core/experiment/
cancellation.rs1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex};
3
4pub trait Cancellable: Send + Sync {
6 fn cancel(&self);
8 fn is_cancelled(&self) -> bool;
10}
11
12type CancellableRef = Arc<dyn Cancellable>;
13
14#[derive(Clone, Default)]
16pub struct CancelToken {
17 inner: Arc<Inner>,
18}
19
20#[derive(Default)]
21struct Inner {
22 cancelled: AtomicBool,
23 children: Mutex<Vec<CancellableRef>>,
24}
25
26impl CancelToken {
27 pub fn new() -> Self {
29 Self {
30 inner: Arc::new(Inner {
31 cancelled: AtomicBool::new(false),
32 children: Mutex::new(Vec::new()),
33 }),
34 }
35 }
36
37 pub fn is_cancelled(&self) -> bool {
39 self.inner.cancelled.load(Ordering::Acquire)
40 }
41
42 pub fn link<T: Cancellable + 'static>(&self, child: T) {
45 let child = Arc::new(child);
46 if self.is_cancelled() {
47 child.cancel();
48 return;
49 }
50
51 let mut children = self.inner.children.lock().unwrap();
52
53 if self.is_cancelled() {
54 drop(children);
55 child.cancel();
56 return;
57 }
58
59 children.push(child);
60 }
61
62 pub fn cancel(&self) {
64 if self.inner.cancelled.swap(true, Ordering::AcqRel) {
65 return;
66 }
67
68 let children = {
69 let mut children = self.inner.children.lock().unwrap();
70 std::mem::take(&mut *children)
71 };
72
73 for c in children {
74 c.cancel();
75 }
76 }
77
78 pub fn into_linked<T: Cancellable + Default + Clone + 'static>(&self) -> T {
81 let merged = T::default();
82 self.link(merged.clone());
83 merged
84 }
85
86 pub fn linked<T: Cancellable + Clone + 'static>(&self, child: T) -> T {
89 self.link(child.clone());
90 child
91 }
92}
93
94impl Cancellable for CancelToken {
95 fn cancel(&self) {
96 CancelToken::cancel(self)
97 }
98 fn is_cancelled(&self) -> bool {
99 self.is_cancelled()
100 }
101}
102
103impl<T: Cancellable> Cancellable for Arc<T> {
104 fn cancel(&self) {
105 self.as_ref().cancel()
106 }
107 fn is_cancelled(&self) -> bool {
108 self.as_ref().is_cancelled()
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use std::sync::Arc;
116 use std::sync::atomic::AtomicUsize;
117 use std::thread;
118 use std::time::Duration;
119
120 struct TestCancellable {
121 cancelled: AtomicBool,
122 cancel_count: AtomicUsize,
123 }
124
125 impl TestCancellable {
126 fn new() -> Self {
127 Self {
128 cancelled: AtomicBool::new(false),
129 cancel_count: AtomicUsize::new(0),
130 }
131 }
132
133 fn cancel_count(&self) -> usize {
134 self.cancel_count.load(Ordering::Relaxed)
135 }
136 }
137
138 impl Cancellable for TestCancellable {
139 fn cancel(&self) {
140 if !self.cancelled.swap(true, Ordering::AcqRel) {
141 self.cancel_count.fetch_add(1, Ordering::AcqRel);
142 }
143 }
144
145 fn is_cancelled(&self) -> bool {
146 self.cancelled.load(Ordering::Acquire)
147 }
148 }
149
150 #[test]
151 fn test_cancel_token() {
152 let token = CancelToken::new();
153
154 token.cancel();
155
156 assert!(token.is_cancelled());
157 }
158
159 #[test]
160 fn test_cancel_children() {
161 let token = CancelToken::new();
162 let child1 = Arc::new(TestCancellable::new());
163 let child2 = Arc::new(TestCancellable::new());
164 token.link(child1.clone());
165 token.link(child2.clone());
166
167 token.cancel();
168
169 assert!(child1.is_cancelled());
170 assert!(child2.is_cancelled());
171 }
172
173 #[test]
174 fn test_idempotent_cancel() {
175 let token = CancelToken::new();
176 let child = Arc::new(TestCancellable::new());
177 token.link(child.clone());
178 token.cancel();
179
180 token.cancel();
181
182 assert!(token.is_cancelled());
183 assert!(child.is_cancelled());
184 assert_eq!(child.cancel_count(), 1);
185 }
186
187 #[test]
188 fn test_concurrent_cancel() {
189 let token = CancelToken::new();
190 let child = Arc::new(TestCancellable::new());
191 token.link(child.clone());
192 let token_clone = token.clone();
193 let handle = thread::spawn(move || {
194 thread::sleep(Duration::from_millis(50));
195 token_clone.cancel();
196 });
197
198 token.cancel();
199
200 handle.join().unwrap();
201 assert!(token.is_cancelled());
202 assert!(child.is_cancelled());
203 assert_eq!(child.cancel_count(), 1); }
205
206 #[test]
207 fn test_link_after_cancel() {
208 let token = CancelToken::new();
209 token.cancel();
210
211 let child = Arc::new(TestCancellable::new());
212 token.link(child.clone());
213
214 assert!(child.is_cancelled());
215 }
216}