dynamo_runtime/utils/tasks/tracker.rs
1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! # Task Tracker - Hierarchical Task Management System
5//!
6//! A composable task management system with configurable scheduling and error handling policies.
7//! The TaskTracker enables controlled concurrent execution with proper resource management,
8//! cancellation semantics, and retry support.
9//!
10//! ## Architecture Overview
11//!
12//! The TaskTracker system is built around three core abstractions that compose together:
13//!
14//! ### 1. **TaskScheduler** - Resource Management
15//!
16//! Controls when and how tasks acquire execution resources (permits, slots, etc.).
17//! Schedulers implement resource acquisition with cancellation support:
18//!
19//! ```text
20//! TaskScheduler::acquire_execution_slot(cancel_token) -> SchedulingResult<ResourceGuard>
21//! ```
22//!
23//! - **Resource Acquisition**: Can be cancelled to avoid unnecessary allocation
24//! - **RAII Guards**: Resources are automatically released when guards are dropped
25//! - **Pluggable**: Different scheduling policies (unlimited, semaphore, rate-limited, etc.)
26//!
27//! ### 2. **OnErrorPolicy** - Error Handling
28//!
29//! Defines how the system responds to task failures:
30//!
31//! ```text
32//! OnErrorPolicy::on_error(error, task_id) -> ErrorResponse
33//! ```
34//!
35//! - **ErrorResponse::Fail**: Log error, fail this task
36//! - **ErrorResponse::Shutdown**: Shutdown tracker and all children
37//! - **ErrorResponse::Custom(action)**: Execute custom logic that can return:
38//! - `ActionResult::Fail`: Handle error and fail the task
39//! - `ActionResult::Shutdown`: Shutdown tracker
40//! - `ActionResult::Continue { continuation }`: Continue with provided task
41//!
42//! ### 3. **Execution Pipeline** - Task Orchestration
43//!
44//! The execution pipeline coordinates scheduling, execution, and error handling:
45//!
46//! ```text
47//! 1. Acquire resources (scheduler.acquire_execution_slot)
48//! 2. Create task future (only after resources acquired)
49//! 3. Execute task while holding guard (RAII pattern)
50//! 4. Handle errors through policy (with retry support for cancellable tasks)
51//! 5. Update metrics and release resources
52//! ```
53//!
54//! ## Key Design Principles
55//!
56//! ### **Separation of Concerns**
57//! - **Scheduling**: When/how to allocate resources
58//! - **Execution**: Running tasks with proper resource management
59//! - **Error Handling**: Responding to failures with configurable policies
60//!
61//! ### **Composability**
62//! - Schedulers and error policies are independent and can be mixed/matched
63//! - Custom policies can be implemented via traits
64//! - Execution pipeline handles the coordination automatically
65//!
66//! ### **Resource Safety**
67//! - Resources are acquired before task creation (prevents early execution)
68//! - RAII pattern ensures resources are always released
69//! - Cancellation is supported during resource acquisition, not during execution
70//!
71//! ### **Retry Support**
72//! - Regular tasks (`spawn`): Cannot be retried (future is consumed)
73//! - Cancellable tasks (`spawn_cancellable`): Support retry via `FnMut` closures
74//! - Error policies can provide next executors via `ActionResult::Continue`
75//!
76//! ## Task Types
77//!
78//! ### Regular Tasks
79//! ```rust
80//! # use dynamo_runtime::utils::tasks::tracker::*;
81//! # #[tokio::main]
82//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
83//! # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
84//! let handle = tracker.spawn(async { Ok(42) });
85//! # let _result = handle.await?;
86//! # Ok(())
87//! # }
88//! ```
89//! - Simple futures that run to completion
90//! - Cannot be retried (future is consumed on first execution)
91//! - Suitable for one-shot operations
92//!
93//! ### Cancellable Tasks
94//! ```rust
95//! # use dynamo_runtime::utils::tasks::tracker::*;
96//! # #[tokio::main]
97//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
98//! # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
99//! let handle = tracker.spawn_cancellable(|cancel_token| async move {
100//! // Task can check cancel_token.is_cancelled() or use tokio::select!
101//! CancellableTaskResult::Ok(42)
102//! });
103//! # let _result = handle.await?;
104//! # Ok(())
105//! # }
106//! ```
107//! - Receive a `CancellationToken` for cooperative cancellation
108//! - Support retry via `FnMut` closures (can be called multiple times)
109//! - Return `CancellableTaskResult` to indicate success/cancellation/error
110//!
111//! ## Hierarchical Structure
112//!
113//! TaskTrackers form parent-child relationships:
114//! - **Metrics**: Child metrics aggregate to parents
115//! - **Cancellation**: Parent cancellation propagates to children
116//! - **Independence**: Child cancellation doesn't affect parents
117//! - **Cleanup**: `join()` waits for all descendants bottom-up
118//!
119//! ## Metrics and Observability
120//!
121//! Built-in metrics track task lifecycle:
122//! - `issued`: Tasks submitted via spawn methods
123//! - `active`: Currently executing tasks
124//! - `success/failed/cancelled/rejected`: Final outcomes
125//! - `pending`: Issued but not completed (issued - completed)
126//! - `queued`: Waiting for resources (pending - active)
127//!
128//! Optional Prometheus integration available via `PrometheusTaskMetrics`.
129//!
130//! ## Usage Examples
131//!
132//! ### Basic Task Execution
133//! ```rust
134//! use dynamo_runtime::utils::tasks::tracker::*;
135//! use std::sync::Arc;
136//!
137//! # #[tokio::main]
138//! # async fn main() -> anyhow::Result<()> {
139//! let scheduler = SemaphoreScheduler::with_permits(10);
140//! let error_policy = LogOnlyPolicy::new();
141//! let tracker = TaskTracker::new(scheduler, error_policy)?;
142//!
143//! let handle = tracker.spawn(async { Ok(42) });
144//! let result = handle.await??;
145//! assert_eq!(result, 42);
146//! # Ok(())
147//! # }
148//! ```
149//!
150//! ### Cancellable Tasks with Retry
151//! ```rust
152//! # use dynamo_runtime::utils::tasks::tracker::*;
153//! # use std::sync::Arc;
154//! # #[tokio::main]
155//! # async fn main() -> anyhow::Result<()> {
156//! # let scheduler = SemaphoreScheduler::with_permits(10);
157//! # let error_policy = LogOnlyPolicy::new();
158//! # let tracker = TaskTracker::new(scheduler, error_policy)?;
159//! let handle = tracker.spawn_cancellable(|cancel_token| async move {
160//! tokio::select! {
161//! result = do_work() => CancellableTaskResult::Ok(result),
162//! _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
163//! }
164//! });
165//! # Ok(())
166//! # }
167//! # async fn do_work() -> i32 { 42 }
168//! ```
169//!
170//! ### Task-Driven Retry with Continuations
171//! ```rust
172//! # use dynamo_runtime::utils::tasks::tracker::*;
173//! # use anyhow::anyhow;
174//! # #[tokio::main]
175//! # async fn main() -> anyhow::Result<()> {
176//! # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new())?;
177//! let handle = tracker.spawn(async {
178//! // Simulate initial failure with retry logic
179//! let error = FailedWithContinuation::from_fn(
180//! anyhow!("Network timeout"),
181//! || async {
182//! println!("Retrying with exponential backoff...");
183//! tokio::time::sleep(std::time::Duration::from_millis(100)).await;
184//! Ok("Success after retry".to_string())
185//! }
186//! );
187//! let result: Result<String, anyhow::Error> = Err(error);
188//! result
189//! });
190//!
191//! let result = handle.await?;
192//! assert!(result.is_ok());
193//! # Ok(())
194//! # }
195//! ```
196//!
197//! ### Custom Error Policy with Continuation
198//! ```rust
199//! # use dynamo_runtime::utils::tasks::tracker::*;
200//! # use std::sync::Arc;
201//! # use async_trait::async_trait;
202//! # #[derive(Debug)]
203//! struct RetryPolicy {
204//! max_attempts: u32,
205//! }
206//!
207//! impl OnErrorPolicy for RetryPolicy {
208//! fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
209//! Arc::new(RetryPolicy { max_attempts: self.max_attempts })
210//! }
211//!
212//! fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
213//! None // Stateless policy
214//! }
215//!
216//! fn on_error(&self, _error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
217//! if context.attempt_count < self.max_attempts {
218//! ErrorResponse::Custom(Box::new(RetryAction))
219//! } else {
220//! ErrorResponse::Fail
221//! }
222//! }
223//! }
224//!
225//! # #[derive(Debug)]
226//! struct RetryAction;
227//!
228//! #[async_trait]
229//! impl OnErrorAction for RetryAction {
230//! async fn execute(
231//! &self,
232//! _error: &anyhow::Error,
233//! _task_id: TaskId,
234//! _attempt_count: u32,
235//! _context: &TaskExecutionContext,
236//! ) -> ActionResult {
237//! // In practice, you would create a continuation here
238//! ActionResult::Fail
239//! }
240//! }
241//! ```
242//!
243//! ## Future Extensibility
244//!
245//! The system is designed for extensibility. See the source code for detailed TODO comments
246//! describing additional policies that can be implemented:
247//! - **Scheduling**: Token bucket rate limiting, adaptive concurrency, memory-aware scheduling
248//! - **Error Handling**: Retry with backoff, circuit breakers, dead letter queues
249//!
250//! Each TODO comment includes complete implementation guidance with data structures,
251//! algorithms, and dependencies needed for future contributors.
252//!
253//! ## Hierarchical Organization
254//!
255//! ```rust
256//! use dynamo_runtime::utils::tasks::tracker::{
257//! TaskTracker, UnlimitedScheduler, ThresholdCancelPolicy, SemaphoreScheduler
258//! };
259//!
260//! # async fn example() -> anyhow::Result<()> {
261//! // Create root tracker with failure threshold policy
262//! let error_policy = ThresholdCancelPolicy::with_threshold(5);
263//! let root = TaskTracker::builder()
264//! .scheduler(UnlimitedScheduler::new())
265//! .error_policy(error_policy)
266//! .build()?;
267//!
268//! // Create child trackers for different components
269//! let api_handler = root.child_tracker()?; // Inherits policies
270//! let background_jobs = root.child_tracker()?;
271//!
272//! // Children can have custom policies
273//! let rate_limited = root.child_tracker_builder()
274//! .scheduler(SemaphoreScheduler::with_permits(2)) // Custom concurrency limit
275//! .build()?;
276//!
277//! // Tasks run independently but metrics roll up
278//! api_handler.spawn(async { Ok(()) });
279//! background_jobs.spawn(async { Ok(()) });
280//! rate_limited.spawn(async { Ok(()) });
281//!
282//! // Join all children hierarchically
283//! root.join().await;
284//! assert_eq!(root.metrics().success(), 3); // Sees all successes
285//! # Ok(())
286//! # }
287//! ```
288//!
289//! ## Policy Examples
290//!
291//! ```rust
292//! use dynamo_runtime::utils::tasks::tracker::{
293//! TaskTracker, CancelOnError, SemaphoreScheduler, ThresholdCancelPolicy
294//! };
295//!
296//! # async fn example() -> anyhow::Result<()> {
297//! // Pattern-based error cancellation
298//! let (error_policy, token) = CancelOnError::with_patterns(
299//! vec!["OutOfMemory".to_string(), "DeviceError".to_string()]
300//! );
301//! let simple = TaskTracker::builder()
302//! .scheduler(SemaphoreScheduler::with_permits(5))
303//! .error_policy(error_policy)
304//! .build()?;
305//!
306//! // Threshold-based cancellation with monitoring
307//! let scheduler = SemaphoreScheduler::with_permits(10); // Returns Arc<SemaphoreScheduler>
308//! let error_policy = ThresholdCancelPolicy::with_threshold(3); // Returns Arc<Policy>
309//!
310//! let advanced = TaskTracker::builder()
311//! .scheduler(scheduler)
312//! .error_policy(error_policy)
313//! .build()?;
314//!
315//! // Monitor cancellation externally
316//! if token.is_cancelled() {
317//! println!("Tracker cancelled due to failures");
318//! }
319//! # Ok(())
320//! # }
321//! ```
322//!
323//! ## Metrics and Observability
324//!
325//! ```rust
326//! use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
327//!
328//! # async fn example() -> anyhow::Result<()> {
329//! let tracker = std::sync::Arc::new(TaskTracker::builder()
330//! .scheduler(SemaphoreScheduler::with_permits(2)) // Only 2 concurrent tasks
331//! .error_policy(LogOnlyPolicy::new())
332//! .build()?);
333//!
334//! // Spawn multiple tasks
335//! for i in 0..5 {
336//! tracker.spawn(async move {
337//! tokio::time::sleep(std::time::Duration::from_millis(100)).await;
338//! Ok(i)
339//! });
340//! }
341//!
342//! // Check metrics
343//! let metrics = tracker.metrics();
344//! println!("Issued: {}", metrics.issued()); // 5 tasks issued
345//! println!("Active: {}", metrics.active()); // 2 tasks running (semaphore limit)
346//! println!("Queued: {}", metrics.queued()); // 3 tasks waiting in scheduler queue
347//! println!("Pending: {}", metrics.pending()); // 5 tasks not yet completed
348//!
349//! tracker.join().await;
350//! assert_eq!(metrics.success(), 5);
351//! assert_eq!(metrics.pending(), 0);
352//! # Ok(())
353//! # }
354//! ```
355//!
356//! ## Prometheus Integration
357//!
358//! ```rust
359//! use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
360//! use dynamo_runtime::metrics::MetricsRegistry;
361//!
362//! # async fn example(registry: &dyn MetricsRegistry) -> anyhow::Result<()> {
363//! // Root tracker with Prometheus metrics
364//! let tracker = TaskTracker::new_with_prometheus(
365//! SemaphoreScheduler::with_permits(10),
366//! LogOnlyPolicy::new(),
367//! registry,
368//! "my_component"
369//! )?;
370//!
371//! // Metrics automatically exported to Prometheus:
372//! // - my_component_tasks_issued_total
373//! // - my_component_tasks_success_total
374//! // - my_component_tasks_failed_total
375//! // - my_component_tasks_active
376//! // - my_component_tasks_queued
377//! # Ok(())
378//! # }
379//! ```
380
381use std::future::Future;
382use std::pin::Pin;
383use std::sync::Arc;
384use std::sync::atomic::{AtomicU64, Ordering};
385
386use crate::metrics::MetricsRegistry;
387use crate::metrics::prometheus_names::task_tracker;
388use anyhow::Result;
389use async_trait::async_trait;
390use derive_builder::Builder;
391use std::collections::HashSet;
392use std::sync::{Mutex, RwLock, Weak};
393use std::time::Duration;
394use thiserror::Error;
395use tokio::sync::Semaphore;
396use tokio::task::JoinHandle;
397use tokio_util::sync::CancellationToken;
398use tokio_util::task::TaskTracker as TokioTaskTracker;
399use tracing::{Instrument, debug, error, warn};
400use uuid::Uuid;
401
402/// Error type for task execution results
403///
404/// This enum distinguishes between task cancellation and actual failures,
405/// enabling proper metrics tracking and error handling.
406#[derive(Error, Debug)]
407pub enum TaskError {
408 /// Task was cancelled (either via cancellation token or tracker shutdown)
409 #[error("Task was cancelled")]
410 Cancelled,
411
412 /// Task failed with an error
413 #[error(transparent)]
414 Failed(#[from] anyhow::Error),
415
416 /// Cannot spawn task on a closed tracker
417 #[error("Cannot spawn task on a closed tracker")]
418 TrackerClosed,
419}
420
421impl TaskError {
422 /// Check if this error represents a cancellation
423 ///
424 /// This is a convenience method for compatibility and readability.
425 pub fn is_cancellation(&self) -> bool {
426 matches!(self, TaskError::Cancelled)
427 }
428
429 /// Check if this error represents a failure
430 pub fn is_failure(&self) -> bool {
431 matches!(self, TaskError::Failed(_))
432 }
433
434 /// Get the underlying anyhow::Error for failures, or a cancellation error for cancellations
435 ///
436 /// This is provided for compatibility with existing code that expects anyhow::Error.
437 pub fn into_anyhow(self) -> anyhow::Error {
438 match self {
439 TaskError::Failed(err) => err,
440 TaskError::Cancelled => anyhow::anyhow!("Task was cancelled"),
441 TaskError::TrackerClosed => anyhow::anyhow!("Cannot spawn task on a closed tracker"),
442 }
443 }
444}
445
446/// A handle to a spawned task that provides both join functionality and cancellation control
447///
448/// `TaskHandle` wraps a `JoinHandle` and provides access to the task's individual cancellation token.
449/// This allows fine-grained control over individual tasks while maintaining the familiar `JoinHandle` API.
450///
451/// # Example
452/// ```rust
453/// # use dynamo_runtime::utils::tasks::tracker::*;
454/// # #[tokio::main]
455/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
456/// # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new())?;
457/// let handle = tracker.spawn(async {
458/// tokio::time::sleep(std::time::Duration::from_millis(100)).await;
459/// Ok(42)
460/// });
461///
462/// // Access the task's cancellation token
463/// let cancel_token = handle.cancellation_token();
464///
465/// // Can cancel the specific task
466/// // cancel_token.cancel();
467///
468/// // Await the task like a normal JoinHandle
469/// let result = handle.await?;
470/// assert_eq!(result?, 42);
471/// # Ok(())
472/// # }
473/// ```
474pub struct TaskHandle<T> {
475 join_handle: JoinHandle<Result<T, TaskError>>,
476 cancel_token: CancellationToken,
477}
478
479impl<T> TaskHandle<T> {
480 /// Create a new TaskHandle wrapping a JoinHandle and cancellation token
481 pub(crate) fn new(
482 join_handle: JoinHandle<Result<T, TaskError>>,
483 cancel_token: CancellationToken,
484 ) -> Self {
485 Self {
486 join_handle,
487 cancel_token,
488 }
489 }
490
491 /// Get the cancellation token for this specific task
492 ///
493 /// This token is a child of the tracker's cancellation token and can be used
494 /// to cancel just this individual task without affecting other tasks.
495 ///
496 /// # Example
497 /// ```rust
498 /// # use dynamo_runtime::utils::tasks::tracker::*;
499 /// # #[tokio::main]
500 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
501 /// # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new())?;
502 /// let handle = tracker.spawn(async {
503 /// tokio::time::sleep(std::time::Duration::from_secs(10)).await;
504 /// Ok("completed")
505 /// });
506 ///
507 /// // Cancel this specific task
508 /// handle.cancellation_token().cancel();
509 ///
510 /// // Task will be cancelled
511 /// let result = handle.await?;
512 /// assert!(result.is_err());
513 /// # Ok(())
514 /// # }
515 /// ```
516 pub fn cancellation_token(&self) -> &CancellationToken {
517 &self.cancel_token
518 }
519
520 /// Abort the task associated with this handle
521 ///
522 /// This is equivalent to calling `JoinHandle::abort()` and will cause the task
523 /// to be cancelled immediately without running any cleanup code.
524 pub fn abort(&self) {
525 self.join_handle.abort();
526 }
527
528 /// Check if the task associated with this handle has finished
529 ///
530 /// This is equivalent to calling `JoinHandle::is_finished()`.
531 pub fn is_finished(&self) -> bool {
532 self.join_handle.is_finished()
533 }
534}
535
536impl<T> std::future::Future for TaskHandle<T> {
537 type Output = Result<Result<T, TaskError>, tokio::task::JoinError>;
538
539 fn poll(
540 mut self: std::pin::Pin<&mut Self>,
541 cx: &mut std::task::Context<'_>,
542 ) -> std::task::Poll<Self::Output> {
543 std::pin::Pin::new(&mut self.join_handle).poll(cx)
544 }
545}
546
547impl<T> std::fmt::Debug for TaskHandle<T> {
548 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
549 f.debug_struct("TaskHandle")
550 .field("join_handle", &"<JoinHandle>")
551 .field("cancel_token", &self.cancel_token)
552 .finish()
553 }
554}
555
556/// Trait for continuation tasks that execute after a failure
557///
558/// This trait allows tasks to define what should happen next after a failure,
559/// eliminating the need for complex type erasure and executor management.
560/// Tasks implement this trait to provide clean continuation logic.
561#[async_trait]
562pub trait Continuation: Send + Sync + std::fmt::Debug + std::any::Any {
563 /// Execute the continuation task after a failure
564 ///
565 /// This method is called when a task fails and a continuation is provided.
566 /// The implementation can perform retry logic, fallback operations,
567 /// transformations, or any other follow-up action.
568 /// Returns the result in a type-erased Box<dyn Any> for flexibility.
569 async fn execute(
570 &self,
571 cancel_token: CancellationToken,
572 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>>;
573}
574
575/// Error type that signals a task failed but provided a continuation
576///
577/// This error type contains a continuation task that can be executed as a follow-up.
578/// The task defines its own continuation logic through the Continuation trait.
579#[derive(Error, Debug)]
580#[error("Task failed with continuation: {source}")]
581pub struct FailedWithContinuation {
582 /// The underlying error that caused the task to fail
583 #[source]
584 pub source: anyhow::Error,
585 /// The continuation task for follow-up execution
586 pub continuation: Arc<dyn Continuation + Send + Sync + 'static>,
587}
588
589impl FailedWithContinuation {
590 /// Create a new FailedWithContinuation with a continuation task
591 ///
592 /// The continuation task defines its own execution logic through the Continuation trait.
593 pub fn new(
594 source: anyhow::Error,
595 continuation: Arc<dyn Continuation + Send + Sync + 'static>,
596 ) -> Self {
597 Self {
598 source,
599 continuation,
600 }
601 }
602
603 /// Create a FailedWithContinuation and convert it to anyhow::Error
604 ///
605 /// This is a convenience method for tasks to easily return continuation errors.
606 pub fn into_anyhow(
607 source: anyhow::Error,
608 continuation: Arc<dyn Continuation + Send + Sync + 'static>,
609 ) -> anyhow::Error {
610 anyhow::Error::new(Self::new(source, continuation))
611 }
612
613 /// Create a FailedWithContinuation from a simple async function (no cancellation support)
614 ///
615 /// This is a convenience method for creating continuation errors from simple async closures
616 /// that don't need to handle cancellation. The function will be executed when the
617 /// continuation is triggered.
618 ///
619 /// # Example
620 /// ```rust
621 /// # use dynamo_runtime::utils::tasks::tracker::*;
622 /// # use anyhow::anyhow;
623 /// # #[tokio::main]
624 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
625 /// let error = FailedWithContinuation::from_fn(
626 /// anyhow!("Initial task failed"),
627 /// || async {
628 /// println!("Retrying operation...");
629 /// Ok("retry_result".to_string())
630 /// }
631 /// );
632 /// # Ok(())
633 /// # }
634 /// ```
635 pub fn from_fn<F, Fut, T>(source: anyhow::Error, f: F) -> anyhow::Error
636 where
637 F: Fn() -> Fut + Send + Sync + 'static,
638 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
639 T: Send + 'static,
640 {
641 let continuation = Arc::new(FnContinuation { f: Box::new(f) });
642 Self::into_anyhow(source, continuation)
643 }
644
645 /// Create a FailedWithContinuation from a cancellable async function
646 ///
647 /// This is a convenience method for creating continuation errors from async closures
648 /// that can handle cancellation. The function receives a CancellationToken
649 /// and should check it periodically for cooperative cancellation.
650 ///
651 /// # Example
652 /// ```rust
653 /// # use dynamo_runtime::utils::tasks::tracker::*;
654 /// # use anyhow::anyhow;
655 /// # #[tokio::main]
656 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
657 /// let error = FailedWithContinuation::from_cancellable(
658 /// anyhow!("Initial task failed"),
659 /// |cancel_token| async move {
660 /// if cancel_token.is_cancelled() {
661 /// return Err(anyhow!("Cancelled"));
662 /// }
663 /// println!("Retrying operation with cancellation support...");
664 /// Ok("retry_result".to_string())
665 /// }
666 /// );
667 /// # Ok(())
668 /// # }
669 /// ```
670 pub fn from_cancellable<F, Fut, T>(source: anyhow::Error, f: F) -> anyhow::Error
671 where
672 F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
673 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
674 T: Send + 'static,
675 {
676 let continuation = Arc::new(CancellableFnContinuation { f: Box::new(f) });
677 Self::into_anyhow(source, continuation)
678 }
679}
680
681/// Extension trait for extracting FailedWithContinuation from anyhow::Error
682///
683/// This trait provides methods to detect and extract continuation tasks
684/// from the type-erased anyhow::Error system.
685pub trait FailedWithContinuationExt {
686 /// Extract a continuation task if this error contains one
687 ///
688 /// Returns the continuation task if the error is a FailedWithContinuation,
689 /// None otherwise.
690 fn extract_continuation(&self) -> Option<Arc<dyn Continuation + Send + Sync + 'static>>;
691
692 /// Check if this error has a continuation
693 fn has_continuation(&self) -> bool;
694}
695
696impl FailedWithContinuationExt for anyhow::Error {
697 fn extract_continuation(&self) -> Option<Arc<dyn Continuation + Send + Sync + 'static>> {
698 // Try to downcast to FailedWithContinuation
699 if let Some(continuation_err) = self.downcast_ref::<FailedWithContinuation>() {
700 Some(continuation_err.continuation.clone())
701 } else {
702 None
703 }
704 }
705
706 fn has_continuation(&self) -> bool {
707 self.downcast_ref::<FailedWithContinuation>().is_some()
708 }
709}
710
711/// Implementation of Continuation for simple async functions (no cancellation support)
712struct FnContinuation<F> {
713 f: Box<F>,
714}
715
716impl<F> std::fmt::Debug for FnContinuation<F> {
717 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
718 f.debug_struct("FnContinuation")
719 .field("f", &"<closure>")
720 .finish()
721 }
722}
723
724#[async_trait]
725impl<F, Fut, T> Continuation for FnContinuation<F>
726where
727 F: Fn() -> Fut + Send + Sync + 'static,
728 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
729 T: Send + 'static,
730{
731 async fn execute(
732 &self,
733 _cancel_token: CancellationToken,
734 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
735 match (self.f)().await {
736 Ok(result) => TaskExecutionResult::Success(Box::new(result)),
737 Err(error) => TaskExecutionResult::Error(error),
738 }
739 }
740}
741
742/// Implementation of Continuation for cancellable async functions
743struct CancellableFnContinuation<F> {
744 f: Box<F>,
745}
746
747impl<F> std::fmt::Debug for CancellableFnContinuation<F> {
748 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
749 f.debug_struct("CancellableFnContinuation")
750 .field("f", &"<closure>")
751 .finish()
752 }
753}
754
755#[async_trait]
756impl<F, Fut, T> Continuation for CancellableFnContinuation<F>
757where
758 F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
759 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
760 T: Send + 'static,
761{
762 async fn execute(
763 &self,
764 cancel_token: CancellationToken,
765 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
766 match (self.f)(cancel_token).await {
767 Ok(result) => TaskExecutionResult::Success(Box::new(result)),
768 Err(error) => TaskExecutionResult::Error(error),
769 }
770 }
771}
772
773/// Common scheduling policies for task execution
774///
775/// These enums provide convenient access to built-in scheduling policies
776/// without requiring manual construction of policy objects.
777///
778/// ## Cancellation Semantics
779///
780/// All schedulers follow the same cancellation behavior:
781/// - Respect cancellation tokens before resource allocation (permits, etc.)
782/// - Once task execution begins, always await completion
783/// - Let tasks handle their own cancellation internally
784#[derive(Debug, Clone)]
785pub enum SchedulingPolicy {
786 /// No concurrency limits - execute all tasks immediately
787 Unlimited,
788 /// Semaphore-based concurrency limiting
789 Semaphore(usize),
790 // TODO: Future scheduling policies to implement
791 //
792 // /// Token bucket rate limiting with burst capacity
793 // /// Implementation: Use tokio::time::interval for refill, AtomicU64 for tokens.
794 // /// acquire() decrements tokens, schedule() waits for refill if empty.
795 // /// Burst allows temporary spikes above steady rate.
796 // /// struct: { rate: f64, burst: usize, tokens: AtomicU64, last_refill: Mutex<Instant> }
797 // /// Example: TokenBucket { rate: 10.0, burst: 5 } = 10 tasks/sec, burst up to 5
798 // TokenBucket { rate: f64, burst: usize },
799 //
800 // /// Weighted fair scheduling across multiple priority classes
801 // /// Implementation: Maintain separate VecDeque for each priority class.
802 // /// Use weighted round-robin: serve N tasks from high, M from normal, etc.
803 // /// Track deficit counters to ensure fairness over time.
804 // /// struct: { queues: HashMap<String, VecDeque<Task>>, weights: Vec<(String, u32)> }
805 // /// Example: WeightedFair { weights: vec![("high", 70), ("normal", 20), ("low", 10)] }
806 // WeightedFair { weights: Vec<(String, u32)> },
807 //
808 // /// Memory-aware scheduling that limits tasks based on available memory
809 // /// Implementation: Monitor system memory via /proc/meminfo or sysinfo crate.
810 // /// Pause scheduling when available memory < threshold, resume when memory freed.
811 // /// Use exponential backoff for memory checks to avoid overhead.
812 // /// struct: { max_memory_mb: usize, check_interval: Duration, semaphore: Semaphore }
813 // MemoryAware { max_memory_mb: usize },
814 //
815 // /// CPU-aware scheduling that adjusts concurrency based on CPU load
816 // /// Implementation: Sample system load via sysinfo crate every N seconds.
817 // /// Dynamically resize internal semaphore permits based on load average.
818 // /// Use PID controller for smooth adjustments, avoid oscillation.
819 // /// struct: { max_cpu_percent: f32, permits: Arc<Semaphore>, sampler: tokio::task }
820 // CpuAware { max_cpu_percent: f32 },
821 //
822 // /// Adaptive scheduler that automatically adjusts concurrency based on performance
823 // /// Implementation: Track task latency and throughput in sliding windows.
824 // /// Increase permits if latency low & throughput stable, decrease if latency spikes.
825 // /// Use additive increase, multiplicative decrease (AIMD) algorithm.
826 // /// struct: { permits: AtomicUsize, latency_tracker: RingBuffer, throughput_tracker: RingBuffer }
827 // Adaptive { initial_permits: usize },
828 //
829 // /// Throttling scheduler that enforces minimum time between task starts
830 // /// Implementation: Store last_execution time in AtomicU64 (unix timestamp).
831 // /// Before scheduling, check elapsed time and tokio::time::sleep if needed.
832 // /// Useful for rate-limiting API calls to external services.
833 // /// struct: { min_interval: Duration, last_execution: AtomicU64 }
834 // Throttling { min_interval_ms: u64 },
835 //
836 // /// Batch scheduler that groups tasks and executes them together
837 // /// Implementation: Collect tasks in Vec<Task>, use tokio::time::timeout for max_wait.
838 // /// Execute batch when size reached OR timeout expires, whichever first.
839 // /// Use futures::future::join_all for parallel execution within batch.
840 // /// struct: { batch_size: usize, max_wait: Duration, pending: Mutex<Vec<Task>> }
841 // Batch { batch_size: usize, max_wait_ms: u64 },
842 //
843 // /// Priority-based scheduler with separate queues for different priority levels
844 // /// Implementation: Three separate semaphores for high/normal/low priorities.
845 // /// Always serve high before normal, normal before low (strict priority).
846 // /// Add starvation protection: promote normal->high after timeout.
847 // /// struct: { high_sem: Semaphore, normal_sem: Semaphore, low_sem: Semaphore }
848 // Priority { high: usize, normal: usize, low: usize },
849 //
850 // /// Backpressure-aware scheduler that monitors downstream capacity
851 // /// Implementation: Track external queue depth via provided callback/metric.
852 // /// Pause scheduling when queue_threshold exceeded, resume after pause_duration.
853 // /// Use exponential backoff for repeated backpressure events.
854 // /// struct: { queue_checker: Arc<dyn Fn() -> usize>, threshold: usize, pause_duration: Duration }
855 // Backpressure { queue_threshold: usize, pause_duration_ms: u64 },
856}
857
858/// Trait for implementing error handling policies
859///
860/// Error policies are lightweight, synchronous decision-makers that analyze task failures
861/// and return an ErrorResponse telling the TaskTracker what action to take. The TaskTracker
862/// handles all the actual work (cancellation, metrics, etc.) based on the policy's response.
863///
864/// ## Key Design Principles
865/// - **Synchronous**: Policies make fast decisions without async operations
866/// - **Stateless where possible**: TaskTracker manages cancellation tokens and state
867/// - **Composable**: Policies can be combined and nested in hierarchies
868/// - **Focused**: Each policy handles one specific error pattern or strategy
869///
870/// Per-task error handling context
871///
872/// Provides context information and state management for error policies.
873/// The state field allows policies to maintain per-task state across multiple error attempts.
874pub struct OnErrorContext {
875 /// Number of times this task has been attempted (starts at 1)
876 pub attempt_count: u32,
877 /// Unique identifier of the failed task
878 pub task_id: TaskId,
879 /// Full execution context with access to scheduler, metrics, etc.
880 pub execution_context: TaskExecutionContext,
881 /// Optional per-task state managed by the policy (None for stateless policies)
882 pub state: Option<Box<dyn std::any::Any + Send + 'static>>,
883}
884
885/// Error handling policy trait for task failures
886///
887/// Policies define how the TaskTracker responds to task failures.
888/// They can be stateless (like LogOnlyPolicy) or maintain per-task state
889/// (like ThresholdCancelPolicy with per-task failure counters).
890pub trait OnErrorPolicy: Send + Sync + std::fmt::Debug {
891 /// Create a child policy for a child tracker
892 ///
893 /// This allows policies to maintain hierarchical relationships,
894 /// such as child cancellation tokens or shared circuit breaker state.
895 fn create_child(&self) -> Arc<dyn OnErrorPolicy>;
896
897 /// Create per-task context state (None if policy is stateless)
898 ///
899 /// This method is called once per task when the first error occurs.
900 /// Stateless policies should return None to avoid unnecessary heap allocations.
901 /// Stateful policies should return Some(Box::new(initial_state)).
902 ///
903 /// # Returns
904 /// * `None` - Policy doesn't need per-task state (no heap allocation)
905 /// * `Some(state)` - Initial state for this task (heap allocated when needed)
906 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>>;
907
908 /// Handle a task failure and return the desired response
909 ///
910 /// # Arguments
911 /// * `error` - The error that occurred
912 /// * `context` - Mutable context with attempt count, task info, and optional state
913 ///
914 /// # Returns
915 /// ErrorResponse indicating how the TaskTracker should handle this failure
916 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse;
917
918 /// Should continuations be allowed for this error?
919 ///
920 /// This method is called before checking if a task provided a continuation to determine
921 /// whether the policy allows continuation-based retries at all. If this returns `false`,
922 /// any `FailedWithContinuation` will be ignored and the error will be handled through
923 /// the normal policy response.
924 ///
925 /// # Arguments
926 /// * `error` - The error that occurred
927 /// * `context` - Per-task context with attempt count and state
928 ///
929 /// # Returns
930 /// * `true` - Allow continuations, check for `FailedWithContinuation` (default)
931 /// * `false` - Reject continuations, handle through normal policy response
932 fn allow_continuation(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
933 true // Default: allow continuations
934 }
935
936 /// Should this continuation be rescheduled through the scheduler?
937 ///
938 /// This method is called when a continuation is about to be executed to determine
939 /// whether it should go through the scheduler's acquisition process again or execute
940 /// immediately with the current execution permission.
941 ///
942 /// **What this means:**
943 /// - **Don't reschedule (`false`)**: Execute continuation immediately with current permission
944 /// - **Reschedule (`true`)**: Release current permission, go through scheduler again
945 ///
946 /// Rescheduling means the continuation will be subject to the scheduler's policies
947 /// again (rate limiting, concurrency limits, backoff delays, etc.).
948 ///
949 /// # Arguments
950 /// * `error` - The error that triggered this retry decision
951 /// * `context` - Per-task context with attempt count and state
952 ///
953 /// # Returns
954 /// * `false` - Execute continuation immediately (default, efficient)
955 /// * `true` - Reschedule through scheduler (for delays, rate limiting, backoff)
956 fn should_reschedule(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
957 false // Default: immediate execution
958 }
959}
960
961/// Common error handling policies for task failure management
962///
963/// These enums provide convenient access to built-in error handling policies
964/// without requiring manual construction of policy objects.
965#[derive(Debug, Clone)]
966pub enum ErrorPolicy {
967 /// Log errors but continue execution - no cancellation
968 LogOnly,
969 /// Cancel all tasks on any error (using default error patterns)
970 CancelOnError,
971 /// Cancel all tasks when specific error patterns are encountered
972 CancelOnPatterns(Vec<String>),
973 /// Cancel after a threshold number of failures
974 CancelOnThreshold { max_failures: usize },
975 /// Cancel when failure rate exceeds threshold within time window
976 CancelOnRate {
977 max_failure_rate: f32,
978 window_secs: u64,
979 },
980 // TODO: Future error policies to implement
981 //
982 // /// Retry failed tasks with exponential backoff
983 // /// Implementation: Store original task in retry queue with attempt count.
984 // /// Use tokio::time::sleep for delays: backoff_ms * 2^attempt.
985 // /// Spawn retry as new task, preserve original task_id for tracing.
986 // /// Need task cloning support in scheduler interface.
987 // /// struct: { max_attempts: usize, backoff_ms: u64, retry_queue: VecDeque<(Task, u32)> }
988 // Retry { max_attempts: usize, backoff_ms: u64 },
989 //
990 // /// Send failed tasks to a dead letter queue for later processing
991 // /// Implementation: Use tokio::sync::mpsc::channel for queue.
992 // /// Serialize task info (id, error, payload) for persistence.
993 // /// Background worker drains queue to external storage (Redis/DB).
994 // /// Include retry count and timestamps for debugging.
995 // /// struct: { queue: mpsc::Sender<DeadLetterItem>, storage: Arc<dyn DeadLetterStorage> }
996 // DeadLetter { queue_name: String },
997 //
998 // /// Execute fallback logic when tasks fail
999 // /// Implementation: Store fallback closure in Arc for thread-safety.
1000 // /// Execute fallback in same context as failed task (inherit cancel token).
1001 // /// Track fallback success/failure separately from original task metrics.
1002 // /// Consider using enum for common fallback patterns (default value, noop, etc).
1003 // /// struct: { fallback_fn: Arc<dyn Fn(TaskId, Error) -> BoxFuture<Result<()>>> }
1004 // Fallback { fallback_fn: Arc<dyn Fn() -> BoxFuture<'static, Result<()>>> },
1005 //
1006 // /// Circuit breaker pattern - stop executing after threshold failures
1007 // /// Implementation: Track state (Closed/Open/HalfOpen) with AtomicU8.
1008 // /// Use failure window (last N tasks) or time window for threshold.
1009 // /// In Open state, reject tasks immediately, use timer for recovery.
1010 // /// In HalfOpen, allow one test task to check if issues resolved.
1011 // /// struct: { state: AtomicU8, failure_count: AtomicU64, last_failure: AtomicU64 }
1012 // CircuitBreaker { failure_threshold: usize, timeout_secs: u64 },
1013 //
1014 // /// Resource protection policy that monitors memory/CPU usage
1015 // /// Implementation: Background task samples system resources via sysinfo.
1016 // /// Cancel tracker when memory > threshold, use process-level monitoring.
1017 // /// Implement graceful degradation: warn at 80%, cancel at 90%.
1018 // /// Include both system-wide and process-specific thresholds.
1019 // /// struct: { monitor_task: JoinHandle, thresholds: ResourceThresholds, cancel_token: CancellationToken }
1020 // ResourceProtection { max_memory_mb: usize },
1021 //
1022 // /// Timeout policy that cancels tasks exceeding maximum duration
1023 // /// Implementation: Wrap each task with tokio::time::timeout.
1024 // /// Store task start time, check duration in on_error callback.
1025 // /// Distinguish timeout errors from other task failures in metrics.
1026 // /// Consider per-task or global timeout strategies.
1027 // /// struct: { max_duration: Duration, timeout_tracker: HashMap<TaskId, Instant> }
1028 // Timeout { max_duration_secs: u64 },
1029 //
1030 // /// Sampling policy that only logs a percentage of errors
1031 // /// Implementation: Use thread-local RNG for sampling decisions.
1032 // /// Hash task_id for deterministic sampling (same task always sampled).
1033 // /// Store sample rate as f32, compare with rand::random::<f32>().
1034 // /// Include rate in log messages for context.
1035 // /// struct: { sample_rate: f32, rng: ThreadLocal<RefCell<SmallRng>> }
1036 // Sampling { sample_rate: f32 },
1037 //
1038 // /// Aggregating policy that batches error reports
1039 // /// Implementation: Collect errors in Vec, flush on size or time trigger.
1040 // /// Use tokio::time::interval for periodic flushing.
1041 // /// Group errors by type/pattern for better insights.
1042 // /// Include error frequency and rate statistics in reports.
1043 // /// struct: { window: Duration, batch: Mutex<Vec<ErrorEntry>>, flush_task: JoinHandle }
1044 // Aggregating { window_secs: u64, max_batch_size: usize },
1045 //
1046 // /// Alerting policy that sends notifications on error patterns
1047 // /// Implementation: Use reqwest for webhook HTTP calls.
1048 // /// Rate-limit alerts to prevent spam (max N per minute).
1049 // /// Include error context, task info, and system metrics in payload.
1050 // /// Support multiple notification channels (webhook, email, slack).
1051 // /// struct: { client: reqwest::Client, rate_limiter: RateLimiter, alert_config: AlertConfig }
1052 // Alerting { webhook_url: String, severity_threshold: String },
1053}
1054
1055/// Response type for error handling policies
1056///
1057/// This enum defines how the TaskTracker should respond to task failures.
1058/// Currently provides minimal functionality with planned extensions for common patterns.
1059#[derive(Debug)]
1060pub enum ErrorResponse {
1061 /// Just fail this task - error will be logged/counted, but tracker continues
1062 Fail,
1063
1064 /// Shutdown this tracker and all child trackers
1065 Shutdown,
1066
1067 /// Execute custom error handling logic with full context access
1068 Custom(Box<dyn OnErrorAction>),
1069 // TODO: Future specialized error responses to implement:
1070 //
1071 // /// Retry the failed task with configurable strategy
1072 // /// Implementation: Add RetryStrategy trait with delay(), should_continue(attempt_count),
1073 // /// release_and_reacquire_resources() methods. TaskTracker handles retry loop with
1074 // /// attempt counting and resource management. Supports exponential backoff, jitter.
1075 // /// Usage: ErrorResponse::Retry(Box::new(ExponentialBackoff { max_attempts: 3, base_delay: 100ms }))
1076 // Retry(Box<dyn RetryStrategy>),
1077 //
1078 // /// Execute fallback logic, then follow secondary action
1079 // /// Implementation: Add FallbackAction trait with execute(error, task_id) -> Result<(), Error>.
1080 // /// Execute fallback first, then recursively handle the 'then' response based on fallback result.
1081 // /// Enables patterns like: try fallback, if it works continue, if it fails retry original task.
1082 // /// Usage: ErrorResponse::Fallback { fallback: Box::new(DefaultValue(42)), then: Box::new(ErrorResponse::Continue) }
1083 // Fallback { fallback: Box<dyn FallbackAction>, then: Box<ErrorResponse> },
1084 //
1085 // /// Restart task with preserved state (for long-running/stateful tasks)
1086 // /// Implementation: Add TaskState trait for serialize/deserialize state, RestartStrategy trait
1087 // /// with create_continuation_task(state) -> Future. Task saves checkpoints during execution,
1088 // /// on error returns StatefulTaskError containing preserved state. Policy can restart from checkpoint.
1089 // /// Usage: ErrorResponse::RestartWithState { state: checkpointed_state, strategy: Box::new(CheckpointRestart { ... }) }
1090 // RestartWithState { state: Box<dyn TaskState>, strategy: Box<dyn RestartStrategy> },
1091}
1092
1093/// Trait for implementing custom error handling actions
1094///
1095/// This provides full access to the task execution context for complex error handling
1096/// scenarios that don't fit into the built-in response patterns.
1097#[async_trait]
1098pub trait OnErrorAction: Send + Sync + std::fmt::Debug {
1099 /// Execute custom error handling logic
1100 ///
1101 /// # Arguments
1102 /// * `error` - The error that caused the task to fail
1103 /// * `task_id` - Unique identifier of the failed task
1104 /// * `attempt_count` - Number of times this task has been attempted (starts at 1)
1105 /// * `context` - Full execution context with access to scheduler, metrics, etc.
1106 ///
1107 /// # Returns
1108 /// ActionResult indicating what the TaskTracker should do next
1109 async fn execute(
1110 &self,
1111 error: &anyhow::Error,
1112 task_id: TaskId,
1113 attempt_count: u32,
1114 context: &TaskExecutionContext,
1115 ) -> ActionResult;
1116}
1117
1118/// Scheduler execution guard state for conditional re-acquisition during task retry loops
1119///
1120/// This controls whether a continuation should reuse the current scheduler execution permission
1121/// or go through the scheduler's acquisition process again.
1122#[derive(Debug, Clone, PartialEq, Eq)]
1123enum GuardState {
1124 /// Keep the current scheduler execution permission for immediate continuation
1125 ///
1126 /// The continuation will execute immediately without going through the scheduler again.
1127 /// This is efficient for simple retries that don't need delays or rate limiting.
1128 Keep,
1129
1130 /// Release current permission and re-acquire through scheduler before continuation
1131 ///
1132 /// The continuation will be subject to the scheduler's policies again (concurrency limits,
1133 /// rate limiting, backoff delays, etc.). Use this for implementing retry delays or
1134 /// when the scheduler needs to apply its policies to the retry attempt.
1135 Reschedule,
1136}
1137
1138/// Result of a custom error action execution
1139#[derive(Debug)]
1140pub enum ActionResult {
1141 /// Just fail this task (error was logged/handled by policy)
1142 ///
1143 /// This means the policy has handled the error appropriately (e.g., logged it,
1144 /// updated metrics, etc.) and the task should fail with this error.
1145 /// The task execution terminates here.
1146 Fail,
1147
1148 /// Continue execution with the provided task
1149 ///
1150 /// This provides a new executable to continue the retry loop with.
1151 /// The task execution continues with the provided continuation.
1152 Continue {
1153 continuation: Arc<dyn Continuation + Send + Sync + 'static>,
1154 },
1155
1156 /// Shutdown this tracker and all child trackers
1157 ///
1158 /// This triggers shutdown of the entire tracker hierarchy.
1159 /// All running and pending tasks will be cancelled.
1160 Shutdown,
1161}
1162
1163/// Execution context provided to custom error actions
1164///
1165/// This gives custom actions full access to the task execution environment
1166/// for implementing complex error handling scenarios.
1167pub struct TaskExecutionContext {
1168 /// Scheduler for reacquiring resources or checking state
1169 pub scheduler: Arc<dyn TaskScheduler>,
1170
1171 /// Metrics for custom tracking
1172 pub metrics: Arc<dyn HierarchicalTaskMetrics>,
1173 // TODO: Future context additions:
1174 // pub guard: Box<dyn ResourceGuard>, // Current resource guard (needs Debug impl)
1175 // pub cancel_token: CancellationToken, // For implementing custom cancellation
1176 // pub task_recreation: Box<dyn TaskRecreator>, // For implementing retry/restart
1177}
1178
1179/// Result of task execution - unified for both regular and cancellable tasks
1180#[derive(Debug)]
1181pub enum TaskExecutionResult<T> {
1182 /// Task completed successfully
1183 Success(T),
1184 /// Task was cancelled (only possible for cancellable tasks)
1185 Cancelled,
1186 /// Task failed with an error
1187 Error(anyhow::Error),
1188}
1189
1190/// Trait for executing different types of tasks in a unified way
1191#[async_trait]
1192trait TaskExecutor<T>: Send {
1193 /// Execute the task with the given cancellation token
1194 async fn execute(&mut self, cancel_token: CancellationToken) -> TaskExecutionResult<T>;
1195}
1196
1197/// Task executor for regular (non-cancellable) tasks
1198struct RegularTaskExecutor<F, T>
1199where
1200 F: Future<Output = Result<T>> + Send + 'static,
1201 T: Send + 'static,
1202{
1203 future: Option<F>,
1204 _phantom: std::marker::PhantomData<T>,
1205}
1206
1207impl<F, T> RegularTaskExecutor<F, T>
1208where
1209 F: Future<Output = Result<T>> + Send + 'static,
1210 T: Send + 'static,
1211{
1212 fn new(future: F) -> Self {
1213 Self {
1214 future: Some(future),
1215 _phantom: std::marker::PhantomData,
1216 }
1217 }
1218}
1219
1220#[async_trait]
1221impl<F, T> TaskExecutor<T> for RegularTaskExecutor<F, T>
1222where
1223 F: Future<Output = Result<T>> + Send + 'static,
1224 T: Send + 'static,
1225{
1226 async fn execute(&mut self, _cancel_token: CancellationToken) -> TaskExecutionResult<T> {
1227 if let Some(future) = self.future.take() {
1228 match future.await {
1229 Ok(value) => TaskExecutionResult::Success(value),
1230 Err(error) => TaskExecutionResult::Error(error),
1231 }
1232 } else {
1233 // This should never happen since regular tasks don't support retry
1234 TaskExecutionResult::Error(anyhow::anyhow!("Regular task already consumed"))
1235 }
1236 }
1237}
1238
1239/// Task executor for cancellable tasks
1240struct CancellableTaskExecutor<F, Fut, T>
1241where
1242 F: FnMut(CancellationToken) -> Fut + Send + 'static,
1243 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
1244 T: Send + 'static,
1245{
1246 task_fn: F,
1247}
1248
1249impl<F, Fut, T> CancellableTaskExecutor<F, Fut, T>
1250where
1251 F: FnMut(CancellationToken) -> Fut + Send + 'static,
1252 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
1253 T: Send + 'static,
1254{
1255 fn new(task_fn: F) -> Self {
1256 Self { task_fn }
1257 }
1258}
1259
1260#[async_trait]
1261impl<F, Fut, T> TaskExecutor<T> for CancellableTaskExecutor<F, Fut, T>
1262where
1263 F: FnMut(CancellationToken) -> Fut + Send + 'static,
1264 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
1265 T: Send + 'static,
1266{
1267 async fn execute(&mut self, cancel_token: CancellationToken) -> TaskExecutionResult<T> {
1268 let future = (self.task_fn)(cancel_token);
1269 match future.await {
1270 CancellableTaskResult::Ok(value) => TaskExecutionResult::Success(value),
1271 CancellableTaskResult::Cancelled => TaskExecutionResult::Cancelled,
1272 CancellableTaskResult::Err(error) => TaskExecutionResult::Error(error),
1273 }
1274 }
1275}
1276
1277/// Common functionality for policy Arc construction
1278///
1279/// This trait provides a standardized `new_arc()` method for all policy types,
1280/// eliminating the need for manual `Arc::new()` calls in client code.
1281pub trait ArcPolicy: Sized + Send + Sync + 'static {
1282 /// Create an Arc-wrapped instance of this policy
1283 fn new_arc(self) -> Arc<Self> {
1284 Arc::new(self)
1285 }
1286}
1287
1288/// Unique identifier for a task
1289#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1290pub struct TaskId(Uuid);
1291
1292impl TaskId {
1293 fn new() -> Self {
1294 Self(Uuid::new_v4())
1295 }
1296}
1297
1298impl std::fmt::Display for TaskId {
1299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1300 write!(f, "task-{}", self.0)
1301 }
1302}
1303
1304/// Result of task execution
1305#[derive(Debug, Clone, PartialEq)]
1306pub enum CompletionStatus {
1307 /// Task completed successfully
1308 Ok,
1309 /// Task was cancelled before or during execution
1310 Cancelled,
1311 /// Task failed with an error
1312 Failed(String),
1313}
1314
1315/// Result type for cancellable tasks that explicitly tracks cancellation
1316#[derive(Debug)]
1317pub enum CancellableTaskResult<T> {
1318 /// Task completed successfully
1319 Ok(T),
1320 /// Task was cancelled (either via token or shutdown)
1321 Cancelled,
1322 /// Task failed with an error
1323 Err(anyhow::Error),
1324}
1325
1326/// Result of scheduling a task
1327#[derive(Debug)]
1328pub enum SchedulingResult<T> {
1329 /// Task was executed and completed
1330 Execute(T),
1331 /// Task was cancelled before execution
1332 Cancelled,
1333 /// Task was rejected due to scheduling policy
1334 Rejected(String),
1335}
1336
1337/// Resource guard that manages task execution
1338///
1339/// This trait enforces proper cancellation semantics by separating resource
1340/// management from task execution. Once a guard is acquired, task execution
1341/// must always run to completion.
1342/// Resource guard for task execution
1343///
1344/// This trait represents resources (permits, slots, etc.) acquired from a scheduler
1345/// that must be held during task execution. The guard automatically releases
1346/// resources when dropped, implementing proper RAII semantics.
1347///
1348/// Guards are returned by `TaskScheduler::acquire_execution_slot()` and must
1349/// be held in scope while the task executes to ensure resources remain allocated.
1350pub trait ResourceGuard: Send + 'static {
1351 // Marker trait - resources are released via Drop on the concrete type
1352}
1353
1354/// Trait for implementing task scheduling policies
1355///
1356/// This trait enforces proper cancellation semantics by splitting resource
1357/// acquisition (which can be cancelled) from task execution (which cannot).
1358///
1359/// ## Design Philosophy
1360///
1361/// Tasks may or may not support cancellation (depending on whether they were
1362/// created with `spawn_cancellable` or regular `spawn`). This split design ensures:
1363///
1364/// - **Resource acquisition**: Can respect cancellation tokens to avoid unnecessary allocation
1365/// - **Task execution**: Always runs to completion; tasks handle their own cancellation
1366///
1367/// This makes it impossible to accidentally interrupt task execution with `tokio::select!`.
1368#[async_trait]
1369pub trait TaskScheduler: Send + Sync + std::fmt::Debug {
1370 /// Acquire resources needed for task execution and return a guard
1371 ///
1372 /// This method handles resource allocation (permits, queue slots, etc.) and
1373 /// can respect cancellation tokens to avoid unnecessary resource consumption.
1374 ///
1375 /// ## Cancellation Behavior
1376 ///
1377 /// The `cancel_token` is used for scheduler-level cancellation (e.g., "don't start new work").
1378 /// If cancellation is requested before or during resource acquisition, this method
1379 /// should return `SchedulingResult::Cancelled`.
1380 ///
1381 /// # Arguments
1382 /// * `cancel_token` - [`CancellationToken`] for scheduler-level cancellation
1383 ///
1384 /// # Returns
1385 /// * `SchedulingResult::Execute(guard)` - Resources acquired, ready to execute
1386 /// * `SchedulingResult::Cancelled` - Cancelled before or during resource acquisition
1387 /// * `SchedulingResult::Rejected(reason)` - Resources unavailable or policy violation
1388 async fn acquire_execution_slot(
1389 &self,
1390 cancel_token: CancellationToken,
1391 ) -> SchedulingResult<Box<dyn ResourceGuard>>;
1392}
1393
1394/// Trait for hierarchical task metrics that supports aggregation up the tracker tree
1395///
1396/// This trait provides different implementations for root and child trackers:
1397/// - Root trackers integrate with Prometheus metrics for observability
1398/// - Child trackers chain metric updates up to their parents for aggregation
1399/// - All implementations maintain thread-safe atomic operations
1400pub trait HierarchicalTaskMetrics: Send + Sync + std::fmt::Debug {
1401 /// Increment issued task counter
1402 fn increment_issued(&self);
1403
1404 /// Increment started task counter
1405 fn increment_started(&self);
1406
1407 /// Increment success counter
1408 fn increment_success(&self);
1409
1410 /// Increment cancelled counter
1411 fn increment_cancelled(&self);
1412
1413 /// Increment failed counter
1414 fn increment_failed(&self);
1415
1416 /// Increment rejected counter
1417 fn increment_rejected(&self);
1418
1419 /// Get current issued count (local to this tracker)
1420 fn issued(&self) -> u64;
1421
1422 /// Get current started count (local to this tracker)
1423 fn started(&self) -> u64;
1424
1425 /// Get current success count (local to this tracker)
1426 fn success(&self) -> u64;
1427
1428 /// Get current cancelled count (local to this tracker)
1429 fn cancelled(&self) -> u64;
1430
1431 /// Get current failed count (local to this tracker)
1432 fn failed(&self) -> u64;
1433
1434 /// Get current rejected count (local to this tracker)
1435 fn rejected(&self) -> u64;
1436
1437 /// Get total completed tasks (success + cancelled + failed + rejected)
1438 fn total_completed(&self) -> u64 {
1439 self.success() + self.cancelled() + self.failed() + self.rejected()
1440 }
1441
1442 /// Get number of pending tasks (issued - completed)
1443 fn pending(&self) -> u64 {
1444 self.issued().saturating_sub(self.total_completed())
1445 }
1446
1447 /// Get the number of tasks that are currently active (started - completed)
1448 fn active(&self) -> u64 {
1449 self.started().saturating_sub(self.total_completed())
1450 }
1451
1452 /// Get number of tasks queued in scheduler (issued - started)
1453 fn queued(&self) -> u64 {
1454 self.issued().saturating_sub(self.started())
1455 }
1456}
1457
1458/// Task execution metrics for a tracker
1459#[derive(Debug, Default)]
1460pub struct TaskMetrics {
1461 /// Number of tasks issued/submitted (via spawn methods)
1462 pub issued_count: AtomicU64,
1463 /// Number of tasks that have started execution
1464 pub started_count: AtomicU64,
1465 /// Number of successfully completed tasks
1466 pub success_count: AtomicU64,
1467 /// Number of cancelled tasks
1468 pub cancelled_count: AtomicU64,
1469 /// Number of failed tasks
1470 pub failed_count: AtomicU64,
1471 /// Number of rejected tasks (by scheduler)
1472 pub rejected_count: AtomicU64,
1473}
1474
1475impl TaskMetrics {
1476 /// Create new metrics instance
1477 pub fn new() -> Self {
1478 Self::default()
1479 }
1480}
1481
1482impl HierarchicalTaskMetrics for TaskMetrics {
1483 /// Increment issued task counter
1484 fn increment_issued(&self) {
1485 self.issued_count.fetch_add(1, Ordering::Relaxed);
1486 }
1487
1488 /// Increment started task counter
1489 fn increment_started(&self) {
1490 self.started_count.fetch_add(1, Ordering::Relaxed);
1491 }
1492
1493 /// Increment success counter
1494 fn increment_success(&self) {
1495 self.success_count.fetch_add(1, Ordering::Relaxed);
1496 }
1497
1498 /// Increment cancelled counter
1499 fn increment_cancelled(&self) {
1500 self.cancelled_count.fetch_add(1, Ordering::Relaxed);
1501 }
1502
1503 /// Increment failed counter
1504 fn increment_failed(&self) {
1505 self.failed_count.fetch_add(1, Ordering::Relaxed);
1506 }
1507
1508 /// Increment rejected counter
1509 fn increment_rejected(&self) {
1510 self.rejected_count.fetch_add(1, Ordering::Relaxed);
1511 }
1512
1513 /// Get current issued count
1514 fn issued(&self) -> u64 {
1515 self.issued_count.load(Ordering::Relaxed)
1516 }
1517
1518 /// Get current started count
1519 fn started(&self) -> u64 {
1520 self.started_count.load(Ordering::Relaxed)
1521 }
1522
1523 /// Get current success count
1524 fn success(&self) -> u64 {
1525 self.success_count.load(Ordering::Relaxed)
1526 }
1527
1528 /// Get current cancelled count
1529 fn cancelled(&self) -> u64 {
1530 self.cancelled_count.load(Ordering::Relaxed)
1531 }
1532
1533 /// Get current failed count
1534 fn failed(&self) -> u64 {
1535 self.failed_count.load(Ordering::Relaxed)
1536 }
1537
1538 /// Get current rejected count
1539 fn rejected(&self) -> u64 {
1540 self.rejected_count.load(Ordering::Relaxed)
1541 }
1542}
1543
1544/// Root tracker metrics with Prometheus integration
1545///
1546/// This implementation maintains local counters and exposes them as Prometheus metrics
1547/// through the provided MetricsRegistry.
1548#[derive(Debug)]
1549pub struct PrometheusTaskMetrics {
1550 /// Prometheus metrics integration
1551 prometheus_issued: prometheus::IntCounter,
1552 prometheus_started: prometheus::IntCounter,
1553 prometheus_success: prometheus::IntCounter,
1554 prometheus_cancelled: prometheus::IntCounter,
1555 prometheus_failed: prometheus::IntCounter,
1556 prometheus_rejected: prometheus::IntCounter,
1557}
1558
1559impl PrometheusTaskMetrics {
1560 /// Create new root metrics with Prometheus integration
1561 ///
1562 /// # Arguments
1563 /// * `registry` - MetricsRegistry for creating Prometheus metrics
1564 /// * `component_name` - Name for the component/tracker (used in metric names)
1565 ///
1566 /// # Example
1567 /// ```rust
1568 /// # use std::sync::Arc;
1569 /// # use dynamo_runtime::utils::tasks::tracker::PrometheusTaskMetrics;
1570 /// # use dynamo_runtime::metrics::MetricsRegistry;
1571 /// # fn example(registry: Arc<dyn MetricsRegistry>) -> anyhow::Result<()> {
1572 /// let metrics = PrometheusTaskMetrics::new(registry.as_ref(), "main_tracker")?;
1573 /// # Ok(())
1574 /// # }
1575 /// ```
1576 pub fn new<R: MetricsRegistry + ?Sized>(
1577 registry: &R,
1578 component_name: &str,
1579 ) -> anyhow::Result<Self> {
1580 let issued_counter = registry.create_intcounter(
1581 &format!("{}_{}", component_name, task_tracker::TASKS_ISSUED_TOTAL),
1582 "Total number of tasks issued/submitted",
1583 &[],
1584 )?;
1585
1586 let started_counter = registry.create_intcounter(
1587 &format!("{}_{}", component_name, task_tracker::TASKS_STARTED_TOTAL),
1588 "Total number of tasks started",
1589 &[],
1590 )?;
1591
1592 let success_counter = registry.create_intcounter(
1593 &format!("{}_{}", component_name, task_tracker::TASKS_SUCCESS_TOTAL),
1594 "Total number of successfully completed tasks",
1595 &[],
1596 )?;
1597
1598 let cancelled_counter = registry.create_intcounter(
1599 &format!("{}_{}", component_name, task_tracker::TASKS_CANCELLED_TOTAL),
1600 "Total number of cancelled tasks",
1601 &[],
1602 )?;
1603
1604 let failed_counter = registry.create_intcounter(
1605 &format!("{}_{}", component_name, task_tracker::TASKS_FAILED_TOTAL),
1606 "Total number of failed tasks",
1607 &[],
1608 )?;
1609
1610 let rejected_counter = registry.create_intcounter(
1611 &format!("{}_{}", component_name, task_tracker::TASKS_REJECTED_TOTAL),
1612 "Total number of rejected tasks",
1613 &[],
1614 )?;
1615
1616 Ok(Self {
1617 prometheus_issued: issued_counter,
1618 prometheus_started: started_counter,
1619 prometheus_success: success_counter,
1620 prometheus_cancelled: cancelled_counter,
1621 prometheus_failed: failed_counter,
1622 prometheus_rejected: rejected_counter,
1623 })
1624 }
1625}
1626
1627impl HierarchicalTaskMetrics for PrometheusTaskMetrics {
1628 fn increment_issued(&self) {
1629 self.prometheus_issued.inc();
1630 }
1631
1632 fn increment_started(&self) {
1633 self.prometheus_started.inc();
1634 }
1635
1636 fn increment_success(&self) {
1637 self.prometheus_success.inc();
1638 }
1639
1640 fn increment_cancelled(&self) {
1641 self.prometheus_cancelled.inc();
1642 }
1643
1644 fn increment_failed(&self) {
1645 self.prometheus_failed.inc();
1646 }
1647
1648 fn increment_rejected(&self) {
1649 self.prometheus_rejected.inc();
1650 }
1651
1652 fn issued(&self) -> u64 {
1653 self.prometheus_issued.get()
1654 }
1655
1656 fn started(&self) -> u64 {
1657 self.prometheus_started.get()
1658 }
1659
1660 fn success(&self) -> u64 {
1661 self.prometheus_success.get()
1662 }
1663
1664 fn cancelled(&self) -> u64 {
1665 self.prometheus_cancelled.get()
1666 }
1667
1668 fn failed(&self) -> u64 {
1669 self.prometheus_failed.get()
1670 }
1671
1672 fn rejected(&self) -> u64 {
1673 self.prometheus_rejected.get()
1674 }
1675}
1676
1677/// Child tracker metrics that chain updates to parent
1678///
1679/// This implementation maintains local counters and automatically forwards
1680/// all metric updates to the parent tracker for hierarchical aggregation.
1681/// Holds a strong reference to parent metrics for optimal performance.
1682#[derive(Debug)]
1683struct ChildTaskMetrics {
1684 /// Local metrics for this tracker
1685 local_metrics: TaskMetrics,
1686 /// Strong reference to parent metrics for fast chaining
1687 /// Safe to hold since metrics don't own trackers - no circular references
1688 parent_metrics: Arc<dyn HierarchicalTaskMetrics>,
1689}
1690
1691impl ChildTaskMetrics {
1692 fn new(parent_metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
1693 Self {
1694 local_metrics: TaskMetrics::new(),
1695 parent_metrics,
1696 }
1697 }
1698}
1699
1700impl HierarchicalTaskMetrics for ChildTaskMetrics {
1701 fn increment_issued(&self) {
1702 self.local_metrics.increment_issued();
1703 self.parent_metrics.increment_issued();
1704 }
1705
1706 fn increment_started(&self) {
1707 self.local_metrics.increment_started();
1708 self.parent_metrics.increment_started();
1709 }
1710
1711 fn increment_success(&self) {
1712 self.local_metrics.increment_success();
1713 self.parent_metrics.increment_success();
1714 }
1715
1716 fn increment_cancelled(&self) {
1717 self.local_metrics.increment_cancelled();
1718 self.parent_metrics.increment_cancelled();
1719 }
1720
1721 fn increment_failed(&self) {
1722 self.local_metrics.increment_failed();
1723 self.parent_metrics.increment_failed();
1724 }
1725
1726 fn increment_rejected(&self) {
1727 self.local_metrics.increment_rejected();
1728 self.parent_metrics.increment_rejected();
1729 }
1730
1731 fn issued(&self) -> u64 {
1732 self.local_metrics.issued()
1733 }
1734
1735 fn started(&self) -> u64 {
1736 self.local_metrics.started()
1737 }
1738
1739 fn success(&self) -> u64 {
1740 self.local_metrics.success()
1741 }
1742
1743 fn cancelled(&self) -> u64 {
1744 self.local_metrics.cancelled()
1745 }
1746
1747 fn failed(&self) -> u64 {
1748 self.local_metrics.failed()
1749 }
1750
1751 fn rejected(&self) -> u64 {
1752 self.local_metrics.rejected()
1753 }
1754}
1755
1756/// Builder for creating child trackers with custom policies
1757///
1758/// Allows flexible customization of scheduling and error handling policies
1759/// for child trackers while maintaining parent-child relationships.
1760pub struct ChildTrackerBuilder<'parent> {
1761 parent: &'parent TaskTracker,
1762 scheduler: Option<Arc<dyn TaskScheduler>>,
1763 error_policy: Option<Arc<dyn OnErrorPolicy>>,
1764}
1765
1766impl<'parent> ChildTrackerBuilder<'parent> {
1767 /// Create a new ChildTrackerBuilder
1768 pub fn new(parent: &'parent TaskTracker) -> Self {
1769 Self {
1770 parent,
1771 scheduler: None,
1772 error_policy: None,
1773 }
1774 }
1775
1776 /// Set custom scheduler for the child tracker
1777 ///
1778 /// If not set, the child will inherit the parent's scheduler.
1779 ///
1780 /// # Arguments
1781 /// * `scheduler` - The scheduler to use for this child tracker
1782 ///
1783 /// # Example
1784 /// ```rust
1785 /// # use std::sync::Arc;
1786 /// # use tokio::sync::Semaphore;
1787 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler};
1788 /// # fn example(parent: &TaskTracker) {
1789 /// let child = parent.child_tracker_builder()
1790 /// .scheduler(SemaphoreScheduler::with_permits(5))
1791 /// .build().unwrap();
1792 /// # }
1793 /// ```
1794 pub fn scheduler(mut self, scheduler: Arc<dyn TaskScheduler>) -> Self {
1795 self.scheduler = Some(scheduler);
1796 self
1797 }
1798
1799 /// Set custom error policy for the child tracker
1800 ///
1801 /// If not set, the child will get a child policy from the parent's error policy
1802 /// (via `OnErrorPolicy::create_child()`).
1803 ///
1804 /// # Arguments
1805 /// * `error_policy` - The error policy to use for this child tracker
1806 ///
1807 /// # Example
1808 /// ```rust
1809 /// # use std::sync::Arc;
1810 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, LogOnlyPolicy};
1811 /// # fn example(parent: &TaskTracker) {
1812 /// let child = parent.child_tracker_builder()
1813 /// .error_policy(LogOnlyPolicy::new())
1814 /// .build().unwrap();
1815 /// # }
1816 /// ```
1817 pub fn error_policy(mut self, error_policy: Arc<dyn OnErrorPolicy>) -> Self {
1818 self.error_policy = Some(error_policy);
1819 self
1820 }
1821
1822 /// Build the child tracker with the specified configuration
1823 ///
1824 /// Creates a new child tracker with:
1825 /// - Custom or inherited scheduler
1826 /// - Custom or child error policy
1827 /// - Hierarchical metrics that chain to parent
1828 /// - Child cancellation token from the parent
1829 /// - Independent lifecycle from parent
1830 ///
1831 /// # Returns
1832 /// A new `Arc<TaskTracker>` configured as a child of the parent
1833 ///
1834 /// # Errors
1835 /// Returns an error if the parent tracker is already closed
1836 pub fn build(self) -> anyhow::Result<TaskTracker> {
1837 // Validate that parent tracker is still active
1838 if self.parent.is_closed() {
1839 return Err(anyhow::anyhow!(
1840 "Cannot create child tracker from closed parent tracker"
1841 ));
1842 }
1843
1844 let parent = self.parent.0.clone();
1845
1846 let child_cancel_token = parent.cancel_token.child_token();
1847 let child_metrics = Arc::new(ChildTaskMetrics::new(parent.metrics.clone()));
1848
1849 // Use provided scheduler or inherit from parent
1850 let scheduler = self.scheduler.unwrap_or_else(|| parent.scheduler.clone());
1851
1852 // Use provided error policy or create child from parent's
1853 let error_policy = self
1854 .error_policy
1855 .unwrap_or_else(|| parent.error_policy.create_child());
1856
1857 let child = Arc::new(TaskTrackerInner {
1858 tokio_tracker: TokioTaskTracker::new(),
1859 parent: None, // No parent reference needed for hierarchical operations
1860 scheduler,
1861 error_policy,
1862 metrics: child_metrics,
1863 cancel_token: child_cancel_token,
1864 children: RwLock::new(Vec::new()),
1865 });
1866
1867 // Register this child with the parent for hierarchical operations
1868 parent
1869 .children
1870 .write()
1871 .unwrap()
1872 .push(Arc::downgrade(&child));
1873
1874 // Periodically clean up dead children to prevent unbounded growth
1875 parent.cleanup_dead_children();
1876
1877 Ok(TaskTracker(child))
1878 }
1879}
1880
1881/// Internal data for TaskTracker
1882///
1883/// This struct contains all the actual state and functionality of a TaskTracker.
1884/// TaskTracker itself is just a wrapper around Arc<TaskTrackerInner>.
1885struct TaskTrackerInner {
1886 /// Tokio's task tracker for lifecycle management
1887 tokio_tracker: TokioTaskTracker,
1888 /// Parent tracker (None for root)
1889 parent: Option<Arc<TaskTrackerInner>>,
1890 /// Scheduling policy (shared with children by default)
1891 scheduler: Arc<dyn TaskScheduler>,
1892 /// Error handling policy (child-specific via create_child)
1893 error_policy: Arc<dyn OnErrorPolicy>,
1894 /// Metrics for this tracker
1895 metrics: Arc<dyn HierarchicalTaskMetrics>,
1896 /// Cancellation token for this tracker (always present)
1897 cancel_token: CancellationToken,
1898 /// List of child trackers for hierarchical operations
1899 children: RwLock<Vec<Weak<TaskTrackerInner>>>,
1900}
1901
1902/// Hierarchical task tracker with pluggable scheduling and error policies
1903///
1904/// TaskTracker provides a composable system for managing background tasks with:
1905/// - Configurable scheduling via [`TaskScheduler`] implementations
1906/// - Flexible error handling via [`OnErrorPolicy`] implementations
1907/// - Parent-child relationships with independent metrics
1908/// - Cancellation propagation and isolation
1909/// - Built-in cancellation token support
1910///
1911/// Built on top of `tokio_util::task::TaskTracker` for robust task lifecycle management.
1912///
1913/// # Example
1914///
1915/// ```rust
1916/// # use std::sync::Arc;
1917/// # use tokio::sync::Semaphore;
1918/// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy, CancellableTaskResult};
1919/// # async fn example() -> anyhow::Result<()> {
1920/// // Create a task tracker with semaphore-based scheduling
1921/// let scheduler = SemaphoreScheduler::with_permits(3);
1922/// let policy = LogOnlyPolicy::new();
1923/// let root = TaskTracker::builder()
1924/// .scheduler(scheduler)
1925/// .error_policy(policy)
1926/// .build()?;
1927///
1928/// // Spawn some tasks
1929/// let handle1 = root.spawn(async { Ok(1) });
1930/// let handle2 = root.spawn(async { Ok(2) });
1931///
1932/// // Get results and join all tasks
1933/// let result1 = handle1.await.unwrap().unwrap();
1934/// let result2 = handle2.await.unwrap().unwrap();
1935/// assert_eq!(result1, 1);
1936/// assert_eq!(result2, 2);
1937/// # Ok(())
1938/// # }
1939/// ```
1940#[derive(Clone)]
1941pub struct TaskTracker(Arc<TaskTrackerInner>);
1942
1943/// Builder for TaskTracker
1944#[derive(Default)]
1945pub struct TaskTrackerBuilder {
1946 scheduler: Option<Arc<dyn TaskScheduler>>,
1947 error_policy: Option<Arc<dyn OnErrorPolicy>>,
1948 metrics: Option<Arc<dyn HierarchicalTaskMetrics>>,
1949 cancel_token: Option<CancellationToken>,
1950}
1951
1952impl TaskTrackerBuilder {
1953 /// Set the scheduler for this TaskTracker
1954 pub fn scheduler(mut self, scheduler: Arc<dyn TaskScheduler>) -> Self {
1955 self.scheduler = Some(scheduler);
1956 self
1957 }
1958
1959 /// Set the error policy for this TaskTracker
1960 pub fn error_policy(mut self, error_policy: Arc<dyn OnErrorPolicy>) -> Self {
1961 self.error_policy = Some(error_policy);
1962 self
1963 }
1964
1965 /// Set custom metrics for this TaskTracker
1966 pub fn metrics(mut self, metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
1967 self.metrics = Some(metrics);
1968 self
1969 }
1970
1971 /// Set the cancellation token for this TaskTracker
1972 pub fn cancel_token(mut self, cancel_token: CancellationToken) -> Self {
1973 self.cancel_token = Some(cancel_token);
1974 self
1975 }
1976
1977 /// Build the TaskTracker
1978 pub fn build(self) -> anyhow::Result<TaskTracker> {
1979 let scheduler = self
1980 .scheduler
1981 .ok_or_else(|| anyhow::anyhow!("TaskTracker requires a scheduler"))?;
1982
1983 let error_policy = self
1984 .error_policy
1985 .ok_or_else(|| anyhow::anyhow!("TaskTracker requires an error policy"))?;
1986
1987 let metrics = self.metrics.unwrap_or_else(|| Arc::new(TaskMetrics::new()));
1988
1989 let cancel_token = self.cancel_token.unwrap_or_default();
1990
1991 let inner = TaskTrackerInner {
1992 tokio_tracker: TokioTaskTracker::new(),
1993 parent: None,
1994 scheduler,
1995 error_policy,
1996 metrics,
1997 cancel_token,
1998 children: RwLock::new(Vec::new()),
1999 };
2000
2001 Ok(TaskTracker(Arc::new(inner)))
2002 }
2003}
2004
2005impl TaskTracker {
2006 /// Create a new root task tracker using the builder pattern
2007 ///
2008 /// This is the preferred way to create new task trackers.
2009 ///
2010 /// # Example
2011 /// ```rust
2012 /// # use std::sync::Arc;
2013 /// # use tokio::sync::Semaphore;
2014 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2015 /// # fn main() -> anyhow::Result<()> {
2016 /// let scheduler = SemaphoreScheduler::with_permits(10);
2017 /// let error_policy = LogOnlyPolicy::new();
2018 /// let tracker = TaskTracker::builder()
2019 /// .scheduler(scheduler)
2020 /// .error_policy(error_policy)
2021 /// .build()?;
2022 /// # Ok(())
2023 /// # }
2024 /// ```
2025 pub fn builder() -> TaskTrackerBuilder {
2026 TaskTrackerBuilder::default()
2027 }
2028
2029 /// Create a new root task tracker with simple parameters (legacy)
2030 ///
2031 /// This method is kept for backward compatibility. Use `builder()` for new code.
2032 /// Uses default metrics (no Prometheus integration).
2033 ///
2034 /// # Arguments
2035 /// * `scheduler` - Scheduling policy to use for all tasks
2036 /// * `error_policy` - Error handling policy for this tracker
2037 ///
2038 /// # Example
2039 /// ```rust
2040 /// # use std::sync::Arc;
2041 /// # use tokio::sync::Semaphore;
2042 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2043 /// # fn main() -> anyhow::Result<()> {
2044 /// let scheduler = SemaphoreScheduler::with_permits(10);
2045 /// let error_policy = LogOnlyPolicy::new();
2046 /// let tracker = TaskTracker::new(scheduler, error_policy)?;
2047 /// # Ok(())
2048 /// # }
2049 /// ```
2050 pub fn new(
2051 scheduler: Arc<dyn TaskScheduler>,
2052 error_policy: Arc<dyn OnErrorPolicy>,
2053 ) -> anyhow::Result<Self> {
2054 Self::builder()
2055 .scheduler(scheduler)
2056 .error_policy(error_policy)
2057 .build()
2058 }
2059
2060 /// Create a new root task tracker with Prometheus metrics integration
2061 ///
2062 /// # Arguments
2063 /// * `scheduler` - Scheduling policy to use for all tasks
2064 /// * `error_policy` - Error handling policy for this tracker
2065 /// * `registry` - MetricsRegistry for Prometheus integration
2066 /// * `component_name` - Name for this tracker component
2067 ///
2068 /// # Example
2069 /// ```rust
2070 /// # use std::sync::Arc;
2071 /// # use tokio::sync::Semaphore;
2072 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2073 /// # use dynamo_runtime::metrics::MetricsRegistry;
2074 /// # fn example(registry: Arc<dyn MetricsRegistry>) -> anyhow::Result<()> {
2075 /// let scheduler = SemaphoreScheduler::with_permits(10);
2076 /// let error_policy = LogOnlyPolicy::new();
2077 /// let tracker = TaskTracker::new_with_prometheus(
2078 /// scheduler,
2079 /// error_policy,
2080 /// registry.as_ref(),
2081 /// "main_tracker"
2082 /// )?;
2083 /// # Ok(())
2084 /// # }
2085 /// ```
2086 pub fn new_with_prometheus<R: MetricsRegistry + ?Sized>(
2087 scheduler: Arc<dyn TaskScheduler>,
2088 error_policy: Arc<dyn OnErrorPolicy>,
2089 registry: &R,
2090 component_name: &str,
2091 ) -> anyhow::Result<Self> {
2092 let prometheus_metrics = Arc::new(PrometheusTaskMetrics::new(registry, component_name)?);
2093
2094 Self::builder()
2095 .scheduler(scheduler)
2096 .error_policy(error_policy)
2097 .metrics(prometheus_metrics)
2098 .build()
2099 }
2100
2101 /// Create a child tracker that inherits scheduling policy
2102 ///
2103 /// The child tracker:
2104 /// - Gets its own independent tokio TaskTracker
2105 /// - Inherits the parent's scheduler
2106 /// - Gets a child error policy via `create_child()`
2107 /// - Has hierarchical metrics that chain to parent
2108 /// - Gets a child cancellation token from the parent
2109 /// - Is independent for cancellation (child cancellation doesn't affect parent)
2110 ///
2111 /// # Errors
2112 /// Returns an error if the parent tracker is already closed
2113 ///
2114 /// # Example
2115 /// ```rust
2116 /// # use std::sync::Arc;
2117 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2118 /// # fn example(root_tracker: TaskTracker) -> anyhow::Result<()> {
2119 /// let child_tracker = root_tracker.child_tracker()?;
2120 /// // Child inherits parent's policies but has separate metrics and lifecycle
2121 /// # Ok(())
2122 /// # }
2123 /// ```
2124 pub fn child_tracker(&self) -> anyhow::Result<TaskTracker> {
2125 Ok(TaskTracker(self.0.child_tracker()?))
2126 }
2127
2128 /// Create a child tracker builder for flexible customization
2129 ///
2130 /// The builder allows you to customize scheduling and error policies for the child tracker.
2131 /// If not specified, policies are inherited from the parent.
2132 ///
2133 /// # Example
2134 /// ```rust
2135 /// # use std::sync::Arc;
2136 /// # use tokio::sync::Semaphore;
2137 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2138 /// # fn example(root_tracker: TaskTracker) {
2139 /// // Custom scheduler, inherit error policy
2140 /// let child1 = root_tracker.child_tracker_builder()
2141 /// .scheduler(SemaphoreScheduler::with_permits(5))
2142 /// .build().unwrap();
2143 ///
2144 /// // Custom error policy, inherit scheduler
2145 /// let child2 = root_tracker.child_tracker_builder()
2146 /// .error_policy(LogOnlyPolicy::new())
2147 /// .build().unwrap();
2148 ///
2149 /// // Both custom
2150 /// let child3 = root_tracker.child_tracker_builder()
2151 /// .scheduler(SemaphoreScheduler::with_permits(3))
2152 /// .error_policy(LogOnlyPolicy::new())
2153 /// .build().unwrap();
2154 /// # }
2155 /// ```
2156 /// Spawn a new task
2157 ///
2158 /// The task will be wrapped with scheduling and error handling logic,
2159 /// then executed according to the configured policies. For tasks that
2160 /// need to inspect cancellation tokens, use [`spawn_cancellable`] instead.
2161 ///
2162 /// # Arguments
2163 /// * `future` - The async task to execute
2164 ///
2165 /// # Returns
2166 /// A [`TaskHandle`] that can be used to await completion and access the task's cancellation token
2167 ///
2168 /// # Panics
2169 /// Panics if the tracker has been closed. This indicates a programming error
2170 /// where tasks are being spawned after the tracker lifecycle has ended.
2171 ///
2172 /// # Example
2173 /// ```rust
2174 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2175 /// # async fn example(tracker: TaskTracker) -> anyhow::Result<()> {
2176 /// let handle = tracker.spawn(async {
2177 /// // Your async work here
2178 /// tokio::time::sleep(std::time::Duration::from_millis(100)).await;
2179 /// Ok(42)
2180 /// });
2181 ///
2182 /// // Access the task's cancellation token
2183 /// let cancel_token = handle.cancellation_token();
2184 ///
2185 /// let result = handle.await?;
2186 /// # Ok(())
2187 /// # }
2188 /// ```
2189 pub fn spawn<F, T>(&self, future: F) -> TaskHandle<T>
2190 where
2191 F: Future<Output = Result<T>> + Send + 'static,
2192 T: Send + 'static,
2193 {
2194 self.0
2195 .spawn(future)
2196 .expect("TaskTracker must not be closed when spawning tasks")
2197 }
2198
2199 /// Spawn a cancellable task that receives a cancellation token
2200 ///
2201 /// This is useful for tasks that need to inspect the cancellation token
2202 /// and gracefully handle cancellation within their logic. The task function
2203 /// must return a `CancellableTaskResult` to properly track cancellation vs errors.
2204 ///
2205 /// # Arguments
2206 ///
2207 /// * `task_fn` - Function that takes a cancellation token and returns a future that resolves to `CancellableTaskResult<T>`
2208 ///
2209 /// # Returns
2210 /// A [`TaskHandle`] that can be used to await completion and access the task's cancellation token
2211 ///
2212 /// # Panics
2213 /// Panics if the tracker has been closed. This indicates a programming error
2214 /// where tasks are being spawned after the tracker lifecycle has ended.
2215 ///
2216 /// # Example
2217 /// ```rust
2218 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, CancellableTaskResult};
2219 /// # async fn example(tracker: TaskTracker) -> anyhow::Result<()> {
2220 /// let handle = tracker.spawn_cancellable(|cancel_token| async move {
2221 /// tokio::select! {
2222 /// _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
2223 /// CancellableTaskResult::Ok(42)
2224 /// },
2225 /// _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
2226 /// }
2227 /// });
2228 ///
2229 /// // Access the task's individual cancellation token
2230 /// let task_cancel_token = handle.cancellation_token();
2231 ///
2232 /// let result = handle.await?;
2233 /// # Ok(())
2234 /// # }
2235 /// ```
2236 pub fn spawn_cancellable<F, Fut, T>(&self, task_fn: F) -> TaskHandle<T>
2237 where
2238 F: FnMut(CancellationToken) -> Fut + Send + 'static,
2239 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
2240 T: Send + 'static,
2241 {
2242 self.0
2243 .spawn_cancellable(task_fn)
2244 .expect("TaskTracker must not be closed when spawning tasks")
2245 }
2246
2247 /// Get metrics for this tracker
2248 ///
2249 /// Metrics are specific to this tracker and do not include
2250 /// metrics from parent or child trackers.
2251 ///
2252 /// # Example
2253 /// ```rust
2254 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2255 /// # fn example(tracker: &TaskTracker) {
2256 /// let metrics = tracker.metrics();
2257 /// println!("Success: {}, Failed: {}", metrics.success(), metrics.failed());
2258 /// # }
2259 /// ```
2260 pub fn metrics(&self) -> &dyn HierarchicalTaskMetrics {
2261 self.0.metrics.as_ref()
2262 }
2263
2264 /// Cancel this tracker and all its tasks
2265 ///
2266 /// This will signal cancellation to all currently running tasks and prevent new tasks from being spawned.
2267 /// The cancellation is immediate and forceful.
2268 ///
2269 /// # Example
2270 /// ```rust
2271 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2272 /// # async fn example(tracker: TaskTracker) -> anyhow::Result<()> {
2273 /// // Spawn a long-running task
2274 /// let handle = tracker.spawn_cancellable(|cancel_token| async move {
2275 /// tokio::select! {
2276 /// _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
2277 /// dynamo_runtime::utils::tasks::tracker::CancellableTaskResult::Ok(42)
2278 /// }
2279 /// _ = cancel_token.cancelled() => {
2280 /// dynamo_runtime::utils::tasks::tracker::CancellableTaskResult::Cancelled
2281 /// }
2282 /// }
2283 /// }).await?;
2284 ///
2285 /// // Cancel the tracker (and thus the task)
2286 /// tracker.cancel();
2287 /// # Ok(())
2288 /// # }
2289 /// ```
2290 pub fn cancel(&self) {
2291 self.0.cancel();
2292 }
2293
2294 /// Check if this tracker is closed
2295 pub fn is_closed(&self) -> bool {
2296 self.0.is_closed()
2297 }
2298
2299 /// Get the cancellation token for this tracker
2300 ///
2301 /// This allows external code to observe or trigger cancellation of this tracker.
2302 ///
2303 /// # Example
2304 /// ```rust
2305 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2306 /// # fn example(tracker: &TaskTracker) {
2307 /// let token = tracker.cancellation_token();
2308 /// // Can check cancellation state or cancel manually
2309 /// if !token.is_cancelled() {
2310 /// token.cancel();
2311 /// }
2312 /// # }
2313 /// ```
2314 pub fn cancellation_token(&self) -> CancellationToken {
2315 self.0.cancellation_token()
2316 }
2317
2318 /// Get the number of active child trackers
2319 ///
2320 /// This counts only child trackers that are still alive (not dropped).
2321 /// Dropped child trackers are automatically cleaned up.
2322 ///
2323 /// # Example
2324 /// ```rust
2325 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2326 /// # fn example(tracker: &TaskTracker) {
2327 /// let child_count = tracker.child_count();
2328 /// println!("This tracker has {} active children", child_count);
2329 /// # }
2330 /// ```
2331 pub fn child_count(&self) -> usize {
2332 self.0.child_count()
2333 }
2334
2335 /// Create a child tracker builder with custom configuration
2336 ///
2337 /// This provides fine-grained control over child tracker creation,
2338 /// allowing you to override the scheduler or error policy while
2339 /// maintaining the parent-child relationship.
2340 ///
2341 /// # Example
2342 /// ```rust
2343 /// # use std::sync::Arc;
2344 /// # use tokio::sync::Semaphore;
2345 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2346 /// # fn example(parent: &TaskTracker) {
2347 /// // Custom scheduler, inherit error policy
2348 /// let child1 = parent.child_tracker_builder()
2349 /// .scheduler(SemaphoreScheduler::with_permits(5))
2350 /// .build().unwrap();
2351 ///
2352 /// // Custom error policy, inherit scheduler
2353 /// let child2 = parent.child_tracker_builder()
2354 /// .error_policy(LogOnlyPolicy::new())
2355 /// .build().unwrap();
2356 ///
2357 /// // Inherit both policies from parent
2358 /// let child3 = parent.child_tracker_builder()
2359 /// .build().unwrap();
2360 /// # }
2361 /// ```
2362 pub fn child_tracker_builder(&self) -> ChildTrackerBuilder<'_> {
2363 ChildTrackerBuilder::new(self)
2364 }
2365
2366 /// Join this tracker and all child trackers
2367 ///
2368 /// This method gracefully shuts down the entire tracker hierarchy by:
2369 /// 1. Closing all trackers (preventing new task spawning)
2370 /// 2. Waiting for all existing tasks to complete
2371 ///
2372 /// Uses stack-safe traversal to prevent stack overflow in deep hierarchies.
2373 /// Children are processed before parents to ensure proper shutdown order.
2374 ///
2375 /// **Hierarchical Behavior:**
2376 /// - Processes children before parents to ensure proper shutdown order
2377 /// - Each tracker is closed before waiting (Tokio requirement)
2378 /// - Leaf trackers simply close and wait for their own tasks
2379 ///
2380 /// # Example
2381 /// ```rust
2382 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2383 /// # async fn example(tracker: TaskTracker) {
2384 /// tracker.join().await;
2385 /// # }
2386 /// ```
2387 pub async fn join(&self) {
2388 self.0.join().await
2389 }
2390}
2391
2392impl TaskTrackerInner {
2393 /// Creates child tracker with inherited scheduler/policy, independent metrics, and hierarchical cancellation
2394 fn child_tracker(self: &Arc<Self>) -> anyhow::Result<Arc<TaskTrackerInner>> {
2395 // Validate that parent tracker is still active
2396 if self.is_closed() {
2397 return Err(anyhow::anyhow!(
2398 "Cannot create child tracker from closed parent tracker"
2399 ));
2400 }
2401
2402 let child_cancel_token = self.cancel_token.child_token();
2403 let child_metrics = Arc::new(ChildTaskMetrics::new(self.metrics.clone()));
2404
2405 let child = Arc::new(TaskTrackerInner {
2406 tokio_tracker: TokioTaskTracker::new(),
2407 parent: Some(self.clone()),
2408 scheduler: self.scheduler.clone(),
2409 error_policy: self.error_policy.create_child(),
2410 metrics: child_metrics,
2411 cancel_token: child_cancel_token,
2412 children: RwLock::new(Vec::new()),
2413 });
2414
2415 // Register this child with the parent for hierarchical operations
2416 self.children.write().unwrap().push(Arc::downgrade(&child));
2417
2418 // Periodically clean up dead children to prevent unbounded growth
2419 self.cleanup_dead_children();
2420
2421 Ok(child)
2422 }
2423
2424 /// Spawn implementation - validates tracker state, generates task ID, applies policies, and tracks execution
2425 fn spawn<F, T>(self: &Arc<Self>, future: F) -> Result<TaskHandle<T>, TaskError>
2426 where
2427 F: Future<Output = Result<T>> + Send + 'static,
2428 T: Send + 'static,
2429 {
2430 // Validate tracker is not closed
2431 if self.tokio_tracker.is_closed() {
2432 return Err(TaskError::TrackerClosed);
2433 }
2434
2435 // Generate a unique task ID
2436 let task_id = self.generate_task_id();
2437
2438 // Increment issued counter immediately when task is submitted
2439 self.metrics.increment_issued();
2440
2441 // Create a child cancellation token for this specific task
2442 let task_cancel_token = self.cancel_token.child_token();
2443 let cancel_token = task_cancel_token.clone();
2444
2445 // Clone the inner Arc to move into the task
2446 let inner = self.clone();
2447
2448 // Wrap the user's future with our scheduling and error handling
2449 let wrapped_future =
2450 async move { Self::execute_with_policies(task_id, future, cancel_token, inner).await };
2451
2452 // Let tokio handle the actual task tracking
2453 let join_handle = self.tokio_tracker.spawn(wrapped_future);
2454
2455 // Wrap in TaskHandle with the child cancellation token
2456 Ok(TaskHandle::new(join_handle, task_cancel_token))
2457 }
2458
2459 /// Spawn cancellable implementation - validates state, provides cancellation token, handles CancellableTaskResult
2460 fn spawn_cancellable<F, Fut, T>(
2461 self: &Arc<Self>,
2462 task_fn: F,
2463 ) -> Result<TaskHandle<T>, TaskError>
2464 where
2465 F: FnMut(CancellationToken) -> Fut + Send + 'static,
2466 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
2467 T: Send + 'static,
2468 {
2469 // Validate tracker is not closed
2470 if self.tokio_tracker.is_closed() {
2471 return Err(TaskError::TrackerClosed);
2472 }
2473
2474 // Generate a unique task ID
2475 let task_id = self.generate_task_id();
2476
2477 // Increment issued counter immediately when task is submitted
2478 self.metrics.increment_issued();
2479
2480 // Create a child cancellation token for this specific task
2481 let task_cancel_token = self.cancel_token.child_token();
2482 let cancel_token = task_cancel_token.clone();
2483
2484 // Clone the inner Arc to move into the task
2485 let inner = self.clone();
2486
2487 // Use the new execution pipeline that defers task creation until after guard acquisition
2488 let wrapped_future = async move {
2489 Self::execute_cancellable_with_policies(task_id, task_fn, cancel_token, inner).await
2490 };
2491
2492 // Let tokio handle the actual task tracking
2493 let join_handle = self.tokio_tracker.spawn(wrapped_future);
2494
2495 // Wrap in TaskHandle with the child cancellation token
2496 Ok(TaskHandle::new(join_handle, task_cancel_token))
2497 }
2498
2499 /// Cancel this tracker and all its tasks - implementation
2500 fn cancel(&self) {
2501 // Close the tracker to prevent new tasks
2502 self.tokio_tracker.close();
2503
2504 // Cancel our own token
2505 self.cancel_token.cancel();
2506 }
2507
2508 /// Returns true if the underlying tokio tracker is closed
2509 fn is_closed(&self) -> bool {
2510 self.tokio_tracker.is_closed()
2511 }
2512
2513 /// Generates a unique task ID using TaskId::new()
2514 fn generate_task_id(&self) -> TaskId {
2515 TaskId::new()
2516 }
2517
2518 /// Removes dead weak references from children list to prevent memory leaks
2519 fn cleanup_dead_children(&self) {
2520 let mut children_guard = self.children.write().unwrap();
2521 children_guard.retain(|weak| weak.upgrade().is_some());
2522 }
2523
2524 /// Returns a clone of the cancellation token
2525 fn cancellation_token(&self) -> CancellationToken {
2526 self.cancel_token.clone()
2527 }
2528
2529 /// Counts active child trackers (filters out dead weak references)
2530 fn child_count(&self) -> usize {
2531 let children_guard = self.children.read().unwrap();
2532 children_guard
2533 .iter()
2534 .filter(|weak| weak.upgrade().is_some())
2535 .count()
2536 }
2537
2538 /// Join implementation - closes all trackers in hierarchy then waits for task completion using stack-safe traversal
2539 async fn join(self: &Arc<Self>) {
2540 // Fast path for leaf trackers (no children)
2541 let is_leaf = {
2542 let children_guard = self.children.read().unwrap();
2543 children_guard.is_empty()
2544 };
2545
2546 if is_leaf {
2547 self.tokio_tracker.close();
2548 self.tokio_tracker.wait().await;
2549 return;
2550 }
2551
2552 // Stack-safe traversal for deep hierarchies
2553 // Processes children before parents to ensure proper shutdown order
2554 let trackers = self.collect_hierarchy();
2555 for t in trackers {
2556 t.tokio_tracker.close();
2557 t.tokio_tracker.wait().await;
2558 }
2559 }
2560
2561 /// Collects hierarchy using iterative DFS, returns Vec in post-order (children before parents) for safe shutdown
2562 fn collect_hierarchy(self: &Arc<TaskTrackerInner>) -> Vec<Arc<TaskTrackerInner>> {
2563 let mut result = Vec::new();
2564 let mut stack = vec![self.clone()];
2565 let mut visited = HashSet::new();
2566
2567 // Collect all trackers using depth-first search
2568 while let Some(tracker) = stack.pop() {
2569 let tracker_ptr = Arc::as_ptr(&tracker) as usize;
2570 if visited.contains(&tracker_ptr) {
2571 continue;
2572 }
2573 visited.insert(tracker_ptr);
2574
2575 // Add current tracker to result
2576 result.push(tracker.clone());
2577
2578 // Add children to stack for processing
2579 if let Ok(children_guard) = tracker.children.read() {
2580 for weak_child in children_guard.iter() {
2581 if let Some(child) = weak_child.upgrade() {
2582 let child_ptr = Arc::as_ptr(&child) as usize;
2583 if !visited.contains(&child_ptr) {
2584 stack.push(child);
2585 }
2586 }
2587 }
2588 }
2589 }
2590
2591 // Reverse to get bottom-up order (children before parents)
2592 result.reverse();
2593 result
2594 }
2595
2596 /// Execute a regular task with scheduling and error handling policies
2597 #[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
2598 async fn execute_with_policies<F, T>(
2599 task_id: TaskId,
2600 future: F,
2601 task_cancel_token: CancellationToken,
2602 inner: Arc<TaskTrackerInner>,
2603 ) -> Result<T, TaskError>
2604 where
2605 F: Future<Output = Result<T>> + Send + 'static,
2606 T: Send + 'static,
2607 {
2608 // Wrap regular future in a task executor that doesn't support retry
2609 let task_executor = RegularTaskExecutor::new(future);
2610 Self::execute_with_retry_loop(task_id, task_executor, task_cancel_token, inner).await
2611 }
2612
2613 /// Execute a cancellable task with scheduling and error handling policies
2614 #[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
2615 async fn execute_cancellable_with_policies<F, Fut, T>(
2616 task_id: TaskId,
2617 task_fn: F,
2618 task_cancel_token: CancellationToken,
2619 inner: Arc<TaskTrackerInner>,
2620 ) -> Result<T, TaskError>
2621 where
2622 F: FnMut(CancellationToken) -> Fut + Send + 'static,
2623 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
2624 T: Send + 'static,
2625 {
2626 // Wrap cancellable task function in a task executor that supports retry
2627 let task_executor = CancellableTaskExecutor::new(task_fn);
2628 Self::execute_with_retry_loop(task_id, task_executor, task_cancel_token, inner).await
2629 }
2630
2631 /// Core execution loop with retry support - unified for both task types
2632 #[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
2633 async fn execute_with_retry_loop<E, T>(
2634 task_id: TaskId,
2635 initial_executor: E,
2636 task_cancellation_token: CancellationToken,
2637 inner: Arc<TaskTrackerInner>,
2638 ) -> Result<T, TaskError>
2639 where
2640 E: TaskExecutor<T> + Send + 'static,
2641 T: Send + 'static,
2642 {
2643 debug!("Starting task execution");
2644
2645 // RAII guard for active counter - increments on creation, decrements on drop
2646 struct ActiveCountGuard {
2647 metrics: Arc<dyn HierarchicalTaskMetrics>,
2648 is_active: bool,
2649 }
2650
2651 impl ActiveCountGuard {
2652 fn new(metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
2653 Self {
2654 metrics,
2655 is_active: false,
2656 }
2657 }
2658
2659 fn activate(&mut self) {
2660 if !self.is_active {
2661 self.metrics.increment_started();
2662 self.is_active = true;
2663 }
2664 }
2665 }
2666
2667 // Current executable - either the original TaskExecutor or a Continuation
2668 enum CurrentExecutable<E>
2669 where
2670 E: Send + 'static,
2671 {
2672 TaskExecutor(E),
2673 Continuation(Arc<dyn Continuation + Send + Sync + 'static>),
2674 }
2675
2676 let mut current_executable = CurrentExecutable::TaskExecutor(initial_executor);
2677 let mut active_guard = ActiveCountGuard::new(inner.metrics.clone());
2678 let mut error_context: Option<OnErrorContext> = None;
2679 let mut scheduler_guard_state = self::GuardState::Keep;
2680 let mut guard_result = async {
2681 inner
2682 .scheduler
2683 .acquire_execution_slot(task_cancellation_token.child_token())
2684 .await
2685 }
2686 .instrument(tracing::debug_span!("scheduler_resource_reacquisition"))
2687 .await;
2688
2689 loop {
2690 if scheduler_guard_state == self::GuardState::Reschedule {
2691 guard_result = async {
2692 inner
2693 .scheduler
2694 .acquire_execution_slot(inner.cancel_token.child_token())
2695 .await
2696 }
2697 .instrument(tracing::debug_span!("scheduler_resource_reacquisition"))
2698 .await;
2699 }
2700
2701 match &guard_result {
2702 SchedulingResult::Execute(_guard) => {
2703 // Activate the RAII guard only once when we successfully acquire resources
2704 active_guard.activate();
2705
2706 // Execute the current executable while holding the guard (RAII pattern)
2707 let execution_result = async {
2708 debug!("Executing task with acquired resources");
2709 match &mut current_executable {
2710 CurrentExecutable::TaskExecutor(executor) => {
2711 executor.execute(inner.cancel_token.child_token()).await
2712 }
2713 CurrentExecutable::Continuation(continuation) => {
2714 // Execute continuation and handle type erasure
2715 match continuation.execute(inner.cancel_token.child_token()).await {
2716 TaskExecutionResult::Success(result) => {
2717 // Try to downcast the result to the expected type T
2718 if let Ok(typed_result) = result.downcast::<T>() {
2719 TaskExecutionResult::Success(*typed_result)
2720 } else {
2721 // Type mismatch - this shouldn't happen with proper usage
2722 let type_error = anyhow::anyhow!(
2723 "Continuation task returned wrong type"
2724 );
2725 error!(
2726 ?type_error,
2727 "Type mismatch in continuation task result"
2728 );
2729 TaskExecutionResult::Error(type_error)
2730 }
2731 }
2732 TaskExecutionResult::Cancelled => {
2733 TaskExecutionResult::Cancelled
2734 }
2735 TaskExecutionResult::Error(error) => {
2736 TaskExecutionResult::Error(error)
2737 }
2738 }
2739 }
2740 }
2741 }
2742 .instrument(tracing::debug_span!("task_execution"))
2743 .await;
2744
2745 // Active counter will be decremented automatically when active_guard drops
2746
2747 match execution_result {
2748 TaskExecutionResult::Success(value) => {
2749 inner.metrics.increment_success();
2750 debug!("Task completed successfully");
2751 return Ok(value);
2752 }
2753 TaskExecutionResult::Cancelled => {
2754 inner.metrics.increment_cancelled();
2755 debug!("Task was cancelled during execution");
2756 return Err(TaskError::Cancelled);
2757 }
2758 TaskExecutionResult::Error(error) => {
2759 debug!("Task failed - handling error through policy - {error:?}");
2760
2761 // Handle the error through the policy system
2762 let (action_result, guard_state) = Self::handle_task_error(
2763 &error,
2764 &mut error_context,
2765 task_id,
2766 &inner,
2767 )
2768 .await;
2769
2770 // Update the scheduler guard state for evaluation after the match
2771 scheduler_guard_state = guard_state;
2772
2773 match action_result {
2774 ActionResult::Fail => {
2775 inner.metrics.increment_failed();
2776 debug!("Policy accepted error - task failed {error:?}");
2777 return Err(TaskError::Failed(error));
2778 }
2779 ActionResult::Shutdown => {
2780 inner.metrics.increment_failed();
2781 warn!("Policy triggered shutdown - {error:?}");
2782 inner.cancel();
2783 return Err(TaskError::Failed(error));
2784 }
2785 ActionResult::Continue { continuation } => {
2786 debug!(
2787 "Policy provided next executable - continuing loop - {error:?}"
2788 );
2789
2790 // Update current executable
2791 current_executable =
2792 CurrentExecutable::Continuation(continuation);
2793
2794 continue; // Continue the main loop with the new executable
2795 }
2796 }
2797 }
2798 }
2799 }
2800 SchedulingResult::Cancelled => {
2801 inner.metrics.increment_cancelled();
2802 debug!("Task was cancelled during resource acquisition");
2803 return Err(TaskError::Cancelled);
2804 }
2805 SchedulingResult::Rejected(reason) => {
2806 inner.metrics.increment_rejected();
2807 debug!(reason, "Task was rejected by scheduler");
2808 return Err(TaskError::Failed(anyhow::anyhow!(
2809 "Task rejected: {}",
2810 reason
2811 )));
2812 }
2813 }
2814 }
2815 }
2816
2817 /// Handle task errors through the error policy and return the action to take
2818 async fn handle_task_error(
2819 error: &anyhow::Error,
2820 error_context: &mut Option<OnErrorContext>,
2821 task_id: TaskId,
2822 inner: &Arc<TaskTrackerInner>,
2823 ) -> (ActionResult, self::GuardState) {
2824 // Create or update the error context (lazy initialization)
2825 let context = error_context.get_or_insert_with(|| OnErrorContext {
2826 attempt_count: 0, // Will be incremented below
2827 task_id,
2828 execution_context: TaskExecutionContext {
2829 scheduler: inner.scheduler.clone(),
2830 metrics: inner.metrics.clone(),
2831 },
2832 state: inner.error_policy.create_context(),
2833 });
2834
2835 // Increment attempt count for this error
2836 context.attempt_count += 1;
2837 let current_attempt = context.attempt_count;
2838
2839 // First, check if the policy allows continuations for this error
2840 if inner.error_policy.allow_continuation(error, context) {
2841 // Policy allows continuations, check if this is a FailedWithContinuation (task-driven continuation)
2842 if let Some(continuation_err) = error.downcast_ref::<FailedWithContinuation>() {
2843 debug!(
2844 task_id = %task_id,
2845 attempt_count = current_attempt,
2846 "Task provided FailedWithContinuation and policy allows continuations - {error:?}"
2847 );
2848
2849 // Task has provided a continuation implementation for the next attempt
2850 // Clone the Arc to return it in ActionResult::Continue
2851 let continuation = continuation_err.continuation.clone();
2852
2853 // Ask policy whether to reschedule task-driven continuation
2854 let should_reschedule = inner.error_policy.should_reschedule(error, context);
2855
2856 let guard_state = if should_reschedule {
2857 self::GuardState::Reschedule
2858 } else {
2859 self::GuardState::Keep
2860 };
2861
2862 return (ActionResult::Continue { continuation }, guard_state);
2863 }
2864 } else {
2865 debug!(
2866 task_id = %task_id,
2867 attempt_count = current_attempt,
2868 "Policy rejected continuations, ignoring any FailedWithContinuation - {error:?}"
2869 );
2870 }
2871
2872 let response = inner.error_policy.on_error(error, context);
2873
2874 match response {
2875 ErrorResponse::Fail => (ActionResult::Fail, self::GuardState::Keep),
2876 ErrorResponse::Shutdown => (ActionResult::Shutdown, self::GuardState::Keep),
2877 ErrorResponse::Custom(action) => {
2878 debug!("Task failed - executing custom action - {error:?}");
2879
2880 // Execute the custom action asynchronously
2881 let action_result = action
2882 .execute(error, task_id, current_attempt, &context.execution_context)
2883 .await;
2884 debug!(?action_result, "Custom action completed");
2885
2886 // If the custom action returned Continue, ask policy about rescheduling
2887 let guard_state = match &action_result {
2888 ActionResult::Continue { .. } => {
2889 let should_reschedule =
2890 inner.error_policy.should_reschedule(error, context);
2891 if should_reschedule {
2892 self::GuardState::Reschedule
2893 } else {
2894 self::GuardState::Keep
2895 }
2896 }
2897 _ => self::GuardState::Keep, // Fail/Shutdown don't need guard state
2898 };
2899
2900 (action_result, guard_state)
2901 }
2902 }
2903 }
2904}
2905
2906// Blanket implementation for all schedulers
2907impl ArcPolicy for UnlimitedScheduler {}
2908impl ArcPolicy for SemaphoreScheduler {}
2909
2910// Blanket implementation for all error policies
2911impl ArcPolicy for LogOnlyPolicy {}
2912impl ArcPolicy for CancelOnError {}
2913impl ArcPolicy for ThresholdCancelPolicy {}
2914impl ArcPolicy for RateCancelPolicy {}
2915
2916/// Resource guard for unlimited scheduling
2917///
2918/// This guard represents "unlimited" resources - no actual resource constraints.
2919/// Since there are no resources to manage, this guard is essentially a no-op.
2920#[derive(Debug)]
2921pub struct UnlimitedGuard;
2922
2923impl ResourceGuard for UnlimitedGuard {
2924 // No resources to manage - marker trait implementation only
2925}
2926
2927/// Unlimited task scheduler that executes all tasks immediately
2928///
2929/// This scheduler provides no concurrency limits and executes all submitted tasks
2930/// immediately. Useful for testing, high-throughput scenarios, or when external
2931/// systems provide the concurrency control.
2932///
2933/// ## Cancellation Behavior
2934///
2935/// - Respects cancellation tokens before resource acquisition
2936/// - Once execution begins (via ResourceGuard), always awaits task completion
2937/// - Tasks handle their own cancellation internally (if created with `spawn_cancellable`)
2938///
2939/// # Example
2940/// ```rust
2941/// # use dynamo_runtime::utils::tasks::tracker::UnlimitedScheduler;
2942/// let scheduler = UnlimitedScheduler::new();
2943/// ```
2944#[derive(Debug)]
2945pub struct UnlimitedScheduler;
2946
2947impl UnlimitedScheduler {
2948 /// Create a new unlimited scheduler returning Arc
2949 pub fn new() -> Arc<Self> {
2950 Arc::new(Self)
2951 }
2952}
2953
2954impl Default for UnlimitedScheduler {
2955 fn default() -> Self {
2956 UnlimitedScheduler
2957 }
2958}
2959
2960#[async_trait]
2961impl TaskScheduler for UnlimitedScheduler {
2962 async fn acquire_execution_slot(
2963 &self,
2964 cancel_token: CancellationToken,
2965 ) -> SchedulingResult<Box<dyn ResourceGuard>> {
2966 debug!("Acquiring execution slot (unlimited scheduler)");
2967
2968 // Check for cancellation before allocating resources
2969 if cancel_token.is_cancelled() {
2970 debug!("Task cancelled before acquiring execution slot");
2971 return SchedulingResult::Cancelled;
2972 }
2973
2974 // No resource constraints for unlimited scheduler
2975 debug!("Execution slot acquired immediately");
2976 SchedulingResult::Execute(Box::new(UnlimitedGuard))
2977 }
2978}
2979
2980/// Resource guard for semaphore-based scheduling
2981///
2982/// This guard holds a semaphore permit and enforces that task execution
2983/// always runs to completion. The permit is automatically released when
2984/// the guard is dropped.
2985#[derive(Debug)]
2986pub struct SemaphoreGuard {
2987 _permit: tokio::sync::OwnedSemaphorePermit,
2988}
2989
2990impl ResourceGuard for SemaphoreGuard {
2991 // Permit is automatically released when the guard is dropped
2992}
2993
2994/// Semaphore-based task scheduler
2995///
2996/// Limits concurrent task execution using a [`tokio::sync::Semaphore`].
2997/// Tasks will wait for an available permit before executing.
2998///
2999/// ## Cancellation Behavior
3000///
3001/// - Respects cancellation tokens before and during permit acquisition
3002/// - Once a permit is acquired (via ResourceGuard), always awaits task completion
3003/// - Holds the permit until the task completes (regardless of cancellation)
3004/// - Tasks handle their own cancellation internally (if created with `spawn_cancellable`)
3005///
3006/// This ensures that permits are not leaked when tasks are cancelled, while still
3007/// allowing cancellable tasks to terminate gracefully on their own.
3008///
3009/// # Example
3010/// ```rust
3011/// # use std::sync::Arc;
3012/// # use tokio::sync::Semaphore;
3013/// # use dynamo_runtime::utils::tasks::tracker::SemaphoreScheduler;
3014/// // Allow up to 5 concurrent tasks
3015/// let semaphore = Arc::new(Semaphore::new(5));
3016/// let scheduler = SemaphoreScheduler::new(semaphore);
3017/// ```
3018#[derive(Debug)]
3019pub struct SemaphoreScheduler {
3020 semaphore: Arc<Semaphore>,
3021}
3022
3023impl SemaphoreScheduler {
3024 /// Create a new semaphore scheduler
3025 ///
3026 /// # Arguments
3027 /// * `semaphore` - Semaphore to use for concurrency control
3028 pub fn new(semaphore: Arc<Semaphore>) -> Self {
3029 Self { semaphore }
3030 }
3031
3032 /// Create a semaphore scheduler with the specified number of permits, returning Arc
3033 pub fn with_permits(permits: usize) -> Arc<Self> {
3034 Arc::new(Self::new(Arc::new(Semaphore::new(permits))))
3035 }
3036
3037 /// Get the number of available permits
3038 pub fn available_permits(&self) -> usize {
3039 self.semaphore.available_permits()
3040 }
3041}
3042
3043#[async_trait]
3044impl TaskScheduler for SemaphoreScheduler {
3045 async fn acquire_execution_slot(
3046 &self,
3047 cancel_token: CancellationToken,
3048 ) -> SchedulingResult<Box<dyn ResourceGuard>> {
3049 debug!("Acquiring semaphore permit");
3050
3051 // Check for cancellation before attempting to acquire semaphore
3052 if cancel_token.is_cancelled() {
3053 debug!("Task cancelled before acquiring semaphore permit");
3054 return SchedulingResult::Cancelled;
3055 }
3056
3057 // Try to acquire a permit, with cancellation support
3058 let permit = {
3059 tokio::select! {
3060 result = self.semaphore.clone().acquire_owned() => {
3061 match result {
3062 Ok(permit) => permit,
3063 Err(_) => return SchedulingResult::Cancelled,
3064 }
3065 }
3066 _ = cancel_token.cancelled() => {
3067 debug!("Task cancelled while waiting for semaphore permit");
3068 return SchedulingResult::Cancelled;
3069 }
3070 }
3071 };
3072
3073 debug!("Acquired semaphore permit");
3074 SchedulingResult::Execute(Box::new(SemaphoreGuard { _permit: permit }))
3075 }
3076}
3077
3078/// Error policy that triggers cancellation based on error patterns
3079///
3080/// This policy analyzes error messages and returns `ErrorResponse::Shutdown` when:
3081/// - No patterns are specified (cancels on any error)
3082/// - Error message matches one of the specified patterns
3083///
3084/// The TaskTracker handles the actual cancellation - this policy just makes the decision.
3085///
3086/// # Example
3087/// ```rust
3088/// # use dynamo_runtime::utils::tasks::tracker::CancelOnError;
3089/// // Cancel on any error
3090/// let policy = CancelOnError::new();
3091///
3092/// // Cancel only on specific error patterns
3093/// let (policy, _token) = CancelOnError::with_patterns(
3094/// vec!["OutOfMemory".to_string(), "DeviceError".to_string()]
3095/// );
3096/// ```
3097#[derive(Debug)]
3098pub struct CancelOnError {
3099 error_patterns: Vec<String>,
3100}
3101
3102impl CancelOnError {
3103 /// Create a new cancel-on-error policy that cancels on any error
3104 ///
3105 /// Returns a policy with no error patterns, meaning it will cancel the TaskTracker
3106 /// on any task failure.
3107 pub fn new() -> Arc<Self> {
3108 Arc::new(Self {
3109 error_patterns: vec![], // Empty patterns = cancel on any error
3110 })
3111 }
3112
3113 /// Create a new cancel-on-error policy with custom error patterns, returning Arc and token
3114 ///
3115 /// # Arguments
3116 /// * `error_patterns` - List of error message patterns that trigger cancellation
3117 pub fn with_patterns(error_patterns: Vec<String>) -> (Arc<Self>, CancellationToken) {
3118 let token = CancellationToken::new();
3119 let policy = Arc::new(Self { error_patterns });
3120 (policy, token)
3121 }
3122}
3123
3124#[async_trait]
3125impl OnErrorPolicy for CancelOnError {
3126 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3127 // Child gets a child cancel token - when parent cancels, child cancels too
3128 // When child cancels, parent is unaffected
3129 Arc::new(CancelOnError {
3130 error_patterns: self.error_patterns.clone(),
3131 })
3132 }
3133
3134 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3135 None // Stateless policy - no heap allocation
3136 }
3137
3138 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3139 error!(?context.task_id, "Task failed - {error:?}");
3140
3141 if self.error_patterns.is_empty() {
3142 return ErrorResponse::Shutdown;
3143 }
3144
3145 // Check if this error should trigger cancellation
3146 let error_str = error.to_string();
3147 let should_cancel = self
3148 .error_patterns
3149 .iter()
3150 .any(|pattern| error_str.contains(pattern));
3151
3152 if should_cancel {
3153 ErrorResponse::Shutdown
3154 } else {
3155 ErrorResponse::Fail
3156 }
3157 }
3158}
3159
3160/// Simple error policy that only logs errors
3161///
3162/// This policy does not trigger cancellation and is useful for
3163/// non-critical tasks or when you want to handle errors externally.
3164#[derive(Debug)]
3165pub struct LogOnlyPolicy;
3166
3167impl LogOnlyPolicy {
3168 /// Create a new log-only policy returning Arc
3169 pub fn new() -> Arc<Self> {
3170 Arc::new(Self)
3171 }
3172}
3173
3174impl Default for LogOnlyPolicy {
3175 fn default() -> Self {
3176 LogOnlyPolicy
3177 }
3178}
3179
3180impl OnErrorPolicy for LogOnlyPolicy {
3181 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3182 // Simple policies can just clone themselves
3183 Arc::new(LogOnlyPolicy)
3184 }
3185
3186 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3187 None // Stateless policy - no heap allocation
3188 }
3189
3190 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3191 error!(?context.task_id, "Task failed - logging only - {error:?}");
3192 ErrorResponse::Fail
3193 }
3194}
3195
3196/// Error policy that cancels tasks after a threshold number of failures
3197///
3198/// This policy tracks the number of failed tasks and triggers cancellation
3199/// when the failure count exceeds the specified threshold. Useful for
3200/// preventing cascading failures in distributed systems.
3201///
3202/// # Example
3203/// ```rust
3204/// # use dynamo_runtime::utils::tasks::tracker::ThresholdCancelPolicy;
3205/// // Cancel after 5 failures
3206/// let policy = ThresholdCancelPolicy::with_threshold(5);
3207/// ```
3208#[derive(Debug)]
3209pub struct ThresholdCancelPolicy {
3210 max_failures: usize,
3211 failure_count: AtomicU64,
3212}
3213
3214impl ThresholdCancelPolicy {
3215 /// Create a new threshold cancel policy with specified failure threshold, returning Arc and token
3216 ///
3217 /// # Arguments
3218 /// * `max_failures` - Maximum number of failures before cancellation
3219 pub fn with_threshold(max_failures: usize) -> Arc<Self> {
3220 Arc::new(Self {
3221 max_failures,
3222 failure_count: AtomicU64::new(0),
3223 })
3224 }
3225
3226 /// Get the current failure count
3227 pub fn failure_count(&self) -> u64 {
3228 self.failure_count.load(Ordering::Relaxed)
3229 }
3230
3231 /// Reset the failure count to zero
3232 ///
3233 /// This is primarily useful for testing scenarios where you want to reset
3234 /// the policy state between test cases.
3235 pub fn reset_failure_count(&self) {
3236 self.failure_count.store(0, Ordering::Relaxed);
3237 }
3238}
3239
3240/// Per-task state for ThresholdCancelPolicy
3241#[derive(Debug)]
3242struct ThresholdState {
3243 failure_count: u32,
3244}
3245
3246impl OnErrorPolicy for ThresholdCancelPolicy {
3247 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3248 // Child gets a child cancel token and inherits the same failure threshold
3249 Arc::new(ThresholdCancelPolicy {
3250 max_failures: self.max_failures,
3251 failure_count: AtomicU64::new(0), // Child starts with fresh count
3252 })
3253 }
3254
3255 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3256 Some(Box::new(ThresholdState { failure_count: 0 }))
3257 }
3258
3259 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3260 error!(?context.task_id, "Task failed - {error:?}");
3261
3262 // Increment global counter for backwards compatibility
3263 let global_failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
3264
3265 // Get per-task state for the actual decision logic
3266 let state = context
3267 .state
3268 .as_mut()
3269 .expect("ThresholdCancelPolicy requires state")
3270 .downcast_mut::<ThresholdState>()
3271 .expect("Context type mismatch");
3272
3273 state.failure_count += 1;
3274 let current_failures = state.failure_count;
3275
3276 if current_failures >= self.max_failures as u32 {
3277 warn!(
3278 ?context.task_id,
3279 current_failures,
3280 global_failures,
3281 max_failures = self.max_failures,
3282 "Per-task failure threshold exceeded, triggering cancellation"
3283 );
3284 ErrorResponse::Shutdown
3285 } else {
3286 debug!(
3287 ?context.task_id,
3288 current_failures,
3289 global_failures,
3290 max_failures = self.max_failures,
3291 "Task failed, tracking per-task failure count"
3292 );
3293 ErrorResponse::Fail
3294 }
3295 }
3296}
3297
3298/// Error policy that cancels tasks when failure rate exceeds threshold within time window
3299///
3300/// This policy tracks failures over a rolling time window and triggers cancellation
3301/// when the failure rate exceeds the specified threshold. More sophisticated than
3302/// simple count-based thresholds as it considers the time dimension.
3303///
3304/// # Example
3305/// ```rust
3306/// # use dynamo_runtime::utils::tasks::tracker::RateCancelPolicy;
3307/// // Cancel if more than 50% of tasks fail within any 60-second window
3308/// let (policy, token) = RateCancelPolicy::builder()
3309/// .rate(0.5)
3310/// .window_secs(60)
3311/// .build();
3312/// ```
3313#[derive(Debug)]
3314pub struct RateCancelPolicy {
3315 cancel_token: CancellationToken,
3316 max_failure_rate: f32,
3317 window_secs: u64,
3318 // TODO: Implement time-window tracking when needed
3319 // For now, this is a placeholder structure with the interface defined
3320}
3321
3322impl RateCancelPolicy {
3323 /// Create a builder for rate-based cancel policy
3324 pub fn builder() -> RateCancelPolicyBuilder {
3325 RateCancelPolicyBuilder::new()
3326 }
3327}
3328
3329/// Builder for RateCancelPolicy
3330pub struct RateCancelPolicyBuilder {
3331 max_failure_rate: Option<f32>,
3332 window_secs: Option<u64>,
3333}
3334
3335impl RateCancelPolicyBuilder {
3336 fn new() -> Self {
3337 Self {
3338 max_failure_rate: None,
3339 window_secs: None,
3340 }
3341 }
3342
3343 /// Set the maximum failure rate (0.0 to 1.0) before cancellation
3344 pub fn rate(mut self, max_failure_rate: f32) -> Self {
3345 self.max_failure_rate = Some(max_failure_rate);
3346 self
3347 }
3348
3349 /// Set the time window in seconds for rate calculation
3350 pub fn window_secs(mut self, window_secs: u64) -> Self {
3351 self.window_secs = Some(window_secs);
3352 self
3353 }
3354
3355 /// Build the policy, returning Arc and cancellation token
3356 pub fn build(self) -> (Arc<RateCancelPolicy>, CancellationToken) {
3357 let max_failure_rate = self.max_failure_rate.expect("rate must be set");
3358 let window_secs = self.window_secs.expect("window_secs must be set");
3359
3360 let token = CancellationToken::new();
3361 let policy = Arc::new(RateCancelPolicy {
3362 cancel_token: token.clone(),
3363 max_failure_rate,
3364 window_secs,
3365 });
3366 (policy, token)
3367 }
3368}
3369
3370#[async_trait]
3371impl OnErrorPolicy for RateCancelPolicy {
3372 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3373 Arc::new(RateCancelPolicy {
3374 cancel_token: self.cancel_token.child_token(),
3375 max_failure_rate: self.max_failure_rate,
3376 window_secs: self.window_secs,
3377 })
3378 }
3379
3380 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3381 None // Stateless policy for now (TODO: add time-window state)
3382 }
3383
3384 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3385 error!(?context.task_id, "Task failed - {error:?}");
3386
3387 // TODO: Implement time-window failure rate calculation
3388 // For now, just log the error and continue
3389 warn!(
3390 ?context.task_id,
3391 max_failure_rate = self.max_failure_rate,
3392 window_secs = self.window_secs,
3393 "Rate-based error policy - time window tracking not yet implemented"
3394 );
3395
3396 ErrorResponse::Fail
3397 }
3398}
3399
3400/// Custom action that triggers a cancellation token when executed
3401///
3402/// This action demonstrates the ErrorResponse::Custom behavior by capturing
3403/// an external cancellation token and triggering it when executed.
3404#[derive(Debug)]
3405pub struct TriggerCancellationTokenAction {
3406 cancel_token: CancellationToken,
3407}
3408
3409impl TriggerCancellationTokenAction {
3410 pub fn new(cancel_token: CancellationToken) -> Self {
3411 Self { cancel_token }
3412 }
3413}
3414
3415#[async_trait]
3416impl OnErrorAction for TriggerCancellationTokenAction {
3417 async fn execute(
3418 &self,
3419 error: &anyhow::Error,
3420 task_id: TaskId,
3421 _attempt_count: u32,
3422 _context: &TaskExecutionContext,
3423 ) -> ActionResult {
3424 warn!(
3425 ?task_id,
3426 "Executing custom action: triggering cancellation token - {error:?}"
3427 );
3428
3429 // Trigger the custom cancellation token
3430 self.cancel_token.cancel();
3431
3432 // Return success - the action completed successfully
3433 ActionResult::Shutdown
3434 }
3435}
3436
3437/// Test error policy that triggers a custom cancellation token on any error
3438///
3439/// This policy demonstrates the ErrorResponse::Custom behavior by capturing
3440/// an external cancellation token and triggering it when any error occurs.
3441/// Used for testing custom error handling actions.
3442///
3443/// # Example
3444/// ```rust
3445/// # use tokio_util::sync::CancellationToken;
3446/// # use dynamo_runtime::utils::tasks::tracker::TriggerCancellationTokenOnError;
3447/// let cancel_token = CancellationToken::new();
3448/// let policy = TriggerCancellationTokenOnError::new(cancel_token.clone());
3449///
3450/// // Policy will trigger the token on any error via ErrorResponse::Custom
3451/// ```
3452#[derive(Debug)]
3453pub struct TriggerCancellationTokenOnError {
3454 cancel_token: CancellationToken,
3455}
3456
3457impl TriggerCancellationTokenOnError {
3458 /// Create a new policy that triggers the given cancellation token on errors
3459 pub fn new(cancel_token: CancellationToken) -> Arc<Self> {
3460 Arc::new(Self { cancel_token })
3461 }
3462}
3463
3464impl OnErrorPolicy for TriggerCancellationTokenOnError {
3465 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3466 // Child gets a child cancel token
3467 Arc::new(TriggerCancellationTokenOnError {
3468 cancel_token: self.cancel_token.clone(),
3469 })
3470 }
3471
3472 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3473 None // Stateless policy - no heap allocation
3474 }
3475
3476 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3477 error!(
3478 ?context.task_id,
3479 "Task failed - triggering custom cancellation token - {error:?}"
3480 );
3481
3482 // Create the custom action that will trigger our token
3483 let action = TriggerCancellationTokenAction::new(self.cancel_token.clone());
3484
3485 // Return Custom response with our action
3486 ErrorResponse::Custom(Box::new(action))
3487 }
3488}
3489
3490#[cfg(test)]
3491mod tests {
3492 use super::*;
3493 use rstest::*;
3494 use std::sync::atomic::AtomicU32;
3495 use std::time::Duration;
3496
3497 // Test fixtures using rstest
3498 #[fixture]
3499 fn semaphore_scheduler() -> Arc<SemaphoreScheduler> {
3500 Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))))
3501 }
3502
3503 #[fixture]
3504 fn unlimited_scheduler() -> Arc<UnlimitedScheduler> {
3505 UnlimitedScheduler::new()
3506 }
3507
3508 #[fixture]
3509 fn log_policy() -> Arc<LogOnlyPolicy> {
3510 LogOnlyPolicy::new()
3511 }
3512
3513 #[fixture]
3514 fn cancel_policy() -> Arc<CancelOnError> {
3515 CancelOnError::new()
3516 }
3517
3518 #[fixture]
3519 fn basic_tracker(
3520 unlimited_scheduler: Arc<UnlimitedScheduler>,
3521 log_policy: Arc<LogOnlyPolicy>,
3522 ) -> TaskTracker {
3523 TaskTracker::new(unlimited_scheduler, log_policy).unwrap()
3524 }
3525
3526 #[rstest]
3527 #[tokio::test]
3528 async fn test_basic_task_execution(basic_tracker: TaskTracker) {
3529 // Test successful task execution
3530 let (tx, rx) = tokio::sync::oneshot::channel();
3531 let handle = basic_tracker.spawn(async {
3532 // Wait for signal to complete instead of sleep
3533 rx.await.ok();
3534 Ok(42)
3535 });
3536
3537 // Signal task to complete
3538 tx.send(()).ok();
3539
3540 // Verify task completes successfully
3541 let result = handle
3542 .await
3543 .expect("Task should complete")
3544 .expect("Task should succeed");
3545 assert_eq!(result, 42);
3546
3547 // Verify metrics
3548 assert_eq!(basic_tracker.metrics().success(), 1);
3549 assert_eq!(basic_tracker.metrics().failed(), 0);
3550 assert_eq!(basic_tracker.metrics().cancelled(), 0);
3551 assert_eq!(basic_tracker.metrics().active(), 0);
3552 }
3553
3554 #[rstest]
3555 #[tokio::test]
3556 async fn test_task_failure(
3557 semaphore_scheduler: Arc<SemaphoreScheduler>,
3558 log_policy: Arc<LogOnlyPolicy>,
3559 ) {
3560 // Test task failure handling
3561 let tracker = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3562
3563 let handle = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("test error")) });
3564
3565 let result = handle.await.unwrap();
3566 assert!(result.is_err());
3567 assert!(matches!(result.unwrap_err(), TaskError::Failed(_)));
3568
3569 // Verify metrics
3570 assert_eq!(tracker.metrics().success(), 0);
3571 assert_eq!(tracker.metrics().failed(), 1);
3572 assert_eq!(tracker.metrics().cancelled(), 0);
3573 }
3574
3575 #[rstest]
3576 #[tokio::test]
3577 async fn test_semaphore_concurrency_limit(log_policy: Arc<LogOnlyPolicy>) {
3578 // Test that semaphore limits concurrent execution
3579 let limited_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2)))); // Only 2 concurrent tasks
3580 let tracker = TaskTracker::new(limited_scheduler, log_policy).unwrap();
3581
3582 let counter = Arc::new(AtomicU32::new(0));
3583 let max_concurrent = Arc::new(AtomicU32::new(0));
3584
3585 // Use broadcast channel to coordinate all tasks
3586 let (tx, _) = tokio::sync::broadcast::channel(1);
3587 let mut handles = Vec::new();
3588
3589 // Spawn 5 tasks that will track concurrency
3590 for _ in 0..5 {
3591 let counter_clone = counter.clone();
3592 let max_clone = max_concurrent.clone();
3593 let mut rx = tx.subscribe();
3594
3595 let handle = tracker.spawn(async move {
3596 // Increment active counter
3597 let current = counter_clone.fetch_add(1, Ordering::Relaxed) + 1;
3598
3599 // Track max concurrent
3600 max_clone.fetch_max(current, Ordering::Relaxed);
3601
3602 // Wait for signal to complete instead of sleep
3603 rx.recv().await.ok();
3604
3605 // Decrement when done
3606 counter_clone.fetch_sub(1, Ordering::Relaxed);
3607
3608 Ok(())
3609 });
3610 handles.push(handle);
3611 }
3612
3613 // Give tasks time to start and register concurrency
3614 tokio::task::yield_now().await;
3615 tokio::task::yield_now().await;
3616
3617 // Signal all tasks to complete
3618 tx.send(()).ok();
3619
3620 // Wait for all tasks to complete
3621 for handle in handles {
3622 handle.await.unwrap().unwrap();
3623 }
3624
3625 // Verify that no more than 2 tasks ran concurrently
3626 assert!(max_concurrent.load(Ordering::Relaxed) <= 2);
3627
3628 // Verify all tasks completed successfully
3629 assert_eq!(tracker.metrics().success(), 5);
3630 assert_eq!(tracker.metrics().failed(), 0);
3631 }
3632
3633 #[rstest]
3634 #[tokio::test]
3635 async fn test_cancel_on_error_policy() {
3636 // Test that CancelOnError policy works correctly
3637 let error_policy = cancel_policy();
3638 let scheduler = semaphore_scheduler();
3639 let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
3640
3641 // Spawn a task that will trigger cancellation
3642 let handle =
3643 tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("OutOfMemory error occurred")) });
3644
3645 // Wait for the error to occur
3646 let result = handle.await.unwrap();
3647 assert!(result.is_err());
3648
3649 // Give cancellation time to propagate
3650 tokio::time::sleep(Duration::from_millis(10)).await;
3651
3652 // Verify the cancel token was triggered
3653 assert!(tracker.cancellation_token().is_cancelled());
3654 }
3655
3656 #[rstest]
3657 #[tokio::test]
3658 async fn test_tracker_cancellation() {
3659 // Test manual cancellation of tracker with CancelOnError policy
3660 let error_policy = cancel_policy();
3661 let scheduler = semaphore_scheduler();
3662 let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
3663 let cancel_token = tracker.cancellation_token().child_token();
3664
3665 // Use oneshot channel instead of sleep for deterministic timing
3666 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
3667
3668 // Spawn a task that respects cancellation
3669 let handle = tracker.spawn({
3670 let cancel_token = cancel_token.clone();
3671 async move {
3672 tokio::select! {
3673 _ = rx => Ok(()),
3674 _ = cancel_token.cancelled() => Err(anyhow::anyhow!("Task was cancelled")),
3675 }
3676 }
3677 });
3678
3679 // Cancel the tracker
3680 tracker.cancel();
3681
3682 // Task should be cancelled
3683 let result = handle.await.unwrap();
3684 assert!(result.is_err());
3685 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
3686 }
3687
3688 #[rstest]
3689 #[tokio::test]
3690 async fn test_child_tracker_independence(
3691 semaphore_scheduler: Arc<SemaphoreScheduler>,
3692 log_policy: Arc<LogOnlyPolicy>,
3693 ) {
3694 // Test that child tracker has independent lifecycle
3695 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3696
3697 let child = parent.child_tracker().unwrap();
3698
3699 // Both should be operational initially
3700 assert!(!parent.is_closed());
3701 assert!(!child.is_closed());
3702
3703 // Cancel child only
3704 child.cancel();
3705
3706 // Parent should remain operational
3707 assert!(!parent.is_closed());
3708
3709 // Parent can still spawn tasks
3710 let handle = parent.spawn(async { Ok(42) });
3711 let result = handle.await.unwrap().unwrap();
3712 assert_eq!(result, 42);
3713 }
3714
3715 #[rstest]
3716 #[tokio::test]
3717 async fn test_independent_metrics(
3718 semaphore_scheduler: Arc<SemaphoreScheduler>,
3719 log_policy: Arc<LogOnlyPolicy>,
3720 ) {
3721 // Test that parent and child have independent metrics
3722 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3723 let child = parent.child_tracker().unwrap();
3724
3725 // Run tasks in parent
3726 let handle1 = parent.spawn(async { Ok(1) });
3727 handle1.await.unwrap().unwrap();
3728
3729 // Run tasks in child
3730 let handle2 = child.spawn(async { Ok(2) });
3731 handle2.await.unwrap().unwrap();
3732
3733 // Each should have their own metrics, but parent sees aggregated
3734 assert_eq!(parent.metrics().success(), 2); // Parent sees its own + child's
3735 assert_eq!(child.metrics().success(), 1); // Child sees only its own
3736 assert_eq!(parent.metrics().total_completed(), 2); // Parent sees aggregated total
3737 assert_eq!(child.metrics().total_completed(), 1); // Child sees only its own
3738 }
3739
3740 #[rstest]
3741 #[tokio::test]
3742 async fn test_cancel_on_error_hierarchy() {
3743 // Test that child error policy cancellation doesn't affect parent
3744 let parent_error_policy = cancel_policy();
3745 let scheduler = semaphore_scheduler();
3746 let parent = TaskTracker::new(scheduler, parent_error_policy).unwrap();
3747 let parent_policy_token = parent.cancellation_token().child_token();
3748 let child = parent.child_tracker().unwrap();
3749
3750 // Initially nothing should be cancelled
3751 assert!(!parent_policy_token.is_cancelled());
3752
3753 // Use explicit synchronization instead of sleep
3754 let (error_tx, error_rx) = tokio::sync::oneshot::channel();
3755 let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel();
3756
3757 // Spawn a monitoring task to watch for the parent policy token cancellation
3758 let parent_token_monitor = parent_policy_token.clone();
3759 let monitor_handle = tokio::spawn(async move {
3760 tokio::select! {
3761 _ = parent_token_monitor.cancelled() => {
3762 cancel_tx.send(true).ok();
3763 }
3764 _ = tokio::time::sleep(Duration::from_millis(100)) => {
3765 cancel_tx.send(false).ok();
3766 }
3767 }
3768 });
3769
3770 // Spawn a task in the child that will trigger cancellation
3771 let handle = child.spawn(async move {
3772 let result = Err::<(), _>(anyhow::anyhow!("OutOfMemory in child"));
3773 error_tx.send(()).ok(); // Signal that the error has occurred
3774 result
3775 });
3776
3777 // Wait for the error to occur
3778 let error_result = handle.await.unwrap();
3779 assert!(error_result.is_err());
3780
3781 // Wait for our error signal
3782 error_rx.await.ok();
3783
3784 // Check if parent policy token was cancelled within timeout
3785 let was_cancelled = cancel_rx.await.unwrap_or(false);
3786 monitor_handle.await.ok();
3787
3788 // Based on hierarchical design: child errors should NOT affect parent
3789 // The child gets its own policy with a child token, and child cancellation
3790 // should not propagate up to the parent policy token
3791 assert!(
3792 !was_cancelled,
3793 "Parent policy token should not be cancelled by child errors"
3794 );
3795 assert!(
3796 !parent_policy_token.is_cancelled(),
3797 "Parent policy token should remain active"
3798 );
3799 }
3800
3801 #[rstest]
3802 #[tokio::test]
3803 async fn test_graceful_shutdown(
3804 semaphore_scheduler: Arc<SemaphoreScheduler>,
3805 log_policy: Arc<LogOnlyPolicy>,
3806 ) {
3807 // Test graceful shutdown with close()
3808 let tracker = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3809
3810 // Use broadcast channel to coordinate task completion
3811 let (tx, _) = tokio::sync::broadcast::channel(1);
3812 let mut handles = Vec::new();
3813
3814 // Spawn some tasks
3815 for i in 0..3 {
3816 let mut rx = tx.subscribe();
3817 let handle = tracker.spawn(async move {
3818 // Wait for signal instead of sleep
3819 rx.recv().await.ok();
3820 Ok(i)
3821 });
3822 handles.push(handle);
3823 }
3824
3825 // Signal all tasks to complete before closing
3826 tx.send(()).ok();
3827
3828 // Close tracker and wait for completion
3829 tracker.join().await;
3830
3831 // All tasks should complete successfully
3832 for handle in handles {
3833 let result = handle.await.unwrap().unwrap();
3834 assert!(result < 3);
3835 }
3836
3837 // Tracker should be closed
3838 assert!(tracker.is_closed());
3839 }
3840
3841 #[rstest]
3842 #[tokio::test]
3843 async fn test_semaphore_scheduler_permit_tracking(log_policy: Arc<LogOnlyPolicy>) {
3844 // Test that SemaphoreScheduler properly tracks permits
3845 let semaphore = Arc::new(Semaphore::new(3));
3846 let scheduler = Arc::new(SemaphoreScheduler::new(semaphore.clone()));
3847 let tracker = TaskTracker::new(scheduler.clone(), log_policy).unwrap();
3848
3849 // Initially all permits should be available
3850 assert_eq!(scheduler.available_permits(), 3);
3851
3852 // Use broadcast channel to coordinate task completion
3853 let (tx, _) = tokio::sync::broadcast::channel(1);
3854 let mut handles = Vec::new();
3855
3856 // Spawn 3 tasks that will hold permits
3857 for _ in 0..3 {
3858 let mut rx = tx.subscribe();
3859 let handle = tracker.spawn(async move {
3860 // Wait for signal to complete
3861 rx.recv().await.ok();
3862 Ok(())
3863 });
3864 handles.push(handle);
3865 }
3866
3867 // Give tasks time to acquire permits
3868 tokio::task::yield_now().await;
3869 tokio::task::yield_now().await;
3870
3871 // All permits should be taken
3872 assert_eq!(scheduler.available_permits(), 0);
3873
3874 // Signal all tasks to complete
3875 tx.send(()).ok();
3876
3877 // Wait for tasks to complete
3878 for handle in handles {
3879 handle.await.unwrap().unwrap();
3880 }
3881
3882 // All permits should be available again
3883 assert_eq!(scheduler.available_permits(), 3);
3884 }
3885
3886 #[rstest]
3887 #[tokio::test]
3888 async fn test_builder_pattern(log_policy: Arc<LogOnlyPolicy>) {
3889 // Test that TaskTracker builder works correctly
3890 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3891 let error_policy = log_policy;
3892
3893 let tracker = TaskTracker::builder()
3894 .scheduler(scheduler)
3895 .error_policy(error_policy)
3896 .build()
3897 .unwrap();
3898
3899 // Tracker should have a cancellation token
3900 let token = tracker.cancellation_token();
3901 assert!(!token.is_cancelled());
3902
3903 // Should be able to spawn tasks
3904 let handle = tracker.spawn(async { Ok(42) });
3905 let result = handle.await.unwrap().unwrap();
3906 assert_eq!(result, 42);
3907 }
3908
3909 #[rstest]
3910 #[tokio::test]
3911 async fn test_all_trackers_have_cancellation_tokens(log_policy: Arc<LogOnlyPolicy>) {
3912 // Test that all trackers (root and children) have cancellation tokens
3913 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3914 let root = TaskTracker::new(scheduler, log_policy).unwrap();
3915 let child = root.child_tracker().unwrap();
3916 let grandchild = child.child_tracker().unwrap();
3917
3918 // All should have cancellation tokens
3919 let root_token = root.cancellation_token();
3920 let child_token = child.cancellation_token();
3921 let grandchild_token = grandchild.cancellation_token();
3922
3923 assert!(!root_token.is_cancelled());
3924 assert!(!child_token.is_cancelled());
3925 assert!(!grandchild_token.is_cancelled());
3926
3927 // Child tokens should be different from parent
3928 // (We can't directly compare tokens, but we can test behavior)
3929 root_token.cancel();
3930
3931 // Give cancellation time to propagate
3932 tokio::time::sleep(Duration::from_millis(10)).await;
3933
3934 // Root should be cancelled
3935 assert!(root_token.is_cancelled());
3936 // Children should also be cancelled (because they are child tokens)
3937 assert!(child_token.is_cancelled());
3938 assert!(grandchild_token.is_cancelled());
3939 }
3940
3941 #[rstest]
3942 #[tokio::test]
3943 async fn test_spawn_cancellable_task(log_policy: Arc<LogOnlyPolicy>) {
3944 // Test cancellable task spawning with proper result handling
3945 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3946 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
3947
3948 // Test successful completion
3949 let (tx, rx) = tokio::sync::oneshot::channel();
3950 let rx = Arc::new(tokio::sync::Mutex::new(Some(rx)));
3951 let handle = tracker.spawn_cancellable(move |_cancel_token| {
3952 let rx = rx.clone();
3953 async move {
3954 // Wait for signal instead of sleep
3955 if let Some(rx) = rx.lock().await.take() {
3956 rx.await.ok();
3957 }
3958 CancellableTaskResult::Ok(42)
3959 }
3960 });
3961
3962 // Signal task to complete
3963 tx.send(()).ok();
3964
3965 let result = handle.await.unwrap().unwrap();
3966 assert_eq!(result, 42);
3967 assert_eq!(tracker.metrics().success(), 1);
3968
3969 // Test cancellation handling
3970 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
3971 let rx = Arc::new(tokio::sync::Mutex::new(Some(rx)));
3972 let handle = tracker.spawn_cancellable(move |cancel_token| {
3973 let rx = rx.clone();
3974 async move {
3975 tokio::select! {
3976 _ = async {
3977 if let Some(rx) = rx.lock().await.take() {
3978 rx.await.ok();
3979 }
3980 } => CancellableTaskResult::Ok("should not complete"),
3981 _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
3982 }
3983 }
3984 });
3985
3986 // Cancel the tracker
3987 tracker.cancel();
3988
3989 let result = handle.await.unwrap();
3990 assert!(result.is_err());
3991 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
3992 }
3993
3994 #[rstest]
3995 #[tokio::test]
3996 async fn test_cancellable_task_metrics_tracking(log_policy: Arc<LogOnlyPolicy>) {
3997 // Test that properly cancelled tasks increment cancelled metrics, not failed metrics
3998 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3999 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4000
4001 // Baseline metrics
4002 assert_eq!(tracker.metrics().cancelled(), 0);
4003 assert_eq!(tracker.metrics().failed(), 0);
4004 assert_eq!(tracker.metrics().success(), 0);
4005
4006 // Test 1: Task that executes and THEN gets cancelled during execution
4007 let (start_tx, start_rx) = tokio::sync::oneshot::channel::<()>();
4008 let (_continue_tx, continue_rx) = tokio::sync::oneshot::channel::<()>();
4009
4010 let start_tx_shared = Arc::new(tokio::sync::Mutex::new(Some(start_tx)));
4011 let continue_rx_shared = Arc::new(tokio::sync::Mutex::new(Some(continue_rx)));
4012
4013 let start_tx_for_task = start_tx_shared.clone();
4014 let continue_rx_for_task = continue_rx_shared.clone();
4015
4016 let handle = tracker.spawn_cancellable(move |cancel_token| {
4017 let start_tx = start_tx_for_task.clone();
4018 let continue_rx = continue_rx_for_task.clone();
4019 async move {
4020 // Signal that we've started executing
4021 if let Some(tx) = start_tx.lock().await.take() {
4022 tx.send(()).ok();
4023 }
4024
4025 // Wait for either continuation signal or cancellation
4026 tokio::select! {
4027 _ = async {
4028 if let Some(rx) = continue_rx.lock().await.take() {
4029 rx.await.ok();
4030 }
4031 } => CancellableTaskResult::Ok("completed normally"),
4032 _ = cancel_token.cancelled() => {
4033 println!("Task detected cancellation and is returning Cancelled");
4034 CancellableTaskResult::Cancelled
4035 },
4036 }
4037 }
4038 });
4039
4040 // Wait for task to start executing
4041 start_rx.await.ok();
4042
4043 // Now cancel while the task is running
4044 println!("Cancelling tracker while task is executing...");
4045 tracker.cancel();
4046
4047 // Wait for the task to complete
4048 let result = handle.await.unwrap();
4049
4050 // Debug output
4051 println!("Task result: {:?}", result);
4052 println!(
4053 "Cancelled: {}, Failed: {}, Success: {}",
4054 tracker.metrics().cancelled(),
4055 tracker.metrics().failed(),
4056 tracker.metrics().success()
4057 );
4058
4059 // The task should be properly cancelled and counted correctly
4060 assert!(result.is_err());
4061 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
4062
4063 // Verify proper metrics: should be counted as cancelled, not failed
4064 assert_eq!(
4065 tracker.metrics().cancelled(),
4066 1,
4067 "Properly cancelled task should increment cancelled count"
4068 );
4069 assert_eq!(
4070 tracker.metrics().failed(),
4071 0,
4072 "Properly cancelled task should NOT increment failed count"
4073 );
4074 }
4075
4076 #[rstest]
4077 #[tokio::test]
4078 async fn test_cancellable_vs_error_metrics_distinction(log_policy: Arc<LogOnlyPolicy>) {
4079 // Test that we properly distinguish between cancellation and actual errors
4080 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
4081 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4082
4083 // Test 1: Actual error should increment failed count
4084 let handle1 = tracker.spawn_cancellable(|_cancel_token| async move {
4085 CancellableTaskResult::<i32>::Err(anyhow::anyhow!("This is a real error"))
4086 });
4087
4088 let result1 = handle1.await.unwrap();
4089 assert!(result1.is_err());
4090 assert!(matches!(result1.unwrap_err(), TaskError::Failed(_)));
4091 assert_eq!(tracker.metrics().failed(), 1);
4092 assert_eq!(tracker.metrics().cancelled(), 0);
4093
4094 // Test 2: Cancellation should increment cancelled count
4095 let handle2 = tracker.spawn_cancellable(|_cancel_token| async move {
4096 CancellableTaskResult::<i32>::Cancelled
4097 });
4098
4099 let result2 = handle2.await.unwrap();
4100 assert!(result2.is_err());
4101 assert!(matches!(result2.unwrap_err(), TaskError::Cancelled));
4102 assert_eq!(tracker.metrics().failed(), 1); // Still 1 from before
4103 assert_eq!(tracker.metrics().cancelled(), 1); // Now 1 from cancellation
4104 }
4105
4106 #[rstest]
4107 #[tokio::test]
4108 async fn test_spawn_cancellable_error_handling(log_policy: Arc<LogOnlyPolicy>) {
4109 // Test error handling in cancellable tasks
4110 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
4111 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4112
4113 // Test error result
4114 let handle = tracker.spawn_cancellable(|_cancel_token| async move {
4115 CancellableTaskResult::<i32>::Err(anyhow::anyhow!("test error"))
4116 });
4117
4118 let result = handle.await.unwrap();
4119 assert!(result.is_err());
4120 assert!(matches!(result.unwrap_err(), TaskError::Failed(_)));
4121 assert_eq!(tracker.metrics().failed(), 1);
4122 }
4123
4124 #[rstest]
4125 #[tokio::test]
4126 async fn test_cancellation_before_execution(log_policy: Arc<LogOnlyPolicy>) {
4127 // Test that spawning on a cancelled tracker panics (new behavior)
4128 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(1))));
4129 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4130
4131 // Cancel the tracker first
4132 tracker.cancel();
4133
4134 // Give cancellation time to propagate to the inner tracker
4135 tokio::time::sleep(Duration::from_millis(5)).await;
4136
4137 // Now try to spawn a task - it should panic since tracker is closed
4138 let panic_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
4139 tracker.spawn(async { Ok(42) })
4140 }));
4141
4142 // Should panic with our new API
4143 assert!(
4144 panic_result.is_err(),
4145 "spawn() should panic when tracker is closed"
4146 );
4147
4148 // Verify the panic message contains expected text
4149 if let Err(panic_payload) = panic_result {
4150 if let Some(panic_msg) = panic_payload.downcast_ref::<String>() {
4151 assert!(
4152 panic_msg.contains("TaskTracker must not be closed"),
4153 "Panic message should indicate tracker is closed: {}",
4154 panic_msg
4155 );
4156 } else if let Some(panic_msg) = panic_payload.downcast_ref::<&str>() {
4157 assert!(
4158 panic_msg.contains("TaskTracker must not be closed"),
4159 "Panic message should indicate tracker is closed: {}",
4160 panic_msg
4161 );
4162 }
4163 }
4164 }
4165
4166 #[rstest]
4167 #[tokio::test]
4168 async fn test_semaphore_scheduler_with_cancellation(log_policy: Arc<LogOnlyPolicy>) {
4169 // Test that SemaphoreScheduler respects cancellation tokens
4170 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(1))));
4171 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4172
4173 // Start a long-running task to occupy the semaphore
4174 let blocker_token = tracker.cancellation_token();
4175 let _blocker_handle = tracker.spawn(async move {
4176 // Wait for cancellation
4177 blocker_token.cancelled().await;
4178 Ok(())
4179 });
4180
4181 // Give the blocker time to acquire the permit
4182 tokio::task::yield_now().await;
4183
4184 // Use oneshot channel for the second task
4185 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
4186
4187 // Spawn another task that will wait for semaphore
4188 let handle = tracker.spawn(async {
4189 rx.await.ok();
4190 Ok(42)
4191 });
4192
4193 // Cancel the tracker while second task is waiting for permit
4194 tracker.cancel();
4195
4196 // The waiting task should be cancelled
4197 let result = handle.await.unwrap();
4198 assert!(result.is_err());
4199 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
4200 }
4201
4202 #[rstest]
4203 #[tokio::test]
4204 async fn test_child_tracker_cancellation_independence(
4205 semaphore_scheduler: Arc<SemaphoreScheduler>,
4206 log_policy: Arc<LogOnlyPolicy>,
4207 ) {
4208 // Test that child tracker cancellation doesn't affect parent
4209 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4210 let child = parent.child_tracker().unwrap();
4211
4212 // Cancel only the child
4213 child.cancel();
4214
4215 // Parent should still be operational
4216 let parent_token = parent.cancellation_token();
4217 assert!(!parent_token.is_cancelled());
4218
4219 // Parent can still spawn tasks
4220 let handle = parent.spawn(async { Ok(42) });
4221 let result = handle.await.unwrap().unwrap();
4222 assert_eq!(result, 42);
4223
4224 // Child should be cancelled
4225 let child_token = child.cancellation_token();
4226 assert!(child_token.is_cancelled());
4227 }
4228
4229 #[rstest]
4230 #[tokio::test]
4231 async fn test_parent_cancellation_propagates_to_children(
4232 semaphore_scheduler: Arc<SemaphoreScheduler>,
4233 log_policy: Arc<LogOnlyPolicy>,
4234 ) {
4235 // Test that parent cancellation propagates to all children
4236 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4237 let child1 = parent.child_tracker().unwrap();
4238 let child2 = parent.child_tracker().unwrap();
4239 let grandchild = child1.child_tracker().unwrap();
4240
4241 // Cancel the parent
4242 parent.cancel();
4243
4244 // Give cancellation time to propagate
4245 tokio::time::sleep(Duration::from_millis(10)).await;
4246
4247 // All should be cancelled
4248 assert!(parent.cancellation_token().is_cancelled());
4249 assert!(child1.cancellation_token().is_cancelled());
4250 assert!(child2.cancellation_token().is_cancelled());
4251 assert!(grandchild.cancellation_token().is_cancelled());
4252 }
4253
4254 #[rstest]
4255 #[tokio::test]
4256 async fn test_issued_counter_tracking(log_policy: Arc<LogOnlyPolicy>) {
4257 // Test that issued counter is incremented when tasks are spawned
4258 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2))));
4259 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4260
4261 // Initially no tasks issued
4262 assert_eq!(tracker.metrics().issued(), 0);
4263 assert_eq!(tracker.metrics().pending(), 0);
4264
4265 // Spawn some tasks
4266 let handle1 = tracker.spawn(async { Ok(1) });
4267 let handle2 = tracker.spawn(async { Ok(2) });
4268 let handle3 = tracker.spawn_cancellable(|_| async { CancellableTaskResult::Ok(3) });
4269
4270 // Issued counter should be incremented immediately
4271 assert_eq!(tracker.metrics().issued(), 3);
4272 assert_eq!(tracker.metrics().pending(), 3); // None completed yet
4273
4274 // Complete the tasks
4275 assert_eq!(handle1.await.unwrap().unwrap(), 1);
4276 assert_eq!(handle2.await.unwrap().unwrap(), 2);
4277 assert_eq!(handle3.await.unwrap().unwrap(), 3);
4278
4279 // Check final accounting
4280 assert_eq!(tracker.metrics().issued(), 3);
4281 assert_eq!(tracker.metrics().success(), 3);
4282 assert_eq!(tracker.metrics().total_completed(), 3);
4283 assert_eq!(tracker.metrics().pending(), 0); // All completed
4284
4285 // Test hierarchical accounting
4286 let child = tracker.child_tracker().unwrap();
4287 let child_handle = child.spawn(async { Ok(42) });
4288
4289 // Both parent and child should see the issued task
4290 assert_eq!(child.metrics().issued(), 1);
4291 assert_eq!(tracker.metrics().issued(), 4); // Parent sees all
4292
4293 child_handle.await.unwrap().unwrap();
4294
4295 // Final hierarchical check
4296 assert_eq!(child.metrics().pending(), 0);
4297 assert_eq!(tracker.metrics().pending(), 0);
4298 assert_eq!(tracker.metrics().success(), 4); // Parent sees all successes
4299 }
4300
4301 #[rstest]
4302 #[tokio::test]
4303 async fn test_child_tracker_builder(log_policy: Arc<LogOnlyPolicy>) {
4304 // Test that child tracker builder allows custom policies
4305 let parent_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
4306 let parent = TaskTracker::new(parent_scheduler, log_policy).unwrap();
4307
4308 // Create child with custom error policy
4309 let child_error_policy = CancelOnError::new();
4310 let child = parent
4311 .child_tracker_builder()
4312 .error_policy(child_error_policy)
4313 .build()
4314 .unwrap();
4315
4316 // Test that child works
4317 let handle = child.spawn(async { Ok(42) });
4318 let result = handle.await.unwrap().unwrap();
4319 assert_eq!(result, 42);
4320
4321 // Child should have its own metrics
4322 assert_eq!(child.metrics().success(), 1);
4323 assert_eq!(parent.metrics().total_completed(), 1); // Parent sees aggregated
4324 }
4325
4326 #[rstest]
4327 #[tokio::test]
4328 async fn test_hierarchical_metrics_aggregation(log_policy: Arc<LogOnlyPolicy>) {
4329 // Test that child metrics aggregate up to parent
4330 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
4331 let parent = TaskTracker::new(scheduler, log_policy.clone()).unwrap();
4332
4333 // Create child1 with default settings
4334 let child1 = parent.child_tracker().unwrap();
4335
4336 // Create child2 with custom error policy
4337 let child_error_policy = CancelOnError::new();
4338 let child2 = parent
4339 .child_tracker_builder()
4340 .error_policy(child_error_policy)
4341 .build()
4342 .unwrap();
4343
4344 // Test both custom schedulers and policies
4345 let another_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(3))));
4346 let another_error_policy = CancelOnError::new();
4347 let child3 = parent
4348 .child_tracker_builder()
4349 .scheduler(another_scheduler)
4350 .error_policy(another_error_policy)
4351 .build()
4352 .unwrap();
4353
4354 // Test that all children are properly registered
4355 assert_eq!(parent.child_count(), 3);
4356
4357 // Test that custom schedulers work
4358 let handle1 = child1.spawn(async { Ok(1) });
4359 let handle2 = child2.spawn(async { Ok(2) });
4360 let handle3 = child3.spawn(async { Ok(3) });
4361
4362 assert_eq!(handle1.await.unwrap().unwrap(), 1);
4363 assert_eq!(handle2.await.unwrap().unwrap(), 2);
4364 assert_eq!(handle3.await.unwrap().unwrap(), 3);
4365
4366 // Verify metrics still work
4367 assert_eq!(parent.metrics().success(), 3); // All child successes roll up
4368 assert_eq!(child1.metrics().success(), 1);
4369 assert_eq!(child2.metrics().success(), 1);
4370 assert_eq!(child3.metrics().success(), 1);
4371 }
4372
4373 #[rstest]
4374 #[tokio::test]
4375 async fn test_scheduler_queue_depth_calculation(log_policy: Arc<LogOnlyPolicy>) {
4376 // Test that we can calculate tasks queued in scheduler
4377 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2)))); // Only 2 concurrent tasks
4378 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4379
4380 // Initially no tasks
4381 assert_eq!(tracker.metrics().issued(), 0);
4382 assert_eq!(tracker.metrics().active(), 0);
4383 assert_eq!(tracker.metrics().queued(), 0);
4384 assert_eq!(tracker.metrics().pending(), 0);
4385
4386 // Use a channel to control when tasks complete
4387 let (complete_tx, _complete_rx) = tokio::sync::broadcast::channel(1);
4388
4389 // Spawn 2 tasks that will hold semaphore permits
4390 let handle1 = tracker.spawn({
4391 let mut rx = complete_tx.subscribe();
4392 async move {
4393 // Wait for completion signal
4394 rx.recv().await.ok();
4395 Ok(1)
4396 }
4397 });
4398 let handle2 = tracker.spawn({
4399 let mut rx = complete_tx.subscribe();
4400 async move {
4401 // Wait for completion signal
4402 rx.recv().await.ok();
4403 Ok(2)
4404 }
4405 });
4406
4407 // Give tasks time to start and acquire permits
4408 tokio::task::yield_now().await;
4409 tokio::task::yield_now().await;
4410
4411 // Should have 2 active tasks, 0 queued
4412 assert_eq!(tracker.metrics().issued(), 2);
4413 assert_eq!(tracker.metrics().active(), 2);
4414 assert_eq!(tracker.metrics().queued(), 0);
4415 assert_eq!(tracker.metrics().pending(), 2);
4416
4417 // Spawn a third task - should be queued since semaphore is full
4418 let handle3 = tracker.spawn(async move { Ok(3) });
4419
4420 // Give time for task to be queued
4421 tokio::task::yield_now().await;
4422
4423 // Should have 2 active, 1 queued
4424 assert_eq!(tracker.metrics().issued(), 3);
4425 assert_eq!(tracker.metrics().active(), 2);
4426 assert_eq!(
4427 tracker.metrics().queued(),
4428 tracker.metrics().pending() - tracker.metrics().active()
4429 );
4430 assert_eq!(tracker.metrics().pending(), 3);
4431
4432 // Complete all tasks by sending the signal
4433 complete_tx.send(()).ok();
4434
4435 let result1 = handle1.await.unwrap().unwrap();
4436 let result2 = handle2.await.unwrap().unwrap();
4437 let result3 = handle3.await.unwrap().unwrap();
4438
4439 assert_eq!(result1, 1);
4440 assert_eq!(result2, 2);
4441 assert_eq!(result3, 3);
4442
4443 // All tasks should be completed
4444 assert_eq!(tracker.metrics().success(), 3);
4445 assert_eq!(tracker.metrics().active(), 0);
4446 assert_eq!(tracker.metrics().queued(), 0);
4447 assert_eq!(tracker.metrics().pending(), 0);
4448 }
4449
4450 #[rstest]
4451 #[tokio::test]
4452 async fn test_hierarchical_metrics_failure_aggregation(
4453 semaphore_scheduler: Arc<SemaphoreScheduler>,
4454 log_policy: Arc<LogOnlyPolicy>,
4455 ) {
4456 // Test that failed task metrics aggregate up to parent
4457 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4458 let child = parent.child_tracker().unwrap();
4459
4460 // Run some successful and some failed tasks
4461 let success_handle = child.spawn(async { Ok(42) });
4462 let failure_handle = child.spawn(async { Err::<(), _>(anyhow::anyhow!("test error")) });
4463
4464 // Wait for tasks to complete
4465 let _success_result = success_handle.await.unwrap().unwrap();
4466 let _failure_result = failure_handle.await.unwrap().unwrap_err();
4467
4468 // Check child metrics
4469 assert_eq!(child.metrics().success(), 1, "Child should have 1 success");
4470 assert_eq!(child.metrics().failed(), 1, "Child should have 1 failure");
4471
4472 // Parent should see the aggregated metrics
4473 // Note: Due to hierarchical aggregation, these metrics propagate up
4474 }
4475
4476 #[rstest]
4477 #[tokio::test]
4478 async fn test_metrics_independence_between_tracker_instances(
4479 semaphore_scheduler: Arc<SemaphoreScheduler>,
4480 log_policy: Arc<LogOnlyPolicy>,
4481 ) {
4482 // Test that different tracker instances have independent metrics
4483 let tracker1 = TaskTracker::new(semaphore_scheduler.clone(), log_policy.clone()).unwrap();
4484 let tracker2 = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4485
4486 // Run tasks in both trackers
4487 let handle1 = tracker1.spawn(async { Ok(1) });
4488 let handle2 = tracker2.spawn(async { Ok(2) });
4489
4490 handle1.await.unwrap().unwrap();
4491 handle2.await.unwrap().unwrap();
4492
4493 // Each tracker should only see its own metrics
4494 assert_eq!(tracker1.metrics().success(), 1);
4495 assert_eq!(tracker2.metrics().success(), 1);
4496 assert_eq!(tracker1.metrics().total_completed(), 1);
4497 assert_eq!(tracker2.metrics().total_completed(), 1);
4498 }
4499
4500 #[rstest]
4501 #[tokio::test]
4502 async fn test_hierarchical_join_waits_for_all(log_policy: Arc<LogOnlyPolicy>) {
4503 // Test that parent.join() waits for child tasks too
4504 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
4505 let parent = TaskTracker::new(scheduler, log_policy).unwrap();
4506 let child1 = parent.child_tracker().unwrap();
4507 let child2 = parent.child_tracker().unwrap();
4508 let grandchild = child1.child_tracker().unwrap();
4509
4510 // Verify parent tracks children
4511 assert_eq!(parent.child_count(), 2);
4512 assert_eq!(child1.child_count(), 1);
4513 assert_eq!(child2.child_count(), 0);
4514 assert_eq!(grandchild.child_count(), 0);
4515
4516 // Track completion order
4517 let completion_order = Arc::new(Mutex::new(Vec::new()));
4518
4519 // Spawn tasks with different durations
4520 let order_clone = completion_order.clone();
4521 let parent_handle = parent.spawn(async move {
4522 tokio::time::sleep(Duration::from_millis(50)).await;
4523 order_clone.lock().unwrap().push("parent");
4524 Ok(())
4525 });
4526
4527 let order_clone = completion_order.clone();
4528 let child1_handle = child1.spawn(async move {
4529 tokio::time::sleep(Duration::from_millis(100)).await;
4530 order_clone.lock().unwrap().push("child1");
4531 Ok(())
4532 });
4533
4534 let order_clone = completion_order.clone();
4535 let child2_handle = child2.spawn(async move {
4536 tokio::time::sleep(Duration::from_millis(75)).await;
4537 order_clone.lock().unwrap().push("child2");
4538 Ok(())
4539 });
4540
4541 let order_clone = completion_order.clone();
4542 let grandchild_handle = grandchild.spawn(async move {
4543 tokio::time::sleep(Duration::from_millis(125)).await;
4544 order_clone.lock().unwrap().push("grandchild");
4545 Ok(())
4546 });
4547
4548 // Test hierarchical join - should wait for ALL tasks in hierarchy
4549 println!("[TEST] About to call parent.join()");
4550 let start = std::time::Instant::now();
4551 parent.join().await; // This should wait for ALL tasks
4552 let elapsed = start.elapsed();
4553 println!("[TEST] parent.join() completed in {:?}", elapsed);
4554
4555 // Should have waited for the longest task (grandchild at 125ms)
4556 assert!(
4557 elapsed >= Duration::from_millis(120),
4558 "Hierarchical join should wait for longest task"
4559 );
4560
4561 // All tasks should be complete
4562 assert!(parent_handle.is_finished());
4563 assert!(child1_handle.is_finished());
4564 assert!(child2_handle.is_finished());
4565 assert!(grandchild_handle.is_finished());
4566
4567 // Verify all tasks completed
4568 let final_order = completion_order.lock().unwrap();
4569 assert_eq!(final_order.len(), 4);
4570 assert!(final_order.contains(&"parent"));
4571 assert!(final_order.contains(&"child1"));
4572 assert!(final_order.contains(&"child2"));
4573 assert!(final_order.contains(&"grandchild"));
4574 }
4575
4576 #[rstest]
4577 #[tokio::test]
4578 async fn test_hierarchical_join_waits_for_children(
4579 semaphore_scheduler: Arc<SemaphoreScheduler>,
4580 log_policy: Arc<LogOnlyPolicy>,
4581 ) {
4582 // Test that join() waits for child tasks (hierarchical behavior)
4583 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4584 let child = parent.child_tracker().unwrap();
4585
4586 // Spawn a quick parent task and slow child task
4587 let _parent_handle = parent.spawn(async {
4588 tokio::time::sleep(Duration::from_millis(20)).await;
4589 Ok(())
4590 });
4591
4592 let _child_handle = child.spawn(async {
4593 tokio::time::sleep(Duration::from_millis(100)).await;
4594 Ok(())
4595 });
4596
4597 // Hierarchical join should wait for both parent and child tasks
4598 let start = std::time::Instant::now();
4599 parent.join().await; // Should wait for both (hierarchical by default)
4600 let elapsed = start.elapsed();
4601
4602 // Should have waited for the longer child task (100ms)
4603 assert!(
4604 elapsed >= Duration::from_millis(90),
4605 "Hierarchical join should wait for all child tasks"
4606 );
4607 }
4608
4609 #[rstest]
4610 #[tokio::test]
4611 async fn test_hierarchical_join_operations(
4612 semaphore_scheduler: Arc<SemaphoreScheduler>,
4613 log_policy: Arc<LogOnlyPolicy>,
4614 ) {
4615 // Test that parent.join() closes and waits for child trackers too
4616 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4617 let child = parent.child_tracker().unwrap();
4618 let grandchild = child.child_tracker().unwrap();
4619
4620 // Verify trackers start as open
4621 assert!(!parent.is_closed());
4622 assert!(!child.is_closed());
4623 assert!(!grandchild.is_closed());
4624
4625 // Join parent (hierarchical by default - closes and waits for all)
4626 parent.join().await;
4627
4628 // All should be closed (check child trackers since parent was moved)
4629 assert!(child.is_closed());
4630 assert!(grandchild.is_closed());
4631 }
4632
4633 #[rstest]
4634 #[tokio::test]
4635 async fn test_unlimited_scheduler() {
4636 // Test that UnlimitedScheduler executes tasks immediately
4637 let scheduler = UnlimitedScheduler::new();
4638 let error_policy = LogOnlyPolicy::new();
4639 let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
4640
4641 let (tx, rx) = tokio::sync::oneshot::channel();
4642 let handle = tracker.spawn(async {
4643 rx.await.ok();
4644 Ok(42)
4645 });
4646
4647 // Task should be ready to execute immediately (no concurrency limit)
4648 tx.send(()).ok();
4649 let result = handle.await.unwrap().unwrap();
4650 assert_eq!(result, 42);
4651
4652 assert_eq!(tracker.metrics().success(), 1);
4653 }
4654
4655 #[rstest]
4656 #[tokio::test]
4657 async fn test_threshold_cancel_policy(semaphore_scheduler: Arc<SemaphoreScheduler>) {
4658 // Test that ThresholdCancelPolicy now uses per-task failure counting
4659 let error_policy = ThresholdCancelPolicy::with_threshold(2); // Cancel after 2 failures per task
4660 let tracker = TaskTracker::new(semaphore_scheduler, error_policy.clone()).unwrap();
4661 let cancel_token = tracker.cancellation_token().child_token();
4662
4663 // With per-task context, individual task failures don't accumulate
4664 // Each task starts with failure_count = 0, so single failures won't trigger cancellation
4665 let _handle1 = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("First failure")) });
4666 tokio::task::yield_now().await;
4667 assert!(!cancel_token.is_cancelled());
4668 assert_eq!(error_policy.failure_count(), 1); // Global counter still increments
4669
4670 // Second failure from different task - still won't trigger cancellation
4671 let _handle2 = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("Second failure")) });
4672 tokio::task::yield_now().await;
4673 assert!(!cancel_token.is_cancelled()); // Per-task context prevents cancellation
4674 assert_eq!(error_policy.failure_count(), 2); // Global counter increments
4675
4676 // For cancellation to occur, a single task would need to fail multiple times
4677 // through continuations (which would require a more complex test setup)
4678 }
4679
4680 #[tokio::test]
4681 async fn test_policy_constructors() {
4682 // Test that all constructors follow the new clean API patterns
4683 let _unlimited = UnlimitedScheduler::new();
4684 let _semaphore = SemaphoreScheduler::with_permits(5);
4685 let _log_only = LogOnlyPolicy::new();
4686 let _cancel_policy = CancelOnError::new();
4687 let _threshold_policy = ThresholdCancelPolicy::with_threshold(3);
4688 let _rate_policy = RateCancelPolicy::builder()
4689 .rate(0.5)
4690 .window_secs(60)
4691 .build();
4692
4693 // All constructors return Arc directly - no more ugly ::new_arc patterns
4694 // This test ensures the clean API reduces boilerplate
4695 }
4696
4697 #[rstest]
4698 #[tokio::test]
4699 async fn test_child_creation_fails_after_join(
4700 semaphore_scheduler: Arc<SemaphoreScheduler>,
4701 log_policy: Arc<LogOnlyPolicy>,
4702 ) {
4703 // Test that child tracker creation fails from closed parent
4704 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4705
4706 // Initially, creating a child should work
4707 let _child = parent.child_tracker().unwrap();
4708
4709 // Close the parent tracker
4710 let parent_clone = parent.clone();
4711 parent.join().await;
4712 assert!(parent_clone.is_closed());
4713
4714 // Now, trying to create a child should fail
4715 let result = parent_clone.child_tracker();
4716 assert!(result.is_err());
4717 assert!(
4718 result
4719 .err()
4720 .unwrap()
4721 .to_string()
4722 .contains("closed parent tracker")
4723 );
4724 }
4725
4726 #[rstest]
4727 #[tokio::test]
4728 async fn test_child_builder_fails_after_join(
4729 semaphore_scheduler: Arc<SemaphoreScheduler>,
4730 log_policy: Arc<LogOnlyPolicy>,
4731 ) {
4732 // Test that child tracker builder creation fails from closed parent
4733 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4734
4735 // Initially, creating a child with builder should work
4736 let _child = parent.child_tracker_builder().build().unwrap();
4737
4738 // Close the parent tracker
4739 let parent_clone = parent.clone();
4740 parent.join().await;
4741 assert!(parent_clone.is_closed());
4742
4743 // Now, trying to create a child with builder should fail
4744 let result = parent_clone.child_tracker_builder().build();
4745 assert!(result.is_err());
4746 assert!(
4747 result
4748 .err()
4749 .unwrap()
4750 .to_string()
4751 .contains("closed parent tracker")
4752 );
4753 }
4754
4755 #[rstest]
4756 #[tokio::test]
4757 async fn test_child_creation_succeeds_before_join(
4758 semaphore_scheduler: Arc<SemaphoreScheduler>,
4759 log_policy: Arc<LogOnlyPolicy>,
4760 ) {
4761 // Test that child creation works normally before parent is joined
4762 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4763
4764 // Should be able to create multiple children before closing
4765 let child1 = parent.child_tracker().unwrap();
4766 let child2 = parent.child_tracker_builder().build().unwrap();
4767
4768 // Verify children can spawn tasks
4769 let handle1 = child1.spawn(async { Ok(42) });
4770 let handle2 = child2.spawn(async { Ok(24) });
4771
4772 let result1 = handle1.await.unwrap().unwrap();
4773 let result2 = handle2.await.unwrap().unwrap();
4774
4775 assert_eq!(result1, 42);
4776 assert_eq!(result2, 24);
4777 assert_eq!(parent.metrics().success(), 2); // Parent sees all successes
4778 }
4779
4780 #[rstest]
4781 #[tokio::test]
4782 async fn test_custom_error_response_with_cancellation_token(
4783 semaphore_scheduler: Arc<SemaphoreScheduler>,
4784 ) {
4785 // Test ErrorResponse::Custom behavior with TriggerCancellationTokenOnError
4786
4787 // Create a custom cancellation token
4788 let custom_cancel_token = CancellationToken::new();
4789
4790 // Create the policy that will trigger our custom token
4791 let error_policy = TriggerCancellationTokenOnError::new(custom_cancel_token.clone());
4792
4793 // Create tracker using builder with the custom policy
4794 let tracker = TaskTracker::builder()
4795 .scheduler(semaphore_scheduler)
4796 .error_policy(error_policy)
4797 .cancel_token(custom_cancel_token.clone())
4798 .build()
4799 .unwrap();
4800
4801 let child = tracker.child_tracker().unwrap();
4802
4803 // Initially, the custom token should not be cancelled
4804 assert!(!custom_cancel_token.is_cancelled());
4805
4806 // Spawn a task that will fail
4807 let handle = child.spawn(async {
4808 Err::<(), _>(anyhow::anyhow!("Test error to trigger custom response"))
4809 });
4810
4811 // Wait for the task to complete (it will fail)
4812 let result = handle.await.unwrap();
4813 assert!(result.is_err());
4814
4815 // Await a timeout/deadline or the cancellation token to be cancelled
4816 // The expectation is that the task will fail, and the cancellation token will be triggered
4817 // Hitting the deadline is a failure
4818 tokio::select! {
4819 _ = tokio::time::sleep(Duration::from_secs(1)) => {
4820 panic!("Task should have failed, but hit the deadline");
4821 }
4822 _ = custom_cancel_token.cancelled() => {
4823 // Task should have failed, and the cancellation token should be triggered
4824 }
4825 }
4826
4827 // The custom cancellation token should now be triggered by our policy
4828 assert!(
4829 custom_cancel_token.is_cancelled(),
4830 "Custom cancellation token should be triggered by ErrorResponse::Custom"
4831 );
4832
4833 assert!(tracker.cancellation_token().is_cancelled());
4834 assert!(child.cancellation_token().is_cancelled());
4835
4836 // Verify the error was counted
4837 assert_eq!(tracker.metrics().failed(), 1);
4838 }
4839
4840 #[test]
4841 fn test_action_result_variants() {
4842 // Test that ActionResult variants can be created and pattern matched
4843
4844 // Test Fail variant
4845 let fail_result = ActionResult::Fail;
4846 match fail_result {
4847 ActionResult::Fail => {} // Expected
4848 _ => panic!("Expected Fail variant"),
4849 }
4850
4851 // Test Shutdown variant
4852 let shutdown_result = ActionResult::Shutdown;
4853 match shutdown_result {
4854 ActionResult::Shutdown => {} // Expected
4855 _ => panic!("Expected Shutdown variant"),
4856 }
4857
4858 // Test Continue variant with Continuation
4859 #[derive(Debug)]
4860 struct TestRestartable;
4861
4862 #[async_trait]
4863 impl Continuation for TestRestartable {
4864 async fn execute(
4865 &self,
4866 _cancel_token: CancellationToken,
4867 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4868 TaskExecutionResult::Success(Box::new("test_result".to_string()))
4869 }
4870 }
4871
4872 let test_restartable = Arc::new(TestRestartable);
4873 let continue_result = ActionResult::Continue {
4874 continuation: test_restartable,
4875 };
4876
4877 match continue_result {
4878 ActionResult::Continue { continuation } => {
4879 // Verify we have a valid Continuation
4880 assert!(format!("{:?}", continuation).contains("TestRestartable"));
4881 }
4882 _ => panic!("Expected Continue variant"),
4883 }
4884 }
4885
4886 #[test]
4887 fn test_continuation_error_creation() {
4888 // Test RestartableError creation and conversion to anyhow::Error
4889
4890 // Create a dummy restartable task for testing
4891 #[derive(Debug)]
4892 struct DummyRestartable;
4893
4894 #[async_trait]
4895 impl Continuation for DummyRestartable {
4896 async fn execute(
4897 &self,
4898 _cancel_token: CancellationToken,
4899 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4900 TaskExecutionResult::Success(Box::new("restarted_result".to_string()))
4901 }
4902 }
4903
4904 let dummy_restartable = Arc::new(DummyRestartable);
4905 let source_error = anyhow::anyhow!("Original task failed");
4906
4907 // Test FailedWithContinuation::new
4908 let continuation_error = FailedWithContinuation::new(source_error, dummy_restartable);
4909
4910 // Verify the error displays correctly
4911 let error_string = format!("{}", continuation_error);
4912 assert!(error_string.contains("Task failed with continuation"));
4913 assert!(error_string.contains("Original task failed"));
4914
4915 // Test conversion to anyhow::Error
4916 let anyhow_error = anyhow::Error::new(continuation_error);
4917 assert!(
4918 anyhow_error
4919 .to_string()
4920 .contains("Task failed with continuation")
4921 );
4922 }
4923
4924 #[test]
4925 fn test_continuation_error_ext_trait() {
4926 // Test the RestartableErrorExt trait methods
4927
4928 // Test with regular anyhow::Error (not restartable)
4929 let regular_error = anyhow::anyhow!("Regular error");
4930 assert!(!regular_error.has_continuation());
4931 let extracted = regular_error.extract_continuation();
4932 assert!(extracted.is_none());
4933
4934 // Test with RestartableError
4935 #[derive(Debug)]
4936 struct TestRestartable;
4937
4938 #[async_trait]
4939 impl Continuation for TestRestartable {
4940 async fn execute(
4941 &self,
4942 _cancel_token: CancellationToken,
4943 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4944 TaskExecutionResult::Success(Box::new("test_result".to_string()))
4945 }
4946 }
4947
4948 let test_restartable = Arc::new(TestRestartable);
4949 let source_error = anyhow::anyhow!("Source error");
4950 let continuation_error = FailedWithContinuation::new(source_error, test_restartable);
4951
4952 let anyhow_error = anyhow::Error::new(continuation_error);
4953 assert!(anyhow_error.has_continuation());
4954
4955 // Test extraction of restartable task
4956 let extracted = anyhow_error.extract_continuation();
4957 assert!(extracted.is_some());
4958 }
4959
4960 #[test]
4961 fn test_continuation_error_into_anyhow_helper() {
4962 // Test the convenience method for creating restartable errors
4963 // Note: This test uses a mock TaskExecutor since we don't have real ones yet
4964
4965 // For now, we'll test the type erasure concept with a simple type
4966 struct MockExecutor;
4967
4968 let _source_error = anyhow::anyhow!("Mock task failed");
4969
4970 // We can't test FailedWithContinuation::into_anyhow yet because it requires
4971 // a real TaskExecutor<T>. This will be tested in Phase 3.
4972 // For now, just verify the concept works with manual construction.
4973
4974 #[derive(Debug)]
4975 struct MockRestartable;
4976
4977 #[async_trait]
4978 impl Continuation for MockRestartable {
4979 async fn execute(
4980 &self,
4981 _cancel_token: CancellationToken,
4982 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4983 TaskExecutionResult::Success(Box::new("mock_result".to_string()))
4984 }
4985 }
4986
4987 let mock_restartable = Arc::new(MockRestartable);
4988 let continuation_error =
4989 FailedWithContinuation::new(anyhow::anyhow!("Mock task failed"), mock_restartable);
4990
4991 let anyhow_error = anyhow::Error::new(continuation_error);
4992 assert!(anyhow_error.has_continuation());
4993 }
4994
4995 #[test]
4996 fn test_continuation_error_with_task_executor() {
4997 // Test RestartableError creation with TaskExecutor
4998
4999 #[derive(Debug)]
5000 struct TestRestartableTask;
5001
5002 #[async_trait]
5003 impl Continuation for TestRestartableTask {
5004 async fn execute(
5005 &self,
5006 _cancel_token: CancellationToken,
5007 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5008 TaskExecutionResult::Success(Box::new("test_result".to_string()))
5009 }
5010 }
5011
5012 let restartable_task = Arc::new(TestRestartableTask);
5013 let source_error = anyhow::anyhow!("Task failed");
5014
5015 // Test FailedWithContinuation::new with Restartable
5016 let continuation_error = FailedWithContinuation::new(source_error, restartable_task);
5017
5018 // Verify the error displays correctly
5019 let error_string = format!("{}", continuation_error);
5020 assert!(error_string.contains("Task failed with continuation"));
5021 assert!(error_string.contains("Task failed"));
5022
5023 // Test conversion to anyhow::Error
5024 let anyhow_error = anyhow::Error::new(continuation_error);
5025 assert!(anyhow_error.has_continuation());
5026
5027 // Test extraction (should work now with Restartable trait)
5028 let extracted = anyhow_error.extract_continuation();
5029 assert!(extracted.is_some()); // Should successfully extract the Restartable
5030 }
5031
5032 #[test]
5033 fn test_continuation_error_into_anyhow_convenience() {
5034 // Test the convenience method for creating restartable errors
5035
5036 #[derive(Debug)]
5037 struct ConvenienceRestartable;
5038
5039 #[async_trait]
5040 impl Continuation for ConvenienceRestartable {
5041 async fn execute(
5042 &self,
5043 _cancel_token: CancellationToken,
5044 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5045 TaskExecutionResult::Success(Box::new(42u32))
5046 }
5047 }
5048
5049 let restartable_task = Arc::new(ConvenienceRestartable);
5050 let source_error = anyhow::anyhow!("Computation failed");
5051
5052 // Test FailedWithContinuation::into_anyhow convenience method
5053 let anyhow_error = FailedWithContinuation::into_anyhow(source_error, restartable_task);
5054
5055 assert!(anyhow_error.has_continuation());
5056 assert!(
5057 anyhow_error
5058 .to_string()
5059 .contains("Task failed with continuation")
5060 );
5061 assert!(anyhow_error.to_string().contains("Computation failed"));
5062 }
5063
5064 #[test]
5065 fn test_handle_task_error_with_continuation_error() {
5066 // Test that handle_task_error properly detects RestartableError
5067
5068 // Create a mock Restartable task
5069 #[derive(Debug)]
5070 struct MockRestartableTask;
5071
5072 #[async_trait]
5073 impl Continuation for MockRestartableTask {
5074 async fn execute(
5075 &self,
5076 _cancel_token: CancellationToken,
5077 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5078 TaskExecutionResult::Success(Box::new("retry_result".to_string()))
5079 }
5080 }
5081
5082 let restartable_task = Arc::new(MockRestartableTask);
5083
5084 // Create RestartableError
5085 let source_error = anyhow::anyhow!("Task failed, but can retry");
5086 let continuation_error = FailedWithContinuation::new(source_error, restartable_task);
5087 let anyhow_error = anyhow::Error::new(continuation_error);
5088
5089 // Verify it's detected as restartable
5090 assert!(anyhow_error.has_continuation());
5091
5092 // Verify we can downcast to FailedWithContinuation
5093 let continuation_ref = anyhow_error.downcast_ref::<FailedWithContinuation>();
5094 assert!(continuation_ref.is_some());
5095
5096 // Verify the continuation task is present
5097 let continuation = continuation_ref.unwrap();
5098 // Note: We can verify the Arc is valid by checking that Arc::strong_count > 0
5099 assert!(Arc::strong_count(&continuation.continuation) > 0);
5100 }
5101
5102 #[test]
5103 fn test_handle_task_error_with_regular_error() {
5104 // Test that handle_task_error properly handles regular errors
5105
5106 let regular_error = anyhow::anyhow!("Regular task failure");
5107
5108 // Verify it's not detected as restartable
5109 assert!(!regular_error.has_continuation());
5110
5111 // Verify we cannot downcast to FailedWithContinuation
5112 let continuation_ref = regular_error.downcast_ref::<FailedWithContinuation>();
5113 assert!(continuation_ref.is_none());
5114 }
5115
5116 // ========================================
5117 // END-TO-END ACTIONRESULT TESTS
5118 // ========================================
5119
5120 #[rstest]
5121 #[tokio::test]
5122 async fn test_end_to_end_continuation_execution(
5123 unlimited_scheduler: Arc<UnlimitedScheduler>,
5124 log_policy: Arc<LogOnlyPolicy>,
5125 ) {
5126 // Test that a task returning FailedWithContinuation actually executes the continuation
5127 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
5128
5129 // Shared state to track execution
5130 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5131 let log_clone = execution_log.clone();
5132
5133 // Create a continuation that logs its execution
5134 #[derive(Debug)]
5135 struct LoggingContinuation {
5136 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5137 result: String,
5138 }
5139
5140 #[async_trait]
5141 impl Continuation for LoggingContinuation {
5142 async fn execute(
5143 &self,
5144 _cancel_token: CancellationToken,
5145 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5146 self.log
5147 .lock()
5148 .await
5149 .push("continuation_executed".to_string());
5150 TaskExecutionResult::Success(Box::new(self.result.clone()))
5151 }
5152 }
5153
5154 let continuation = Arc::new(LoggingContinuation {
5155 log: log_clone,
5156 result: "continuation_result".to_string(),
5157 });
5158
5159 // Task that fails with continuation
5160 let log_for_task = execution_log.clone();
5161 let handle = tracker.spawn(async move {
5162 log_for_task
5163 .lock()
5164 .await
5165 .push("original_task_executed".to_string());
5166
5167 // Return FailedWithContinuation
5168 let error = anyhow::anyhow!("Original task failed");
5169 let result: Result<String, anyhow::Error> =
5170 Err(FailedWithContinuation::into_anyhow(error, continuation));
5171 result
5172 });
5173
5174 // Execute and verify the continuation was called
5175 let result = handle.await.expect("Task should complete");
5176 assert!(result.is_ok(), "Continuation should succeed");
5177
5178 // Verify execution order
5179 let log = execution_log.lock().await;
5180 assert_eq!(log.len(), 2);
5181 assert_eq!(log[0], "original_task_executed");
5182 assert_eq!(log[1], "continuation_executed");
5183
5184 // Verify metrics - should show 1 success (from continuation)
5185 assert_eq!(tracker.metrics().success(), 1);
5186 assert_eq!(tracker.metrics().failed(), 0); // Continuation succeeded
5187 assert_eq!(tracker.metrics().cancelled(), 0);
5188 }
5189
5190 #[rstest]
5191 #[tokio::test]
5192 async fn test_end_to_end_multiple_continuations(
5193 unlimited_scheduler: Arc<UnlimitedScheduler>,
5194 log_policy: Arc<LogOnlyPolicy>,
5195 ) {
5196 // Test multiple continuation attempts
5197 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
5198
5199 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5200 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
5201
5202 // Continuation that fails twice, then succeeds
5203 #[derive(Debug)]
5204 struct RetryingContinuation {
5205 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5206 attempt_count: Arc<std::sync::atomic::AtomicU32>,
5207 }
5208
5209 #[async_trait]
5210 impl Continuation for RetryingContinuation {
5211 async fn execute(
5212 &self,
5213 _cancel_token: CancellationToken,
5214 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5215 let attempt = self
5216 .attempt_count
5217 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
5218 + 1;
5219 self.log
5220 .lock()
5221 .await
5222 .push(format!("continuation_attempt_{}", attempt));
5223
5224 if attempt < 3 {
5225 // Fail with another continuation
5226 let next_continuation = Arc::new(RetryingContinuation {
5227 log: self.log.clone(),
5228 attempt_count: self.attempt_count.clone(),
5229 });
5230 let error = anyhow::anyhow!("Continuation attempt {} failed", attempt);
5231 TaskExecutionResult::Error(FailedWithContinuation::into_anyhow(
5232 error,
5233 next_continuation,
5234 ))
5235 } else {
5236 // Succeed on third attempt
5237 TaskExecutionResult::Success(Box::new(format!(
5238 "success_on_attempt_{}",
5239 attempt
5240 )))
5241 }
5242 }
5243 }
5244
5245 let initial_continuation = Arc::new(RetryingContinuation {
5246 log: execution_log.clone(),
5247 attempt_count: attempt_count.clone(),
5248 });
5249
5250 // Task that immediately fails with continuation
5251 let handle = tracker.spawn(async move {
5252 let error = anyhow::anyhow!("Original task failed");
5253 let result: Result<String, anyhow::Error> = Err(FailedWithContinuation::into_anyhow(
5254 error,
5255 initial_continuation,
5256 ));
5257 result
5258 });
5259
5260 // Execute and verify multiple continuations
5261 let result = handle.await.expect("Task should complete");
5262 assert!(result.is_ok(), "Final continuation should succeed");
5263
5264 // Verify all attempts were made
5265 let log = execution_log.lock().await;
5266 assert_eq!(log.len(), 3);
5267 assert_eq!(log[0], "continuation_attempt_1");
5268 assert_eq!(log[1], "continuation_attempt_2");
5269 assert_eq!(log[2], "continuation_attempt_3");
5270
5271 // Verify final attempt count
5272 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 3);
5273
5274 // Verify metrics - should show 1 success (final continuation)
5275 assert_eq!(tracker.metrics().success(), 1);
5276 assert_eq!(tracker.metrics().failed(), 0);
5277 }
5278
5279 #[rstest]
5280 #[tokio::test]
5281 async fn test_end_to_end_continuation_failure(
5282 unlimited_scheduler: Arc<UnlimitedScheduler>,
5283 log_policy: Arc<LogOnlyPolicy>,
5284 ) {
5285 // Test continuation that ultimately fails without providing another continuation
5286 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
5287
5288 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5289 let log_clone = execution_log.clone();
5290
5291 // Continuation that fails without providing another continuation
5292 #[derive(Debug)]
5293 struct FailingContinuation {
5294 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5295 }
5296
5297 #[async_trait]
5298 impl Continuation for FailingContinuation {
5299 async fn execute(
5300 &self,
5301 _cancel_token: CancellationToken,
5302 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5303 self.log
5304 .lock()
5305 .await
5306 .push("continuation_failed".to_string());
5307 TaskExecutionResult::Error(anyhow::anyhow!("Continuation failed permanently"))
5308 }
5309 }
5310
5311 let continuation = Arc::new(FailingContinuation { log: log_clone });
5312
5313 // Task that fails with continuation
5314 let log_for_task = execution_log.clone();
5315 let handle = tracker.spawn(async move {
5316 log_for_task
5317 .lock()
5318 .await
5319 .push("original_task_executed".to_string());
5320
5321 let error = anyhow::anyhow!("Original task failed");
5322 let result: Result<String, anyhow::Error> =
5323 Err(FailedWithContinuation::into_anyhow(error, continuation));
5324 result
5325 });
5326
5327 // Execute and verify the continuation failed
5328 let result = handle.await.expect("Task should complete");
5329 assert!(result.is_err(), "Continuation should fail");
5330
5331 // Verify execution order
5332 let log = execution_log.lock().await;
5333 assert_eq!(log.len(), 2);
5334 assert_eq!(log[0], "original_task_executed");
5335 assert_eq!(log[1], "continuation_failed");
5336
5337 // Verify metrics - should show 1 failure (from continuation)
5338 assert_eq!(tracker.metrics().success(), 0);
5339 assert_eq!(tracker.metrics().failed(), 1);
5340 assert_eq!(tracker.metrics().cancelled(), 0);
5341 }
5342
5343 #[rstest]
5344 #[tokio::test]
5345 async fn test_end_to_end_all_action_result_variants(
5346 unlimited_scheduler: Arc<UnlimitedScheduler>,
5347 ) {
5348 // Comprehensive test of Fail, Shutdown, and Continue paths
5349
5350 // Test 1: ActionResult::Fail (via LogOnlyPolicy)
5351 {
5352 let tracker =
5353 TaskTracker::new(unlimited_scheduler.clone(), LogOnlyPolicy::new()).unwrap();
5354 let handle = tracker.spawn(async {
5355 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Test error"));
5356 result
5357 });
5358 let result = handle.await.expect("Task should complete");
5359 assert!(result.is_err(), "LogOnly should let error through");
5360 assert_eq!(tracker.metrics().failed(), 1);
5361 }
5362
5363 // Test 2: ActionResult::Shutdown (via CancelOnError)
5364 {
5365 let tracker =
5366 TaskTracker::new(unlimited_scheduler.clone(), CancelOnError::new()).unwrap();
5367 let handle = tracker.spawn(async {
5368 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Test error"));
5369 result
5370 });
5371 let result = handle.await.expect("Task should complete");
5372 assert!(result.is_err(), "CancelOnError should fail task");
5373 assert!(
5374 tracker.cancellation_token().is_cancelled(),
5375 "Should cancel tracker"
5376 );
5377 assert_eq!(tracker.metrics().failed(), 1);
5378 }
5379
5380 // Test 3: ActionResult::Continue (via FailedWithContinuation)
5381 {
5382 let tracker =
5383 TaskTracker::new(unlimited_scheduler.clone(), LogOnlyPolicy::new()).unwrap();
5384
5385 #[derive(Debug)]
5386 struct TestContinuation;
5387
5388 #[async_trait]
5389 impl Continuation for TestContinuation {
5390 async fn execute(
5391 &self,
5392 _cancel_token: CancellationToken,
5393 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5394 TaskExecutionResult::Success(Box::new("continuation_success".to_string()))
5395 }
5396 }
5397
5398 let continuation = Arc::new(TestContinuation);
5399 let handle = tracker.spawn(async move {
5400 let error = anyhow::anyhow!("Original failure");
5401 let result: Result<String, anyhow::Error> =
5402 Err(FailedWithContinuation::into_anyhow(error, continuation));
5403 result
5404 });
5405
5406 let result = handle.await.expect("Task should complete");
5407 assert!(result.is_ok(), "Continuation should succeed");
5408 assert_eq!(tracker.metrics().success(), 1);
5409 assert_eq!(tracker.metrics().failed(), 0);
5410 }
5411 }
5412
5413 // ========================================
5414 // LOOP BEHAVIOR AND POLICY INTERACTION TESTS
5415 // ========================================
5416 //
5417 // These tests demonstrate the current ActionResult system and identify
5418 // areas for future improvement:
5419 //
5420 // ✅ WHAT WORKS:
5421 // - All ActionResult variants (Continue, Cancel, ExecuteNext) are tested
5422 // - Task-driven continuations work correctly
5423 // - Policy-driven continuations work correctly
5424 // - Mixed continuation sources work correctly
5425 // - Loop behavior with resource management works correctly
5426 //
5427 // 🔄 CURRENT LIMITATIONS:
5428 // - ThresholdCancelPolicy tracks failures GLOBALLY, not per-task
5429 // - OnErrorPolicy doesn't receive attempt_count parameter
5430 // - No per-task context for stateful retry policies
5431 //
5432 // 🚀 FUTURE IMPROVEMENTS IDENTIFIED:
5433 // - Add OnErrorContext associated type for per-task state
5434 // - Pass attempt_count to OnErrorPolicy::on_error
5435 // - Enable per-task failure tracking, backoff timers, etc.
5436 //
5437 // The tests below demonstrate both current capabilities and limitations.
5438
5439 /// Test retry loop behavior with different policies and continuation counts
5440 ///
5441 /// This test verifies that:
5442 /// 1. Tasks can provide multiple continuations in sequence
5443 /// 2. Different error policies can limit the number of continuation attempts
5444 /// 3. The retry loop correctly handles policy decisions about when to stop
5445 ///
5446 /// Key insight: Policies are only consulted for regular errors, not FailedWithContinuation.
5447 /// So we need continuations that eventually fail with regular errors to test policy limits.
5448 ///
5449 /// DESIGN LIMITATION: Current ThresholdCancelPolicy tracks failures GLOBALLY across all tasks,
5450 /// not per-task. This test demonstrates the current behavior but isn't ideal for retry loop testing.
5451 ///
5452 /// FUTURE IMPROVEMENT: Add OnErrorContext associated type to OnErrorPolicy:
5453 /// ```rust
5454 /// trait OnErrorPolicy {
5455 /// type Context: Default + Send + Sync;
5456 /// fn on_error(&self, error: &anyhow::Error, task_id: TaskId,
5457 /// attempt_count: u32, context: &mut Self::Context) -> ErrorResponse;
5458 /// }
5459 /// ```
5460 /// This would enable per-task failure tracking, backoff timers, etc.
5461 ///
5462 /// NOTE: Uses fresh policy instance for each test case to avoid global state interference.
5463 #[rstest]
5464 #[case(
5465 1,
5466 false,
5467 "Global policy with max_failures=1 should stop after first regular error"
5468 )]
5469 #[case(
5470 2,
5471 false, // Actually fails - ActionResult::Fail accepts the error and fails the task
5472 "Global policy with max_failures=2 allows error but ActionResult::Fail still fails the task"
5473 )]
5474 #[tokio::test]
5475 async fn test_continuation_loop_with_global_threshold_policy(
5476 unlimited_scheduler: Arc<UnlimitedScheduler>,
5477 #[case] max_failures: usize,
5478 #[case] should_succeed: bool,
5479 #[case] description: &str,
5480 ) {
5481 // Task that provides continuations, but continuations fail with regular errors
5482 // so the policy gets consulted and can limit retries
5483
5484 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5485 let attempt_counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
5486
5487 // Create a continuation that fails with regular errors (not FailedWithContinuation)
5488 // This allows the policy to be consulted and potentially stop the retries
5489 #[derive(Debug)]
5490 struct PolicyTestContinuation {
5491 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5492 attempt_counter: Arc<std::sync::atomic::AtomicU32>,
5493 max_attempts_before_success: u32,
5494 }
5495
5496 #[async_trait]
5497 impl Continuation for PolicyTestContinuation {
5498 async fn execute(
5499 &self,
5500 _cancel_token: CancellationToken,
5501 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5502 let attempt = self
5503 .attempt_counter
5504 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
5505 + 1;
5506 self.log
5507 .lock()
5508 .await
5509 .push(format!("continuation_attempt_{}", attempt));
5510
5511 if attempt < self.max_attempts_before_success {
5512 // Fail with regular error - this will be seen by the policy
5513 TaskExecutionResult::Error(anyhow::anyhow!(
5514 "Continuation attempt {} failed (regular error)",
5515 attempt
5516 ))
5517 } else {
5518 // Succeed after enough attempts
5519 TaskExecutionResult::Success(Box::new(format!(
5520 "success_on_attempt_{}",
5521 attempt
5522 )))
5523 }
5524 }
5525 }
5526
5527 // Create fresh policy instance for each test case to avoid global state interference
5528 let policy = ThresholdCancelPolicy::with_threshold(max_failures);
5529 let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
5530
5531 // Original task that fails with continuation
5532 let log_for_task = execution_log.clone();
5533 // Set max_attempts_before_success so that:
5534 // - For max_failures=1: Continuation fails 1 time (attempt 1), policy cancels after 1 failure
5535 // - For max_failures=2: Continuation fails 1 time (attempt 1), succeeds on attempt 2
5536 let continuation = Arc::new(PolicyTestContinuation {
5537 log: execution_log.clone(),
5538 attempt_counter: attempt_counter.clone(),
5539 max_attempts_before_success: 2, // Always fail on attempt 1, succeed on attempt 2
5540 });
5541
5542 let handle = tracker.spawn(async move {
5543 log_for_task
5544 .lock()
5545 .await
5546 .push("original_task_executed".to_string());
5547 let error = anyhow::anyhow!("Original task failed");
5548 let result: Result<String, anyhow::Error> =
5549 Err(FailedWithContinuation::into_anyhow(error, continuation));
5550 result
5551 });
5552
5553 // Execute and check result based on policy
5554 let result = handle.await.expect("Task should complete");
5555
5556 // Debug: Print actual results
5557 let log = execution_log.lock().await;
5558 let final_attempt_count = attempt_counter.load(std::sync::atomic::Ordering::Relaxed);
5559 println!(
5560 "Test case: max_failures={}, should_succeed={}",
5561 max_failures, should_succeed
5562 );
5563 println!("Result: {:?}", result.is_ok());
5564 println!("Log entries: {:?}", log);
5565 println!("Attempt count: {}", final_attempt_count);
5566 println!(
5567 "Metrics: success={}, failed={}",
5568 tracker.metrics().success(),
5569 tracker.metrics().failed()
5570 );
5571 drop(log); // Release the lock
5572
5573 // Both test cases should fail because ActionResult::Fail accepts the error and fails the task
5574 assert!(result.is_err(), "{}: Task should fail", description);
5575 assert_eq!(
5576 tracker.metrics().success(),
5577 0,
5578 "{}: Should have 0 successes",
5579 description
5580 );
5581 assert_eq!(
5582 tracker.metrics().failed(),
5583 1,
5584 "{}: Should have 1 failure",
5585 description
5586 );
5587
5588 // Should have stopped after 1 continuation attempt because ActionResult::Fail fails the task
5589 let log = execution_log.lock().await;
5590 assert_eq!(
5591 log.len(),
5592 2,
5593 "{}: Should have 2 log entries (original + 1 continuation attempt)",
5594 description
5595 );
5596 assert_eq!(log[0], "original_task_executed");
5597 assert_eq!(log[1], "continuation_attempt_1");
5598
5599 assert_eq!(
5600 attempt_counter.load(std::sync::atomic::Ordering::Relaxed),
5601 1,
5602 "{}: Should have made 1 continuation attempt",
5603 description
5604 );
5605
5606 // The key difference is whether the tracker gets cancelled
5607 if max_failures == 1 {
5608 assert!(
5609 tracker.cancellation_token().is_cancelled(),
5610 "Tracker should be cancelled with max_failures=1"
5611 );
5612 } else {
5613 assert!(
5614 !tracker.cancellation_token().is_cancelled(),
5615 "Tracker should NOT be cancelled with max_failures=2 (policy allows the error)"
5616 );
5617 }
5618 }
5619
5620 /// Simple test to understand ThresholdCancelPolicy behavior with per-task context
5621 #[rstest]
5622 #[tokio::test]
5623 async fn test_simple_threshold_policy_behavior(unlimited_scheduler: Arc<UnlimitedScheduler>) {
5624 // Test with max_failures=2 - now uses per-task failure counting
5625 let policy = ThresholdCancelPolicy::with_threshold(2);
5626 let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
5627
5628 // Task 1: Should fail but not trigger cancellation (per-task failure count = 1)
5629 let handle1 = tracker.spawn(async {
5630 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("First failure"));
5631 result
5632 });
5633 let result1 = handle1.await.expect("Task should complete");
5634 assert!(result1.is_err(), "First task should fail");
5635 assert!(
5636 !tracker.cancellation_token().is_cancelled(),
5637 "Should not be cancelled after 1 failure"
5638 );
5639
5640 // Task 2: Should fail but not trigger cancellation (different task, per-task failure count = 1)
5641 let handle2 = tracker.spawn(async {
5642 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Second failure"));
5643 result
5644 });
5645 let result2 = handle2.await.expect("Task should complete");
5646 assert!(result2.is_err(), "Second task should fail");
5647 assert!(
5648 !tracker.cancellation_token().is_cancelled(),
5649 "Should NOT be cancelled - per-task context prevents global accumulation"
5650 );
5651
5652 println!("Policy global failure count: {}", policy.failure_count());
5653 assert_eq!(
5654 policy.failure_count(),
5655 2,
5656 "Policy should have counted 2 failures globally (for backwards compatibility)"
5657 );
5658 }
5659
5660 /// Test demonstrating that per-task error context solves the global failure tracking problem
5661 ///
5662 /// This test shows that with OnErrorContext, each task has independent failure tracking.
5663 #[rstest]
5664 #[tokio::test]
5665 async fn test_per_task_context_limitation_demo(unlimited_scheduler: Arc<UnlimitedScheduler>) {
5666 // Create a policy that should allow 2 failures per task
5667 let policy = ThresholdCancelPolicy::with_threshold(2);
5668 let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
5669
5670 // Task 1: Fails once (per-task failure count = 1, below threshold)
5671 let handle1 = tracker.spawn(async {
5672 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Task 1 failure"));
5673 result
5674 });
5675 let result1 = handle1.await.expect("Task should complete");
5676 assert!(result1.is_err(), "Task 1 should fail");
5677
5678 // Task 2: Also fails once (per-task failure count = 1, below threshold)
5679 // With per-task context, this doesn't interfere with Task 1's failure budget
5680 let handle2 = tracker.spawn(async {
5681 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Task 2 failure"));
5682 result
5683 });
5684 let result2 = handle2.await.expect("Task should complete");
5685 assert!(result2.is_err(), "Task 2 should fail");
5686
5687 // With per-task context, tracker should NOT be cancelled
5688 // Each task failed only once, which is below the threshold of 2
5689 assert!(
5690 !tracker.cancellation_token().is_cancelled(),
5691 "Tracker should NOT be cancelled - per-task context prevents premature cancellation"
5692 );
5693
5694 println!("Global failure count: {}", policy.failure_count());
5695 assert_eq!(
5696 policy.failure_count(),
5697 2,
5698 "Global policy counted 2 failures across different tasks"
5699 );
5700
5701 // This demonstrates the limitation: we can't test per-task retry behavior
5702 // because failures from different tasks affect each other's retry budgets
5703 }
5704
5705 /// Test allow_continuation() policy method with attempt-based logic
5706 ///
5707 /// This test verifies that:
5708 /// 1. Policies can conditionally allow/reject continuations based on context
5709 /// 2. When allow_continuation() returns false, FailedWithContinuation is ignored
5710 /// 3. When allow_continuation() returns true, FailedWithContinuation is processed normally
5711 /// 4. The policy's decision takes precedence over task-provided continuations
5712 #[rstest]
5713 #[case(
5714 3,
5715 true,
5716 "Policy allows continuations up to 3 attempts - should succeed"
5717 )]
5718 #[case(
5719 2,
5720 true,
5721 "Policy allows continuations up to 2 attempts - should succeed"
5722 )]
5723 #[case(0, false, "Policy allows 0 attempts - should fail immediately")]
5724 #[tokio::test]
5725 async fn test_allow_continuation_policy_control(
5726 unlimited_scheduler: Arc<UnlimitedScheduler>,
5727 #[case] max_attempts: u32,
5728 #[case] should_succeed: bool,
5729 #[case] description: &str,
5730 ) {
5731 // Policy that allows continuations only up to max_attempts
5732 #[derive(Debug)]
5733 struct AttemptLimitPolicy {
5734 max_attempts: u32,
5735 }
5736
5737 impl OnErrorPolicy for AttemptLimitPolicy {
5738 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
5739 Arc::new(AttemptLimitPolicy {
5740 max_attempts: self.max_attempts,
5741 })
5742 }
5743
5744 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
5745 None // Stateless policy
5746 }
5747
5748 fn allow_continuation(&self, _error: &anyhow::Error, context: &OnErrorContext) -> bool {
5749 context.attempt_count <= self.max_attempts
5750 }
5751
5752 fn on_error(
5753 &self,
5754 _error: &anyhow::Error,
5755 _context: &mut OnErrorContext,
5756 ) -> ErrorResponse {
5757 ErrorResponse::Fail // Just fail when continuations are not allowed
5758 }
5759 }
5760
5761 let policy = Arc::new(AttemptLimitPolicy { max_attempts });
5762 let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
5763 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5764
5765 // Continuation that always tries to retry
5766 #[derive(Debug)]
5767 struct AlwaysRetryContinuation {
5768 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5769 attempt: u32,
5770 }
5771
5772 #[async_trait]
5773 impl Continuation for AlwaysRetryContinuation {
5774 async fn execute(
5775 &self,
5776 _cancel_token: CancellationToken,
5777 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5778 self.log
5779 .lock()
5780 .await
5781 .push(format!("continuation_attempt_{}", self.attempt));
5782
5783 if self.attempt >= 2 {
5784 // Success after 2 attempts
5785 TaskExecutionResult::Success(Box::new("final_success".to_string()))
5786 } else {
5787 // Try to continue with another continuation
5788 let next_continuation = Arc::new(AlwaysRetryContinuation {
5789 log: self.log.clone(),
5790 attempt: self.attempt + 1,
5791 });
5792 let error = anyhow::anyhow!("Continuation attempt {} failed", self.attempt);
5793 TaskExecutionResult::Error(FailedWithContinuation::into_anyhow(
5794 error,
5795 next_continuation,
5796 ))
5797 }
5798 }
5799 }
5800
5801 // Task that immediately fails with a continuation
5802 let initial_continuation = Arc::new(AlwaysRetryContinuation {
5803 log: execution_log.clone(),
5804 attempt: 1,
5805 });
5806
5807 let log_for_task = execution_log.clone();
5808 let handle = tracker.spawn(async move {
5809 log_for_task
5810 .lock()
5811 .await
5812 .push("initial_task_failure".to_string());
5813 let error = anyhow::anyhow!("Initial task failure");
5814 let result: Result<String, anyhow::Error> = Err(FailedWithContinuation::into_anyhow(
5815 error,
5816 initial_continuation,
5817 ));
5818 result
5819 });
5820
5821 let result = handle.await.expect("Task should complete");
5822
5823 if should_succeed {
5824 assert!(result.is_ok(), "{}: Task should succeed", description);
5825 assert_eq!(
5826 tracker.metrics().success(),
5827 1,
5828 "{}: Should have 1 success",
5829 description
5830 );
5831
5832 // Should have executed multiple continuations
5833 let log = execution_log.lock().await;
5834 assert!(
5835 log.len() > 2,
5836 "{}: Should have multiple log entries",
5837 description
5838 );
5839 assert!(log.contains(&"continuation_attempt_1".to_string()));
5840 } else {
5841 assert!(result.is_err(), "{}: Task should fail", description);
5842 assert_eq!(
5843 tracker.metrics().failed(),
5844 1,
5845 "{}: Should have 1 failure",
5846 description
5847 );
5848
5849 // Should have stopped early due to policy rejection
5850 let log = execution_log.lock().await;
5851 assert_eq!(
5852 log.len(),
5853 1,
5854 "{}: Should only have initial task entry",
5855 description
5856 );
5857 assert_eq!(log[0], "initial_task_failure");
5858 // Should NOT contain continuation attempts because policy rejected them
5859 assert!(
5860 !log.iter()
5861 .any(|entry| entry.contains("continuation_attempt")),
5862 "{}: Should not have continuation attempts, but got: {:?}",
5863 description,
5864 *log
5865 );
5866 }
5867 }
5868
5869 /// Test TaskHandle functionality
5870 ///
5871 /// This test verifies that:
5872 /// 1. TaskHandle can be awaited like a JoinHandle
5873 /// 2. TaskHandle provides access to the task's cancellation token
5874 /// 3. Individual task cancellation works correctly
5875 /// 4. TaskHandle methods (abort, is_finished) work as expected
5876 #[tokio::test]
5877 async fn test_task_handle_functionality() {
5878 let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
5879
5880 // Test basic functionality - TaskHandle can be awaited
5881 let handle1 = tracker.spawn(async {
5882 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
5883 Ok("completed".to_string())
5884 });
5885
5886 // Verify we can access the cancellation token
5887 let cancel_token = handle1.cancellation_token();
5888 assert!(
5889 !cancel_token.is_cancelled(),
5890 "Token should not be cancelled initially"
5891 );
5892
5893 // Await the task
5894 let result1 = handle1.await.expect("Task should complete");
5895 assert!(result1.is_ok(), "Task should succeed");
5896 assert_eq!(result1.unwrap(), "completed");
5897
5898 // Test individual task cancellation
5899 let handle2 = tracker.spawn_cancellable(|cancel_token| async move {
5900 tokio::select! {
5901 _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
5902 CancellableTaskResult::Ok("task_was_not_cancelled".to_string())
5903 },
5904 _ = cancel_token.cancelled() => {
5905 CancellableTaskResult::Cancelled
5906 },
5907
5908 }
5909 });
5910
5911 let cancel_token2 = handle2.cancellation_token();
5912
5913 // Cancel this specific task
5914 cancel_token2.cancel();
5915
5916 // The task should be cancelled
5917 let result2 = handle2.await.expect("Task should complete");
5918 assert!(result2.is_err(), "Task should be cancelled");
5919 assert!(
5920 result2.unwrap_err().is_cancellation(),
5921 "Should be a cancellation error"
5922 );
5923
5924 // Test that other tasks are not affected
5925 let handle3 = tracker.spawn(async { Ok("not_cancelled".to_string()) });
5926
5927 let result3 = handle3.await.expect("Task should complete");
5928 assert!(result3.is_ok(), "Other tasks should not be affected");
5929 assert_eq!(result3.unwrap(), "not_cancelled");
5930
5931 // Test abort functionality
5932 let handle4 = tracker.spawn(async {
5933 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
5934 Ok("should_be_aborted".to_string())
5935 });
5936
5937 // Check is_finished before abort
5938 assert!(!handle4.is_finished(), "Task should not be finished yet");
5939
5940 // Abort the task
5941 handle4.abort();
5942
5943 // Task should be aborted (JoinError)
5944 let result4 = handle4.await;
5945 assert!(result4.is_err(), "Aborted task should return JoinError");
5946
5947 // Verify metrics
5948 assert_eq!(
5949 tracker.metrics().success(),
5950 2,
5951 "Should have 2 successful tasks"
5952 );
5953 assert_eq!(
5954 tracker.metrics().cancelled(),
5955 1,
5956 "Should have 1 cancelled task"
5957 );
5958 // Note: aborted tasks don't count as cancelled in our metrics
5959 }
5960
5961 /// Test TaskHandle with cancellable tasks
5962 #[tokio::test]
5963 async fn test_task_handle_with_cancellable_tasks() {
5964 let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
5965
5966 // Test cancellable task with TaskHandle
5967 let handle = tracker.spawn_cancellable(|cancel_token| async move {
5968 tokio::select! {
5969 _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
5970 CancellableTaskResult::Ok("completed".to_string())
5971 },
5972 _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
5973 }
5974 });
5975
5976 // Verify we can access the task's individual cancellation token
5977 let task_cancel_token = handle.cancellation_token();
5978 assert!(
5979 !task_cancel_token.is_cancelled(),
5980 "Task token should not be cancelled initially"
5981 );
5982
5983 // Let the task complete normally
5984 let result = handle.await.expect("Task should complete");
5985 assert!(result.is_ok(), "Task should succeed");
5986 assert_eq!(result.unwrap(), "completed");
5987
5988 // Test cancellation of cancellable task
5989 let handle2 = tracker.spawn_cancellable(|cancel_token| async move {
5990 tokio::select! {
5991 _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
5992 CancellableTaskResult::Ok("should_not_complete".to_string())
5993 },
5994 _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
5995 }
5996 });
5997
5998 // Cancel the specific task
5999 handle2.cancellation_token().cancel();
6000
6001 let result2 = handle2.await.expect("Task should complete");
6002 assert!(result2.is_err(), "Task should be cancelled");
6003 assert!(
6004 result2.unwrap_err().is_cancellation(),
6005 "Should be a cancellation error"
6006 );
6007
6008 // Verify metrics
6009 assert_eq!(
6010 tracker.metrics().success(),
6011 1,
6012 "Should have 1 successful task"
6013 );
6014 assert_eq!(
6015 tracker.metrics().cancelled(),
6016 1,
6017 "Should have 1 cancelled task"
6018 );
6019 }
6020
6021 /// Test FailedWithContinuation helper methods
6022 ///
6023 /// This test verifies that:
6024 /// 1. from_fn creates working continuations from simple closures
6025 /// 2. from_cancellable creates working continuations from cancellable closures
6026 /// 3. Both helpers integrate correctly with the task execution system
6027 #[tokio::test]
6028 async fn test_continuation_helpers() {
6029 let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
6030
6031 // Test from_fn helper
6032 let handle1 = tracker.spawn(async {
6033 let error =
6034 FailedWithContinuation::from_fn(anyhow::anyhow!("Initial failure"), || async {
6035 Ok("Success from from_fn".to_string())
6036 });
6037 let result: Result<String, anyhow::Error> = Err(error);
6038 result
6039 });
6040
6041 let result1 = handle1.await.expect("Task should complete");
6042 assert!(
6043 result1.is_ok(),
6044 "Task with from_fn continuation should succeed"
6045 );
6046 assert_eq!(result1.unwrap(), "Success from from_fn");
6047
6048 // Test from_cancellable helper
6049 let handle2 = tracker.spawn(async {
6050 let error = FailedWithContinuation::from_cancellable(
6051 anyhow::anyhow!("Initial failure"),
6052 |_cancel_token| async move { Ok("Success from from_cancellable".to_string()) },
6053 );
6054 let result: Result<String, anyhow::Error> = Err(error);
6055 result
6056 });
6057
6058 let result2 = handle2.await.expect("Task should complete");
6059 assert!(
6060 result2.is_ok(),
6061 "Task with from_cancellable continuation should succeed"
6062 );
6063 assert_eq!(result2.unwrap(), "Success from from_cancellable");
6064
6065 // Verify metrics
6066 assert_eq!(
6067 tracker.metrics().success(),
6068 2,
6069 "Should have 2 successful tasks"
6070 );
6071 assert_eq!(tracker.metrics().failed(), 0, "Should have 0 failed tasks");
6072 }
6073
6074 /// Test should_reschedule() policy method with mock scheduler tracking
6075 ///
6076 /// This test verifies that:
6077 /// 1. When should_reschedule() returns false, the guard is reused (efficient)
6078 /// 2. When should_reschedule() returns true, the guard is re-acquired through scheduler
6079 /// 3. The scheduler's acquire_execution_slot is called the expected number of times
6080 /// 4. Rescheduling works for both task-driven and policy-driven continuations
6081 #[rstest]
6082 #[case(false, 1, "Policy requests no rescheduling - should reuse guard")]
6083 #[case(true, 2, "Policy requests rescheduling - should re-acquire guard")]
6084 #[tokio::test]
6085 async fn test_should_reschedule_policy_control(
6086 #[case] should_reschedule: bool,
6087 #[case] expected_acquisitions: u32,
6088 #[case] description: &str,
6089 ) {
6090 // Mock scheduler that tracks acquisition calls
6091 #[derive(Debug)]
6092 struct MockScheduler {
6093 acquisition_count: Arc<AtomicU32>,
6094 }
6095
6096 impl MockScheduler {
6097 fn new() -> Self {
6098 Self {
6099 acquisition_count: Arc::new(AtomicU32::new(0)),
6100 }
6101 }
6102
6103 fn acquisition_count(&self) -> u32 {
6104 self.acquisition_count.load(Ordering::Relaxed)
6105 }
6106 }
6107
6108 #[async_trait]
6109 impl TaskScheduler for MockScheduler {
6110 async fn acquire_execution_slot(
6111 &self,
6112 _cancel_token: CancellationToken,
6113 ) -> SchedulingResult<Box<dyn ResourceGuard>> {
6114 self.acquisition_count.fetch_add(1, Ordering::Relaxed);
6115 SchedulingResult::Execute(Box::new(UnlimitedGuard))
6116 }
6117 }
6118
6119 // Policy that controls rescheduling behavior
6120 #[derive(Debug)]
6121 struct RescheduleTestPolicy {
6122 should_reschedule: bool,
6123 }
6124
6125 impl OnErrorPolicy for RescheduleTestPolicy {
6126 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
6127 Arc::new(RescheduleTestPolicy {
6128 should_reschedule: self.should_reschedule,
6129 })
6130 }
6131
6132 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
6133 None // Stateless policy
6134 }
6135
6136 fn allow_continuation(
6137 &self,
6138 _error: &anyhow::Error,
6139 _context: &OnErrorContext,
6140 ) -> bool {
6141 true // Always allow continuations for this test
6142 }
6143
6144 fn should_reschedule(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
6145 self.should_reschedule
6146 }
6147
6148 fn on_error(
6149 &self,
6150 _error: &anyhow::Error,
6151 _context: &mut OnErrorContext,
6152 ) -> ErrorResponse {
6153 ErrorResponse::Fail // Just fail when continuations are not allowed
6154 }
6155 }
6156
6157 let mock_scheduler = Arc::new(MockScheduler::new());
6158 let policy = Arc::new(RescheduleTestPolicy { should_reschedule });
6159 let tracker = TaskTracker::new(mock_scheduler.clone(), policy).unwrap();
6160 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
6161
6162 // Simple continuation that succeeds on second attempt
6163 #[derive(Debug)]
6164 struct SimpleRetryContinuation {
6165 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6166 }
6167
6168 #[async_trait]
6169 impl Continuation for SimpleRetryContinuation {
6170 async fn execute(
6171 &self,
6172 _cancel_token: CancellationToken,
6173 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
6174 self.log
6175 .lock()
6176 .await
6177 .push("continuation_executed".to_string());
6178
6179 // Succeed immediately
6180 TaskExecutionResult::Success(Box::new("continuation_success".to_string()))
6181 }
6182 }
6183
6184 // Task that fails with a continuation
6185 let continuation = Arc::new(SimpleRetryContinuation {
6186 log: execution_log.clone(),
6187 });
6188
6189 let log_for_task = execution_log.clone();
6190 let handle = tracker.spawn(async move {
6191 log_for_task
6192 .lock()
6193 .await
6194 .push("initial_task_failure".to_string());
6195 let error = anyhow::anyhow!("Initial task failure");
6196 let result: Result<String, anyhow::Error> =
6197 Err(FailedWithContinuation::into_anyhow(error, continuation));
6198 result
6199 });
6200
6201 let result = handle.await.expect("Task should complete");
6202
6203 // Task should succeed regardless of rescheduling behavior
6204 assert!(result.is_ok(), "{}: Task should succeed", description);
6205 assert_eq!(
6206 tracker.metrics().success(),
6207 1,
6208 "{}: Should have 1 success",
6209 description
6210 );
6211
6212 // Verify the execution log
6213 let log = execution_log.lock().await;
6214 assert_eq!(
6215 log.len(),
6216 2,
6217 "{}: Should have initial task + continuation",
6218 description
6219 );
6220 assert_eq!(log[0], "initial_task_failure");
6221 assert_eq!(log[1], "continuation_executed");
6222
6223 // Most importantly: verify the scheduler acquisition count
6224 let actual_acquisitions = mock_scheduler.acquisition_count();
6225 assert_eq!(
6226 actual_acquisitions, expected_acquisitions,
6227 "{}: Expected {} scheduler acquisitions, got {}",
6228 description, expected_acquisitions, actual_acquisitions
6229 );
6230 }
6231
6232 /// Test continuation loop with custom action policies
6233 ///
6234 /// This tests that custom error actions can also provide continuations
6235 /// and that the loop behavior works correctly with policy-provided continuations
6236 ///
6237 /// NOTE: Uses fresh policy/action instances to avoid global state interference.
6238 #[rstest]
6239 #[case(1, true, "Custom action with 1 retry should succeed")]
6240 #[case(3, true, "Custom action with 3 retries should succeed")]
6241 #[tokio::test]
6242 async fn test_continuation_loop_with_custom_action_policy(
6243 unlimited_scheduler: Arc<UnlimitedScheduler>,
6244 #[case] max_retries: u32,
6245 #[case] should_succeed: bool,
6246 #[case] description: &str,
6247 ) {
6248 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
6249 let retry_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
6250
6251 // Custom action that provides continuations up to max_retries
6252 #[derive(Debug)]
6253 struct RetryAction {
6254 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6255 retry_count: Arc<std::sync::atomic::AtomicU32>,
6256 max_retries: u32,
6257 }
6258
6259 #[async_trait]
6260 impl OnErrorAction for RetryAction {
6261 async fn execute(
6262 &self,
6263 _error: &anyhow::Error,
6264 _task_id: TaskId,
6265 _attempt_count: u32,
6266 _context: &TaskExecutionContext,
6267 ) -> ActionResult {
6268 let current_retry = self
6269 .retry_count
6270 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
6271 + 1;
6272 self.log
6273 .lock()
6274 .await
6275 .push(format!("custom_action_retry_{}", current_retry));
6276
6277 if current_retry <= self.max_retries {
6278 // Provide a continuation that succeeds if this is the final retry
6279 #[derive(Debug)]
6280 struct RetryContinuation {
6281 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6282 retry_number: u32,
6283 max_retries: u32,
6284 }
6285
6286 #[async_trait]
6287 impl Continuation for RetryContinuation {
6288 async fn execute(
6289 &self,
6290 _cancel_token: CancellationToken,
6291 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>>
6292 {
6293 self.log
6294 .lock()
6295 .await
6296 .push(format!("retry_continuation_{}", self.retry_number));
6297
6298 if self.retry_number >= self.max_retries {
6299 // Final retry succeeds
6300 TaskExecutionResult::Success(Box::new(format!(
6301 "success_after_{}_retries",
6302 self.retry_number
6303 )))
6304 } else {
6305 // Still need more retries, fail with regular error (not FailedWithContinuation)
6306 // This will trigger the custom action again
6307 TaskExecutionResult::Error(anyhow::anyhow!(
6308 "Retry {} still failing",
6309 self.retry_number
6310 ))
6311 }
6312 }
6313 }
6314
6315 let continuation = Arc::new(RetryContinuation {
6316 log: self.log.clone(),
6317 retry_number: current_retry,
6318 max_retries: self.max_retries,
6319 });
6320
6321 ActionResult::Continue { continuation }
6322 } else {
6323 // Exceeded max retries, cancel
6324 ActionResult::Shutdown
6325 }
6326 }
6327 }
6328
6329 // Custom policy that uses the retry action
6330 #[derive(Debug)]
6331 struct CustomRetryPolicy {
6332 action: Arc<RetryAction>,
6333 }
6334
6335 impl OnErrorPolicy for CustomRetryPolicy {
6336 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
6337 Arc::new(CustomRetryPolicy {
6338 action: self.action.clone(),
6339 })
6340 }
6341
6342 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
6343 None // Stateless policy - no heap allocation
6344 }
6345
6346 fn on_error(
6347 &self,
6348 _error: &anyhow::Error,
6349 _context: &mut OnErrorContext,
6350 ) -> ErrorResponse {
6351 ErrorResponse::Custom(Box::new(RetryAction {
6352 log: self.action.log.clone(),
6353 retry_count: self.action.retry_count.clone(),
6354 max_retries: self.action.max_retries,
6355 }))
6356 }
6357 }
6358
6359 let action = Arc::new(RetryAction {
6360 log: execution_log.clone(),
6361 retry_count: retry_count.clone(),
6362 max_retries,
6363 });
6364 let policy = Arc::new(CustomRetryPolicy { action });
6365 let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
6366
6367 // Task that always fails with regular error (not FailedWithContinuation)
6368 let log_for_task = execution_log.clone();
6369 let handle = tracker.spawn(async move {
6370 log_for_task
6371 .lock()
6372 .await
6373 .push("original_task_failed".to_string());
6374 let result: Result<String, anyhow::Error> =
6375 Err(anyhow::anyhow!("Original task failure"));
6376 result
6377 });
6378
6379 // Execute and verify results
6380 let result = handle.await.expect("Task should complete");
6381
6382 if should_succeed {
6383 assert!(result.is_ok(), "{}: Task should succeed", description);
6384 assert_eq!(
6385 tracker.metrics().success(),
6386 1,
6387 "{}: Should have 1 success",
6388 description
6389 );
6390
6391 // Verify the retry sequence
6392 let log = execution_log.lock().await;
6393 let expected_entries = 1 + (max_retries * 2); // original + (action + continuation) per retry
6394 assert_eq!(
6395 log.len(),
6396 expected_entries as usize,
6397 "{}: Should have {} log entries",
6398 description,
6399 expected_entries
6400 );
6401
6402 assert_eq!(
6403 retry_count.load(std::sync::atomic::Ordering::Relaxed),
6404 max_retries,
6405 "{}: Should have made {} retry attempts",
6406 description,
6407 max_retries
6408 );
6409 } else {
6410 assert!(result.is_err(), "{}: Task should fail", description);
6411 assert!(
6412 tracker.cancellation_token().is_cancelled(),
6413 "{}: Should be cancelled",
6414 description
6415 );
6416
6417 // Should have stopped after max_retries
6418 let final_retry_count = retry_count.load(std::sync::atomic::Ordering::Relaxed);
6419 assert!(
6420 final_retry_count > max_retries,
6421 "{}: Should have exceeded max_retries ({}), got {}",
6422 description,
6423 max_retries,
6424 final_retry_count
6425 );
6426 }
6427 }
6428
6429 /// Test mixed continuation sources (task-driven + policy-driven)
6430 ///
6431 /// This test verifies that both task-provided continuations and policy-provided
6432 /// continuations can work together in the same execution flow
6433 #[rstest]
6434 #[tokio::test]
6435 async fn test_mixed_continuation_sources(
6436 unlimited_scheduler: Arc<UnlimitedScheduler>,
6437 log_policy: Arc<LogOnlyPolicy>,
6438 ) {
6439 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
6440 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
6441
6442 // Task that provides a continuation, which then fails with regular error
6443 let log_for_task = execution_log.clone();
6444 let log_for_continuation = execution_log.clone();
6445
6446 #[derive(Debug)]
6447 struct MixedContinuation {
6448 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6449 }
6450
6451 #[async_trait]
6452 impl Continuation for MixedContinuation {
6453 async fn execute(
6454 &self,
6455 _cancel_token: CancellationToken,
6456 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
6457 self.log
6458 .lock()
6459 .await
6460 .push("task_continuation_executed".to_string());
6461 // This continuation fails with a regular error (not FailedWithContinuation)
6462 // So it will be handled by the policy (LogOnlyPolicy just continues)
6463 TaskExecutionResult::Error(anyhow::anyhow!("Task continuation failed"))
6464 }
6465 }
6466
6467 let continuation = Arc::new(MixedContinuation {
6468 log: log_for_continuation,
6469 });
6470
6471 let handle = tracker.spawn(async move {
6472 log_for_task
6473 .lock()
6474 .await
6475 .push("original_task_executed".to_string());
6476
6477 // Task provides continuation
6478 let error = anyhow::anyhow!("Original task failed");
6479 let result: Result<String, anyhow::Error> =
6480 Err(FailedWithContinuation::into_anyhow(error, continuation));
6481 result
6482 });
6483
6484 // Execute - should fail because continuation fails and LogOnlyPolicy just logs
6485 let result = handle.await.expect("Task should complete");
6486 assert!(
6487 result.is_err(),
6488 "Should fail because continuation fails and policy just logs"
6489 );
6490
6491 // Verify execution sequence
6492 let log = execution_log.lock().await;
6493 assert_eq!(log.len(), 2);
6494 assert_eq!(log[0], "original_task_executed");
6495 assert_eq!(log[1], "task_continuation_executed");
6496
6497 // Verify metrics - should show failure from continuation
6498 assert_eq!(tracker.metrics().success(), 0);
6499 assert_eq!(tracker.metrics().failed(), 1);
6500 }
6501
6502 /// Debug test to understand the threshold policy behavior in retry loop
6503 #[rstest]
6504 #[tokio::test]
6505 async fn debug_threshold_policy_in_retry_loop(unlimited_scheduler: Arc<UnlimitedScheduler>) {
6506 let policy = ThresholdCancelPolicy::with_threshold(2);
6507 let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
6508
6509 // Simple continuation that always fails with regular error
6510 #[derive(Debug)]
6511 struct AlwaysFailContinuation {
6512 attempt: Arc<std::sync::atomic::AtomicU32>,
6513 }
6514
6515 #[async_trait]
6516 impl Continuation for AlwaysFailContinuation {
6517 async fn execute(
6518 &self,
6519 _cancel_token: CancellationToken,
6520 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
6521 let attempt_num = self
6522 .attempt
6523 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
6524 + 1;
6525 println!("Continuation attempt {}", attempt_num);
6526 TaskExecutionResult::Error(anyhow::anyhow!(
6527 "Continuation attempt {} failed",
6528 attempt_num
6529 ))
6530 }
6531 }
6532
6533 let attempt_counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
6534 let continuation = Arc::new(AlwaysFailContinuation {
6535 attempt: attempt_counter.clone(),
6536 });
6537
6538 let handle = tracker.spawn(async move {
6539 println!("Original task executing");
6540 let error = anyhow::anyhow!("Original task failed");
6541 let result: Result<String, anyhow::Error> =
6542 Err(FailedWithContinuation::into_anyhow(error, continuation));
6543 result
6544 });
6545
6546 let result = handle.await.expect("Task should complete");
6547 println!("Final result: {:?}", result.is_ok());
6548 println!("Policy failure count: {}", policy.failure_count());
6549 println!(
6550 "Continuation attempts: {}",
6551 attempt_counter.load(std::sync::atomic::Ordering::Relaxed)
6552 );
6553 println!(
6554 "Tracker cancelled: {}",
6555 tracker.cancellation_token().is_cancelled()
6556 );
6557 println!(
6558 "Metrics: success={}, failed={}",
6559 tracker.metrics().success(),
6560 tracker.metrics().failed()
6561 );
6562
6563 // This should help us understand what's happening
6564 }
6565}