Skip to main content

forge_agent/workflow/
cancellation.rs

1//! Async cancellation token system with parent-child hierarchy.
2//!
3//! Provides a cooperative cancellation model for workflows and tasks,
4//! allowing users to cancel running operations and propagate cancellation
5//! to all active tasks through a parent-child token hierarchy.
6//!
7//! # Architecture
8//!
9//! The cancellation system consists of three core types:
10//! - [`CancellationToken`]: Thread-safe token representing cancellation state
11//! - [`CancellationTokenSource`]: Owner of the parent token with cancel() method
12//! - [`ChildToken`]: Derived child token for task-level cancellation
13//!
14//! # Cooperative Cancellation Patterns
15//!
16//! This module supports two main patterns for cooperative cancellation:
17//!
18//! ## 1. Polling Pattern
19//!
20//! Poll the cancellation token in long-running loops:
21//!
22//! ```ignore
23//! while !token.poll_cancelled() {
24//!     // Do work
25//!     tokio::time::sleep(Duration::from_millis(100)).await;
26//! }
27//! ```
28//!
29//! ## 2. Async Wait Pattern
30//!
31//! Wait for cancellation signal asynchronously:
32//!
33//! ```ignore
34//! tokio::select! {
35//!     _ = token.wait_cancelled() => {
36//!         // Handle cancellation
37//!     }
38//!     result = do_work() => {
39//!         // Handle completion
40//!     }
41//! }
42//! ```
43//!
44//! # Example
45//!
46//! ```ignore
47//! use forge_agent::workflow::{CancellationTokenSource, CancellationToken};
48//!
49//! // Create cancellation source for workflow
50//! let source = CancellationTokenSource::new();
51//! let token = source.token();
52//!
53//! // Pass token to tasks
54//! tokio::spawn(async move {
55//!     while !token.is_cancelled() {
56//!         // Do work
57//!     }
58//! });
59//!
60//! // Cancel from anywhere
61//! source.cancel();
62//! ```
63
64use std::future::Future;
65use std::sync::atomic::{AtomicBool, Ordering};
66use std::sync::Arc;
67use tokio::sync::Notify;
68
69/// Thread-safe cancellation token.
70///
71/// Wraps an Arc<AtomicBool> for thread-safe cancellation state.
72/// Tokens can be cheaply cloned and shared across tasks.
73///
74/// # Cooperative Cancellation
75///
76/// Tasks can cooperatively respond to cancellation by:
77/// - Polling with [`poll_cancelled()`](Self::poll_cancelled) in loops
78/// - Awaiting with [`wait_cancelled()`](Self::wait_cancelled) in async contexts
79///
80/// # Examples
81///
82/// See the [`examples`](crate::workflow::examples) module for complete cancellation-aware task
83/// implementations demonstrating:
84/// - [`CancellationAwareTask`](crate::workflow::examples::CancellationAwareTask): Polling pattern
85/// - [`PollingTask`](crate::workflow::examples::PollingTask): tokio::select! pattern
86/// - [`TimeoutAndCancellationTask`](crate::workflow::examples::TimeoutAndCancellationTask): Timeout + cancellation
87///
88/// # Cloning
89///
90/// Cloning a token creates a new reference to the same cancellation state.
91/// When any token is cancelled (via CancellationTokenSource), all clones
92/// will report as cancelled.
93///
94/// # Example
95///
96/// ```ignore
97/// let source = CancellationTokenSource::new();
98/// let token1 = source.token();
99/// let token2 = token1.clone(); // Same state
100///
101/// assert!(!token1.is_cancelled());
102/// assert!(!token2.is_cancelled());
103///
104/// source.cancel();
105///
106/// assert!(token1.is_cancelled());
107/// assert!(token2.is_cancelled());
108/// ```
109#[derive(Clone, Debug)]
110pub struct CancellationToken {
111    cancelled: Arc<AtomicBool>,
112    notify: Arc<Notify>,
113}
114
115impl CancellationToken {
116    /// Creates a new non-cancelled token.
117    pub(crate) fn new() -> Self {
118        Self {
119            cancelled: Arc::new(AtomicBool::new(false)),
120            notify: Arc::new(Notify::new()),
121        }
122    }
123
124    /// Returns true if the token has been cancelled.
125    ///
126    /// Uses Ordering::SeqCst for strongest memory guarantees to ensure
127    /// cancellation is visible across all threads.
128    ///
129    /// # Example
130    ///
131    /// ```ignore
132    /// let source = CancellationTokenSource::new();
133    /// let token = source.token();
134    ///
135    /// assert!(!token.is_cancelled());
136    /// source.cancel();
137    /// assert!(token.is_cancelled());
138    /// ```
139    pub fn is_cancelled(&self) -> bool {
140        self.cancelled.load(Ordering::SeqCst)
141    }
142
143    /// Polls the cancellation state - semantic alias for [`is_cancelled()`].
144    ///
145    /// This method is intended for use in long-running loops where tasks
146    /// cooperatively check for cancellation. The naming makes the intent
147    /// clearer than [`is_cancelled()`] in polling contexts.
148    ///
149    /// # Example
150    ///
151    /// ```ignore
152    /// // Polling pattern in a loop
153    /// while !token.poll_cancelled() {
154    ///     // Do work
155    ///     tokio::time::sleep(Duration::from_millis(100)).await;
156    /// }
157    /// ```
158    pub fn poll_cancelled(&self) -> bool {
159        self.is_cancelled()
160    }
161
162    /// Async method that waits until the token is cancelled.
163    ///
164    /// This uses polling with tokio::time::sleep to avoid busy-waiting.
165    /// Multiple tasks can wait simultaneously.
166    ///
167    /// # Example
168    ///
169    /// ```ignore
170    /// // Wait for cancellation
171    /// token.wait_until_cancelled().await;
172    /// println!("Token was cancelled!");
173    /// ```
174    ///
175    /// # Use with tokio::select!
176    ///
177    /// ```ignore
178    /// tokio::select! {
179    ///     _ = token.wait_until_cancelled() => {
180    ///         println!("Cancelled!");
181    ///     }
182    ///     result = do_work() => {
183    ///         println!("Work completed: {:?}", result);
184    ///     }
185    /// }
186    /// ```
187    pub async fn wait_until_cancelled(&self) {
188        while !self.is_cancelled() {
189            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
190        }
191    }
192
193    /// Returns a Future that completes when this token is cancelled.
194    ///
195    /// This is equivalent to [`wait_until_cancelled()`] but returns a named future type
196    /// that can be stored and passed around.
197    ///
198    /// # Example
199    ///
200    /// ```ignore
201    /// let future = token.wait_cancelled();
202    /// // ... later
203    /// future.await;
204    /// ```
205    pub fn wait_cancelled(&self) -> impl Future<Output = ()> + Send + Sync + 'static {
206        let cancelled = self.cancelled.clone();
207        async move {
208            while !cancelled.load(Ordering::SeqCst) {
209                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
210            }
211        }
212    }
213}
214
215/// Source that owns a cancellation token and can trigger cancellation.
216///
217/// The CancellationTokenSource is the owner of the parent token and provides
218/// the cancel() method to set the cancellation state. When cancelled, all
219/// child tokens and clones will report as cancelled.
220///
221/// # Parent-Child Hierarchy
222///
223/// The source can create child tokens via child_token(), which allows for
224/// hierarchical cancellation. Children inherit the parent's cancellation state.
225///
226/// # Cloning
227///
228/// Cloning a source creates a new handle to the same underlying token.
229/// This allows multiple parts of the code to share cancellation control.
230///
231/// # Example
232///
233/// ```ignore
234/// let source = CancellationTokenSource::new();
235/// let token = source.token();
236///
237/// // Pass token to workflow
238/// tokio::spawn(async move {
239///     while !token.is_cancelled() {
240///         // Do work
241///     }
242/// });
243///
244/// // Cancel workflow from main thread
245/// source.cancel();
246/// ```
247#[derive(Clone)]
248pub struct CancellationTokenSource {
249    token: CancellationToken,
250}
251
252impl CancellationTokenSource {
253    /// Creates a new cancellation source with a fresh token.
254    ///
255    /// # Example
256    ///
257    /// ```ignore
258    /// let source = CancellationTokenSource::new();
259    /// let token = source.token();
260    /// assert!(!token.is_cancelled());
261    /// ```
262    pub fn new() -> Self {
263        Self {
264            token: CancellationToken::new(),
265        }
266    }
267
268    /// Returns a reference to the parent token.
269    ///
270    /// The token can be cloned and passed to tasks for cancellation checking.
271    ///
272    /// # Example
273    ///
274    /// ```ignore
275    /// let source = CancellationTokenSource::new();
276    /// let token = source.token();
277    /// let token2 = token.clone(); // Both reference same state
278    /// ```
279    pub fn token(&self) -> CancellationToken {
280        self.token.clone()
281    }
282
283    /// Cancels the token, propagating to all child tokens and clones.
284    ///
285    /// This method is idempotent - calling it multiple times has no additional effect.
286    /// All tasks waiting via [`wait_cancelled()`](CancellationToken::wait_cancelled) will be woken.
287    ///
288    /// # Example
289    ///
290    /// ```ignore
291    /// let source = CancellationTokenSource::new();
292    /// let token = source.token();
293    ///
294    /// source.cancel();
295    /// assert!(token.is_cancelled());
296    ///
297    /// source.cancel(); // Idempotent - no additional effect
298    /// assert!(token.is_cancelled());
299    /// ```
300    pub fn cancel(&self) {
301        self.token.cancelled.store(true, Ordering::SeqCst);
302        self.token.notify.notify_waiters();
303    }
304
305    /// Creates a child token that inherits parent cancellation.
306    ///
307    /// Child tokens check both their local state and the parent's state.
308    /// This allows for task-level cancellation independent of workflow cancellation.
309    ///
310    /// # Example
311    ///
312    /// ```ignore
313    /// let source = CancellationTokenSource::new();
314    /// let child = source.child_token();
315    ///
316    /// // Child inherits parent cancellation
317    /// source.cancel();
318    /// assert!(child.is_cancelled());
319    /// ```
320    pub fn child_token(&self) -> ChildToken {
321        ChildToken {
322            parent: self.token.clone(),
323            local_cancelled: Arc::new(AtomicBool::new(false)),
324        }
325    }
326}
327
328impl Default for CancellationTokenSource {
329    fn default() -> Self {
330        Self::new()
331    }
332}
333
334/// Child cancellation token that inherits from a parent.
335///
336/// Child tokens check both their local cancellation state and their parent's
337/// state. This allows for hierarchical cancellation where a task can be
338/// cancelled independently or inherit cancellation from its parent workflow.
339///
340/// # Cancellation Logic
341///
342/// A child token is cancelled if:
343/// - The parent token is cancelled, OR
344/// - The child's local cancel() method was called
345///
346/// # Example
347///
348/// ```ignore
349/// let source = CancellationTokenSource::new();
350/// let child = source.child_token();
351///
352/// // Child inherits parent cancellation
353/// source.cancel();
354/// assert!(child.is_cancelled());
355///
356/// // Or child can be cancelled independently
357/// let source2 = CancellationTokenSource::new();
358/// let child2 = source2.child_token();
359/// child2.cancel();
360/// assert!(child2.is_cancelled());
361/// assert!(!source2.token().is_cancelled());
362/// ```
363#[derive(Clone)]
364pub struct ChildToken {
365    parent: CancellationToken,
366    local_cancelled: Arc<AtomicBool>,
367}
368
369impl ChildToken {
370    /// Returns true if either parent or local token is cancelled.
371    ///
372    /// Checks both the parent token and local cancellation state.
373    ///
374    /// # Example
375    ///
376    /// ```ignore
377    /// let source = CancellationTokenSource::new();
378    /// let child = source.child_token();
379    ///
380    /// assert!(!child.is_cancelled());
381    ///
382    /// // Cancel parent
383    /// source.cancel();
384    /// assert!(child.is_cancelled());
385    /// ```
386    pub fn is_cancelled(&self) -> bool {
387        self.parent.is_cancelled() || self.local_cancelled.load(Ordering::SeqCst)
388    }
389
390    /// Cancels this child token locally.
391    ///
392    /// Local cancellation does not affect the parent token or other children.
393    ///
394    /// # Example
395    ///
396    /// ```ignore
397    /// let source = CancellationTokenSource::new();
398    /// let child1 = source.child_token();
399    /// let child2 = source.child_token();
400    ///
401    /// child1.cancel();
402    /// assert!(child1.is_cancelled());
403    /// assert!(!child2.is_cancelled()); // Other children unaffected
404    /// assert!(!source.token().is_cancelled()); // Parent unaffected
405    /// ```
406    pub fn cancel(&self) {
407        self.local_cancelled.store(true, Ordering::SeqCst);
408    }
409}
410
411impl std::fmt::Debug for ChildToken {
412    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413        f.debug_struct("ChildToken")
414            .field("parent_cancelled", &self.parent.is_cancelled())
415            .field("local_cancelled", &self.local_cancelled.load(Ordering::SeqCst))
416            .field("is_cancelled", &self.is_cancelled())
417            .finish()
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn test_token_initially_not_cancelled() {
427        let source = CancellationTokenSource::new();
428        let token = source.token();
429
430        assert!(!token.is_cancelled());
431    }
432
433    #[test]
434    fn test_source_cancel_sets_token() {
435        let source = CancellationTokenSource::new();
436        let token = source.token();
437
438        source.cancel();
439        assert!(token.is_cancelled());
440    }
441
442    #[test]
443    fn test_token_clone_shares_state() {
444        let source = CancellationTokenSource::new();
445        let token1 = source.token();
446        let token2 = token1.clone();
447
448        source.cancel();
449
450        // Both clones should see cancellation
451        assert!(token1.is_cancelled());
452        assert!(token2.is_cancelled());
453    }
454
455    #[test]
456    fn test_child_token_inherits_parent_cancellation() {
457        let source = CancellationTokenSource::new();
458        let child = source.child_token();
459
460        assert!(!child.is_cancelled());
461
462        // Cancel parent
463        source.cancel();
464        assert!(child.is_cancelled());
465    }
466
467    #[test]
468    fn test_child_token_independent_cancel() {
469        let source = CancellationTokenSource::new();
470        let child = source.child_token();
471
472        // Cancel child locally
473        child.cancel();
474        assert!(child.is_cancelled());
475
476        // Parent should not be cancelled
477        assert!(!source.token().is_cancelled());
478    }
479
480    #[test]
481    fn test_multiple_children_all_cancelled() {
482        let source = CancellationTokenSource::new();
483        let child1 = source.child_token();
484        let child2 = source.child_token();
485        let child3 = source.child_token();
486
487        // Cancel parent
488        source.cancel();
489
490        // All children should be cancelled
491        assert!(child1.is_cancelled());
492        assert!(child2.is_cancelled());
493        assert!(child3.is_cancelled());
494    }
495
496    #[test]
497    fn test_cancellation_thread_safe() {
498        use std::thread;
499        use std::time::Duration;
500
501        let source = CancellationTokenSource::new();
502        let token = source.token();
503        let token_clone = token.clone();
504
505        // Spawn thread to check cancellation
506        let handle = thread::spawn(move || {
507            while !token_clone.is_cancelled() {
508                // Busy wait (for testing only)
509            }
510            // Thread should exit when cancelled
511        });
512
513        // Give thread time to start
514    thread::sleep(Duration::from_millis(10));
515
516    // Cancel from main thread
517    source.cancel();
518
519    // Thread should exit
520    handle.join().unwrap();
521}
522
523    #[test]
524    fn test_token_debug_display() {
525        let source = CancellationTokenSource::new();
526        let token = source.token();
527
528        // Debug should work
529        let debug_str = format!("{:?}", token);
530        assert!(debug_str.contains("CancellationToken"));
531
532        // Cancel and debug again
533        source.cancel();
534        let debug_str_cancelled = format!("{:?}", token);
535        assert!(debug_str_cancelled.contains("CancellationToken"));
536    }
537
538    #[test]
539    fn test_child_token_debug_display() {
540        let source = CancellationTokenSource::new();
541        let child = source.child_token();
542
543        // Debug should show state
544        let debug_str = format!("{:?}", child);
545        assert!(debug_str.contains("ChildToken"));
546        assert!(debug_str.contains("parent_cancelled: false"));
547        assert!(debug_str.contains("local_cancelled: false"));
548
549        // Cancel parent
550        source.cancel();
551        let debug_str_cancelled = format!("{:?}", child);
552        assert!(debug_str_cancelled.contains("parent_cancelled: true"));
553    }
554
555    #[test]
556    fn test_source_cancel_idempotent() {
557        let source = CancellationTokenSource::new();
558        let token = source.token();
559
560        // Cancel multiple times
561        source.cancel();
562        source.cancel();
563        source.cancel();
564
565        // Should still be cancelled
566        assert!(token.is_cancelled());
567    }
568
569    #[test]
570    fn test_child_token_parent_and_local_both_cancelled() {
571        let source = CancellationTokenSource::new();
572        let child = source.child_token();
573
574        // Cancel both parent and child
575        source.cancel();
576        child.cancel();
577
578        // Should still be cancelled
579        assert!(child.is_cancelled());
580    }
581
582    #[test]
583    fn test_default_source() {
584        let source = CancellationTokenSource::default();
585        let token = source.token();
586
587        assert!(!token.is_cancelled());
588    }
589
590    // Tests for cooperative cancellation
591
592    #[test]
593    fn test_poll_cancelled_returns_false_initially() {
594        let source = CancellationTokenSource::new();
595        let token = source.token();
596
597        assert!(!token.poll_cancelled());
598    }
599
600    #[test]
601    fn test_poll_cancelled_returns_true_after_cancel() {
602        let source = CancellationTokenSource::new();
603        let token = source.token();
604
605        source.cancel();
606        assert!(token.poll_cancelled());
607    }
608
609    #[tokio::test]
610    async fn test_wait_cancelled_completes_on_cancel() {
611        let source = CancellationTokenSource::new();
612        let token = source.token();
613
614        // Spawn a task to cancel after a delay
615        let source_clone = source.clone();
616        tokio::spawn(async move {
617            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
618            source_clone.cancel();
619        });
620
621        // Wait for cancellation - should complete within 200ms
622        let start = std::time::Instant::now();
623        token.wait_cancelled().await;
624        let elapsed = start.elapsed();
625
626        assert!(elapsed < tokio::time::Duration::from_millis(200));
627        assert!(token.is_cancelled());
628    }
629
630    #[tokio::test]
631    async fn test_wait_cancelled_multiple_waiters() {
632        let source = CancellationTokenSource::new();
633        let token1 = source.token();
634        let token2 = source.token();
635        let token3 = source.token();
636
637        // Spawn multiple waiters
638        let handle1 = tokio::spawn(async move {
639            token1.wait_cancelled().await;
640        });
641
642        let handle2 = tokio::spawn(async move {
643            token2.wait_cancelled().await;
644        });
645
646        let handle3 = tokio::spawn(async move {
647            token3.wait_cancelled().await;
648        });
649
650        // Cancel after a delay
651        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
652        source.cancel();
653
654        // All waiters should complete
655        let start = std::time::Instant::now();
656        let (r1, r2, r3) = tokio::join!(handle1, handle2, handle3);
657        let elapsed = start.elapsed();
658
659        assert!(r1.is_ok());
660        assert!(r2.is_ok());
661        assert!(r3.is_ok());
662        assert!(elapsed < tokio::time::Duration::from_millis(200));
663    }
664
665    #[tokio::test]
666    async fn test_wait_cancelled_idempotent() {
667        let source = CancellationTokenSource::new();
668        let token = source.token();
669
670        // Cancel immediately
671        source.cancel();
672
673        // Multiple waits should all complete immediately
674        let start = std::time::Instant::now();
675        token.wait_cancelled().await;
676        token.wait_cancelled().await;
677        token.wait_cancelled().await;
678        let elapsed = start.elapsed();
679
680        assert!(elapsed < tokio::time::Duration::from_millis(10));
681    }
682
683    #[tokio::test]
684    async fn test_cooperative_cancellation_pattern() {
685        let source = CancellationTokenSource::new();
686        let token = source.token();
687
688        // Simulate a task that cooperatively polls for cancellation
689        let token_clone = token.clone();
690        let handle = tokio::spawn(async move {
691            let mut iterations = 0;
692            while !token_clone.poll_cancelled() {
693                iterations += 1;
694                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
695
696                // Safety limit to avoid infinite loop in test
697                if iterations >= 100 {
698                    break;
699                }
700            }
701            iterations
702        });
703
704        // Cancel after 50ms
705        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
706        source.cancel();
707
708        // Task should exit early due to cancellation
709        let iterations = handle.await.unwrap();
710        assert!(iterations < 100); // Should not complete all 100 iterations
711        assert!(iterations > 2); // Should have done some work
712    }
713
714    // Integration test with WorkflowExecutor
715
716    #[tokio::test]
717    async fn test_workflow_cancellation_with_executor() {
718        use crate::workflow::dag::Workflow;
719        use crate::workflow::executor::WorkflowExecutor;
720        use crate::workflow::task::{TaskContext, TaskId, TaskResult, WorkflowTask};
721        use async_trait::async_trait;
722
723        // Create a simple task for testing
724        struct SimpleTask {
725            id: TaskId,
726            name: String,
727        }
728
729        #[async_trait]
730        impl WorkflowTask for SimpleTask {
731            async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, crate::workflow::TaskError> {
732                // Simulate some work
733                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
734                Ok(TaskResult::Success)
735            }
736
737            fn id(&self) -> TaskId {
738                self.id.clone()
739            }
740
741            fn name(&self) -> &str {
742                &self.name
743            }
744        }
745
746        // Create workflow with 5 sequential tasks
747        let mut workflow = Workflow::new();
748        for i in 1..=5 {
749            workflow.add_task(Box::new(SimpleTask {
750                id: TaskId::new(format!("task-{}", i)),
751                name: format!("Task {}", i),
752            }));
753        }
754
755        // Create cancellation source and cancel before execution
756        let source = CancellationTokenSource::new();
757        let mut executor = WorkflowExecutor::new(workflow)
758            .with_cancellation_source(source);
759
760        // Cancel immediately
761        executor.cancel();
762
763        // Execute workflow (should be cancelled immediately)
764        let result = executor.execute().await.unwrap();
765
766        // Workflow should be cancelled with no tasks completed
767        assert!(!result.success);
768        assert_eq!(result.completed_tasks.len(), 0);
769
770        // Verify cancellation was recorded in audit log
771        let events = executor.audit_log().replay();
772        assert!(events.iter().any(|e| matches!(e, crate::audit::AuditEvent::WorkflowCancelled { .. })));
773    }
774}