enough_tokio/
lib.rs

1//! # enough-tokio
2//!
3//! Bridge tokio's `CancellationToken` to the [`Stop`] trait.
4//!
5//! ## When to Use
6//!
7//! Use this crate when you have:
8//! - Tokio async code that needs to cancel CPU-intensive sync work in `spawn_blocking`
9//! - Libraries that accept `impl Stop` and you want to use tokio's cancellation
10//!
11//! ## Complete Example
12//!
13//! ```rust,no_run
14//! use enough_tokio::TokioStop;
15//! use enough::Stop;
16//! use tokio_util::sync::CancellationToken;
17//!
18//! #[tokio::main]
19//! async fn main() {
20//!     let token = CancellationToken::new();
21//!     let stop = TokioStop::new(token.clone());
22//!
23//!     // Spawn CPU-intensive work
24//!     let handle = tokio::task::spawn_blocking(move || {
25//!         for i in 0..1_000_000 {
26//!             if i % 1000 == 0 && stop.should_stop() {
27//!                 return Err("cancelled");
28//!             }
29//!             // ... do work ...
30//!         }
31//!         Ok("done")
32//!     });
33//!
34//!     // Cancel after timeout
35//!     tokio::time::sleep(std::time::Duration::from_millis(10)).await;
36//!     token.cancel();
37//!
38//!     let result = handle.await.unwrap();
39//!     println!("{:?}", result);
40//! }
41//! ```
42//!
43//! ## Quick Reference
44//!
45//! ```rust,no_run
46//! # use enough_tokio::TokioStop;
47//! # use enough::Stop;
48//! # use tokio_util::sync::CancellationToken;
49//! let token = CancellationToken::new();
50//! let stop = TokioStop::new(token.clone());
51//!
52//! stop.should_stop();         // Check if cancelled (sync)
53//! stop.cancel();              // Trigger cancellation
54//! // stop.cancelled().await;  // Wait for cancellation (async)
55//! let child = stop.child();   // Create child token
56//! ```
57
58#![warn(missing_docs)]
59#![warn(clippy::all)]
60
61use enough::{Stop, StopReason};
62use tokio_util::sync::CancellationToken;
63
64/// Wrapper around tokio's [`CancellationToken`] that implements [`Stop`].
65///
66/// This allows using tokio's cancellation system with libraries that
67/// accept `impl Stop`.
68///
69/// # Example
70///
71/// ```rust
72/// use enough_tokio::TokioStop;
73/// use enough::Stop;
74/// use tokio_util::sync::CancellationToken;
75///
76/// let token = CancellationToken::new();
77/// let stop = TokioStop::new(token.clone());
78///
79/// assert!(!stop.should_stop());
80///
81/// token.cancel();
82///
83/// assert!(stop.should_stop());
84/// ```
85#[derive(Clone)]
86pub struct TokioStop {
87    token: CancellationToken,
88}
89
90impl TokioStop {
91    /// Create a new TokioStop from a CancellationToken.
92    #[inline]
93    pub fn new(token: CancellationToken) -> Self {
94        Self { token }
95    }
96
97    /// Get the underlying CancellationToken.
98    #[inline]
99    pub fn token(&self) -> &CancellationToken {
100        &self.token
101    }
102
103    /// Get a clone of the underlying CancellationToken.
104    #[inline]
105    pub fn into_token(self) -> CancellationToken {
106        self.token
107    }
108
109    /// Wait for cancellation.
110    ///
111    /// This is an async method for use in async contexts.
112    #[inline]
113    pub async fn cancelled(&self) {
114        self.token.cancelled().await;
115    }
116
117    /// Create a child token that is cancelled when this one is.
118    #[inline]
119    pub fn child(&self) -> TokioStop {
120        Self::new(self.token.child_token())
121    }
122
123    /// Cancel the token.
124    #[inline]
125    pub fn cancel(&self) {
126        self.token.cancel();
127    }
128}
129
130impl Stop for TokioStop {
131    #[inline]
132    fn check(&self) -> Result<(), StopReason> {
133        if self.token.is_cancelled() {
134            Err(StopReason::Cancelled)
135        } else {
136            Ok(())
137        }
138    }
139
140    #[inline]
141    fn should_stop(&self) -> bool {
142        self.token.is_cancelled()
143    }
144}
145
146impl From<CancellationToken> for TokioStop {
147    fn from(token: CancellationToken) -> Self {
148        Self::new(token)
149    }
150}
151
152impl From<TokioStop> for CancellationToken {
153    fn from(stop: TokioStop) -> Self {
154        stop.token
155    }
156}
157
158impl std::fmt::Debug for TokioStop {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct("TokioStop")
161            .field("cancelled", &self.token.is_cancelled())
162            .finish()
163    }
164}
165
166/// Extension trait for CancellationToken to easily convert to Stop.
167///
168/// Named `CancellationTokenStopExt` to avoid potential conflicts if
169/// `tokio_util` ever adds a `CancellationTokenExt` trait.
170pub trait CancellationTokenStopExt {
171    /// Convert to a TokioStop for use with `impl Stop` APIs.
172    fn as_stop(&self) -> TokioStop;
173}
174
175impl CancellationTokenStopExt for CancellationToken {
176    fn as_stop(&self) -> TokioStop {
177        TokioStop::new(self.clone())
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn tokio_stop_reflects_token() {
187        let token = CancellationToken::new();
188        let stop = TokioStop::new(token.clone());
189
190        assert!(!stop.should_stop());
191        assert!(stop.check().is_ok());
192
193        token.cancel();
194
195        assert!(stop.should_stop());
196        assert_eq!(stop.check(), Err(StopReason::Cancelled));
197    }
198
199    #[test]
200    fn tokio_stop_child() {
201        let parent = TokioStop::new(CancellationToken::new());
202        let child = parent.child();
203
204        assert!(!child.should_stop());
205
206        parent.cancel();
207
208        assert!(child.should_stop());
209    }
210
211    #[test]
212    fn tokio_stop_is_send_sync() {
213        fn assert_send_sync<T: Send + Sync>() {}
214        assert_send_sync::<TokioStop>();
215    }
216
217    #[test]
218    fn tokio_stop_clone() {
219        let token = CancellationToken::new();
220        let stop1 = TokioStop::new(token.clone());
221        let stop2 = stop1.clone();
222
223        token.cancel();
224
225        assert!(stop1.should_stop());
226        assert!(stop2.should_stop());
227    }
228
229    #[test]
230    fn from_conversions() {
231        let token = CancellationToken::new();
232        let stop: TokioStop = token.clone().into();
233        let _token2: CancellationToken = stop.into();
234    }
235
236    #[test]
237    fn extension_trait() {
238        let token = CancellationToken::new();
239        let stop = token.as_stop();
240
241        assert!(!stop.should_stop());
242        token.cancel();
243        assert!(stop.should_stop());
244    }
245
246    #[tokio::test]
247    async fn cancelled_async() {
248        let token = CancellationToken::new();
249        let stop = TokioStop::new(token.clone());
250
251        // Spawn a task that cancels after a delay
252        tokio::spawn(async move {
253            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
254            token.cancel();
255        });
256
257        // Wait for cancellation
258        stop.cancelled().await;
259
260        assert!(stop.should_stop());
261    }
262
263    #[tokio::test]
264    async fn spawn_blocking_integration() {
265        let token = CancellationToken::new();
266        let stop = TokioStop::new(token.clone());
267
268        let handle = tokio::task::spawn_blocking(move || {
269            let mut count = 0;
270            for i in 0..1_000_000 {
271                if i % 1000 == 0 && stop.should_stop() {
272                    return Err("cancelled");
273                }
274                count += 1;
275                // Simulate work
276                std::hint::black_box(count);
277            }
278            Ok(count)
279        });
280
281        // Cancel quickly
282        tokio::time::sleep(std::time::Duration::from_micros(100)).await;
283        token.cancel();
284
285        let result = handle.await.unwrap();
286        // Either completed or cancelled - both are valid
287        assert!(result.is_ok() || result == Err("cancelled"));
288    }
289
290    #[tokio::test]
291    async fn select_with_cancellation() {
292        let token = CancellationToken::new();
293        let stop = TokioStop::new(token.clone());
294
295        // Spawn cancellation
296        tokio::spawn(async move {
297            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
298            token.cancel();
299        });
300
301        let result = tokio::select! {
302            _ = stop.cancelled() => "cancelled",
303            _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => "timeout",
304        };
305
306        assert_eq!(result, "cancelled");
307    }
308
309    #[tokio::test]
310    async fn multiple_tasks_same_token() {
311        use std::sync::atomic::{AtomicUsize, Ordering};
312        use std::sync::Arc;
313
314        let token = CancellationToken::new();
315        let cancelled_count = Arc::new(AtomicUsize::new(0));
316
317        let mut handles = vec![];
318
319        for _ in 0..10 {
320            let stop = TokioStop::new(token.clone());
321            let cancelled_count = Arc::clone(&cancelled_count);
322
323            handles.push(tokio::spawn(async move {
324                for _ in 0..100 {
325                    if stop.should_stop() {
326                        cancelled_count.fetch_add(1, Ordering::Relaxed);
327                        return;
328                    }
329                    tokio::time::sleep(std::time::Duration::from_millis(5)).await;
330                }
331            }));
332        }
333
334        // Cancel after some tasks have started
335        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
336        token.cancel();
337
338        for h in handles {
339            h.await.unwrap();
340        }
341
342        // At least some tasks should have been cancelled
343        assert!(cancelled_count.load(Ordering::Relaxed) > 0);
344    }
345
346    #[tokio::test]
347    async fn child_token_cancellation() {
348        let parent = TokioStop::new(CancellationToken::new());
349        let child1 = parent.child();
350        let child2 = parent.child();
351
352        assert!(!child1.should_stop());
353        assert!(!child2.should_stop());
354
355        // Cancel one child doesn't affect others
356        child1.cancel();
357        assert!(child1.should_stop());
358        assert!(!child2.should_stop());
359        assert!(!parent.should_stop());
360
361        // Cancel parent affects remaining children
362        parent.cancel();
363        assert!(child2.should_stop());
364    }
365
366    #[tokio::test]
367    async fn nested_child_tokens() {
368        let root = TokioStop::new(CancellationToken::new());
369        let level1 = root.child();
370        let level2 = level1.child();
371        let level3 = level2.child();
372
373        assert!(!level3.should_stop());
374
375        root.cancel();
376
377        assert!(level1.should_stop());
378        assert!(level2.should_stop());
379        assert!(level3.should_stop());
380    }
381
382    #[tokio::test]
383    async fn check_returns_correct_reason() {
384        let token = CancellationToken::new();
385        let stop = TokioStop::new(token.clone());
386
387        assert_eq!(stop.check(), Ok(()));
388
389        token.cancel();
390
391        assert_eq!(stop.check(), Err(StopReason::Cancelled));
392    }
393
394    #[tokio::test]
395    async fn debug_formatting() {
396        let token = CancellationToken::new();
397        let stop = TokioStop::new(token.clone());
398
399        let debug = format!("{:?}", stop);
400        assert!(debug.contains("TokioStop"));
401        assert!(debug.contains("cancelled"));
402        assert!(debug.contains("false"));
403
404        token.cancel();
405
406        let debug = format!("{:?}", stop);
407        assert!(debug.contains("true"));
408    }
409
410    #[tokio::test]
411    async fn integration_with_stop_trait() {
412        fn process_sync(data: &[u8], stop: impl Stop) -> Result<usize, &'static str> {
413            for (i, _chunk) in data.chunks(100).enumerate() {
414                if i % 10 == 0 && stop.should_stop() {
415                    return Err("cancelled");
416                }
417            }
418            Ok(data.len())
419        }
420
421        let token = CancellationToken::new();
422        let stop = TokioStop::new(token.clone());
423        let data = vec![0u8; 10000];
424
425        // Not cancelled - completes
426        let result = process_sync(&data, stop.clone());
427        assert_eq!(result, Ok(10000));
428
429        // Cancel and retry
430        token.cancel();
431        let result = process_sync(&data, stop);
432        assert_eq!(result, Err("cancelled"));
433    }
434
435    #[tokio::test]
436    async fn token_accessor_methods() {
437        let original_token = CancellationToken::new();
438        let stop = TokioStop::new(original_token.clone());
439
440        // token() returns reference
441        let token_ref = stop.token();
442        assert!(!token_ref.is_cancelled());
443
444        // into_token() consumes and returns owned token
445        let recovered_token = stop.into_token();
446        assert!(!recovered_token.is_cancelled());
447
448        // Original token still works
449        original_token.cancel();
450        assert!(recovered_token.is_cancelled());
451    }
452
453    #[test]
454    fn sync_send_bounds() {
455        fn assert_send<T: Send>() {}
456        fn assert_sync<T: Sync>() {}
457
458        assert_send::<TokioStop>();
459        assert_sync::<TokioStop>();
460    }
461
462    #[tokio::test]
463    async fn rapid_cancel_check_cycle() {
464        // Stress test rapid cancellation
465        for _ in 0..100 {
466            let token = CancellationToken::new();
467            let stop = TokioStop::new(token.clone());
468
469            assert!(!stop.should_stop());
470            token.cancel();
471            assert!(stop.should_stop());
472        }
473    }
474
475    #[tokio::test]
476    async fn select_loop_with_pinned_cancelled() {
477        use tokio::sync::mpsc;
478
479        let token = CancellationToken::new();
480        let stop = TokioStop::new(token.clone());
481        let (tx, mut rx) = mpsc::channel::<i32>(10);
482
483        // Send some messages
484        tx.send(1).await.unwrap();
485        tx.send(2).await.unwrap();
486        tx.send(3).await.unwrap();
487
488        // Spawn cancellation after messages
489        let token_clone = token.clone();
490        tokio::spawn(async move {
491            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
492            token_clone.cancel();
493        });
494
495        // Correct pattern: pin the future outside the loop
496        let cancelled = stop.cancelled();
497        tokio::pin!(cancelled);
498
499        let mut received = vec![];
500        let mut was_cancelled = false;
501
502        loop {
503            tokio::select! {
504                _ = &mut cancelled => {
505                    was_cancelled = true;
506                    break;
507                }
508                msg = rx.recv() => {
509                    match msg {
510                        Some(m) => received.push(m),
511                        None => break,
512                    }
513                }
514            }
515        }
516
517        assert_eq!(received, vec![1, 2, 3]);
518        assert!(was_cancelled);
519    }
520
521    #[tokio::test]
522    async fn select_biased_cancellation_priority() {
523        use tokio::sync::mpsc;
524
525        let token = CancellationToken::new();
526        let stop = TokioStop::new(token.clone());
527        let (tx, mut rx) = mpsc::channel::<i32>(10);
528
529        // Pre-cancel before loop
530        token.cancel();
531
532        // Send a message (channel should still have it)
533        tx.send(42).await.unwrap();
534
535        let cancelled = stop.cancelled();
536        tokio::pin!(cancelled);
537
538        // With biased, cancellation should win since it's first
539        let result = tokio::select! {
540            biased;
541            _ = &mut cancelled => "cancelled",
542            _ = rx.recv() => "received",
543        };
544
545        assert_eq!(result, "cancelled");
546    }
547}