dynamo_runtime/utils/tasks/tracker.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! # Task Tracker - Hierarchical Task Management System
5//!
6//! A composable task management system with configurable scheduling and error handling policies.
7//! The TaskTracker enables controlled concurrent execution with proper resource management,
8//! cancellation semantics, and retry support.
9//!
10//! ## Architecture Overview
11//!
12//! The TaskTracker system is built around three core abstractions that compose together:
13//!
14//! ### 1. **TaskScheduler** - Resource Management
15//!
16//! Controls when and how tasks acquire execution resources (permits, slots, etc.).
17//! Schedulers implement resource acquisition with cancellation support:
18//!
19//! ```text
20//! TaskScheduler::acquire_execution_slot(cancel_token) -> SchedulingResult<ResourceGuard>
21//! ```
22//!
23//! - **Resource Acquisition**: Can be cancelled to avoid unnecessary allocation
24//! - **RAII Guards**: Resources are automatically released when guards are dropped
25//! - **Pluggable**: Different scheduling policies (unlimited, semaphore, rate-limited, etc.)
26//!
27//! ### 2. **OnErrorPolicy** - Error Handling
28//!
29//! Defines how the system responds to task failures:
30//!
31//! ```text
32//! OnErrorPolicy::on_error(error, task_id) -> ErrorResponse
33//! ```
34//!
35//! - **ErrorResponse::Fail**: Log error, fail this task
36//! - **ErrorResponse::Shutdown**: Shutdown tracker and all children
37//! - **ErrorResponse::Custom(action)**: Execute custom logic that can return:
38//! - `ActionResult::Fail`: Handle error and fail the task
39//! - `ActionResult::Shutdown`: Shutdown tracker
40//! - `ActionResult::Continue { continuation }`: Continue with provided task
41//!
42//! ### 3. **Execution Pipeline** - Task Orchestration
43//!
44//! The execution pipeline coordinates scheduling, execution, and error handling:
45//!
46//! ```text
47//! 1. Acquire resources (scheduler.acquire_execution_slot)
48//! 2. Create task future (only after resources acquired)
49//! 3. Execute task while holding guard (RAII pattern)
50//! 4. Handle errors through policy (with retry support for cancellable tasks)
51//! 5. Update metrics and release resources
52//! ```
53//!
54//! ## Key Design Principles
55//!
56//! ### **Separation of Concerns**
57//! - **Scheduling**: When/how to allocate resources
58//! - **Execution**: Running tasks with proper resource management
59//! - **Error Handling**: Responding to failures with configurable policies
60//!
61//! ### **Composability**
62//! - Schedulers and error policies are independent and can be mixed/matched
63//! - Custom policies can be implemented via traits
64//! - Execution pipeline handles the coordination automatically
65//!
66//! ### **Resource Safety**
67//! - Resources are acquired before task creation (prevents early execution)
68//! - RAII pattern ensures resources are always released
69//! - Cancellation is supported during resource acquisition, not during execution
70//!
71//! ### **Retry Support**
72//! - Regular tasks (`spawn`): Cannot be retried (future is consumed)
73//! - Cancellable tasks (`spawn_cancellable`): Support retry via `FnMut` closures
74//! - Error policies can provide next executors via `ActionResult::Continue`
75//!
76//! ## Task Types
77//!
78//! ### Regular Tasks
79//! ```rust
80//! # use dynamo_runtime::utils::tasks::tracker::*;
81//! # #[tokio::main]
82//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
83//! # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
84//! let handle = tracker.spawn(async { Ok(42) });
85//! # let _result = handle.await?;
86//! # Ok(())
87//! # }
88//! ```
89//! - Simple futures that run to completion
90//! - Cannot be retried (future is consumed on first execution)
91//! - Suitable for one-shot operations
92//!
93//! ### Cancellable Tasks
94//! ```rust
95//! # use dynamo_runtime::utils::tasks::tracker::*;
96//! # #[tokio::main]
97//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
98//! # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
99//! let handle = tracker.spawn_cancellable(|cancel_token| async move {
100//! // Task can check cancel_token.is_cancelled() or use tokio::select!
101//! CancellableTaskResult::Ok(42)
102//! });
103//! # let _result = handle.await?;
104//! # Ok(())
105//! # }
106//! ```
107//! - Receive a `CancellationToken` for cooperative cancellation
108//! - Support retry via `FnMut` closures (can be called multiple times)
109//! - Return `CancellableTaskResult` to indicate success/cancellation/error
110//!
111//! ## Hierarchical Structure
112//!
113//! TaskTrackers form parent-child relationships:
114//! - **Metrics**: Child metrics aggregate to parents
115//! - **Cancellation**: Parent cancellation propagates to children
116//! - **Independence**: Child cancellation doesn't affect parents
117//! - **Cleanup**: `join()` waits for all descendants bottom-up
118//!
119//! ## Metrics and Observability
120//!
121//! Built-in metrics track task lifecycle:
122//! - `issued`: Tasks submitted via spawn methods
123//! - `active`: Currently executing tasks
124//! - `success/failed/cancelled/rejected`: Final outcomes
125//! - `pending`: Issued but not completed (issued - completed)
126//! - `queued`: Waiting for resources (pending - active)
127//!
128//! Optional Prometheus integration available via `PrometheusTaskMetrics`.
129//!
130//! ## Usage Examples
131//!
132//! ### Basic Task Execution
133//! ```rust
134//! use dynamo_runtime::utils::tasks::tracker::*;
135//! use std::sync::Arc;
136//!
137//! # #[tokio::main]
138//! # async fn main() -> anyhow::Result<()> {
139//! let scheduler = SemaphoreScheduler::with_permits(10);
140//! let error_policy = LogOnlyPolicy::new();
141//! let tracker = TaskTracker::new(scheduler, error_policy)?;
142//!
143//! let handle = tracker.spawn(async { Ok(42) });
144//! let result = handle.await??;
145//! assert_eq!(result, 42);
146//! # Ok(())
147//! # }
148//! ```
149//!
150//! ### Cancellable Tasks with Retry
151//! ```rust
152//! # use dynamo_runtime::utils::tasks::tracker::*;
153//! # use std::sync::Arc;
154//! # #[tokio::main]
155//! # async fn main() -> anyhow::Result<()> {
156//! # let scheduler = SemaphoreScheduler::with_permits(10);
157//! # let error_policy = LogOnlyPolicy::new();
158//! # let tracker = TaskTracker::new(scheduler, error_policy)?;
159//! let handle = tracker.spawn_cancellable(|cancel_token| async move {
160//! tokio::select! {
161//! result = do_work() => CancellableTaskResult::Ok(result),
162//! _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
163//! }
164//! });
165//! # Ok(())
166//! # }
167//! # async fn do_work() -> i32 { 42 }
168//! ```
169//!
170//! ### Task-Driven Retry with Continuations
171//! ```rust
172//! # use dynamo_runtime::utils::tasks::tracker::*;
173//! # use anyhow::anyhow;
174//! # #[tokio::main]
175//! # async fn main() -> anyhow::Result<()> {
176//! # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new())?;
177//! let handle = tracker.spawn(async {
178//! // Simulate initial failure with retry logic
179//! let error = FailedWithContinuation::from_fn(
180//! anyhow!("Network timeout"),
181//! || async {
182//! println!("Retrying with exponential backoff...");
183//! tokio::time::sleep(std::time::Duration::from_millis(100)).await;
184//! Ok("Success after retry".to_string())
185//! }
186//! );
187//! let result: Result<String, anyhow::Error> = Err(error);
188//! result
189//! });
190//!
191//! let result = handle.await?;
192//! assert!(result.is_ok());
193//! # Ok(())
194//! # }
195//! ```
196//!
197//! ### Custom Error Policy with Continuation
198//! ```rust
199//! # use dynamo_runtime::utils::tasks::tracker::*;
200//! # use std::sync::Arc;
201//! # use async_trait::async_trait;
202//! # #[derive(Debug)]
203//! struct RetryPolicy {
204//! max_attempts: u32,
205//! }
206//!
207//! impl OnErrorPolicy for RetryPolicy {
208//! fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
209//! Arc::new(RetryPolicy { max_attempts: self.max_attempts })
210//! }
211//!
212//! fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
213//! None // Stateless policy
214//! }
215//!
216//! fn on_error(&self, _error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
217//! if context.attempt_count < self.max_attempts {
218//! ErrorResponse::Custom(Box::new(RetryAction))
219//! } else {
220//! ErrorResponse::Fail
221//! }
222//! }
223//! }
224//!
225//! # #[derive(Debug)]
226//! struct RetryAction;
227//!
228//! #[async_trait]
229//! impl OnErrorAction for RetryAction {
230//! async fn execute(
231//! &self,
232//! _error: &anyhow::Error,
233//! _task_id: TaskId,
234//! _attempt_count: u32,
235//! _context: &TaskExecutionContext,
236//! ) -> ActionResult {
237//! // In practice, you would create a continuation here
238//! ActionResult::Fail
239//! }
240//! }
241//! ```
242//!
243//! ## Future Extensibility
244//!
245//! The system is designed for extensibility. See the source code for detailed TODO comments
246//! describing additional policies that can be implemented:
247//! - **Scheduling**: Token bucket rate limiting, adaptive concurrency, memory-aware scheduling
248//! - **Error Handling**: Retry with backoff, circuit breakers, dead letter queues
249//!
250//! Each TODO comment includes complete implementation guidance with data structures,
251//! algorithms, and dependencies needed for future contributors.
252//!
253//! ## Hierarchical Organization
254//!
255//! ```rust
256//! use dynamo_runtime::utils::tasks::tracker::{
257//! TaskTracker, UnlimitedScheduler, ThresholdCancelPolicy, SemaphoreScheduler
258//! };
259//!
260//! # async fn example() -> anyhow::Result<()> {
261//! // Create root tracker with failure threshold policy
262//! let error_policy = ThresholdCancelPolicy::with_threshold(5);
263//! let root = TaskTracker::builder()
264//! .scheduler(UnlimitedScheduler::new())
265//! .error_policy(error_policy)
266//! .build()?;
267//!
268//! // Create child trackers for different components
269//! let api_handler = root.child_tracker()?; // Inherits policies
270//! let background_jobs = root.child_tracker()?;
271//!
272//! // Children can have custom policies
273//! let rate_limited = root.child_tracker_builder()
274//! .scheduler(SemaphoreScheduler::with_permits(2)) // Custom concurrency limit
275//! .build()?;
276//!
277//! // Tasks run independently but metrics roll up
278//! api_handler.spawn(async { Ok(()) });
279//! background_jobs.spawn(async { Ok(()) });
280//! rate_limited.spawn(async { Ok(()) });
281//!
282//! // Join all children hierarchically
283//! root.join().await;
284//! assert_eq!(root.metrics().success(), 3); // Sees all successes
285//! # Ok(())
286//! # }
287//! ```
288//!
289//! ## Policy Examples
290//!
291//! ```rust
292//! use dynamo_runtime::utils::tasks::tracker::{
293//! TaskTracker, CancelOnError, SemaphoreScheduler, ThresholdCancelPolicy
294//! };
295//!
296//! # async fn example() -> anyhow::Result<()> {
297//! // Pattern-based error cancellation
298//! let (error_policy, token) = CancelOnError::with_patterns(
299//! vec!["OutOfMemory".to_string(), "DeviceError".to_string()]
300//! );
301//! let simple = TaskTracker::builder()
302//! .scheduler(SemaphoreScheduler::with_permits(5))
303//! .error_policy(error_policy)
304//! .build()?;
305//!
306//! // Threshold-based cancellation with monitoring
307//! let scheduler = SemaphoreScheduler::with_permits(10); // Returns Arc<SemaphoreScheduler>
308//! let error_policy = ThresholdCancelPolicy::with_threshold(3); // Returns Arc<Policy>
309//!
310//! let advanced = TaskTracker::builder()
311//! .scheduler(scheduler)
312//! .error_policy(error_policy)
313//! .build()?;
314//!
315//! // Monitor cancellation externally
316//! if token.is_cancelled() {
317//! println!("Tracker cancelled due to failures");
318//! }
319//! # Ok(())
320//! # }
321//! ```
322//!
323//! ## Metrics and Observability
324//!
325//! ```rust
326//! use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
327//!
328//! # async fn example() -> anyhow::Result<()> {
329//! let tracker = std::sync::Arc::new(TaskTracker::builder()
330//! .scheduler(SemaphoreScheduler::with_permits(2)) // Only 2 concurrent tasks
331//! .error_policy(LogOnlyPolicy::new())
332//! .build()?);
333//!
334//! // Spawn multiple tasks
335//! for i in 0..5 {
336//! tracker.spawn(async move {
337//! tokio::time::sleep(std::time::Duration::from_millis(100)).await;
338//! Ok(i)
339//! });
340//! }
341//!
342//! // Check metrics
343//! let metrics = tracker.metrics();
344//! println!("Issued: {}", metrics.issued()); // 5 tasks issued
345//! println!("Active: {}", metrics.active()); // 2 tasks running (semaphore limit)
346//! println!("Queued: {}", metrics.queued()); // 3 tasks waiting in scheduler queue
347//! println!("Pending: {}", metrics.pending()); // 5 tasks not yet completed
348//!
349//! tracker.join().await;
350//! assert_eq!(metrics.success(), 5);
351//! assert_eq!(metrics.pending(), 0);
352//! # Ok(())
353//! # }
354//! ```
355//!
356//! ## Prometheus Integration
357//!
358//! ```rust
359//! use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
360//! use dynamo_runtime::metrics::MetricsRegistry;
361//!
362//! # async fn example(registry: &dyn MetricsRegistry) -> anyhow::Result<()> {
363//! // Root tracker with Prometheus metrics
364//! let tracker = TaskTracker::new_with_prometheus(
365//! SemaphoreScheduler::with_permits(10),
366//! LogOnlyPolicy::new(),
367//! registry,
368//! "my_component"
369//! )?;
370//!
371//! // Metrics automatically exported to Prometheus:
372//! // - my_component_tasks_issued_total
373//! // - my_component_tasks_success_total
374//! // - my_component_tasks_failed_total
375//! // - my_component_tasks_active
376//! // - my_component_tasks_queued
377//! # Ok(())
378//! # }
379//! ```
380
381use std::future::Future;
382use std::pin::Pin;
383use std::sync::Arc;
384use std::sync::atomic::{AtomicU64, Ordering};
385
386use crate::metrics::MetricsRegistry;
387use crate::metrics::prometheus_names::task_tracker;
388use anyhow::Result;
389use async_trait::async_trait;
390use derive_builder::Builder;
391use std::collections::HashSet;
392use std::sync::{Mutex, RwLock, Weak};
393use std::time::Duration;
394use thiserror::Error;
395use tokio::sync::Semaphore;
396use tokio::task::JoinHandle;
397use tokio_util::sync::CancellationToken;
398use tokio_util::task::TaskTracker as TokioTaskTracker;
399use tracing::{Instrument, debug, error, warn};
400use uuid::Uuid;
401
402/// Error type for task execution results
403///
404/// This enum distinguishes between task cancellation and actual failures,
405/// enabling proper metrics tracking and error handling.
406#[derive(Error, Debug)]
407pub enum TaskError {
408 /// Task was cancelled (either via cancellation token or tracker shutdown)
409 #[error("Task was cancelled")]
410 Cancelled,
411
412 /// Task failed with an error
413 #[error(transparent)]
414 Failed(#[from] anyhow::Error),
415
416 /// Cannot spawn task on a closed tracker
417 #[error("Cannot spawn task on a closed tracker")]
418 TrackerClosed,
419}
420
421impl TaskError {
422 /// Check if this error represents a cancellation
423 ///
424 /// This is a convenience method for compatibility and readability.
425 pub fn is_cancellation(&self) -> bool {
426 matches!(self, TaskError::Cancelled)
427 }
428
429 /// Check if this error represents a failure
430 pub fn is_failure(&self) -> bool {
431 matches!(self, TaskError::Failed(_))
432 }
433
434 /// Get the underlying anyhow::Error for failures, or a cancellation error for cancellations
435 ///
436 /// This is provided for compatibility with existing code that expects anyhow::Error.
437 pub fn into_anyhow(self) -> anyhow::Error {
438 match self {
439 TaskError::Failed(err) => err,
440 TaskError::Cancelled => anyhow::anyhow!("Task was cancelled"),
441 TaskError::TrackerClosed => anyhow::anyhow!("Cannot spawn task on a closed tracker"),
442 }
443 }
444}
445
446/// A handle to a spawned task that provides both join functionality and cancellation control
447///
448/// `TaskHandle` wraps a `JoinHandle` and provides access to the task's individual cancellation token.
449/// This allows fine-grained control over individual tasks while maintaining the familiar `JoinHandle` API.
450///
451/// # Example
452/// ```rust
453/// # use dynamo_runtime::utils::tasks::tracker::*;
454/// # #[tokio::main]
455/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
456/// # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new())?;
457/// let handle = tracker.spawn(async {
458/// tokio::time::sleep(std::time::Duration::from_millis(100)).await;
459/// Ok(42)
460/// });
461///
462/// // Access the task's cancellation token
463/// let cancel_token = handle.cancellation_token();
464///
465/// // Can cancel the specific task
466/// // cancel_token.cancel();
467///
468/// // Await the task like a normal JoinHandle
469/// let result = handle.await?;
470/// assert_eq!(result?, 42);
471/// # Ok(())
472/// # }
473/// ```
474pub struct TaskHandle<T> {
475 join_handle: JoinHandle<Result<T, TaskError>>,
476 cancel_token: CancellationToken,
477}
478
479impl<T> TaskHandle<T> {
480 /// Create a new TaskHandle wrapping a JoinHandle and cancellation token
481 pub(crate) fn new(
482 join_handle: JoinHandle<Result<T, TaskError>>,
483 cancel_token: CancellationToken,
484 ) -> Self {
485 Self {
486 join_handle,
487 cancel_token,
488 }
489 }
490
491 /// Get the cancellation token for this specific task
492 ///
493 /// This token is a child of the tracker's cancellation token and can be used
494 /// to cancel just this individual task without affecting other tasks.
495 // FIXME: The doctest previously here failed intermittently and may
496 // indicate a bug in either the doctest example or the implementation.
497 pub fn cancellation_token(&self) -> &CancellationToken {
498 &self.cancel_token
499 }
500
501 /// Abort the task associated with this handle
502 ///
503 /// This is equivalent to calling `JoinHandle::abort()` and will cause the task
504 /// to be cancelled immediately without running any cleanup code.
505 pub fn abort(&self) {
506 self.join_handle.abort();
507 }
508
509 /// Check if the task associated with this handle has finished
510 ///
511 /// This is equivalent to calling `JoinHandle::is_finished()`.
512 pub fn is_finished(&self) -> bool {
513 self.join_handle.is_finished()
514 }
515}
516
517impl<T> std::future::Future for TaskHandle<T> {
518 type Output = Result<Result<T, TaskError>, tokio::task::JoinError>;
519
520 fn poll(
521 mut self: std::pin::Pin<&mut Self>,
522 cx: &mut std::task::Context<'_>,
523 ) -> std::task::Poll<Self::Output> {
524 std::pin::Pin::new(&mut self.join_handle).poll(cx)
525 }
526}
527
528impl<T> std::fmt::Debug for TaskHandle<T> {
529 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
530 f.debug_struct("TaskHandle")
531 .field("join_handle", &"<JoinHandle>")
532 .field("cancel_token", &self.cancel_token)
533 .finish()
534 }
535}
536
537/// Trait for continuation tasks that execute after a failure
538///
539/// This trait allows tasks to define what should happen next after a failure,
540/// eliminating the need for complex type erasure and executor management.
541/// Tasks implement this trait to provide clean continuation logic.
542#[async_trait]
543pub trait Continuation: Send + Sync + std::fmt::Debug + std::any::Any {
544 /// Execute the continuation task after a failure
545 ///
546 /// This method is called when a task fails and a continuation is provided.
547 /// The implementation can perform retry logic, fallback operations,
548 /// transformations, or any other follow-up action.
549 /// Returns the result in a type-erased Box<dyn Any> for flexibility.
550 async fn execute(
551 &self,
552 cancel_token: CancellationToken,
553 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>>;
554}
555
556/// Error type that signals a task failed but provided a continuation
557///
558/// This error type contains a continuation task that can be executed as a follow-up.
559/// The task defines its own continuation logic through the Continuation trait.
560#[derive(Error, Debug)]
561#[error("Task failed with continuation: {source}")]
562pub struct FailedWithContinuation {
563 /// The underlying error that caused the task to fail
564 #[source]
565 pub source: anyhow::Error,
566 /// The continuation task for follow-up execution
567 pub continuation: Arc<dyn Continuation + Send + Sync + 'static>,
568}
569
570impl FailedWithContinuation {
571 /// Create a new FailedWithContinuation with a continuation task
572 ///
573 /// The continuation task defines its own execution logic through the Continuation trait.
574 pub fn new(
575 source: anyhow::Error,
576 continuation: Arc<dyn Continuation + Send + Sync + 'static>,
577 ) -> Self {
578 Self {
579 source,
580 continuation,
581 }
582 }
583
584 /// Create a FailedWithContinuation and convert it to anyhow::Error
585 ///
586 /// This is a convenience method for tasks to easily return continuation errors.
587 pub fn into_anyhow(
588 source: anyhow::Error,
589 continuation: Arc<dyn Continuation + Send + Sync + 'static>,
590 ) -> anyhow::Error {
591 anyhow::Error::new(Self::new(source, continuation))
592 }
593
594 /// Create a FailedWithContinuation from a simple async function (no cancellation support)
595 ///
596 /// This is a convenience method for creating continuation errors from simple async closures
597 /// that don't need to handle cancellation. The function will be executed when the
598 /// continuation is triggered.
599 ///
600 /// # Example
601 /// ```rust
602 /// # use dynamo_runtime::utils::tasks::tracker::*;
603 /// # use anyhow::anyhow;
604 /// # #[tokio::main]
605 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
606 /// let error = FailedWithContinuation::from_fn(
607 /// anyhow!("Initial task failed"),
608 /// || async {
609 /// println!("Retrying operation...");
610 /// Ok("retry_result".to_string())
611 /// }
612 /// );
613 /// # Ok(())
614 /// # }
615 /// ```
616 pub fn from_fn<F, Fut, T>(source: anyhow::Error, f: F) -> anyhow::Error
617 where
618 F: Fn() -> Fut + Send + Sync + 'static,
619 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
620 T: Send + 'static,
621 {
622 let continuation = Arc::new(FnContinuation { f: Box::new(f) });
623 Self::into_anyhow(source, continuation)
624 }
625
626 /// Create a FailedWithContinuation from a cancellable async function
627 ///
628 /// This is a convenience method for creating continuation errors from async closures
629 /// that can handle cancellation. The function receives a CancellationToken
630 /// and should check it periodically for cooperative cancellation.
631 ///
632 /// # Example
633 /// ```rust
634 /// # use dynamo_runtime::utils::tasks::tracker::*;
635 /// # use anyhow::anyhow;
636 /// # #[tokio::main]
637 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
638 /// let error = FailedWithContinuation::from_cancellable(
639 /// anyhow!("Initial task failed"),
640 /// |cancel_token| async move {
641 /// if cancel_token.is_cancelled() {
642 /// return Err(anyhow!("Cancelled"));
643 /// }
644 /// println!("Retrying operation with cancellation support...");
645 /// Ok("retry_result".to_string())
646 /// }
647 /// );
648 /// # Ok(())
649 /// # }
650 /// ```
651 pub fn from_cancellable<F, Fut, T>(source: anyhow::Error, f: F) -> anyhow::Error
652 where
653 F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
654 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
655 T: Send + 'static,
656 {
657 let continuation = Arc::new(CancellableFnContinuation { f: Box::new(f) });
658 Self::into_anyhow(source, continuation)
659 }
660}
661
662/// Extension trait for extracting FailedWithContinuation from anyhow::Error
663///
664/// This trait provides methods to detect and extract continuation tasks
665/// from the type-erased anyhow::Error system.
666pub trait FailedWithContinuationExt {
667 /// Extract a continuation task if this error contains one
668 ///
669 /// Returns the continuation task if the error is a FailedWithContinuation,
670 /// None otherwise.
671 fn extract_continuation(&self) -> Option<Arc<dyn Continuation + Send + Sync + 'static>>;
672
673 /// Check if this error has a continuation
674 fn has_continuation(&self) -> bool;
675}
676
677impl FailedWithContinuationExt for anyhow::Error {
678 fn extract_continuation(&self) -> Option<Arc<dyn Continuation + Send + Sync + 'static>> {
679 // Try to downcast to FailedWithContinuation
680 if let Some(continuation_err) = self.downcast_ref::<FailedWithContinuation>() {
681 Some(continuation_err.continuation.clone())
682 } else {
683 None
684 }
685 }
686
687 fn has_continuation(&self) -> bool {
688 self.downcast_ref::<FailedWithContinuation>().is_some()
689 }
690}
691
692/// Implementation of Continuation for simple async functions (no cancellation support)
693struct FnContinuation<F> {
694 f: Box<F>,
695}
696
697impl<F> std::fmt::Debug for FnContinuation<F> {
698 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
699 f.debug_struct("FnContinuation")
700 .field("f", &"<closure>")
701 .finish()
702 }
703}
704
705#[async_trait]
706impl<F, Fut, T> Continuation for FnContinuation<F>
707where
708 F: Fn() -> Fut + Send + Sync + 'static,
709 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
710 T: Send + 'static,
711{
712 async fn execute(
713 &self,
714 _cancel_token: CancellationToken,
715 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
716 match (self.f)().await {
717 Ok(result) => TaskExecutionResult::Success(Box::new(result)),
718 Err(error) => TaskExecutionResult::Error(error),
719 }
720 }
721}
722
723/// Implementation of Continuation for cancellable async functions
724struct CancellableFnContinuation<F> {
725 f: Box<F>,
726}
727
728impl<F> std::fmt::Debug for CancellableFnContinuation<F> {
729 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
730 f.debug_struct("CancellableFnContinuation")
731 .field("f", &"<closure>")
732 .finish()
733 }
734}
735
736#[async_trait]
737impl<F, Fut, T> Continuation for CancellableFnContinuation<F>
738where
739 F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
740 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
741 T: Send + 'static,
742{
743 async fn execute(
744 &self,
745 cancel_token: CancellationToken,
746 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
747 match (self.f)(cancel_token).await {
748 Ok(result) => TaskExecutionResult::Success(Box::new(result)),
749 Err(error) => TaskExecutionResult::Error(error),
750 }
751 }
752}
753
754/// Common scheduling policies for task execution
755///
756/// These enums provide convenient access to built-in scheduling policies
757/// without requiring manual construction of policy objects.
758///
759/// ## Cancellation Semantics
760///
761/// All schedulers follow the same cancellation behavior:
762/// - Respect cancellation tokens before resource allocation (permits, etc.)
763/// - Once task execution begins, always await completion
764/// - Let tasks handle their own cancellation internally
765#[derive(Debug, Clone)]
766pub enum SchedulingPolicy {
767 /// No concurrency limits - execute all tasks immediately
768 Unlimited,
769 /// Semaphore-based concurrency limiting
770 Semaphore(usize),
771 // TODO: Future scheduling policies to implement
772 //
773 // /// Token bucket rate limiting with burst capacity
774 // /// Implementation: Use tokio::time::interval for refill, AtomicU64 for tokens.
775 // /// acquire() decrements tokens, schedule() waits for refill if empty.
776 // /// Burst allows temporary spikes above steady rate.
777 // /// struct: { rate: f64, burst: usize, tokens: AtomicU64, last_refill: Mutex<Instant> }
778 // /// Example: TokenBucket { rate: 10.0, burst: 5 } = 10 tasks/sec, burst up to 5
779 // TokenBucket { rate: f64, burst: usize },
780 //
781 // /// Weighted fair scheduling across multiple priority classes
782 // /// Implementation: Maintain separate VecDeque for each priority class.
783 // /// Use weighted round-robin: serve N tasks from high, M from normal, etc.
784 // /// Track deficit counters to ensure fairness over time.
785 // /// struct: { queues: HashMap<String, VecDeque<Task>>, weights: Vec<(String, u32)> }
786 // /// Example: WeightedFair { weights: vec![("high", 70), ("normal", 20), ("low", 10)] }
787 // WeightedFair { weights: Vec<(String, u32)> },
788 //
789 // /// Memory-aware scheduling that limits tasks based on available memory
790 // /// Implementation: Monitor system memory via /proc/meminfo or sysinfo crate.
791 // /// Pause scheduling when available memory < threshold, resume when memory freed.
792 // /// Use exponential backoff for memory checks to avoid overhead.
793 // /// struct: { max_memory_mb: usize, check_interval: Duration, semaphore: Semaphore }
794 // MemoryAware { max_memory_mb: usize },
795 //
796 // /// CPU-aware scheduling that adjusts concurrency based on CPU load
797 // /// Implementation: Sample system load via sysinfo crate every N seconds.
798 // /// Dynamically resize internal semaphore permits based on load average.
799 // /// Use PID controller for smooth adjustments, avoid oscillation.
800 // /// struct: { max_cpu_percent: f32, permits: Arc<Semaphore>, sampler: tokio::task }
801 // CpuAware { max_cpu_percent: f32 },
802 //
803 // /// Adaptive scheduler that automatically adjusts concurrency based on performance
804 // /// Implementation: Track task latency and throughput in sliding windows.
805 // /// Increase permits if latency low & throughput stable, decrease if latency spikes.
806 // /// Use additive increase, multiplicative decrease (AIMD) algorithm.
807 // /// struct: { permits: AtomicUsize, latency_tracker: RingBuffer, throughput_tracker: RingBuffer }
808 // Adaptive { initial_permits: usize },
809 //
810 // /// Throttling scheduler that enforces minimum time between task starts
811 // /// Implementation: Store last_execution time in AtomicU64 (unix timestamp).
812 // /// Before scheduling, check elapsed time and tokio::time::sleep if needed.
813 // /// Useful for rate-limiting API calls to external services.
814 // /// struct: { min_interval: Duration, last_execution: AtomicU64 }
815 // Throttling { min_interval_ms: u64 },
816 //
817 // /// Batch scheduler that groups tasks and executes them together
818 // /// Implementation: Collect tasks in Vec<Task>, use tokio::time::timeout for max_wait.
819 // /// Execute batch when size reached OR timeout expires, whichever first.
820 // /// Use futures::future::join_all for parallel execution within batch.
821 // /// struct: { batch_size: usize, max_wait: Duration, pending: Mutex<Vec<Task>> }
822 // Batch { batch_size: usize, max_wait_ms: u64 },
823 //
824 // /// Priority-based scheduler with separate queues for different priority levels
825 // /// Implementation: Three separate semaphores for high/normal/low priorities.
826 // /// Always serve high before normal, normal before low (strict priority).
827 // /// Add starvation protection: promote normal->high after timeout.
828 // /// struct: { high_sem: Semaphore, normal_sem: Semaphore, low_sem: Semaphore }
829 // Priority { high: usize, normal: usize, low: usize },
830 //
831 // /// Backpressure-aware scheduler that monitors downstream capacity
832 // /// Implementation: Track external queue depth via provided callback/metric.
833 // /// Pause scheduling when queue_threshold exceeded, resume after pause_duration.
834 // /// Use exponential backoff for repeated backpressure events.
835 // /// struct: { queue_checker: Arc<dyn Fn() -> usize>, threshold: usize, pause_duration: Duration }
836 // Backpressure { queue_threshold: usize, pause_duration_ms: u64 },
837}
838
839/// Trait for implementing error handling policies
840///
841/// Error policies are lightweight, synchronous decision-makers that analyze task failures
842/// and return an ErrorResponse telling the TaskTracker what action to take. The TaskTracker
843/// handles all the actual work (cancellation, metrics, etc.) based on the policy's response.
844///
845/// ## Key Design Principles
846/// - **Synchronous**: Policies make fast decisions without async operations
847/// - **Stateless where possible**: TaskTracker manages cancellation tokens and state
848/// - **Composable**: Policies can be combined and nested in hierarchies
849/// - **Focused**: Each policy handles one specific error pattern or strategy
850///
851/// Per-task error handling context
852///
853/// Provides context information and state management for error policies.
854/// The state field allows policies to maintain per-task state across multiple error attempts.
855pub struct OnErrorContext {
856 /// Number of times this task has been attempted (starts at 1)
857 pub attempt_count: u32,
858 /// Unique identifier of the failed task
859 pub task_id: TaskId,
860 /// Full execution context with access to scheduler, metrics, etc.
861 pub execution_context: TaskExecutionContext,
862 /// Optional per-task state managed by the policy (None for stateless policies)
863 pub state: Option<Box<dyn std::any::Any + Send + 'static>>,
864}
865
866/// Error handling policy trait for task failures
867///
868/// Policies define how the TaskTracker responds to task failures.
869/// They can be stateless (like LogOnlyPolicy) or maintain per-task state
870/// (like ThresholdCancelPolicy with per-task failure counters).
871pub trait OnErrorPolicy: Send + Sync + std::fmt::Debug {
872 /// Create a child policy for a child tracker
873 ///
874 /// This allows policies to maintain hierarchical relationships,
875 /// such as child cancellation tokens or shared circuit breaker state.
876 fn create_child(&self) -> Arc<dyn OnErrorPolicy>;
877
878 /// Create per-task context state (None if policy is stateless)
879 ///
880 /// This method is called once per task when the first error occurs.
881 /// Stateless policies should return None to avoid unnecessary heap allocations.
882 /// Stateful policies should return Some(Box::new(initial_state)).
883 ///
884 /// # Returns
885 /// * `None` - Policy doesn't need per-task state (no heap allocation)
886 /// * `Some(state)` - Initial state for this task (heap allocated when needed)
887 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>>;
888
889 /// Handle a task failure and return the desired response
890 ///
891 /// # Arguments
892 /// * `error` - The error that occurred
893 /// * `context` - Mutable context with attempt count, task info, and optional state
894 ///
895 /// # Returns
896 /// ErrorResponse indicating how the TaskTracker should handle this failure
897 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse;
898
899 /// Should continuations be allowed for this error?
900 ///
901 /// This method is called before checking if a task provided a continuation to determine
902 /// whether the policy allows continuation-based retries at all. If this returns `false`,
903 /// any `FailedWithContinuation` will be ignored and the error will be handled through
904 /// the normal policy response.
905 ///
906 /// # Arguments
907 /// * `error` - The error that occurred
908 /// * `context` - Per-task context with attempt count and state
909 ///
910 /// # Returns
911 /// * `true` - Allow continuations, check for `FailedWithContinuation` (default)
912 /// * `false` - Reject continuations, handle through normal policy response
913 fn allow_continuation(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
914 true // Default: allow continuations
915 }
916
917 /// Should this continuation be rescheduled through the scheduler?
918 ///
919 /// This method is called when a continuation is about to be executed to determine
920 /// whether it should go through the scheduler's acquisition process again or execute
921 /// immediately with the current execution permission.
922 ///
923 /// **What this means:**
924 /// - **Don't reschedule (`false`)**: Execute continuation immediately with current permission
925 /// - **Reschedule (`true`)**: Release current permission, go through scheduler again
926 ///
927 /// Rescheduling means the continuation will be subject to the scheduler's policies
928 /// again (rate limiting, concurrency limits, backoff delays, etc.).
929 ///
930 /// # Arguments
931 /// * `error` - The error that triggered this retry decision
932 /// * `context` - Per-task context with attempt count and state
933 ///
934 /// # Returns
935 /// * `false` - Execute continuation immediately (default, efficient)
936 /// * `true` - Reschedule through scheduler (for delays, rate limiting, backoff)
937 fn should_reschedule(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
938 false // Default: immediate execution
939 }
940}
941
942/// Common error handling policies for task failure management
943///
944/// These enums provide convenient access to built-in error handling policies
945/// without requiring manual construction of policy objects.
946#[derive(Debug, Clone)]
947pub enum ErrorPolicy {
948 /// Log errors but continue execution - no cancellation
949 LogOnly,
950 /// Cancel all tasks on any error (using default error patterns)
951 CancelOnError,
952 /// Cancel all tasks when specific error patterns are encountered
953 CancelOnPatterns(Vec<String>),
954 /// Cancel after a threshold number of failures
955 CancelOnThreshold { max_failures: usize },
956 /// Cancel when failure rate exceeds threshold within time window
957 CancelOnRate {
958 max_failure_rate: f32,
959 window_secs: u64,
960 },
961 // TODO: Future error policies to implement
962 //
963 // /// Retry failed tasks with exponential backoff
964 // /// Implementation: Store original task in retry queue with attempt count.
965 // /// Use tokio::time::sleep for delays: backoff_ms * 2^attempt.
966 // /// Spawn retry as new task, preserve original task_id for tracing.
967 // /// Need task cloning support in scheduler interface.
968 // /// struct: { max_attempts: usize, backoff_ms: u64, retry_queue: VecDeque<(Task, u32)> }
969 // Retry { max_attempts: usize, backoff_ms: u64 },
970 //
971 // /// Send failed tasks to a dead letter queue for later processing
972 // /// Implementation: Use tokio::sync::mpsc::channel for queue.
973 // /// Serialize task info (id, error, payload) for persistence.
974 // /// Background worker drains queue to external storage (Redis/DB).
975 // /// Include retry count and timestamps for debugging.
976 // /// struct: { queue: mpsc::Sender<DeadLetterItem>, storage: Arc<dyn DeadLetterStorage> }
977 // DeadLetter { queue_name: String },
978 //
979 // /// Execute fallback logic when tasks fail
980 // /// Implementation: Store fallback closure in Arc for thread-safety.
981 // /// Execute fallback in same context as failed task (inherit cancel token).
982 // /// Track fallback success/failure separately from original task metrics.
983 // /// Consider using enum for common fallback patterns (default value, noop, etc).
984 // /// struct: { fallback_fn: Arc<dyn Fn(TaskId, Error) -> BoxFuture<Result<()>>> }
985 // Fallback { fallback_fn: Arc<dyn Fn() -> BoxFuture<'static, Result<()>>> },
986 //
987 // /// Circuit breaker pattern - stop executing after threshold failures
988 // /// Implementation: Track state (Closed/Open/HalfOpen) with AtomicU8.
989 // /// Use failure window (last N tasks) or time window for threshold.
990 // /// In Open state, reject tasks immediately, use timer for recovery.
991 // /// In HalfOpen, allow one test task to check if issues resolved.
992 // /// struct: { state: AtomicU8, failure_count: AtomicU64, last_failure: AtomicU64 }
993 // CircuitBreaker { failure_threshold: usize, timeout_secs: u64 },
994 //
995 // /// Resource protection policy that monitors memory/CPU usage
996 // /// Implementation: Background task samples system resources via sysinfo.
997 // /// Cancel tracker when memory > threshold, use process-level monitoring.
998 // /// Implement graceful degradation: warn at 80%, cancel at 90%.
999 // /// Include both system-wide and process-specific thresholds.
1000 // /// struct: { monitor_task: JoinHandle, thresholds: ResourceThresholds, cancel_token: CancellationToken }
1001 // ResourceProtection { max_memory_mb: usize },
1002 //
1003 // /// Timeout policy that cancels tasks exceeding maximum duration
1004 // /// Implementation: Wrap each task with tokio::time::timeout.
1005 // /// Store task start time, check duration in on_error callback.
1006 // /// Distinguish timeout errors from other task failures in metrics.
1007 // /// Consider per-task or global timeout strategies.
1008 // /// struct: { max_duration: Duration, timeout_tracker: HashMap<TaskId, Instant> }
1009 // Timeout { max_duration_secs: u64 },
1010 //
1011 // /// Sampling policy that only logs a percentage of errors
1012 // /// Implementation: Use thread-local RNG for sampling decisions.
1013 // /// Hash task_id for deterministic sampling (same task always sampled).
1014 // /// Store sample rate as f32, compare with rand::random::<f32>().
1015 // /// Include rate in log messages for context.
1016 // /// struct: { sample_rate: f32, rng: ThreadLocal<RefCell<SmallRng>> }
1017 // Sampling { sample_rate: f32 },
1018 //
1019 // /// Aggregating policy that batches error reports
1020 // /// Implementation: Collect errors in Vec, flush on size or time trigger.
1021 // /// Use tokio::time::interval for periodic flushing.
1022 // /// Group errors by type/pattern for better insights.
1023 // /// Include error frequency and rate statistics in reports.
1024 // /// struct: { window: Duration, batch: Mutex<Vec<ErrorEntry>>, flush_task: JoinHandle }
1025 // Aggregating { window_secs: u64, max_batch_size: usize },
1026 //
1027 // /// Alerting policy that sends notifications on error patterns
1028 // /// Implementation: Use reqwest for webhook HTTP calls.
1029 // /// Rate-limit alerts to prevent spam (max N per minute).
1030 // /// Include error context, task info, and system metrics in payload.
1031 // /// Support multiple notification channels (webhook, email, slack).
1032 // /// struct: { client: reqwest::Client, rate_limiter: RateLimiter, alert_config: AlertConfig }
1033 // Alerting { webhook_url: String, severity_threshold: String },
1034}
1035
1036/// Response type for error handling policies
1037///
1038/// This enum defines how the TaskTracker should respond to task failures.
1039/// Currently provides minimal functionality with planned extensions for common patterns.
1040#[derive(Debug)]
1041pub enum ErrorResponse {
1042 /// Just fail this task - error will be logged/counted, but tracker continues
1043 Fail,
1044
1045 /// Shutdown this tracker and all child trackers
1046 Shutdown,
1047
1048 /// Execute custom error handling logic with full context access
1049 Custom(Box<dyn OnErrorAction>),
1050 // TODO: Future specialized error responses to implement:
1051 //
1052 // /// Retry the failed task with configurable strategy
1053 // /// Implementation: Add RetryStrategy trait with delay(), should_continue(attempt_count),
1054 // /// release_and_reacquire_resources() methods. TaskTracker handles retry loop with
1055 // /// attempt counting and resource management. Supports exponential backoff, jitter.
1056 // /// Usage: ErrorResponse::Retry(Box::new(ExponentialBackoff { max_attempts: 3, base_delay: 100ms }))
1057 // Retry(Box<dyn RetryStrategy>),
1058 //
1059 // /// Execute fallback logic, then follow secondary action
1060 // /// Implementation: Add FallbackAction trait with execute(error, task_id) -> Result<(), Error>.
1061 // /// Execute fallback first, then recursively handle the 'then' response based on fallback result.
1062 // /// Enables patterns like: try fallback, if it works continue, if it fails retry original task.
1063 // /// Usage: ErrorResponse::Fallback { fallback: Box::new(DefaultValue(42)), then: Box::new(ErrorResponse::Continue) }
1064 // Fallback { fallback: Box<dyn FallbackAction>, then: Box<ErrorResponse> },
1065 //
1066 // /// Restart task with preserved state (for long-running/stateful tasks)
1067 // /// Implementation: Add TaskState trait for serialize/deserialize state, RestartStrategy trait
1068 // /// with create_continuation_task(state) -> Future. Task saves checkpoints during execution,
1069 // /// on error returns StatefulTaskError containing preserved state. Policy can restart from checkpoint.
1070 // /// Usage: ErrorResponse::RestartWithState { state: checkpointed_state, strategy: Box::new(CheckpointRestart { ... }) }
1071 // RestartWithState { state: Box<dyn TaskState>, strategy: Box<dyn RestartStrategy> },
1072}
1073
1074/// Trait for implementing custom error handling actions
1075///
1076/// This provides full access to the task execution context for complex error handling
1077/// scenarios that don't fit into the built-in response patterns.
1078#[async_trait]
1079pub trait OnErrorAction: Send + Sync + std::fmt::Debug {
1080 /// Execute custom error handling logic
1081 ///
1082 /// # Arguments
1083 /// * `error` - The error that caused the task to fail
1084 /// * `task_id` - Unique identifier of the failed task
1085 /// * `attempt_count` - Number of times this task has been attempted (starts at 1)
1086 /// * `context` - Full execution context with access to scheduler, metrics, etc.
1087 ///
1088 /// # Returns
1089 /// ActionResult indicating what the TaskTracker should do next
1090 async fn execute(
1091 &self,
1092 error: &anyhow::Error,
1093 task_id: TaskId,
1094 attempt_count: u32,
1095 context: &TaskExecutionContext,
1096 ) -> ActionResult;
1097}
1098
1099/// Scheduler execution guard state for conditional re-acquisition during task retry loops
1100///
1101/// This controls whether a continuation should reuse the current scheduler execution permission
1102/// or go through the scheduler's acquisition process again.
1103#[derive(Debug, Clone, PartialEq, Eq)]
1104enum GuardState {
1105 /// Keep the current scheduler execution permission for immediate continuation
1106 ///
1107 /// The continuation will execute immediately without going through the scheduler again.
1108 /// This is efficient for simple retries that don't need delays or rate limiting.
1109 Keep,
1110
1111 /// Release current permission and re-acquire through scheduler before continuation
1112 ///
1113 /// The continuation will be subject to the scheduler's policies again (concurrency limits,
1114 /// rate limiting, backoff delays, etc.). Use this for implementing retry delays or
1115 /// when the scheduler needs to apply its policies to the retry attempt.
1116 Reschedule,
1117}
1118
1119/// Result of a custom error action execution
1120#[derive(Debug)]
1121pub enum ActionResult {
1122 /// Just fail this task (error was logged/handled by policy)
1123 ///
1124 /// This means the policy has handled the error appropriately (e.g., logged it,
1125 /// updated metrics, etc.) and the task should fail with this error.
1126 /// The task execution terminates here.
1127 Fail,
1128
1129 /// Continue execution with the provided task
1130 ///
1131 /// This provides a new executable to continue the retry loop with.
1132 /// The task execution continues with the provided continuation.
1133 Continue {
1134 continuation: Arc<dyn Continuation + Send + Sync + 'static>,
1135 },
1136
1137 /// Shutdown this tracker and all child trackers
1138 ///
1139 /// This triggers shutdown of the entire tracker hierarchy.
1140 /// All running and pending tasks will be cancelled.
1141 Shutdown,
1142}
1143
1144/// Execution context provided to custom error actions
1145///
1146/// This gives custom actions full access to the task execution environment
1147/// for implementing complex error handling scenarios.
1148pub struct TaskExecutionContext {
1149 /// Scheduler for reacquiring resources or checking state
1150 pub scheduler: Arc<dyn TaskScheduler>,
1151
1152 /// Metrics for custom tracking
1153 pub metrics: Arc<dyn HierarchicalTaskMetrics>,
1154 // TODO: Future context additions:
1155 // pub guard: Box<dyn ResourceGuard>, // Current resource guard (needs Debug impl)
1156 // pub cancel_token: CancellationToken, // For implementing custom cancellation
1157 // pub task_recreation: Box<dyn TaskRecreator>, // For implementing retry/restart
1158}
1159
1160/// Result of task execution - unified for both regular and cancellable tasks
1161#[derive(Debug)]
1162pub enum TaskExecutionResult<T> {
1163 /// Task completed successfully
1164 Success(T),
1165 /// Task was cancelled (only possible for cancellable tasks)
1166 Cancelled,
1167 /// Task failed with an error
1168 Error(anyhow::Error),
1169}
1170
1171/// Trait for executing different types of tasks in a unified way
1172#[async_trait]
1173trait TaskExecutor<T>: Send {
1174 /// Execute the task with the given cancellation token
1175 async fn execute(&mut self, cancel_token: CancellationToken) -> TaskExecutionResult<T>;
1176}
1177
1178/// Task executor for regular (non-cancellable) tasks
1179struct RegularTaskExecutor<F, T>
1180where
1181 F: Future<Output = Result<T>> + Send + 'static,
1182 T: Send + 'static,
1183{
1184 future: Option<F>,
1185 _phantom: std::marker::PhantomData<T>,
1186}
1187
1188impl<F, T> RegularTaskExecutor<F, T>
1189where
1190 F: Future<Output = Result<T>> + Send + 'static,
1191 T: Send + 'static,
1192{
1193 fn new(future: F) -> Self {
1194 Self {
1195 future: Some(future),
1196 _phantom: std::marker::PhantomData,
1197 }
1198 }
1199}
1200
1201#[async_trait]
1202impl<F, T> TaskExecutor<T> for RegularTaskExecutor<F, T>
1203where
1204 F: Future<Output = Result<T>> + Send + 'static,
1205 T: Send + 'static,
1206{
1207 async fn execute(&mut self, _cancel_token: CancellationToken) -> TaskExecutionResult<T> {
1208 if let Some(future) = self.future.take() {
1209 match future.await {
1210 Ok(value) => TaskExecutionResult::Success(value),
1211 Err(error) => TaskExecutionResult::Error(error),
1212 }
1213 } else {
1214 // This should never happen since regular tasks don't support retry
1215 TaskExecutionResult::Error(anyhow::anyhow!("Regular task already consumed"))
1216 }
1217 }
1218}
1219
1220/// Task executor for cancellable tasks
1221struct CancellableTaskExecutor<F, Fut, T>
1222where
1223 F: FnMut(CancellationToken) -> Fut + Send + 'static,
1224 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
1225 T: Send + 'static,
1226{
1227 task_fn: F,
1228}
1229
1230impl<F, Fut, T> CancellableTaskExecutor<F, Fut, T>
1231where
1232 F: FnMut(CancellationToken) -> Fut + Send + 'static,
1233 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
1234 T: Send + 'static,
1235{
1236 fn new(task_fn: F) -> Self {
1237 Self { task_fn }
1238 }
1239}
1240
1241#[async_trait]
1242impl<F, Fut, T> TaskExecutor<T> for CancellableTaskExecutor<F, Fut, T>
1243where
1244 F: FnMut(CancellationToken) -> Fut + Send + 'static,
1245 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
1246 T: Send + 'static,
1247{
1248 async fn execute(&mut self, cancel_token: CancellationToken) -> TaskExecutionResult<T> {
1249 let future = (self.task_fn)(cancel_token);
1250 match future.await {
1251 CancellableTaskResult::Ok(value) => TaskExecutionResult::Success(value),
1252 CancellableTaskResult::Cancelled => TaskExecutionResult::Cancelled,
1253 CancellableTaskResult::Err(error) => TaskExecutionResult::Error(error),
1254 }
1255 }
1256}
1257
1258/// Common functionality for policy Arc construction
1259///
1260/// This trait provides a standardized `new_arc()` method for all policy types,
1261/// eliminating the need for manual `Arc::new()` calls in client code.
1262pub trait ArcPolicy: Sized + Send + Sync + 'static {
1263 /// Create an Arc-wrapped instance of this policy
1264 fn new_arc(self) -> Arc<Self> {
1265 Arc::new(self)
1266 }
1267}
1268
1269/// Unique identifier for a task
1270#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1271pub struct TaskId(Uuid);
1272
1273impl TaskId {
1274 fn new() -> Self {
1275 Self(Uuid::new_v4())
1276 }
1277}
1278
1279impl std::fmt::Display for TaskId {
1280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1281 write!(f, "task-{}", self.0)
1282 }
1283}
1284
1285/// Result of task execution
1286#[derive(Debug, Clone, PartialEq)]
1287pub enum CompletionStatus {
1288 /// Task completed successfully
1289 Ok,
1290 /// Task was cancelled before or during execution
1291 Cancelled,
1292 /// Task failed with an error
1293 Failed(String),
1294}
1295
1296/// Result type for cancellable tasks that explicitly tracks cancellation
1297#[derive(Debug)]
1298pub enum CancellableTaskResult<T> {
1299 /// Task completed successfully
1300 Ok(T),
1301 /// Task was cancelled (either via token or shutdown)
1302 Cancelled,
1303 /// Task failed with an error
1304 Err(anyhow::Error),
1305}
1306
1307/// Result of scheduling a task
1308#[derive(Debug)]
1309pub enum SchedulingResult<T> {
1310 /// Task was executed and completed
1311 Execute(T),
1312 /// Task was cancelled before execution
1313 Cancelled,
1314 /// Task was rejected due to scheduling policy
1315 Rejected(String),
1316}
1317
1318/// Resource guard that manages task execution
1319///
1320/// This trait enforces proper cancellation semantics by separating resource
1321/// management from task execution. Once a guard is acquired, task execution
1322/// must always run to completion.
1323/// Resource guard for task execution
1324///
1325/// This trait represents resources (permits, slots, etc.) acquired from a scheduler
1326/// that must be held during task execution. The guard automatically releases
1327/// resources when dropped, implementing proper RAII semantics.
1328///
1329/// Guards are returned by `TaskScheduler::acquire_execution_slot()` and must
1330/// be held in scope while the task executes to ensure resources remain allocated.
1331pub trait ResourceGuard: Send + 'static {
1332 // Marker trait - resources are released via Drop on the concrete type
1333}
1334
1335/// Trait for implementing task scheduling policies
1336///
1337/// This trait enforces proper cancellation semantics by splitting resource
1338/// acquisition (which can be cancelled) from task execution (which cannot).
1339///
1340/// ## Design Philosophy
1341///
1342/// Tasks may or may not support cancellation (depending on whether they were
1343/// created with `spawn_cancellable` or regular `spawn`). This split design ensures:
1344///
1345/// - **Resource acquisition**: Can respect cancellation tokens to avoid unnecessary allocation
1346/// - **Task execution**: Always runs to completion; tasks handle their own cancellation
1347///
1348/// This makes it impossible to accidentally interrupt task execution with `tokio::select!`.
1349#[async_trait]
1350pub trait TaskScheduler: Send + Sync + std::fmt::Debug {
1351 /// Acquire resources needed for task execution and return a guard
1352 ///
1353 /// This method handles resource allocation (permits, queue slots, etc.) and
1354 /// can respect cancellation tokens to avoid unnecessary resource consumption.
1355 ///
1356 /// ## Cancellation Behavior
1357 ///
1358 /// The `cancel_token` is used for scheduler-level cancellation (e.g., "don't start new work").
1359 /// If cancellation is requested before or during resource acquisition, this method
1360 /// should return `SchedulingResult::Cancelled`.
1361 ///
1362 /// # Arguments
1363 /// * `cancel_token` - [`CancellationToken`] for scheduler-level cancellation
1364 ///
1365 /// # Returns
1366 /// * `SchedulingResult::Execute(guard)` - Resources acquired, ready to execute
1367 /// * `SchedulingResult::Cancelled` - Cancelled before or during resource acquisition
1368 /// * `SchedulingResult::Rejected(reason)` - Resources unavailable or policy violation
1369 async fn acquire_execution_slot(
1370 &self,
1371 cancel_token: CancellationToken,
1372 ) -> SchedulingResult<Box<dyn ResourceGuard>>;
1373}
1374
1375/// Trait for hierarchical task metrics that supports aggregation up the tracker tree
1376///
1377/// This trait provides different implementations for root and child trackers:
1378/// - Root trackers integrate with Prometheus metrics for observability
1379/// - Child trackers chain metric updates up to their parents for aggregation
1380/// - All implementations maintain thread-safe atomic operations
1381pub trait HierarchicalTaskMetrics: Send + Sync + std::fmt::Debug {
1382 /// Increment issued task counter
1383 fn increment_issued(&self);
1384
1385 /// Increment started task counter
1386 fn increment_started(&self);
1387
1388 /// Increment success counter
1389 fn increment_success(&self);
1390
1391 /// Increment cancelled counter
1392 fn increment_cancelled(&self);
1393
1394 /// Increment failed counter
1395 fn increment_failed(&self);
1396
1397 /// Increment rejected counter
1398 fn increment_rejected(&self);
1399
1400 /// Get current issued count (local to this tracker)
1401 fn issued(&self) -> u64;
1402
1403 /// Get current started count (local to this tracker)
1404 fn started(&self) -> u64;
1405
1406 /// Get current success count (local to this tracker)
1407 fn success(&self) -> u64;
1408
1409 /// Get current cancelled count (local to this tracker)
1410 fn cancelled(&self) -> u64;
1411
1412 /// Get current failed count (local to this tracker)
1413 fn failed(&self) -> u64;
1414
1415 /// Get current rejected count (local to this tracker)
1416 fn rejected(&self) -> u64;
1417
1418 /// Get total completed tasks (success + cancelled + failed + rejected)
1419 fn total_completed(&self) -> u64 {
1420 self.success() + self.cancelled() + self.failed() + self.rejected()
1421 }
1422
1423 /// Get number of pending tasks (issued - completed)
1424 fn pending(&self) -> u64 {
1425 self.issued().saturating_sub(self.total_completed())
1426 }
1427
1428 /// Get the number of tasks that are currently active (started - completed)
1429 fn active(&self) -> u64 {
1430 self.started().saturating_sub(self.total_completed())
1431 }
1432
1433 /// Get number of tasks queued in scheduler (issued - started)
1434 fn queued(&self) -> u64 {
1435 self.issued().saturating_sub(self.started())
1436 }
1437}
1438
1439/// Task execution metrics for a tracker
1440#[derive(Debug, Default)]
1441pub struct TaskMetrics {
1442 /// Number of tasks issued/submitted (via spawn methods)
1443 pub issued_count: AtomicU64,
1444 /// Number of tasks that have started execution
1445 pub started_count: AtomicU64,
1446 /// Number of successfully completed tasks
1447 pub success_count: AtomicU64,
1448 /// Number of cancelled tasks
1449 pub cancelled_count: AtomicU64,
1450 /// Number of failed tasks
1451 pub failed_count: AtomicU64,
1452 /// Number of rejected tasks (by scheduler)
1453 pub rejected_count: AtomicU64,
1454}
1455
1456impl TaskMetrics {
1457 /// Create new metrics instance
1458 pub fn new() -> Self {
1459 Self::default()
1460 }
1461}
1462
1463impl HierarchicalTaskMetrics for TaskMetrics {
1464 /// Increment issued task counter
1465 fn increment_issued(&self) {
1466 self.issued_count.fetch_add(1, Ordering::Relaxed);
1467 }
1468
1469 /// Increment started task counter
1470 fn increment_started(&self) {
1471 self.started_count.fetch_add(1, Ordering::Relaxed);
1472 }
1473
1474 /// Increment success counter
1475 fn increment_success(&self) {
1476 self.success_count.fetch_add(1, Ordering::Relaxed);
1477 }
1478
1479 /// Increment cancelled counter
1480 fn increment_cancelled(&self) {
1481 self.cancelled_count.fetch_add(1, Ordering::Relaxed);
1482 }
1483
1484 /// Increment failed counter
1485 fn increment_failed(&self) {
1486 self.failed_count.fetch_add(1, Ordering::Relaxed);
1487 }
1488
1489 /// Increment rejected counter
1490 fn increment_rejected(&self) {
1491 self.rejected_count.fetch_add(1, Ordering::Relaxed);
1492 }
1493
1494 /// Get current issued count
1495 fn issued(&self) -> u64 {
1496 self.issued_count.load(Ordering::Relaxed)
1497 }
1498
1499 /// Get current started count
1500 fn started(&self) -> u64 {
1501 self.started_count.load(Ordering::Relaxed)
1502 }
1503
1504 /// Get current success count
1505 fn success(&self) -> u64 {
1506 self.success_count.load(Ordering::Relaxed)
1507 }
1508
1509 /// Get current cancelled count
1510 fn cancelled(&self) -> u64 {
1511 self.cancelled_count.load(Ordering::Relaxed)
1512 }
1513
1514 /// Get current failed count
1515 fn failed(&self) -> u64 {
1516 self.failed_count.load(Ordering::Relaxed)
1517 }
1518
1519 /// Get current rejected count
1520 fn rejected(&self) -> u64 {
1521 self.rejected_count.load(Ordering::Relaxed)
1522 }
1523}
1524
1525/// Root tracker metrics with Prometheus integration
1526///
1527/// This implementation maintains local counters and exposes them as Prometheus metrics
1528/// through the provided MetricsRegistry.
1529#[derive(Debug)]
1530pub struct PrometheusTaskMetrics {
1531 /// Prometheus metrics integration
1532 prometheus_issued: prometheus::IntCounter,
1533 prometheus_started: prometheus::IntCounter,
1534 prometheus_success: prometheus::IntCounter,
1535 prometheus_cancelled: prometheus::IntCounter,
1536 prometheus_failed: prometheus::IntCounter,
1537 prometheus_rejected: prometheus::IntCounter,
1538}
1539
1540impl PrometheusTaskMetrics {
1541 /// Create new root metrics with Prometheus integration
1542 ///
1543 /// # Arguments
1544 /// * `registry` - MetricsRegistry for creating Prometheus metrics
1545 /// * `component_name` - Name for the component/tracker (used in metric names)
1546 ///
1547 /// # Example
1548 /// ```rust
1549 /// # use std::sync::Arc;
1550 /// # use dynamo_runtime::utils::tasks::tracker::PrometheusTaskMetrics;
1551 /// # use dynamo_runtime::metrics::MetricsRegistry;
1552 /// # fn example(registry: Arc<dyn MetricsRegistry>) -> anyhow::Result<()> {
1553 /// let metrics = PrometheusTaskMetrics::new(registry.as_ref(), "main_tracker")?;
1554 /// # Ok(())
1555 /// # }
1556 /// ```
1557 pub fn new<R: MetricsRegistry + ?Sized>(
1558 registry: &R,
1559 component_name: &str,
1560 ) -> anyhow::Result<Self> {
1561 let issued_counter = registry.create_intcounter(
1562 &format!("{}_{}", component_name, task_tracker::TASKS_ISSUED_TOTAL),
1563 "Total number of tasks issued/submitted",
1564 &[],
1565 )?;
1566
1567 let started_counter = registry.create_intcounter(
1568 &format!("{}_{}", component_name, task_tracker::TASKS_STARTED_TOTAL),
1569 "Total number of tasks started",
1570 &[],
1571 )?;
1572
1573 let success_counter = registry.create_intcounter(
1574 &format!("{}_{}", component_name, task_tracker::TASKS_SUCCESS_TOTAL),
1575 "Total number of successfully completed tasks",
1576 &[],
1577 )?;
1578
1579 let cancelled_counter = registry.create_intcounter(
1580 &format!("{}_{}", component_name, task_tracker::TASKS_CANCELLED_TOTAL),
1581 "Total number of cancelled tasks",
1582 &[],
1583 )?;
1584
1585 let failed_counter = registry.create_intcounter(
1586 &format!("{}_{}", component_name, task_tracker::TASKS_FAILED_TOTAL),
1587 "Total number of failed tasks",
1588 &[],
1589 )?;
1590
1591 let rejected_counter = registry.create_intcounter(
1592 &format!("{}_{}", component_name, task_tracker::TASKS_REJECTED_TOTAL),
1593 "Total number of rejected tasks",
1594 &[],
1595 )?;
1596
1597 Ok(Self {
1598 prometheus_issued: issued_counter,
1599 prometheus_started: started_counter,
1600 prometheus_success: success_counter,
1601 prometheus_cancelled: cancelled_counter,
1602 prometheus_failed: failed_counter,
1603 prometheus_rejected: rejected_counter,
1604 })
1605 }
1606}
1607
1608impl HierarchicalTaskMetrics for PrometheusTaskMetrics {
1609 fn increment_issued(&self) {
1610 self.prometheus_issued.inc();
1611 }
1612
1613 fn increment_started(&self) {
1614 self.prometheus_started.inc();
1615 }
1616
1617 fn increment_success(&self) {
1618 self.prometheus_success.inc();
1619 }
1620
1621 fn increment_cancelled(&self) {
1622 self.prometheus_cancelled.inc();
1623 }
1624
1625 fn increment_failed(&self) {
1626 self.prometheus_failed.inc();
1627 }
1628
1629 fn increment_rejected(&self) {
1630 self.prometheus_rejected.inc();
1631 }
1632
1633 fn issued(&self) -> u64 {
1634 self.prometheus_issued.get()
1635 }
1636
1637 fn started(&self) -> u64 {
1638 self.prometheus_started.get()
1639 }
1640
1641 fn success(&self) -> u64 {
1642 self.prometheus_success.get()
1643 }
1644
1645 fn cancelled(&self) -> u64 {
1646 self.prometheus_cancelled.get()
1647 }
1648
1649 fn failed(&self) -> u64 {
1650 self.prometheus_failed.get()
1651 }
1652
1653 fn rejected(&self) -> u64 {
1654 self.prometheus_rejected.get()
1655 }
1656}
1657
1658/// Child tracker metrics that chain updates to parent
1659///
1660/// This implementation maintains local counters and automatically forwards
1661/// all metric updates to the parent tracker for hierarchical aggregation.
1662/// Holds a strong reference to parent metrics for optimal performance.
1663#[derive(Debug)]
1664struct ChildTaskMetrics {
1665 /// Local metrics for this tracker
1666 local_metrics: TaskMetrics,
1667 /// Strong reference to parent metrics for fast chaining
1668 /// Safe to hold since metrics don't own trackers - no circular references
1669 parent_metrics: Arc<dyn HierarchicalTaskMetrics>,
1670}
1671
1672impl ChildTaskMetrics {
1673 fn new(parent_metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
1674 Self {
1675 local_metrics: TaskMetrics::new(),
1676 parent_metrics,
1677 }
1678 }
1679}
1680
1681impl HierarchicalTaskMetrics for ChildTaskMetrics {
1682 fn increment_issued(&self) {
1683 self.local_metrics.increment_issued();
1684 self.parent_metrics.increment_issued();
1685 }
1686
1687 fn increment_started(&self) {
1688 self.local_metrics.increment_started();
1689 self.parent_metrics.increment_started();
1690 }
1691
1692 fn increment_success(&self) {
1693 self.local_metrics.increment_success();
1694 self.parent_metrics.increment_success();
1695 }
1696
1697 fn increment_cancelled(&self) {
1698 self.local_metrics.increment_cancelled();
1699 self.parent_metrics.increment_cancelled();
1700 }
1701
1702 fn increment_failed(&self) {
1703 self.local_metrics.increment_failed();
1704 self.parent_metrics.increment_failed();
1705 }
1706
1707 fn increment_rejected(&self) {
1708 self.local_metrics.increment_rejected();
1709 self.parent_metrics.increment_rejected();
1710 }
1711
1712 fn issued(&self) -> u64 {
1713 self.local_metrics.issued()
1714 }
1715
1716 fn started(&self) -> u64 {
1717 self.local_metrics.started()
1718 }
1719
1720 fn success(&self) -> u64 {
1721 self.local_metrics.success()
1722 }
1723
1724 fn cancelled(&self) -> u64 {
1725 self.local_metrics.cancelled()
1726 }
1727
1728 fn failed(&self) -> u64 {
1729 self.local_metrics.failed()
1730 }
1731
1732 fn rejected(&self) -> u64 {
1733 self.local_metrics.rejected()
1734 }
1735}
1736
1737/// Builder for creating child trackers with custom policies
1738///
1739/// Allows flexible customization of scheduling and error handling policies
1740/// for child trackers while maintaining parent-child relationships.
1741pub struct ChildTrackerBuilder<'parent> {
1742 parent: &'parent TaskTracker,
1743 scheduler: Option<Arc<dyn TaskScheduler>>,
1744 error_policy: Option<Arc<dyn OnErrorPolicy>>,
1745}
1746
1747impl<'parent> ChildTrackerBuilder<'parent> {
1748 /// Create a new ChildTrackerBuilder
1749 pub fn new(parent: &'parent TaskTracker) -> Self {
1750 Self {
1751 parent,
1752 scheduler: None,
1753 error_policy: None,
1754 }
1755 }
1756
1757 /// Set custom scheduler for the child tracker
1758 ///
1759 /// If not set, the child will inherit the parent's scheduler.
1760 ///
1761 /// # Arguments
1762 /// * `scheduler` - The scheduler to use for this child tracker
1763 ///
1764 /// # Example
1765 /// ```rust
1766 /// # use std::sync::Arc;
1767 /// # use tokio::sync::Semaphore;
1768 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler};
1769 /// # fn example(parent: &TaskTracker) {
1770 /// let child = parent.child_tracker_builder()
1771 /// .scheduler(SemaphoreScheduler::with_permits(5))
1772 /// .build().unwrap();
1773 /// # }
1774 /// ```
1775 pub fn scheduler(mut self, scheduler: Arc<dyn TaskScheduler>) -> Self {
1776 self.scheduler = Some(scheduler);
1777 self
1778 }
1779
1780 /// Set custom error policy for the child tracker
1781 ///
1782 /// If not set, the child will get a child policy from the parent's error policy
1783 /// (via `OnErrorPolicy::create_child()`).
1784 ///
1785 /// # Arguments
1786 /// * `error_policy` - The error policy to use for this child tracker
1787 ///
1788 /// # Example
1789 /// ```rust
1790 /// # use std::sync::Arc;
1791 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, LogOnlyPolicy};
1792 /// # fn example(parent: &TaskTracker) {
1793 /// let child = parent.child_tracker_builder()
1794 /// .error_policy(LogOnlyPolicy::new())
1795 /// .build().unwrap();
1796 /// # }
1797 /// ```
1798 pub fn error_policy(mut self, error_policy: Arc<dyn OnErrorPolicy>) -> Self {
1799 self.error_policy = Some(error_policy);
1800 self
1801 }
1802
1803 /// Build the child tracker with the specified configuration
1804 ///
1805 /// Creates a new child tracker with:
1806 /// - Custom or inherited scheduler
1807 /// - Custom or child error policy
1808 /// - Hierarchical metrics that chain to parent
1809 /// - Child cancellation token from the parent
1810 /// - Independent lifecycle from parent
1811 ///
1812 /// # Returns
1813 /// A new `Arc<TaskTracker>` configured as a child of the parent
1814 ///
1815 /// # Errors
1816 /// Returns an error if the parent tracker is already closed
1817 pub fn build(self) -> anyhow::Result<TaskTracker> {
1818 // Validate that parent tracker is still active
1819 if self.parent.is_closed() {
1820 return Err(anyhow::anyhow!(
1821 "Cannot create child tracker from closed parent tracker"
1822 ));
1823 }
1824
1825 let parent = self.parent.0.clone();
1826
1827 let child_cancel_token = parent.cancel_token.child_token();
1828 let child_metrics = Arc::new(ChildTaskMetrics::new(parent.metrics.clone()));
1829
1830 // Use provided scheduler or inherit from parent
1831 let scheduler = self.scheduler.unwrap_or_else(|| parent.scheduler.clone());
1832
1833 // Use provided error policy or create child from parent's
1834 let error_policy = self
1835 .error_policy
1836 .unwrap_or_else(|| parent.error_policy.create_child());
1837
1838 let child = Arc::new(TaskTrackerInner {
1839 tokio_tracker: TokioTaskTracker::new(),
1840 parent: None, // No parent reference needed for hierarchical operations
1841 scheduler,
1842 error_policy,
1843 metrics: child_metrics,
1844 cancel_token: child_cancel_token,
1845 children: RwLock::new(Vec::new()),
1846 });
1847
1848 // Register this child with the parent for hierarchical operations
1849 parent
1850 .children
1851 .write()
1852 .unwrap()
1853 .push(Arc::downgrade(&child));
1854
1855 // Periodically clean up dead children to prevent unbounded growth
1856 parent.cleanup_dead_children();
1857
1858 Ok(TaskTracker(child))
1859 }
1860}
1861
1862/// Internal data for TaskTracker
1863///
1864/// This struct contains all the actual state and functionality of a TaskTracker.
1865/// TaskTracker itself is just a wrapper around Arc<TaskTrackerInner>.
1866struct TaskTrackerInner {
1867 /// Tokio's task tracker for lifecycle management
1868 tokio_tracker: TokioTaskTracker,
1869 /// Parent tracker (None for root)
1870 parent: Option<Arc<TaskTrackerInner>>,
1871 /// Scheduling policy (shared with children by default)
1872 scheduler: Arc<dyn TaskScheduler>,
1873 /// Error handling policy (child-specific via create_child)
1874 error_policy: Arc<dyn OnErrorPolicy>,
1875 /// Metrics for this tracker
1876 metrics: Arc<dyn HierarchicalTaskMetrics>,
1877 /// Cancellation token for this tracker (always present)
1878 cancel_token: CancellationToken,
1879 /// List of child trackers for hierarchical operations
1880 children: RwLock<Vec<Weak<TaskTrackerInner>>>,
1881}
1882
1883/// Hierarchical task tracker with pluggable scheduling and error policies
1884///
1885/// TaskTracker provides a composable system for managing background tasks with:
1886/// - Configurable scheduling via [`TaskScheduler`] implementations
1887/// - Flexible error handling via [`OnErrorPolicy`] implementations
1888/// - Parent-child relationships with independent metrics
1889/// - Cancellation propagation and isolation
1890/// - Built-in cancellation token support
1891///
1892/// Built on top of `tokio_util::task::TaskTracker` for robust task lifecycle management.
1893///
1894/// # Example
1895///
1896/// ```rust
1897/// # use std::sync::Arc;
1898/// # use tokio::sync::Semaphore;
1899/// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy, CancellableTaskResult};
1900/// # async fn example() -> anyhow::Result<()> {
1901/// // Create a task tracker with semaphore-based scheduling
1902/// let scheduler = SemaphoreScheduler::with_permits(3);
1903/// let policy = LogOnlyPolicy::new();
1904/// let root = TaskTracker::builder()
1905/// .scheduler(scheduler)
1906/// .error_policy(policy)
1907/// .build()?;
1908///
1909/// // Spawn some tasks
1910/// let handle1 = root.spawn(async { Ok(1) });
1911/// let handle2 = root.spawn(async { Ok(2) });
1912///
1913/// // Get results and join all tasks
1914/// let result1 = handle1.await.unwrap().unwrap();
1915/// let result2 = handle2.await.unwrap().unwrap();
1916/// assert_eq!(result1, 1);
1917/// assert_eq!(result2, 2);
1918/// # Ok(())
1919/// # }
1920/// ```
1921#[derive(Clone)]
1922pub struct TaskTracker(Arc<TaskTrackerInner>);
1923
1924/// Builder for TaskTracker
1925#[derive(Default)]
1926pub struct TaskTrackerBuilder {
1927 scheduler: Option<Arc<dyn TaskScheduler>>,
1928 error_policy: Option<Arc<dyn OnErrorPolicy>>,
1929 metrics: Option<Arc<dyn HierarchicalTaskMetrics>>,
1930 cancel_token: Option<CancellationToken>,
1931}
1932
1933impl TaskTrackerBuilder {
1934 /// Set the scheduler for this TaskTracker
1935 pub fn scheduler(mut self, scheduler: Arc<dyn TaskScheduler>) -> Self {
1936 self.scheduler = Some(scheduler);
1937 self
1938 }
1939
1940 /// Set the error policy for this TaskTracker
1941 pub fn error_policy(mut self, error_policy: Arc<dyn OnErrorPolicy>) -> Self {
1942 self.error_policy = Some(error_policy);
1943 self
1944 }
1945
1946 /// Set custom metrics for this TaskTracker
1947 pub fn metrics(mut self, metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
1948 self.metrics = Some(metrics);
1949 self
1950 }
1951
1952 /// Set the cancellation token for this TaskTracker
1953 pub fn cancel_token(mut self, cancel_token: CancellationToken) -> Self {
1954 self.cancel_token = Some(cancel_token);
1955 self
1956 }
1957
1958 /// Build the TaskTracker
1959 pub fn build(self) -> anyhow::Result<TaskTracker> {
1960 let scheduler = self
1961 .scheduler
1962 .ok_or_else(|| anyhow::anyhow!("TaskTracker requires a scheduler"))?;
1963
1964 let error_policy = self
1965 .error_policy
1966 .ok_or_else(|| anyhow::anyhow!("TaskTracker requires an error policy"))?;
1967
1968 let metrics = self.metrics.unwrap_or_else(|| Arc::new(TaskMetrics::new()));
1969
1970 let cancel_token = self.cancel_token.unwrap_or_default();
1971
1972 let inner = TaskTrackerInner {
1973 tokio_tracker: TokioTaskTracker::new(),
1974 parent: None,
1975 scheduler,
1976 error_policy,
1977 metrics,
1978 cancel_token,
1979 children: RwLock::new(Vec::new()),
1980 };
1981
1982 Ok(TaskTracker(Arc::new(inner)))
1983 }
1984}
1985
1986impl TaskTracker {
1987 /// Create a new root task tracker using the builder pattern
1988 ///
1989 /// This is the preferred way to create new task trackers.
1990 ///
1991 /// # Example
1992 /// ```rust
1993 /// # use std::sync::Arc;
1994 /// # use tokio::sync::Semaphore;
1995 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
1996 /// # fn main() -> anyhow::Result<()> {
1997 /// let scheduler = SemaphoreScheduler::with_permits(10);
1998 /// let error_policy = LogOnlyPolicy::new();
1999 /// let tracker = TaskTracker::builder()
2000 /// .scheduler(scheduler)
2001 /// .error_policy(error_policy)
2002 /// .build()?;
2003 /// # Ok(())
2004 /// # }
2005 /// ```
2006 pub fn builder() -> TaskTrackerBuilder {
2007 TaskTrackerBuilder::default()
2008 }
2009
2010 /// Create a new root task tracker with simple parameters (legacy)
2011 ///
2012 /// This method is kept for backward compatibility. Use `builder()` for new code.
2013 /// Uses default metrics (no Prometheus integration).
2014 ///
2015 /// # Arguments
2016 /// * `scheduler` - Scheduling policy to use for all tasks
2017 /// * `error_policy` - Error handling policy for this tracker
2018 ///
2019 /// # Example
2020 /// ```rust
2021 /// # use std::sync::Arc;
2022 /// # use tokio::sync::Semaphore;
2023 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2024 /// # fn main() -> anyhow::Result<()> {
2025 /// let scheduler = SemaphoreScheduler::with_permits(10);
2026 /// let error_policy = LogOnlyPolicy::new();
2027 /// let tracker = TaskTracker::new(scheduler, error_policy)?;
2028 /// # Ok(())
2029 /// # }
2030 /// ```
2031 pub fn new(
2032 scheduler: Arc<dyn TaskScheduler>,
2033 error_policy: Arc<dyn OnErrorPolicy>,
2034 ) -> anyhow::Result<Self> {
2035 Self::builder()
2036 .scheduler(scheduler)
2037 .error_policy(error_policy)
2038 .build()
2039 }
2040
2041 /// Create a new root task tracker with Prometheus metrics integration
2042 ///
2043 /// # Arguments
2044 /// * `scheduler` - Scheduling policy to use for all tasks
2045 /// * `error_policy` - Error handling policy for this tracker
2046 /// * `registry` - MetricsRegistry for Prometheus integration
2047 /// * `component_name` - Name for this tracker component
2048 ///
2049 /// # Example
2050 /// ```rust
2051 /// # use std::sync::Arc;
2052 /// # use tokio::sync::Semaphore;
2053 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2054 /// # use dynamo_runtime::metrics::MetricsRegistry;
2055 /// # fn example(registry: Arc<dyn MetricsRegistry>) -> anyhow::Result<()> {
2056 /// let scheduler = SemaphoreScheduler::with_permits(10);
2057 /// let error_policy = LogOnlyPolicy::new();
2058 /// let tracker = TaskTracker::new_with_prometheus(
2059 /// scheduler,
2060 /// error_policy,
2061 /// registry.as_ref(),
2062 /// "main_tracker"
2063 /// )?;
2064 /// # Ok(())
2065 /// # }
2066 /// ```
2067 pub fn new_with_prometheus<R: MetricsRegistry + ?Sized>(
2068 scheduler: Arc<dyn TaskScheduler>,
2069 error_policy: Arc<dyn OnErrorPolicy>,
2070 registry: &R,
2071 component_name: &str,
2072 ) -> anyhow::Result<Self> {
2073 let prometheus_metrics = Arc::new(PrometheusTaskMetrics::new(registry, component_name)?);
2074
2075 Self::builder()
2076 .scheduler(scheduler)
2077 .error_policy(error_policy)
2078 .metrics(prometheus_metrics)
2079 .build()
2080 }
2081
2082 /// Create a child tracker that inherits scheduling policy
2083 ///
2084 /// The child tracker:
2085 /// - Gets its own independent tokio TaskTracker
2086 /// - Inherits the parent's scheduler
2087 /// - Gets a child error policy via `create_child()`
2088 /// - Has hierarchical metrics that chain to parent
2089 /// - Gets a child cancellation token from the parent
2090 /// - Is independent for cancellation (child cancellation doesn't affect parent)
2091 ///
2092 /// # Errors
2093 /// Returns an error if the parent tracker is already closed
2094 ///
2095 /// # Example
2096 /// ```rust
2097 /// # use std::sync::Arc;
2098 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2099 /// # fn example(root_tracker: TaskTracker) -> anyhow::Result<()> {
2100 /// let child_tracker = root_tracker.child_tracker()?;
2101 /// // Child inherits parent's policies but has separate metrics and lifecycle
2102 /// # Ok(())
2103 /// # }
2104 /// ```
2105 pub fn child_tracker(&self) -> anyhow::Result<TaskTracker> {
2106 Ok(TaskTracker(self.0.child_tracker()?))
2107 }
2108
2109 /// Create a child tracker builder for flexible customization
2110 ///
2111 /// The builder allows you to customize scheduling and error policies for the child tracker.
2112 /// If not specified, policies are inherited from the parent.
2113 ///
2114 /// # Example
2115 /// ```rust
2116 /// # use std::sync::Arc;
2117 /// # use tokio::sync::Semaphore;
2118 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2119 /// # fn example(root_tracker: TaskTracker) {
2120 /// // Custom scheduler, inherit error policy
2121 /// let child1 = root_tracker.child_tracker_builder()
2122 /// .scheduler(SemaphoreScheduler::with_permits(5))
2123 /// .build().unwrap();
2124 ///
2125 /// // Custom error policy, inherit scheduler
2126 /// let child2 = root_tracker.child_tracker_builder()
2127 /// .error_policy(LogOnlyPolicy::new())
2128 /// .build().unwrap();
2129 ///
2130 /// // Both custom
2131 /// let child3 = root_tracker.child_tracker_builder()
2132 /// .scheduler(SemaphoreScheduler::with_permits(3))
2133 /// .error_policy(LogOnlyPolicy::new())
2134 /// .build().unwrap();
2135 /// # }
2136 /// ```
2137 /// Spawn a new task
2138 ///
2139 /// The task will be wrapped with scheduling and error handling logic,
2140 /// then executed according to the configured policies. For tasks that
2141 /// need to inspect cancellation tokens, use [`spawn_cancellable`] instead.
2142 ///
2143 /// # Arguments
2144 /// * `future` - The async task to execute
2145 ///
2146 /// # Returns
2147 /// A [`TaskHandle`] that can be used to await completion and access the task's cancellation token
2148 ///
2149 /// # Panics
2150 /// Panics if the tracker has been closed. This indicates a programming error
2151 /// where tasks are being spawned after the tracker lifecycle has ended.
2152 ///
2153 /// # Example
2154 /// ```rust
2155 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2156 /// # async fn example(tracker: TaskTracker) -> anyhow::Result<()> {
2157 /// let handle = tracker.spawn(async {
2158 /// // Your async work here
2159 /// tokio::time::sleep(std::time::Duration::from_millis(100)).await;
2160 /// Ok(42)
2161 /// });
2162 ///
2163 /// // Access the task's cancellation token
2164 /// let cancel_token = handle.cancellation_token();
2165 ///
2166 /// let result = handle.await?;
2167 /// # Ok(())
2168 /// # }
2169 /// ```
2170 pub fn spawn<F, T>(&self, future: F) -> TaskHandle<T>
2171 where
2172 F: Future<Output = Result<T>> + Send + 'static,
2173 T: Send + 'static,
2174 {
2175 self.0
2176 .spawn(future)
2177 .expect("TaskTracker must not be closed when spawning tasks")
2178 }
2179
2180 /// Spawn a cancellable task that receives a cancellation token
2181 ///
2182 /// This is useful for tasks that need to inspect the cancellation token
2183 /// and gracefully handle cancellation within their logic. The task function
2184 /// must return a `CancellableTaskResult` to properly track cancellation vs errors.
2185 ///
2186 /// # Arguments
2187 ///
2188 /// * `task_fn` - Function that takes a cancellation token and returns a future that resolves to `CancellableTaskResult<T>`
2189 ///
2190 /// # Returns
2191 /// A [`TaskHandle`] that can be used to await completion and access the task's cancellation token
2192 ///
2193 /// # Panics
2194 /// Panics if the tracker has been closed. This indicates a programming error
2195 /// where tasks are being spawned after the tracker lifecycle has ended.
2196 ///
2197 /// # Example
2198 /// ```rust
2199 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, CancellableTaskResult};
2200 /// # async fn example(tracker: TaskTracker) -> anyhow::Result<()> {
2201 /// let handle = tracker.spawn_cancellable(|cancel_token| async move {
2202 /// tokio::select! {
2203 /// _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
2204 /// CancellableTaskResult::Ok(42)
2205 /// },
2206 /// _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
2207 /// }
2208 /// });
2209 ///
2210 /// // Access the task's individual cancellation token
2211 /// let task_cancel_token = handle.cancellation_token();
2212 ///
2213 /// let result = handle.await?;
2214 /// # Ok(())
2215 /// # }
2216 /// ```
2217 pub fn spawn_cancellable<F, Fut, T>(&self, task_fn: F) -> TaskHandle<T>
2218 where
2219 F: FnMut(CancellationToken) -> Fut + Send + 'static,
2220 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
2221 T: Send + 'static,
2222 {
2223 self.0
2224 .spawn_cancellable(task_fn)
2225 .expect("TaskTracker must not be closed when spawning tasks")
2226 }
2227
2228 /// Get metrics for this tracker
2229 ///
2230 /// Metrics are specific to this tracker and do not include
2231 /// metrics from parent or child trackers.
2232 ///
2233 /// # Example
2234 /// ```rust
2235 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2236 /// # fn example(tracker: &TaskTracker) {
2237 /// let metrics = tracker.metrics();
2238 /// println!("Success: {}, Failed: {}", metrics.success(), metrics.failed());
2239 /// # }
2240 /// ```
2241 pub fn metrics(&self) -> &dyn HierarchicalTaskMetrics {
2242 self.0.metrics.as_ref()
2243 }
2244
2245 /// Cancel this tracker and all its tasks
2246 ///
2247 /// This will signal cancellation to all currently running tasks and prevent new tasks from being spawned.
2248 /// The cancellation is immediate and forceful.
2249 ///
2250 /// # Example
2251 /// ```rust
2252 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2253 /// # async fn example(tracker: TaskTracker) -> anyhow::Result<()> {
2254 /// // Spawn a long-running task
2255 /// let handle = tracker.spawn_cancellable(|cancel_token| async move {
2256 /// tokio::select! {
2257 /// _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
2258 /// dynamo_runtime::utils::tasks::tracker::CancellableTaskResult::Ok(42)
2259 /// }
2260 /// _ = cancel_token.cancelled() => {
2261 /// dynamo_runtime::utils::tasks::tracker::CancellableTaskResult::Cancelled
2262 /// }
2263 /// }
2264 /// }).await?;
2265 ///
2266 /// // Cancel the tracker (and thus the task)
2267 /// tracker.cancel();
2268 /// # Ok(())
2269 /// # }
2270 /// ```
2271 pub fn cancel(&self) {
2272 self.0.cancel();
2273 }
2274
2275 /// Check if this tracker is closed
2276 pub fn is_closed(&self) -> bool {
2277 self.0.is_closed()
2278 }
2279
2280 /// Get the cancellation token for this tracker
2281 ///
2282 /// This allows external code to observe or trigger cancellation of this tracker.
2283 ///
2284 /// # Example
2285 /// ```rust
2286 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2287 /// # fn example(tracker: &TaskTracker) {
2288 /// let token = tracker.cancellation_token();
2289 /// // Can check cancellation state or cancel manually
2290 /// if !token.is_cancelled() {
2291 /// token.cancel();
2292 /// }
2293 /// # }
2294 /// ```
2295 pub fn cancellation_token(&self) -> CancellationToken {
2296 self.0.cancellation_token()
2297 }
2298
2299 /// Get the number of active child trackers
2300 ///
2301 /// This counts only child trackers that are still alive (not dropped).
2302 /// Dropped child trackers are automatically cleaned up.
2303 ///
2304 /// # Example
2305 /// ```rust
2306 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2307 /// # fn example(tracker: &TaskTracker) {
2308 /// let child_count = tracker.child_count();
2309 /// println!("This tracker has {} active children", child_count);
2310 /// # }
2311 /// ```
2312 pub fn child_count(&self) -> usize {
2313 self.0.child_count()
2314 }
2315
2316 /// Create a child tracker builder with custom configuration
2317 ///
2318 /// This provides fine-grained control over child tracker creation,
2319 /// allowing you to override the scheduler or error policy while
2320 /// maintaining the parent-child relationship.
2321 ///
2322 /// # Example
2323 /// ```rust
2324 /// # use std::sync::Arc;
2325 /// # use tokio::sync::Semaphore;
2326 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2327 /// # fn example(parent: &TaskTracker) {
2328 /// // Custom scheduler, inherit error policy
2329 /// let child1 = parent.child_tracker_builder()
2330 /// .scheduler(SemaphoreScheduler::with_permits(5))
2331 /// .build().unwrap();
2332 ///
2333 /// // Custom error policy, inherit scheduler
2334 /// let child2 = parent.child_tracker_builder()
2335 /// .error_policy(LogOnlyPolicy::new())
2336 /// .build().unwrap();
2337 ///
2338 /// // Inherit both policies from parent
2339 /// let child3 = parent.child_tracker_builder()
2340 /// .build().unwrap();
2341 /// # }
2342 /// ```
2343 pub fn child_tracker_builder(&self) -> ChildTrackerBuilder<'_> {
2344 ChildTrackerBuilder::new(self)
2345 }
2346
2347 /// Join this tracker and all child trackers
2348 ///
2349 /// This method gracefully shuts down the entire tracker hierarchy by:
2350 /// 1. Closing all trackers (preventing new task spawning)
2351 /// 2. Waiting for all existing tasks to complete
2352 ///
2353 /// Uses stack-safe traversal to prevent stack overflow in deep hierarchies.
2354 /// Children are processed before parents to ensure proper shutdown order.
2355 ///
2356 /// **Hierarchical Behavior:**
2357 /// - Processes children before parents to ensure proper shutdown order
2358 /// - Each tracker is closed before waiting (Tokio requirement)
2359 /// - Leaf trackers simply close and wait for their own tasks
2360 ///
2361 /// # Example
2362 /// ```rust
2363 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2364 /// # async fn example(tracker: TaskTracker) {
2365 /// tracker.join().await;
2366 /// # }
2367 /// ```
2368 pub async fn join(&self) {
2369 self.0.join().await
2370 }
2371}
2372
2373impl TaskTrackerInner {
2374 /// Creates child tracker with inherited scheduler/policy, independent metrics, and hierarchical cancellation
2375 fn child_tracker(self: &Arc<Self>) -> anyhow::Result<Arc<TaskTrackerInner>> {
2376 // Validate that parent tracker is still active
2377 if self.is_closed() {
2378 return Err(anyhow::anyhow!(
2379 "Cannot create child tracker from closed parent tracker"
2380 ));
2381 }
2382
2383 let child_cancel_token = self.cancel_token.child_token();
2384 let child_metrics = Arc::new(ChildTaskMetrics::new(self.metrics.clone()));
2385
2386 let child = Arc::new(TaskTrackerInner {
2387 tokio_tracker: TokioTaskTracker::new(),
2388 parent: Some(self.clone()),
2389 scheduler: self.scheduler.clone(),
2390 error_policy: self.error_policy.create_child(),
2391 metrics: child_metrics,
2392 cancel_token: child_cancel_token,
2393 children: RwLock::new(Vec::new()),
2394 });
2395
2396 // Register this child with the parent for hierarchical operations
2397 self.children.write().unwrap().push(Arc::downgrade(&child));
2398
2399 // Periodically clean up dead children to prevent unbounded growth
2400 self.cleanup_dead_children();
2401
2402 Ok(child)
2403 }
2404
2405 /// Spawn implementation - validates tracker state, generates task ID, applies policies, and tracks execution
2406 fn spawn<F, T>(self: &Arc<Self>, future: F) -> Result<TaskHandle<T>, TaskError>
2407 where
2408 F: Future<Output = Result<T>> + Send + 'static,
2409 T: Send + 'static,
2410 {
2411 // Validate tracker is not closed
2412 if self.tokio_tracker.is_closed() {
2413 return Err(TaskError::TrackerClosed);
2414 }
2415
2416 // Generate a unique task ID
2417 let task_id = self.generate_task_id();
2418
2419 // Increment issued counter immediately when task is submitted
2420 self.metrics.increment_issued();
2421
2422 // Create a child cancellation token for this specific task
2423 let task_cancel_token = self.cancel_token.child_token();
2424 let cancel_token = task_cancel_token.clone();
2425
2426 // Clone the inner Arc to move into the task
2427 let inner = self.clone();
2428
2429 // Wrap the user's future with our scheduling and error handling
2430 let wrapped_future =
2431 async move { Self::execute_with_policies(task_id, future, cancel_token, inner).await };
2432
2433 // Let tokio handle the actual task tracking
2434 let join_handle = self.tokio_tracker.spawn(wrapped_future);
2435
2436 // Wrap in TaskHandle with the child cancellation token
2437 Ok(TaskHandle::new(join_handle, task_cancel_token))
2438 }
2439
2440 /// Spawn cancellable implementation - validates state, provides cancellation token, handles CancellableTaskResult
2441 fn spawn_cancellable<F, Fut, T>(
2442 self: &Arc<Self>,
2443 task_fn: F,
2444 ) -> Result<TaskHandle<T>, TaskError>
2445 where
2446 F: FnMut(CancellationToken) -> Fut + Send + 'static,
2447 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
2448 T: Send + 'static,
2449 {
2450 // Validate tracker is not closed
2451 if self.tokio_tracker.is_closed() {
2452 return Err(TaskError::TrackerClosed);
2453 }
2454
2455 // Generate a unique task ID
2456 let task_id = self.generate_task_id();
2457
2458 // Increment issued counter immediately when task is submitted
2459 self.metrics.increment_issued();
2460
2461 // Create a child cancellation token for this specific task
2462 let task_cancel_token = self.cancel_token.child_token();
2463 let cancel_token = task_cancel_token.clone();
2464
2465 // Clone the inner Arc to move into the task
2466 let inner = self.clone();
2467
2468 // Use the new execution pipeline that defers task creation until after guard acquisition
2469 let wrapped_future = async move {
2470 Self::execute_cancellable_with_policies(task_id, task_fn, cancel_token, inner).await
2471 };
2472
2473 // Let tokio handle the actual task tracking
2474 let join_handle = self.tokio_tracker.spawn(wrapped_future);
2475
2476 // Wrap in TaskHandle with the child cancellation token
2477 Ok(TaskHandle::new(join_handle, task_cancel_token))
2478 }
2479
2480 /// Cancel this tracker and all its tasks - implementation
2481 fn cancel(&self) {
2482 // Close the tracker to prevent new tasks
2483 self.tokio_tracker.close();
2484
2485 // Cancel our own token
2486 self.cancel_token.cancel();
2487 }
2488
2489 /// Returns true if the underlying tokio tracker is closed
2490 fn is_closed(&self) -> bool {
2491 self.tokio_tracker.is_closed()
2492 }
2493
2494 /// Generates a unique task ID using TaskId::new()
2495 fn generate_task_id(&self) -> TaskId {
2496 TaskId::new()
2497 }
2498
2499 /// Removes dead weak references from children list to prevent memory leaks
2500 fn cleanup_dead_children(&self) {
2501 let mut children_guard = self.children.write().unwrap();
2502 children_guard.retain(|weak| weak.upgrade().is_some());
2503 }
2504
2505 /// Returns a clone of the cancellation token
2506 fn cancellation_token(&self) -> CancellationToken {
2507 self.cancel_token.clone()
2508 }
2509
2510 /// Counts active child trackers (filters out dead weak references)
2511 fn child_count(&self) -> usize {
2512 let children_guard = self.children.read().unwrap();
2513 children_guard
2514 .iter()
2515 .filter(|weak| weak.upgrade().is_some())
2516 .count()
2517 }
2518
2519 /// Join implementation - closes all trackers in hierarchy then waits for task completion using stack-safe traversal
2520 async fn join(self: &Arc<Self>) {
2521 // Fast path for leaf trackers (no children)
2522 let is_leaf = {
2523 let children_guard = self.children.read().unwrap();
2524 children_guard.is_empty()
2525 };
2526
2527 if is_leaf {
2528 self.tokio_tracker.close();
2529 self.tokio_tracker.wait().await;
2530 return;
2531 }
2532
2533 // Stack-safe traversal for deep hierarchies
2534 // Processes children before parents to ensure proper shutdown order
2535 let trackers = self.collect_hierarchy();
2536 for t in trackers {
2537 t.tokio_tracker.close();
2538 t.tokio_tracker.wait().await;
2539 }
2540 }
2541
2542 /// Collects hierarchy using iterative DFS, returns Vec in post-order (children before parents) for safe shutdown
2543 fn collect_hierarchy(self: &Arc<TaskTrackerInner>) -> Vec<Arc<TaskTrackerInner>> {
2544 let mut result = Vec::new();
2545 let mut stack = vec![self.clone()];
2546 let mut visited = HashSet::new();
2547
2548 // Collect all trackers using depth-first search
2549 while let Some(tracker) = stack.pop() {
2550 let tracker_ptr = Arc::as_ptr(&tracker) as usize;
2551 if visited.contains(&tracker_ptr) {
2552 continue;
2553 }
2554 visited.insert(tracker_ptr);
2555
2556 // Add current tracker to result
2557 result.push(tracker.clone());
2558
2559 // Add children to stack for processing
2560 if let Ok(children_guard) = tracker.children.read() {
2561 for weak_child in children_guard.iter() {
2562 if let Some(child) = weak_child.upgrade() {
2563 let child_ptr = Arc::as_ptr(&child) as usize;
2564 if !visited.contains(&child_ptr) {
2565 stack.push(child);
2566 }
2567 }
2568 }
2569 }
2570 }
2571
2572 // Reverse to get bottom-up order (children before parents)
2573 result.reverse();
2574 result
2575 }
2576
2577 /// Execute a regular task with scheduling and error handling policies
2578 #[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
2579 async fn execute_with_policies<F, T>(
2580 task_id: TaskId,
2581 future: F,
2582 task_cancel_token: CancellationToken,
2583 inner: Arc<TaskTrackerInner>,
2584 ) -> Result<T, TaskError>
2585 where
2586 F: Future<Output = Result<T>> + Send + 'static,
2587 T: Send + 'static,
2588 {
2589 // Wrap regular future in a task executor that doesn't support retry
2590 let task_executor = RegularTaskExecutor::new(future);
2591 Self::execute_with_retry_loop(task_id, task_executor, task_cancel_token, inner).await
2592 }
2593
2594 /// Execute a cancellable task with scheduling and error handling policies
2595 #[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
2596 async fn execute_cancellable_with_policies<F, Fut, T>(
2597 task_id: TaskId,
2598 task_fn: F,
2599 task_cancel_token: CancellationToken,
2600 inner: Arc<TaskTrackerInner>,
2601 ) -> Result<T, TaskError>
2602 where
2603 F: FnMut(CancellationToken) -> Fut + Send + 'static,
2604 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
2605 T: Send + 'static,
2606 {
2607 // Wrap cancellable task function in a task executor that supports retry
2608 let task_executor = CancellableTaskExecutor::new(task_fn);
2609 Self::execute_with_retry_loop(task_id, task_executor, task_cancel_token, inner).await
2610 }
2611
2612 /// Core execution loop with retry support - unified for both task types
2613 #[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
2614 async fn execute_with_retry_loop<E, T>(
2615 task_id: TaskId,
2616 initial_executor: E,
2617 task_cancellation_token: CancellationToken,
2618 inner: Arc<TaskTrackerInner>,
2619 ) -> Result<T, TaskError>
2620 where
2621 E: TaskExecutor<T> + Send + 'static,
2622 T: Send + 'static,
2623 {
2624 debug!("Starting task execution");
2625
2626 // RAII guard for active counter - increments on creation, decrements on drop
2627 struct ActiveCountGuard {
2628 metrics: Arc<dyn HierarchicalTaskMetrics>,
2629 is_active: bool,
2630 }
2631
2632 impl ActiveCountGuard {
2633 fn new(metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
2634 Self {
2635 metrics,
2636 is_active: false,
2637 }
2638 }
2639
2640 fn activate(&mut self) {
2641 if !self.is_active {
2642 self.metrics.increment_started();
2643 self.is_active = true;
2644 }
2645 }
2646 }
2647
2648 // Current executable - either the original TaskExecutor or a Continuation
2649 enum CurrentExecutable<E>
2650 where
2651 E: Send + 'static,
2652 {
2653 TaskExecutor(E),
2654 Continuation(Arc<dyn Continuation + Send + Sync + 'static>),
2655 }
2656
2657 let mut current_executable = CurrentExecutable::TaskExecutor(initial_executor);
2658 let mut active_guard = ActiveCountGuard::new(inner.metrics.clone());
2659 let mut error_context: Option<OnErrorContext> = None;
2660 let mut scheduler_guard_state = self::GuardState::Keep;
2661 let mut guard_result = async {
2662 inner
2663 .scheduler
2664 .acquire_execution_slot(task_cancellation_token.child_token())
2665 .await
2666 }
2667 .instrument(tracing::debug_span!("scheduler_resource_reacquisition"))
2668 .await;
2669
2670 loop {
2671 if scheduler_guard_state == self::GuardState::Reschedule {
2672 guard_result = async {
2673 inner
2674 .scheduler
2675 .acquire_execution_slot(inner.cancel_token.child_token())
2676 .await
2677 }
2678 .instrument(tracing::debug_span!("scheduler_resource_reacquisition"))
2679 .await;
2680 }
2681
2682 match &guard_result {
2683 SchedulingResult::Execute(_guard) => {
2684 // Activate the RAII guard only once when we successfully acquire resources
2685 active_guard.activate();
2686
2687 // Execute the current executable while holding the guard (RAII pattern)
2688 let execution_result = async {
2689 debug!("Executing task with acquired resources");
2690 match &mut current_executable {
2691 CurrentExecutable::TaskExecutor(executor) => {
2692 executor.execute(inner.cancel_token.child_token()).await
2693 }
2694 CurrentExecutable::Continuation(continuation) => {
2695 // Execute continuation and handle type erasure
2696 match continuation.execute(inner.cancel_token.child_token()).await {
2697 TaskExecutionResult::Success(result) => {
2698 // Try to downcast the result to the expected type T
2699 if let Ok(typed_result) = result.downcast::<T>() {
2700 TaskExecutionResult::Success(*typed_result)
2701 } else {
2702 // Type mismatch - this shouldn't happen with proper usage
2703 let type_error = anyhow::anyhow!(
2704 "Continuation task returned wrong type"
2705 );
2706 error!(
2707 ?type_error,
2708 "Type mismatch in continuation task result"
2709 );
2710 TaskExecutionResult::Error(type_error)
2711 }
2712 }
2713 TaskExecutionResult::Cancelled => {
2714 TaskExecutionResult::Cancelled
2715 }
2716 TaskExecutionResult::Error(error) => {
2717 TaskExecutionResult::Error(error)
2718 }
2719 }
2720 }
2721 }
2722 }
2723 .instrument(tracing::debug_span!("task_execution"))
2724 .await;
2725
2726 // Active counter will be decremented automatically when active_guard drops
2727
2728 match execution_result {
2729 TaskExecutionResult::Success(value) => {
2730 inner.metrics.increment_success();
2731 debug!("Task completed successfully");
2732 return Ok(value);
2733 }
2734 TaskExecutionResult::Cancelled => {
2735 inner.metrics.increment_cancelled();
2736 debug!("Task was cancelled during execution");
2737 return Err(TaskError::Cancelled);
2738 }
2739 TaskExecutionResult::Error(error) => {
2740 debug!("Task failed - handling error through policy - {error:?}");
2741
2742 // Handle the error through the policy system
2743 let (action_result, guard_state) = Self::handle_task_error(
2744 &error,
2745 &mut error_context,
2746 task_id,
2747 &inner,
2748 )
2749 .await;
2750
2751 // Update the scheduler guard state for evaluation after the match
2752 scheduler_guard_state = guard_state;
2753
2754 match action_result {
2755 ActionResult::Fail => {
2756 inner.metrics.increment_failed();
2757 debug!("Policy accepted error - task failed {error:?}");
2758 return Err(TaskError::Failed(error));
2759 }
2760 ActionResult::Shutdown => {
2761 inner.metrics.increment_failed();
2762 warn!("Policy triggered shutdown - {error:?}");
2763 inner.cancel();
2764 return Err(TaskError::Failed(error));
2765 }
2766 ActionResult::Continue { continuation } => {
2767 debug!(
2768 "Policy provided next executable - continuing loop - {error:?}"
2769 );
2770
2771 // Update current executable
2772 current_executable =
2773 CurrentExecutable::Continuation(continuation);
2774
2775 continue; // Continue the main loop with the new executable
2776 }
2777 }
2778 }
2779 }
2780 }
2781 SchedulingResult::Cancelled => {
2782 inner.metrics.increment_cancelled();
2783 debug!("Task was cancelled during resource acquisition");
2784 return Err(TaskError::Cancelled);
2785 }
2786 SchedulingResult::Rejected(reason) => {
2787 inner.metrics.increment_rejected();
2788 debug!(reason, "Task was rejected by scheduler");
2789 return Err(TaskError::Failed(anyhow::anyhow!(
2790 "Task rejected: {}",
2791 reason
2792 )));
2793 }
2794 }
2795 }
2796 }
2797
2798 /// Handle task errors through the error policy and return the action to take
2799 async fn handle_task_error(
2800 error: &anyhow::Error,
2801 error_context: &mut Option<OnErrorContext>,
2802 task_id: TaskId,
2803 inner: &Arc<TaskTrackerInner>,
2804 ) -> (ActionResult, self::GuardState) {
2805 // Create or update the error context (lazy initialization)
2806 let context = error_context.get_or_insert_with(|| OnErrorContext {
2807 attempt_count: 0, // Will be incremented below
2808 task_id,
2809 execution_context: TaskExecutionContext {
2810 scheduler: inner.scheduler.clone(),
2811 metrics: inner.metrics.clone(),
2812 },
2813 state: inner.error_policy.create_context(),
2814 });
2815
2816 // Increment attempt count for this error
2817 context.attempt_count += 1;
2818 let current_attempt = context.attempt_count;
2819
2820 // First, check if the policy allows continuations for this error
2821 if inner.error_policy.allow_continuation(error, context) {
2822 // Policy allows continuations, check if this is a FailedWithContinuation (task-driven continuation)
2823 if let Some(continuation_err) = error.downcast_ref::<FailedWithContinuation>() {
2824 debug!(
2825 task_id = %task_id,
2826 attempt_count = current_attempt,
2827 "Task provided FailedWithContinuation and policy allows continuations - {error:?}"
2828 );
2829
2830 // Task has provided a continuation implementation for the next attempt
2831 // Clone the Arc to return it in ActionResult::Continue
2832 let continuation = continuation_err.continuation.clone();
2833
2834 // Ask policy whether to reschedule task-driven continuation
2835 let should_reschedule = inner.error_policy.should_reschedule(error, context);
2836
2837 let guard_state = if should_reschedule {
2838 self::GuardState::Reschedule
2839 } else {
2840 self::GuardState::Keep
2841 };
2842
2843 return (ActionResult::Continue { continuation }, guard_state);
2844 }
2845 } else {
2846 debug!(
2847 task_id = %task_id,
2848 attempt_count = current_attempt,
2849 "Policy rejected continuations, ignoring any FailedWithContinuation - {error:?}"
2850 );
2851 }
2852
2853 let response = inner.error_policy.on_error(error, context);
2854
2855 match response {
2856 ErrorResponse::Fail => (ActionResult::Fail, self::GuardState::Keep),
2857 ErrorResponse::Shutdown => (ActionResult::Shutdown, self::GuardState::Keep),
2858 ErrorResponse::Custom(action) => {
2859 debug!("Task failed - executing custom action - {error:?}");
2860
2861 // Execute the custom action asynchronously
2862 let action_result = action
2863 .execute(error, task_id, current_attempt, &context.execution_context)
2864 .await;
2865 debug!(?action_result, "Custom action completed");
2866
2867 // If the custom action returned Continue, ask policy about rescheduling
2868 let guard_state = match &action_result {
2869 ActionResult::Continue { .. } => {
2870 let should_reschedule =
2871 inner.error_policy.should_reschedule(error, context);
2872 if should_reschedule {
2873 self::GuardState::Reschedule
2874 } else {
2875 self::GuardState::Keep
2876 }
2877 }
2878 _ => self::GuardState::Keep, // Fail/Shutdown don't need guard state
2879 };
2880
2881 (action_result, guard_state)
2882 }
2883 }
2884 }
2885}
2886
2887// Blanket implementation for all schedulers
2888impl ArcPolicy for UnlimitedScheduler {}
2889impl ArcPolicy for SemaphoreScheduler {}
2890
2891// Blanket implementation for all error policies
2892impl ArcPolicy for LogOnlyPolicy {}
2893impl ArcPolicy for CancelOnError {}
2894impl ArcPolicy for ThresholdCancelPolicy {}
2895impl ArcPolicy for RateCancelPolicy {}
2896
2897/// Resource guard for unlimited scheduling
2898///
2899/// This guard represents "unlimited" resources - no actual resource constraints.
2900/// Since there are no resources to manage, this guard is essentially a no-op.
2901#[derive(Debug)]
2902pub struct UnlimitedGuard;
2903
2904impl ResourceGuard for UnlimitedGuard {
2905 // No resources to manage - marker trait implementation only
2906}
2907
2908/// Unlimited task scheduler that executes all tasks immediately
2909///
2910/// This scheduler provides no concurrency limits and executes all submitted tasks
2911/// immediately. Useful for testing, high-throughput scenarios, or when external
2912/// systems provide the concurrency control.
2913///
2914/// ## Cancellation Behavior
2915///
2916/// - Respects cancellation tokens before resource acquisition
2917/// - Once execution begins (via ResourceGuard), always awaits task completion
2918/// - Tasks handle their own cancellation internally (if created with `spawn_cancellable`)
2919///
2920/// # Example
2921/// ```rust
2922/// # use dynamo_runtime::utils::tasks::tracker::UnlimitedScheduler;
2923/// let scheduler = UnlimitedScheduler::new();
2924/// ```
2925#[derive(Debug)]
2926pub struct UnlimitedScheduler;
2927
2928impl UnlimitedScheduler {
2929 /// Create a new unlimited scheduler returning Arc
2930 pub fn new() -> Arc<Self> {
2931 Arc::new(Self)
2932 }
2933}
2934
2935impl Default for UnlimitedScheduler {
2936 fn default() -> Self {
2937 UnlimitedScheduler
2938 }
2939}
2940
2941#[async_trait]
2942impl TaskScheduler for UnlimitedScheduler {
2943 async fn acquire_execution_slot(
2944 &self,
2945 cancel_token: CancellationToken,
2946 ) -> SchedulingResult<Box<dyn ResourceGuard>> {
2947 debug!("Acquiring execution slot (unlimited scheduler)");
2948
2949 // Check for cancellation before allocating resources
2950 if cancel_token.is_cancelled() {
2951 debug!("Task cancelled before acquiring execution slot");
2952 return SchedulingResult::Cancelled;
2953 }
2954
2955 // No resource constraints for unlimited scheduler
2956 debug!("Execution slot acquired immediately");
2957 SchedulingResult::Execute(Box::new(UnlimitedGuard))
2958 }
2959}
2960
2961/// Resource guard for semaphore-based scheduling
2962///
2963/// This guard holds a semaphore permit and enforces that task execution
2964/// always runs to completion. The permit is automatically released when
2965/// the guard is dropped.
2966#[derive(Debug)]
2967pub struct SemaphoreGuard {
2968 _permit: tokio::sync::OwnedSemaphorePermit,
2969}
2970
2971impl ResourceGuard for SemaphoreGuard {
2972 // Permit is automatically released when the guard is dropped
2973}
2974
2975/// Semaphore-based task scheduler
2976///
2977/// Limits concurrent task execution using a [`tokio::sync::Semaphore`].
2978/// Tasks will wait for an available permit before executing.
2979///
2980/// ## Cancellation Behavior
2981///
2982/// - Respects cancellation tokens before and during permit acquisition
2983/// - Once a permit is acquired (via ResourceGuard), always awaits task completion
2984/// - Holds the permit until the task completes (regardless of cancellation)
2985/// - Tasks handle their own cancellation internally (if created with `spawn_cancellable`)
2986///
2987/// This ensures that permits are not leaked when tasks are cancelled, while still
2988/// allowing cancellable tasks to terminate gracefully on their own.
2989///
2990/// # Example
2991/// ```rust
2992/// # use std::sync::Arc;
2993/// # use tokio::sync::Semaphore;
2994/// # use dynamo_runtime::utils::tasks::tracker::SemaphoreScheduler;
2995/// // Allow up to 5 concurrent tasks
2996/// let semaphore = Arc::new(Semaphore::new(5));
2997/// let scheduler = SemaphoreScheduler::new(semaphore);
2998/// ```
2999#[derive(Debug)]
3000pub struct SemaphoreScheduler {
3001 semaphore: Arc<Semaphore>,
3002}
3003
3004impl SemaphoreScheduler {
3005 /// Create a new semaphore scheduler
3006 ///
3007 /// # Arguments
3008 /// * `semaphore` - Semaphore to use for concurrency control
3009 pub fn new(semaphore: Arc<Semaphore>) -> Self {
3010 Self { semaphore }
3011 }
3012
3013 /// Create a semaphore scheduler with the specified number of permits, returning Arc
3014 pub fn with_permits(permits: usize) -> Arc<Self> {
3015 Arc::new(Self::new(Arc::new(Semaphore::new(permits))))
3016 }
3017
3018 /// Get the number of available permits
3019 pub fn available_permits(&self) -> usize {
3020 self.semaphore.available_permits()
3021 }
3022}
3023
3024#[async_trait]
3025impl TaskScheduler for SemaphoreScheduler {
3026 async fn acquire_execution_slot(
3027 &self,
3028 cancel_token: CancellationToken,
3029 ) -> SchedulingResult<Box<dyn ResourceGuard>> {
3030 debug!("Acquiring semaphore permit");
3031
3032 // Check for cancellation before attempting to acquire semaphore
3033 if cancel_token.is_cancelled() {
3034 debug!("Task cancelled before acquiring semaphore permit");
3035 return SchedulingResult::Cancelled;
3036 }
3037
3038 // Try to acquire a permit, with cancellation support
3039 let permit = {
3040 tokio::select! {
3041 result = self.semaphore.clone().acquire_owned() => {
3042 match result {
3043 Ok(permit) => permit,
3044 Err(_) => return SchedulingResult::Cancelled,
3045 }
3046 }
3047 _ = cancel_token.cancelled() => {
3048 debug!("Task cancelled while waiting for semaphore permit");
3049 return SchedulingResult::Cancelled;
3050 }
3051 }
3052 };
3053
3054 debug!("Acquired semaphore permit");
3055 SchedulingResult::Execute(Box::new(SemaphoreGuard { _permit: permit }))
3056 }
3057}
3058
3059/// Error policy that triggers cancellation based on error patterns
3060///
3061/// This policy analyzes error messages and returns `ErrorResponse::Shutdown` when:
3062/// - No patterns are specified (cancels on any error)
3063/// - Error message matches one of the specified patterns
3064///
3065/// The TaskTracker handles the actual cancellation - this policy just makes the decision.
3066///
3067/// # Example
3068/// ```rust
3069/// # use dynamo_runtime::utils::tasks::tracker::CancelOnError;
3070/// // Cancel on any error
3071/// let policy = CancelOnError::new();
3072///
3073/// // Cancel only on specific error patterns
3074/// let (policy, _token) = CancelOnError::with_patterns(
3075/// vec!["OutOfMemory".to_string(), "DeviceError".to_string()]
3076/// );
3077/// ```
3078#[derive(Debug)]
3079pub struct CancelOnError {
3080 error_patterns: Vec<String>,
3081}
3082
3083impl CancelOnError {
3084 /// Create a new cancel-on-error policy that cancels on any error
3085 ///
3086 /// Returns a policy with no error patterns, meaning it will cancel the TaskTracker
3087 /// on any task failure.
3088 pub fn new() -> Arc<Self> {
3089 Arc::new(Self {
3090 error_patterns: vec![], // Empty patterns = cancel on any error
3091 })
3092 }
3093
3094 /// Create a new cancel-on-error policy with custom error patterns, returning Arc and token
3095 ///
3096 /// # Arguments
3097 /// * `error_patterns` - List of error message patterns that trigger cancellation
3098 pub fn with_patterns(error_patterns: Vec<String>) -> (Arc<Self>, CancellationToken) {
3099 let token = CancellationToken::new();
3100 let policy = Arc::new(Self { error_patterns });
3101 (policy, token)
3102 }
3103}
3104
3105#[async_trait]
3106impl OnErrorPolicy for CancelOnError {
3107 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3108 // Child gets a child cancel token - when parent cancels, child cancels too
3109 // When child cancels, parent is unaffected
3110 Arc::new(CancelOnError {
3111 error_patterns: self.error_patterns.clone(),
3112 })
3113 }
3114
3115 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3116 None // Stateless policy - no heap allocation
3117 }
3118
3119 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3120 error!(?context.task_id, "Task failed - {error:?}");
3121
3122 if self.error_patterns.is_empty() {
3123 return ErrorResponse::Shutdown;
3124 }
3125
3126 // Check if this error should trigger cancellation
3127 let error_str = error.to_string();
3128 let should_cancel = self
3129 .error_patterns
3130 .iter()
3131 .any(|pattern| error_str.contains(pattern));
3132
3133 if should_cancel {
3134 ErrorResponse::Shutdown
3135 } else {
3136 ErrorResponse::Fail
3137 }
3138 }
3139}
3140
3141/// Simple error policy that only logs errors
3142///
3143/// This policy does not trigger cancellation and is useful for
3144/// non-critical tasks or when you want to handle errors externally.
3145#[derive(Debug)]
3146pub struct LogOnlyPolicy;
3147
3148impl LogOnlyPolicy {
3149 /// Create a new log-only policy returning Arc
3150 pub fn new() -> Arc<Self> {
3151 Arc::new(Self)
3152 }
3153}
3154
3155impl Default for LogOnlyPolicy {
3156 fn default() -> Self {
3157 LogOnlyPolicy
3158 }
3159}
3160
3161impl OnErrorPolicy for LogOnlyPolicy {
3162 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3163 // Simple policies can just clone themselves
3164 Arc::new(LogOnlyPolicy)
3165 }
3166
3167 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3168 None // Stateless policy - no heap allocation
3169 }
3170
3171 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3172 error!(?context.task_id, "Task failed - logging only - {error:?}");
3173 ErrorResponse::Fail
3174 }
3175}
3176
3177/// Error policy that cancels tasks after a threshold number of failures
3178///
3179/// This policy tracks the number of failed tasks and triggers cancellation
3180/// when the failure count exceeds the specified threshold. Useful for
3181/// preventing cascading failures in distributed systems.
3182///
3183/// # Example
3184/// ```rust
3185/// # use dynamo_runtime::utils::tasks::tracker::ThresholdCancelPolicy;
3186/// // Cancel after 5 failures
3187/// let policy = ThresholdCancelPolicy::with_threshold(5);
3188/// ```
3189#[derive(Debug)]
3190pub struct ThresholdCancelPolicy {
3191 max_failures: usize,
3192 failure_count: AtomicU64,
3193}
3194
3195impl ThresholdCancelPolicy {
3196 /// Create a new threshold cancel policy with specified failure threshold, returning Arc and token
3197 ///
3198 /// # Arguments
3199 /// * `max_failures` - Maximum number of failures before cancellation
3200 pub fn with_threshold(max_failures: usize) -> Arc<Self> {
3201 Arc::new(Self {
3202 max_failures,
3203 failure_count: AtomicU64::new(0),
3204 })
3205 }
3206
3207 /// Get the current failure count
3208 pub fn failure_count(&self) -> u64 {
3209 self.failure_count.load(Ordering::Relaxed)
3210 }
3211
3212 /// Reset the failure count to zero
3213 ///
3214 /// This is primarily useful for testing scenarios where you want to reset
3215 /// the policy state between test cases.
3216 pub fn reset_failure_count(&self) {
3217 self.failure_count.store(0, Ordering::Relaxed);
3218 }
3219}
3220
3221/// Per-task state for ThresholdCancelPolicy
3222#[derive(Debug)]
3223struct ThresholdState {
3224 failure_count: u32,
3225}
3226
3227impl OnErrorPolicy for ThresholdCancelPolicy {
3228 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3229 // Child gets a child cancel token and inherits the same failure threshold
3230 Arc::new(ThresholdCancelPolicy {
3231 max_failures: self.max_failures,
3232 failure_count: AtomicU64::new(0), // Child starts with fresh count
3233 })
3234 }
3235
3236 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3237 Some(Box::new(ThresholdState { failure_count: 0 }))
3238 }
3239
3240 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3241 error!(?context.task_id, "Task failed - {error:?}");
3242
3243 // Increment global counter for backwards compatibility
3244 let global_failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
3245
3246 // Get per-task state for the actual decision logic
3247 let state = context
3248 .state
3249 .as_mut()
3250 .expect("ThresholdCancelPolicy requires state")
3251 .downcast_mut::<ThresholdState>()
3252 .expect("Context type mismatch");
3253
3254 state.failure_count += 1;
3255 let current_failures = state.failure_count;
3256
3257 if current_failures >= self.max_failures as u32 {
3258 warn!(
3259 ?context.task_id,
3260 current_failures,
3261 global_failures,
3262 max_failures = self.max_failures,
3263 "Per-task failure threshold exceeded, triggering cancellation"
3264 );
3265 ErrorResponse::Shutdown
3266 } else {
3267 debug!(
3268 ?context.task_id,
3269 current_failures,
3270 global_failures,
3271 max_failures = self.max_failures,
3272 "Task failed, tracking per-task failure count"
3273 );
3274 ErrorResponse::Fail
3275 }
3276 }
3277}
3278
3279/// Error policy that cancels tasks when failure rate exceeds threshold within time window
3280///
3281/// This policy tracks failures over a rolling time window and triggers cancellation
3282/// when the failure rate exceeds the specified threshold. More sophisticated than
3283/// simple count-based thresholds as it considers the time dimension.
3284///
3285/// # Example
3286/// ```rust
3287/// # use dynamo_runtime::utils::tasks::tracker::RateCancelPolicy;
3288/// // Cancel if more than 50% of tasks fail within any 60-second window
3289/// let (policy, token) = RateCancelPolicy::builder()
3290/// .rate(0.5)
3291/// .window_secs(60)
3292/// .build();
3293/// ```
3294#[derive(Debug)]
3295pub struct RateCancelPolicy {
3296 cancel_token: CancellationToken,
3297 max_failure_rate: f32,
3298 window_secs: u64,
3299 // TODO: Implement time-window tracking when needed
3300 // For now, this is a placeholder structure with the interface defined
3301}
3302
3303impl RateCancelPolicy {
3304 /// Create a builder for rate-based cancel policy
3305 pub fn builder() -> RateCancelPolicyBuilder {
3306 RateCancelPolicyBuilder::new()
3307 }
3308}
3309
3310/// Builder for RateCancelPolicy
3311pub struct RateCancelPolicyBuilder {
3312 max_failure_rate: Option<f32>,
3313 window_secs: Option<u64>,
3314}
3315
3316impl RateCancelPolicyBuilder {
3317 fn new() -> Self {
3318 Self {
3319 max_failure_rate: None,
3320 window_secs: None,
3321 }
3322 }
3323
3324 /// Set the maximum failure rate (0.0 to 1.0) before cancellation
3325 pub fn rate(mut self, max_failure_rate: f32) -> Self {
3326 self.max_failure_rate = Some(max_failure_rate);
3327 self
3328 }
3329
3330 /// Set the time window in seconds for rate calculation
3331 pub fn window_secs(mut self, window_secs: u64) -> Self {
3332 self.window_secs = Some(window_secs);
3333 self
3334 }
3335
3336 /// Build the policy, returning Arc and cancellation token
3337 pub fn build(self) -> (Arc<RateCancelPolicy>, CancellationToken) {
3338 let max_failure_rate = self.max_failure_rate.expect("rate must be set");
3339 let window_secs = self.window_secs.expect("window_secs must be set");
3340
3341 let token = CancellationToken::new();
3342 let policy = Arc::new(RateCancelPolicy {
3343 cancel_token: token.clone(),
3344 max_failure_rate,
3345 window_secs,
3346 });
3347 (policy, token)
3348 }
3349}
3350
3351#[async_trait]
3352impl OnErrorPolicy for RateCancelPolicy {
3353 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3354 Arc::new(RateCancelPolicy {
3355 cancel_token: self.cancel_token.child_token(),
3356 max_failure_rate: self.max_failure_rate,
3357 window_secs: self.window_secs,
3358 })
3359 }
3360
3361 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3362 None // Stateless policy for now (TODO: add time-window state)
3363 }
3364
3365 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3366 error!(?context.task_id, "Task failed - {error:?}");
3367
3368 // TODO: Implement time-window failure rate calculation
3369 // For now, just log the error and continue
3370 warn!(
3371 ?context.task_id,
3372 max_failure_rate = self.max_failure_rate,
3373 window_secs = self.window_secs,
3374 "Rate-based error policy - time window tracking not yet implemented"
3375 );
3376
3377 ErrorResponse::Fail
3378 }
3379}
3380
3381/// Custom action that triggers a cancellation token when executed
3382///
3383/// This action demonstrates the ErrorResponse::Custom behavior by capturing
3384/// an external cancellation token and triggering it when executed.
3385#[derive(Debug)]
3386pub struct TriggerCancellationTokenAction {
3387 cancel_token: CancellationToken,
3388}
3389
3390impl TriggerCancellationTokenAction {
3391 pub fn new(cancel_token: CancellationToken) -> Self {
3392 Self { cancel_token }
3393 }
3394}
3395
3396#[async_trait]
3397impl OnErrorAction for TriggerCancellationTokenAction {
3398 async fn execute(
3399 &self,
3400 error: &anyhow::Error,
3401 task_id: TaskId,
3402 _attempt_count: u32,
3403 _context: &TaskExecutionContext,
3404 ) -> ActionResult {
3405 warn!(
3406 ?task_id,
3407 "Executing custom action: triggering cancellation token - {error:?}"
3408 );
3409
3410 // Trigger the custom cancellation token
3411 self.cancel_token.cancel();
3412
3413 // Return success - the action completed successfully
3414 ActionResult::Shutdown
3415 }
3416}
3417
3418/// Test error policy that triggers a custom cancellation token on any error
3419///
3420/// This policy demonstrates the ErrorResponse::Custom behavior by capturing
3421/// an external cancellation token and triggering it when any error occurs.
3422/// Used for testing custom error handling actions.
3423///
3424/// # Example
3425/// ```rust
3426/// # use tokio_util::sync::CancellationToken;
3427/// # use dynamo_runtime::utils::tasks::tracker::TriggerCancellationTokenOnError;
3428/// let cancel_token = CancellationToken::new();
3429/// let policy = TriggerCancellationTokenOnError::new(cancel_token.clone());
3430///
3431/// // Policy will trigger the token on any error via ErrorResponse::Custom
3432/// ```
3433#[derive(Debug)]
3434pub struct TriggerCancellationTokenOnError {
3435 cancel_token: CancellationToken,
3436}
3437
3438impl TriggerCancellationTokenOnError {
3439 /// Create a new policy that triggers the given cancellation token on errors
3440 pub fn new(cancel_token: CancellationToken) -> Arc<Self> {
3441 Arc::new(Self { cancel_token })
3442 }
3443}
3444
3445impl OnErrorPolicy for TriggerCancellationTokenOnError {
3446 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3447 // Child gets a child cancel token
3448 Arc::new(TriggerCancellationTokenOnError {
3449 cancel_token: self.cancel_token.clone(),
3450 })
3451 }
3452
3453 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3454 None // Stateless policy - no heap allocation
3455 }
3456
3457 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3458 error!(
3459 ?context.task_id,
3460 "Task failed - triggering custom cancellation token - {error:?}"
3461 );
3462
3463 // Create the custom action that will trigger our token
3464 let action = TriggerCancellationTokenAction::new(self.cancel_token.clone());
3465
3466 // Return Custom response with our action
3467 ErrorResponse::Custom(Box::new(action))
3468 }
3469}
3470
3471#[cfg(test)]
3472mod tests {
3473 use super::*;
3474 use rstest::*;
3475 use std::sync::atomic::AtomicU32;
3476 use std::time::Duration;
3477
3478 // Test fixtures using rstest
3479 #[fixture]
3480 fn semaphore_scheduler() -> Arc<SemaphoreScheduler> {
3481 Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))))
3482 }
3483
3484 #[fixture]
3485 fn unlimited_scheduler() -> Arc<UnlimitedScheduler> {
3486 UnlimitedScheduler::new()
3487 }
3488
3489 #[fixture]
3490 fn log_policy() -> Arc<LogOnlyPolicy> {
3491 LogOnlyPolicy::new()
3492 }
3493
3494 #[fixture]
3495 fn cancel_policy() -> Arc<CancelOnError> {
3496 CancelOnError::new()
3497 }
3498
3499 #[fixture]
3500 fn basic_tracker(
3501 unlimited_scheduler: Arc<UnlimitedScheduler>,
3502 log_policy: Arc<LogOnlyPolicy>,
3503 ) -> TaskTracker {
3504 TaskTracker::new(unlimited_scheduler, log_policy).unwrap()
3505 }
3506
3507 #[rstest]
3508 #[tokio::test]
3509 async fn test_basic_task_execution(basic_tracker: TaskTracker) {
3510 // Test successful task execution
3511 let (tx, rx) = tokio::sync::oneshot::channel();
3512 let handle = basic_tracker.spawn(async {
3513 // Wait for signal to complete instead of sleep
3514 rx.await.ok();
3515 Ok(42)
3516 });
3517
3518 // Signal task to complete
3519 tx.send(()).ok();
3520
3521 // Verify task completes successfully
3522 let result = handle
3523 .await
3524 .expect("Task should complete")
3525 .expect("Task should succeed");
3526 assert_eq!(result, 42);
3527
3528 // Verify metrics
3529 assert_eq!(basic_tracker.metrics().success(), 1);
3530 assert_eq!(basic_tracker.metrics().failed(), 0);
3531 assert_eq!(basic_tracker.metrics().cancelled(), 0);
3532 assert_eq!(basic_tracker.metrics().active(), 0);
3533 }
3534
3535 #[rstest]
3536 #[tokio::test]
3537 async fn test_task_failure(
3538 semaphore_scheduler: Arc<SemaphoreScheduler>,
3539 log_policy: Arc<LogOnlyPolicy>,
3540 ) {
3541 // Test task failure handling
3542 let tracker = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3543
3544 let handle = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("test error")) });
3545
3546 let result = handle.await.unwrap();
3547 assert!(result.is_err());
3548 assert!(matches!(result.unwrap_err(), TaskError::Failed(_)));
3549
3550 // Verify metrics
3551 assert_eq!(tracker.metrics().success(), 0);
3552 assert_eq!(tracker.metrics().failed(), 1);
3553 assert_eq!(tracker.metrics().cancelled(), 0);
3554 }
3555
3556 #[rstest]
3557 #[tokio::test]
3558 async fn test_semaphore_concurrency_limit(log_policy: Arc<LogOnlyPolicy>) {
3559 // Test that semaphore limits concurrent execution
3560 let limited_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2)))); // Only 2 concurrent tasks
3561 let tracker = TaskTracker::new(limited_scheduler, log_policy).unwrap();
3562
3563 let counter = Arc::new(AtomicU32::new(0));
3564 let max_concurrent = Arc::new(AtomicU32::new(0));
3565
3566 // Use broadcast channel to coordinate all tasks
3567 let (tx, _) = tokio::sync::broadcast::channel(1);
3568 let mut handles = Vec::new();
3569
3570 // Spawn 5 tasks that will track concurrency
3571 for _ in 0..5 {
3572 let counter_clone = counter.clone();
3573 let max_clone = max_concurrent.clone();
3574 let mut rx = tx.subscribe();
3575
3576 let handle = tracker.spawn(async move {
3577 // Increment active counter
3578 let current = counter_clone.fetch_add(1, Ordering::Relaxed) + 1;
3579
3580 // Track max concurrent
3581 max_clone.fetch_max(current, Ordering::Relaxed);
3582
3583 // Wait for signal to complete instead of sleep
3584 rx.recv().await.ok();
3585
3586 // Decrement when done
3587 counter_clone.fetch_sub(1, Ordering::Relaxed);
3588
3589 Ok(())
3590 });
3591 handles.push(handle);
3592 }
3593
3594 // Give tasks time to start and register concurrency
3595 tokio::task::yield_now().await;
3596 tokio::task::yield_now().await;
3597
3598 // Signal all tasks to complete
3599 tx.send(()).ok();
3600
3601 // Wait for all tasks to complete
3602 for handle in handles {
3603 handle.await.unwrap().unwrap();
3604 }
3605
3606 // Verify that no more than 2 tasks ran concurrently
3607 assert!(max_concurrent.load(Ordering::Relaxed) <= 2);
3608
3609 // Verify all tasks completed successfully
3610 assert_eq!(tracker.metrics().success(), 5);
3611 assert_eq!(tracker.metrics().failed(), 0);
3612 }
3613
3614 #[rstest]
3615 #[tokio::test]
3616 async fn test_cancel_on_error_policy() {
3617 // Test that CancelOnError policy works correctly
3618 let error_policy = cancel_policy();
3619 let scheduler = semaphore_scheduler();
3620 let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
3621
3622 // Spawn a task that will trigger cancellation
3623 let handle =
3624 tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("OutOfMemory error occurred")) });
3625
3626 // Wait for the error to occur
3627 let result = handle.await.unwrap();
3628 assert!(result.is_err());
3629
3630 // Give cancellation time to propagate
3631 tokio::time::sleep(Duration::from_millis(10)).await;
3632
3633 // Verify the cancel token was triggered
3634 assert!(tracker.cancellation_token().is_cancelled());
3635 }
3636
3637 #[rstest]
3638 #[tokio::test]
3639 async fn test_tracker_cancellation() {
3640 // Test manual cancellation of tracker with CancelOnError policy
3641 let error_policy = cancel_policy();
3642 let scheduler = semaphore_scheduler();
3643 let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
3644 let cancel_token = tracker.cancellation_token().child_token();
3645
3646 // Use oneshot channel instead of sleep for deterministic timing
3647 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
3648
3649 // Spawn a task that respects cancellation
3650 let handle = tracker.spawn({
3651 let cancel_token = cancel_token.clone();
3652 async move {
3653 tokio::select! {
3654 _ = rx => Ok(()),
3655 _ = cancel_token.cancelled() => Err(anyhow::anyhow!("Task was cancelled")),
3656 }
3657 }
3658 });
3659
3660 // Cancel the tracker
3661 tracker.cancel();
3662
3663 // Task should be cancelled
3664 let result = handle.await.unwrap();
3665 assert!(result.is_err());
3666 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
3667 }
3668
3669 #[rstest]
3670 #[tokio::test]
3671 async fn test_child_tracker_independence(
3672 semaphore_scheduler: Arc<SemaphoreScheduler>,
3673 log_policy: Arc<LogOnlyPolicy>,
3674 ) {
3675 // Test that child tracker has independent lifecycle
3676 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3677
3678 let child = parent.child_tracker().unwrap();
3679
3680 // Both should be operational initially
3681 assert!(!parent.is_closed());
3682 assert!(!child.is_closed());
3683
3684 // Cancel child only
3685 child.cancel();
3686
3687 // Parent should remain operational
3688 assert!(!parent.is_closed());
3689
3690 // Parent can still spawn tasks
3691 let handle = parent.spawn(async { Ok(42) });
3692 let result = handle.await.unwrap().unwrap();
3693 assert_eq!(result, 42);
3694 }
3695
3696 #[rstest]
3697 #[tokio::test]
3698 async fn test_independent_metrics(
3699 semaphore_scheduler: Arc<SemaphoreScheduler>,
3700 log_policy: Arc<LogOnlyPolicy>,
3701 ) {
3702 // Test that parent and child have independent metrics
3703 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3704 let child = parent.child_tracker().unwrap();
3705
3706 // Run tasks in parent
3707 let handle1 = parent.spawn(async { Ok(1) });
3708 handle1.await.unwrap().unwrap();
3709
3710 // Run tasks in child
3711 let handle2 = child.spawn(async { Ok(2) });
3712 handle2.await.unwrap().unwrap();
3713
3714 // Each should have their own metrics, but parent sees aggregated
3715 assert_eq!(parent.metrics().success(), 2); // Parent sees its own + child's
3716 assert_eq!(child.metrics().success(), 1); // Child sees only its own
3717 assert_eq!(parent.metrics().total_completed(), 2); // Parent sees aggregated total
3718 assert_eq!(child.metrics().total_completed(), 1); // Child sees only its own
3719 }
3720
3721 #[rstest]
3722 #[tokio::test]
3723 async fn test_cancel_on_error_hierarchy() {
3724 // Test that child error policy cancellation doesn't affect parent
3725 let parent_error_policy = cancel_policy();
3726 let scheduler = semaphore_scheduler();
3727 let parent = TaskTracker::new(scheduler, parent_error_policy).unwrap();
3728 let parent_policy_token = parent.cancellation_token().child_token();
3729 let child = parent.child_tracker().unwrap();
3730
3731 // Initially nothing should be cancelled
3732 assert!(!parent_policy_token.is_cancelled());
3733
3734 // Use explicit synchronization instead of sleep
3735 let (error_tx, error_rx) = tokio::sync::oneshot::channel();
3736 let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel();
3737
3738 // Spawn a monitoring task to watch for the parent policy token cancellation
3739 let parent_token_monitor = parent_policy_token.clone();
3740 let monitor_handle = tokio::spawn(async move {
3741 tokio::select! {
3742 _ = parent_token_monitor.cancelled() => {
3743 cancel_tx.send(true).ok();
3744 }
3745 _ = tokio::time::sleep(Duration::from_millis(100)) => {
3746 cancel_tx.send(false).ok();
3747 }
3748 }
3749 });
3750
3751 // Spawn a task in the child that will trigger cancellation
3752 let handle = child.spawn(async move {
3753 let result = Err::<(), _>(anyhow::anyhow!("OutOfMemory in child"));
3754 error_tx.send(()).ok(); // Signal that the error has occurred
3755 result
3756 });
3757
3758 // Wait for the error to occur
3759 let error_result = handle.await.unwrap();
3760 assert!(error_result.is_err());
3761
3762 // Wait for our error signal
3763 error_rx.await.ok();
3764
3765 // Check if parent policy token was cancelled within timeout
3766 let was_cancelled = cancel_rx.await.unwrap_or(false);
3767 monitor_handle.await.ok();
3768
3769 // Based on hierarchical design: child errors should NOT affect parent
3770 // The child gets its own policy with a child token, and child cancellation
3771 // should not propagate up to the parent policy token
3772 assert!(
3773 !was_cancelled,
3774 "Parent policy token should not be cancelled by child errors"
3775 );
3776 assert!(
3777 !parent_policy_token.is_cancelled(),
3778 "Parent policy token should remain active"
3779 );
3780 }
3781
3782 #[rstest]
3783 #[tokio::test]
3784 async fn test_graceful_shutdown(
3785 semaphore_scheduler: Arc<SemaphoreScheduler>,
3786 log_policy: Arc<LogOnlyPolicy>,
3787 ) {
3788 // Test graceful shutdown with close()
3789 let tracker = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3790
3791 // Use broadcast channel to coordinate task completion
3792 let (tx, _) = tokio::sync::broadcast::channel(1);
3793 let mut handles = Vec::new();
3794
3795 // Spawn some tasks
3796 for i in 0..3 {
3797 let mut rx = tx.subscribe();
3798 let handle = tracker.spawn(async move {
3799 // Wait for signal instead of sleep
3800 rx.recv().await.ok();
3801 Ok(i)
3802 });
3803 handles.push(handle);
3804 }
3805
3806 // Signal all tasks to complete before closing
3807 tx.send(()).ok();
3808
3809 // Close tracker and wait for completion
3810 tracker.join().await;
3811
3812 // All tasks should complete successfully
3813 for handle in handles {
3814 let result = handle.await.unwrap().unwrap();
3815 assert!(result < 3);
3816 }
3817
3818 // Tracker should be closed
3819 assert!(tracker.is_closed());
3820 }
3821
3822 #[rstest]
3823 #[tokio::test]
3824 async fn test_semaphore_scheduler_permit_tracking(log_policy: Arc<LogOnlyPolicy>) {
3825 // Test that SemaphoreScheduler properly tracks permits
3826 let semaphore = Arc::new(Semaphore::new(3));
3827 let scheduler = Arc::new(SemaphoreScheduler::new(semaphore.clone()));
3828 let tracker = TaskTracker::new(scheduler.clone(), log_policy).unwrap();
3829
3830 // Initially all permits should be available
3831 assert_eq!(scheduler.available_permits(), 3);
3832
3833 // Use broadcast channel to coordinate task completion
3834 let (tx, _) = tokio::sync::broadcast::channel(1);
3835 let mut handles = Vec::new();
3836
3837 // Spawn 3 tasks that will hold permits
3838 for _ in 0..3 {
3839 let mut rx = tx.subscribe();
3840 let handle = tracker.spawn(async move {
3841 // Wait for signal to complete
3842 rx.recv().await.ok();
3843 Ok(())
3844 });
3845 handles.push(handle);
3846 }
3847
3848 // Give tasks time to acquire permits
3849 tokio::task::yield_now().await;
3850 tokio::task::yield_now().await;
3851
3852 // All permits should be taken
3853 assert_eq!(scheduler.available_permits(), 0);
3854
3855 // Signal all tasks to complete
3856 tx.send(()).ok();
3857
3858 // Wait for tasks to complete
3859 for handle in handles {
3860 handle.await.unwrap().unwrap();
3861 }
3862
3863 // All permits should be available again
3864 assert_eq!(scheduler.available_permits(), 3);
3865 }
3866
3867 #[rstest]
3868 #[tokio::test]
3869 async fn test_builder_pattern(log_policy: Arc<LogOnlyPolicy>) {
3870 // Test that TaskTracker builder works correctly
3871 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3872 let error_policy = log_policy;
3873
3874 let tracker = TaskTracker::builder()
3875 .scheduler(scheduler)
3876 .error_policy(error_policy)
3877 .build()
3878 .unwrap();
3879
3880 // Tracker should have a cancellation token
3881 let token = tracker.cancellation_token();
3882 assert!(!token.is_cancelled());
3883
3884 // Should be able to spawn tasks
3885 let handle = tracker.spawn(async { Ok(42) });
3886 let result = handle.await.unwrap().unwrap();
3887 assert_eq!(result, 42);
3888 }
3889
3890 #[rstest]
3891 #[tokio::test]
3892 async fn test_all_trackers_have_cancellation_tokens(log_policy: Arc<LogOnlyPolicy>) {
3893 // Test that all trackers (root and children) have cancellation tokens
3894 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3895 let root = TaskTracker::new(scheduler, log_policy).unwrap();
3896 let child = root.child_tracker().unwrap();
3897 let grandchild = child.child_tracker().unwrap();
3898
3899 // All should have cancellation tokens
3900 let root_token = root.cancellation_token();
3901 let child_token = child.cancellation_token();
3902 let grandchild_token = grandchild.cancellation_token();
3903
3904 assert!(!root_token.is_cancelled());
3905 assert!(!child_token.is_cancelled());
3906 assert!(!grandchild_token.is_cancelled());
3907
3908 // Child tokens should be different from parent
3909 // (We can't directly compare tokens, but we can test behavior)
3910 root_token.cancel();
3911
3912 // Give cancellation time to propagate
3913 tokio::time::sleep(Duration::from_millis(10)).await;
3914
3915 // Root should be cancelled
3916 assert!(root_token.is_cancelled());
3917 // Children should also be cancelled (because they are child tokens)
3918 assert!(child_token.is_cancelled());
3919 assert!(grandchild_token.is_cancelled());
3920 }
3921
3922 #[rstest]
3923 #[tokio::test]
3924 async fn test_spawn_cancellable_task(log_policy: Arc<LogOnlyPolicy>) {
3925 // Test cancellable task spawning with proper result handling
3926 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3927 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
3928
3929 // Test successful completion
3930 let (tx, rx) = tokio::sync::oneshot::channel();
3931 let rx = Arc::new(tokio::sync::Mutex::new(Some(rx)));
3932 let handle = tracker.spawn_cancellable(move |_cancel_token| {
3933 let rx = rx.clone();
3934 async move {
3935 // Wait for signal instead of sleep
3936 if let Some(rx) = rx.lock().await.take() {
3937 rx.await.ok();
3938 }
3939 CancellableTaskResult::Ok(42)
3940 }
3941 });
3942
3943 // Signal task to complete
3944 tx.send(()).ok();
3945
3946 let result = handle.await.unwrap().unwrap();
3947 assert_eq!(result, 42);
3948 assert_eq!(tracker.metrics().success(), 1);
3949
3950 // Test cancellation handling
3951 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
3952 let rx = Arc::new(tokio::sync::Mutex::new(Some(rx)));
3953 let handle = tracker.spawn_cancellable(move |cancel_token| {
3954 let rx = rx.clone();
3955 async move {
3956 tokio::select! {
3957 _ = async {
3958 if let Some(rx) = rx.lock().await.take() {
3959 rx.await.ok();
3960 }
3961 } => CancellableTaskResult::Ok("should not complete"),
3962 _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
3963 }
3964 }
3965 });
3966
3967 // Cancel the tracker
3968 tracker.cancel();
3969
3970 let result = handle.await.unwrap();
3971 assert!(result.is_err());
3972 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
3973 }
3974
3975 #[rstest]
3976 #[tokio::test]
3977 async fn test_cancellable_task_metrics_tracking(log_policy: Arc<LogOnlyPolicy>) {
3978 // Test that properly cancelled tasks increment cancelled metrics, not failed metrics
3979 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3980 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
3981
3982 // Baseline metrics
3983 assert_eq!(tracker.metrics().cancelled(), 0);
3984 assert_eq!(tracker.metrics().failed(), 0);
3985 assert_eq!(tracker.metrics().success(), 0);
3986
3987 // Test 1: Task that executes and THEN gets cancelled during execution
3988 let (start_tx, start_rx) = tokio::sync::oneshot::channel::<()>();
3989 let (_continue_tx, continue_rx) = tokio::sync::oneshot::channel::<()>();
3990
3991 let start_tx_shared = Arc::new(tokio::sync::Mutex::new(Some(start_tx)));
3992 let continue_rx_shared = Arc::new(tokio::sync::Mutex::new(Some(continue_rx)));
3993
3994 let start_tx_for_task = start_tx_shared.clone();
3995 let continue_rx_for_task = continue_rx_shared.clone();
3996
3997 let handle = tracker.spawn_cancellable(move |cancel_token| {
3998 let start_tx = start_tx_for_task.clone();
3999 let continue_rx = continue_rx_for_task.clone();
4000 async move {
4001 // Signal that we've started executing
4002 if let Some(tx) = start_tx.lock().await.take() {
4003 tx.send(()).ok();
4004 }
4005
4006 // Wait for either continuation signal or cancellation
4007 tokio::select! {
4008 _ = async {
4009 if let Some(rx) = continue_rx.lock().await.take() {
4010 rx.await.ok();
4011 }
4012 } => CancellableTaskResult::Ok("completed normally"),
4013 _ = cancel_token.cancelled() => {
4014 println!("Task detected cancellation and is returning Cancelled");
4015 CancellableTaskResult::Cancelled
4016 },
4017 }
4018 }
4019 });
4020
4021 // Wait for task to start executing
4022 start_rx.await.ok();
4023
4024 // Now cancel while the task is running
4025 println!("Cancelling tracker while task is executing...");
4026 tracker.cancel();
4027
4028 // Wait for the task to complete
4029 let result = handle.await.unwrap();
4030
4031 // Debug output
4032 println!("Task result: {:?}", result);
4033 println!(
4034 "Cancelled: {}, Failed: {}, Success: {}",
4035 tracker.metrics().cancelled(),
4036 tracker.metrics().failed(),
4037 tracker.metrics().success()
4038 );
4039
4040 // The task should be properly cancelled and counted correctly
4041 assert!(result.is_err());
4042 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
4043
4044 // Verify proper metrics: should be counted as cancelled, not failed
4045 assert_eq!(
4046 tracker.metrics().cancelled(),
4047 1,
4048 "Properly cancelled task should increment cancelled count"
4049 );
4050 assert_eq!(
4051 tracker.metrics().failed(),
4052 0,
4053 "Properly cancelled task should NOT increment failed count"
4054 );
4055 }
4056
4057 #[rstest]
4058 #[tokio::test]
4059 async fn test_cancellable_vs_error_metrics_distinction(log_policy: Arc<LogOnlyPolicy>) {
4060 // Test that we properly distinguish between cancellation and actual errors
4061 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
4062 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4063
4064 // Test 1: Actual error should increment failed count
4065 let handle1 = tracker.spawn_cancellable(|_cancel_token| async move {
4066 CancellableTaskResult::<i32>::Err(anyhow::anyhow!("This is a real error"))
4067 });
4068
4069 let result1 = handle1.await.unwrap();
4070 assert!(result1.is_err());
4071 assert!(matches!(result1.unwrap_err(), TaskError::Failed(_)));
4072 assert_eq!(tracker.metrics().failed(), 1);
4073 assert_eq!(tracker.metrics().cancelled(), 0);
4074
4075 // Test 2: Cancellation should increment cancelled count
4076 let handle2 = tracker.spawn_cancellable(|_cancel_token| async move {
4077 CancellableTaskResult::<i32>::Cancelled
4078 });
4079
4080 let result2 = handle2.await.unwrap();
4081 assert!(result2.is_err());
4082 assert!(matches!(result2.unwrap_err(), TaskError::Cancelled));
4083 assert_eq!(tracker.metrics().failed(), 1); // Still 1 from before
4084 assert_eq!(tracker.metrics().cancelled(), 1); // Now 1 from cancellation
4085 }
4086
4087 #[rstest]
4088 #[tokio::test]
4089 async fn test_spawn_cancellable_error_handling(log_policy: Arc<LogOnlyPolicy>) {
4090 // Test error handling in cancellable tasks
4091 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
4092 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4093
4094 // Test error result
4095 let handle = tracker.spawn_cancellable(|_cancel_token| async move {
4096 CancellableTaskResult::<i32>::Err(anyhow::anyhow!("test error"))
4097 });
4098
4099 let result = handle.await.unwrap();
4100 assert!(result.is_err());
4101 assert!(matches!(result.unwrap_err(), TaskError::Failed(_)));
4102 assert_eq!(tracker.metrics().failed(), 1);
4103 }
4104
4105 #[rstest]
4106 #[tokio::test]
4107 async fn test_cancellation_before_execution(log_policy: Arc<LogOnlyPolicy>) {
4108 // Test that spawning on a cancelled tracker panics (new behavior)
4109 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(1))));
4110 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4111
4112 // Cancel the tracker first
4113 tracker.cancel();
4114
4115 // Give cancellation time to propagate to the inner tracker
4116 tokio::time::sleep(Duration::from_millis(5)).await;
4117
4118 // Now try to spawn a task - it should panic since tracker is closed
4119 let panic_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
4120 tracker.spawn(async { Ok(42) })
4121 }));
4122
4123 // Should panic with our new API
4124 assert!(
4125 panic_result.is_err(),
4126 "spawn() should panic when tracker is closed"
4127 );
4128
4129 // Verify the panic message contains expected text
4130 if let Err(panic_payload) = panic_result {
4131 if let Some(panic_msg) = panic_payload.downcast_ref::<String>() {
4132 assert!(
4133 panic_msg.contains("TaskTracker must not be closed"),
4134 "Panic message should indicate tracker is closed: {}",
4135 panic_msg
4136 );
4137 } else if let Some(panic_msg) = panic_payload.downcast_ref::<&str>() {
4138 assert!(
4139 panic_msg.contains("TaskTracker must not be closed"),
4140 "Panic message should indicate tracker is closed: {}",
4141 panic_msg
4142 );
4143 }
4144 }
4145 }
4146
4147 #[rstest]
4148 #[tokio::test]
4149 async fn test_semaphore_scheduler_with_cancellation(log_policy: Arc<LogOnlyPolicy>) {
4150 // Test that SemaphoreScheduler respects cancellation tokens
4151 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(1))));
4152 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4153
4154 // Start a long-running task to occupy the semaphore
4155 let blocker_token = tracker.cancellation_token();
4156 let _blocker_handle = tracker.spawn(async move {
4157 // Wait for cancellation
4158 blocker_token.cancelled().await;
4159 Ok(())
4160 });
4161
4162 // Give the blocker time to acquire the permit
4163 tokio::task::yield_now().await;
4164
4165 // Use oneshot channel for the second task
4166 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
4167
4168 // Spawn another task that will wait for semaphore
4169 let handle = tracker.spawn(async {
4170 rx.await.ok();
4171 Ok(42)
4172 });
4173
4174 // Cancel the tracker while second task is waiting for permit
4175 tracker.cancel();
4176
4177 // The waiting task should be cancelled
4178 let result = handle.await.unwrap();
4179 assert!(result.is_err());
4180 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
4181 }
4182
4183 #[rstest]
4184 #[tokio::test]
4185 async fn test_child_tracker_cancellation_independence(
4186 semaphore_scheduler: Arc<SemaphoreScheduler>,
4187 log_policy: Arc<LogOnlyPolicy>,
4188 ) {
4189 // Test that child tracker cancellation doesn't affect parent
4190 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4191 let child = parent.child_tracker().unwrap();
4192
4193 // Cancel only the child
4194 child.cancel();
4195
4196 // Parent should still be operational
4197 let parent_token = parent.cancellation_token();
4198 assert!(!parent_token.is_cancelled());
4199
4200 // Parent can still spawn tasks
4201 let handle = parent.spawn(async { Ok(42) });
4202 let result = handle.await.unwrap().unwrap();
4203 assert_eq!(result, 42);
4204
4205 // Child should be cancelled
4206 let child_token = child.cancellation_token();
4207 assert!(child_token.is_cancelled());
4208 }
4209
4210 #[rstest]
4211 #[tokio::test]
4212 async fn test_parent_cancellation_propagates_to_children(
4213 semaphore_scheduler: Arc<SemaphoreScheduler>,
4214 log_policy: Arc<LogOnlyPolicy>,
4215 ) {
4216 // Test that parent cancellation propagates to all children
4217 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4218 let child1 = parent.child_tracker().unwrap();
4219 let child2 = parent.child_tracker().unwrap();
4220 let grandchild = child1.child_tracker().unwrap();
4221
4222 // Cancel the parent
4223 parent.cancel();
4224
4225 // Give cancellation time to propagate
4226 tokio::time::sleep(Duration::from_millis(10)).await;
4227
4228 // All should be cancelled
4229 assert!(parent.cancellation_token().is_cancelled());
4230 assert!(child1.cancellation_token().is_cancelled());
4231 assert!(child2.cancellation_token().is_cancelled());
4232 assert!(grandchild.cancellation_token().is_cancelled());
4233 }
4234
4235 #[rstest]
4236 #[tokio::test]
4237 async fn test_issued_counter_tracking(log_policy: Arc<LogOnlyPolicy>) {
4238 // Test that issued counter is incremented when tasks are spawned
4239 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2))));
4240 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4241
4242 // Initially no tasks issued
4243 assert_eq!(tracker.metrics().issued(), 0);
4244 assert_eq!(tracker.metrics().pending(), 0);
4245
4246 // Spawn some tasks
4247 let handle1 = tracker.spawn(async { Ok(1) });
4248 let handle2 = tracker.spawn(async { Ok(2) });
4249 let handle3 = tracker.spawn_cancellable(|_| async { CancellableTaskResult::Ok(3) });
4250
4251 // Issued counter should be incremented immediately
4252 assert_eq!(tracker.metrics().issued(), 3);
4253 assert_eq!(tracker.metrics().pending(), 3); // None completed yet
4254
4255 // Complete the tasks
4256 assert_eq!(handle1.await.unwrap().unwrap(), 1);
4257 assert_eq!(handle2.await.unwrap().unwrap(), 2);
4258 assert_eq!(handle3.await.unwrap().unwrap(), 3);
4259
4260 // Check final accounting
4261 assert_eq!(tracker.metrics().issued(), 3);
4262 assert_eq!(tracker.metrics().success(), 3);
4263 assert_eq!(tracker.metrics().total_completed(), 3);
4264 assert_eq!(tracker.metrics().pending(), 0); // All completed
4265
4266 // Test hierarchical accounting
4267 let child = tracker.child_tracker().unwrap();
4268 let child_handle = child.spawn(async { Ok(42) });
4269
4270 // Both parent and child should see the issued task
4271 assert_eq!(child.metrics().issued(), 1);
4272 assert_eq!(tracker.metrics().issued(), 4); // Parent sees all
4273
4274 child_handle.await.unwrap().unwrap();
4275
4276 // Final hierarchical check
4277 assert_eq!(child.metrics().pending(), 0);
4278 assert_eq!(tracker.metrics().pending(), 0);
4279 assert_eq!(tracker.metrics().success(), 4); // Parent sees all successes
4280 }
4281
4282 #[rstest]
4283 #[tokio::test]
4284 async fn test_child_tracker_builder(log_policy: Arc<LogOnlyPolicy>) {
4285 // Test that child tracker builder allows custom policies
4286 let parent_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
4287 let parent = TaskTracker::new(parent_scheduler, log_policy).unwrap();
4288
4289 // Create child with custom error policy
4290 let child_error_policy = CancelOnError::new();
4291 let child = parent
4292 .child_tracker_builder()
4293 .error_policy(child_error_policy)
4294 .build()
4295 .unwrap();
4296
4297 // Test that child works
4298 let handle = child.spawn(async { Ok(42) });
4299 let result = handle.await.unwrap().unwrap();
4300 assert_eq!(result, 42);
4301
4302 // Child should have its own metrics
4303 assert_eq!(child.metrics().success(), 1);
4304 assert_eq!(parent.metrics().total_completed(), 1); // Parent sees aggregated
4305 }
4306
4307 #[rstest]
4308 #[tokio::test]
4309 async fn test_hierarchical_metrics_aggregation(log_policy: Arc<LogOnlyPolicy>) {
4310 // Test that child metrics aggregate up to parent
4311 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
4312 let parent = TaskTracker::new(scheduler, log_policy.clone()).unwrap();
4313
4314 // Create child1 with default settings
4315 let child1 = parent.child_tracker().unwrap();
4316
4317 // Create child2 with custom error policy
4318 let child_error_policy = CancelOnError::new();
4319 let child2 = parent
4320 .child_tracker_builder()
4321 .error_policy(child_error_policy)
4322 .build()
4323 .unwrap();
4324
4325 // Test both custom schedulers and policies
4326 let another_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(3))));
4327 let another_error_policy = CancelOnError::new();
4328 let child3 = parent
4329 .child_tracker_builder()
4330 .scheduler(another_scheduler)
4331 .error_policy(another_error_policy)
4332 .build()
4333 .unwrap();
4334
4335 // Test that all children are properly registered
4336 assert_eq!(parent.child_count(), 3);
4337
4338 // Test that custom schedulers work
4339 let handle1 = child1.spawn(async { Ok(1) });
4340 let handle2 = child2.spawn(async { Ok(2) });
4341 let handle3 = child3.spawn(async { Ok(3) });
4342
4343 assert_eq!(handle1.await.unwrap().unwrap(), 1);
4344 assert_eq!(handle2.await.unwrap().unwrap(), 2);
4345 assert_eq!(handle3.await.unwrap().unwrap(), 3);
4346
4347 // Verify metrics still work
4348 assert_eq!(parent.metrics().success(), 3); // All child successes roll up
4349 assert_eq!(child1.metrics().success(), 1);
4350 assert_eq!(child2.metrics().success(), 1);
4351 assert_eq!(child3.metrics().success(), 1);
4352 }
4353
4354 #[rstest]
4355 #[tokio::test]
4356 async fn test_scheduler_queue_depth_calculation(log_policy: Arc<LogOnlyPolicy>) {
4357 // Test that we can calculate tasks queued in scheduler
4358 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2)))); // Only 2 concurrent tasks
4359 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4360
4361 // Initially no tasks
4362 assert_eq!(tracker.metrics().issued(), 0);
4363 assert_eq!(tracker.metrics().active(), 0);
4364 assert_eq!(tracker.metrics().queued(), 0);
4365 assert_eq!(tracker.metrics().pending(), 0);
4366
4367 // Use a channel to control when tasks complete
4368 let (complete_tx, _complete_rx) = tokio::sync::broadcast::channel(1);
4369
4370 // Spawn 2 tasks that will hold semaphore permits
4371 let handle1 = tracker.spawn({
4372 let mut rx = complete_tx.subscribe();
4373 async move {
4374 // Wait for completion signal
4375 rx.recv().await.ok();
4376 Ok(1)
4377 }
4378 });
4379 let handle2 = tracker.spawn({
4380 let mut rx = complete_tx.subscribe();
4381 async move {
4382 // Wait for completion signal
4383 rx.recv().await.ok();
4384 Ok(2)
4385 }
4386 });
4387
4388 // Give tasks time to start and acquire permits
4389 tokio::task::yield_now().await;
4390 tokio::task::yield_now().await;
4391
4392 // Should have 2 active tasks, 0 queued
4393 assert_eq!(tracker.metrics().issued(), 2);
4394 assert_eq!(tracker.metrics().active(), 2);
4395 assert_eq!(tracker.metrics().queued(), 0);
4396 assert_eq!(tracker.metrics().pending(), 2);
4397
4398 // Spawn a third task - should be queued since semaphore is full
4399 let handle3 = tracker.spawn(async move { Ok(3) });
4400
4401 // Give time for task to be queued
4402 tokio::task::yield_now().await;
4403
4404 // Should have 2 active, 1 queued
4405 assert_eq!(tracker.metrics().issued(), 3);
4406 assert_eq!(tracker.metrics().active(), 2);
4407 assert_eq!(
4408 tracker.metrics().queued(),
4409 tracker.metrics().pending() - tracker.metrics().active()
4410 );
4411 assert_eq!(tracker.metrics().pending(), 3);
4412
4413 // Complete all tasks by sending the signal
4414 complete_tx.send(()).ok();
4415
4416 let result1 = handle1.await.unwrap().unwrap();
4417 let result2 = handle2.await.unwrap().unwrap();
4418 let result3 = handle3.await.unwrap().unwrap();
4419
4420 assert_eq!(result1, 1);
4421 assert_eq!(result2, 2);
4422 assert_eq!(result3, 3);
4423
4424 // All tasks should be completed
4425 assert_eq!(tracker.metrics().success(), 3);
4426 assert_eq!(tracker.metrics().active(), 0);
4427 assert_eq!(tracker.metrics().queued(), 0);
4428 assert_eq!(tracker.metrics().pending(), 0);
4429 }
4430
4431 #[rstest]
4432 #[tokio::test]
4433 async fn test_hierarchical_metrics_failure_aggregation(
4434 semaphore_scheduler: Arc<SemaphoreScheduler>,
4435 log_policy: Arc<LogOnlyPolicy>,
4436 ) {
4437 // Test that failed task metrics aggregate up to parent
4438 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4439 let child = parent.child_tracker().unwrap();
4440
4441 // Run some successful and some failed tasks
4442 let success_handle = child.spawn(async { Ok(42) });
4443 let failure_handle = child.spawn(async { Err::<(), _>(anyhow::anyhow!("test error")) });
4444
4445 // Wait for tasks to complete
4446 let _success_result = success_handle.await.unwrap().unwrap();
4447 let _failure_result = failure_handle.await.unwrap().unwrap_err();
4448
4449 // Check child metrics
4450 assert_eq!(child.metrics().success(), 1, "Child should have 1 success");
4451 assert_eq!(child.metrics().failed(), 1, "Child should have 1 failure");
4452
4453 // Parent should see the aggregated metrics
4454 // Note: Due to hierarchical aggregation, these metrics propagate up
4455 }
4456
4457 #[rstest]
4458 #[tokio::test]
4459 async fn test_metrics_independence_between_tracker_instances(
4460 semaphore_scheduler: Arc<SemaphoreScheduler>,
4461 log_policy: Arc<LogOnlyPolicy>,
4462 ) {
4463 // Test that different tracker instances have independent metrics
4464 let tracker1 = TaskTracker::new(semaphore_scheduler.clone(), log_policy.clone()).unwrap();
4465 let tracker2 = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4466
4467 // Run tasks in both trackers
4468 let handle1 = tracker1.spawn(async { Ok(1) });
4469 let handle2 = tracker2.spawn(async { Ok(2) });
4470
4471 handle1.await.unwrap().unwrap();
4472 handle2.await.unwrap().unwrap();
4473
4474 // Each tracker should only see its own metrics
4475 assert_eq!(tracker1.metrics().success(), 1);
4476 assert_eq!(tracker2.metrics().success(), 1);
4477 assert_eq!(tracker1.metrics().total_completed(), 1);
4478 assert_eq!(tracker2.metrics().total_completed(), 1);
4479 }
4480
4481 #[rstest]
4482 #[tokio::test]
4483 async fn test_hierarchical_join_waits_for_all(log_policy: Arc<LogOnlyPolicy>) {
4484 // Test that parent.join() waits for child tasks too
4485 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
4486 let parent = TaskTracker::new(scheduler, log_policy).unwrap();
4487 let child1 = parent.child_tracker().unwrap();
4488 let child2 = parent.child_tracker().unwrap();
4489 let grandchild = child1.child_tracker().unwrap();
4490
4491 // Verify parent tracks children
4492 assert_eq!(parent.child_count(), 2);
4493 assert_eq!(child1.child_count(), 1);
4494 assert_eq!(child2.child_count(), 0);
4495 assert_eq!(grandchild.child_count(), 0);
4496
4497 // Track completion order
4498 let completion_order = Arc::new(Mutex::new(Vec::new()));
4499
4500 // Spawn tasks with different durations
4501 let order_clone = completion_order.clone();
4502 let parent_handle = parent.spawn(async move {
4503 tokio::time::sleep(Duration::from_millis(50)).await;
4504 order_clone.lock().unwrap().push("parent");
4505 Ok(())
4506 });
4507
4508 let order_clone = completion_order.clone();
4509 let child1_handle = child1.spawn(async move {
4510 tokio::time::sleep(Duration::from_millis(100)).await;
4511 order_clone.lock().unwrap().push("child1");
4512 Ok(())
4513 });
4514
4515 let order_clone = completion_order.clone();
4516 let child2_handle = child2.spawn(async move {
4517 tokio::time::sleep(Duration::from_millis(75)).await;
4518 order_clone.lock().unwrap().push("child2");
4519 Ok(())
4520 });
4521
4522 let order_clone = completion_order.clone();
4523 let grandchild_handle = grandchild.spawn(async move {
4524 tokio::time::sleep(Duration::from_millis(125)).await;
4525 order_clone.lock().unwrap().push("grandchild");
4526 Ok(())
4527 });
4528
4529 // Test hierarchical join - should wait for ALL tasks in hierarchy
4530 println!("[TEST] About to call parent.join()");
4531 let start = std::time::Instant::now();
4532 parent.join().await; // This should wait for ALL tasks
4533 let elapsed = start.elapsed();
4534 println!("[TEST] parent.join() completed in {:?}", elapsed);
4535
4536 // Should have waited for the longest task (grandchild at 125ms)
4537 assert!(
4538 elapsed >= Duration::from_millis(120),
4539 "Hierarchical join should wait for longest task"
4540 );
4541
4542 // All tasks should be complete
4543 assert!(parent_handle.is_finished());
4544 assert!(child1_handle.is_finished());
4545 assert!(child2_handle.is_finished());
4546 assert!(grandchild_handle.is_finished());
4547
4548 // Verify all tasks completed
4549 let final_order = completion_order.lock().unwrap();
4550 assert_eq!(final_order.len(), 4);
4551 assert!(final_order.contains(&"parent"));
4552 assert!(final_order.contains(&"child1"));
4553 assert!(final_order.contains(&"child2"));
4554 assert!(final_order.contains(&"grandchild"));
4555 }
4556
4557 #[rstest]
4558 #[tokio::test]
4559 async fn test_hierarchical_join_waits_for_children(
4560 semaphore_scheduler: Arc<SemaphoreScheduler>,
4561 log_policy: Arc<LogOnlyPolicy>,
4562 ) {
4563 // Test that join() waits for child tasks (hierarchical behavior)
4564 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4565 let child = parent.child_tracker().unwrap();
4566
4567 // Spawn a quick parent task and slow child task
4568 let _parent_handle = parent.spawn(async {
4569 tokio::time::sleep(Duration::from_millis(20)).await;
4570 Ok(())
4571 });
4572
4573 let _child_handle = child.spawn(async {
4574 tokio::time::sleep(Duration::from_millis(100)).await;
4575 Ok(())
4576 });
4577
4578 // Hierarchical join should wait for both parent and child tasks
4579 let start = std::time::Instant::now();
4580 parent.join().await; // Should wait for both (hierarchical by default)
4581 let elapsed = start.elapsed();
4582
4583 // Should have waited for the longer child task (100ms)
4584 assert!(
4585 elapsed >= Duration::from_millis(90),
4586 "Hierarchical join should wait for all child tasks"
4587 );
4588 }
4589
4590 #[rstest]
4591 #[tokio::test]
4592 async fn test_hierarchical_join_operations(
4593 semaphore_scheduler: Arc<SemaphoreScheduler>,
4594 log_policy: Arc<LogOnlyPolicy>,
4595 ) {
4596 // Test that parent.join() closes and waits for child trackers too
4597 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4598 let child = parent.child_tracker().unwrap();
4599 let grandchild = child.child_tracker().unwrap();
4600
4601 // Verify trackers start as open
4602 assert!(!parent.is_closed());
4603 assert!(!child.is_closed());
4604 assert!(!grandchild.is_closed());
4605
4606 // Join parent (hierarchical by default - closes and waits for all)
4607 parent.join().await;
4608
4609 // All should be closed (check child trackers since parent was moved)
4610 assert!(child.is_closed());
4611 assert!(grandchild.is_closed());
4612 }
4613
4614 #[rstest]
4615 #[tokio::test]
4616 async fn test_unlimited_scheduler() {
4617 // Test that UnlimitedScheduler executes tasks immediately
4618 let scheduler = UnlimitedScheduler::new();
4619 let error_policy = LogOnlyPolicy::new();
4620 let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
4621
4622 let (tx, rx) = tokio::sync::oneshot::channel();
4623 let handle = tracker.spawn(async {
4624 rx.await.ok();
4625 Ok(42)
4626 });
4627
4628 // Task should be ready to execute immediately (no concurrency limit)
4629 tx.send(()).ok();
4630 let result = handle.await.unwrap().unwrap();
4631 assert_eq!(result, 42);
4632
4633 assert_eq!(tracker.metrics().success(), 1);
4634 }
4635
4636 #[rstest]
4637 #[tokio::test]
4638 async fn test_threshold_cancel_policy(semaphore_scheduler: Arc<SemaphoreScheduler>) {
4639 // Test that ThresholdCancelPolicy now uses per-task failure counting
4640 let error_policy = ThresholdCancelPolicy::with_threshold(2); // Cancel after 2 failures per task
4641 let tracker = TaskTracker::new(semaphore_scheduler, error_policy.clone()).unwrap();
4642 let cancel_token = tracker.cancellation_token().child_token();
4643
4644 // With per-task context, individual task failures don't accumulate
4645 // Each task starts with failure_count = 0, so single failures won't trigger cancellation
4646 let _handle1 = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("First failure")) });
4647 tokio::task::yield_now().await;
4648 assert!(!cancel_token.is_cancelled());
4649 assert_eq!(error_policy.failure_count(), 1); // Global counter still increments
4650
4651 // Second failure from different task - still won't trigger cancellation
4652 let _handle2 = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("Second failure")) });
4653 tokio::task::yield_now().await;
4654 assert!(!cancel_token.is_cancelled()); // Per-task context prevents cancellation
4655 assert_eq!(error_policy.failure_count(), 2); // Global counter increments
4656
4657 // For cancellation to occur, a single task would need to fail multiple times
4658 // through continuations (which would require a more complex test setup)
4659 }
4660
4661 #[tokio::test]
4662 async fn test_policy_constructors() {
4663 // Test that all constructors follow the new clean API patterns
4664 let _unlimited = UnlimitedScheduler::new();
4665 let _semaphore = SemaphoreScheduler::with_permits(5);
4666 let _log_only = LogOnlyPolicy::new();
4667 let _cancel_policy = CancelOnError::new();
4668 let _threshold_policy = ThresholdCancelPolicy::with_threshold(3);
4669 let _rate_policy = RateCancelPolicy::builder()
4670 .rate(0.5)
4671 .window_secs(60)
4672 .build();
4673
4674 // All constructors return Arc directly - no more ugly ::new_arc patterns
4675 // This test ensures the clean API reduces boilerplate
4676 }
4677
4678 #[rstest]
4679 #[tokio::test]
4680 async fn test_child_creation_fails_after_join(
4681 semaphore_scheduler: Arc<SemaphoreScheduler>,
4682 log_policy: Arc<LogOnlyPolicy>,
4683 ) {
4684 // Test that child tracker creation fails from closed parent
4685 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4686
4687 // Initially, creating a child should work
4688 let _child = parent.child_tracker().unwrap();
4689
4690 // Close the parent tracker
4691 let parent_clone = parent.clone();
4692 parent.join().await;
4693 assert!(parent_clone.is_closed());
4694
4695 // Now, trying to create a child should fail
4696 let result = parent_clone.child_tracker();
4697 assert!(result.is_err());
4698 assert!(
4699 result
4700 .err()
4701 .unwrap()
4702 .to_string()
4703 .contains("closed parent tracker")
4704 );
4705 }
4706
4707 #[rstest]
4708 #[tokio::test]
4709 async fn test_child_builder_fails_after_join(
4710 semaphore_scheduler: Arc<SemaphoreScheduler>,
4711 log_policy: Arc<LogOnlyPolicy>,
4712 ) {
4713 // Test that child tracker builder creation fails from closed parent
4714 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4715
4716 // Initially, creating a child with builder should work
4717 let _child = parent.child_tracker_builder().build().unwrap();
4718
4719 // Close the parent tracker
4720 let parent_clone = parent.clone();
4721 parent.join().await;
4722 assert!(parent_clone.is_closed());
4723
4724 // Now, trying to create a child with builder should fail
4725 let result = parent_clone.child_tracker_builder().build();
4726 assert!(result.is_err());
4727 assert!(
4728 result
4729 .err()
4730 .unwrap()
4731 .to_string()
4732 .contains("closed parent tracker")
4733 );
4734 }
4735
4736 #[rstest]
4737 #[tokio::test]
4738 async fn test_child_creation_succeeds_before_join(
4739 semaphore_scheduler: Arc<SemaphoreScheduler>,
4740 log_policy: Arc<LogOnlyPolicy>,
4741 ) {
4742 // Test that child creation works normally before parent is joined
4743 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4744
4745 // Should be able to create multiple children before closing
4746 let child1 = parent.child_tracker().unwrap();
4747 let child2 = parent.child_tracker_builder().build().unwrap();
4748
4749 // Verify children can spawn tasks
4750 let handle1 = child1.spawn(async { Ok(42) });
4751 let handle2 = child2.spawn(async { Ok(24) });
4752
4753 let result1 = handle1.await.unwrap().unwrap();
4754 let result2 = handle2.await.unwrap().unwrap();
4755
4756 assert_eq!(result1, 42);
4757 assert_eq!(result2, 24);
4758 assert_eq!(parent.metrics().success(), 2); // Parent sees all successes
4759 }
4760
4761 #[rstest]
4762 #[tokio::test]
4763 async fn test_custom_error_response_with_cancellation_token(
4764 semaphore_scheduler: Arc<SemaphoreScheduler>,
4765 ) {
4766 // Test ErrorResponse::Custom behavior with TriggerCancellationTokenOnError
4767
4768 // Create a custom cancellation token
4769 let custom_cancel_token = CancellationToken::new();
4770
4771 // Create the policy that will trigger our custom token
4772 let error_policy = TriggerCancellationTokenOnError::new(custom_cancel_token.clone());
4773
4774 // Create tracker using builder with the custom policy
4775 let tracker = TaskTracker::builder()
4776 .scheduler(semaphore_scheduler)
4777 .error_policy(error_policy)
4778 .cancel_token(custom_cancel_token.clone())
4779 .build()
4780 .unwrap();
4781
4782 let child = tracker.child_tracker().unwrap();
4783
4784 // Initially, the custom token should not be cancelled
4785 assert!(!custom_cancel_token.is_cancelled());
4786
4787 // Spawn a task that will fail
4788 let handle = child.spawn(async {
4789 Err::<(), _>(anyhow::anyhow!("Test error to trigger custom response"))
4790 });
4791
4792 // Wait for the task to complete (it will fail)
4793 let result = handle.await.unwrap();
4794 assert!(result.is_err());
4795
4796 // Await a timeout/deadline or the cancellation token to be cancelled
4797 // The expectation is that the task will fail, and the cancellation token will be triggered
4798 // Hitting the deadline is a failure
4799 tokio::select! {
4800 _ = tokio::time::sleep(Duration::from_secs(1)) => {
4801 panic!("Task should have failed, but hit the deadline");
4802 }
4803 _ = custom_cancel_token.cancelled() => {
4804 // Task should have failed, and the cancellation token should be triggered
4805 }
4806 }
4807
4808 // The custom cancellation token should now be triggered by our policy
4809 assert!(
4810 custom_cancel_token.is_cancelled(),
4811 "Custom cancellation token should be triggered by ErrorResponse::Custom"
4812 );
4813
4814 assert!(tracker.cancellation_token().is_cancelled());
4815 assert!(child.cancellation_token().is_cancelled());
4816
4817 // Verify the error was counted
4818 assert_eq!(tracker.metrics().failed(), 1);
4819 }
4820
4821 #[test]
4822 fn test_action_result_variants() {
4823 // Test that ActionResult variants can be created and pattern matched
4824
4825 // Test Fail variant
4826 let fail_result = ActionResult::Fail;
4827 match fail_result {
4828 ActionResult::Fail => {} // Expected
4829 _ => panic!("Expected Fail variant"),
4830 }
4831
4832 // Test Shutdown variant
4833 let shutdown_result = ActionResult::Shutdown;
4834 match shutdown_result {
4835 ActionResult::Shutdown => {} // Expected
4836 _ => panic!("Expected Shutdown variant"),
4837 }
4838
4839 // Test Continue variant with Continuation
4840 #[derive(Debug)]
4841 struct TestRestartable;
4842
4843 #[async_trait]
4844 impl Continuation for TestRestartable {
4845 async fn execute(
4846 &self,
4847 _cancel_token: CancellationToken,
4848 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4849 TaskExecutionResult::Success(Box::new("test_result".to_string()))
4850 }
4851 }
4852
4853 let test_restartable = Arc::new(TestRestartable);
4854 let continue_result = ActionResult::Continue {
4855 continuation: test_restartable,
4856 };
4857
4858 match continue_result {
4859 ActionResult::Continue { continuation } => {
4860 // Verify we have a valid Continuation
4861 assert!(format!("{:?}", continuation).contains("TestRestartable"));
4862 }
4863 _ => panic!("Expected Continue variant"),
4864 }
4865 }
4866
4867 #[test]
4868 fn test_continuation_error_creation() {
4869 // Test RestartableError creation and conversion to anyhow::Error
4870
4871 // Create a dummy restartable task for testing
4872 #[derive(Debug)]
4873 struct DummyRestartable;
4874
4875 #[async_trait]
4876 impl Continuation for DummyRestartable {
4877 async fn execute(
4878 &self,
4879 _cancel_token: CancellationToken,
4880 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4881 TaskExecutionResult::Success(Box::new("restarted_result".to_string()))
4882 }
4883 }
4884
4885 let dummy_restartable = Arc::new(DummyRestartable);
4886 let source_error = anyhow::anyhow!("Original task failed");
4887
4888 // Test FailedWithContinuation::new
4889 let continuation_error = FailedWithContinuation::new(source_error, dummy_restartable);
4890
4891 // Verify the error displays correctly
4892 let error_string = format!("{}", continuation_error);
4893 assert!(error_string.contains("Task failed with continuation"));
4894 assert!(error_string.contains("Original task failed"));
4895
4896 // Test conversion to anyhow::Error
4897 let anyhow_error = anyhow::Error::new(continuation_error);
4898 assert!(
4899 anyhow_error
4900 .to_string()
4901 .contains("Task failed with continuation")
4902 );
4903 }
4904
4905 #[test]
4906 fn test_continuation_error_ext_trait() {
4907 // Test the RestartableErrorExt trait methods
4908
4909 // Test with regular anyhow::Error (not restartable)
4910 let regular_error = anyhow::anyhow!("Regular error");
4911 assert!(!regular_error.has_continuation());
4912 let extracted = regular_error.extract_continuation();
4913 assert!(extracted.is_none());
4914
4915 // Test with RestartableError
4916 #[derive(Debug)]
4917 struct TestRestartable;
4918
4919 #[async_trait]
4920 impl Continuation for TestRestartable {
4921 async fn execute(
4922 &self,
4923 _cancel_token: CancellationToken,
4924 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4925 TaskExecutionResult::Success(Box::new("test_result".to_string()))
4926 }
4927 }
4928
4929 let test_restartable = Arc::new(TestRestartable);
4930 let source_error = anyhow::anyhow!("Source error");
4931 let continuation_error = FailedWithContinuation::new(source_error, test_restartable);
4932
4933 let anyhow_error = anyhow::Error::new(continuation_error);
4934 assert!(anyhow_error.has_continuation());
4935
4936 // Test extraction of restartable task
4937 let extracted = anyhow_error.extract_continuation();
4938 assert!(extracted.is_some());
4939 }
4940
4941 #[test]
4942 fn test_continuation_error_into_anyhow_helper() {
4943 // Test the convenience method for creating restartable errors
4944 // Note: This test uses a mock TaskExecutor since we don't have real ones yet
4945
4946 // For now, we'll test the type erasure concept with a simple type
4947 struct MockExecutor;
4948
4949 let _source_error = anyhow::anyhow!("Mock task failed");
4950
4951 // We can't test FailedWithContinuation::into_anyhow yet because it requires
4952 // a real TaskExecutor<T>. This will be tested in Phase 3.
4953 // For now, just verify the concept works with manual construction.
4954
4955 #[derive(Debug)]
4956 struct MockRestartable;
4957
4958 #[async_trait]
4959 impl Continuation for MockRestartable {
4960 async fn execute(
4961 &self,
4962 _cancel_token: CancellationToken,
4963 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4964 TaskExecutionResult::Success(Box::new("mock_result".to_string()))
4965 }
4966 }
4967
4968 let mock_restartable = Arc::new(MockRestartable);
4969 let continuation_error =
4970 FailedWithContinuation::new(anyhow::anyhow!("Mock task failed"), mock_restartable);
4971
4972 let anyhow_error = anyhow::Error::new(continuation_error);
4973 assert!(anyhow_error.has_continuation());
4974 }
4975
4976 #[test]
4977 fn test_continuation_error_with_task_executor() {
4978 // Test RestartableError creation with TaskExecutor
4979
4980 #[derive(Debug)]
4981 struct TestRestartableTask;
4982
4983 #[async_trait]
4984 impl Continuation for TestRestartableTask {
4985 async fn execute(
4986 &self,
4987 _cancel_token: CancellationToken,
4988 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4989 TaskExecutionResult::Success(Box::new("test_result".to_string()))
4990 }
4991 }
4992
4993 let restartable_task = Arc::new(TestRestartableTask);
4994 let source_error = anyhow::anyhow!("Task failed");
4995
4996 // Test FailedWithContinuation::new with Restartable
4997 let continuation_error = FailedWithContinuation::new(source_error, restartable_task);
4998
4999 // Verify the error displays correctly
5000 let error_string = format!("{}", continuation_error);
5001 assert!(error_string.contains("Task failed with continuation"));
5002 assert!(error_string.contains("Task failed"));
5003
5004 // Test conversion to anyhow::Error
5005 let anyhow_error = anyhow::Error::new(continuation_error);
5006 assert!(anyhow_error.has_continuation());
5007
5008 // Test extraction (should work now with Restartable trait)
5009 let extracted = anyhow_error.extract_continuation();
5010 assert!(extracted.is_some()); // Should successfully extract the Restartable
5011 }
5012
5013 #[test]
5014 fn test_continuation_error_into_anyhow_convenience() {
5015 // Test the convenience method for creating restartable errors
5016
5017 #[derive(Debug)]
5018 struct ConvenienceRestartable;
5019
5020 #[async_trait]
5021 impl Continuation for ConvenienceRestartable {
5022 async fn execute(
5023 &self,
5024 _cancel_token: CancellationToken,
5025 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5026 TaskExecutionResult::Success(Box::new(42u32))
5027 }
5028 }
5029
5030 let restartable_task = Arc::new(ConvenienceRestartable);
5031 let source_error = anyhow::anyhow!("Computation failed");
5032
5033 // Test FailedWithContinuation::into_anyhow convenience method
5034 let anyhow_error = FailedWithContinuation::into_anyhow(source_error, restartable_task);
5035
5036 assert!(anyhow_error.has_continuation());
5037 assert!(
5038 anyhow_error
5039 .to_string()
5040 .contains("Task failed with continuation")
5041 );
5042 assert!(anyhow_error.to_string().contains("Computation failed"));
5043 }
5044
5045 #[test]
5046 fn test_handle_task_error_with_continuation_error() {
5047 // Test that handle_task_error properly detects RestartableError
5048
5049 // Create a mock Restartable task
5050 #[derive(Debug)]
5051 struct MockRestartableTask;
5052
5053 #[async_trait]
5054 impl Continuation for MockRestartableTask {
5055 async fn execute(
5056 &self,
5057 _cancel_token: CancellationToken,
5058 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5059 TaskExecutionResult::Success(Box::new("retry_result".to_string()))
5060 }
5061 }
5062
5063 let restartable_task = Arc::new(MockRestartableTask);
5064
5065 // Create RestartableError
5066 let source_error = anyhow::anyhow!("Task failed, but can retry");
5067 let continuation_error = FailedWithContinuation::new(source_error, restartable_task);
5068 let anyhow_error = anyhow::Error::new(continuation_error);
5069
5070 // Verify it's detected as restartable
5071 assert!(anyhow_error.has_continuation());
5072
5073 // Verify we can downcast to FailedWithContinuation
5074 let continuation_ref = anyhow_error.downcast_ref::<FailedWithContinuation>();
5075 assert!(continuation_ref.is_some());
5076
5077 // Verify the continuation task is present
5078 let continuation = continuation_ref.unwrap();
5079 // Note: We can verify the Arc is valid by checking that Arc::strong_count > 0
5080 assert!(Arc::strong_count(&continuation.continuation) > 0);
5081 }
5082
5083 #[test]
5084 fn test_handle_task_error_with_regular_error() {
5085 // Test that handle_task_error properly handles regular errors
5086
5087 let regular_error = anyhow::anyhow!("Regular task failure");
5088
5089 // Verify it's not detected as restartable
5090 assert!(!regular_error.has_continuation());
5091
5092 // Verify we cannot downcast to FailedWithContinuation
5093 let continuation_ref = regular_error.downcast_ref::<FailedWithContinuation>();
5094 assert!(continuation_ref.is_none());
5095 }
5096
5097 // ========================================
5098 // END-TO-END ACTIONRESULT TESTS
5099 // ========================================
5100
5101 #[rstest]
5102 #[tokio::test]
5103 async fn test_end_to_end_continuation_execution(
5104 unlimited_scheduler: Arc<UnlimitedScheduler>,
5105 log_policy: Arc<LogOnlyPolicy>,
5106 ) {
5107 // Test that a task returning FailedWithContinuation actually executes the continuation
5108 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
5109
5110 // Shared state to track execution
5111 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5112 let log_clone = execution_log.clone();
5113
5114 // Create a continuation that logs its execution
5115 #[derive(Debug)]
5116 struct LoggingContinuation {
5117 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5118 result: String,
5119 }
5120
5121 #[async_trait]
5122 impl Continuation for LoggingContinuation {
5123 async fn execute(
5124 &self,
5125 _cancel_token: CancellationToken,
5126 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5127 self.log
5128 .lock()
5129 .await
5130 .push("continuation_executed".to_string());
5131 TaskExecutionResult::Success(Box::new(self.result.clone()))
5132 }
5133 }
5134
5135 let continuation = Arc::new(LoggingContinuation {
5136 log: log_clone,
5137 result: "continuation_result".to_string(),
5138 });
5139
5140 // Task that fails with continuation
5141 let log_for_task = execution_log.clone();
5142 let handle = tracker.spawn(async move {
5143 log_for_task
5144 .lock()
5145 .await
5146 .push("original_task_executed".to_string());
5147
5148 // Return FailedWithContinuation
5149 let error = anyhow::anyhow!("Original task failed");
5150 let result: Result<String, anyhow::Error> =
5151 Err(FailedWithContinuation::into_anyhow(error, continuation));
5152 result
5153 });
5154
5155 // Execute and verify the continuation was called
5156 let result = handle.await.expect("Task should complete");
5157 assert!(result.is_ok(), "Continuation should succeed");
5158
5159 // Verify execution order
5160 let log = execution_log.lock().await;
5161 assert_eq!(log.len(), 2);
5162 assert_eq!(log[0], "original_task_executed");
5163 assert_eq!(log[1], "continuation_executed");
5164
5165 // Verify metrics - should show 1 success (from continuation)
5166 assert_eq!(tracker.metrics().success(), 1);
5167 assert_eq!(tracker.metrics().failed(), 0); // Continuation succeeded
5168 assert_eq!(tracker.metrics().cancelled(), 0);
5169 }
5170
5171 #[rstest]
5172 #[tokio::test]
5173 async fn test_end_to_end_multiple_continuations(
5174 unlimited_scheduler: Arc<UnlimitedScheduler>,
5175 log_policy: Arc<LogOnlyPolicy>,
5176 ) {
5177 // Test multiple continuation attempts
5178 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
5179
5180 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5181 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
5182
5183 // Continuation that fails twice, then succeeds
5184 #[derive(Debug)]
5185 struct RetryingContinuation {
5186 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5187 attempt_count: Arc<std::sync::atomic::AtomicU32>,
5188 }
5189
5190 #[async_trait]
5191 impl Continuation for RetryingContinuation {
5192 async fn execute(
5193 &self,
5194 _cancel_token: CancellationToken,
5195 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5196 let attempt = self
5197 .attempt_count
5198 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
5199 + 1;
5200 self.log
5201 .lock()
5202 .await
5203 .push(format!("continuation_attempt_{}", attempt));
5204
5205 if attempt < 3 {
5206 // Fail with another continuation
5207 let next_continuation = Arc::new(RetryingContinuation {
5208 log: self.log.clone(),
5209 attempt_count: self.attempt_count.clone(),
5210 });
5211 let error = anyhow::anyhow!("Continuation attempt {} failed", attempt);
5212 TaskExecutionResult::Error(FailedWithContinuation::into_anyhow(
5213 error,
5214 next_continuation,
5215 ))
5216 } else {
5217 // Succeed on third attempt
5218 TaskExecutionResult::Success(Box::new(format!(
5219 "success_on_attempt_{}",
5220 attempt
5221 )))
5222 }
5223 }
5224 }
5225
5226 let initial_continuation = Arc::new(RetryingContinuation {
5227 log: execution_log.clone(),
5228 attempt_count: attempt_count.clone(),
5229 });
5230
5231 // Task that immediately fails with continuation
5232 let handle = tracker.spawn(async move {
5233 let error = anyhow::anyhow!("Original task failed");
5234 let result: Result<String, anyhow::Error> = Err(FailedWithContinuation::into_anyhow(
5235 error,
5236 initial_continuation,
5237 ));
5238 result
5239 });
5240
5241 // Execute and verify multiple continuations
5242 let result = handle.await.expect("Task should complete");
5243 assert!(result.is_ok(), "Final continuation should succeed");
5244
5245 // Verify all attempts were made
5246 let log = execution_log.lock().await;
5247 assert_eq!(log.len(), 3);
5248 assert_eq!(log[0], "continuation_attempt_1");
5249 assert_eq!(log[1], "continuation_attempt_2");
5250 assert_eq!(log[2], "continuation_attempt_3");
5251
5252 // Verify final attempt count
5253 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 3);
5254
5255 // Verify metrics - should show 1 success (final continuation)
5256 assert_eq!(tracker.metrics().success(), 1);
5257 assert_eq!(tracker.metrics().failed(), 0);
5258 }
5259
5260 #[rstest]
5261 #[tokio::test]
5262 async fn test_end_to_end_continuation_failure(
5263 unlimited_scheduler: Arc<UnlimitedScheduler>,
5264 log_policy: Arc<LogOnlyPolicy>,
5265 ) {
5266 // Test continuation that ultimately fails without providing another continuation
5267 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
5268
5269 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5270 let log_clone = execution_log.clone();
5271
5272 // Continuation that fails without providing another continuation
5273 #[derive(Debug)]
5274 struct FailingContinuation {
5275 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5276 }
5277
5278 #[async_trait]
5279 impl Continuation for FailingContinuation {
5280 async fn execute(
5281 &self,
5282 _cancel_token: CancellationToken,
5283 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5284 self.log
5285 .lock()
5286 .await
5287 .push("continuation_failed".to_string());
5288 TaskExecutionResult::Error(anyhow::anyhow!("Continuation failed permanently"))
5289 }
5290 }
5291
5292 let continuation = Arc::new(FailingContinuation { log: log_clone });
5293
5294 // Task that fails with continuation
5295 let log_for_task = execution_log.clone();
5296 let handle = tracker.spawn(async move {
5297 log_for_task
5298 .lock()
5299 .await
5300 .push("original_task_executed".to_string());
5301
5302 let error = anyhow::anyhow!("Original task failed");
5303 let result: Result<String, anyhow::Error> =
5304 Err(FailedWithContinuation::into_anyhow(error, continuation));
5305 result
5306 });
5307
5308 // Execute and verify the continuation failed
5309 let result = handle.await.expect("Task should complete");
5310 assert!(result.is_err(), "Continuation should fail");
5311
5312 // Verify execution order
5313 let log = execution_log.lock().await;
5314 assert_eq!(log.len(), 2);
5315 assert_eq!(log[0], "original_task_executed");
5316 assert_eq!(log[1], "continuation_failed");
5317
5318 // Verify metrics - should show 1 failure (from continuation)
5319 assert_eq!(tracker.metrics().success(), 0);
5320 assert_eq!(tracker.metrics().failed(), 1);
5321 assert_eq!(tracker.metrics().cancelled(), 0);
5322 }
5323
5324 #[rstest]
5325 #[tokio::test]
5326 async fn test_end_to_end_all_action_result_variants(
5327 unlimited_scheduler: Arc<UnlimitedScheduler>,
5328 ) {
5329 // Comprehensive test of Fail, Shutdown, and Continue paths
5330
5331 // Test 1: ActionResult::Fail (via LogOnlyPolicy)
5332 {
5333 let tracker =
5334 TaskTracker::new(unlimited_scheduler.clone(), LogOnlyPolicy::new()).unwrap();
5335 let handle = tracker.spawn(async {
5336 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Test error"));
5337 result
5338 });
5339 let result = handle.await.expect("Task should complete");
5340 assert!(result.is_err(), "LogOnly should let error through");
5341 assert_eq!(tracker.metrics().failed(), 1);
5342 }
5343
5344 // Test 2: ActionResult::Shutdown (via CancelOnError)
5345 {
5346 let tracker =
5347 TaskTracker::new(unlimited_scheduler.clone(), CancelOnError::new()).unwrap();
5348 let handle = tracker.spawn(async {
5349 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Test error"));
5350 result
5351 });
5352 let result = handle.await.expect("Task should complete");
5353 assert!(result.is_err(), "CancelOnError should fail task");
5354 assert!(
5355 tracker.cancellation_token().is_cancelled(),
5356 "Should cancel tracker"
5357 );
5358 assert_eq!(tracker.metrics().failed(), 1);
5359 }
5360
5361 // Test 3: ActionResult::Continue (via FailedWithContinuation)
5362 {
5363 let tracker =
5364 TaskTracker::new(unlimited_scheduler.clone(), LogOnlyPolicy::new()).unwrap();
5365
5366 #[derive(Debug)]
5367 struct TestContinuation;
5368
5369 #[async_trait]
5370 impl Continuation for TestContinuation {
5371 async fn execute(
5372 &self,
5373 _cancel_token: CancellationToken,
5374 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5375 TaskExecutionResult::Success(Box::new("continuation_success".to_string()))
5376 }
5377 }
5378
5379 let continuation = Arc::new(TestContinuation);
5380 let handle = tracker.spawn(async move {
5381 let error = anyhow::anyhow!("Original failure");
5382 let result: Result<String, anyhow::Error> =
5383 Err(FailedWithContinuation::into_anyhow(error, continuation));
5384 result
5385 });
5386
5387 let result = handle.await.expect("Task should complete");
5388 assert!(result.is_ok(), "Continuation should succeed");
5389 assert_eq!(tracker.metrics().success(), 1);
5390 assert_eq!(tracker.metrics().failed(), 0);
5391 }
5392 }
5393
5394 // ========================================
5395 // LOOP BEHAVIOR AND POLICY INTERACTION TESTS
5396 // ========================================
5397 //
5398 // These tests demonstrate the current ActionResult system and identify
5399 // areas for future improvement:
5400 //
5401 // ✅ WHAT WORKS:
5402 // - All ActionResult variants (Continue, Cancel, ExecuteNext) are tested
5403 // - Task-driven continuations work correctly
5404 // - Policy-driven continuations work correctly
5405 // - Mixed continuation sources work correctly
5406 // - Loop behavior with resource management works correctly
5407 //
5408 // 🔄 CURRENT LIMITATIONS:
5409 // - ThresholdCancelPolicy tracks failures GLOBALLY, not per-task
5410 // - OnErrorPolicy doesn't receive attempt_count parameter
5411 // - No per-task context for stateful retry policies
5412 //
5413 // 🚀 FUTURE IMPROVEMENTS IDENTIFIED:
5414 // - Add OnErrorContext associated type for per-task state
5415 // - Pass attempt_count to OnErrorPolicy::on_error
5416 // - Enable per-task failure tracking, backoff timers, etc.
5417 //
5418 // The tests below demonstrate both current capabilities and limitations.
5419
5420 /// Test retry loop behavior with different policies and continuation counts
5421 ///
5422 /// This test verifies that:
5423 /// 1. Tasks can provide multiple continuations in sequence
5424 /// 2. Different error policies can limit the number of continuation attempts
5425 /// 3. The retry loop correctly handles policy decisions about when to stop
5426 ///
5427 /// Key insight: Policies are only consulted for regular errors, not FailedWithContinuation.
5428 /// So we need continuations that eventually fail with regular errors to test policy limits.
5429 ///
5430 /// DESIGN LIMITATION: Current ThresholdCancelPolicy tracks failures GLOBALLY across all tasks,
5431 /// not per-task. This test demonstrates the current behavior but isn't ideal for retry loop testing.
5432 ///
5433 /// FUTURE IMPROVEMENT: Add OnErrorContext associated type to OnErrorPolicy:
5434 /// ```rust
5435 /// trait OnErrorPolicy {
5436 /// type Context: Default + Send + Sync;
5437 /// fn on_error(&self, error: &anyhow::Error, task_id: TaskId,
5438 /// attempt_count: u32, context: &mut Self::Context) -> ErrorResponse;
5439 /// }
5440 /// ```
5441 /// This would enable per-task failure tracking, backoff timers, etc.
5442 ///
5443 /// NOTE: Uses fresh policy instance for each test case to avoid global state interference.
5444 #[rstest]
5445 #[case(
5446 1,
5447 false,
5448 "Global policy with max_failures=1 should stop after first regular error"
5449 )]
5450 #[case(
5451 2,
5452 false, // Actually fails - ActionResult::Fail accepts the error and fails the task
5453 "Global policy with max_failures=2 allows error but ActionResult::Fail still fails the task"
5454 )]
5455 #[tokio::test]
5456 async fn test_continuation_loop_with_global_threshold_policy(
5457 unlimited_scheduler: Arc<UnlimitedScheduler>,
5458 #[case] max_failures: usize,
5459 #[case] should_succeed: bool,
5460 #[case] description: &str,
5461 ) {
5462 // Task that provides continuations, but continuations fail with regular errors
5463 // so the policy gets consulted and can limit retries
5464
5465 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5466 let attempt_counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
5467
5468 // Create a continuation that fails with regular errors (not FailedWithContinuation)
5469 // This allows the policy to be consulted and potentially stop the retries
5470 #[derive(Debug)]
5471 struct PolicyTestContinuation {
5472 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5473 attempt_counter: Arc<std::sync::atomic::AtomicU32>,
5474 max_attempts_before_success: u32,
5475 }
5476
5477 #[async_trait]
5478 impl Continuation for PolicyTestContinuation {
5479 async fn execute(
5480 &self,
5481 _cancel_token: CancellationToken,
5482 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5483 let attempt = self
5484 .attempt_counter
5485 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
5486 + 1;
5487 self.log
5488 .lock()
5489 .await
5490 .push(format!("continuation_attempt_{}", attempt));
5491
5492 if attempt < self.max_attempts_before_success {
5493 // Fail with regular error - this will be seen by the policy
5494 TaskExecutionResult::Error(anyhow::anyhow!(
5495 "Continuation attempt {} failed (regular error)",
5496 attempt
5497 ))
5498 } else {
5499 // Succeed after enough attempts
5500 TaskExecutionResult::Success(Box::new(format!(
5501 "success_on_attempt_{}",
5502 attempt
5503 )))
5504 }
5505 }
5506 }
5507
5508 // Create fresh policy instance for each test case to avoid global state interference
5509 let policy = ThresholdCancelPolicy::with_threshold(max_failures);
5510 let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
5511
5512 // Original task that fails with continuation
5513 let log_for_task = execution_log.clone();
5514 // Set max_attempts_before_success so that:
5515 // - For max_failures=1: Continuation fails 1 time (attempt 1), policy cancels after 1 failure
5516 // - For max_failures=2: Continuation fails 1 time (attempt 1), succeeds on attempt 2
5517 let continuation = Arc::new(PolicyTestContinuation {
5518 log: execution_log.clone(),
5519 attempt_counter: attempt_counter.clone(),
5520 max_attempts_before_success: 2, // Always fail on attempt 1, succeed on attempt 2
5521 });
5522
5523 let handle = tracker.spawn(async move {
5524 log_for_task
5525 .lock()
5526 .await
5527 .push("original_task_executed".to_string());
5528 let error = anyhow::anyhow!("Original task failed");
5529 let result: Result<String, anyhow::Error> =
5530 Err(FailedWithContinuation::into_anyhow(error, continuation));
5531 result
5532 });
5533
5534 // Execute and check result based on policy
5535 let result = handle.await.expect("Task should complete");
5536
5537 // Debug: Print actual results
5538 let log = execution_log.lock().await;
5539 let final_attempt_count = attempt_counter.load(std::sync::atomic::Ordering::Relaxed);
5540 println!(
5541 "Test case: max_failures={}, should_succeed={}",
5542 max_failures, should_succeed
5543 );
5544 println!("Result: {:?}", result.is_ok());
5545 println!("Log entries: {:?}", log);
5546 println!("Attempt count: {}", final_attempt_count);
5547 println!(
5548 "Metrics: success={}, failed={}",
5549 tracker.metrics().success(),
5550 tracker.metrics().failed()
5551 );
5552 drop(log); // Release the lock
5553
5554 // Both test cases should fail because ActionResult::Fail accepts the error and fails the task
5555 assert!(result.is_err(), "{}: Task should fail", description);
5556 assert_eq!(
5557 tracker.metrics().success(),
5558 0,
5559 "{}: Should have 0 successes",
5560 description
5561 );
5562 assert_eq!(
5563 tracker.metrics().failed(),
5564 1,
5565 "{}: Should have 1 failure",
5566 description
5567 );
5568
5569 // Should have stopped after 1 continuation attempt because ActionResult::Fail fails the task
5570 let log = execution_log.lock().await;
5571 assert_eq!(
5572 log.len(),
5573 2,
5574 "{}: Should have 2 log entries (original + 1 continuation attempt)",
5575 description
5576 );
5577 assert_eq!(log[0], "original_task_executed");
5578 assert_eq!(log[1], "continuation_attempt_1");
5579
5580 assert_eq!(
5581 attempt_counter.load(std::sync::atomic::Ordering::Relaxed),
5582 1,
5583 "{}: Should have made 1 continuation attempt",
5584 description
5585 );
5586
5587 // The key difference is whether the tracker gets cancelled
5588 if max_failures == 1 {
5589 assert!(
5590 tracker.cancellation_token().is_cancelled(),
5591 "Tracker should be cancelled with max_failures=1"
5592 );
5593 } else {
5594 assert!(
5595 !tracker.cancellation_token().is_cancelled(),
5596 "Tracker should NOT be cancelled with max_failures=2 (policy allows the error)"
5597 );
5598 }
5599 }
5600
5601 /// Simple test to understand ThresholdCancelPolicy behavior with per-task context
5602 #[rstest]
5603 #[tokio::test]
5604 async fn test_simple_threshold_policy_behavior(unlimited_scheduler: Arc<UnlimitedScheduler>) {
5605 // Test with max_failures=2 - now uses per-task failure counting
5606 let policy = ThresholdCancelPolicy::with_threshold(2);
5607 let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
5608
5609 // Task 1: Should fail but not trigger cancellation (per-task failure count = 1)
5610 let handle1 = tracker.spawn(async {
5611 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("First failure"));
5612 result
5613 });
5614 let result1 = handle1.await.expect("Task should complete");
5615 assert!(result1.is_err(), "First task should fail");
5616 assert!(
5617 !tracker.cancellation_token().is_cancelled(),
5618 "Should not be cancelled after 1 failure"
5619 );
5620
5621 // Task 2: Should fail but not trigger cancellation (different task, per-task failure count = 1)
5622 let handle2 = tracker.spawn(async {
5623 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Second failure"));
5624 result
5625 });
5626 let result2 = handle2.await.expect("Task should complete");
5627 assert!(result2.is_err(), "Second task should fail");
5628 assert!(
5629 !tracker.cancellation_token().is_cancelled(),
5630 "Should NOT be cancelled - per-task context prevents global accumulation"
5631 );
5632
5633 println!("Policy global failure count: {}", policy.failure_count());
5634 assert_eq!(
5635 policy.failure_count(),
5636 2,
5637 "Policy should have counted 2 failures globally (for backwards compatibility)"
5638 );
5639 }
5640
5641 /// Test demonstrating that per-task error context solves the global failure tracking problem
5642 ///
5643 /// This test shows that with OnErrorContext, each task has independent failure tracking.
5644 #[rstest]
5645 #[tokio::test]
5646 async fn test_per_task_context_limitation_demo(unlimited_scheduler: Arc<UnlimitedScheduler>) {
5647 // Create a policy that should allow 2 failures per task
5648 let policy = ThresholdCancelPolicy::with_threshold(2);
5649 let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
5650
5651 // Task 1: Fails once (per-task failure count = 1, below threshold)
5652 let handle1 = tracker.spawn(async {
5653 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Task 1 failure"));
5654 result
5655 });
5656 let result1 = handle1.await.expect("Task should complete");
5657 assert!(result1.is_err(), "Task 1 should fail");
5658
5659 // Task 2: Also fails once (per-task failure count = 1, below threshold)
5660 // With per-task context, this doesn't interfere with Task 1's failure budget
5661 let handle2 = tracker.spawn(async {
5662 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Task 2 failure"));
5663 result
5664 });
5665 let result2 = handle2.await.expect("Task should complete");
5666 assert!(result2.is_err(), "Task 2 should fail");
5667
5668 // With per-task context, tracker should NOT be cancelled
5669 // Each task failed only once, which is below the threshold of 2
5670 assert!(
5671 !tracker.cancellation_token().is_cancelled(),
5672 "Tracker should NOT be cancelled - per-task context prevents premature cancellation"
5673 );
5674
5675 println!("Global failure count: {}", policy.failure_count());
5676 assert_eq!(
5677 policy.failure_count(),
5678 2,
5679 "Global policy counted 2 failures across different tasks"
5680 );
5681
5682 // This demonstrates the limitation: we can't test per-task retry behavior
5683 // because failures from different tasks affect each other's retry budgets
5684 }
5685
5686 /// Test allow_continuation() policy method with attempt-based logic
5687 ///
5688 /// This test verifies that:
5689 /// 1. Policies can conditionally allow/reject continuations based on context
5690 /// 2. When allow_continuation() returns false, FailedWithContinuation is ignored
5691 /// 3. When allow_continuation() returns true, FailedWithContinuation is processed normally
5692 /// 4. The policy's decision takes precedence over task-provided continuations
5693 #[rstest]
5694 #[case(
5695 3,
5696 true,
5697 "Policy allows continuations up to 3 attempts - should succeed"
5698 )]
5699 #[case(
5700 2,
5701 true,
5702 "Policy allows continuations up to 2 attempts - should succeed"
5703 )]
5704 #[case(0, false, "Policy allows 0 attempts - should fail immediately")]
5705 #[tokio::test]
5706 async fn test_allow_continuation_policy_control(
5707 unlimited_scheduler: Arc<UnlimitedScheduler>,
5708 #[case] max_attempts: u32,
5709 #[case] should_succeed: bool,
5710 #[case] description: &str,
5711 ) {
5712 // Policy that allows continuations only up to max_attempts
5713 #[derive(Debug)]
5714 struct AttemptLimitPolicy {
5715 max_attempts: u32,
5716 }
5717
5718 impl OnErrorPolicy for AttemptLimitPolicy {
5719 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
5720 Arc::new(AttemptLimitPolicy {
5721 max_attempts: self.max_attempts,
5722 })
5723 }
5724
5725 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
5726 None // Stateless policy
5727 }
5728
5729 fn allow_continuation(&self, _error: &anyhow::Error, context: &OnErrorContext) -> bool {
5730 context.attempt_count <= self.max_attempts
5731 }
5732
5733 fn on_error(
5734 &self,
5735 _error: &anyhow::Error,
5736 _context: &mut OnErrorContext,
5737 ) -> ErrorResponse {
5738 ErrorResponse::Fail // Just fail when continuations are not allowed
5739 }
5740 }
5741
5742 let policy = Arc::new(AttemptLimitPolicy { max_attempts });
5743 let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
5744 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5745
5746 // Continuation that always tries to retry
5747 #[derive(Debug)]
5748 struct AlwaysRetryContinuation {
5749 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5750 attempt: u32,
5751 }
5752
5753 #[async_trait]
5754 impl Continuation for AlwaysRetryContinuation {
5755 async fn execute(
5756 &self,
5757 _cancel_token: CancellationToken,
5758 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5759 self.log
5760 .lock()
5761 .await
5762 .push(format!("continuation_attempt_{}", self.attempt));
5763
5764 if self.attempt >= 2 {
5765 // Success after 2 attempts
5766 TaskExecutionResult::Success(Box::new("final_success".to_string()))
5767 } else {
5768 // Try to continue with another continuation
5769 let next_continuation = Arc::new(AlwaysRetryContinuation {
5770 log: self.log.clone(),
5771 attempt: self.attempt + 1,
5772 });
5773 let error = anyhow::anyhow!("Continuation attempt {} failed", self.attempt);
5774 TaskExecutionResult::Error(FailedWithContinuation::into_anyhow(
5775 error,
5776 next_continuation,
5777 ))
5778 }
5779 }
5780 }
5781
5782 // Task that immediately fails with a continuation
5783 let initial_continuation = Arc::new(AlwaysRetryContinuation {
5784 log: execution_log.clone(),
5785 attempt: 1,
5786 });
5787
5788 let log_for_task = execution_log.clone();
5789 let handle = tracker.spawn(async move {
5790 log_for_task
5791 .lock()
5792 .await
5793 .push("initial_task_failure".to_string());
5794 let error = anyhow::anyhow!("Initial task failure");
5795 let result: Result<String, anyhow::Error> = Err(FailedWithContinuation::into_anyhow(
5796 error,
5797 initial_continuation,
5798 ));
5799 result
5800 });
5801
5802 let result = handle.await.expect("Task should complete");
5803
5804 if should_succeed {
5805 assert!(result.is_ok(), "{}: Task should succeed", description);
5806 assert_eq!(
5807 tracker.metrics().success(),
5808 1,
5809 "{}: Should have 1 success",
5810 description
5811 );
5812
5813 // Should have executed multiple continuations
5814 let log = execution_log.lock().await;
5815 assert!(
5816 log.len() > 2,
5817 "{}: Should have multiple log entries",
5818 description
5819 );
5820 assert!(log.contains(&"continuation_attempt_1".to_string()));
5821 } else {
5822 assert!(result.is_err(), "{}: Task should fail", description);
5823 assert_eq!(
5824 tracker.metrics().failed(),
5825 1,
5826 "{}: Should have 1 failure",
5827 description
5828 );
5829
5830 // Should have stopped early due to policy rejection
5831 let log = execution_log.lock().await;
5832 assert_eq!(
5833 log.len(),
5834 1,
5835 "{}: Should only have initial task entry",
5836 description
5837 );
5838 assert_eq!(log[0], "initial_task_failure");
5839 // Should NOT contain continuation attempts because policy rejected them
5840 assert!(
5841 !log.iter()
5842 .any(|entry| entry.contains("continuation_attempt")),
5843 "{}: Should not have continuation attempts, but got: {:?}",
5844 description,
5845 *log
5846 );
5847 }
5848 }
5849
5850 /// Test TaskHandle functionality
5851 ///
5852 /// This test verifies that:
5853 /// 1. TaskHandle can be awaited like a JoinHandle
5854 /// 2. TaskHandle provides access to the task's cancellation token
5855 /// 3. Individual task cancellation works correctly
5856 /// 4. TaskHandle methods (abort, is_finished) work as expected
5857 #[tokio::test]
5858 async fn test_task_handle_functionality() {
5859 let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
5860
5861 // Test basic functionality - TaskHandle can be awaited
5862 let handle1 = tracker.spawn(async {
5863 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
5864 Ok("completed".to_string())
5865 });
5866
5867 // Verify we can access the cancellation token
5868 let cancel_token = handle1.cancellation_token();
5869 assert!(
5870 !cancel_token.is_cancelled(),
5871 "Token should not be cancelled initially"
5872 );
5873
5874 // Await the task
5875 let result1 = handle1.await.expect("Task should complete");
5876 assert!(result1.is_ok(), "Task should succeed");
5877 assert_eq!(result1.unwrap(), "completed");
5878
5879 // Test individual task cancellation
5880 let handle2 = tracker.spawn_cancellable(|cancel_token| async move {
5881 tokio::select! {
5882 _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
5883 CancellableTaskResult::Ok("task_was_not_cancelled".to_string())
5884 },
5885 _ = cancel_token.cancelled() => {
5886 CancellableTaskResult::Cancelled
5887 },
5888
5889 }
5890 });
5891
5892 let cancel_token2 = handle2.cancellation_token();
5893
5894 // Cancel this specific task
5895 cancel_token2.cancel();
5896
5897 // The task should be cancelled
5898 let result2 = handle2.await.expect("Task should complete");
5899 assert!(result2.is_err(), "Task should be cancelled");
5900 assert!(
5901 result2.unwrap_err().is_cancellation(),
5902 "Should be a cancellation error"
5903 );
5904
5905 // Test that other tasks are not affected
5906 let handle3 = tracker.spawn(async { Ok("not_cancelled".to_string()) });
5907
5908 let result3 = handle3.await.expect("Task should complete");
5909 assert!(result3.is_ok(), "Other tasks should not be affected");
5910 assert_eq!(result3.unwrap(), "not_cancelled");
5911
5912 // Test abort functionality
5913 let handle4 = tracker.spawn(async {
5914 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
5915 Ok("should_be_aborted".to_string())
5916 });
5917
5918 // Check is_finished before abort
5919 assert!(!handle4.is_finished(), "Task should not be finished yet");
5920
5921 // Abort the task
5922 handle4.abort();
5923
5924 // Task should be aborted (JoinError)
5925 let result4 = handle4.await;
5926 assert!(result4.is_err(), "Aborted task should return JoinError");
5927
5928 // Verify metrics
5929 assert_eq!(
5930 tracker.metrics().success(),
5931 2,
5932 "Should have 2 successful tasks"
5933 );
5934 assert_eq!(
5935 tracker.metrics().cancelled(),
5936 1,
5937 "Should have 1 cancelled task"
5938 );
5939 // Note: aborted tasks don't count as cancelled in our metrics
5940 }
5941
5942 /// Test TaskHandle with cancellable tasks
5943 #[tokio::test]
5944 async fn test_task_handle_with_cancellable_tasks() {
5945 let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
5946
5947 // Test cancellable task with TaskHandle
5948 let handle = tracker.spawn_cancellable(|cancel_token| async move {
5949 tokio::select! {
5950 _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
5951 CancellableTaskResult::Ok("completed".to_string())
5952 },
5953 _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
5954 }
5955 });
5956
5957 // Verify we can access the task's individual cancellation token
5958 let task_cancel_token = handle.cancellation_token();
5959 assert!(
5960 !task_cancel_token.is_cancelled(),
5961 "Task token should not be cancelled initially"
5962 );
5963
5964 // Let the task complete normally
5965 let result = handle.await.expect("Task should complete");
5966 assert!(result.is_ok(), "Task should succeed");
5967 assert_eq!(result.unwrap(), "completed");
5968
5969 // Test cancellation of cancellable task
5970 let handle2 = tracker.spawn_cancellable(|cancel_token| async move {
5971 tokio::select! {
5972 _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
5973 CancellableTaskResult::Ok("should_not_complete".to_string())
5974 },
5975 _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
5976 }
5977 });
5978
5979 // Cancel the specific task
5980 handle2.cancellation_token().cancel();
5981
5982 let result2 = handle2.await.expect("Task should complete");
5983 assert!(result2.is_err(), "Task should be cancelled");
5984 assert!(
5985 result2.unwrap_err().is_cancellation(),
5986 "Should be a cancellation error"
5987 );
5988
5989 // Verify metrics
5990 assert_eq!(
5991 tracker.metrics().success(),
5992 1,
5993 "Should have 1 successful task"
5994 );
5995 assert_eq!(
5996 tracker.metrics().cancelled(),
5997 1,
5998 "Should have 1 cancelled task"
5999 );
6000 }
6001
6002 /// Test FailedWithContinuation helper methods
6003 ///
6004 /// This test verifies that:
6005 /// 1. from_fn creates working continuations from simple closures
6006 /// 2. from_cancellable creates working continuations from cancellable closures
6007 /// 3. Both helpers integrate correctly with the task execution system
6008 #[tokio::test]
6009 async fn test_continuation_helpers() {
6010 let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
6011
6012 // Test from_fn helper
6013 let handle1 = tracker.spawn(async {
6014 let error =
6015 FailedWithContinuation::from_fn(anyhow::anyhow!("Initial failure"), || async {
6016 Ok("Success from from_fn".to_string())
6017 });
6018 let result: Result<String, anyhow::Error> = Err(error);
6019 result
6020 });
6021
6022 let result1 = handle1.await.expect("Task should complete");
6023 assert!(
6024 result1.is_ok(),
6025 "Task with from_fn continuation should succeed"
6026 );
6027 assert_eq!(result1.unwrap(), "Success from from_fn");
6028
6029 // Test from_cancellable helper
6030 let handle2 = tracker.spawn(async {
6031 let error = FailedWithContinuation::from_cancellable(
6032 anyhow::anyhow!("Initial failure"),
6033 |_cancel_token| async move { Ok("Success from from_cancellable".to_string()) },
6034 );
6035 let result: Result<String, anyhow::Error> = Err(error);
6036 result
6037 });
6038
6039 let result2 = handle2.await.expect("Task should complete");
6040 assert!(
6041 result2.is_ok(),
6042 "Task with from_cancellable continuation should succeed"
6043 );
6044 assert_eq!(result2.unwrap(), "Success from from_cancellable");
6045
6046 // Verify metrics
6047 assert_eq!(
6048 tracker.metrics().success(),
6049 2,
6050 "Should have 2 successful tasks"
6051 );
6052 assert_eq!(tracker.metrics().failed(), 0, "Should have 0 failed tasks");
6053 }
6054
6055 /// Test should_reschedule() policy method with mock scheduler tracking
6056 ///
6057 /// This test verifies that:
6058 /// 1. When should_reschedule() returns false, the guard is reused (efficient)
6059 /// 2. When should_reschedule() returns true, the guard is re-acquired through scheduler
6060 /// 3. The scheduler's acquire_execution_slot is called the expected number of times
6061 /// 4. Rescheduling works for both task-driven and policy-driven continuations
6062 #[rstest]
6063 #[case(false, 1, "Policy requests no rescheduling - should reuse guard")]
6064 #[case(true, 2, "Policy requests rescheduling - should re-acquire guard")]
6065 #[tokio::test]
6066 async fn test_should_reschedule_policy_control(
6067 #[case] should_reschedule: bool,
6068 #[case] expected_acquisitions: u32,
6069 #[case] description: &str,
6070 ) {
6071 // Mock scheduler that tracks acquisition calls
6072 #[derive(Debug)]
6073 struct MockScheduler {
6074 acquisition_count: Arc<AtomicU32>,
6075 }
6076
6077 impl MockScheduler {
6078 fn new() -> Self {
6079 Self {
6080 acquisition_count: Arc::new(AtomicU32::new(0)),
6081 }
6082 }
6083
6084 fn acquisition_count(&self) -> u32 {
6085 self.acquisition_count.load(Ordering::Relaxed)
6086 }
6087 }
6088
6089 #[async_trait]
6090 impl TaskScheduler for MockScheduler {
6091 async fn acquire_execution_slot(
6092 &self,
6093 _cancel_token: CancellationToken,
6094 ) -> SchedulingResult<Box<dyn ResourceGuard>> {
6095 self.acquisition_count.fetch_add(1, Ordering::Relaxed);
6096 SchedulingResult::Execute(Box::new(UnlimitedGuard))
6097 }
6098 }
6099
6100 // Policy that controls rescheduling behavior
6101 #[derive(Debug)]
6102 struct RescheduleTestPolicy {
6103 should_reschedule: bool,
6104 }
6105
6106 impl OnErrorPolicy for RescheduleTestPolicy {
6107 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
6108 Arc::new(RescheduleTestPolicy {
6109 should_reschedule: self.should_reschedule,
6110 })
6111 }
6112
6113 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
6114 None // Stateless policy
6115 }
6116
6117 fn allow_continuation(
6118 &self,
6119 _error: &anyhow::Error,
6120 _context: &OnErrorContext,
6121 ) -> bool {
6122 true // Always allow continuations for this test
6123 }
6124
6125 fn should_reschedule(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
6126 self.should_reschedule
6127 }
6128
6129 fn on_error(
6130 &self,
6131 _error: &anyhow::Error,
6132 _context: &mut OnErrorContext,
6133 ) -> ErrorResponse {
6134 ErrorResponse::Fail // Just fail when continuations are not allowed
6135 }
6136 }
6137
6138 let mock_scheduler = Arc::new(MockScheduler::new());
6139 let policy = Arc::new(RescheduleTestPolicy { should_reschedule });
6140 let tracker = TaskTracker::new(mock_scheduler.clone(), policy).unwrap();
6141 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
6142
6143 // Simple continuation that succeeds on second attempt
6144 #[derive(Debug)]
6145 struct SimpleRetryContinuation {
6146 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6147 }
6148
6149 #[async_trait]
6150 impl Continuation for SimpleRetryContinuation {
6151 async fn execute(
6152 &self,
6153 _cancel_token: CancellationToken,
6154 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
6155 self.log
6156 .lock()
6157 .await
6158 .push("continuation_executed".to_string());
6159
6160 // Succeed immediately
6161 TaskExecutionResult::Success(Box::new("continuation_success".to_string()))
6162 }
6163 }
6164
6165 // Task that fails with a continuation
6166 let continuation = Arc::new(SimpleRetryContinuation {
6167 log: execution_log.clone(),
6168 });
6169
6170 let log_for_task = execution_log.clone();
6171 let handle = tracker.spawn(async move {
6172 log_for_task
6173 .lock()
6174 .await
6175 .push("initial_task_failure".to_string());
6176 let error = anyhow::anyhow!("Initial task failure");
6177 let result: Result<String, anyhow::Error> =
6178 Err(FailedWithContinuation::into_anyhow(error, continuation));
6179 result
6180 });
6181
6182 let result = handle.await.expect("Task should complete");
6183
6184 // Task should succeed regardless of rescheduling behavior
6185 assert!(result.is_ok(), "{}: Task should succeed", description);
6186 assert_eq!(
6187 tracker.metrics().success(),
6188 1,
6189 "{}: Should have 1 success",
6190 description
6191 );
6192
6193 // Verify the execution log
6194 let log = execution_log.lock().await;
6195 assert_eq!(
6196 log.len(),
6197 2,
6198 "{}: Should have initial task + continuation",
6199 description
6200 );
6201 assert_eq!(log[0], "initial_task_failure");
6202 assert_eq!(log[1], "continuation_executed");
6203
6204 // Most importantly: verify the scheduler acquisition count
6205 let actual_acquisitions = mock_scheduler.acquisition_count();
6206 assert_eq!(
6207 actual_acquisitions, expected_acquisitions,
6208 "{}: Expected {} scheduler acquisitions, got {}",
6209 description, expected_acquisitions, actual_acquisitions
6210 );
6211 }
6212
6213 /// Test continuation loop with custom action policies
6214 ///
6215 /// This tests that custom error actions can also provide continuations
6216 /// and that the loop behavior works correctly with policy-provided continuations
6217 ///
6218 /// NOTE: Uses fresh policy/action instances to avoid global state interference.
6219 #[rstest]
6220 #[case(1, true, "Custom action with 1 retry should succeed")]
6221 #[case(3, true, "Custom action with 3 retries should succeed")]
6222 #[tokio::test]
6223 async fn test_continuation_loop_with_custom_action_policy(
6224 unlimited_scheduler: Arc<UnlimitedScheduler>,
6225 #[case] max_retries: u32,
6226 #[case] should_succeed: bool,
6227 #[case] description: &str,
6228 ) {
6229 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
6230 let retry_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
6231
6232 // Custom action that provides continuations up to max_retries
6233 #[derive(Debug)]
6234 struct RetryAction {
6235 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6236 retry_count: Arc<std::sync::atomic::AtomicU32>,
6237 max_retries: u32,
6238 }
6239
6240 #[async_trait]
6241 impl OnErrorAction for RetryAction {
6242 async fn execute(
6243 &self,
6244 _error: &anyhow::Error,
6245 _task_id: TaskId,
6246 _attempt_count: u32,
6247 _context: &TaskExecutionContext,
6248 ) -> ActionResult {
6249 let current_retry = self
6250 .retry_count
6251 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
6252 + 1;
6253 self.log
6254 .lock()
6255 .await
6256 .push(format!("custom_action_retry_{}", current_retry));
6257
6258 if current_retry <= self.max_retries {
6259 // Provide a continuation that succeeds if this is the final retry
6260 #[derive(Debug)]
6261 struct RetryContinuation {
6262 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6263 retry_number: u32,
6264 max_retries: u32,
6265 }
6266
6267 #[async_trait]
6268 impl Continuation for RetryContinuation {
6269 async fn execute(
6270 &self,
6271 _cancel_token: CancellationToken,
6272 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>>
6273 {
6274 self.log
6275 .lock()
6276 .await
6277 .push(format!("retry_continuation_{}", self.retry_number));
6278
6279 if self.retry_number >= self.max_retries {
6280 // Final retry succeeds
6281 TaskExecutionResult::Success(Box::new(format!(
6282 "success_after_{}_retries",
6283 self.retry_number
6284 )))
6285 } else {
6286 // Still need more retries, fail with regular error (not FailedWithContinuation)
6287 // This will trigger the custom action again
6288 TaskExecutionResult::Error(anyhow::anyhow!(
6289 "Retry {} still failing",
6290 self.retry_number
6291 ))
6292 }
6293 }
6294 }
6295
6296 let continuation = Arc::new(RetryContinuation {
6297 log: self.log.clone(),
6298 retry_number: current_retry,
6299 max_retries: self.max_retries,
6300 });
6301
6302 ActionResult::Continue { continuation }
6303 } else {
6304 // Exceeded max retries, cancel
6305 ActionResult::Shutdown
6306 }
6307 }
6308 }
6309
6310 // Custom policy that uses the retry action
6311 #[derive(Debug)]
6312 struct CustomRetryPolicy {
6313 action: Arc<RetryAction>,
6314 }
6315
6316 impl OnErrorPolicy for CustomRetryPolicy {
6317 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
6318 Arc::new(CustomRetryPolicy {
6319 action: self.action.clone(),
6320 })
6321 }
6322
6323 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
6324 None // Stateless policy - no heap allocation
6325 }
6326
6327 fn on_error(
6328 &self,
6329 _error: &anyhow::Error,
6330 _context: &mut OnErrorContext,
6331 ) -> ErrorResponse {
6332 ErrorResponse::Custom(Box::new(RetryAction {
6333 log: self.action.log.clone(),
6334 retry_count: self.action.retry_count.clone(),
6335 max_retries: self.action.max_retries,
6336 }))
6337 }
6338 }
6339
6340 let action = Arc::new(RetryAction {
6341 log: execution_log.clone(),
6342 retry_count: retry_count.clone(),
6343 max_retries,
6344 });
6345 let policy = Arc::new(CustomRetryPolicy { action });
6346 let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
6347
6348 // Task that always fails with regular error (not FailedWithContinuation)
6349 let log_for_task = execution_log.clone();
6350 let handle = tracker.spawn(async move {
6351 log_for_task
6352 .lock()
6353 .await
6354 .push("original_task_failed".to_string());
6355 let result: Result<String, anyhow::Error> =
6356 Err(anyhow::anyhow!("Original task failure"));
6357 result
6358 });
6359
6360 // Execute and verify results
6361 let result = handle.await.expect("Task should complete");
6362
6363 if should_succeed {
6364 assert!(result.is_ok(), "{}: Task should succeed", description);
6365 assert_eq!(
6366 tracker.metrics().success(),
6367 1,
6368 "{}: Should have 1 success",
6369 description
6370 );
6371
6372 // Verify the retry sequence
6373 let log = execution_log.lock().await;
6374 let expected_entries = 1 + (max_retries * 2); // original + (action + continuation) per retry
6375 assert_eq!(
6376 log.len(),
6377 expected_entries as usize,
6378 "{}: Should have {} log entries",
6379 description,
6380 expected_entries
6381 );
6382
6383 assert_eq!(
6384 retry_count.load(std::sync::atomic::Ordering::Relaxed),
6385 max_retries,
6386 "{}: Should have made {} retry attempts",
6387 description,
6388 max_retries
6389 );
6390 } else {
6391 assert!(result.is_err(), "{}: Task should fail", description);
6392 assert!(
6393 tracker.cancellation_token().is_cancelled(),
6394 "{}: Should be cancelled",
6395 description
6396 );
6397
6398 // Should have stopped after max_retries
6399 let final_retry_count = retry_count.load(std::sync::atomic::Ordering::Relaxed);
6400 assert!(
6401 final_retry_count > max_retries,
6402 "{}: Should have exceeded max_retries ({}), got {}",
6403 description,
6404 max_retries,
6405 final_retry_count
6406 );
6407 }
6408 }
6409
6410 /// Test mixed continuation sources (task-driven + policy-driven)
6411 ///
6412 /// This test verifies that both task-provided continuations and policy-provided
6413 /// continuations can work together in the same execution flow
6414 #[rstest]
6415 #[tokio::test]
6416 async fn test_mixed_continuation_sources(
6417 unlimited_scheduler: Arc<UnlimitedScheduler>,
6418 log_policy: Arc<LogOnlyPolicy>,
6419 ) {
6420 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
6421 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
6422
6423 // Task that provides a continuation, which then fails with regular error
6424 let log_for_task = execution_log.clone();
6425 let log_for_continuation = execution_log.clone();
6426
6427 #[derive(Debug)]
6428 struct MixedContinuation {
6429 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6430 }
6431
6432 #[async_trait]
6433 impl Continuation for MixedContinuation {
6434 async fn execute(
6435 &self,
6436 _cancel_token: CancellationToken,
6437 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
6438 self.log
6439 .lock()
6440 .await
6441 .push("task_continuation_executed".to_string());
6442 // This continuation fails with a regular error (not FailedWithContinuation)
6443 // So it will be handled by the policy (LogOnlyPolicy just continues)
6444 TaskExecutionResult::Error(anyhow::anyhow!("Task continuation failed"))
6445 }
6446 }
6447
6448 let continuation = Arc::new(MixedContinuation {
6449 log: log_for_continuation,
6450 });
6451
6452 let handle = tracker.spawn(async move {
6453 log_for_task
6454 .lock()
6455 .await
6456 .push("original_task_executed".to_string());
6457
6458 // Task provides continuation
6459 let error = anyhow::anyhow!("Original task failed");
6460 let result: Result<String, anyhow::Error> =
6461 Err(FailedWithContinuation::into_anyhow(error, continuation));
6462 result
6463 });
6464
6465 // Execute - should fail because continuation fails and LogOnlyPolicy just logs
6466 let result = handle.await.expect("Task should complete");
6467 assert!(
6468 result.is_err(),
6469 "Should fail because continuation fails and policy just logs"
6470 );
6471
6472 // Verify execution sequence
6473 let log = execution_log.lock().await;
6474 assert_eq!(log.len(), 2);
6475 assert_eq!(log[0], "original_task_executed");
6476 assert_eq!(log[1], "task_continuation_executed");
6477
6478 // Verify metrics - should show failure from continuation
6479 assert_eq!(tracker.metrics().success(), 0);
6480 assert_eq!(tracker.metrics().failed(), 1);
6481 }
6482
6483 /// Debug test to understand the threshold policy behavior in retry loop
6484 #[rstest]
6485 #[tokio::test]
6486 async fn debug_threshold_policy_in_retry_loop(unlimited_scheduler: Arc<UnlimitedScheduler>) {
6487 let policy = ThresholdCancelPolicy::with_threshold(2);
6488 let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
6489
6490 // Simple continuation that always fails with regular error
6491 #[derive(Debug)]
6492 struct AlwaysFailContinuation {
6493 attempt: Arc<std::sync::atomic::AtomicU32>,
6494 }
6495
6496 #[async_trait]
6497 impl Continuation for AlwaysFailContinuation {
6498 async fn execute(
6499 &self,
6500 _cancel_token: CancellationToken,
6501 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
6502 let attempt_num = self
6503 .attempt
6504 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
6505 + 1;
6506 println!("Continuation attempt {}", attempt_num);
6507 TaskExecutionResult::Error(anyhow::anyhow!(
6508 "Continuation attempt {} failed",
6509 attempt_num
6510 ))
6511 }
6512 }
6513
6514 let attempt_counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
6515 let continuation = Arc::new(AlwaysFailContinuation {
6516 attempt: attempt_counter.clone(),
6517 });
6518
6519 let handle = tracker.spawn(async move {
6520 println!("Original task executing");
6521 let error = anyhow::anyhow!("Original task failed");
6522 let result: Result<String, anyhow::Error> =
6523 Err(FailedWithContinuation::into_anyhow(error, continuation));
6524 result
6525 });
6526
6527 let result = handle.await.expect("Task should complete");
6528 println!("Final result: {:?}", result.is_ok());
6529 println!("Policy failure count: {}", policy.failure_count());
6530 println!(
6531 "Continuation attempts: {}",
6532 attempt_counter.load(std::sync::atomic::Ordering::Relaxed)
6533 );
6534 println!(
6535 "Tracker cancelled: {}",
6536 tracker.cancellation_token().is_cancelled()
6537 );
6538 println!(
6539 "Metrics: success={}, failed={}",
6540 tracker.metrics().success(),
6541 tracker.metrics().failed()
6542 );
6543
6544 // This should help us understand what's happening
6545 }
6546}