1use std::{ops::BitOr, sync::Arc};
2
3use futures::future::select_all;
4use tokio_util::sync::CancellationToken;
5
6#[derive(Clone)]
19pub struct CancelGroup {
20 sources: Arc<[CancellationToken]>,
21}
22
23impl CancelGroup {
24 #[must_use]
25 pub fn new(sources: Vec<CancellationToken>) -> Self {
26 Self {
27 sources: sources.into(),
28 }
29 }
30
31 pub async fn cancelled(&self) {
32 if self.is_cancelled() {
33 return;
34 }
35 let futs: Vec<_> = self
36 .sources
37 .iter()
38 .map(|s| Box::pin(s.cancelled()))
39 .collect();
40 if futs.is_empty() {
41 std::future::pending::<()>().await;
42 return;
43 }
44 select_all(futs).await;
45 }
46
47 #[must_use]
49 pub fn equals_ptr(&self, other: &Self) -> bool {
50 Arc::ptr_eq(&self.sources, &other.sources)
51 }
52
53 #[must_use]
54 pub fn is_cancelled(&self) -> bool {
55 self.sources.iter().any(CancellationToken::is_cancelled)
56 }
57
58 fn tokens(&self) -> &[CancellationToken] {
59 &self.sources
60 }
61}
62
63impl From<CancellationToken> for CancelGroup {
64 fn from(token: CancellationToken) -> Self {
65 Self::new(vec![token])
66 }
67}
68
69impl From<Vec<CancellationToken>> for CancelGroup {
70 fn from(tokens: Vec<CancellationToken>) -> Self {
71 Self::new(tokens)
72 }
73}
74
75impl BitOr for CancelGroup {
76 type Output = Self;
77
78 fn bitor(self, rhs: Self) -> Self {
79 let mut tokens = self.tokens().to_vec();
80 tokens.extend_from_slice(rhs.tokens());
81 Self::new(tokens)
82 }
83}
84
85impl BitOr<CancellationToken> for CancelGroup {
86 type Output = Self;
87
88 fn bitor(self, rhs: CancellationToken) -> Self {
89 let mut tokens = self.tokens().to_vec();
90 tokens.push(rhs);
91 Self::new(tokens)
92 }
93}
94
95impl BitOr<CancelGroup> for CancellationToken {
96 type Output = CancelGroup;
97
98 fn bitor(self, rhs: CancelGroup) -> CancelGroup {
99 let mut tokens = vec![self];
100 tokens.extend_from_slice(rhs.tokens());
101 CancelGroup::new(tokens)
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use std::time::Duration;
108
109 use kithara_test_utils::kithara;
110 use tokio::{spawn, task, time as tokio_time};
111 use tokio_util::sync::CancellationToken;
112
113 use super::CancelGroup;
114
115 #[derive(Clone, Debug)]
116 enum Src {
117 Fresh,
118 ChildOf(usize),
119 PreCancelled,
120 }
121
122 #[derive(Clone, Debug)]
123 enum Act {
124 Source(usize),
125 Parent(usize),
126 None,
127 }
128
129 struct Setup {
130 group: CancelGroup,
131 parents: Vec<CancellationToken>,
132 sources: Vec<CancellationToken>,
133 }
134
135 fn build(spec: &[Src]) -> Setup {
136 let mut parents: Vec<CancellationToken> = Vec::new();
137 let mut sources: Vec<CancellationToken> = Vec::new();
138
139 for s in spec {
140 match s {
141 Src::Fresh => sources.push(CancellationToken::new()),
142 Src::ChildOf(idx) => {
143 while parents.len() <= *idx {
144 parents.push(CancellationToken::new());
145 }
146 sources.push(parents[*idx].child_token());
147 }
148 Src::PreCancelled => {
149 let tok = CancellationToken::new();
150 tok.cancel();
151 sources.push(tok);
152 }
153 }
154 }
155
156 let group = CancelGroup::new(sources.clone());
157 Setup {
158 group,
159 parents,
160 sources,
161 }
162 }
163
164 fn fire(act: &Act, s: &Setup) {
165 match act {
166 Act::Source(i) => s.sources[*i].cancel(),
167 Act::Parent(i) => s.parents[*i].cancel(),
168 Act::None => {}
169 }
170 }
171
172 macro_rules! sync_cancel_tests {
173 ($($name:ident: $spec:expr, $action:expr, $expected:expr;)*) => {
174 $(
175 #[kithara::test(timeout(Duration::from_secs(5)))]
176 fn $name() {
177 let s = build(&$spec);
178 fire(&$action, &s);
179 assert_eq!(s.group.is_cancelled(), $expected);
180 }
181 )*
182 }
183 }
184
185 sync_cancel_tests! {
186 two_fresh_cancel_first:
187 [Src::Fresh, Src::Fresh], Act::Source(0), true;
188 two_fresh_cancel_second:
189 [Src::Fresh, Src::Fresh], Act::Source(1), true;
190 single_cancel:
191 [Src::Fresh], Act::Source(0), true;
192 two_fresh_no_cancel:
193 [Src::Fresh, Src::Fresh], Act::None, false;
194 pre_cancelled_plus_fresh:
195 [Src::PreCancelled, Src::Fresh], Act::None, true;
196 fresh_and_child_cancel_fresh:
197 [Src::Fresh, Src::ChildOf(0)], Act::Source(0), true;
198 fresh_and_child_cancel_parent:
199 [Src::Fresh, Src::ChildOf(0)], Act::Parent(0), true;
200 two_children_same_parent_cancel_parent:
201 [Src::ChildOf(0), Src::ChildOf(0)], Act::Parent(0), true;
202 two_children_diff_parents_cancel_first:
203 [Src::ChildOf(0), Src::ChildOf(1)], Act::Parent(0), true;
204 two_children_diff_parents_cancel_second:
205 [Src::ChildOf(0), Src::ChildOf(1)], Act::Parent(1), true;
206 two_children_diff_parents_no_cancel:
207 [Src::ChildOf(0), Src::ChildOf(1)], Act::None, false;
208 mixed_with_pre_cancelled:
209 [Src::Fresh, Src::ChildOf(0), Src::PreCancelled], Act::None, true;
210 }
211
212 macro_rules! async_cancel_tests {
213 ($($name:ident: $spec:expr, $action:expr;)*) => {
214 $(
215 #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
216 async fn $name() {
217 let s = build(&$spec);
218 let group2 = s.group.clone();
219 let handle = spawn(async move { group2.cancelled().await });
220
221 task::yield_now().await;
222
223 assert!(!s.group.is_cancelled(), "must not be cancelled before action");
224 fire(&$action, &s);
225
226 tokio_time::timeout(Duration::from_secs(2), handle)
227 .await
228 .expect("BUG: cancelled() must resolve within the test timeout")
229 .expect("BUG: spawned cancellation task must not panic");
230 }
231 )*
232 }
233 }
234
235 async_cancel_tests! {
236 async_two_fresh_cancel_first:
237 [Src::Fresh, Src::Fresh], Act::Source(0);
238 async_two_fresh_cancel_second:
239 [Src::Fresh, Src::Fresh], Act::Source(1);
240 async_single_cancel:
241 [Src::Fresh], Act::Source(0);
242 async_fresh_and_child_cancel_parent:
243 [Src::Fresh, Src::ChildOf(0)], Act::Parent(0);
244 async_two_children_same_parent:
245 [Src::ChildOf(0), Src::ChildOf(0)], Act::Parent(0);
246 async_two_children_diff_parents_cancel_first:
247 [Src::ChildOf(0), Src::ChildOf(1)], Act::Parent(0);
248 async_two_children_diff_parents_cancel_second:
249 [Src::ChildOf(0), Src::ChildOf(1)], Act::Parent(1);
250 }
251
252 #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
253 async fn cancelled_resolves_immediately_when_pre_cancelled() {
254 let tok = CancellationToken::new();
255 tok.cancel();
256 let group = CancelGroup::new(vec![tok, CancellationToken::new()]);
257
258 tokio_time::timeout(Duration::from_secs(1), group.cancelled())
259 .await
260 .expect("BUG: cancelled() must return immediately for a pre-cancelled source");
261 }
262
263 #[kithara::test(timeout(Duration::from_secs(5)))]
264 fn empty_group_is_not_cancelled() {
265 let group = CancelGroup::new(vec![]);
266 assert!(!group.is_cancelled());
267 }
268
269 #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
270 async fn empty_group_cancelled_never_resolves() {
271 let group = CancelGroup::new(vec![]);
272 let result = tokio_time::timeout(Duration::from_millis(50), group.cancelled()).await;
273 assert!(
274 result.is_err(),
275 "cancelled() on empty group must not resolve"
276 );
277 }
278
279 #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
280 async fn clone_observes_same_cancellation() {
281 let tok = CancellationToken::new();
282 let group = CancelGroup::new(vec![tok.clone()]);
283 let cloned = group.clone();
284
285 tok.cancel();
286 assert!(group.is_cancelled());
287 assert!(cloned.is_cancelled());
288 }
289
290 #[kithara::test(timeout(Duration::from_secs(5)))]
291 fn token_bitor_token() {
292 let a = CancellationToken::new();
293 let b = CancellationToken::new();
294 let group = CancelGroup::from(a.clone()) | b.clone();
295
296 assert!(!group.is_cancelled());
297 a.cancel();
298 assert!(group.is_cancelled());
299 }
300
301 #[kithara::test(timeout(Duration::from_secs(5)))]
302 fn group_bitor_token() {
303 let a = CancellationToken::new();
304 let b = CancellationToken::new();
305 let group = CancelGroup::from(a.clone()) | b.clone();
306
307 assert!(!group.is_cancelled());
308 b.cancel();
309 assert!(group.is_cancelled());
310 }
311
312 #[kithara::test(timeout(Duration::from_secs(5)))]
313 fn token_bitor_group() {
314 let a = CancellationToken::new();
315 let b = CancellationToken::new();
316 let group = a.clone() | CancelGroup::from(b.clone());
317
318 assert!(!group.is_cancelled());
319 a.cancel();
320 assert!(group.is_cancelled());
321 }
322
323 #[kithara::test(timeout(Duration::from_secs(5)))]
324 fn group_bitor_group() {
325 let a = CancellationToken::new();
326 let b = CancellationToken::new();
327 let g1 = CancelGroup::from(a.clone());
328 let g2 = CancelGroup::from(b.clone());
329 let merged = g1 | g2;
330
331 assert!(!merged.is_cancelled());
332 b.cancel();
333 assert!(merged.is_cancelled());
334 }
335
336 #[kithara::test(timeout(Duration::from_secs(5)))]
337 fn chained_bitor() {
338 let a = CancellationToken::new();
339 let b = CancellationToken::new();
340 let c = CancellationToken::new();
341 let group = CancelGroup::from(a.clone()) | b.clone() | c.clone();
342
343 assert!(!group.is_cancelled());
344 c.cancel();
345 assert!(group.is_cancelled());
346 }
347
348 #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
349 async fn bitor_async_cancelled() {
350 let a = CancellationToken::new();
351 let b = CancellationToken::new();
352 let group = CancelGroup::from(a.clone()) | b.clone();
353
354 let g2 = group.clone();
355 let handle = spawn(async move { g2.cancelled().await });
356 task::yield_now().await;
357
358 assert!(!group.is_cancelled());
359 b.cancel();
360
361 tokio_time::timeout(Duration::from_secs(2), handle)
362 .await
363 .expect("BUG: cancelled() must resolve once one source has cancelled")
364 .expect("BUG: spawned task awaiting cancellation must not panic");
365 }
366}