Skip to main content

kithara_platform/
cancel_group.rs

1use std::{ops::BitOr, sync::Arc};
2
3use futures::future::select_all;
4use tokio_util::sync::CancellationToken;
5
6/// OR-combinator for cancellation tokens.
7///
8/// Fires when **any** source token is cancelled. No spawn — uses
9/// sync polling for `is_cancelled()` and `select_all` for the
10/// async `cancelled()` future.
11///
12/// Supports composition via `|`:
13/// ```ignore
14/// let cancel = token_a | token_b;
15/// let cancel = group | extra_token;
16/// let cancel = group1 | group2;
17/// ```
18#[derive(Clone)]
19pub struct CancelGroup {
20    sources: Arc<[CancellationToken]>,
21}
22
23impl CancelGroup {
24    #[must_use]
25    pub fn new(sources: Vec<CancellationToken>) -> Self {
26        Self {
27            sources: sources.into(),
28        }
29    }
30
31    pub async fn cancelled(&self) {
32        if self.is_cancelled() {
33            return;
34        }
35        let futs: Vec<_> = self
36            .sources
37            .iter()
38            .map(|s| Box::pin(s.cancelled()))
39            .collect();
40        if futs.is_empty() {
41            std::future::pending::<()>().await;
42            return;
43        }
44        select_all(futs).await;
45    }
46
47    /// Returns `true` if both groups share the same underlying source array.
48    #[must_use]
49    pub fn equals_ptr(&self, other: &Self) -> bool {
50        Arc::ptr_eq(&self.sources, &other.sources)
51    }
52
53    #[must_use]
54    pub fn is_cancelled(&self) -> bool {
55        self.sources.iter().any(CancellationToken::is_cancelled)
56    }
57
58    fn tokens(&self) -> &[CancellationToken] {
59        &self.sources
60    }
61}
62
63impl From<CancellationToken> for CancelGroup {
64    fn from(token: CancellationToken) -> Self {
65        Self::new(vec![token])
66    }
67}
68
69impl From<Vec<CancellationToken>> for CancelGroup {
70    fn from(tokens: Vec<CancellationToken>) -> Self {
71        Self::new(tokens)
72    }
73}
74
75impl BitOr for CancelGroup {
76    type Output = Self;
77
78    fn bitor(self, rhs: Self) -> Self {
79        let mut tokens = self.tokens().to_vec();
80        tokens.extend_from_slice(rhs.tokens());
81        Self::new(tokens)
82    }
83}
84
85impl BitOr<CancellationToken> for CancelGroup {
86    type Output = Self;
87
88    fn bitor(self, rhs: CancellationToken) -> Self {
89        let mut tokens = self.tokens().to_vec();
90        tokens.push(rhs);
91        Self::new(tokens)
92    }
93}
94
95impl BitOr<CancelGroup> for CancellationToken {
96    type Output = CancelGroup;
97
98    fn bitor(self, rhs: CancelGroup) -> CancelGroup {
99        let mut tokens = vec![self];
100        tokens.extend_from_slice(rhs.tokens());
101        CancelGroup::new(tokens)
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use std::time::Duration;
108
109    use kithara_test_utils::kithara;
110    use tokio::{spawn, task, time as tokio_time};
111    use tokio_util::sync::CancellationToken;
112
113    use super::CancelGroup;
114
115    #[derive(Clone, Debug)]
116    enum Src {
117        Fresh,
118        ChildOf(usize),
119        PreCancelled,
120    }
121
122    #[derive(Clone, Debug)]
123    enum Act {
124        Source(usize),
125        Parent(usize),
126        None,
127    }
128
129    struct Setup {
130        group: CancelGroup,
131        parents: Vec<CancellationToken>,
132        sources: Vec<CancellationToken>,
133    }
134
135    fn build(spec: &[Src]) -> Setup {
136        let mut parents: Vec<CancellationToken> = Vec::new();
137        let mut sources: Vec<CancellationToken> = Vec::new();
138
139        for s in spec {
140            match s {
141                Src::Fresh => sources.push(CancellationToken::new()),
142                Src::ChildOf(idx) => {
143                    while parents.len() <= *idx {
144                        parents.push(CancellationToken::new());
145                    }
146                    sources.push(parents[*idx].child_token());
147                }
148                Src::PreCancelled => {
149                    let tok = CancellationToken::new();
150                    tok.cancel();
151                    sources.push(tok);
152                }
153            }
154        }
155
156        let group = CancelGroup::new(sources.clone());
157        Setup {
158            group,
159            parents,
160            sources,
161        }
162    }
163
164    fn fire(act: &Act, s: &Setup) {
165        match act {
166            Act::Source(i) => s.sources[*i].cancel(),
167            Act::Parent(i) => s.parents[*i].cancel(),
168            Act::None => {}
169        }
170    }
171
172    macro_rules! sync_cancel_tests {
173        ($($name:ident: $spec:expr, $action:expr, $expected:expr;)*) => {
174            $(
175                #[kithara::test(timeout(Duration::from_secs(5)))]
176                fn $name() {
177                    let s = build(&$spec);
178                    fire(&$action, &s);
179                    assert_eq!(s.group.is_cancelled(), $expected);
180                }
181            )*
182        }
183    }
184
185    sync_cancel_tests! {
186        two_fresh_cancel_first:
187            [Src::Fresh, Src::Fresh], Act::Source(0), true;
188        two_fresh_cancel_second:
189            [Src::Fresh, Src::Fresh], Act::Source(1), true;
190        single_cancel:
191            [Src::Fresh], Act::Source(0), true;
192        two_fresh_no_cancel:
193            [Src::Fresh, Src::Fresh], Act::None, false;
194        pre_cancelled_plus_fresh:
195            [Src::PreCancelled, Src::Fresh], Act::None, true;
196        fresh_and_child_cancel_fresh:
197            [Src::Fresh, Src::ChildOf(0)], Act::Source(0), true;
198        fresh_and_child_cancel_parent:
199            [Src::Fresh, Src::ChildOf(0)], Act::Parent(0), true;
200        two_children_same_parent_cancel_parent:
201            [Src::ChildOf(0), Src::ChildOf(0)], Act::Parent(0), true;
202        two_children_diff_parents_cancel_first:
203            [Src::ChildOf(0), Src::ChildOf(1)], Act::Parent(0), true;
204        two_children_diff_parents_cancel_second:
205            [Src::ChildOf(0), Src::ChildOf(1)], Act::Parent(1), true;
206        two_children_diff_parents_no_cancel:
207            [Src::ChildOf(0), Src::ChildOf(1)], Act::None, false;
208        mixed_with_pre_cancelled:
209            [Src::Fresh, Src::ChildOf(0), Src::PreCancelled], Act::None, true;
210    }
211
212    macro_rules! async_cancel_tests {
213        ($($name:ident: $spec:expr, $action:expr;)*) => {
214            $(
215                #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
216                async fn $name() {
217                    let s = build(&$spec);
218                    let group2 = s.group.clone();
219                    let handle = spawn(async move { group2.cancelled().await });
220
221                    task::yield_now().await;
222
223                    assert!(!s.group.is_cancelled(), "must not be cancelled before action");
224                    fire(&$action, &s);
225
226                    tokio_time::timeout(Duration::from_secs(2), handle)
227                        .await
228                        .expect("BUG: cancelled() must resolve within the test timeout")
229                        .expect("BUG: spawned cancellation task must not panic");
230                }
231            )*
232        }
233    }
234
235    async_cancel_tests! {
236        async_two_fresh_cancel_first:
237            [Src::Fresh, Src::Fresh], Act::Source(0);
238        async_two_fresh_cancel_second:
239            [Src::Fresh, Src::Fresh], Act::Source(1);
240        async_single_cancel:
241            [Src::Fresh], Act::Source(0);
242        async_fresh_and_child_cancel_parent:
243            [Src::Fresh, Src::ChildOf(0)], Act::Parent(0);
244        async_two_children_same_parent:
245            [Src::ChildOf(0), Src::ChildOf(0)], Act::Parent(0);
246        async_two_children_diff_parents_cancel_first:
247            [Src::ChildOf(0), Src::ChildOf(1)], Act::Parent(0);
248        async_two_children_diff_parents_cancel_second:
249            [Src::ChildOf(0), Src::ChildOf(1)], Act::Parent(1);
250    }
251
252    #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
253    async fn cancelled_resolves_immediately_when_pre_cancelled() {
254        let tok = CancellationToken::new();
255        tok.cancel();
256        let group = CancelGroup::new(vec![tok, CancellationToken::new()]);
257
258        tokio_time::timeout(Duration::from_secs(1), group.cancelled())
259            .await
260            .expect("BUG: cancelled() must return immediately for a pre-cancelled source");
261    }
262
263    #[kithara::test(timeout(Duration::from_secs(5)))]
264    fn empty_group_is_not_cancelled() {
265        let group = CancelGroup::new(vec![]);
266        assert!(!group.is_cancelled());
267    }
268
269    #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
270    async fn empty_group_cancelled_never_resolves() {
271        let group = CancelGroup::new(vec![]);
272        let result = tokio_time::timeout(Duration::from_millis(50), group.cancelled()).await;
273        assert!(
274            result.is_err(),
275            "cancelled() on empty group must not resolve"
276        );
277    }
278
279    #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
280    async fn clone_observes_same_cancellation() {
281        let tok = CancellationToken::new();
282        let group = CancelGroup::new(vec![tok.clone()]);
283        let cloned = group.clone();
284
285        tok.cancel();
286        assert!(group.is_cancelled());
287        assert!(cloned.is_cancelled());
288    }
289
290    #[kithara::test(timeout(Duration::from_secs(5)))]
291    fn token_bitor_token() {
292        let a = CancellationToken::new();
293        let b = CancellationToken::new();
294        let group = CancelGroup::from(a.clone()) | b.clone();
295
296        assert!(!group.is_cancelled());
297        a.cancel();
298        assert!(group.is_cancelled());
299    }
300
301    #[kithara::test(timeout(Duration::from_secs(5)))]
302    fn group_bitor_token() {
303        let a = CancellationToken::new();
304        let b = CancellationToken::new();
305        let group = CancelGroup::from(a.clone()) | b.clone();
306
307        assert!(!group.is_cancelled());
308        b.cancel();
309        assert!(group.is_cancelled());
310    }
311
312    #[kithara::test(timeout(Duration::from_secs(5)))]
313    fn token_bitor_group() {
314        let a = CancellationToken::new();
315        let b = CancellationToken::new();
316        let group = a.clone() | CancelGroup::from(b.clone());
317
318        assert!(!group.is_cancelled());
319        a.cancel();
320        assert!(group.is_cancelled());
321    }
322
323    #[kithara::test(timeout(Duration::from_secs(5)))]
324    fn group_bitor_group() {
325        let a = CancellationToken::new();
326        let b = CancellationToken::new();
327        let g1 = CancelGroup::from(a.clone());
328        let g2 = CancelGroup::from(b.clone());
329        let merged = g1 | g2;
330
331        assert!(!merged.is_cancelled());
332        b.cancel();
333        assert!(merged.is_cancelled());
334    }
335
336    #[kithara::test(timeout(Duration::from_secs(5)))]
337    fn chained_bitor() {
338        let a = CancellationToken::new();
339        let b = CancellationToken::new();
340        let c = CancellationToken::new();
341        let group = CancelGroup::from(a.clone()) | b.clone() | c.clone();
342
343        assert!(!group.is_cancelled());
344        c.cancel();
345        assert!(group.is_cancelled());
346    }
347
348    #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
349    async fn bitor_async_cancelled() {
350        let a = CancellationToken::new();
351        let b = CancellationToken::new();
352        let group = CancelGroup::from(a.clone()) | b.clone();
353
354        let g2 = group.clone();
355        let handle = spawn(async move { g2.cancelled().await });
356        task::yield_now().await;
357
358        assert!(!group.is_cancelled());
359        b.cancel();
360
361        tokio_time::timeout(Duration::from_secs(2), handle)
362            .await
363            .expect("BUG: cancelled() must resolve once one source has cancelled")
364            .expect("BUG: spawned task awaiting cancellation must not panic");
365    }
366}