1#![warn(missing_docs)]
59#![warn(clippy::all)]
60
61use enough::{Stop, StopReason};
62use tokio_util::sync::CancellationToken;
63
64#[derive(Clone)]
86pub struct TokioStop {
87 token: CancellationToken,
88}
89
90impl TokioStop {
91 #[inline]
93 pub fn new(token: CancellationToken) -> Self {
94 Self { token }
95 }
96
97 #[inline]
99 pub fn token(&self) -> &CancellationToken {
100 &self.token
101 }
102
103 #[inline]
105 pub fn into_token(self) -> CancellationToken {
106 self.token
107 }
108
109 #[inline]
113 pub async fn cancelled(&self) {
114 self.token.cancelled().await;
115 }
116
117 #[inline]
119 pub fn child(&self) -> TokioStop {
120 Self::new(self.token.child_token())
121 }
122
123 #[inline]
125 pub fn cancel(&self) {
126 self.token.cancel();
127 }
128}
129
130impl Stop for TokioStop {
131 #[inline]
132 fn check(&self) -> Result<(), StopReason> {
133 if self.token.is_cancelled() {
134 Err(StopReason::Cancelled)
135 } else {
136 Ok(())
137 }
138 }
139
140 #[inline]
141 fn should_stop(&self) -> bool {
142 self.token.is_cancelled()
143 }
144}
145
146impl From<CancellationToken> for TokioStop {
147 fn from(token: CancellationToken) -> Self {
148 Self::new(token)
149 }
150}
151
152impl From<TokioStop> for CancellationToken {
153 fn from(stop: TokioStop) -> Self {
154 stop.token
155 }
156}
157
158impl std::fmt::Debug for TokioStop {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("TokioStop")
161 .field("cancelled", &self.token.is_cancelled())
162 .finish()
163 }
164}
165
166pub trait CancellationTokenStopExt {
171 fn as_stop(&self) -> TokioStop;
173}
174
175impl CancellationTokenStopExt for CancellationToken {
176 fn as_stop(&self) -> TokioStop {
177 TokioStop::new(self.clone())
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn tokio_stop_reflects_token() {
187 let token = CancellationToken::new();
188 let stop = TokioStop::new(token.clone());
189
190 assert!(!stop.should_stop());
191 assert!(stop.check().is_ok());
192
193 token.cancel();
194
195 assert!(stop.should_stop());
196 assert_eq!(stop.check(), Err(StopReason::Cancelled));
197 }
198
199 #[test]
200 fn tokio_stop_child() {
201 let parent = TokioStop::new(CancellationToken::new());
202 let child = parent.child();
203
204 assert!(!child.should_stop());
205
206 parent.cancel();
207
208 assert!(child.should_stop());
209 }
210
211 #[test]
212 fn tokio_stop_is_send_sync() {
213 fn assert_send_sync<T: Send + Sync>() {}
214 assert_send_sync::<TokioStop>();
215 }
216
217 #[test]
218 fn tokio_stop_clone() {
219 let token = CancellationToken::new();
220 let stop1 = TokioStop::new(token.clone());
221 let stop2 = stop1.clone();
222
223 token.cancel();
224
225 assert!(stop1.should_stop());
226 assert!(stop2.should_stop());
227 }
228
229 #[test]
230 fn from_conversions() {
231 let token = CancellationToken::new();
232 let stop: TokioStop = token.clone().into();
233 let _token2: CancellationToken = stop.into();
234 }
235
236 #[test]
237 fn extension_trait() {
238 let token = CancellationToken::new();
239 let stop = token.as_stop();
240
241 assert!(!stop.should_stop());
242 token.cancel();
243 assert!(stop.should_stop());
244 }
245
246 #[tokio::test]
247 async fn cancelled_async() {
248 let token = CancellationToken::new();
249 let stop = TokioStop::new(token.clone());
250
251 tokio::spawn(async move {
253 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
254 token.cancel();
255 });
256
257 stop.cancelled().await;
259
260 assert!(stop.should_stop());
261 }
262
263 #[tokio::test]
264 async fn spawn_blocking_integration() {
265 let token = CancellationToken::new();
266 let stop = TokioStop::new(token.clone());
267
268 let handle = tokio::task::spawn_blocking(move || {
269 let mut count = 0;
270 for i in 0..1_000_000 {
271 if i % 1000 == 0 && stop.should_stop() {
272 return Err("cancelled");
273 }
274 count += 1;
275 std::hint::black_box(count);
277 }
278 Ok(count)
279 });
280
281 tokio::time::sleep(std::time::Duration::from_micros(100)).await;
283 token.cancel();
284
285 let result = handle.await.unwrap();
286 assert!(result.is_ok() || result == Err("cancelled"));
288 }
289
290 #[tokio::test]
291 async fn select_with_cancellation() {
292 let token = CancellationToken::new();
293 let stop = TokioStop::new(token.clone());
294
295 tokio::spawn(async move {
297 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
298 token.cancel();
299 });
300
301 let result = tokio::select! {
302 _ = stop.cancelled() => "cancelled",
303 _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => "timeout",
304 };
305
306 assert_eq!(result, "cancelled");
307 }
308
309 #[tokio::test]
310 async fn multiple_tasks_same_token() {
311 use std::sync::atomic::{AtomicUsize, Ordering};
312 use std::sync::Arc;
313
314 let token = CancellationToken::new();
315 let cancelled_count = Arc::new(AtomicUsize::new(0));
316
317 let mut handles = vec![];
318
319 for _ in 0..10 {
320 let stop = TokioStop::new(token.clone());
321 let cancelled_count = Arc::clone(&cancelled_count);
322
323 handles.push(tokio::spawn(async move {
324 for _ in 0..100 {
325 if stop.should_stop() {
326 cancelled_count.fetch_add(1, Ordering::Relaxed);
327 return;
328 }
329 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
330 }
331 }));
332 }
333
334 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
336 token.cancel();
337
338 for h in handles {
339 h.await.unwrap();
340 }
341
342 assert!(cancelled_count.load(Ordering::Relaxed) > 0);
344 }
345
346 #[tokio::test]
347 async fn child_token_cancellation() {
348 let parent = TokioStop::new(CancellationToken::new());
349 let child1 = parent.child();
350 let child2 = parent.child();
351
352 assert!(!child1.should_stop());
353 assert!(!child2.should_stop());
354
355 child1.cancel();
357 assert!(child1.should_stop());
358 assert!(!child2.should_stop());
359 assert!(!parent.should_stop());
360
361 parent.cancel();
363 assert!(child2.should_stop());
364 }
365
366 #[tokio::test]
367 async fn nested_child_tokens() {
368 let root = TokioStop::new(CancellationToken::new());
369 let level1 = root.child();
370 let level2 = level1.child();
371 let level3 = level2.child();
372
373 assert!(!level3.should_stop());
374
375 root.cancel();
376
377 assert!(level1.should_stop());
378 assert!(level2.should_stop());
379 assert!(level3.should_stop());
380 }
381
382 #[tokio::test]
383 async fn check_returns_correct_reason() {
384 let token = CancellationToken::new();
385 let stop = TokioStop::new(token.clone());
386
387 assert_eq!(stop.check(), Ok(()));
388
389 token.cancel();
390
391 assert_eq!(stop.check(), Err(StopReason::Cancelled));
392 }
393
394 #[tokio::test]
395 async fn debug_formatting() {
396 let token = CancellationToken::new();
397 let stop = TokioStop::new(token.clone());
398
399 let debug = format!("{:?}", stop);
400 assert!(debug.contains("TokioStop"));
401 assert!(debug.contains("cancelled"));
402 assert!(debug.contains("false"));
403
404 token.cancel();
405
406 let debug = format!("{:?}", stop);
407 assert!(debug.contains("true"));
408 }
409
410 #[tokio::test]
411 async fn integration_with_stop_trait() {
412 fn process_sync(data: &[u8], stop: impl Stop) -> Result<usize, &'static str> {
413 for (i, _chunk) in data.chunks(100).enumerate() {
414 if i % 10 == 0 && stop.should_stop() {
415 return Err("cancelled");
416 }
417 }
418 Ok(data.len())
419 }
420
421 let token = CancellationToken::new();
422 let stop = TokioStop::new(token.clone());
423 let data = vec![0u8; 10000];
424
425 let result = process_sync(&data, stop.clone());
427 assert_eq!(result, Ok(10000));
428
429 token.cancel();
431 let result = process_sync(&data, stop);
432 assert_eq!(result, Err("cancelled"));
433 }
434
435 #[tokio::test]
436 async fn token_accessor_methods() {
437 let original_token = CancellationToken::new();
438 let stop = TokioStop::new(original_token.clone());
439
440 let token_ref = stop.token();
442 assert!(!token_ref.is_cancelled());
443
444 let recovered_token = stop.into_token();
446 assert!(!recovered_token.is_cancelled());
447
448 original_token.cancel();
450 assert!(recovered_token.is_cancelled());
451 }
452
453 #[test]
454 fn sync_send_bounds() {
455 fn assert_send<T: Send>() {}
456 fn assert_sync<T: Sync>() {}
457
458 assert_send::<TokioStop>();
459 assert_sync::<TokioStop>();
460 }
461
462 #[tokio::test]
463 async fn rapid_cancel_check_cycle() {
464 for _ in 0..100 {
466 let token = CancellationToken::new();
467 let stop = TokioStop::new(token.clone());
468
469 assert!(!stop.should_stop());
470 token.cancel();
471 assert!(stop.should_stop());
472 }
473 }
474
475 #[tokio::test]
476 async fn select_loop_with_pinned_cancelled() {
477 use tokio::sync::mpsc;
478
479 let token = CancellationToken::new();
480 let stop = TokioStop::new(token.clone());
481 let (tx, mut rx) = mpsc::channel::<i32>(10);
482
483 tx.send(1).await.unwrap();
485 tx.send(2).await.unwrap();
486 tx.send(3).await.unwrap();
487
488 let token_clone = token.clone();
490 tokio::spawn(async move {
491 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
492 token_clone.cancel();
493 });
494
495 let cancelled = stop.cancelled();
497 tokio::pin!(cancelled);
498
499 let mut received = vec![];
500 let mut was_cancelled = false;
501
502 loop {
503 tokio::select! {
504 _ = &mut cancelled => {
505 was_cancelled = true;
506 break;
507 }
508 msg = rx.recv() => {
509 match msg {
510 Some(m) => received.push(m),
511 None => break,
512 }
513 }
514 }
515 }
516
517 assert_eq!(received, vec![1, 2, 3]);
518 assert!(was_cancelled);
519 }
520
521 #[tokio::test]
522 async fn select_biased_cancellation_priority() {
523 use tokio::sync::mpsc;
524
525 let token = CancellationToken::new();
526 let stop = TokioStop::new(token.clone());
527 let (tx, mut rx) = mpsc::channel::<i32>(10);
528
529 token.cancel();
531
532 tx.send(42).await.unwrap();
534
535 let cancelled = stop.cancelled();
536 tokio::pin!(cancelled);
537
538 let result = tokio::select! {
540 biased;
541 _ = &mut cancelled => "cancelled",
542 _ = rx.recv() => "received",
543 };
544
545 assert_eq!(result, "cancelled");
546 }
547}