forge_agent/workflow/cancellation.rs
1//! Async cancellation token system with parent-child hierarchy.
2//!
3//! Provides a cooperative cancellation model for workflows and tasks,
4//! allowing users to cancel running operations and propagate cancellation
5//! to all active tasks through a parent-child token hierarchy.
6//!
7//! # Architecture
8//!
9//! The cancellation system consists of three core types:
10//! - [`CancellationToken`]: Thread-safe token representing cancellation state
11//! - [`CancellationTokenSource`]: Owner of the parent token with cancel() method
12//! - [`ChildToken`]: Derived child token for task-level cancellation
13//!
14//! # Cooperative Cancellation Patterns
15//!
16//! This module supports two main patterns for cooperative cancellation:
17//!
18//! ## 1. Polling Pattern
19//!
20//! Poll the cancellation token in long-running loops:
21//!
22//! ```ignore
23//! while !token.poll_cancelled() {
24//! // Do work
25//! tokio::time::sleep(Duration::from_millis(100)).await;
26//! }
27//! ```
28//!
29//! ## 2. Async Wait Pattern
30//!
31//! Wait for cancellation signal asynchronously:
32//!
33//! ```ignore
34//! tokio::select! {
35//! _ = token.wait_cancelled() => {
36//! // Handle cancellation
37//! }
38//! result = do_work() => {
39//! // Handle completion
40//! }
41//! }
42//! ```
43//!
44//! # Example
45//!
46//! ```ignore
47//! use forge_agent::workflow::{CancellationTokenSource, CancellationToken};
48//!
49//! // Create cancellation source for workflow
50//! let source = CancellationTokenSource::new();
51//! let token = source.token();
52//!
53//! // Pass token to tasks
54//! tokio::spawn(async move {
55//! while !token.is_cancelled() {
56//! // Do work
57//! }
58//! });
59//!
60//! // Cancel from anywhere
61//! source.cancel();
62//! ```
63
64use std::future::Future;
65use std::sync::atomic::{AtomicBool, Ordering};
66use std::sync::Arc;
67use tokio::sync::Notify;
68
69/// Thread-safe cancellation token.
70///
71/// Wraps an Arc<AtomicBool> for thread-safe cancellation state.
72/// Tokens can be cheaply cloned and shared across tasks.
73///
74/// # Cooperative Cancellation
75///
76/// Tasks can cooperatively respond to cancellation by:
77/// - Polling with [`poll_cancelled()`](Self::poll_cancelled) in loops
78/// - Awaiting with [`wait_cancelled()`](Self::wait_cancelled) in async contexts
79///
80/// # Examples
81///
82/// See the [`examples`](crate::workflow::examples) module for complete cancellation-aware task
83/// implementations demonstrating:
84/// - [`CancellationAwareTask`](crate::workflow::examples::CancellationAwareTask): Polling pattern
85/// - [`PollingTask`](crate::workflow::examples::PollingTask): tokio::select! pattern
86/// - [`TimeoutAndCancellationTask`](crate::workflow::examples::TimeoutAndCancellationTask): Timeout + cancellation
87///
88/// # Cloning
89///
90/// Cloning a token creates a new reference to the same cancellation state.
91/// When any token is cancelled (via CancellationTokenSource), all clones
92/// will report as cancelled.
93///
94/// # Example
95///
96/// ```ignore
97/// let source = CancellationTokenSource::new();
98/// let token1 = source.token();
99/// let token2 = token1.clone(); // Same state
100///
101/// assert!(!token1.is_cancelled());
102/// assert!(!token2.is_cancelled());
103///
104/// source.cancel();
105///
106/// assert!(token1.is_cancelled());
107/// assert!(token2.is_cancelled());
108/// ```
109#[derive(Clone, Debug)]
110pub struct CancellationToken {
111 cancelled: Arc<AtomicBool>,
112 notify: Arc<Notify>,
113}
114
115impl CancellationToken {
116 /// Creates a new non-cancelled token.
117 pub(crate) fn new() -> Self {
118 Self {
119 cancelled: Arc::new(AtomicBool::new(false)),
120 notify: Arc::new(Notify::new()),
121 }
122 }
123
124 /// Returns true if the token has been cancelled.
125 ///
126 /// Uses Ordering::SeqCst for strongest memory guarantees to ensure
127 /// cancellation is visible across all threads.
128 ///
129 /// # Example
130 ///
131 /// ```ignore
132 /// let source = CancellationTokenSource::new();
133 /// let token = source.token();
134 ///
135 /// assert!(!token.is_cancelled());
136 /// source.cancel();
137 /// assert!(token.is_cancelled());
138 /// ```
139 pub fn is_cancelled(&self) -> bool {
140 self.cancelled.load(Ordering::SeqCst)
141 }
142
143 /// Polls the cancellation state - semantic alias for [`is_cancelled()`].
144 ///
145 /// This method is intended for use in long-running loops where tasks
146 /// cooperatively check for cancellation. The naming makes the intent
147 /// clearer than [`is_cancelled()`] in polling contexts.
148 ///
149 /// # Example
150 ///
151 /// ```ignore
152 /// // Polling pattern in a loop
153 /// while !token.poll_cancelled() {
154 /// // Do work
155 /// tokio::time::sleep(Duration::from_millis(100)).await;
156 /// }
157 /// ```
158 pub fn poll_cancelled(&self) -> bool {
159 self.is_cancelled()
160 }
161
162 /// Async method that waits until the token is cancelled.
163 ///
164 /// This uses polling with tokio::time::sleep to avoid busy-waiting.
165 /// Multiple tasks can wait simultaneously.
166 ///
167 /// # Example
168 ///
169 /// ```ignore
170 /// // Wait for cancellation
171 /// token.wait_until_cancelled().await;
172 /// println!("Token was cancelled!");
173 /// ```
174 ///
175 /// # Use with tokio::select!
176 ///
177 /// ```ignore
178 /// tokio::select! {
179 /// _ = token.wait_until_cancelled() => {
180 /// println!("Cancelled!");
181 /// }
182 /// result = do_work() => {
183 /// println!("Work completed: {:?}", result);
184 /// }
185 /// }
186 /// ```
187 pub async fn wait_until_cancelled(&self) {
188 while !self.is_cancelled() {
189 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
190 }
191 }
192
193 /// Returns a Future that completes when this token is cancelled.
194 ///
195 /// This is equivalent to [`wait_until_cancelled()`] but returns a named future type
196 /// that can be stored and passed around.
197 ///
198 /// # Example
199 ///
200 /// ```ignore
201 /// let future = token.wait_cancelled();
202 /// // ... later
203 /// future.await;
204 /// ```
205 pub fn wait_cancelled(&self) -> impl Future<Output = ()> + Send + Sync + 'static {
206 let cancelled = self.cancelled.clone();
207 async move {
208 while !cancelled.load(Ordering::SeqCst) {
209 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
210 }
211 }
212 }
213}
214
215/// Source that owns a cancellation token and can trigger cancellation.
216///
217/// The CancellationTokenSource is the owner of the parent token and provides
218/// the cancel() method to set the cancellation state. When cancelled, all
219/// child tokens and clones will report as cancelled.
220///
221/// # Parent-Child Hierarchy
222///
223/// The source can create child tokens via child_token(), which allows for
224/// hierarchical cancellation. Children inherit the parent's cancellation state.
225///
226/// # Cloning
227///
228/// Cloning a source creates a new handle to the same underlying token.
229/// This allows multiple parts of the code to share cancellation control.
230///
231/// # Example
232///
233/// ```ignore
234/// let source = CancellationTokenSource::new();
235/// let token = source.token();
236///
237/// // Pass token to workflow
238/// tokio::spawn(async move {
239/// while !token.is_cancelled() {
240/// // Do work
241/// }
242/// });
243///
244/// // Cancel workflow from main thread
245/// source.cancel();
246/// ```
247#[derive(Clone)]
248pub struct CancellationTokenSource {
249 token: CancellationToken,
250}
251
252impl CancellationTokenSource {
253 /// Creates a new cancellation source with a fresh token.
254 ///
255 /// # Example
256 ///
257 /// ```ignore
258 /// let source = CancellationTokenSource::new();
259 /// let token = source.token();
260 /// assert!(!token.is_cancelled());
261 /// ```
262 pub fn new() -> Self {
263 Self {
264 token: CancellationToken::new(),
265 }
266 }
267
268 /// Returns a reference to the parent token.
269 ///
270 /// The token can be cloned and passed to tasks for cancellation checking.
271 ///
272 /// # Example
273 ///
274 /// ```ignore
275 /// let source = CancellationTokenSource::new();
276 /// let token = source.token();
277 /// let token2 = token.clone(); // Both reference same state
278 /// ```
279 pub fn token(&self) -> CancellationToken {
280 self.token.clone()
281 }
282
283 /// Cancels the token, propagating to all child tokens and clones.
284 ///
285 /// This method is idempotent - calling it multiple times has no additional effect.
286 /// All tasks waiting via [`wait_cancelled()`](CancellationToken::wait_cancelled) will be woken.
287 ///
288 /// # Example
289 ///
290 /// ```ignore
291 /// let source = CancellationTokenSource::new();
292 /// let token = source.token();
293 ///
294 /// source.cancel();
295 /// assert!(token.is_cancelled());
296 ///
297 /// source.cancel(); // Idempotent - no additional effect
298 /// assert!(token.is_cancelled());
299 /// ```
300 pub fn cancel(&self) {
301 self.token.cancelled.store(true, Ordering::SeqCst);
302 self.token.notify.notify_waiters();
303 }
304
305 /// Creates a child token that inherits parent cancellation.
306 ///
307 /// Child tokens check both their local state and the parent's state.
308 /// This allows for task-level cancellation independent of workflow cancellation.
309 ///
310 /// # Example
311 ///
312 /// ```ignore
313 /// let source = CancellationTokenSource::new();
314 /// let child = source.child_token();
315 ///
316 /// // Child inherits parent cancellation
317 /// source.cancel();
318 /// assert!(child.is_cancelled());
319 /// ```
320 pub fn child_token(&self) -> ChildToken {
321 ChildToken {
322 parent: self.token.clone(),
323 local_cancelled: Arc::new(AtomicBool::new(false)),
324 }
325 }
326}
327
328impl Default for CancellationTokenSource {
329 fn default() -> Self {
330 Self::new()
331 }
332}
333
334/// Child cancellation token that inherits from a parent.
335///
336/// Child tokens check both their local cancellation state and their parent's
337/// state. This allows for hierarchical cancellation where a task can be
338/// cancelled independently or inherit cancellation from its parent workflow.
339///
340/// # Cancellation Logic
341///
342/// A child token is cancelled if:
343/// - The parent token is cancelled, OR
344/// - The child's local cancel() method was called
345///
346/// # Example
347///
348/// ```ignore
349/// let source = CancellationTokenSource::new();
350/// let child = source.child_token();
351///
352/// // Child inherits parent cancellation
353/// source.cancel();
354/// assert!(child.is_cancelled());
355///
356/// // Or child can be cancelled independently
357/// let source2 = CancellationTokenSource::new();
358/// let child2 = source2.child_token();
359/// child2.cancel();
360/// assert!(child2.is_cancelled());
361/// assert!(!source2.token().is_cancelled());
362/// ```
363#[derive(Clone)]
364pub struct ChildToken {
365 parent: CancellationToken,
366 local_cancelled: Arc<AtomicBool>,
367}
368
369impl ChildToken {
370 /// Returns true if either parent or local token is cancelled.
371 ///
372 /// Checks both the parent token and local cancellation state.
373 ///
374 /// # Example
375 ///
376 /// ```ignore
377 /// let source = CancellationTokenSource::new();
378 /// let child = source.child_token();
379 ///
380 /// assert!(!child.is_cancelled());
381 ///
382 /// // Cancel parent
383 /// source.cancel();
384 /// assert!(child.is_cancelled());
385 /// ```
386 pub fn is_cancelled(&self) -> bool {
387 self.parent.is_cancelled() || self.local_cancelled.load(Ordering::SeqCst)
388 }
389
390 /// Cancels this child token locally.
391 ///
392 /// Local cancellation does not affect the parent token or other children.
393 ///
394 /// # Example
395 ///
396 /// ```ignore
397 /// let source = CancellationTokenSource::new();
398 /// let child1 = source.child_token();
399 /// let child2 = source.child_token();
400 ///
401 /// child1.cancel();
402 /// assert!(child1.is_cancelled());
403 /// assert!(!child2.is_cancelled()); // Other children unaffected
404 /// assert!(!source.token().is_cancelled()); // Parent unaffected
405 /// ```
406 pub fn cancel(&self) {
407 self.local_cancelled.store(true, Ordering::SeqCst);
408 }
409}
410
411impl std::fmt::Debug for ChildToken {
412 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 f.debug_struct("ChildToken")
414 .field("parent_cancelled", &self.parent.is_cancelled())
415 .field("local_cancelled", &self.local_cancelled.load(Ordering::SeqCst))
416 .field("is_cancelled", &self.is_cancelled())
417 .finish()
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_token_initially_not_cancelled() {
427 let source = CancellationTokenSource::new();
428 let token = source.token();
429
430 assert!(!token.is_cancelled());
431 }
432
433 #[test]
434 fn test_source_cancel_sets_token() {
435 let source = CancellationTokenSource::new();
436 let token = source.token();
437
438 source.cancel();
439 assert!(token.is_cancelled());
440 }
441
442 #[test]
443 fn test_token_clone_shares_state() {
444 let source = CancellationTokenSource::new();
445 let token1 = source.token();
446 let token2 = token1.clone();
447
448 source.cancel();
449
450 // Both clones should see cancellation
451 assert!(token1.is_cancelled());
452 assert!(token2.is_cancelled());
453 }
454
455 #[test]
456 fn test_child_token_inherits_parent_cancellation() {
457 let source = CancellationTokenSource::new();
458 let child = source.child_token();
459
460 assert!(!child.is_cancelled());
461
462 // Cancel parent
463 source.cancel();
464 assert!(child.is_cancelled());
465 }
466
467 #[test]
468 fn test_child_token_independent_cancel() {
469 let source = CancellationTokenSource::new();
470 let child = source.child_token();
471
472 // Cancel child locally
473 child.cancel();
474 assert!(child.is_cancelled());
475
476 // Parent should not be cancelled
477 assert!(!source.token().is_cancelled());
478 }
479
480 #[test]
481 fn test_multiple_children_all_cancelled() {
482 let source = CancellationTokenSource::new();
483 let child1 = source.child_token();
484 let child2 = source.child_token();
485 let child3 = source.child_token();
486
487 // Cancel parent
488 source.cancel();
489
490 // All children should be cancelled
491 assert!(child1.is_cancelled());
492 assert!(child2.is_cancelled());
493 assert!(child3.is_cancelled());
494 }
495
496 #[test]
497 fn test_cancellation_thread_safe() {
498 use std::thread;
499 use std::time::Duration;
500
501 let source = CancellationTokenSource::new();
502 let token = source.token();
503 let token_clone = token.clone();
504
505 // Spawn thread to check cancellation
506 let handle = thread::spawn(move || {
507 while !token_clone.is_cancelled() {
508 // Busy wait (for testing only)
509 }
510 // Thread should exit when cancelled
511 });
512
513 // Give thread time to start
514 thread::sleep(Duration::from_millis(10));
515
516 // Cancel from main thread
517 source.cancel();
518
519 // Thread should exit
520 handle.join().unwrap();
521}
522
523 #[test]
524 fn test_token_debug_display() {
525 let source = CancellationTokenSource::new();
526 let token = source.token();
527
528 // Debug should work
529 let debug_str = format!("{:?}", token);
530 assert!(debug_str.contains("CancellationToken"));
531
532 // Cancel and debug again
533 source.cancel();
534 let debug_str_cancelled = format!("{:?}", token);
535 assert!(debug_str_cancelled.contains("CancellationToken"));
536 }
537
538 #[test]
539 fn test_child_token_debug_display() {
540 let source = CancellationTokenSource::new();
541 let child = source.child_token();
542
543 // Debug should show state
544 let debug_str = format!("{:?}", child);
545 assert!(debug_str.contains("ChildToken"));
546 assert!(debug_str.contains("parent_cancelled: false"));
547 assert!(debug_str.contains("local_cancelled: false"));
548
549 // Cancel parent
550 source.cancel();
551 let debug_str_cancelled = format!("{:?}", child);
552 assert!(debug_str_cancelled.contains("parent_cancelled: true"));
553 }
554
555 #[test]
556 fn test_source_cancel_idempotent() {
557 let source = CancellationTokenSource::new();
558 let token = source.token();
559
560 // Cancel multiple times
561 source.cancel();
562 source.cancel();
563 source.cancel();
564
565 // Should still be cancelled
566 assert!(token.is_cancelled());
567 }
568
569 #[test]
570 fn test_child_token_parent_and_local_both_cancelled() {
571 let source = CancellationTokenSource::new();
572 let child = source.child_token();
573
574 // Cancel both parent and child
575 source.cancel();
576 child.cancel();
577
578 // Should still be cancelled
579 assert!(child.is_cancelled());
580 }
581
582 #[test]
583 fn test_default_source() {
584 let source = CancellationTokenSource::default();
585 let token = source.token();
586
587 assert!(!token.is_cancelled());
588 }
589
590 // Tests for cooperative cancellation
591
592 #[test]
593 fn test_poll_cancelled_returns_false_initially() {
594 let source = CancellationTokenSource::new();
595 let token = source.token();
596
597 assert!(!token.poll_cancelled());
598 }
599
600 #[test]
601 fn test_poll_cancelled_returns_true_after_cancel() {
602 let source = CancellationTokenSource::new();
603 let token = source.token();
604
605 source.cancel();
606 assert!(token.poll_cancelled());
607 }
608
609 #[tokio::test]
610 async fn test_wait_cancelled_completes_on_cancel() {
611 let source = CancellationTokenSource::new();
612 let token = source.token();
613
614 // Spawn a task to cancel after a delay
615 let source_clone = source.clone();
616 tokio::spawn(async move {
617 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
618 source_clone.cancel();
619 });
620
621 // Wait for cancellation - should complete within 200ms
622 let start = std::time::Instant::now();
623 token.wait_cancelled().await;
624 let elapsed = start.elapsed();
625
626 assert!(elapsed < tokio::time::Duration::from_millis(200));
627 assert!(token.is_cancelled());
628 }
629
630 #[tokio::test]
631 async fn test_wait_cancelled_multiple_waiters() {
632 let source = CancellationTokenSource::new();
633 let token1 = source.token();
634 let token2 = source.token();
635 let token3 = source.token();
636
637 // Spawn multiple waiters
638 let handle1 = tokio::spawn(async move {
639 token1.wait_cancelled().await;
640 });
641
642 let handle2 = tokio::spawn(async move {
643 token2.wait_cancelled().await;
644 });
645
646 let handle3 = tokio::spawn(async move {
647 token3.wait_cancelled().await;
648 });
649
650 // Cancel after a delay
651 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
652 source.cancel();
653
654 // All waiters should complete
655 let start = std::time::Instant::now();
656 let (r1, r2, r3) = tokio::join!(handle1, handle2, handle3);
657 let elapsed = start.elapsed();
658
659 assert!(r1.is_ok());
660 assert!(r2.is_ok());
661 assert!(r3.is_ok());
662 assert!(elapsed < tokio::time::Duration::from_millis(200));
663 }
664
665 #[tokio::test]
666 async fn test_wait_cancelled_idempotent() {
667 let source = CancellationTokenSource::new();
668 let token = source.token();
669
670 // Cancel immediately
671 source.cancel();
672
673 // Multiple waits should all complete immediately
674 let start = std::time::Instant::now();
675 token.wait_cancelled().await;
676 token.wait_cancelled().await;
677 token.wait_cancelled().await;
678 let elapsed = start.elapsed();
679
680 assert!(elapsed < tokio::time::Duration::from_millis(10));
681 }
682
683 #[tokio::test]
684 async fn test_cooperative_cancellation_pattern() {
685 let source = CancellationTokenSource::new();
686 let token = source.token();
687
688 // Simulate a task that cooperatively polls for cancellation
689 let token_clone = token.clone();
690 let handle = tokio::spawn(async move {
691 let mut iterations = 0;
692 while !token_clone.poll_cancelled() {
693 iterations += 1;
694 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
695
696 // Safety limit to avoid infinite loop in test
697 if iterations >= 100 {
698 break;
699 }
700 }
701 iterations
702 });
703
704 // Cancel after 50ms
705 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
706 source.cancel();
707
708 // Task should exit early due to cancellation
709 let iterations = handle.await.unwrap();
710 assert!(iterations < 100); // Should not complete all 100 iterations
711 assert!(iterations > 2); // Should have done some work
712 }
713
714 // Integration test with WorkflowExecutor
715
716 #[tokio::test]
717 async fn test_workflow_cancellation_with_executor() {
718 use crate::workflow::dag::Workflow;
719 use crate::workflow::executor::WorkflowExecutor;
720 use crate::workflow::task::{TaskContext, TaskId, TaskResult, WorkflowTask};
721 use async_trait::async_trait;
722
723 // Create a simple task for testing
724 struct SimpleTask {
725 id: TaskId,
726 name: String,
727 }
728
729 #[async_trait]
730 impl WorkflowTask for SimpleTask {
731 async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, crate::workflow::TaskError> {
732 // Simulate some work
733 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
734 Ok(TaskResult::Success)
735 }
736
737 fn id(&self) -> TaskId {
738 self.id.clone()
739 }
740
741 fn name(&self) -> &str {
742 &self.name
743 }
744 }
745
746 // Create workflow with 5 sequential tasks
747 let mut workflow = Workflow::new();
748 for i in 1..=5 {
749 workflow.add_task(Box::new(SimpleTask {
750 id: TaskId::new(format!("task-{}", i)),
751 name: format!("Task {}", i),
752 }));
753 }
754
755 // Create cancellation source and cancel before execution
756 let source = CancellationTokenSource::new();
757 let mut executor = WorkflowExecutor::new(workflow)
758 .with_cancellation_source(source);
759
760 // Cancel immediately
761 executor.cancel();
762
763 // Execute workflow (should be cancelled immediately)
764 let result = executor.execute().await.unwrap();
765
766 // Workflow should be cancelled with no tasks completed
767 assert!(!result.success);
768 assert_eq!(result.completed_tasks.len(), 0);
769
770 // Verify cancellation was recorded in audit log
771 let events = executor.audit_log().replay();
772 assert!(events.iter().any(|e| matches!(e, crate::audit::AuditEvent::WorkflowCancelled { .. })));
773 }
774}