ferrous_di/
cancellation.rs

1//! Cancellation token support for workflow engines.
2//!
3//! This module provides cancellation token primitives that are essential
4//! for n8n-style workflow engines where nodes need abort capabilities.
5
6use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
7use std::time::{Duration, Instant};
8
9/// A token that can be used to signal cancellation across async operations.
10///
11/// Essential for workflow engines where nodes need abort capabilities.
12/// The token is designed to be DI-visible and propagate through scopes.
13///
14/// # Examples
15///
16/// ```
17/// use ferrous_di::{ServiceCollection, CancellationToken, Resolver};
18/// use std::sync::Arc;
19///
20/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
21/// let mut services = ServiceCollection::new();
22/// 
23/// // Register cancellation token as scoped resource
24/// services.add_scoped_factory::<CancellationToken, _>(|_| {
25///     CancellationToken::new()
26/// });
27///
28/// let provider = services.build();
29/// let scope = provider.create_scope();
30///
31/// let cancel_token = scope.get_required::<CancellationToken>();
32/// 
33/// // Check if cancelled
34/// if cancel_token.is_cancelled() {
35///     return Err("Operation cancelled".into());
36/// }
37/// 
38/// # #[cfg(feature = "async")]
39/// // Use in async operation with tokio
40/// tokio::select! {
41///     result = some_long_operation() => {
42///         // Operation completed
43///     }
44///     _ = cancel_token.cancelled() => {
45///         // Operation was cancelled
46///         return Err("Operation cancelled".into());
47///     }
48/// }
49/// # Ok(())
50/// # }
51/// 
52/// # async fn some_long_operation() {}
53/// ```
54#[derive(Clone)]
55pub struct CancellationToken {
56    inner: Arc<CancellationTokenInner>,
57}
58
59struct CancellationTokenInner {
60    cancelled: AtomicBool,
61    parent: Option<CancellationToken>,
62    created_at: Instant,
63}
64
65impl CancellationToken {
66    /// Creates a new cancellation token.
67    pub fn new() -> Self {
68        Self {
69            inner: Arc::new(CancellationTokenInner {
70                cancelled: AtomicBool::new(false),
71                parent: None,
72                created_at: Instant::now(),
73            })
74        }
75    }
76
77    /// Creates a child token that will be cancelled when either this token
78    /// or the parent token is cancelled.
79    ///
80    /// Perfect for hierarchical workflow cancellation (flow → run → node).
81    ///
82    /// # Examples
83    ///
84    /// ```
85    /// use ferrous_di::CancellationToken;
86    ///
87    /// let parent_token = CancellationToken::new();
88    /// let child_token = parent_token.child_token();
89    ///
90    /// parent_token.cancel();
91    /// assert!(child_token.is_cancelled());
92    /// ```
93    pub fn child_token(&self) -> Self {
94        Self {
95            inner: Arc::new(CancellationTokenInner {
96                cancelled: AtomicBool::new(false),
97                parent: Some(self.clone()),
98                created_at: Instant::now(),
99            })
100        }
101    }
102
103    /// Cancels the token, signaling that associated operations should stop.
104    pub fn cancel(&self) {
105        self.inner.cancelled.store(true, Ordering::Release);
106    }
107
108    /// Returns true if cancellation has been requested.
109    pub fn is_cancelled(&self) -> bool {
110        // Check self first
111        if self.inner.cancelled.load(Ordering::Acquire) {
112            return true;
113        }
114        
115        // Check parent chain
116        if let Some(ref parent) = self.inner.parent {
117            return parent.is_cancelled();
118        }
119        
120        false
121    }
122
123    /// Throws a cancellation error if the token is cancelled.
124    ///
125    /// # Errors
126    ///
127    /// Returns `Err` with a cancellation message if the token is cancelled.
128    pub fn throw_if_cancelled(&self) -> Result<(), CancellationError> {
129        if self.is_cancelled() {
130            Err(CancellationError::new("Operation was cancelled"))
131        } else {
132            Ok(())
133        }
134    }
135
136    /// Returns a future that completes when cancellation is requested.
137    ///
138    /// Perfect for use with `tokio::select!` to race against operations.
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// use ferrous_di::CancellationToken;
144    ///
145    /// # async fn example() {
146    /// let token = CancellationToken::new();
147    /// 
148    /// # #[cfg(feature = "async")]
149    /// tokio::select! {
150    ///     result = some_operation() => {
151    ///         // Operation completed normally
152    ///     }
153    ///     _ = token.cancelled() => {
154    ///         // Operation was cancelled
155    ///     }
156    /// }
157    /// # }
158    /// 
159    /// # async fn some_operation() {}
160    /// ```
161    #[cfg(feature = "async")]
162    pub async fn cancelled(&self) {
163        loop {
164            if self.is_cancelled() {
165                return;
166            }
167            
168            // Small delay to avoid busy waiting
169            tokio::time::sleep(Duration::from_millis(1)).await;
170        }
171    }
172
173    /// Returns the elapsed time since this token was created.
174    ///
175    /// Useful for timeout-based cancellation in workflow engines.
176    pub fn elapsed(&self) -> Duration {
177        self.inner.created_at.elapsed()
178    }
179
180    /// Creates a token that will automatically cancel after the specified duration.
181    ///
182    /// Perfect for implementing timeouts in workflow nodes.
183    ///
184    /// # Examples
185    ///
186    /// ```
187    /// use ferrous_di::CancellationToken;
188    /// use std::time::Duration;
189    ///
190    /// # async fn example() {
191    /// let token = CancellationToken::with_timeout(Duration::from_secs(30));
192    /// 
193    /// // Token will automatically cancel after 30 seconds
194    /// # #[cfg(feature = "async")]
195    /// tokio::select! {
196    ///     result = long_running_operation() => {
197    ///         // Completed within timeout
198    ///     }
199    ///     _ = token.cancelled() => {
200    ///         // Timed out after 30 seconds
201    ///     }
202    /// }
203    /// # }
204    /// 
205    /// # async fn long_running_operation() {}
206    /// ```
207    #[cfg(feature = "async")]
208    pub fn with_timeout(timeout: Duration) -> Self {
209        let token = Self::new();
210        let token_clone = token.clone();
211        
212        tokio::spawn(async move {
213            tokio::time::sleep(timeout).await;
214            token_clone.cancel();
215        });
216        
217        token
218    }
219}
220
221impl Default for CancellationToken {
222    fn default() -> Self {
223        Self::new()
224    }
225}
226
227/// Error type for cancellation operations.
228#[derive(Debug, Clone)]
229pub struct CancellationError {
230    message: String,
231}
232
233impl CancellationError {
234    /// Creates a new cancellation error with the given message.
235    pub fn new(message: impl Into<String>) -> Self {
236        Self {
237            message: message.into(),
238        }
239    }
240}
241
242impl std::fmt::Display for CancellationError {
243    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        write!(f, "Cancellation error: {}", self.message)
245    }
246}
247
248impl std::error::Error for CancellationError {}
249
250/// Extension trait for Scope to easily create child scopes with cancellation.
251pub trait ScopeCancellationExt {
252    /// Creates a child scope with a cancellation token derived from the parent scope.
253    ///
254    /// Perfect for n8n-style hierarchical cancellation (workflow → run → node).
255    ///
256    /// # Examples
257    ///
258    /// ```
259    /// use ferrous_di::{ServiceCollection, CancellationToken, ScopeCancellationExt, Resolver};
260    ///
261    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
262    /// let mut services = ServiceCollection::new();
263    /// services.add_scoped_factory::<CancellationToken, _>(|_| CancellationToken::new());
264    ///
265    /// let provider = services.build();
266    /// let parent_scope = provider.create_scope();
267    /// 
268    /// // Child scope inherits cancellation from parent
269    /// let child_scope = parent_scope.with_cancellation_from_parent();
270    /// 
271    /// let parent_token = parent_scope.get_required::<CancellationToken>();
272    /// let child_token = child_scope.get_required::<CancellationToken>();
273    /// 
274    /// parent_token.cancel();
275    /// assert!(child_token.is_cancelled()); // Child is cancelled when parent is
276    /// # Ok(())
277    /// # }
278    /// ```
279    fn with_cancellation_from_parent(&self) -> Self;
280}
281
282impl ScopeCancellationExt for crate::provider::Scope {
283    fn with_cancellation_from_parent(&self) -> Self {
284        use std::sync::Arc;
285        use crate::traits::Resolver;
286        
287        // Create a child scope
288        let child_scope = self.create_child();
289        
290        // Get parent cancellation token if it exists
291        let parent_token = self.get::<CancellationToken>().unwrap_or_else(|_| {
292            // No parent token, create a new root token
293            Arc::new(CancellationToken::new())
294        });
295        
296        // Create child token that will be cancelled when parent is cancelled
297        let child_token = parent_token.child_token();
298        
299        // We need to inject the child token into the child scope
300        // Since we can't modify the service registration after the provider is built,
301        // we'll use a different approach: manually cache the token in the scope
302        
303        // Store the child token in the child scope's scoped storage
304        // This leverages the existing scoped caching mechanism
305        let token_key = crate::key::key_of_type::<CancellationToken>();
306        
307        #[cfg(feature = "once-cell")]
308        {
309            // Find the slot for CancellationToken in the registry
310            if let Some(reg) = child_scope.root.inner().registry.registrations.get(&token_key) {
311                if let Some(slot) = reg.scoped_slot {
312                    // Initialize the slot with our child token
313                    let _ = child_scope.scoped_cells[slot].set(Arc::new(child_token) as crate::registration::AnyArc);
314                }
315            }
316        }
317        
318        #[cfg(not(feature = "once-cell"))]
319        {
320            // Use HashMap-based scoped storage
321            let mut scoped = child_scope.scoped.lock().unwrap();
322            scoped.insert(token_key, Arc::new(child_token) as crate::registration::AnyArc);
323        }
324        
325        child_scope
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use std::time::Duration;
333
334    #[test]
335    fn test_cancellation_token_basic() {
336        let token = CancellationToken::new();
337        assert!(!token.is_cancelled());
338        
339        token.cancel();
340        assert!(token.is_cancelled());
341    }
342
343    #[test]
344    fn test_child_token_cancellation() {
345        let parent = CancellationToken::new();
346        let child = parent.child_token();
347        
348        assert!(!parent.is_cancelled());
349        assert!(!child.is_cancelled());
350        
351        parent.cancel();
352        assert!(parent.is_cancelled());
353        assert!(child.is_cancelled());
354    }
355
356    #[test]
357    fn test_child_token_independent_cancellation() {
358        let parent = CancellationToken::new();
359        let child = parent.child_token();
360        
361        child.cancel();
362        assert!(!parent.is_cancelled());
363        assert!(child.is_cancelled());
364    }
365
366    #[test]
367    fn test_throw_if_cancelled() {
368        let token = CancellationToken::new();
369        
370        // Should not throw when not cancelled
371        assert!(token.throw_if_cancelled().is_ok());
372        
373        token.cancel();
374        
375        // Should throw when cancelled
376        assert!(token.throw_if_cancelled().is_err());
377    }
378
379    #[cfg(feature = "async")]
380    #[tokio::test]
381    async fn test_timeout_cancellation() {
382        let token = CancellationToken::with_timeout(Duration::from_millis(10));
383        
384        assert!(!token.is_cancelled());
385        
386        // Wait for timeout
387        tokio::time::sleep(Duration::from_millis(20)).await;
388        
389        assert!(token.is_cancelled());
390    }
391
392    #[cfg(feature = "async")]
393    #[tokio::test]
394    async fn test_cancelled_future() {
395        let token = CancellationToken::new();
396        let token_clone = token.clone();
397        
398        // Spawn task to cancel token after delay
399        tokio::spawn(async move {
400            tokio::time::sleep(Duration::from_millis(10)).await;
401            token_clone.cancel();
402        });
403        
404        // Wait for cancellation
405        token.cancelled().await;
406        assert!(token.is_cancelled());
407    }
408
409    #[test]
410    fn test_scope_cancellation_ext() {
411        use crate::{ServiceCollection, traits::Resolver};
412        
413        let mut services = ServiceCollection::new();
414        services.add_scoped_factory::<CancellationToken, _>(|_| CancellationToken::new());
415        
416        let provider = services.build();
417        let parent_scope = provider.create_scope();
418        
419        // Create child scope with inherited cancellation
420        let child_scope = parent_scope.with_cancellation_from_parent();
421        
422        let parent_token = parent_scope.get_required::<CancellationToken>();
423        let child_token = child_scope.get_required::<CancellationToken>();
424        
425        // Initially neither should be cancelled
426        assert!(!parent_token.is_cancelled());
427        assert!(!child_token.is_cancelled());
428        
429        // Cancel parent - child should also be cancelled
430        parent_token.cancel();
431        assert!(parent_token.is_cancelled());
432        assert!(child_token.is_cancelled());
433    }
434}