dynamo_runtime/utils/tasks/tracker.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! # Task Tracker - Hierarchical Task Management System
5//!
6//! A composable task management system with configurable scheduling and error handling policies.
7//! The TaskTracker enables controlled concurrent execution with proper resource management,
8//! cancellation semantics, and retry support.
9//!
10//! ## Architecture Overview
11//!
12//! The TaskTracker system is built around three core abstractions that compose together:
13//!
14//! ### 1. **TaskScheduler** - Resource Management
15//!
16//! Controls when and how tasks acquire execution resources (permits, slots, etc.).
17//! Schedulers implement resource acquisition with cancellation support:
18//!
19//! ```text
20//! TaskScheduler::acquire_execution_slot(cancel_token) -> SchedulingResult<ResourceGuard>
21//! ```
22//!
23//! - **Resource Acquisition**: Can be cancelled to avoid unnecessary allocation
24//! - **RAII Guards**: Resources are automatically released when guards are dropped
25//! - **Pluggable**: Different scheduling policies (unlimited, semaphore, rate-limited, etc.)
26//!
27//! ### 2. **OnErrorPolicy** - Error Handling
28//!
29//! Defines how the system responds to task failures:
30//!
31//! ```text
32//! OnErrorPolicy::on_error(error, task_id) -> ErrorResponse
33//! ```
34//!
35//! - **ErrorResponse::Fail**: Log error, fail this task
36//! - **ErrorResponse::Shutdown**: Shutdown tracker and all children
37//! - **ErrorResponse::Custom(action)**: Execute custom logic that can return:
38//! - `ActionResult::Fail`: Handle error and fail the task
39//! - `ActionResult::Shutdown`: Shutdown tracker
40//! - `ActionResult::Continue { continuation }`: Continue with provided task
41//!
42//! ### 3. **Execution Pipeline** - Task Orchestration
43//!
44//! The execution pipeline coordinates scheduling, execution, and error handling:
45//!
46//! ```text
47//! 1. Acquire resources (scheduler.acquire_execution_slot)
48//! 2. Create task future (only after resources acquired)
49//! 3. Execute task while holding guard (RAII pattern)
50//! 4. Handle errors through policy (with retry support for cancellable tasks)
51//! 5. Update metrics and release resources
52//! ```
53//!
54//! ## Key Design Principles
55//!
56//! ### **Separation of Concerns**
57//! - **Scheduling**: When/how to allocate resources
58//! - **Execution**: Running tasks with proper resource management
59//! - **Error Handling**: Responding to failures with configurable policies
60//!
61//! ### **Composability**
62//! - Schedulers and error policies are independent and can be mixed/matched
63//! - Custom policies can be implemented via traits
64//! - Execution pipeline handles the coordination automatically
65//!
66//! ### **Resource Safety**
67//! - Resources are acquired before task creation (prevents early execution)
68//! - RAII pattern ensures resources are always released
69//! - Cancellation is supported during resource acquisition, not during execution
70//!
71//! ### **Retry Support**
72//! - Regular tasks (`spawn`): Cannot be retried (future is consumed)
73//! - Cancellable tasks (`spawn_cancellable`): Support retry via `FnMut` closures
74//! - Error policies can provide next executors via `ActionResult::Continue`
75//!
76//! ## Task Types
77//!
78//! ### Regular Tasks
79//! ```rust
80//! # use dynamo_runtime::utils::tasks::tracker::*;
81//! # #[tokio::main]
82//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
83//! # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
84//! let handle = tracker.spawn(async { Ok(42) });
85//! # let _result = handle.await?;
86//! # Ok(())
87//! # }
88//! ```
89//! - Simple futures that run to completion
90//! - Cannot be retried (future is consumed on first execution)
91//! - Suitable for one-shot operations
92//!
93//! ### Cancellable Tasks
94//! ```rust
95//! # use dynamo_runtime::utils::tasks::tracker::*;
96//! # #[tokio::main]
97//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
98//! # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
99//! let handle = tracker.spawn_cancellable(|cancel_token| async move {
100//! // Task can check cancel_token.is_cancelled() or use tokio::select!
101//! CancellableTaskResult::Ok(42)
102//! });
103//! # let _result = handle.await?;
104//! # Ok(())
105//! # }
106//! ```
107//! - Receive a `CancellationToken` for cooperative cancellation
108//! - Support retry via `FnMut` closures (can be called multiple times)
109//! - Return `CancellableTaskResult` to indicate success/cancellation/error
110//!
111//! ## Hierarchical Structure
112//!
113//! TaskTrackers form parent-child relationships:
114//! - **Metrics**: Child metrics aggregate to parents
115//! - **Cancellation**: Parent cancellation propagates to children
116//! - **Independence**: Child cancellation doesn't affect parents
117//! - **Cleanup**: `join()` waits for all descendants bottom-up
118//!
119//! ## Metrics and Observability
120//!
121//! Built-in metrics track task lifecycle:
122//! - `issued`: Tasks submitted via spawn methods
123//! - `active`: Currently executing tasks
124//! - `success/failed/cancelled/rejected`: Final outcomes
125//! - `pending`: Issued but not completed (issued - completed)
126//! - `queued`: Waiting for resources (pending - active)
127//!
128//! Optional Prometheus integration available via `PrometheusTaskMetrics`.
129//!
130//! ## Usage Examples
131//!
132//! ### Basic Task Execution
133//! ```rust
134//! use dynamo_runtime::utils::tasks::tracker::*;
135//! use std::sync::Arc;
136//!
137//! # #[tokio::main]
138//! # async fn main() -> anyhow::Result<()> {
139//! let scheduler = SemaphoreScheduler::with_permits(10);
140//! let error_policy = LogOnlyPolicy::new();
141//! let tracker = TaskTracker::new(scheduler, error_policy)?;
142//!
143//! let handle = tracker.spawn(async { Ok(42) });
144//! let result = handle.await??;
145//! assert_eq!(result, 42);
146//! # Ok(())
147//! # }
148//! ```
149//!
150//! ### Cancellable Tasks with Retry
151//! ```rust
152//! # use dynamo_runtime::utils::tasks::tracker::*;
153//! # use std::sync::Arc;
154//! # #[tokio::main]
155//! # async fn main() -> anyhow::Result<()> {
156//! # let scheduler = SemaphoreScheduler::with_permits(10);
157//! # let error_policy = LogOnlyPolicy::new();
158//! # let tracker = TaskTracker::new(scheduler, error_policy)?;
159//! let handle = tracker.spawn_cancellable(|cancel_token| async move {
160//! tokio::select! {
161//! result = do_work() => CancellableTaskResult::Ok(result),
162//! _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
163//! }
164//! });
165//! # Ok(())
166//! # }
167//! # async fn do_work() -> i32 { 42 }
168//! ```
169//!
170//! ### Task-Driven Retry with Continuations
171//! ```rust
172//! # use dynamo_runtime::utils::tasks::tracker::*;
173//! # use anyhow::anyhow;
174//! # #[tokio::main]
175//! # async fn main() -> anyhow::Result<()> {
176//! # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new())?;
177//! let handle = tracker.spawn(async {
178//! // Simulate initial failure with retry logic
179//! let error = FailedWithContinuation::from_fn(
180//! anyhow!("Network timeout"),
181//! || async {
182//! println!("Retrying with exponential backoff...");
183//! tokio::time::sleep(std::time::Duration::from_millis(100)).await;
184//! Ok("Success after retry".to_string())
185//! }
186//! );
187//! let result: Result<String, anyhow::Error> = Err(error);
188//! result
189//! });
190//!
191//! let result = handle.await?;
192//! assert!(result.is_ok());
193//! # Ok(())
194//! # }
195//! ```
196//!
197//! ### Custom Error Policy with Continuation
198//! ```rust
199//! # use dynamo_runtime::utils::tasks::tracker::*;
200//! # use std::sync::Arc;
201//! # use async_trait::async_trait;
202//! # #[derive(Debug)]
203//! struct RetryPolicy {
204//! max_attempts: u32,
205//! }
206//!
207//! impl OnErrorPolicy for RetryPolicy {
208//! fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
209//! Arc::new(RetryPolicy { max_attempts: self.max_attempts })
210//! }
211//!
212//! fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
213//! None // Stateless policy
214//! }
215//!
216//! fn on_error(&self, _error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
217//! if context.attempt_count < self.max_attempts {
218//! ErrorResponse::Custom(Box::new(RetryAction))
219//! } else {
220//! ErrorResponse::Fail
221//! }
222//! }
223//! }
224//!
225//! # #[derive(Debug)]
226//! struct RetryAction;
227//!
228//! #[async_trait]
229//! impl OnErrorAction for RetryAction {
230//! async fn execute(
231//! &self,
232//! _error: &anyhow::Error,
233//! _task_id: TaskId,
234//! _attempt_count: u32,
235//! _context: &TaskExecutionContext,
236//! ) -> ActionResult {
237//! // In practice, you would create a continuation here
238//! ActionResult::Fail
239//! }
240//! }
241//! ```
242//!
243//! ## Future Extensibility
244//!
245//! The system is designed for extensibility. See the source code for detailed TODO comments
246//! describing additional policies that can be implemented:
247//! - **Scheduling**: Token bucket rate limiting, adaptive concurrency, memory-aware scheduling
248//! - **Error Handling**: Retry with backoff, circuit breakers, dead letter queues
249//!
250//! Each TODO comment includes complete implementation guidance with data structures,
251//! algorithms, and dependencies needed for future contributors.
252//!
253//! ## Hierarchical Organization
254//!
255//! ```rust
256//! use dynamo_runtime::utils::tasks::tracker::{
257//! TaskTracker, UnlimitedScheduler, ThresholdCancelPolicy, SemaphoreScheduler
258//! };
259//!
260//! # async fn example() -> anyhow::Result<()> {
261//! // Create root tracker with failure threshold policy
262//! let error_policy = ThresholdCancelPolicy::with_threshold(5);
263//! let root = TaskTracker::builder()
264//! .scheduler(UnlimitedScheduler::new())
265//! .error_policy(error_policy)
266//! .build()?;
267//!
268//! // Create child trackers for different components
269//! let api_handler = root.child_tracker()?; // Inherits policies
270//! let background_jobs = root.child_tracker()?;
271//!
272//! // Children can have custom policies
273//! let rate_limited = root.child_tracker_builder()
274//! .scheduler(SemaphoreScheduler::with_permits(2)) // Custom concurrency limit
275//! .build()?;
276//!
277//! // Tasks run independently but metrics roll up
278//! api_handler.spawn(async { Ok(()) });
279//! background_jobs.spawn(async { Ok(()) });
280//! rate_limited.spawn(async { Ok(()) });
281//!
282//! // Join all children hierarchically
283//! root.join().await;
284//! assert_eq!(root.metrics().success(), 3); // Sees all successes
285//! # Ok(())
286//! # }
287//! ```
288//!
289//! ## Policy Examples
290//!
291//! ```rust
292//! use dynamo_runtime::utils::tasks::tracker::{
293//! TaskTracker, CancelOnError, SemaphoreScheduler, ThresholdCancelPolicy
294//! };
295//!
296//! # async fn example() -> anyhow::Result<()> {
297//! // Pattern-based error cancellation
298//! let (error_policy, token) = CancelOnError::with_patterns(
299//! vec!["OutOfMemory".to_string(), "DeviceError".to_string()]
300//! );
301//! let simple = TaskTracker::builder()
302//! .scheduler(SemaphoreScheduler::with_permits(5))
303//! .error_policy(error_policy)
304//! .build()?;
305//!
306//! // Threshold-based cancellation with monitoring
307//! let scheduler = SemaphoreScheduler::with_permits(10); // Returns Arc<SemaphoreScheduler>
308//! let error_policy = ThresholdCancelPolicy::with_threshold(3); // Returns Arc<Policy>
309//!
310//! let advanced = TaskTracker::builder()
311//! .scheduler(scheduler)
312//! .error_policy(error_policy)
313//! .build()?;
314//!
315//! // Monitor cancellation externally
316//! if token.is_cancelled() {
317//! println!("Tracker cancelled due to failures");
318//! }
319//! # Ok(())
320//! # }
321//! ```
322//!
323//! ## Metrics and Observability
324//!
325//! ```rust
326//! use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
327//!
328//! # async fn example() -> anyhow::Result<()> {
329//! let tracker = std::sync::Arc::new(TaskTracker::builder()
330//! .scheduler(SemaphoreScheduler::with_permits(2)) // Only 2 concurrent tasks
331//! .error_policy(LogOnlyPolicy::new())
332//! .build()?);
333//!
334//! // Spawn multiple tasks
335//! for i in 0..5 {
336//! tracker.spawn(async move {
337//! tokio::time::sleep(std::time::Duration::from_millis(100)).await;
338//! Ok(i)
339//! });
340//! }
341//!
342//! // Check metrics
343//! let metrics = tracker.metrics();
344//! println!("Issued: {}", metrics.issued()); // 5 tasks issued
345//! println!("Active: {}", metrics.active()); // 2 tasks running (semaphore limit)
346//! println!("Queued: {}", metrics.queued()); // 3 tasks waiting in scheduler queue
347//! println!("Pending: {}", metrics.pending()); // 5 tasks not yet completed
348//!
349//! tracker.join().await;
350//! assert_eq!(metrics.success(), 5);
351//! assert_eq!(metrics.pending(), 0);
352//! # Ok(())
353//! # }
354//! ```
355//!
356//! ## Prometheus Integration
357//!
358//! ```rust
359//! use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
360//! use dynamo_runtime::DistributedRuntime;
361//!
362//! # async fn example(drt: &DistributedRuntime) -> anyhow::Result<()> {
363//! // Root tracker with Prometheus metrics
364//! let tracker = TaskTracker::new_with_prometheus(
365//! SemaphoreScheduler::with_permits(10),
366//! LogOnlyPolicy::new(),
367//! drt,
368//! "my_component"
369//! )?;
370//!
371//! // Metrics automatically exported to Prometheus:
372//! // - my_component_tasks_issued_total
373//! // - my_component_tasks_success_total
374//! // - my_component_tasks_failed_total
375//! // - my_component_tasks_active
376//! // - my_component_tasks_queued
377//! # Ok(())
378//! # }
379//! ```
380
381use std::future::Future;
382use std::pin::Pin;
383use std::sync::Arc;
384use std::sync::atomic::{AtomicU64, Ordering};
385
386use crate::metrics::MetricsHierarchy;
387use crate::metrics::prometheus_names::task_tracker;
388use anyhow::Result;
389use async_trait::async_trait;
390use derive_builder::Builder;
391use std::collections::HashSet;
392use std::sync::{Mutex, RwLock, Weak};
393use std::time::Duration;
394use thiserror::Error;
395use tokio::sync::Semaphore;
396use tokio::task::JoinHandle;
397use tokio_util::sync::CancellationToken;
398use tokio_util::task::TaskTracker as TokioTaskTracker;
399use tracing::{Instrument, debug, error, warn};
400use uuid::Uuid;
401
402/// Error type for task execution results
403///
404/// This enum distinguishes between task cancellation and actual failures,
405/// enabling proper metrics tracking and error handling.
406#[derive(Error, Debug)]
407pub enum TaskError {
408 /// Task was cancelled (either via cancellation token or tracker shutdown)
409 #[error("Task was cancelled")]
410 Cancelled,
411
412 /// Task failed with an error
413 #[error(transparent)]
414 Failed(#[from] anyhow::Error),
415
416 /// Cannot spawn task on a closed tracker
417 #[error("Cannot spawn task on a closed tracker")]
418 TrackerClosed,
419}
420
421impl TaskError {
422 /// Check if this error represents a cancellation
423 ///
424 /// This is a convenience method for compatibility and readability.
425 pub fn is_cancellation(&self) -> bool {
426 matches!(self, TaskError::Cancelled)
427 }
428
429 /// Check if this error represents a failure
430 pub fn is_failure(&self) -> bool {
431 matches!(self, TaskError::Failed(_))
432 }
433
434 /// Get the underlying anyhow::Error for failures, or a cancellation error for cancellations
435 ///
436 /// This is provided for compatibility with existing code that expects anyhow::Error.
437 pub fn into_anyhow(self) -> anyhow::Error {
438 match self {
439 TaskError::Failed(err) => err,
440 TaskError::Cancelled => anyhow::anyhow!("Task was cancelled"),
441 TaskError::TrackerClosed => anyhow::anyhow!("Cannot spawn task on a closed tracker"),
442 }
443 }
444}
445
446/// A handle to a spawned task that provides both join functionality and cancellation control
447///
448/// `TaskHandle` wraps a `JoinHandle` and provides access to the task's individual cancellation token.
449/// This allows fine-grained control over individual tasks while maintaining the familiar `JoinHandle` API.
450///
451/// # Example
452/// ```rust
453/// # use dynamo_runtime::utils::tasks::tracker::*;
454/// # #[tokio::main]
455/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
456/// # let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new())?;
457/// let handle = tracker.spawn(async {
458/// tokio::time::sleep(std::time::Duration::from_millis(100)).await;
459/// Ok(42)
460/// });
461///
462/// // Access the task's cancellation token
463/// let cancel_token = handle.cancellation_token();
464///
465/// // Can cancel the specific task
466/// // cancel_token.cancel();
467///
468/// // Await the task like a normal JoinHandle
469/// let result = handle.await?;
470/// assert_eq!(result?, 42);
471/// # Ok(())
472/// # }
473/// ```
474pub struct TaskHandle<T> {
475 join_handle: JoinHandle<Result<T, TaskError>>,
476 cancel_token: CancellationToken,
477}
478
479impl<T> TaskHandle<T> {
480 /// Create a new TaskHandle wrapping a JoinHandle and cancellation token
481 pub(crate) fn new(
482 join_handle: JoinHandle<Result<T, TaskError>>,
483 cancel_token: CancellationToken,
484 ) -> Self {
485 Self {
486 join_handle,
487 cancel_token,
488 }
489 }
490
491 /// Get the cancellation token for this specific task
492 ///
493 /// This token is a child of the tracker's cancellation token and can be used
494 /// to cancel just this individual task without affecting other tasks.
495 // FIXME: The doctest previously here failed intermittently and may
496 // indicate a bug in either the doctest example or the implementation.
497 pub fn cancellation_token(&self) -> &CancellationToken {
498 &self.cancel_token
499 }
500
501 /// Abort the task associated with this handle
502 ///
503 /// This is equivalent to calling `JoinHandle::abort()` and will cause the task
504 /// to be cancelled immediately without running any cleanup code.
505 pub fn abort(&self) {
506 self.join_handle.abort();
507 }
508
509 /// Check if the task associated with this handle has finished
510 ///
511 /// This is equivalent to calling `JoinHandle::is_finished()`.
512 pub fn is_finished(&self) -> bool {
513 self.join_handle.is_finished()
514 }
515}
516
517impl<T> std::future::Future for TaskHandle<T> {
518 type Output = Result<Result<T, TaskError>, tokio::task::JoinError>;
519
520 fn poll(
521 mut self: std::pin::Pin<&mut Self>,
522 cx: &mut std::task::Context<'_>,
523 ) -> std::task::Poll<Self::Output> {
524 std::pin::Pin::new(&mut self.join_handle).poll(cx)
525 }
526}
527
528impl<T> std::fmt::Debug for TaskHandle<T> {
529 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
530 f.debug_struct("TaskHandle")
531 .field("join_handle", &"<JoinHandle>")
532 .field("cancel_token", &self.cancel_token)
533 .finish()
534 }
535}
536
537/// Trait for continuation tasks that execute after a failure
538///
539/// This trait allows tasks to define what should happen next after a failure,
540/// eliminating the need for complex type erasure and executor management.
541/// Tasks implement this trait to provide clean continuation logic.
542#[async_trait]
543pub trait Continuation: Send + Sync + std::fmt::Debug + std::any::Any {
544 /// Execute the continuation task after a failure
545 ///
546 /// This method is called when a task fails and a continuation is provided.
547 /// The implementation can perform retry logic, fallback operations,
548 /// transformations, or any other follow-up action.
549 /// Returns the result in a type-erased Box<dyn Any> for flexibility.
550 async fn execute(
551 &self,
552 cancel_token: CancellationToken,
553 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>>;
554}
555
556/// Error type that signals a task failed but provided a continuation
557///
558/// This error type contains a continuation task that can be executed as a follow-up.
559/// The task defines its own continuation logic through the Continuation trait.
560#[derive(Error, Debug)]
561#[error("Task failed with continuation: {source}")]
562pub struct FailedWithContinuation {
563 /// The underlying error that caused the task to fail
564 #[source]
565 pub source: anyhow::Error,
566 /// The continuation task for follow-up execution
567 pub continuation: Arc<dyn Continuation + Send + Sync + 'static>,
568}
569
570impl FailedWithContinuation {
571 /// Create a new FailedWithContinuation with a continuation task
572 ///
573 /// The continuation task defines its own execution logic through the Continuation trait.
574 pub fn new(
575 source: anyhow::Error,
576 continuation: Arc<dyn Continuation + Send + Sync + 'static>,
577 ) -> Self {
578 Self {
579 source,
580 continuation,
581 }
582 }
583
584 /// Create a FailedWithContinuation and convert it to anyhow::Error
585 ///
586 /// This is a convenience method for tasks to easily return continuation errors.
587 pub fn into_anyhow(
588 source: anyhow::Error,
589 continuation: Arc<dyn Continuation + Send + Sync + 'static>,
590 ) -> anyhow::Error {
591 anyhow::Error::new(Self::new(source, continuation))
592 }
593
594 /// Create a FailedWithContinuation from a simple async function (no cancellation support)
595 ///
596 /// This is a convenience method for creating continuation errors from simple async closures
597 /// that don't need to handle cancellation. The function will be executed when the
598 /// continuation is triggered.
599 ///
600 /// # Example
601 /// ```rust
602 /// # use dynamo_runtime::utils::tasks::tracker::*;
603 /// # use anyhow::anyhow;
604 /// # #[tokio::main]
605 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
606 /// let error = FailedWithContinuation::from_fn(
607 /// anyhow!("Initial task failed"),
608 /// || async {
609 /// println!("Retrying operation...");
610 /// Ok("retry_result".to_string())
611 /// }
612 /// );
613 /// # Ok(())
614 /// # }
615 /// ```
616 pub fn from_fn<F, Fut, T>(source: anyhow::Error, f: F) -> anyhow::Error
617 where
618 F: Fn() -> Fut + Send + Sync + 'static,
619 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
620 T: Send + 'static,
621 {
622 let continuation = Arc::new(FnContinuation { f: Box::new(f) });
623 Self::into_anyhow(source, continuation)
624 }
625
626 /// Create a FailedWithContinuation from a cancellable async function
627 ///
628 /// This is a convenience method for creating continuation errors from async closures
629 /// that can handle cancellation. The function receives a CancellationToken
630 /// and should check it periodically for cooperative cancellation.
631 ///
632 /// # Example
633 /// ```rust
634 /// # use dynamo_runtime::utils::tasks::tracker::*;
635 /// # use anyhow::anyhow;
636 /// # #[tokio::main]
637 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
638 /// let error = FailedWithContinuation::from_cancellable(
639 /// anyhow!("Initial task failed"),
640 /// |cancel_token| async move {
641 /// if cancel_token.is_cancelled() {
642 /// return Err(anyhow!("Cancelled"));
643 /// }
644 /// println!("Retrying operation with cancellation support...");
645 /// Ok("retry_result".to_string())
646 /// }
647 /// );
648 /// # Ok(())
649 /// # }
650 /// ```
651 pub fn from_cancellable<F, Fut, T>(source: anyhow::Error, f: F) -> anyhow::Error
652 where
653 F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
654 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
655 T: Send + 'static,
656 {
657 let continuation = Arc::new(CancellableFnContinuation { f: Box::new(f) });
658 Self::into_anyhow(source, continuation)
659 }
660}
661
662/// Extension trait for extracting FailedWithContinuation from anyhow::Error
663///
664/// This trait provides methods to detect and extract continuation tasks
665/// from the type-erased anyhow::Error system.
666pub trait FailedWithContinuationExt {
667 /// Extract a continuation task if this error contains one
668 ///
669 /// Returns the continuation task if the error is a FailedWithContinuation,
670 /// None otherwise.
671 fn extract_continuation(&self) -> Option<Arc<dyn Continuation + Send + Sync + 'static>>;
672
673 /// Check if this error has a continuation
674 fn has_continuation(&self) -> bool;
675}
676
677impl FailedWithContinuationExt for anyhow::Error {
678 fn extract_continuation(&self) -> Option<Arc<dyn Continuation + Send + Sync + 'static>> {
679 // Try to downcast to FailedWithContinuation
680 if let Some(continuation_err) = self.downcast_ref::<FailedWithContinuation>() {
681 Some(continuation_err.continuation.clone())
682 } else {
683 None
684 }
685 }
686
687 fn has_continuation(&self) -> bool {
688 self.downcast_ref::<FailedWithContinuation>().is_some()
689 }
690}
691
692/// Implementation of Continuation for simple async functions (no cancellation support)
693struct FnContinuation<F> {
694 f: Box<F>,
695}
696
697impl<F> std::fmt::Debug for FnContinuation<F> {
698 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
699 f.debug_struct("FnContinuation")
700 .field("f", &"<closure>")
701 .finish()
702 }
703}
704
705#[async_trait]
706impl<F, Fut, T> Continuation for FnContinuation<F>
707where
708 F: Fn() -> Fut + Send + Sync + 'static,
709 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
710 T: Send + 'static,
711{
712 async fn execute(
713 &self,
714 _cancel_token: CancellationToken,
715 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
716 match (self.f)().await {
717 Ok(result) => TaskExecutionResult::Success(Box::new(result)),
718 Err(error) => TaskExecutionResult::Error(error),
719 }
720 }
721}
722
723/// Implementation of Continuation for cancellable async functions
724struct CancellableFnContinuation<F> {
725 f: Box<F>,
726}
727
728impl<F> std::fmt::Debug for CancellableFnContinuation<F> {
729 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
730 f.debug_struct("CancellableFnContinuation")
731 .field("f", &"<closure>")
732 .finish()
733 }
734}
735
736#[async_trait]
737impl<F, Fut, T> Continuation for CancellableFnContinuation<F>
738where
739 F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
740 Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
741 T: Send + 'static,
742{
743 async fn execute(
744 &self,
745 cancel_token: CancellationToken,
746 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
747 match (self.f)(cancel_token).await {
748 Ok(result) => TaskExecutionResult::Success(Box::new(result)),
749 Err(error) => TaskExecutionResult::Error(error),
750 }
751 }
752}
753
754/// Common scheduling policies for task execution
755///
756/// These enums provide convenient access to built-in scheduling policies
757/// without requiring manual construction of policy objects.
758///
759/// ## Cancellation Semantics
760///
761/// All schedulers follow the same cancellation behavior:
762/// - Respect cancellation tokens before resource allocation (permits, etc.)
763/// - Once task execution begins, always await completion
764/// - Let tasks handle their own cancellation internally
765#[derive(Debug, Clone)]
766pub enum SchedulingPolicy {
767 /// No concurrency limits - execute all tasks immediately
768 Unlimited,
769 /// Semaphore-based concurrency limiting
770 Semaphore(usize),
771 // TODO: Future scheduling policies to implement
772 //
773 // /// Token bucket rate limiting with burst capacity
774 // /// Implementation: Use tokio::time::interval for refill, AtomicU64 for tokens.
775 // /// acquire() decrements tokens, schedule() waits for refill if empty.
776 // /// Burst allows temporary spikes above steady rate.
777 // /// struct: { rate: f64, burst: usize, tokens: AtomicU64, last_refill: Mutex<Instant> }
778 // /// Example: TokenBucket { rate: 10.0, burst: 5 } = 10 tasks/sec, burst up to 5
779 // TokenBucket { rate: f64, burst: usize },
780 //
781 // /// Weighted fair scheduling across multiple priority classes
782 // /// Implementation: Maintain separate VecDeque for each priority class.
783 // /// Use weighted round-robin: serve N tasks from high, M from normal, etc.
784 // /// Track deficit counters to ensure fairness over time.
785 // /// struct: { queues: HashMap<String, VecDeque<Task>>, weights: Vec<(String, u32)> }
786 // /// Example: WeightedFair { weights: vec![("high", 70), ("normal", 20), ("low", 10)] }
787 // WeightedFair { weights: Vec<(String, u32)> },
788 //
789 // /// Memory-aware scheduling that limits tasks based on available memory
790 // /// Implementation: Monitor system memory via /proc/meminfo or sysinfo crate.
791 // /// Pause scheduling when available memory < threshold, resume when memory freed.
792 // /// Use exponential backoff for memory checks to avoid overhead.
793 // /// struct: { max_memory_mb: usize, check_interval: Duration, semaphore: Semaphore }
794 // MemoryAware { max_memory_mb: usize },
795 //
796 // /// CPU-aware scheduling that adjusts concurrency based on CPU load
797 // /// Implementation: Sample system load via sysinfo crate every N seconds.
798 // /// Dynamically resize internal semaphore permits based on load average.
799 // /// Use PID controller for smooth adjustments, avoid oscillation.
800 // /// struct: { max_cpu_percent: f32, permits: Arc<Semaphore>, sampler: tokio::task }
801 // CpuAware { max_cpu_percent: f32 },
802 //
803 // /// Adaptive scheduler that automatically adjusts concurrency based on performance
804 // /// Implementation: Track task latency and throughput in sliding windows.
805 // /// Increase permits if latency low & throughput stable, decrease if latency spikes.
806 // /// Use additive increase, multiplicative decrease (AIMD) algorithm.
807 // /// struct: { permits: AtomicUsize, latency_tracker: RingBuffer, throughput_tracker: RingBuffer }
808 // Adaptive { initial_permits: usize },
809 //
810 // /// Throttling scheduler that enforces minimum time between task starts
811 // /// Implementation: Store last_execution time in AtomicU64 (unix timestamp).
812 // /// Before scheduling, check elapsed time and tokio::time::sleep if needed.
813 // /// Useful for rate-limiting API calls to external services.
814 // /// struct: { min_interval: Duration, last_execution: AtomicU64 }
815 // Throttling { min_interval_ms: u64 },
816 //
817 // /// Batch scheduler that groups tasks and executes them together
818 // /// Implementation: Collect tasks in Vec<Task>, use tokio::time::timeout for max_wait.
819 // /// Execute batch when size reached OR timeout expires, whichever first.
820 // /// Use futures::future::join_all for parallel execution within batch.
821 // /// struct: { batch_size: usize, max_wait: Duration, pending: Mutex<Vec<Task>> }
822 // Batch { batch_size: usize, max_wait_ms: u64 },
823 //
824 // /// Priority-based scheduler with separate queues for different priority levels
825 // /// Implementation: Three separate semaphores for high/normal/low priorities.
826 // /// Always serve high before normal, normal before low (strict priority).
827 // /// Add starvation protection: promote normal->high after timeout.
828 // /// struct: { high_sem: Semaphore, normal_sem: Semaphore, low_sem: Semaphore }
829 // Priority { high: usize, normal: usize, low: usize },
830 //
831 // /// Backpressure-aware scheduler that monitors downstream capacity
832 // /// Implementation: Track external queue depth via provided callback/metric.
833 // /// Pause scheduling when queue_threshold exceeded, resume after pause_duration.
834 // /// Use exponential backoff for repeated backpressure events.
835 // /// struct: { queue_checker: Arc<dyn Fn() -> usize>, threshold: usize, pause_duration: Duration }
836 // Backpressure { queue_threshold: usize, pause_duration_ms: u64 },
837}
838
839/// Trait for implementing error handling policies
840///
841/// Error policies are lightweight, synchronous decision-makers that analyze task failures
842/// and return an ErrorResponse telling the TaskTracker what action to take. The TaskTracker
843/// handles all the actual work (cancellation, metrics, etc.) based on the policy's response.
844///
845/// ## Key Design Principles
846/// - **Synchronous**: Policies make fast decisions without async operations
847/// - **Stateless where possible**: TaskTracker manages cancellation tokens and state
848/// - **Composable**: Policies can be combined and nested in hierarchies
849/// - **Focused**: Each policy handles one specific error pattern or strategy
850///
851/// Per-task error handling context
852///
853/// Provides context information and state management for error policies.
854/// The state field allows policies to maintain per-task state across multiple error attempts.
855pub struct OnErrorContext {
856 /// Number of times this task has been attempted (starts at 1)
857 pub attempt_count: u32,
858 /// Unique identifier of the failed task
859 pub task_id: TaskId,
860 /// Full execution context with access to scheduler, metrics, etc.
861 pub execution_context: TaskExecutionContext,
862 /// Optional per-task state managed by the policy (None for stateless policies)
863 pub state: Option<Box<dyn std::any::Any + Send + 'static>>,
864}
865
866/// Error handling policy trait for task failures
867///
868/// Policies define how the TaskTracker responds to task failures.
869/// They can be stateless (like LogOnlyPolicy) or maintain per-task state
870/// (like ThresholdCancelPolicy with per-task failure counters).
871pub trait OnErrorPolicy: Send + Sync + std::fmt::Debug {
872 /// Create a child policy for a child tracker
873 ///
874 /// This allows policies to maintain hierarchical relationships,
875 /// such as child cancellation tokens or shared circuit breaker state.
876 fn create_child(&self) -> Arc<dyn OnErrorPolicy>;
877
878 /// Create per-task context state (None if policy is stateless)
879 ///
880 /// This method is called once per task when the first error occurs.
881 /// Stateless policies should return None to avoid unnecessary heap allocations.
882 /// Stateful policies should return Some(Box::new(initial_state)).
883 ///
884 /// # Returns
885 /// * `None` - Policy doesn't need per-task state (no heap allocation)
886 /// * `Some(state)` - Initial state for this task (heap allocated when needed)
887 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>>;
888
889 /// Handle a task failure and return the desired response
890 ///
891 /// # Arguments
892 /// * `error` - The error that occurred
893 /// * `context` - Mutable context with attempt count, task info, and optional state
894 ///
895 /// # Returns
896 /// ErrorResponse indicating how the TaskTracker should handle this failure
897 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse;
898
899 /// Should continuations be allowed for this error?
900 ///
901 /// This method is called before checking if a task provided a continuation to determine
902 /// whether the policy allows continuation-based retries at all. If this returns `false`,
903 /// any `FailedWithContinuation` will be ignored and the error will be handled through
904 /// the normal policy response.
905 ///
906 /// # Arguments
907 /// * `error` - The error that occurred
908 /// * `context` - Per-task context with attempt count and state
909 ///
910 /// # Returns
911 /// * `true` - Allow continuations, check for `FailedWithContinuation` (default)
912 /// * `false` - Reject continuations, handle through normal policy response
913 fn allow_continuation(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
914 true // Default: allow continuations
915 }
916
917 /// Should this continuation be rescheduled through the scheduler?
918 ///
919 /// This method is called when a continuation is about to be executed to determine
920 /// whether it should go through the scheduler's acquisition process again or execute
921 /// immediately with the current execution permission.
922 ///
923 /// **What this means:**
924 /// - **Don't reschedule (`false`)**: Execute continuation immediately with current permission
925 /// - **Reschedule (`true`)**: Release current permission, go through scheduler again
926 ///
927 /// Rescheduling means the continuation will be subject to the scheduler's policies
928 /// again (rate limiting, concurrency limits, backoff delays, etc.).
929 ///
930 /// # Arguments
931 /// * `error` - The error that triggered this retry decision
932 /// * `context` - Per-task context with attempt count and state
933 ///
934 /// # Returns
935 /// * `false` - Execute continuation immediately (default, efficient)
936 /// * `true` - Reschedule through scheduler (for delays, rate limiting, backoff)
937 fn should_reschedule(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
938 false // Default: immediate execution
939 }
940}
941
942/// Common error handling policies for task failure management
943///
944/// These enums provide convenient access to built-in error handling policies
945/// without requiring manual construction of policy objects.
946#[derive(Debug, Clone)]
947pub enum ErrorPolicy {
948 /// Log errors but continue execution - no cancellation
949 LogOnly,
950 /// Cancel all tasks on any error (using default error patterns)
951 CancelOnError,
952 /// Cancel all tasks when specific error patterns are encountered
953 CancelOnPatterns(Vec<String>),
954 /// Cancel after a threshold number of failures
955 CancelOnThreshold { max_failures: usize },
956 /// Cancel when failure rate exceeds threshold within time window
957 CancelOnRate {
958 max_failure_rate: f32,
959 window_secs: u64,
960 },
961 // TODO: Future error policies to implement
962 //
963 // /// Retry failed tasks with exponential backoff
964 // /// Implementation: Store original task in retry queue with attempt count.
965 // /// Use tokio::time::sleep for delays: backoff_ms * 2^attempt.
966 // /// Spawn retry as new task, preserve original task_id for tracing.
967 // /// Need task cloning support in scheduler interface.
968 // /// struct: { max_attempts: usize, backoff_ms: u64, retry_queue: VecDeque<(Task, u32)> }
969 // Retry { max_attempts: usize, backoff_ms: u64 },
970 //
971 // /// Send failed tasks to a dead letter queue for later processing
972 // /// Implementation: Use tokio::sync::mpsc::channel for queue.
973 // /// Serialize task info (id, error, payload) for persistence.
974 // /// Background worker drains queue to external storage (Redis/DB).
975 // /// Include retry count and timestamps for debugging.
976 // /// struct: { queue: mpsc::Sender<DeadLetterItem>, storage: Arc<dyn DeadLetterStorage> }
977 // DeadLetter { queue_name: String },
978 //
979 // /// Execute fallback logic when tasks fail
980 // /// Implementation: Store fallback closure in Arc for thread-safety.
981 // /// Execute fallback in same context as failed task (inherit cancel token).
982 // /// Track fallback success/failure separately from original task metrics.
983 // /// Consider using enum for common fallback patterns (default value, noop, etc).
984 // /// struct: { fallback_fn: Arc<dyn Fn(TaskId, Error) -> BoxFuture<Result<()>>> }
985 // Fallback { fallback_fn: Arc<dyn Fn() -> BoxFuture<'static, Result<()>>> },
986 //
987 // /// Circuit breaker pattern - stop executing after threshold failures
988 // /// Implementation: Track state (Closed/Open/HalfOpen) with AtomicU8.
989 // /// Use failure window (last N tasks) or time window for threshold.
990 // /// In Open state, reject tasks immediately, use timer for recovery.
991 // /// In HalfOpen, allow one test task to check if issues resolved.
992 // /// struct: { state: AtomicU8, failure_count: AtomicU64, last_failure: AtomicU64 }
993 // CircuitBreaker { failure_threshold: usize, timeout_secs: u64 },
994 //
995 // /// Resource protection policy that monitors memory/CPU usage
996 // /// Implementation: Background task samples system resources via sysinfo.
997 // /// Cancel tracker when memory > threshold, use process-level monitoring.
998 // /// Implement graceful degradation: warn at 80%, cancel at 90%.
999 // /// Include both system-wide and process-specific thresholds.
1000 // /// struct: { monitor_task: JoinHandle, thresholds: ResourceThresholds, cancel_token: CancellationToken }
1001 // ResourceProtection { max_memory_mb: usize },
1002 //
1003 // /// Timeout policy that cancels tasks exceeding maximum duration
1004 // /// Implementation: Wrap each task with tokio::time::timeout.
1005 // /// Store task start time, check duration in on_error callback.
1006 // /// Distinguish timeout errors from other task failures in metrics.
1007 // /// Consider per-task or global timeout strategies.
1008 // /// struct: { max_duration: Duration, timeout_tracker: HashMap<TaskId, Instant> }
1009 // Timeout { max_duration_secs: u64 },
1010 //
1011 // /// Sampling policy that only logs a percentage of errors
1012 // /// Implementation: Use thread-local RNG for sampling decisions.
1013 // /// Hash task_id for deterministic sampling (same task always sampled).
1014 // /// Store sample rate as f32, compare with rand::random::<f32>().
1015 // /// Include rate in log messages for context.
1016 // /// struct: { sample_rate: f32, rng: ThreadLocal<RefCell<SmallRng>> }
1017 // Sampling { sample_rate: f32 },
1018 //
1019 // /// Aggregating policy that batches error reports
1020 // /// Implementation: Collect errors in Vec, flush on size or time trigger.
1021 // /// Use tokio::time::interval for periodic flushing.
1022 // /// Group errors by type/pattern for better insights.
1023 // /// Include error frequency and rate statistics in reports.
1024 // /// struct: { window: Duration, batch: Mutex<Vec<ErrorEntry>>, flush_task: JoinHandle }
1025 // Aggregating { window_secs: u64, max_batch_size: usize },
1026 //
1027 // /// Alerting policy that sends notifications on error patterns
1028 // /// Implementation: Use reqwest for webhook HTTP calls.
1029 // /// Rate-limit alerts to prevent spam (max N per minute).
1030 // /// Include error context, task info, and system metrics in payload.
1031 // /// Support multiple notification channels (webhook, email, slack).
1032 // /// struct: { client: reqwest::Client, rate_limiter: RateLimiter, alert_config: AlertConfig }
1033 // Alerting { webhook_url: String, severity_threshold: String },
1034}
1035
1036/// Response type for error handling policies
1037///
1038/// This enum defines how the TaskTracker should respond to task failures.
1039/// Currently provides minimal functionality with planned extensions for common patterns.
1040#[derive(Debug)]
1041pub enum ErrorResponse {
1042 /// Just fail this task - error will be logged/counted, but tracker continues
1043 Fail,
1044
1045 /// Shutdown this tracker and all child trackers
1046 Shutdown,
1047
1048 /// Execute custom error handling logic with full context access
1049 Custom(Box<dyn OnErrorAction>),
1050 // TODO: Future specialized error responses to implement:
1051 //
1052 // /// Retry the failed task with configurable strategy
1053 // /// Implementation: Add RetryStrategy trait with delay(), should_continue(attempt_count),
1054 // /// release_and_reacquire_resources() methods. TaskTracker handles retry loop with
1055 // /// attempt counting and resource management. Supports exponential backoff, jitter.
1056 // /// Usage: ErrorResponse::Retry(Box::new(ExponentialBackoff { max_attempts: 3, base_delay: 100ms }))
1057 // Retry(Box<dyn RetryStrategy>),
1058 //
1059 // /// Execute fallback logic, then follow secondary action
1060 // /// Implementation: Add FallbackAction trait with execute(error, task_id) -> Result<(), Error>.
1061 // /// Execute fallback first, then recursively handle the 'then' response based on fallback result.
1062 // /// Enables patterns like: try fallback, if it works continue, if it fails retry original task.
1063 // /// Usage: ErrorResponse::Fallback { fallback: Box::new(DefaultValue(42)), then: Box::new(ErrorResponse::Continue) }
1064 // Fallback { fallback: Box<dyn FallbackAction>, then: Box<ErrorResponse> },
1065 //
1066 // /// Restart task with preserved state (for long-running/stateful tasks)
1067 // /// Implementation: Add TaskState trait for serialize/deserialize state, RestartStrategy trait
1068 // /// with create_continuation_task(state) -> Future. Task saves checkpoints during execution,
1069 // /// on error returns StatefulTaskError containing preserved state. Policy can restart from checkpoint.
1070 // /// Usage: ErrorResponse::RestartWithState { state: checkpointed_state, strategy: Box::new(CheckpointRestart { ... }) }
1071 // RestartWithState { state: Box<dyn TaskState>, strategy: Box<dyn RestartStrategy> },
1072}
1073
1074/// Trait for implementing custom error handling actions
1075///
1076/// This provides full access to the task execution context for complex error handling
1077/// scenarios that don't fit into the built-in response patterns.
1078#[async_trait]
1079pub trait OnErrorAction: Send + Sync + std::fmt::Debug {
1080 /// Execute custom error handling logic
1081 ///
1082 /// # Arguments
1083 /// * `error` - The error that caused the task to fail
1084 /// * `task_id` - Unique identifier of the failed task
1085 /// * `attempt_count` - Number of times this task has been attempted (starts at 1)
1086 /// * `context` - Full execution context with access to scheduler, metrics, etc.
1087 ///
1088 /// # Returns
1089 /// ActionResult indicating what the TaskTracker should do next
1090 async fn execute(
1091 &self,
1092 error: &anyhow::Error,
1093 task_id: TaskId,
1094 attempt_count: u32,
1095 context: &TaskExecutionContext,
1096 ) -> ActionResult;
1097}
1098
1099/// Scheduler execution guard state for conditional re-acquisition during task retry loops
1100///
1101/// This controls whether a continuation should reuse the current scheduler execution permission
1102/// or go through the scheduler's acquisition process again.
1103#[derive(Debug, Clone, PartialEq, Eq)]
1104enum GuardState {
1105 /// Keep the current scheduler execution permission for immediate continuation
1106 ///
1107 /// The continuation will execute immediately without going through the scheduler again.
1108 /// This is efficient for simple retries that don't need delays or rate limiting.
1109 Keep,
1110
1111 /// Release current permission and re-acquire through scheduler before continuation
1112 ///
1113 /// The continuation will be subject to the scheduler's policies again (concurrency limits,
1114 /// rate limiting, backoff delays, etc.). Use this for implementing retry delays or
1115 /// when the scheduler needs to apply its policies to the retry attempt.
1116 Reschedule,
1117}
1118
1119/// Result of a custom error action execution
1120#[derive(Debug)]
1121pub enum ActionResult {
1122 /// Just fail this task (error was logged/handled by policy)
1123 ///
1124 /// This means the policy has handled the error appropriately (e.g., logged it,
1125 /// updated metrics, etc.) and the task should fail with this error.
1126 /// The task execution terminates here.
1127 Fail,
1128
1129 /// Continue execution with the provided task
1130 ///
1131 /// This provides a new executable to continue the retry loop with.
1132 /// The task execution continues with the provided continuation.
1133 Continue {
1134 continuation: Arc<dyn Continuation + Send + Sync + 'static>,
1135 },
1136
1137 /// Shutdown this tracker and all child trackers
1138 ///
1139 /// This triggers shutdown of the entire tracker hierarchy.
1140 /// All running and pending tasks will be cancelled.
1141 Shutdown,
1142}
1143
1144/// Execution context provided to custom error actions
1145///
1146/// This gives custom actions full access to the task execution environment
1147/// for implementing complex error handling scenarios.
1148pub struct TaskExecutionContext {
1149 /// Scheduler for reacquiring resources or checking state
1150 pub scheduler: Arc<dyn TaskScheduler>,
1151
1152 /// Metrics for custom tracking
1153 pub metrics: Arc<dyn HierarchicalTaskMetrics>,
1154 // TODO: Future context additions:
1155 // pub guard: Box<dyn ResourceGuard>, // Current resource guard (needs Debug impl)
1156 // pub cancel_token: CancellationToken, // For implementing custom cancellation
1157 // pub task_recreation: Box<dyn TaskRecreator>, // For implementing retry/restart
1158}
1159
1160/// Result of task execution - unified for both regular and cancellable tasks
1161#[derive(Debug)]
1162pub enum TaskExecutionResult<T> {
1163 /// Task completed successfully
1164 Success(T),
1165 /// Task was cancelled (only possible for cancellable tasks)
1166 Cancelled,
1167 /// Task failed with an error
1168 Error(anyhow::Error),
1169}
1170
1171/// Trait for executing different types of tasks in a unified way
1172#[async_trait]
1173trait TaskExecutor<T>: Send {
1174 /// Execute the task with the given cancellation token
1175 async fn execute(&mut self, cancel_token: CancellationToken) -> TaskExecutionResult<T>;
1176}
1177
1178/// Task executor for regular (non-cancellable) tasks
1179struct RegularTaskExecutor<F, T>
1180where
1181 F: Future<Output = Result<T>> + Send + 'static,
1182 T: Send + 'static,
1183{
1184 future: Option<F>,
1185 _phantom: std::marker::PhantomData<T>,
1186}
1187
1188impl<F, T> RegularTaskExecutor<F, T>
1189where
1190 F: Future<Output = Result<T>> + Send + 'static,
1191 T: Send + 'static,
1192{
1193 fn new(future: F) -> Self {
1194 Self {
1195 future: Some(future),
1196 _phantom: std::marker::PhantomData,
1197 }
1198 }
1199}
1200
1201#[async_trait]
1202impl<F, T> TaskExecutor<T> for RegularTaskExecutor<F, T>
1203where
1204 F: Future<Output = Result<T>> + Send + 'static,
1205 T: Send + 'static,
1206{
1207 async fn execute(&mut self, _cancel_token: CancellationToken) -> TaskExecutionResult<T> {
1208 if let Some(future) = self.future.take() {
1209 match future.await {
1210 Ok(value) => TaskExecutionResult::Success(value),
1211 Err(error) => TaskExecutionResult::Error(error),
1212 }
1213 } else {
1214 // This should never happen since regular tasks don't support retry
1215 TaskExecutionResult::Error(anyhow::anyhow!("Regular task already consumed"))
1216 }
1217 }
1218}
1219
1220/// Task executor for cancellable tasks
1221struct CancellableTaskExecutor<F, Fut, T>
1222where
1223 F: FnMut(CancellationToken) -> Fut + Send + 'static,
1224 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
1225 T: Send + 'static,
1226{
1227 task_fn: F,
1228}
1229
1230impl<F, Fut, T> CancellableTaskExecutor<F, Fut, T>
1231where
1232 F: FnMut(CancellationToken) -> Fut + Send + 'static,
1233 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
1234 T: Send + 'static,
1235{
1236 fn new(task_fn: F) -> Self {
1237 Self { task_fn }
1238 }
1239}
1240
1241#[async_trait]
1242impl<F, Fut, T> TaskExecutor<T> for CancellableTaskExecutor<F, Fut, T>
1243where
1244 F: FnMut(CancellationToken) -> Fut + Send + 'static,
1245 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
1246 T: Send + 'static,
1247{
1248 async fn execute(&mut self, cancel_token: CancellationToken) -> TaskExecutionResult<T> {
1249 let future = (self.task_fn)(cancel_token);
1250 match future.await {
1251 CancellableTaskResult::Ok(value) => TaskExecutionResult::Success(value),
1252 CancellableTaskResult::Cancelled => TaskExecutionResult::Cancelled,
1253 CancellableTaskResult::Err(error) => TaskExecutionResult::Error(error),
1254 }
1255 }
1256}
1257
1258/// Common functionality for policy Arc construction
1259///
1260/// This trait provides a standardized `new_arc()` method for all policy types,
1261/// eliminating the need for manual `Arc::new()` calls in client code.
1262pub trait ArcPolicy: Sized + Send + Sync + 'static {
1263 /// Create an Arc-wrapped instance of this policy
1264 fn new_arc(self) -> Arc<Self> {
1265 Arc::new(self)
1266 }
1267}
1268
1269/// Unique identifier for a task
1270#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1271pub struct TaskId(Uuid);
1272
1273impl TaskId {
1274 fn new() -> Self {
1275 Self(Uuid::new_v4())
1276 }
1277}
1278
1279impl std::fmt::Display for TaskId {
1280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1281 write!(f, "task-{}", self.0)
1282 }
1283}
1284
1285/// Result of task execution
1286#[derive(Debug, Clone, PartialEq)]
1287pub enum CompletionStatus {
1288 /// Task completed successfully
1289 Ok,
1290 /// Task was cancelled before or during execution
1291 Cancelled,
1292 /// Task failed with an error
1293 Failed(String),
1294}
1295
1296/// Result type for cancellable tasks that explicitly tracks cancellation
1297#[derive(Debug)]
1298pub enum CancellableTaskResult<T> {
1299 /// Task completed successfully
1300 Ok(T),
1301 /// Task was cancelled (either via token or shutdown)
1302 Cancelled,
1303 /// Task failed with an error
1304 Err(anyhow::Error),
1305}
1306
1307/// Result of scheduling a task
1308#[derive(Debug)]
1309pub enum SchedulingResult<T> {
1310 /// Task was executed and completed
1311 Execute(T),
1312 /// Task was cancelled before execution
1313 Cancelled,
1314 /// Task was rejected due to scheduling policy
1315 Rejected(String),
1316}
1317
1318/// Resource guard that manages task execution
1319///
1320/// This trait enforces proper cancellation semantics by separating resource
1321/// management from task execution. Once a guard is acquired, task execution
1322/// must always run to completion.
1323/// Resource guard for task execution
1324///
1325/// This trait represents resources (permits, slots, etc.) acquired from a scheduler
1326/// that must be held during task execution. The guard automatically releases
1327/// resources when dropped, implementing proper RAII semantics.
1328///
1329/// Guards are returned by `TaskScheduler::acquire_execution_slot()` and must
1330/// be held in scope while the task executes to ensure resources remain allocated.
1331pub trait ResourceGuard: Send + 'static {
1332 // Marker trait - resources are released via Drop on the concrete type
1333}
1334
1335/// Trait for implementing task scheduling policies
1336///
1337/// This trait enforces proper cancellation semantics by splitting resource
1338/// acquisition (which can be cancelled) from task execution (which cannot).
1339///
1340/// ## Design Philosophy
1341///
1342/// Tasks may or may not support cancellation (depending on whether they were
1343/// created with `spawn_cancellable` or regular `spawn`). This split design ensures:
1344///
1345/// - **Resource acquisition**: Can respect cancellation tokens to avoid unnecessary allocation
1346/// - **Task execution**: Always runs to completion; tasks handle their own cancellation
1347///
1348/// This makes it impossible to accidentally interrupt task execution with `tokio::select!`.
1349#[async_trait]
1350pub trait TaskScheduler: Send + Sync + std::fmt::Debug {
1351 /// Acquire resources needed for task execution and return a guard
1352 ///
1353 /// This method handles resource allocation (permits, queue slots, etc.) and
1354 /// can respect cancellation tokens to avoid unnecessary resource consumption.
1355 ///
1356 /// ## Cancellation Behavior
1357 ///
1358 /// The `cancel_token` is used for scheduler-level cancellation (e.g., "don't start new work").
1359 /// If cancellation is requested before or during resource acquisition, this method
1360 /// should return `SchedulingResult::Cancelled`.
1361 ///
1362 /// # Arguments
1363 /// * `cancel_token` - [`CancellationToken`] for scheduler-level cancellation
1364 ///
1365 /// # Returns
1366 /// * `SchedulingResult::Execute(guard)` - Resources acquired, ready to execute
1367 /// * `SchedulingResult::Cancelled` - Cancelled before or during resource acquisition
1368 /// * `SchedulingResult::Rejected(reason)` - Resources unavailable or policy violation
1369 async fn acquire_execution_slot(
1370 &self,
1371 cancel_token: CancellationToken,
1372 ) -> SchedulingResult<Box<dyn ResourceGuard>>;
1373}
1374
1375/// Trait for hierarchical task metrics that supports aggregation up the tracker tree
1376///
1377/// This trait provides different implementations for root and child trackers:
1378/// - Root trackers integrate with Prometheus metrics for observability
1379/// - Child trackers chain metric updates up to their parents for aggregation
1380/// - All implementations maintain thread-safe atomic operations
1381pub trait HierarchicalTaskMetrics: Send + Sync + std::fmt::Debug {
1382 /// Increment issued task counter
1383 fn increment_issued(&self);
1384
1385 /// Increment started task counter
1386 fn increment_started(&self);
1387
1388 /// Increment success counter
1389 fn increment_success(&self);
1390
1391 /// Increment cancelled counter
1392 fn increment_cancelled(&self);
1393
1394 /// Increment failed counter
1395 fn increment_failed(&self);
1396
1397 /// Increment rejected counter
1398 fn increment_rejected(&self);
1399
1400 /// Get current issued count (local to this tracker)
1401 fn issued(&self) -> u64;
1402
1403 /// Get current started count (local to this tracker)
1404 fn started(&self) -> u64;
1405
1406 /// Get current success count (local to this tracker)
1407 fn success(&self) -> u64;
1408
1409 /// Get current cancelled count (local to this tracker)
1410 fn cancelled(&self) -> u64;
1411
1412 /// Get current failed count (local to this tracker)
1413 fn failed(&self) -> u64;
1414
1415 /// Get current rejected count (local to this tracker)
1416 fn rejected(&self) -> u64;
1417
1418 /// Get total completed tasks (success + cancelled + failed + rejected)
1419 fn total_completed(&self) -> u64 {
1420 self.success() + self.cancelled() + self.failed() + self.rejected()
1421 }
1422
1423 /// Get number of pending tasks (issued - completed)
1424 fn pending(&self) -> u64 {
1425 self.issued().saturating_sub(self.total_completed())
1426 }
1427
1428 /// Get the number of tasks that are currently active (started - completed)
1429 fn active(&self) -> u64 {
1430 self.started().saturating_sub(self.total_completed())
1431 }
1432
1433 /// Get number of tasks queued in scheduler (issued - started)
1434 fn queued(&self) -> u64 {
1435 self.issued().saturating_sub(self.started())
1436 }
1437}
1438
1439/// Task execution metrics for a tracker
1440#[derive(Debug, Default)]
1441pub struct TaskMetrics {
1442 /// Number of tasks issued/submitted (via spawn methods)
1443 pub issued_count: AtomicU64,
1444 /// Number of tasks that have started execution
1445 pub started_count: AtomicU64,
1446 /// Number of successfully completed tasks
1447 pub success_count: AtomicU64,
1448 /// Number of cancelled tasks
1449 pub cancelled_count: AtomicU64,
1450 /// Number of failed tasks
1451 pub failed_count: AtomicU64,
1452 /// Number of rejected tasks (by scheduler)
1453 pub rejected_count: AtomicU64,
1454}
1455
1456impl TaskMetrics {
1457 /// Create new metrics instance
1458 pub fn new() -> Self {
1459 Self::default()
1460 }
1461}
1462
1463impl HierarchicalTaskMetrics for TaskMetrics {
1464 /// Increment issued task counter
1465 fn increment_issued(&self) {
1466 self.issued_count.fetch_add(1, Ordering::Relaxed);
1467 }
1468
1469 /// Increment started task counter
1470 fn increment_started(&self) {
1471 self.started_count.fetch_add(1, Ordering::Relaxed);
1472 }
1473
1474 /// Increment success counter
1475 fn increment_success(&self) {
1476 self.success_count.fetch_add(1, Ordering::Relaxed);
1477 }
1478
1479 /// Increment cancelled counter
1480 fn increment_cancelled(&self) {
1481 self.cancelled_count.fetch_add(1, Ordering::Relaxed);
1482 }
1483
1484 /// Increment failed counter
1485 fn increment_failed(&self) {
1486 self.failed_count.fetch_add(1, Ordering::Relaxed);
1487 }
1488
1489 /// Increment rejected counter
1490 fn increment_rejected(&self) {
1491 self.rejected_count.fetch_add(1, Ordering::Relaxed);
1492 }
1493
1494 /// Get current issued count
1495 fn issued(&self) -> u64 {
1496 self.issued_count.load(Ordering::Relaxed)
1497 }
1498
1499 /// Get current started count
1500 fn started(&self) -> u64 {
1501 self.started_count.load(Ordering::Relaxed)
1502 }
1503
1504 /// Get current success count
1505 fn success(&self) -> u64 {
1506 self.success_count.load(Ordering::Relaxed)
1507 }
1508
1509 /// Get current cancelled count
1510 fn cancelled(&self) -> u64 {
1511 self.cancelled_count.load(Ordering::Relaxed)
1512 }
1513
1514 /// Get current failed count
1515 fn failed(&self) -> u64 {
1516 self.failed_count.load(Ordering::Relaxed)
1517 }
1518
1519 /// Get current rejected count
1520 fn rejected(&self) -> u64 {
1521 self.rejected_count.load(Ordering::Relaxed)
1522 }
1523}
1524
1525/// Root tracker metrics with Prometheus integration
1526///
1527/// This implementation maintains local counters and exposes them as Prometheus metrics
1528/// through the provided MetricsRegistry.
1529#[derive(Debug)]
1530pub struct PrometheusTaskMetrics {
1531 /// Prometheus metrics integration
1532 prometheus_issued: prometheus::IntCounter,
1533 prometheus_started: prometheus::IntCounter,
1534 prometheus_success: prometheus::IntCounter,
1535 prometheus_cancelled: prometheus::IntCounter,
1536 prometheus_failed: prometheus::IntCounter,
1537 prometheus_rejected: prometheus::IntCounter,
1538}
1539
1540impl PrometheusTaskMetrics {
1541 /// Create new root metrics with Prometheus integration
1542 ///
1543 /// # Arguments
1544 /// * `registry` - MetricsRegistry for creating Prometheus metrics
1545 /// * `component_name` - Name for the component/tracker (used in metric names)
1546 ///
1547 /// # Example
1548 /// ```rust
1549 /// # use std::sync::Arc;
1550 /// # use dynamo_runtime::utils::tasks::tracker::PrometheusTaskMetrics;
1551 /// # use dynamo_runtime::DistributedRuntime;
1552 /// # fn example(drt: &DistributedRuntime) -> anyhow::Result<()> {
1553 /// let metrics = PrometheusTaskMetrics::new(drt, "main_tracker")?;
1554 /// # Ok(())
1555 /// # }
1556 /// ```
1557 pub fn new<R: MetricsHierarchy>(registry: &R, component_name: &str) -> anyhow::Result<Self> {
1558 let metrics = registry.metrics();
1559 let issued_counter = metrics.create_intcounter(
1560 &format!("{}_{}", component_name, task_tracker::TASKS_ISSUED_TOTAL),
1561 "Total number of tasks issued/submitted",
1562 &[],
1563 )?;
1564
1565 let started_counter = metrics.create_intcounter(
1566 &format!("{}_{}", component_name, task_tracker::TASKS_STARTED_TOTAL),
1567 "Total number of tasks started",
1568 &[],
1569 )?;
1570
1571 let success_counter = metrics.create_intcounter(
1572 &format!("{}_{}", component_name, task_tracker::TASKS_SUCCESS_TOTAL),
1573 "Total number of successfully completed tasks",
1574 &[],
1575 )?;
1576
1577 let cancelled_counter = metrics.create_intcounter(
1578 &format!("{}_{}", component_name, task_tracker::TASKS_CANCELLED_TOTAL),
1579 "Total number of cancelled tasks",
1580 &[],
1581 )?;
1582
1583 let failed_counter = metrics.create_intcounter(
1584 &format!("{}_{}", component_name, task_tracker::TASKS_FAILED_TOTAL),
1585 "Total number of failed tasks",
1586 &[],
1587 )?;
1588
1589 let rejected_counter = metrics.create_intcounter(
1590 &format!("{}_{}", component_name, task_tracker::TASKS_REJECTED_TOTAL),
1591 "Total number of rejected tasks",
1592 &[],
1593 )?;
1594
1595 Ok(Self {
1596 prometheus_issued: issued_counter,
1597 prometheus_started: started_counter,
1598 prometheus_success: success_counter,
1599 prometheus_cancelled: cancelled_counter,
1600 prometheus_failed: failed_counter,
1601 prometheus_rejected: rejected_counter,
1602 })
1603 }
1604}
1605
1606impl HierarchicalTaskMetrics for PrometheusTaskMetrics {
1607 fn increment_issued(&self) {
1608 self.prometheus_issued.inc();
1609 }
1610
1611 fn increment_started(&self) {
1612 self.prometheus_started.inc();
1613 }
1614
1615 fn increment_success(&self) {
1616 self.prometheus_success.inc();
1617 }
1618
1619 fn increment_cancelled(&self) {
1620 self.prometheus_cancelled.inc();
1621 }
1622
1623 fn increment_failed(&self) {
1624 self.prometheus_failed.inc();
1625 }
1626
1627 fn increment_rejected(&self) {
1628 self.prometheus_rejected.inc();
1629 }
1630
1631 fn issued(&self) -> u64 {
1632 self.prometheus_issued.get()
1633 }
1634
1635 fn started(&self) -> u64 {
1636 self.prometheus_started.get()
1637 }
1638
1639 fn success(&self) -> u64 {
1640 self.prometheus_success.get()
1641 }
1642
1643 fn cancelled(&self) -> u64 {
1644 self.prometheus_cancelled.get()
1645 }
1646
1647 fn failed(&self) -> u64 {
1648 self.prometheus_failed.get()
1649 }
1650
1651 fn rejected(&self) -> u64 {
1652 self.prometheus_rejected.get()
1653 }
1654}
1655
1656/// Child tracker metrics that chain updates to parent
1657///
1658/// This implementation maintains local counters and automatically forwards
1659/// all metric updates to the parent tracker for hierarchical aggregation.
1660/// Holds a strong reference to parent metrics for optimal performance.
1661#[derive(Debug)]
1662struct ChildTaskMetrics {
1663 /// Local metrics for this tracker
1664 local_metrics: TaskMetrics,
1665 /// Strong reference to parent metrics for fast chaining
1666 /// Safe to hold since metrics don't own trackers - no circular references
1667 parent_metrics: Arc<dyn HierarchicalTaskMetrics>,
1668}
1669
1670impl ChildTaskMetrics {
1671 fn new(parent_metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
1672 Self {
1673 local_metrics: TaskMetrics::new(),
1674 parent_metrics,
1675 }
1676 }
1677}
1678
1679impl HierarchicalTaskMetrics for ChildTaskMetrics {
1680 fn increment_issued(&self) {
1681 self.local_metrics.increment_issued();
1682 self.parent_metrics.increment_issued();
1683 }
1684
1685 fn increment_started(&self) {
1686 self.local_metrics.increment_started();
1687 self.parent_metrics.increment_started();
1688 }
1689
1690 fn increment_success(&self) {
1691 self.local_metrics.increment_success();
1692 self.parent_metrics.increment_success();
1693 }
1694
1695 fn increment_cancelled(&self) {
1696 self.local_metrics.increment_cancelled();
1697 self.parent_metrics.increment_cancelled();
1698 }
1699
1700 fn increment_failed(&self) {
1701 self.local_metrics.increment_failed();
1702 self.parent_metrics.increment_failed();
1703 }
1704
1705 fn increment_rejected(&self) {
1706 self.local_metrics.increment_rejected();
1707 self.parent_metrics.increment_rejected();
1708 }
1709
1710 fn issued(&self) -> u64 {
1711 self.local_metrics.issued()
1712 }
1713
1714 fn started(&self) -> u64 {
1715 self.local_metrics.started()
1716 }
1717
1718 fn success(&self) -> u64 {
1719 self.local_metrics.success()
1720 }
1721
1722 fn cancelled(&self) -> u64 {
1723 self.local_metrics.cancelled()
1724 }
1725
1726 fn failed(&self) -> u64 {
1727 self.local_metrics.failed()
1728 }
1729
1730 fn rejected(&self) -> u64 {
1731 self.local_metrics.rejected()
1732 }
1733}
1734
1735/// Builder for creating child trackers with custom policies
1736///
1737/// Allows flexible customization of scheduling and error handling policies
1738/// for child trackers while maintaining parent-child relationships.
1739pub struct ChildTrackerBuilder<'parent> {
1740 parent: &'parent TaskTracker,
1741 scheduler: Option<Arc<dyn TaskScheduler>>,
1742 error_policy: Option<Arc<dyn OnErrorPolicy>>,
1743}
1744
1745impl<'parent> ChildTrackerBuilder<'parent> {
1746 /// Create a new ChildTrackerBuilder
1747 pub fn new(parent: &'parent TaskTracker) -> Self {
1748 Self {
1749 parent,
1750 scheduler: None,
1751 error_policy: None,
1752 }
1753 }
1754
1755 /// Set custom scheduler for the child tracker
1756 ///
1757 /// If not set, the child will inherit the parent's scheduler.
1758 ///
1759 /// # Arguments
1760 /// * `scheduler` - The scheduler to use for this child tracker
1761 ///
1762 /// # Example
1763 /// ```rust
1764 /// # use std::sync::Arc;
1765 /// # use tokio::sync::Semaphore;
1766 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler};
1767 /// # fn example(parent: &TaskTracker) {
1768 /// let child = parent.child_tracker_builder()
1769 /// .scheduler(SemaphoreScheduler::with_permits(5))
1770 /// .build().unwrap();
1771 /// # }
1772 /// ```
1773 pub fn scheduler(mut self, scheduler: Arc<dyn TaskScheduler>) -> Self {
1774 self.scheduler = Some(scheduler);
1775 self
1776 }
1777
1778 /// Set custom error policy for the child tracker
1779 ///
1780 /// If not set, the child will get a child policy from the parent's error policy
1781 /// (via `OnErrorPolicy::create_child()`).
1782 ///
1783 /// # Arguments
1784 /// * `error_policy` - The error policy to use for this child tracker
1785 ///
1786 /// # Example
1787 /// ```rust
1788 /// # use std::sync::Arc;
1789 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, LogOnlyPolicy};
1790 /// # fn example(parent: &TaskTracker) {
1791 /// let child = parent.child_tracker_builder()
1792 /// .error_policy(LogOnlyPolicy::new())
1793 /// .build().unwrap();
1794 /// # }
1795 /// ```
1796 pub fn error_policy(mut self, error_policy: Arc<dyn OnErrorPolicy>) -> Self {
1797 self.error_policy = Some(error_policy);
1798 self
1799 }
1800
1801 /// Build the child tracker with the specified configuration
1802 ///
1803 /// Creates a new child tracker with:
1804 /// - Custom or inherited scheduler
1805 /// - Custom or child error policy
1806 /// - Hierarchical metrics that chain to parent
1807 /// - Child cancellation token from the parent
1808 /// - Independent lifecycle from parent
1809 ///
1810 /// # Returns
1811 /// A new `Arc<TaskTracker>` configured as a child of the parent
1812 ///
1813 /// # Errors
1814 /// Returns an error if the parent tracker is already closed
1815 pub fn build(self) -> anyhow::Result<TaskTracker> {
1816 // Validate that parent tracker is still active
1817 if self.parent.is_closed() {
1818 return Err(anyhow::anyhow!(
1819 "Cannot create child tracker from closed parent tracker"
1820 ));
1821 }
1822
1823 let parent = self.parent.0.clone();
1824
1825 let child_cancel_token = parent.cancel_token.child_token();
1826 let child_metrics = Arc::new(ChildTaskMetrics::new(parent.metrics.clone()));
1827
1828 // Use provided scheduler or inherit from parent
1829 let scheduler = self.scheduler.unwrap_or_else(|| parent.scheduler.clone());
1830
1831 // Use provided error policy or create child from parent's
1832 let error_policy = self
1833 .error_policy
1834 .unwrap_or_else(|| parent.error_policy.create_child());
1835
1836 let child = Arc::new(TaskTrackerInner {
1837 tokio_tracker: TokioTaskTracker::new(),
1838 parent: None, // No parent reference needed for hierarchical operations
1839 scheduler,
1840 error_policy,
1841 metrics: child_metrics,
1842 cancel_token: child_cancel_token,
1843 children: RwLock::new(Vec::new()),
1844 });
1845
1846 // Register this child with the parent for hierarchical operations
1847 parent
1848 .children
1849 .write()
1850 .unwrap()
1851 .push(Arc::downgrade(&child));
1852
1853 // Periodically clean up dead children to prevent unbounded growth
1854 parent.cleanup_dead_children();
1855
1856 Ok(TaskTracker(child))
1857 }
1858}
1859
1860/// Internal data for TaskTracker
1861///
1862/// This struct contains all the actual state and functionality of a TaskTracker.
1863/// TaskTracker itself is just a wrapper around Arc<TaskTrackerInner>.
1864struct TaskTrackerInner {
1865 /// Tokio's task tracker for lifecycle management
1866 tokio_tracker: TokioTaskTracker,
1867 /// Parent tracker (None for root)
1868 parent: Option<Arc<TaskTrackerInner>>,
1869 /// Scheduling policy (shared with children by default)
1870 scheduler: Arc<dyn TaskScheduler>,
1871 /// Error handling policy (child-specific via create_child)
1872 error_policy: Arc<dyn OnErrorPolicy>,
1873 /// Metrics for this tracker
1874 metrics: Arc<dyn HierarchicalTaskMetrics>,
1875 /// Cancellation token for this tracker (always present)
1876 cancel_token: CancellationToken,
1877 /// List of child trackers for hierarchical operations
1878 children: RwLock<Vec<Weak<TaskTrackerInner>>>,
1879}
1880
1881/// Hierarchical task tracker with pluggable scheduling and error policies
1882///
1883/// TaskTracker provides a composable system for managing background tasks with:
1884/// - Configurable scheduling via [`TaskScheduler`] implementations
1885/// - Flexible error handling via [`OnErrorPolicy`] implementations
1886/// - Parent-child relationships with independent metrics
1887/// - Cancellation propagation and isolation
1888/// - Built-in cancellation token support
1889///
1890/// Built on top of `tokio_util::task::TaskTracker` for robust task lifecycle management.
1891///
1892/// # Example
1893///
1894/// ```rust
1895/// # use std::sync::Arc;
1896/// # use tokio::sync::Semaphore;
1897/// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy, CancellableTaskResult};
1898/// # async fn example() -> anyhow::Result<()> {
1899/// // Create a task tracker with semaphore-based scheduling
1900/// let scheduler = SemaphoreScheduler::with_permits(3);
1901/// let policy = LogOnlyPolicy::new();
1902/// let root = TaskTracker::builder()
1903/// .scheduler(scheduler)
1904/// .error_policy(policy)
1905/// .build()?;
1906///
1907/// // Spawn some tasks
1908/// let handle1 = root.spawn(async { Ok(1) });
1909/// let handle2 = root.spawn(async { Ok(2) });
1910///
1911/// // Get results and join all tasks
1912/// let result1 = handle1.await.unwrap().unwrap();
1913/// let result2 = handle2.await.unwrap().unwrap();
1914/// assert_eq!(result1, 1);
1915/// assert_eq!(result2, 2);
1916/// # Ok(())
1917/// # }
1918/// ```
1919#[derive(Clone)]
1920pub struct TaskTracker(Arc<TaskTrackerInner>);
1921
1922/// Builder for TaskTracker
1923#[derive(Default)]
1924pub struct TaskTrackerBuilder {
1925 scheduler: Option<Arc<dyn TaskScheduler>>,
1926 error_policy: Option<Arc<dyn OnErrorPolicy>>,
1927 metrics: Option<Arc<dyn HierarchicalTaskMetrics>>,
1928 cancel_token: Option<CancellationToken>,
1929}
1930
1931impl TaskTrackerBuilder {
1932 /// Set the scheduler for this TaskTracker
1933 pub fn scheduler(mut self, scheduler: Arc<dyn TaskScheduler>) -> Self {
1934 self.scheduler = Some(scheduler);
1935 self
1936 }
1937
1938 /// Set the error policy for this TaskTracker
1939 pub fn error_policy(mut self, error_policy: Arc<dyn OnErrorPolicy>) -> Self {
1940 self.error_policy = Some(error_policy);
1941 self
1942 }
1943
1944 /// Set custom metrics for this TaskTracker
1945 pub fn metrics(mut self, metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
1946 self.metrics = Some(metrics);
1947 self
1948 }
1949
1950 /// Set the cancellation token for this TaskTracker
1951 pub fn cancel_token(mut self, cancel_token: CancellationToken) -> Self {
1952 self.cancel_token = Some(cancel_token);
1953 self
1954 }
1955
1956 /// Build the TaskTracker
1957 pub fn build(self) -> anyhow::Result<TaskTracker> {
1958 let scheduler = self
1959 .scheduler
1960 .ok_or_else(|| anyhow::anyhow!("TaskTracker requires a scheduler"))?;
1961
1962 let error_policy = self
1963 .error_policy
1964 .ok_or_else(|| anyhow::anyhow!("TaskTracker requires an error policy"))?;
1965
1966 let metrics = self.metrics.unwrap_or_else(|| Arc::new(TaskMetrics::new()));
1967
1968 let cancel_token = self.cancel_token.unwrap_or_default();
1969
1970 let inner = TaskTrackerInner {
1971 tokio_tracker: TokioTaskTracker::new(),
1972 parent: None,
1973 scheduler,
1974 error_policy,
1975 metrics,
1976 cancel_token,
1977 children: RwLock::new(Vec::new()),
1978 };
1979
1980 Ok(TaskTracker(Arc::new(inner)))
1981 }
1982}
1983
1984impl TaskTracker {
1985 /// Create a new root task tracker using the builder pattern
1986 ///
1987 /// This is the preferred way to create new task trackers.
1988 ///
1989 /// # Example
1990 /// ```rust
1991 /// # use std::sync::Arc;
1992 /// # use tokio::sync::Semaphore;
1993 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
1994 /// # fn main() -> anyhow::Result<()> {
1995 /// let scheduler = SemaphoreScheduler::with_permits(10);
1996 /// let error_policy = LogOnlyPolicy::new();
1997 /// let tracker = TaskTracker::builder()
1998 /// .scheduler(scheduler)
1999 /// .error_policy(error_policy)
2000 /// .build()?;
2001 /// # Ok(())
2002 /// # }
2003 /// ```
2004 pub fn builder() -> TaskTrackerBuilder {
2005 TaskTrackerBuilder::default()
2006 }
2007
2008 /// Create a new root task tracker with simple parameters (legacy)
2009 ///
2010 /// This method is kept for backward compatibility. Use `builder()` for new code.
2011 /// Uses default metrics (no Prometheus integration).
2012 ///
2013 /// # Arguments
2014 /// * `scheduler` - Scheduling policy to use for all tasks
2015 /// * `error_policy` - Error handling policy for this tracker
2016 ///
2017 /// # Example
2018 /// ```rust
2019 /// # use std::sync::Arc;
2020 /// # use tokio::sync::Semaphore;
2021 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2022 /// # fn main() -> anyhow::Result<()> {
2023 /// let scheduler = SemaphoreScheduler::with_permits(10);
2024 /// let error_policy = LogOnlyPolicy::new();
2025 /// let tracker = TaskTracker::new(scheduler, error_policy)?;
2026 /// # Ok(())
2027 /// # }
2028 /// ```
2029 pub fn new(
2030 scheduler: Arc<dyn TaskScheduler>,
2031 error_policy: Arc<dyn OnErrorPolicy>,
2032 ) -> anyhow::Result<Self> {
2033 Self::builder()
2034 .scheduler(scheduler)
2035 .error_policy(error_policy)
2036 .build()
2037 }
2038
2039 /// Create a new root task tracker with Prometheus metrics integration
2040 ///
2041 /// # Arguments
2042 /// * `scheduler` - Scheduling policy to use for all tasks
2043 /// * `error_policy` - Error handling policy for this tracker
2044 /// * `registry` - MetricsRegistry for Prometheus integration
2045 /// * `component_name` - Name for this tracker component
2046 ///
2047 /// # Example
2048 /// ```rust
2049 /// # use std::sync::Arc;
2050 /// # use tokio::sync::Semaphore;
2051 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2052 /// # use dynamo_runtime::DistributedRuntime;
2053 /// # fn example(drt: &DistributedRuntime) -> anyhow::Result<()> {
2054 /// let scheduler = SemaphoreScheduler::with_permits(10);
2055 /// let error_policy = LogOnlyPolicy::new();
2056 /// let tracker = TaskTracker::new_with_prometheus(
2057 /// scheduler,
2058 /// error_policy,
2059 /// drt,
2060 /// "main_tracker"
2061 /// )?;
2062 /// # Ok(())
2063 /// # }
2064 /// ```
2065 pub fn new_with_prometheus<R: MetricsHierarchy>(
2066 scheduler: Arc<dyn TaskScheduler>,
2067 error_policy: Arc<dyn OnErrorPolicy>,
2068 registry: &R,
2069 component_name: &str,
2070 ) -> anyhow::Result<Self> {
2071 let prometheus_metrics = Arc::new(PrometheusTaskMetrics::new(registry, component_name)?);
2072
2073 Self::builder()
2074 .scheduler(scheduler)
2075 .error_policy(error_policy)
2076 .metrics(prometheus_metrics)
2077 .build()
2078 }
2079
2080 /// Create a child tracker that inherits scheduling policy
2081 ///
2082 /// The child tracker:
2083 /// - Gets its own independent tokio TaskTracker
2084 /// - Inherits the parent's scheduler
2085 /// - Gets a child error policy via `create_child()`
2086 /// - Has hierarchical metrics that chain to parent
2087 /// - Gets a child cancellation token from the parent
2088 /// - Is independent for cancellation (child cancellation doesn't affect parent)
2089 ///
2090 /// # Errors
2091 /// Returns an error if the parent tracker is already closed
2092 ///
2093 /// # Example
2094 /// ```rust
2095 /// # use std::sync::Arc;
2096 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2097 /// # fn example(root_tracker: TaskTracker) -> anyhow::Result<()> {
2098 /// let child_tracker = root_tracker.child_tracker()?;
2099 /// // Child inherits parent's policies but has separate metrics and lifecycle
2100 /// # Ok(())
2101 /// # }
2102 /// ```
2103 pub fn child_tracker(&self) -> anyhow::Result<TaskTracker> {
2104 Ok(TaskTracker(self.0.child_tracker()?))
2105 }
2106
2107 /// Create a child tracker builder for flexible customization
2108 ///
2109 /// The builder allows you to customize scheduling and error policies for the child tracker.
2110 /// If not specified, policies are inherited from the parent.
2111 ///
2112 /// # Example
2113 /// ```rust
2114 /// # use std::sync::Arc;
2115 /// # use tokio::sync::Semaphore;
2116 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2117 /// # fn example(root_tracker: TaskTracker) {
2118 /// // Custom scheduler, inherit error policy
2119 /// let child1 = root_tracker.child_tracker_builder()
2120 /// .scheduler(SemaphoreScheduler::with_permits(5))
2121 /// .build().unwrap();
2122 ///
2123 /// // Custom error policy, inherit scheduler
2124 /// let child2 = root_tracker.child_tracker_builder()
2125 /// .error_policy(LogOnlyPolicy::new())
2126 /// .build().unwrap();
2127 ///
2128 /// // Both custom
2129 /// let child3 = root_tracker.child_tracker_builder()
2130 /// .scheduler(SemaphoreScheduler::with_permits(3))
2131 /// .error_policy(LogOnlyPolicy::new())
2132 /// .build().unwrap();
2133 /// # }
2134 /// ```
2135 /// Spawn a new task
2136 ///
2137 /// The task will be wrapped with scheduling and error handling logic,
2138 /// then executed according to the configured policies. For tasks that
2139 /// need to inspect cancellation tokens, use [`spawn_cancellable`] instead.
2140 ///
2141 /// # Arguments
2142 /// * `future` - The async task to execute
2143 ///
2144 /// # Returns
2145 /// A [`TaskHandle`] that can be used to await completion and access the task's cancellation token
2146 ///
2147 /// # Panics
2148 /// Panics if the tracker has been closed. This indicates a programming error
2149 /// where tasks are being spawned after the tracker lifecycle has ended.
2150 ///
2151 /// # Example
2152 /// ```rust
2153 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2154 /// # async fn example(tracker: TaskTracker) -> anyhow::Result<()> {
2155 /// let handle = tracker.spawn(async {
2156 /// // Your async work here
2157 /// tokio::time::sleep(std::time::Duration::from_millis(100)).await;
2158 /// Ok(42)
2159 /// });
2160 ///
2161 /// // Access the task's cancellation token
2162 /// let cancel_token = handle.cancellation_token();
2163 ///
2164 /// let result = handle.await?;
2165 /// # Ok(())
2166 /// # }
2167 /// ```
2168 pub fn spawn<F, T>(&self, future: F) -> TaskHandle<T>
2169 where
2170 F: Future<Output = Result<T>> + Send + 'static,
2171 T: Send + 'static,
2172 {
2173 self.0
2174 .spawn(future)
2175 .expect("TaskTracker must not be closed when spawning tasks")
2176 }
2177
2178 /// Spawn a cancellable task that receives a cancellation token
2179 ///
2180 /// This is useful for tasks that need to inspect the cancellation token
2181 /// and gracefully handle cancellation within their logic. The task function
2182 /// must return a `CancellableTaskResult` to properly track cancellation vs errors.
2183 ///
2184 /// # Arguments
2185 ///
2186 /// * `task_fn` - Function that takes a cancellation token and returns a future that resolves to `CancellableTaskResult<T>`
2187 ///
2188 /// # Returns
2189 /// A [`TaskHandle`] that can be used to await completion and access the task's cancellation token
2190 ///
2191 /// # Panics
2192 /// Panics if the tracker has been closed. This indicates a programming error
2193 /// where tasks are being spawned after the tracker lifecycle has ended.
2194 ///
2195 /// # Example
2196 /// ```rust
2197 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, CancellableTaskResult};
2198 /// # async fn example(tracker: TaskTracker) -> anyhow::Result<()> {
2199 /// let handle = tracker.spawn_cancellable(|cancel_token| async move {
2200 /// tokio::select! {
2201 /// _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
2202 /// CancellableTaskResult::Ok(42)
2203 /// },
2204 /// _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
2205 /// }
2206 /// });
2207 ///
2208 /// // Access the task's individual cancellation token
2209 /// let task_cancel_token = handle.cancellation_token();
2210 ///
2211 /// let result = handle.await?;
2212 /// # Ok(())
2213 /// # }
2214 /// ```
2215 pub fn spawn_cancellable<F, Fut, T>(&self, task_fn: F) -> TaskHandle<T>
2216 where
2217 F: FnMut(CancellationToken) -> Fut + Send + 'static,
2218 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
2219 T: Send + 'static,
2220 {
2221 self.0
2222 .spawn_cancellable(task_fn)
2223 .expect("TaskTracker must not be closed when spawning tasks")
2224 }
2225
2226 /// Get metrics for this tracker
2227 ///
2228 /// Metrics are specific to this tracker and do not include
2229 /// metrics from parent or child trackers.
2230 ///
2231 /// # Example
2232 /// ```rust
2233 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2234 /// # fn example(tracker: &TaskTracker) {
2235 /// let metrics = tracker.metrics();
2236 /// println!("Success: {}, Failed: {}", metrics.success(), metrics.failed());
2237 /// # }
2238 /// ```
2239 pub fn metrics(&self) -> &dyn HierarchicalTaskMetrics {
2240 self.0.metrics.as_ref()
2241 }
2242
2243 /// Cancel this tracker and all its tasks
2244 ///
2245 /// This will signal cancellation to all currently running tasks and prevent new tasks from being spawned.
2246 /// The cancellation is immediate and forceful.
2247 ///
2248 /// # Example
2249 /// ```rust
2250 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2251 /// # async fn example(tracker: TaskTracker) -> anyhow::Result<()> {
2252 /// // Spawn a long-running task
2253 /// let handle = tracker.spawn_cancellable(|cancel_token| async move {
2254 /// tokio::select! {
2255 /// _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
2256 /// dynamo_runtime::utils::tasks::tracker::CancellableTaskResult::Ok(42)
2257 /// }
2258 /// _ = cancel_token.cancelled() => {
2259 /// dynamo_runtime::utils::tasks::tracker::CancellableTaskResult::Cancelled
2260 /// }
2261 /// }
2262 /// }).await?;
2263 ///
2264 /// // Cancel the tracker (and thus the task)
2265 /// tracker.cancel();
2266 /// # Ok(())
2267 /// # }
2268 /// ```
2269 pub fn cancel(&self) {
2270 self.0.cancel();
2271 }
2272
2273 /// Check if this tracker is closed
2274 pub fn is_closed(&self) -> bool {
2275 self.0.is_closed()
2276 }
2277
2278 /// Get the cancellation token for this tracker
2279 ///
2280 /// This allows external code to observe or trigger cancellation of this tracker.
2281 ///
2282 /// # Example
2283 /// ```rust
2284 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2285 /// # fn example(tracker: &TaskTracker) {
2286 /// let token = tracker.cancellation_token();
2287 /// // Can check cancellation state or cancel manually
2288 /// if !token.is_cancelled() {
2289 /// token.cancel();
2290 /// }
2291 /// # }
2292 /// ```
2293 pub fn cancellation_token(&self) -> CancellationToken {
2294 self.0.cancellation_token()
2295 }
2296
2297 /// Get the number of active child trackers
2298 ///
2299 /// This counts only child trackers that are still alive (not dropped).
2300 /// Dropped child trackers are automatically cleaned up.
2301 ///
2302 /// # Example
2303 /// ```rust
2304 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2305 /// # fn example(tracker: &TaskTracker) {
2306 /// let child_count = tracker.child_count();
2307 /// println!("This tracker has {} active children", child_count);
2308 /// # }
2309 /// ```
2310 pub fn child_count(&self) -> usize {
2311 self.0.child_count()
2312 }
2313
2314 /// Create a child tracker builder with custom configuration
2315 ///
2316 /// This provides fine-grained control over child tracker creation,
2317 /// allowing you to override the scheduler or error policy while
2318 /// maintaining the parent-child relationship.
2319 ///
2320 /// # Example
2321 /// ```rust
2322 /// # use std::sync::Arc;
2323 /// # use tokio::sync::Semaphore;
2324 /// # use dynamo_runtime::utils::tasks::tracker::{TaskTracker, SemaphoreScheduler, LogOnlyPolicy};
2325 /// # fn example(parent: &TaskTracker) {
2326 /// // Custom scheduler, inherit error policy
2327 /// let child1 = parent.child_tracker_builder()
2328 /// .scheduler(SemaphoreScheduler::with_permits(5))
2329 /// .build().unwrap();
2330 ///
2331 /// // Custom error policy, inherit scheduler
2332 /// let child2 = parent.child_tracker_builder()
2333 /// .error_policy(LogOnlyPolicy::new())
2334 /// .build().unwrap();
2335 ///
2336 /// // Inherit both policies from parent
2337 /// let child3 = parent.child_tracker_builder()
2338 /// .build().unwrap();
2339 /// # }
2340 /// ```
2341 pub fn child_tracker_builder(&self) -> ChildTrackerBuilder<'_> {
2342 ChildTrackerBuilder::new(self)
2343 }
2344
2345 /// Join this tracker and all child trackers
2346 ///
2347 /// This method gracefully shuts down the entire tracker hierarchy by:
2348 /// 1. Closing all trackers (preventing new task spawning)
2349 /// 2. Waiting for all existing tasks to complete
2350 ///
2351 /// Uses stack-safe traversal to prevent stack overflow in deep hierarchies.
2352 /// Children are processed before parents to ensure proper shutdown order.
2353 ///
2354 /// **Hierarchical Behavior:**
2355 /// - Processes children before parents to ensure proper shutdown order
2356 /// - Each tracker is closed before waiting (Tokio requirement)
2357 /// - Leaf trackers simply close and wait for their own tasks
2358 ///
2359 /// # Example
2360 /// ```rust
2361 /// # use dynamo_runtime::utils::tasks::tracker::TaskTracker;
2362 /// # async fn example(tracker: TaskTracker) {
2363 /// tracker.join().await;
2364 /// # }
2365 /// ```
2366 pub async fn join(&self) {
2367 self.0.join().await
2368 }
2369}
2370
2371impl TaskTrackerInner {
2372 /// Creates child tracker with inherited scheduler/policy, independent metrics, and hierarchical cancellation
2373 fn child_tracker(self: &Arc<Self>) -> anyhow::Result<Arc<TaskTrackerInner>> {
2374 // Validate that parent tracker is still active
2375 if self.is_closed() {
2376 return Err(anyhow::anyhow!(
2377 "Cannot create child tracker from closed parent tracker"
2378 ));
2379 }
2380
2381 let child_cancel_token = self.cancel_token.child_token();
2382 let child_metrics = Arc::new(ChildTaskMetrics::new(self.metrics.clone()));
2383
2384 let child = Arc::new(TaskTrackerInner {
2385 tokio_tracker: TokioTaskTracker::new(),
2386 parent: Some(self.clone()),
2387 scheduler: self.scheduler.clone(),
2388 error_policy: self.error_policy.create_child(),
2389 metrics: child_metrics,
2390 cancel_token: child_cancel_token,
2391 children: RwLock::new(Vec::new()),
2392 });
2393
2394 // Register this child with the parent for hierarchical operations
2395 self.children.write().unwrap().push(Arc::downgrade(&child));
2396
2397 // Periodically clean up dead children to prevent unbounded growth
2398 self.cleanup_dead_children();
2399
2400 Ok(child)
2401 }
2402
2403 /// Spawn implementation - validates tracker state, generates task ID, applies policies, and tracks execution
2404 fn spawn<F, T>(self: &Arc<Self>, future: F) -> Result<TaskHandle<T>, TaskError>
2405 where
2406 F: Future<Output = Result<T>> + Send + 'static,
2407 T: Send + 'static,
2408 {
2409 // Validate tracker is not closed
2410 if self.tokio_tracker.is_closed() {
2411 return Err(TaskError::TrackerClosed);
2412 }
2413
2414 // Generate a unique task ID
2415 let task_id = self.generate_task_id();
2416
2417 // Increment issued counter immediately when task is submitted
2418 self.metrics.increment_issued();
2419
2420 // Create a child cancellation token for this specific task
2421 let task_cancel_token = self.cancel_token.child_token();
2422 let cancel_token = task_cancel_token.clone();
2423
2424 // Clone the inner Arc to move into the task
2425 let inner = self.clone();
2426
2427 // Wrap the user's future with our scheduling and error handling
2428 let wrapped_future =
2429 async move { Self::execute_with_policies(task_id, future, cancel_token, inner).await };
2430
2431 // Let tokio handle the actual task tracking
2432 let join_handle = self.tokio_tracker.spawn(wrapped_future);
2433
2434 // Wrap in TaskHandle with the child cancellation token
2435 Ok(TaskHandle::new(join_handle, task_cancel_token))
2436 }
2437
2438 /// Spawn cancellable implementation - validates state, provides cancellation token, handles CancellableTaskResult
2439 fn spawn_cancellable<F, Fut, T>(
2440 self: &Arc<Self>,
2441 task_fn: F,
2442 ) -> Result<TaskHandle<T>, TaskError>
2443 where
2444 F: FnMut(CancellationToken) -> Fut + Send + 'static,
2445 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
2446 T: Send + 'static,
2447 {
2448 // Validate tracker is not closed
2449 if self.tokio_tracker.is_closed() {
2450 return Err(TaskError::TrackerClosed);
2451 }
2452
2453 // Generate a unique task ID
2454 let task_id = self.generate_task_id();
2455
2456 // Increment issued counter immediately when task is submitted
2457 self.metrics.increment_issued();
2458
2459 // Create a child cancellation token for this specific task
2460 let task_cancel_token = self.cancel_token.child_token();
2461 let cancel_token = task_cancel_token.clone();
2462
2463 // Clone the inner Arc to move into the task
2464 let inner = self.clone();
2465
2466 // Use the new execution pipeline that defers task creation until after guard acquisition
2467 let wrapped_future = async move {
2468 Self::execute_cancellable_with_policies(task_id, task_fn, cancel_token, inner).await
2469 };
2470
2471 // Let tokio handle the actual task tracking
2472 let join_handle = self.tokio_tracker.spawn(wrapped_future);
2473
2474 // Wrap in TaskHandle with the child cancellation token
2475 Ok(TaskHandle::new(join_handle, task_cancel_token))
2476 }
2477
2478 /// Cancel this tracker and all its tasks - implementation
2479 fn cancel(&self) {
2480 // Close the tracker to prevent new tasks
2481 self.tokio_tracker.close();
2482
2483 // Cancel our own token
2484 self.cancel_token.cancel();
2485 }
2486
2487 /// Returns true if the underlying tokio tracker is closed
2488 fn is_closed(&self) -> bool {
2489 self.tokio_tracker.is_closed()
2490 }
2491
2492 /// Generates a unique task ID using TaskId::new()
2493 fn generate_task_id(&self) -> TaskId {
2494 TaskId::new()
2495 }
2496
2497 /// Removes dead weak references from children list to prevent memory leaks
2498 fn cleanup_dead_children(&self) {
2499 let mut children_guard = self.children.write().unwrap();
2500 children_guard.retain(|weak| weak.upgrade().is_some());
2501 }
2502
2503 /// Returns a clone of the cancellation token
2504 fn cancellation_token(&self) -> CancellationToken {
2505 self.cancel_token.clone()
2506 }
2507
2508 /// Counts active child trackers (filters out dead weak references)
2509 fn child_count(&self) -> usize {
2510 let children_guard = self.children.read().unwrap();
2511 children_guard
2512 .iter()
2513 .filter(|weak| weak.upgrade().is_some())
2514 .count()
2515 }
2516
2517 /// Join implementation - closes all trackers in hierarchy then waits for task completion using stack-safe traversal
2518 async fn join(self: &Arc<Self>) {
2519 // Fast path for leaf trackers (no children)
2520 let is_leaf = {
2521 let children_guard = self.children.read().unwrap();
2522 children_guard.is_empty()
2523 };
2524
2525 if is_leaf {
2526 self.tokio_tracker.close();
2527 self.tokio_tracker.wait().await;
2528 return;
2529 }
2530
2531 // Stack-safe traversal for deep hierarchies
2532 // Processes children before parents to ensure proper shutdown order
2533 let trackers = self.collect_hierarchy();
2534 for t in trackers {
2535 t.tokio_tracker.close();
2536 t.tokio_tracker.wait().await;
2537 }
2538 }
2539
2540 /// Collects hierarchy using iterative DFS, returns Vec in post-order (children before parents) for safe shutdown
2541 fn collect_hierarchy(self: &Arc<TaskTrackerInner>) -> Vec<Arc<TaskTrackerInner>> {
2542 let mut result = Vec::new();
2543 let mut stack = vec![self.clone()];
2544 let mut visited = HashSet::new();
2545
2546 // Collect all trackers using depth-first search
2547 while let Some(tracker) = stack.pop() {
2548 let tracker_ptr = Arc::as_ptr(&tracker) as usize;
2549 if visited.contains(&tracker_ptr) {
2550 continue;
2551 }
2552 visited.insert(tracker_ptr);
2553
2554 // Add current tracker to result
2555 result.push(tracker.clone());
2556
2557 // Add children to stack for processing
2558 if let Ok(children_guard) = tracker.children.read() {
2559 for weak_child in children_guard.iter() {
2560 if let Some(child) = weak_child.upgrade() {
2561 let child_ptr = Arc::as_ptr(&child) as usize;
2562 if !visited.contains(&child_ptr) {
2563 stack.push(child);
2564 }
2565 }
2566 }
2567 }
2568 }
2569
2570 // Reverse to get bottom-up order (children before parents)
2571 result.reverse();
2572 result
2573 }
2574
2575 /// Execute a regular task with scheduling and error handling policies
2576 #[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
2577 async fn execute_with_policies<F, T>(
2578 task_id: TaskId,
2579 future: F,
2580 task_cancel_token: CancellationToken,
2581 inner: Arc<TaskTrackerInner>,
2582 ) -> Result<T, TaskError>
2583 where
2584 F: Future<Output = Result<T>> + Send + 'static,
2585 T: Send + 'static,
2586 {
2587 // Wrap regular future in a task executor that doesn't support retry
2588 let task_executor = RegularTaskExecutor::new(future);
2589 Self::execute_with_retry_loop(task_id, task_executor, task_cancel_token, inner).await
2590 }
2591
2592 /// Execute a cancellable task with scheduling and error handling policies
2593 #[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
2594 async fn execute_cancellable_with_policies<F, Fut, T>(
2595 task_id: TaskId,
2596 task_fn: F,
2597 task_cancel_token: CancellationToken,
2598 inner: Arc<TaskTrackerInner>,
2599 ) -> Result<T, TaskError>
2600 where
2601 F: FnMut(CancellationToken) -> Fut + Send + 'static,
2602 Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
2603 T: Send + 'static,
2604 {
2605 // Wrap cancellable task function in a task executor that supports retry
2606 let task_executor = CancellableTaskExecutor::new(task_fn);
2607 Self::execute_with_retry_loop(task_id, task_executor, task_cancel_token, inner).await
2608 }
2609
2610 /// Core execution loop with retry support - unified for both task types
2611 #[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
2612 async fn execute_with_retry_loop<E, T>(
2613 task_id: TaskId,
2614 initial_executor: E,
2615 task_cancellation_token: CancellationToken,
2616 inner: Arc<TaskTrackerInner>,
2617 ) -> Result<T, TaskError>
2618 where
2619 E: TaskExecutor<T> + Send + 'static,
2620 T: Send + 'static,
2621 {
2622 debug!("Starting task execution");
2623
2624 // RAII guard for active counter - increments on creation, decrements on drop
2625 struct ActiveCountGuard {
2626 metrics: Arc<dyn HierarchicalTaskMetrics>,
2627 is_active: bool,
2628 }
2629
2630 impl ActiveCountGuard {
2631 fn new(metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
2632 Self {
2633 metrics,
2634 is_active: false,
2635 }
2636 }
2637
2638 fn activate(&mut self) {
2639 if !self.is_active {
2640 self.metrics.increment_started();
2641 self.is_active = true;
2642 }
2643 }
2644 }
2645
2646 // Current executable - either the original TaskExecutor or a Continuation
2647 enum CurrentExecutable<E>
2648 where
2649 E: Send + 'static,
2650 {
2651 TaskExecutor(E),
2652 Continuation(Arc<dyn Continuation + Send + Sync + 'static>),
2653 }
2654
2655 let mut current_executable = CurrentExecutable::TaskExecutor(initial_executor);
2656 let mut active_guard = ActiveCountGuard::new(inner.metrics.clone());
2657 let mut error_context: Option<OnErrorContext> = None;
2658 let mut scheduler_guard_state = self::GuardState::Keep;
2659 let mut guard_result = async {
2660 inner
2661 .scheduler
2662 .acquire_execution_slot(task_cancellation_token.child_token())
2663 .await
2664 }
2665 .instrument(tracing::debug_span!("scheduler_resource_reacquisition"))
2666 .await;
2667
2668 loop {
2669 if scheduler_guard_state == self::GuardState::Reschedule {
2670 guard_result = async {
2671 inner
2672 .scheduler
2673 .acquire_execution_slot(inner.cancel_token.child_token())
2674 .await
2675 }
2676 .instrument(tracing::debug_span!("scheduler_resource_reacquisition"))
2677 .await;
2678 }
2679
2680 match &guard_result {
2681 SchedulingResult::Execute(_guard) => {
2682 // Activate the RAII guard only once when we successfully acquire resources
2683 active_guard.activate();
2684
2685 // Execute the current executable while holding the guard (RAII pattern)
2686 let execution_result = async {
2687 debug!("Executing task with acquired resources");
2688 match &mut current_executable {
2689 CurrentExecutable::TaskExecutor(executor) => {
2690 executor.execute(inner.cancel_token.child_token()).await
2691 }
2692 CurrentExecutable::Continuation(continuation) => {
2693 // Execute continuation and handle type erasure
2694 match continuation.execute(inner.cancel_token.child_token()).await {
2695 TaskExecutionResult::Success(result) => {
2696 // Try to downcast the result to the expected type T
2697 if let Ok(typed_result) = result.downcast::<T>() {
2698 TaskExecutionResult::Success(*typed_result)
2699 } else {
2700 // Type mismatch - this shouldn't happen with proper usage
2701 let type_error = anyhow::anyhow!(
2702 "Continuation task returned wrong type"
2703 );
2704 error!(
2705 ?type_error,
2706 "Type mismatch in continuation task result"
2707 );
2708 TaskExecutionResult::Error(type_error)
2709 }
2710 }
2711 TaskExecutionResult::Cancelled => {
2712 TaskExecutionResult::Cancelled
2713 }
2714 TaskExecutionResult::Error(error) => {
2715 TaskExecutionResult::Error(error)
2716 }
2717 }
2718 }
2719 }
2720 }
2721 .instrument(tracing::debug_span!("task_execution"))
2722 .await;
2723
2724 // Active counter will be decremented automatically when active_guard drops
2725
2726 match execution_result {
2727 TaskExecutionResult::Success(value) => {
2728 inner.metrics.increment_success();
2729 debug!("Task completed successfully");
2730 return Ok(value);
2731 }
2732 TaskExecutionResult::Cancelled => {
2733 inner.metrics.increment_cancelled();
2734 debug!("Task was cancelled during execution");
2735 return Err(TaskError::Cancelled);
2736 }
2737 TaskExecutionResult::Error(error) => {
2738 debug!("Task failed - handling error through policy - {error:?}");
2739
2740 // Handle the error through the policy system
2741 let (action_result, guard_state) = Self::handle_task_error(
2742 &error,
2743 &mut error_context,
2744 task_id,
2745 &inner,
2746 )
2747 .await;
2748
2749 // Update the scheduler guard state for evaluation after the match
2750 scheduler_guard_state = guard_state;
2751
2752 match action_result {
2753 ActionResult::Fail => {
2754 inner.metrics.increment_failed();
2755 debug!("Policy accepted error - task failed {error:?}");
2756 return Err(TaskError::Failed(error));
2757 }
2758 ActionResult::Shutdown => {
2759 inner.metrics.increment_failed();
2760 warn!("Policy triggered shutdown - {error:?}");
2761 inner.cancel();
2762 return Err(TaskError::Failed(error));
2763 }
2764 ActionResult::Continue { continuation } => {
2765 debug!(
2766 "Policy provided next executable - continuing loop - {error:?}"
2767 );
2768
2769 // Update current executable
2770 current_executable =
2771 CurrentExecutable::Continuation(continuation);
2772
2773 continue; // Continue the main loop with the new executable
2774 }
2775 }
2776 }
2777 }
2778 }
2779 SchedulingResult::Cancelled => {
2780 inner.metrics.increment_cancelled();
2781 debug!("Task was cancelled during resource acquisition");
2782 return Err(TaskError::Cancelled);
2783 }
2784 SchedulingResult::Rejected(reason) => {
2785 inner.metrics.increment_rejected();
2786 debug!(reason, "Task was rejected by scheduler");
2787 return Err(TaskError::Failed(anyhow::anyhow!(
2788 "Task rejected: {}",
2789 reason
2790 )));
2791 }
2792 }
2793 }
2794 }
2795
2796 /// Handle task errors through the error policy and return the action to take
2797 async fn handle_task_error(
2798 error: &anyhow::Error,
2799 error_context: &mut Option<OnErrorContext>,
2800 task_id: TaskId,
2801 inner: &Arc<TaskTrackerInner>,
2802 ) -> (ActionResult, self::GuardState) {
2803 // Create or update the error context (lazy initialization)
2804 let context = error_context.get_or_insert_with(|| OnErrorContext {
2805 attempt_count: 0, // Will be incremented below
2806 task_id,
2807 execution_context: TaskExecutionContext {
2808 scheduler: inner.scheduler.clone(),
2809 metrics: inner.metrics.clone(),
2810 },
2811 state: inner.error_policy.create_context(),
2812 });
2813
2814 // Increment attempt count for this error
2815 context.attempt_count += 1;
2816 let current_attempt = context.attempt_count;
2817
2818 // First, check if the policy allows continuations for this error
2819 if inner.error_policy.allow_continuation(error, context) {
2820 // Policy allows continuations, check if this is a FailedWithContinuation (task-driven continuation)
2821 if let Some(continuation_err) = error.downcast_ref::<FailedWithContinuation>() {
2822 debug!(
2823 task_id = %task_id,
2824 attempt_count = current_attempt,
2825 "Task provided FailedWithContinuation and policy allows continuations - {error:?}"
2826 );
2827
2828 // Task has provided a continuation implementation for the next attempt
2829 // Clone the Arc to return it in ActionResult::Continue
2830 let continuation = continuation_err.continuation.clone();
2831
2832 // Ask policy whether to reschedule task-driven continuation
2833 let should_reschedule = inner.error_policy.should_reschedule(error, context);
2834
2835 let guard_state = if should_reschedule {
2836 self::GuardState::Reschedule
2837 } else {
2838 self::GuardState::Keep
2839 };
2840
2841 return (ActionResult::Continue { continuation }, guard_state);
2842 }
2843 } else {
2844 debug!(
2845 task_id = %task_id,
2846 attempt_count = current_attempt,
2847 "Policy rejected continuations, ignoring any FailedWithContinuation - {error:?}"
2848 );
2849 }
2850
2851 let response = inner.error_policy.on_error(error, context);
2852
2853 match response {
2854 ErrorResponse::Fail => (ActionResult::Fail, self::GuardState::Keep),
2855 ErrorResponse::Shutdown => (ActionResult::Shutdown, self::GuardState::Keep),
2856 ErrorResponse::Custom(action) => {
2857 debug!("Task failed - executing custom action - {error:?}");
2858
2859 // Execute the custom action asynchronously
2860 let action_result = action
2861 .execute(error, task_id, current_attempt, &context.execution_context)
2862 .await;
2863 debug!(?action_result, "Custom action completed");
2864
2865 // If the custom action returned Continue, ask policy about rescheduling
2866 let guard_state = match &action_result {
2867 ActionResult::Continue { .. } => {
2868 let should_reschedule =
2869 inner.error_policy.should_reschedule(error, context);
2870 if should_reschedule {
2871 self::GuardState::Reschedule
2872 } else {
2873 self::GuardState::Keep
2874 }
2875 }
2876 _ => self::GuardState::Keep, // Fail/Shutdown don't need guard state
2877 };
2878
2879 (action_result, guard_state)
2880 }
2881 }
2882 }
2883}
2884
2885// Blanket implementation for all schedulers
2886impl ArcPolicy for UnlimitedScheduler {}
2887impl ArcPolicy for SemaphoreScheduler {}
2888
2889// Blanket implementation for all error policies
2890impl ArcPolicy for LogOnlyPolicy {}
2891impl ArcPolicy for CancelOnError {}
2892impl ArcPolicy for ThresholdCancelPolicy {}
2893impl ArcPolicy for RateCancelPolicy {}
2894
2895/// Resource guard for unlimited scheduling
2896///
2897/// This guard represents "unlimited" resources - no actual resource constraints.
2898/// Since there are no resources to manage, this guard is essentially a no-op.
2899#[derive(Debug)]
2900pub struct UnlimitedGuard;
2901
2902impl ResourceGuard for UnlimitedGuard {
2903 // No resources to manage - marker trait implementation only
2904}
2905
2906/// Unlimited task scheduler that executes all tasks immediately
2907///
2908/// This scheduler provides no concurrency limits and executes all submitted tasks
2909/// immediately. Useful for testing, high-throughput scenarios, or when external
2910/// systems provide the concurrency control.
2911///
2912/// ## Cancellation Behavior
2913///
2914/// - Respects cancellation tokens before resource acquisition
2915/// - Once execution begins (via ResourceGuard), always awaits task completion
2916/// - Tasks handle their own cancellation internally (if created with `spawn_cancellable`)
2917///
2918/// # Example
2919/// ```rust
2920/// # use dynamo_runtime::utils::tasks::tracker::UnlimitedScheduler;
2921/// let scheduler = UnlimitedScheduler::new();
2922/// ```
2923#[derive(Debug)]
2924pub struct UnlimitedScheduler;
2925
2926impl UnlimitedScheduler {
2927 /// Create a new unlimited scheduler returning Arc
2928 pub fn new() -> Arc<Self> {
2929 Arc::new(Self)
2930 }
2931}
2932
2933impl Default for UnlimitedScheduler {
2934 fn default() -> Self {
2935 UnlimitedScheduler
2936 }
2937}
2938
2939#[async_trait]
2940impl TaskScheduler for UnlimitedScheduler {
2941 async fn acquire_execution_slot(
2942 &self,
2943 cancel_token: CancellationToken,
2944 ) -> SchedulingResult<Box<dyn ResourceGuard>> {
2945 debug!("Acquiring execution slot (unlimited scheduler)");
2946
2947 // Check for cancellation before allocating resources
2948 if cancel_token.is_cancelled() {
2949 debug!("Task cancelled before acquiring execution slot");
2950 return SchedulingResult::Cancelled;
2951 }
2952
2953 // No resource constraints for unlimited scheduler
2954 debug!("Execution slot acquired immediately");
2955 SchedulingResult::Execute(Box::new(UnlimitedGuard))
2956 }
2957}
2958
2959/// Resource guard for semaphore-based scheduling
2960///
2961/// This guard holds a semaphore permit and enforces that task execution
2962/// always runs to completion. The permit is automatically released when
2963/// the guard is dropped.
2964#[derive(Debug)]
2965pub struct SemaphoreGuard {
2966 _permit: tokio::sync::OwnedSemaphorePermit,
2967}
2968
2969impl ResourceGuard for SemaphoreGuard {
2970 // Permit is automatically released when the guard is dropped
2971}
2972
2973/// Semaphore-based task scheduler
2974///
2975/// Limits concurrent task execution using a [`tokio::sync::Semaphore`].
2976/// Tasks will wait for an available permit before executing.
2977///
2978/// ## Cancellation Behavior
2979///
2980/// - Respects cancellation tokens before and during permit acquisition
2981/// - Once a permit is acquired (via ResourceGuard), always awaits task completion
2982/// - Holds the permit until the task completes (regardless of cancellation)
2983/// - Tasks handle their own cancellation internally (if created with `spawn_cancellable`)
2984///
2985/// This ensures that permits are not leaked when tasks are cancelled, while still
2986/// allowing cancellable tasks to terminate gracefully on their own.
2987///
2988/// # Example
2989/// ```rust
2990/// # use std::sync::Arc;
2991/// # use tokio::sync::Semaphore;
2992/// # use dynamo_runtime::utils::tasks::tracker::SemaphoreScheduler;
2993/// // Allow up to 5 concurrent tasks
2994/// let semaphore = Arc::new(Semaphore::new(5));
2995/// let scheduler = SemaphoreScheduler::new(semaphore);
2996/// ```
2997#[derive(Debug)]
2998pub struct SemaphoreScheduler {
2999 semaphore: Arc<Semaphore>,
3000}
3001
3002impl SemaphoreScheduler {
3003 /// Create a new semaphore scheduler
3004 ///
3005 /// # Arguments
3006 /// * `semaphore` - Semaphore to use for concurrency control
3007 pub fn new(semaphore: Arc<Semaphore>) -> Self {
3008 Self { semaphore }
3009 }
3010
3011 /// Create a semaphore scheduler with the specified number of permits, returning Arc
3012 pub fn with_permits(permits: usize) -> Arc<Self> {
3013 Arc::new(Self::new(Arc::new(Semaphore::new(permits))))
3014 }
3015
3016 /// Get the number of available permits
3017 pub fn available_permits(&self) -> usize {
3018 self.semaphore.available_permits()
3019 }
3020}
3021
3022#[async_trait]
3023impl TaskScheduler for SemaphoreScheduler {
3024 async fn acquire_execution_slot(
3025 &self,
3026 cancel_token: CancellationToken,
3027 ) -> SchedulingResult<Box<dyn ResourceGuard>> {
3028 debug!("Acquiring semaphore permit");
3029
3030 // Check for cancellation before attempting to acquire semaphore
3031 if cancel_token.is_cancelled() {
3032 debug!("Task cancelled before acquiring semaphore permit");
3033 return SchedulingResult::Cancelled;
3034 }
3035
3036 // Try to acquire a permit, with cancellation support
3037 let permit = {
3038 tokio::select! {
3039 result = self.semaphore.clone().acquire_owned() => {
3040 match result {
3041 Ok(permit) => permit,
3042 Err(_) => return SchedulingResult::Cancelled,
3043 }
3044 }
3045 _ = cancel_token.cancelled() => {
3046 debug!("Task cancelled while waiting for semaphore permit");
3047 return SchedulingResult::Cancelled;
3048 }
3049 }
3050 };
3051
3052 debug!("Acquired semaphore permit");
3053 SchedulingResult::Execute(Box::new(SemaphoreGuard { _permit: permit }))
3054 }
3055}
3056
3057/// Error policy that triggers cancellation based on error patterns
3058///
3059/// This policy analyzes error messages and returns `ErrorResponse::Shutdown` when:
3060/// - No patterns are specified (cancels on any error)
3061/// - Error message matches one of the specified patterns
3062///
3063/// The TaskTracker handles the actual cancellation - this policy just makes the decision.
3064///
3065/// # Example
3066/// ```rust
3067/// # use dynamo_runtime::utils::tasks::tracker::CancelOnError;
3068/// // Cancel on any error
3069/// let policy = CancelOnError::new();
3070///
3071/// // Cancel only on specific error patterns
3072/// let (policy, _token) = CancelOnError::with_patterns(
3073/// vec!["OutOfMemory".to_string(), "DeviceError".to_string()]
3074/// );
3075/// ```
3076#[derive(Debug)]
3077pub struct CancelOnError {
3078 error_patterns: Vec<String>,
3079}
3080
3081impl CancelOnError {
3082 /// Create a new cancel-on-error policy that cancels on any error
3083 ///
3084 /// Returns a policy with no error patterns, meaning it will cancel the TaskTracker
3085 /// on any task failure.
3086 pub fn new() -> Arc<Self> {
3087 Arc::new(Self {
3088 error_patterns: vec![], // Empty patterns = cancel on any error
3089 })
3090 }
3091
3092 /// Create a new cancel-on-error policy with custom error patterns, returning Arc and token
3093 ///
3094 /// # Arguments
3095 /// * `error_patterns` - List of error message patterns that trigger cancellation
3096 pub fn with_patterns(error_patterns: Vec<String>) -> (Arc<Self>, CancellationToken) {
3097 let token = CancellationToken::new();
3098 let policy = Arc::new(Self { error_patterns });
3099 (policy, token)
3100 }
3101}
3102
3103#[async_trait]
3104impl OnErrorPolicy for CancelOnError {
3105 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3106 // Child gets a child cancel token - when parent cancels, child cancels too
3107 // When child cancels, parent is unaffected
3108 Arc::new(CancelOnError {
3109 error_patterns: self.error_patterns.clone(),
3110 })
3111 }
3112
3113 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3114 None // Stateless policy - no heap allocation
3115 }
3116
3117 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3118 error!(?context.task_id, "Task failed - {error:?}");
3119
3120 if self.error_patterns.is_empty() {
3121 return ErrorResponse::Shutdown;
3122 }
3123
3124 // Check if this error should trigger cancellation
3125 let error_str = error.to_string();
3126 let should_cancel = self
3127 .error_patterns
3128 .iter()
3129 .any(|pattern| error_str.contains(pattern));
3130
3131 if should_cancel {
3132 ErrorResponse::Shutdown
3133 } else {
3134 ErrorResponse::Fail
3135 }
3136 }
3137}
3138
3139/// Simple error policy that only logs errors
3140///
3141/// This policy does not trigger cancellation and is useful for
3142/// non-critical tasks or when you want to handle errors externally.
3143#[derive(Debug)]
3144pub struct LogOnlyPolicy;
3145
3146impl LogOnlyPolicy {
3147 /// Create a new log-only policy returning Arc
3148 pub fn new() -> Arc<Self> {
3149 Arc::new(Self)
3150 }
3151}
3152
3153impl Default for LogOnlyPolicy {
3154 fn default() -> Self {
3155 LogOnlyPolicy
3156 }
3157}
3158
3159impl OnErrorPolicy for LogOnlyPolicy {
3160 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3161 // Simple policies can just clone themselves
3162 Arc::new(LogOnlyPolicy)
3163 }
3164
3165 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3166 None // Stateless policy - no heap allocation
3167 }
3168
3169 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3170 error!(?context.task_id, "Task failed - logging only - {error:?}");
3171 ErrorResponse::Fail
3172 }
3173}
3174
3175/// Error policy that cancels tasks after a threshold number of failures
3176///
3177/// This policy tracks the number of failed tasks and triggers cancellation
3178/// when the failure count exceeds the specified threshold. Useful for
3179/// preventing cascading failures in distributed systems.
3180///
3181/// # Example
3182/// ```rust
3183/// # use dynamo_runtime::utils::tasks::tracker::ThresholdCancelPolicy;
3184/// // Cancel after 5 failures
3185/// let policy = ThresholdCancelPolicy::with_threshold(5);
3186/// ```
3187#[derive(Debug)]
3188pub struct ThresholdCancelPolicy {
3189 max_failures: usize,
3190 failure_count: AtomicU64,
3191}
3192
3193impl ThresholdCancelPolicy {
3194 /// Create a new threshold cancel policy with specified failure threshold, returning Arc and token
3195 ///
3196 /// # Arguments
3197 /// * `max_failures` - Maximum number of failures before cancellation
3198 pub fn with_threshold(max_failures: usize) -> Arc<Self> {
3199 Arc::new(Self {
3200 max_failures,
3201 failure_count: AtomicU64::new(0),
3202 })
3203 }
3204
3205 /// Get the current failure count
3206 pub fn failure_count(&self) -> u64 {
3207 self.failure_count.load(Ordering::Relaxed)
3208 }
3209
3210 /// Reset the failure count to zero
3211 ///
3212 /// This is primarily useful for testing scenarios where you want to reset
3213 /// the policy state between test cases.
3214 pub fn reset_failure_count(&self) {
3215 self.failure_count.store(0, Ordering::Relaxed);
3216 }
3217}
3218
3219/// Per-task state for ThresholdCancelPolicy
3220#[derive(Debug)]
3221struct ThresholdState {
3222 failure_count: u32,
3223}
3224
3225impl OnErrorPolicy for ThresholdCancelPolicy {
3226 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3227 // Child gets a child cancel token and inherits the same failure threshold
3228 Arc::new(ThresholdCancelPolicy {
3229 max_failures: self.max_failures,
3230 failure_count: AtomicU64::new(0), // Child starts with fresh count
3231 })
3232 }
3233
3234 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3235 Some(Box::new(ThresholdState { failure_count: 0 }))
3236 }
3237
3238 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3239 error!(?context.task_id, "Task failed - {error:?}");
3240
3241 // Increment global counter for backwards compatibility
3242 let global_failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
3243
3244 // Get per-task state for the actual decision logic
3245 let state = context
3246 .state
3247 .as_mut()
3248 .expect("ThresholdCancelPolicy requires state")
3249 .downcast_mut::<ThresholdState>()
3250 .expect("Context type mismatch");
3251
3252 state.failure_count += 1;
3253 let current_failures = state.failure_count;
3254
3255 if current_failures >= self.max_failures as u32 {
3256 warn!(
3257 ?context.task_id,
3258 current_failures,
3259 global_failures,
3260 max_failures = self.max_failures,
3261 "Per-task failure threshold exceeded, triggering cancellation"
3262 );
3263 ErrorResponse::Shutdown
3264 } else {
3265 debug!(
3266 ?context.task_id,
3267 current_failures,
3268 global_failures,
3269 max_failures = self.max_failures,
3270 "Task failed, tracking per-task failure count"
3271 );
3272 ErrorResponse::Fail
3273 }
3274 }
3275}
3276
3277/// Error policy that cancels tasks when failure rate exceeds threshold within time window
3278///
3279/// This policy tracks failures over a rolling time window and triggers cancellation
3280/// when the failure rate exceeds the specified threshold. More sophisticated than
3281/// simple count-based thresholds as it considers the time dimension.
3282///
3283/// # Example
3284/// ```rust
3285/// # use dynamo_runtime::utils::tasks::tracker::RateCancelPolicy;
3286/// // Cancel if more than 50% of tasks fail within any 60-second window
3287/// let (policy, token) = RateCancelPolicy::builder()
3288/// .rate(0.5)
3289/// .window_secs(60)
3290/// .build();
3291/// ```
3292#[derive(Debug)]
3293pub struct RateCancelPolicy {
3294 cancel_token: CancellationToken,
3295 max_failure_rate: f32,
3296 window_secs: u64,
3297 // TODO: Implement time-window tracking when needed
3298 // For now, this is a placeholder structure with the interface defined
3299}
3300
3301impl RateCancelPolicy {
3302 /// Create a builder for rate-based cancel policy
3303 pub fn builder() -> RateCancelPolicyBuilder {
3304 RateCancelPolicyBuilder::new()
3305 }
3306}
3307
3308/// Builder for RateCancelPolicy
3309pub struct RateCancelPolicyBuilder {
3310 max_failure_rate: Option<f32>,
3311 window_secs: Option<u64>,
3312}
3313
3314impl RateCancelPolicyBuilder {
3315 fn new() -> Self {
3316 Self {
3317 max_failure_rate: None,
3318 window_secs: None,
3319 }
3320 }
3321
3322 /// Set the maximum failure rate (0.0 to 1.0) before cancellation
3323 pub fn rate(mut self, max_failure_rate: f32) -> Self {
3324 self.max_failure_rate = Some(max_failure_rate);
3325 self
3326 }
3327
3328 /// Set the time window in seconds for rate calculation
3329 pub fn window_secs(mut self, window_secs: u64) -> Self {
3330 self.window_secs = Some(window_secs);
3331 self
3332 }
3333
3334 /// Build the policy, returning Arc and cancellation token
3335 pub fn build(self) -> (Arc<RateCancelPolicy>, CancellationToken) {
3336 let max_failure_rate = self.max_failure_rate.expect("rate must be set");
3337 let window_secs = self.window_secs.expect("window_secs must be set");
3338
3339 let token = CancellationToken::new();
3340 let policy = Arc::new(RateCancelPolicy {
3341 cancel_token: token.clone(),
3342 max_failure_rate,
3343 window_secs,
3344 });
3345 (policy, token)
3346 }
3347}
3348
3349#[async_trait]
3350impl OnErrorPolicy for RateCancelPolicy {
3351 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3352 Arc::new(RateCancelPolicy {
3353 cancel_token: self.cancel_token.child_token(),
3354 max_failure_rate: self.max_failure_rate,
3355 window_secs: self.window_secs,
3356 })
3357 }
3358
3359 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3360 None // Stateless policy for now (TODO: add time-window state)
3361 }
3362
3363 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3364 error!(?context.task_id, "Task failed - {error:?}");
3365
3366 // TODO: Implement time-window failure rate calculation
3367 // For now, just log the error and continue
3368 warn!(
3369 ?context.task_id,
3370 max_failure_rate = self.max_failure_rate,
3371 window_secs = self.window_secs,
3372 "Rate-based error policy - time window tracking not yet implemented"
3373 );
3374
3375 ErrorResponse::Fail
3376 }
3377}
3378
3379/// Custom action that triggers a cancellation token when executed
3380///
3381/// This action demonstrates the ErrorResponse::Custom behavior by capturing
3382/// an external cancellation token and triggering it when executed.
3383#[derive(Debug)]
3384pub struct TriggerCancellationTokenAction {
3385 cancel_token: CancellationToken,
3386}
3387
3388impl TriggerCancellationTokenAction {
3389 pub fn new(cancel_token: CancellationToken) -> Self {
3390 Self { cancel_token }
3391 }
3392}
3393
3394#[async_trait]
3395impl OnErrorAction for TriggerCancellationTokenAction {
3396 async fn execute(
3397 &self,
3398 error: &anyhow::Error,
3399 task_id: TaskId,
3400 _attempt_count: u32,
3401 _context: &TaskExecutionContext,
3402 ) -> ActionResult {
3403 warn!(
3404 ?task_id,
3405 "Executing custom action: triggering cancellation token - {error:?}"
3406 );
3407
3408 // Trigger the custom cancellation token
3409 self.cancel_token.cancel();
3410
3411 // Return success - the action completed successfully
3412 ActionResult::Shutdown
3413 }
3414}
3415
3416/// Test error policy that triggers a custom cancellation token on any error
3417///
3418/// This policy demonstrates the ErrorResponse::Custom behavior by capturing
3419/// an external cancellation token and triggering it when any error occurs.
3420/// Used for testing custom error handling actions.
3421///
3422/// # Example
3423/// ```rust
3424/// # use tokio_util::sync::CancellationToken;
3425/// # use dynamo_runtime::utils::tasks::tracker::TriggerCancellationTokenOnError;
3426/// let cancel_token = CancellationToken::new();
3427/// let policy = TriggerCancellationTokenOnError::new(cancel_token.clone());
3428///
3429/// // Policy will trigger the token on any error via ErrorResponse::Custom
3430/// ```
3431#[derive(Debug)]
3432pub struct TriggerCancellationTokenOnError {
3433 cancel_token: CancellationToken,
3434}
3435
3436impl TriggerCancellationTokenOnError {
3437 /// Create a new policy that triggers the given cancellation token on errors
3438 pub fn new(cancel_token: CancellationToken) -> Arc<Self> {
3439 Arc::new(Self { cancel_token })
3440 }
3441}
3442
3443impl OnErrorPolicy for TriggerCancellationTokenOnError {
3444 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
3445 // Child gets a child cancel token
3446 Arc::new(TriggerCancellationTokenOnError {
3447 cancel_token: self.cancel_token.clone(),
3448 })
3449 }
3450
3451 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
3452 None // Stateless policy - no heap allocation
3453 }
3454
3455 fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
3456 error!(
3457 ?context.task_id,
3458 "Task failed - triggering custom cancellation token - {error:?}"
3459 );
3460
3461 // Create the custom action that will trigger our token
3462 let action = TriggerCancellationTokenAction::new(self.cancel_token.clone());
3463
3464 // Return Custom response with our action
3465 ErrorResponse::Custom(Box::new(action))
3466 }
3467}
3468
3469#[cfg(test)]
3470mod tests {
3471 use super::*;
3472 use rstest::*;
3473 use std::sync::atomic::AtomicU32;
3474 use std::time::Duration;
3475
3476 // Test fixtures using rstest
3477 #[fixture]
3478 fn semaphore_scheduler() -> Arc<SemaphoreScheduler> {
3479 Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))))
3480 }
3481
3482 #[fixture]
3483 fn unlimited_scheduler() -> Arc<UnlimitedScheduler> {
3484 UnlimitedScheduler::new()
3485 }
3486
3487 #[fixture]
3488 fn log_policy() -> Arc<LogOnlyPolicy> {
3489 LogOnlyPolicy::new()
3490 }
3491
3492 #[fixture]
3493 fn cancel_policy() -> Arc<CancelOnError> {
3494 CancelOnError::new()
3495 }
3496
3497 #[fixture]
3498 fn basic_tracker(
3499 unlimited_scheduler: Arc<UnlimitedScheduler>,
3500 log_policy: Arc<LogOnlyPolicy>,
3501 ) -> TaskTracker {
3502 TaskTracker::new(unlimited_scheduler, log_policy).unwrap()
3503 }
3504
3505 #[rstest]
3506 #[tokio::test]
3507 async fn test_basic_task_execution(basic_tracker: TaskTracker) {
3508 // Test successful task execution
3509 let (tx, rx) = tokio::sync::oneshot::channel();
3510 let handle = basic_tracker.spawn(async {
3511 // Wait for signal to complete instead of sleep
3512 rx.await.ok();
3513 Ok(42)
3514 });
3515
3516 // Signal task to complete
3517 tx.send(()).ok();
3518
3519 // Verify task completes successfully
3520 let result = handle
3521 .await
3522 .expect("Task should complete")
3523 .expect("Task should succeed");
3524 assert_eq!(result, 42);
3525
3526 // Verify metrics
3527 assert_eq!(basic_tracker.metrics().success(), 1);
3528 assert_eq!(basic_tracker.metrics().failed(), 0);
3529 assert_eq!(basic_tracker.metrics().cancelled(), 0);
3530 assert_eq!(basic_tracker.metrics().active(), 0);
3531 }
3532
3533 #[rstest]
3534 #[tokio::test]
3535 async fn test_task_failure(
3536 semaphore_scheduler: Arc<SemaphoreScheduler>,
3537 log_policy: Arc<LogOnlyPolicy>,
3538 ) {
3539 // Test task failure handling
3540 let tracker = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3541
3542 let handle = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("test error")) });
3543
3544 let result = handle.await.unwrap();
3545 assert!(result.is_err());
3546 assert!(matches!(result.unwrap_err(), TaskError::Failed(_)));
3547
3548 // Verify metrics
3549 assert_eq!(tracker.metrics().success(), 0);
3550 assert_eq!(tracker.metrics().failed(), 1);
3551 assert_eq!(tracker.metrics().cancelled(), 0);
3552 }
3553
3554 #[rstest]
3555 #[tokio::test]
3556 async fn test_semaphore_concurrency_limit(log_policy: Arc<LogOnlyPolicy>) {
3557 // Test that semaphore limits concurrent execution
3558 let limited_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2)))); // Only 2 concurrent tasks
3559 let tracker = TaskTracker::new(limited_scheduler, log_policy).unwrap();
3560
3561 let counter = Arc::new(AtomicU32::new(0));
3562 let max_concurrent = Arc::new(AtomicU32::new(0));
3563
3564 // Use broadcast channel to coordinate all tasks
3565 let (tx, _) = tokio::sync::broadcast::channel(1);
3566 let mut handles = Vec::new();
3567
3568 // Spawn 5 tasks that will track concurrency
3569 for _ in 0..5 {
3570 let counter_clone = counter.clone();
3571 let max_clone = max_concurrent.clone();
3572 let mut rx = tx.subscribe();
3573
3574 let handle = tracker.spawn(async move {
3575 // Increment active counter
3576 let current = counter_clone.fetch_add(1, Ordering::Relaxed) + 1;
3577
3578 // Track max concurrent
3579 max_clone.fetch_max(current, Ordering::Relaxed);
3580
3581 // Wait for signal to complete instead of sleep
3582 rx.recv().await.ok();
3583
3584 // Decrement when done
3585 counter_clone.fetch_sub(1, Ordering::Relaxed);
3586
3587 Ok(())
3588 });
3589 handles.push(handle);
3590 }
3591
3592 // Give tasks time to start and register concurrency
3593 tokio::task::yield_now().await;
3594 tokio::task::yield_now().await;
3595
3596 // Signal all tasks to complete
3597 tx.send(()).ok();
3598
3599 // Wait for all tasks to complete
3600 for handle in handles {
3601 handle.await.unwrap().unwrap();
3602 }
3603
3604 // Verify that no more than 2 tasks ran concurrently
3605 assert!(max_concurrent.load(Ordering::Relaxed) <= 2);
3606
3607 // Verify all tasks completed successfully
3608 assert_eq!(tracker.metrics().success(), 5);
3609 assert_eq!(tracker.metrics().failed(), 0);
3610 }
3611
3612 #[rstest]
3613 #[tokio::test]
3614 async fn test_cancel_on_error_policy() {
3615 // Test that CancelOnError policy works correctly
3616 let error_policy = cancel_policy();
3617 let scheduler = semaphore_scheduler();
3618 let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
3619
3620 // Spawn a task that will trigger cancellation
3621 let handle =
3622 tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("OutOfMemory error occurred")) });
3623
3624 // Wait for the error to occur
3625 let result = handle.await.unwrap();
3626 assert!(result.is_err());
3627
3628 // Give cancellation time to propagate
3629 tokio::time::sleep(Duration::from_millis(10)).await;
3630
3631 // Verify the cancel token was triggered
3632 assert!(tracker.cancellation_token().is_cancelled());
3633 }
3634
3635 #[rstest]
3636 #[tokio::test]
3637 async fn test_tracker_cancellation() {
3638 // Test manual cancellation of tracker with CancelOnError policy
3639 let error_policy = cancel_policy();
3640 let scheduler = semaphore_scheduler();
3641 let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
3642 let cancel_token = tracker.cancellation_token().child_token();
3643
3644 // Use oneshot channel instead of sleep for deterministic timing
3645 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
3646
3647 // Spawn a task that respects cancellation
3648 let handle = tracker.spawn({
3649 let cancel_token = cancel_token.clone();
3650 async move {
3651 tokio::select! {
3652 _ = rx => Ok(()),
3653 _ = cancel_token.cancelled() => Err(anyhow::anyhow!("Task was cancelled")),
3654 }
3655 }
3656 });
3657
3658 // Cancel the tracker
3659 tracker.cancel();
3660
3661 // Task should be cancelled
3662 let result = handle.await.unwrap();
3663 assert!(result.is_err());
3664 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
3665 }
3666
3667 #[rstest]
3668 #[tokio::test]
3669 async fn test_child_tracker_independence(
3670 semaphore_scheduler: Arc<SemaphoreScheduler>,
3671 log_policy: Arc<LogOnlyPolicy>,
3672 ) {
3673 // Test that child tracker has independent lifecycle
3674 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3675
3676 let child = parent.child_tracker().unwrap();
3677
3678 // Both should be operational initially
3679 assert!(!parent.is_closed());
3680 assert!(!child.is_closed());
3681
3682 // Cancel child only
3683 child.cancel();
3684
3685 // Parent should remain operational
3686 assert!(!parent.is_closed());
3687
3688 // Parent can still spawn tasks
3689 let handle = parent.spawn(async { Ok(42) });
3690 let result = handle.await.unwrap().unwrap();
3691 assert_eq!(result, 42);
3692 }
3693
3694 #[rstest]
3695 #[tokio::test]
3696 async fn test_independent_metrics(
3697 semaphore_scheduler: Arc<SemaphoreScheduler>,
3698 log_policy: Arc<LogOnlyPolicy>,
3699 ) {
3700 // Test that parent and child have independent metrics
3701 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3702 let child = parent.child_tracker().unwrap();
3703
3704 // Run tasks in parent
3705 let handle1 = parent.spawn(async { Ok(1) });
3706 handle1.await.unwrap().unwrap();
3707
3708 // Run tasks in child
3709 let handle2 = child.spawn(async { Ok(2) });
3710 handle2.await.unwrap().unwrap();
3711
3712 // Each should have their own metrics, but parent sees aggregated
3713 assert_eq!(parent.metrics().success(), 2); // Parent sees its own + child's
3714 assert_eq!(child.metrics().success(), 1); // Child sees only its own
3715 assert_eq!(parent.metrics().total_completed(), 2); // Parent sees aggregated total
3716 assert_eq!(child.metrics().total_completed(), 1); // Child sees only its own
3717 }
3718
3719 #[rstest]
3720 #[tokio::test]
3721 async fn test_cancel_on_error_hierarchy() {
3722 // Test that child error policy cancellation doesn't affect parent
3723 let parent_error_policy = cancel_policy();
3724 let scheduler = semaphore_scheduler();
3725 let parent = TaskTracker::new(scheduler, parent_error_policy).unwrap();
3726 let parent_policy_token = parent.cancellation_token().child_token();
3727 let child = parent.child_tracker().unwrap();
3728
3729 // Initially nothing should be cancelled
3730 assert!(!parent_policy_token.is_cancelled());
3731
3732 // Use explicit synchronization instead of sleep
3733 let (error_tx, error_rx) = tokio::sync::oneshot::channel();
3734 let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel();
3735
3736 // Spawn a monitoring task to watch for the parent policy token cancellation
3737 let parent_token_monitor = parent_policy_token.clone();
3738 let monitor_handle = tokio::spawn(async move {
3739 tokio::select! {
3740 _ = parent_token_monitor.cancelled() => {
3741 cancel_tx.send(true).ok();
3742 }
3743 _ = tokio::time::sleep(Duration::from_millis(100)) => {
3744 cancel_tx.send(false).ok();
3745 }
3746 }
3747 });
3748
3749 // Spawn a task in the child that will trigger cancellation
3750 let handle = child.spawn(async move {
3751 let result = Err::<(), _>(anyhow::anyhow!("OutOfMemory in child"));
3752 error_tx.send(()).ok(); // Signal that the error has occurred
3753 result
3754 });
3755
3756 // Wait for the error to occur
3757 let error_result = handle.await.unwrap();
3758 assert!(error_result.is_err());
3759
3760 // Wait for our error signal
3761 error_rx.await.ok();
3762
3763 // Check if parent policy token was cancelled within timeout
3764 let was_cancelled = cancel_rx.await.unwrap_or(false);
3765 monitor_handle.await.ok();
3766
3767 // Based on hierarchical design: child errors should NOT affect parent
3768 // The child gets its own policy with a child token, and child cancellation
3769 // should not propagate up to the parent policy token
3770 assert!(
3771 !was_cancelled,
3772 "Parent policy token should not be cancelled by child errors"
3773 );
3774 assert!(
3775 !parent_policy_token.is_cancelled(),
3776 "Parent policy token should remain active"
3777 );
3778 }
3779
3780 #[rstest]
3781 #[tokio::test]
3782 async fn test_graceful_shutdown(
3783 semaphore_scheduler: Arc<SemaphoreScheduler>,
3784 log_policy: Arc<LogOnlyPolicy>,
3785 ) {
3786 // Test graceful shutdown with close()
3787 let tracker = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
3788
3789 // Use broadcast channel to coordinate task completion
3790 let (tx, _) = tokio::sync::broadcast::channel(1);
3791 let mut handles = Vec::new();
3792
3793 // Spawn some tasks
3794 for i in 0..3 {
3795 let mut rx = tx.subscribe();
3796 let handle = tracker.spawn(async move {
3797 // Wait for signal instead of sleep
3798 rx.recv().await.ok();
3799 Ok(i)
3800 });
3801 handles.push(handle);
3802 }
3803
3804 // Signal all tasks to complete before closing
3805 tx.send(()).ok();
3806
3807 // Close tracker and wait for completion
3808 tracker.join().await;
3809
3810 // All tasks should complete successfully
3811 for handle in handles {
3812 let result = handle.await.unwrap().unwrap();
3813 assert!(result < 3);
3814 }
3815
3816 // Tracker should be closed
3817 assert!(tracker.is_closed());
3818 }
3819
3820 #[rstest]
3821 #[tokio::test]
3822 async fn test_semaphore_scheduler_permit_tracking(log_policy: Arc<LogOnlyPolicy>) {
3823 // Test that SemaphoreScheduler properly tracks permits
3824 let semaphore = Arc::new(Semaphore::new(3));
3825 let scheduler = Arc::new(SemaphoreScheduler::new(semaphore.clone()));
3826 let tracker = TaskTracker::new(scheduler.clone(), log_policy).unwrap();
3827
3828 // Initially all permits should be available
3829 assert_eq!(scheduler.available_permits(), 3);
3830
3831 // Use broadcast channel to coordinate task completion
3832 let (tx, _) = tokio::sync::broadcast::channel(1);
3833 let mut handles = Vec::new();
3834
3835 // Spawn 3 tasks that will hold permits
3836 for _ in 0..3 {
3837 let mut rx = tx.subscribe();
3838 let handle = tracker.spawn(async move {
3839 // Wait for signal to complete
3840 rx.recv().await.ok();
3841 Ok(())
3842 });
3843 handles.push(handle);
3844 }
3845
3846 // Give tasks time to acquire permits
3847 tokio::task::yield_now().await;
3848 tokio::task::yield_now().await;
3849
3850 // All permits should be taken
3851 assert_eq!(scheduler.available_permits(), 0);
3852
3853 // Signal all tasks to complete
3854 tx.send(()).ok();
3855
3856 // Wait for tasks to complete
3857 for handle in handles {
3858 handle.await.unwrap().unwrap();
3859 }
3860
3861 // All permits should be available again
3862 assert_eq!(scheduler.available_permits(), 3);
3863 }
3864
3865 #[rstest]
3866 #[tokio::test]
3867 async fn test_builder_pattern(log_policy: Arc<LogOnlyPolicy>) {
3868 // Test that TaskTracker builder works correctly
3869 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3870 let error_policy = log_policy;
3871
3872 let tracker = TaskTracker::builder()
3873 .scheduler(scheduler)
3874 .error_policy(error_policy)
3875 .build()
3876 .unwrap();
3877
3878 // Tracker should have a cancellation token
3879 let token = tracker.cancellation_token();
3880 assert!(!token.is_cancelled());
3881
3882 // Should be able to spawn tasks
3883 let handle = tracker.spawn(async { Ok(42) });
3884 let result = handle.await.unwrap().unwrap();
3885 assert_eq!(result, 42);
3886 }
3887
3888 #[rstest]
3889 #[tokio::test]
3890 async fn test_all_trackers_have_cancellation_tokens(log_policy: Arc<LogOnlyPolicy>) {
3891 // Test that all trackers (root and children) have cancellation tokens
3892 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3893 let root = TaskTracker::new(scheduler, log_policy).unwrap();
3894 let child = root.child_tracker().unwrap();
3895 let grandchild = child.child_tracker().unwrap();
3896
3897 // All should have cancellation tokens
3898 let root_token = root.cancellation_token();
3899 let child_token = child.cancellation_token();
3900 let grandchild_token = grandchild.cancellation_token();
3901
3902 assert!(!root_token.is_cancelled());
3903 assert!(!child_token.is_cancelled());
3904 assert!(!grandchild_token.is_cancelled());
3905
3906 // Child tokens should be different from parent
3907 // (We can't directly compare tokens, but we can test behavior)
3908 root_token.cancel();
3909
3910 // Give cancellation time to propagate
3911 tokio::time::sleep(Duration::from_millis(10)).await;
3912
3913 // Root should be cancelled
3914 assert!(root_token.is_cancelled());
3915 // Children should also be cancelled (because they are child tokens)
3916 assert!(child_token.is_cancelled());
3917 assert!(grandchild_token.is_cancelled());
3918 }
3919
3920 #[rstest]
3921 #[tokio::test]
3922 async fn test_spawn_cancellable_task(log_policy: Arc<LogOnlyPolicy>) {
3923 // Test cancellable task spawning with proper result handling
3924 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3925 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
3926
3927 // Test successful completion
3928 let (tx, rx) = tokio::sync::oneshot::channel();
3929 let rx = Arc::new(tokio::sync::Mutex::new(Some(rx)));
3930 let handle = tracker.spawn_cancellable(move |_cancel_token| {
3931 let rx = rx.clone();
3932 async move {
3933 // Wait for signal instead of sleep
3934 if let Some(rx) = rx.lock().await.take() {
3935 rx.await.ok();
3936 }
3937 CancellableTaskResult::Ok(42)
3938 }
3939 });
3940
3941 // Signal task to complete
3942 tx.send(()).ok();
3943
3944 let result = handle.await.unwrap().unwrap();
3945 assert_eq!(result, 42);
3946 assert_eq!(tracker.metrics().success(), 1);
3947
3948 // Test cancellation handling
3949 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
3950 let rx = Arc::new(tokio::sync::Mutex::new(Some(rx)));
3951 let handle = tracker.spawn_cancellable(move |cancel_token| {
3952 let rx = rx.clone();
3953 async move {
3954 tokio::select! {
3955 _ = async {
3956 if let Some(rx) = rx.lock().await.take() {
3957 rx.await.ok();
3958 }
3959 } => CancellableTaskResult::Ok("should not complete"),
3960 _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
3961 }
3962 }
3963 });
3964
3965 // Cancel the tracker
3966 tracker.cancel();
3967
3968 let result = handle.await.unwrap();
3969 assert!(result.is_err());
3970 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
3971 }
3972
3973 #[rstest]
3974 #[tokio::test]
3975 async fn test_cancellable_task_metrics_tracking(log_policy: Arc<LogOnlyPolicy>) {
3976 // Test that properly cancelled tasks increment cancelled metrics, not failed metrics
3977 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
3978 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
3979
3980 // Baseline metrics
3981 assert_eq!(tracker.metrics().cancelled(), 0);
3982 assert_eq!(tracker.metrics().failed(), 0);
3983 assert_eq!(tracker.metrics().success(), 0);
3984
3985 // Test 1: Task that executes and THEN gets cancelled during execution
3986 let (start_tx, start_rx) = tokio::sync::oneshot::channel::<()>();
3987 let (_continue_tx, continue_rx) = tokio::sync::oneshot::channel::<()>();
3988
3989 let start_tx_shared = Arc::new(tokio::sync::Mutex::new(Some(start_tx)));
3990 let continue_rx_shared = Arc::new(tokio::sync::Mutex::new(Some(continue_rx)));
3991
3992 let start_tx_for_task = start_tx_shared.clone();
3993 let continue_rx_for_task = continue_rx_shared.clone();
3994
3995 let handle = tracker.spawn_cancellable(move |cancel_token| {
3996 let start_tx = start_tx_for_task.clone();
3997 let continue_rx = continue_rx_for_task.clone();
3998 async move {
3999 // Signal that we've started executing
4000 if let Some(tx) = start_tx.lock().await.take() {
4001 tx.send(()).ok();
4002 }
4003
4004 // Wait for either continuation signal or cancellation
4005 tokio::select! {
4006 _ = async {
4007 if let Some(rx) = continue_rx.lock().await.take() {
4008 rx.await.ok();
4009 }
4010 } => CancellableTaskResult::Ok("completed normally"),
4011 _ = cancel_token.cancelled() => {
4012 println!("Task detected cancellation and is returning Cancelled");
4013 CancellableTaskResult::Cancelled
4014 },
4015 }
4016 }
4017 });
4018
4019 // Wait for task to start executing
4020 start_rx.await.ok();
4021
4022 // Now cancel while the task is running
4023 println!("Cancelling tracker while task is executing...");
4024 tracker.cancel();
4025
4026 // Wait for the task to complete
4027 let result = handle.await.unwrap();
4028
4029 // Debug output
4030 println!("Task result: {:?}", result);
4031 println!(
4032 "Cancelled: {}, Failed: {}, Success: {}",
4033 tracker.metrics().cancelled(),
4034 tracker.metrics().failed(),
4035 tracker.metrics().success()
4036 );
4037
4038 // The task should be properly cancelled and counted correctly
4039 assert!(result.is_err());
4040 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
4041
4042 // Verify proper metrics: should be counted as cancelled, not failed
4043 assert_eq!(
4044 tracker.metrics().cancelled(),
4045 1,
4046 "Properly cancelled task should increment cancelled count"
4047 );
4048 assert_eq!(
4049 tracker.metrics().failed(),
4050 0,
4051 "Properly cancelled task should NOT increment failed count"
4052 );
4053 }
4054
4055 #[rstest]
4056 #[tokio::test]
4057 async fn test_cancellable_vs_error_metrics_distinction(log_policy: Arc<LogOnlyPolicy>) {
4058 // Test that we properly distinguish between cancellation and actual errors
4059 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
4060 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4061
4062 // Test 1: Actual error should increment failed count
4063 let handle1 = tracker.spawn_cancellable(|_cancel_token| async move {
4064 CancellableTaskResult::<i32>::Err(anyhow::anyhow!("This is a real error"))
4065 });
4066
4067 let result1 = handle1.await.unwrap();
4068 assert!(result1.is_err());
4069 assert!(matches!(result1.unwrap_err(), TaskError::Failed(_)));
4070 assert_eq!(tracker.metrics().failed(), 1);
4071 assert_eq!(tracker.metrics().cancelled(), 0);
4072
4073 // Test 2: Cancellation should increment cancelled count
4074 let handle2 = tracker.spawn_cancellable(|_cancel_token| async move {
4075 CancellableTaskResult::<i32>::Cancelled
4076 });
4077
4078 let result2 = handle2.await.unwrap();
4079 assert!(result2.is_err());
4080 assert!(matches!(result2.unwrap_err(), TaskError::Cancelled));
4081 assert_eq!(tracker.metrics().failed(), 1); // Still 1 from before
4082 assert_eq!(tracker.metrics().cancelled(), 1); // Now 1 from cancellation
4083 }
4084
4085 #[rstest]
4086 #[tokio::test]
4087 async fn test_spawn_cancellable_error_handling(log_policy: Arc<LogOnlyPolicy>) {
4088 // Test error handling in cancellable tasks
4089 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
4090 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4091
4092 // Test error result
4093 let handle = tracker.spawn_cancellable(|_cancel_token| async move {
4094 CancellableTaskResult::<i32>::Err(anyhow::anyhow!("test error"))
4095 });
4096
4097 let result = handle.await.unwrap();
4098 assert!(result.is_err());
4099 assert!(matches!(result.unwrap_err(), TaskError::Failed(_)));
4100 assert_eq!(tracker.metrics().failed(), 1);
4101 }
4102
4103 #[rstest]
4104 #[tokio::test]
4105 async fn test_cancellation_before_execution(log_policy: Arc<LogOnlyPolicy>) {
4106 // Test that spawning on a cancelled tracker panics (new behavior)
4107 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(1))));
4108 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4109
4110 // Cancel the tracker first
4111 tracker.cancel();
4112
4113 // Give cancellation time to propagate to the inner tracker
4114 tokio::time::sleep(Duration::from_millis(5)).await;
4115
4116 // Now try to spawn a task - it should panic since tracker is closed
4117 let panic_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
4118 tracker.spawn(async { Ok(42) })
4119 }));
4120
4121 // Should panic with our new API
4122 assert!(
4123 panic_result.is_err(),
4124 "spawn() should panic when tracker is closed"
4125 );
4126
4127 // Verify the panic message contains expected text
4128 if let Err(panic_payload) = panic_result {
4129 if let Some(panic_msg) = panic_payload.downcast_ref::<String>() {
4130 assert!(
4131 panic_msg.contains("TaskTracker must not be closed"),
4132 "Panic message should indicate tracker is closed: {}",
4133 panic_msg
4134 );
4135 } else if let Some(panic_msg) = panic_payload.downcast_ref::<&str>() {
4136 assert!(
4137 panic_msg.contains("TaskTracker must not be closed"),
4138 "Panic message should indicate tracker is closed: {}",
4139 panic_msg
4140 );
4141 }
4142 }
4143 }
4144
4145 #[rstest]
4146 #[tokio::test]
4147 async fn test_semaphore_scheduler_with_cancellation(log_policy: Arc<LogOnlyPolicy>) {
4148 // Test that SemaphoreScheduler respects cancellation tokens
4149 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(1))));
4150 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4151
4152 // Start a long-running task to occupy the semaphore
4153 let blocker_token = tracker.cancellation_token();
4154 let _blocker_handle = tracker.spawn(async move {
4155 // Wait for cancellation
4156 blocker_token.cancelled().await;
4157 Ok(())
4158 });
4159
4160 // Give the blocker time to acquire the permit
4161 tokio::task::yield_now().await;
4162
4163 // Use oneshot channel for the second task
4164 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
4165
4166 // Spawn another task that will wait for semaphore
4167 let handle = tracker.spawn(async {
4168 rx.await.ok();
4169 Ok(42)
4170 });
4171
4172 // Cancel the tracker while second task is waiting for permit
4173 tracker.cancel();
4174
4175 // The waiting task should be cancelled
4176 let result = handle.await.unwrap();
4177 assert!(result.is_err());
4178 assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
4179 }
4180
4181 #[rstest]
4182 #[tokio::test]
4183 async fn test_child_tracker_cancellation_independence(
4184 semaphore_scheduler: Arc<SemaphoreScheduler>,
4185 log_policy: Arc<LogOnlyPolicy>,
4186 ) {
4187 // Test that child tracker cancellation doesn't affect parent
4188 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4189 let child = parent.child_tracker().unwrap();
4190
4191 // Cancel only the child
4192 child.cancel();
4193
4194 // Parent should still be operational
4195 let parent_token = parent.cancellation_token();
4196 assert!(!parent_token.is_cancelled());
4197
4198 // Parent can still spawn tasks
4199 let handle = parent.spawn(async { Ok(42) });
4200 let result = handle.await.unwrap().unwrap();
4201 assert_eq!(result, 42);
4202
4203 // Child should be cancelled
4204 let child_token = child.cancellation_token();
4205 assert!(child_token.is_cancelled());
4206 }
4207
4208 #[rstest]
4209 #[tokio::test]
4210 async fn test_parent_cancellation_propagates_to_children(
4211 semaphore_scheduler: Arc<SemaphoreScheduler>,
4212 log_policy: Arc<LogOnlyPolicy>,
4213 ) {
4214 // Test that parent cancellation propagates to all children
4215 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4216 let child1 = parent.child_tracker().unwrap();
4217 let child2 = parent.child_tracker().unwrap();
4218 let grandchild = child1.child_tracker().unwrap();
4219
4220 // Cancel the parent
4221 parent.cancel();
4222
4223 // Give cancellation time to propagate
4224 tokio::time::sleep(Duration::from_millis(10)).await;
4225
4226 // All should be cancelled
4227 assert!(parent.cancellation_token().is_cancelled());
4228 assert!(child1.cancellation_token().is_cancelled());
4229 assert!(child2.cancellation_token().is_cancelled());
4230 assert!(grandchild.cancellation_token().is_cancelled());
4231 }
4232
4233 #[rstest]
4234 #[tokio::test]
4235 async fn test_issued_counter_tracking(log_policy: Arc<LogOnlyPolicy>) {
4236 // Test that issued counter is incremented when tasks are spawned
4237 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2))));
4238 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4239
4240 // Initially no tasks issued
4241 assert_eq!(tracker.metrics().issued(), 0);
4242 assert_eq!(tracker.metrics().pending(), 0);
4243
4244 // Spawn some tasks
4245 let handle1 = tracker.spawn(async { Ok(1) });
4246 let handle2 = tracker.spawn(async { Ok(2) });
4247 let handle3 = tracker.spawn_cancellable(|_| async { CancellableTaskResult::Ok(3) });
4248
4249 // Issued counter should be incremented immediately
4250 assert_eq!(tracker.metrics().issued(), 3);
4251 assert_eq!(tracker.metrics().pending(), 3); // None completed yet
4252
4253 // Complete the tasks
4254 assert_eq!(handle1.await.unwrap().unwrap(), 1);
4255 assert_eq!(handle2.await.unwrap().unwrap(), 2);
4256 assert_eq!(handle3.await.unwrap().unwrap(), 3);
4257
4258 // Check final accounting
4259 assert_eq!(tracker.metrics().issued(), 3);
4260 assert_eq!(tracker.metrics().success(), 3);
4261 assert_eq!(tracker.metrics().total_completed(), 3);
4262 assert_eq!(tracker.metrics().pending(), 0); // All completed
4263
4264 // Test hierarchical accounting
4265 let child = tracker.child_tracker().unwrap();
4266 let child_handle = child.spawn(async { Ok(42) });
4267
4268 // Both parent and child should see the issued task
4269 assert_eq!(child.metrics().issued(), 1);
4270 assert_eq!(tracker.metrics().issued(), 4); // Parent sees all
4271
4272 child_handle.await.unwrap().unwrap();
4273
4274 // Final hierarchical check
4275 assert_eq!(child.metrics().pending(), 0);
4276 assert_eq!(tracker.metrics().pending(), 0);
4277 assert_eq!(tracker.metrics().success(), 4); // Parent sees all successes
4278 }
4279
4280 #[rstest]
4281 #[tokio::test]
4282 async fn test_child_tracker_builder(log_policy: Arc<LogOnlyPolicy>) {
4283 // Test that child tracker builder allows custom policies
4284 let parent_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
4285 let parent = TaskTracker::new(parent_scheduler, log_policy).unwrap();
4286
4287 // Create child with custom error policy
4288 let child_error_policy = CancelOnError::new();
4289 let child = parent
4290 .child_tracker_builder()
4291 .error_policy(child_error_policy)
4292 .build()
4293 .unwrap();
4294
4295 // Test that child works
4296 let handle = child.spawn(async { Ok(42) });
4297 let result = handle.await.unwrap().unwrap();
4298 assert_eq!(result, 42);
4299
4300 // Child should have its own metrics
4301 assert_eq!(child.metrics().success(), 1);
4302 assert_eq!(parent.metrics().total_completed(), 1); // Parent sees aggregated
4303 }
4304
4305 #[rstest]
4306 #[tokio::test]
4307 async fn test_hierarchical_metrics_aggregation(log_policy: Arc<LogOnlyPolicy>) {
4308 // Test that child metrics aggregate up to parent
4309 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
4310 let parent = TaskTracker::new(scheduler, log_policy.clone()).unwrap();
4311
4312 // Create child1 with default settings
4313 let child1 = parent.child_tracker().unwrap();
4314
4315 // Create child2 with custom error policy
4316 let child_error_policy = CancelOnError::new();
4317 let child2 = parent
4318 .child_tracker_builder()
4319 .error_policy(child_error_policy)
4320 .build()
4321 .unwrap();
4322
4323 // Test both custom schedulers and policies
4324 let another_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(3))));
4325 let another_error_policy = CancelOnError::new();
4326 let child3 = parent
4327 .child_tracker_builder()
4328 .scheduler(another_scheduler)
4329 .error_policy(another_error_policy)
4330 .build()
4331 .unwrap();
4332
4333 // Test that all children are properly registered
4334 assert_eq!(parent.child_count(), 3);
4335
4336 // Test that custom schedulers work
4337 let handle1 = child1.spawn(async { Ok(1) });
4338 let handle2 = child2.spawn(async { Ok(2) });
4339 let handle3 = child3.spawn(async { Ok(3) });
4340
4341 assert_eq!(handle1.await.unwrap().unwrap(), 1);
4342 assert_eq!(handle2.await.unwrap().unwrap(), 2);
4343 assert_eq!(handle3.await.unwrap().unwrap(), 3);
4344
4345 // Verify metrics still work
4346 assert_eq!(parent.metrics().success(), 3); // All child successes roll up
4347 assert_eq!(child1.metrics().success(), 1);
4348 assert_eq!(child2.metrics().success(), 1);
4349 assert_eq!(child3.metrics().success(), 1);
4350 }
4351
4352 #[rstest]
4353 #[tokio::test]
4354 async fn test_scheduler_queue_depth_calculation(log_policy: Arc<LogOnlyPolicy>) {
4355 // Test that we can calculate tasks queued in scheduler
4356 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2)))); // Only 2 concurrent tasks
4357 let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
4358
4359 // Initially no tasks
4360 assert_eq!(tracker.metrics().issued(), 0);
4361 assert_eq!(tracker.metrics().active(), 0);
4362 assert_eq!(tracker.metrics().queued(), 0);
4363 assert_eq!(tracker.metrics().pending(), 0);
4364
4365 // Use a channel to control when tasks complete
4366 let (complete_tx, _complete_rx) = tokio::sync::broadcast::channel(1);
4367
4368 // Spawn 2 tasks that will hold semaphore permits
4369 let handle1 = tracker.spawn({
4370 let mut rx = complete_tx.subscribe();
4371 async move {
4372 // Wait for completion signal
4373 rx.recv().await.ok();
4374 Ok(1)
4375 }
4376 });
4377 let handle2 = tracker.spawn({
4378 let mut rx = complete_tx.subscribe();
4379 async move {
4380 // Wait for completion signal
4381 rx.recv().await.ok();
4382 Ok(2)
4383 }
4384 });
4385
4386 // Give tasks time to start and acquire permits
4387 tokio::task::yield_now().await;
4388 tokio::task::yield_now().await;
4389
4390 // Should have 2 active tasks, 0 queued
4391 assert_eq!(tracker.metrics().issued(), 2);
4392 assert_eq!(tracker.metrics().active(), 2);
4393 assert_eq!(tracker.metrics().queued(), 0);
4394 assert_eq!(tracker.metrics().pending(), 2);
4395
4396 // Spawn a third task - should be queued since semaphore is full
4397 let handle3 = tracker.spawn(async move { Ok(3) });
4398
4399 // Give time for task to be queued
4400 tokio::task::yield_now().await;
4401
4402 // Should have 2 active, 1 queued
4403 assert_eq!(tracker.metrics().issued(), 3);
4404 assert_eq!(tracker.metrics().active(), 2);
4405 assert_eq!(
4406 tracker.metrics().queued(),
4407 tracker.metrics().pending() - tracker.metrics().active()
4408 );
4409 assert_eq!(tracker.metrics().pending(), 3);
4410
4411 // Complete all tasks by sending the signal
4412 complete_tx.send(()).ok();
4413
4414 let result1 = handle1.await.unwrap().unwrap();
4415 let result2 = handle2.await.unwrap().unwrap();
4416 let result3 = handle3.await.unwrap().unwrap();
4417
4418 assert_eq!(result1, 1);
4419 assert_eq!(result2, 2);
4420 assert_eq!(result3, 3);
4421
4422 // All tasks should be completed
4423 assert_eq!(tracker.metrics().success(), 3);
4424 assert_eq!(tracker.metrics().active(), 0);
4425 assert_eq!(tracker.metrics().queued(), 0);
4426 assert_eq!(tracker.metrics().pending(), 0);
4427 }
4428
4429 #[rstest]
4430 #[tokio::test]
4431 async fn test_hierarchical_metrics_failure_aggregation(
4432 semaphore_scheduler: Arc<SemaphoreScheduler>,
4433 log_policy: Arc<LogOnlyPolicy>,
4434 ) {
4435 // Test that failed task metrics aggregate up to parent
4436 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4437 let child = parent.child_tracker().unwrap();
4438
4439 // Run some successful and some failed tasks
4440 let success_handle = child.spawn(async { Ok(42) });
4441 let failure_handle = child.spawn(async { Err::<(), _>(anyhow::anyhow!("test error")) });
4442
4443 // Wait for tasks to complete
4444 let _success_result = success_handle.await.unwrap().unwrap();
4445 let _failure_result = failure_handle.await.unwrap().unwrap_err();
4446
4447 // Check child metrics
4448 assert_eq!(child.metrics().success(), 1, "Child should have 1 success");
4449 assert_eq!(child.metrics().failed(), 1, "Child should have 1 failure");
4450
4451 // Parent should see the aggregated metrics
4452 // Note: Due to hierarchical aggregation, these metrics propagate up
4453 }
4454
4455 #[rstest]
4456 #[tokio::test]
4457 async fn test_metrics_independence_between_tracker_instances(
4458 semaphore_scheduler: Arc<SemaphoreScheduler>,
4459 log_policy: Arc<LogOnlyPolicy>,
4460 ) {
4461 // Test that different tracker instances have independent metrics
4462 let tracker1 = TaskTracker::new(semaphore_scheduler.clone(), log_policy.clone()).unwrap();
4463 let tracker2 = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4464
4465 // Run tasks in both trackers
4466 let handle1 = tracker1.spawn(async { Ok(1) });
4467 let handle2 = tracker2.spawn(async { Ok(2) });
4468
4469 handle1.await.unwrap().unwrap();
4470 handle2.await.unwrap().unwrap();
4471
4472 // Each tracker should only see its own metrics
4473 assert_eq!(tracker1.metrics().success(), 1);
4474 assert_eq!(tracker2.metrics().success(), 1);
4475 assert_eq!(tracker1.metrics().total_completed(), 1);
4476 assert_eq!(tracker2.metrics().total_completed(), 1);
4477 }
4478
4479 #[rstest]
4480 #[tokio::test]
4481 async fn test_hierarchical_join_waits_for_all(log_policy: Arc<LogOnlyPolicy>) {
4482 // Test that parent.join() waits for child tasks too
4483 let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
4484 let parent = TaskTracker::new(scheduler, log_policy).unwrap();
4485 let child1 = parent.child_tracker().unwrap();
4486 let child2 = parent.child_tracker().unwrap();
4487 let grandchild = child1.child_tracker().unwrap();
4488
4489 // Verify parent tracks children
4490 assert_eq!(parent.child_count(), 2);
4491 assert_eq!(child1.child_count(), 1);
4492 assert_eq!(child2.child_count(), 0);
4493 assert_eq!(grandchild.child_count(), 0);
4494
4495 // Track completion order
4496 let completion_order = Arc::new(Mutex::new(Vec::new()));
4497
4498 // Spawn tasks with different durations
4499 let order_clone = completion_order.clone();
4500 let parent_handle = parent.spawn(async move {
4501 tokio::time::sleep(Duration::from_millis(50)).await;
4502 order_clone.lock().unwrap().push("parent");
4503 Ok(())
4504 });
4505
4506 let order_clone = completion_order.clone();
4507 let child1_handle = child1.spawn(async move {
4508 tokio::time::sleep(Duration::from_millis(100)).await;
4509 order_clone.lock().unwrap().push("child1");
4510 Ok(())
4511 });
4512
4513 let order_clone = completion_order.clone();
4514 let child2_handle = child2.spawn(async move {
4515 tokio::time::sleep(Duration::from_millis(75)).await;
4516 order_clone.lock().unwrap().push("child2");
4517 Ok(())
4518 });
4519
4520 let order_clone = completion_order.clone();
4521 let grandchild_handle = grandchild.spawn(async move {
4522 tokio::time::sleep(Duration::from_millis(125)).await;
4523 order_clone.lock().unwrap().push("grandchild");
4524 Ok(())
4525 });
4526
4527 // Test hierarchical join - should wait for ALL tasks in hierarchy
4528 println!("[TEST] About to call parent.join()");
4529 let start = std::time::Instant::now();
4530 parent.join().await; // This should wait for ALL tasks
4531 let elapsed = start.elapsed();
4532 println!("[TEST] parent.join() completed in {:?}", elapsed);
4533
4534 // Should have waited for the longest task (grandchild at 125ms)
4535 assert!(
4536 elapsed >= Duration::from_millis(120),
4537 "Hierarchical join should wait for longest task"
4538 );
4539
4540 // All tasks should be complete
4541 assert!(parent_handle.is_finished());
4542 assert!(child1_handle.is_finished());
4543 assert!(child2_handle.is_finished());
4544 assert!(grandchild_handle.is_finished());
4545
4546 // Verify all tasks completed
4547 let final_order = completion_order.lock().unwrap();
4548 assert_eq!(final_order.len(), 4);
4549 assert!(final_order.contains(&"parent"));
4550 assert!(final_order.contains(&"child1"));
4551 assert!(final_order.contains(&"child2"));
4552 assert!(final_order.contains(&"grandchild"));
4553 }
4554
4555 #[rstest]
4556 #[tokio::test]
4557 async fn test_hierarchical_join_waits_for_children(
4558 semaphore_scheduler: Arc<SemaphoreScheduler>,
4559 log_policy: Arc<LogOnlyPolicy>,
4560 ) {
4561 // Test that join() waits for child tasks (hierarchical behavior)
4562 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4563 let child = parent.child_tracker().unwrap();
4564
4565 // Spawn a quick parent task and slow child task
4566 let _parent_handle = parent.spawn(async {
4567 tokio::time::sleep(Duration::from_millis(20)).await;
4568 Ok(())
4569 });
4570
4571 let _child_handle = child.spawn(async {
4572 tokio::time::sleep(Duration::from_millis(100)).await;
4573 Ok(())
4574 });
4575
4576 // Hierarchical join should wait for both parent and child tasks
4577 let start = std::time::Instant::now();
4578 parent.join().await; // Should wait for both (hierarchical by default)
4579 let elapsed = start.elapsed();
4580
4581 // Should have waited for the longer child task (100ms)
4582 assert!(
4583 elapsed >= Duration::from_millis(90),
4584 "Hierarchical join should wait for all child tasks"
4585 );
4586 }
4587
4588 #[rstest]
4589 #[tokio::test]
4590 async fn test_hierarchical_join_operations(
4591 semaphore_scheduler: Arc<SemaphoreScheduler>,
4592 log_policy: Arc<LogOnlyPolicy>,
4593 ) {
4594 // Test that parent.join() closes and waits for child trackers too
4595 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4596 let child = parent.child_tracker().unwrap();
4597 let grandchild = child.child_tracker().unwrap();
4598
4599 // Verify trackers start as open
4600 assert!(!parent.is_closed());
4601 assert!(!child.is_closed());
4602 assert!(!grandchild.is_closed());
4603
4604 // Join parent (hierarchical by default - closes and waits for all)
4605 parent.join().await;
4606
4607 // All should be closed (check child trackers since parent was moved)
4608 assert!(child.is_closed());
4609 assert!(grandchild.is_closed());
4610 }
4611
4612 #[rstest]
4613 #[tokio::test]
4614 async fn test_unlimited_scheduler() {
4615 // Test that UnlimitedScheduler executes tasks immediately
4616 let scheduler = UnlimitedScheduler::new();
4617 let error_policy = LogOnlyPolicy::new();
4618 let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
4619
4620 let (tx, rx) = tokio::sync::oneshot::channel();
4621 let handle = tracker.spawn(async {
4622 rx.await.ok();
4623 Ok(42)
4624 });
4625
4626 // Task should be ready to execute immediately (no concurrency limit)
4627 tx.send(()).ok();
4628 let result = handle.await.unwrap().unwrap();
4629 assert_eq!(result, 42);
4630
4631 assert_eq!(tracker.metrics().success(), 1);
4632 }
4633
4634 #[rstest]
4635 #[tokio::test]
4636 async fn test_threshold_cancel_policy(semaphore_scheduler: Arc<SemaphoreScheduler>) {
4637 // Test that ThresholdCancelPolicy now uses per-task failure counting
4638 let error_policy = ThresholdCancelPolicy::with_threshold(2); // Cancel after 2 failures per task
4639 let tracker = TaskTracker::new(semaphore_scheduler, error_policy.clone()).unwrap();
4640 let cancel_token = tracker.cancellation_token().child_token();
4641
4642 // With per-task context, individual task failures don't accumulate
4643 // Each task starts with failure_count = 0, so single failures won't trigger cancellation
4644 let _handle1 = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("First failure")) });
4645 tokio::task::yield_now().await;
4646 assert!(!cancel_token.is_cancelled());
4647 assert_eq!(error_policy.failure_count(), 1); // Global counter still increments
4648
4649 // Second failure from different task - still won't trigger cancellation
4650 let _handle2 = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("Second failure")) });
4651 tokio::task::yield_now().await;
4652 assert!(!cancel_token.is_cancelled()); // Per-task context prevents cancellation
4653 assert_eq!(error_policy.failure_count(), 2); // Global counter increments
4654
4655 // For cancellation to occur, a single task would need to fail multiple times
4656 // through continuations (which would require a more complex test setup)
4657 }
4658
4659 #[tokio::test]
4660 async fn test_policy_constructors() {
4661 // Test that all constructors follow the new clean API patterns
4662 let _unlimited = UnlimitedScheduler::new();
4663 let _semaphore = SemaphoreScheduler::with_permits(5);
4664 let _log_only = LogOnlyPolicy::new();
4665 let _cancel_policy = CancelOnError::new();
4666 let _threshold_policy = ThresholdCancelPolicy::with_threshold(3);
4667 let _rate_policy = RateCancelPolicy::builder()
4668 .rate(0.5)
4669 .window_secs(60)
4670 .build();
4671
4672 // All constructors return Arc directly - no more ugly ::new_arc patterns
4673 // This test ensures the clean API reduces boilerplate
4674 }
4675
4676 #[rstest]
4677 #[tokio::test]
4678 async fn test_child_creation_fails_after_join(
4679 semaphore_scheduler: Arc<SemaphoreScheduler>,
4680 log_policy: Arc<LogOnlyPolicy>,
4681 ) {
4682 // Test that child tracker creation fails from closed parent
4683 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4684
4685 // Initially, creating a child should work
4686 let _child = parent.child_tracker().unwrap();
4687
4688 // Close the parent tracker
4689 let parent_clone = parent.clone();
4690 parent.join().await;
4691 assert!(parent_clone.is_closed());
4692
4693 // Now, trying to create a child should fail
4694 let result = parent_clone.child_tracker();
4695 assert!(result.is_err());
4696 assert!(
4697 result
4698 .err()
4699 .unwrap()
4700 .to_string()
4701 .contains("closed parent tracker")
4702 );
4703 }
4704
4705 #[rstest]
4706 #[tokio::test]
4707 async fn test_child_builder_fails_after_join(
4708 semaphore_scheduler: Arc<SemaphoreScheduler>,
4709 log_policy: Arc<LogOnlyPolicy>,
4710 ) {
4711 // Test that child tracker builder creation fails from closed parent
4712 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4713
4714 // Initially, creating a child with builder should work
4715 let _child = parent.child_tracker_builder().build().unwrap();
4716
4717 // Close the parent tracker
4718 let parent_clone = parent.clone();
4719 parent.join().await;
4720 assert!(parent_clone.is_closed());
4721
4722 // Now, trying to create a child with builder should fail
4723 let result = parent_clone.child_tracker_builder().build();
4724 assert!(result.is_err());
4725 assert!(
4726 result
4727 .err()
4728 .unwrap()
4729 .to_string()
4730 .contains("closed parent tracker")
4731 );
4732 }
4733
4734 #[rstest]
4735 #[tokio::test]
4736 async fn test_child_creation_succeeds_before_join(
4737 semaphore_scheduler: Arc<SemaphoreScheduler>,
4738 log_policy: Arc<LogOnlyPolicy>,
4739 ) {
4740 // Test that child creation works normally before parent is joined
4741 let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
4742
4743 // Should be able to create multiple children before closing
4744 let child1 = parent.child_tracker().unwrap();
4745 let child2 = parent.child_tracker_builder().build().unwrap();
4746
4747 // Verify children can spawn tasks
4748 let handle1 = child1.spawn(async { Ok(42) });
4749 let handle2 = child2.spawn(async { Ok(24) });
4750
4751 let result1 = handle1.await.unwrap().unwrap();
4752 let result2 = handle2.await.unwrap().unwrap();
4753
4754 assert_eq!(result1, 42);
4755 assert_eq!(result2, 24);
4756 assert_eq!(parent.metrics().success(), 2); // Parent sees all successes
4757 }
4758
4759 #[rstest]
4760 #[tokio::test]
4761 async fn test_custom_error_response_with_cancellation_token(
4762 semaphore_scheduler: Arc<SemaphoreScheduler>,
4763 ) {
4764 // Test ErrorResponse::Custom behavior with TriggerCancellationTokenOnError
4765
4766 // Create a custom cancellation token
4767 let custom_cancel_token = CancellationToken::new();
4768
4769 // Create the policy that will trigger our custom token
4770 let error_policy = TriggerCancellationTokenOnError::new(custom_cancel_token.clone());
4771
4772 // Create tracker using builder with the custom policy
4773 let tracker = TaskTracker::builder()
4774 .scheduler(semaphore_scheduler)
4775 .error_policy(error_policy)
4776 .cancel_token(custom_cancel_token.clone())
4777 .build()
4778 .unwrap();
4779
4780 let child = tracker.child_tracker().unwrap();
4781
4782 // Initially, the custom token should not be cancelled
4783 assert!(!custom_cancel_token.is_cancelled());
4784
4785 // Spawn a task that will fail
4786 let handle = child.spawn(async {
4787 Err::<(), _>(anyhow::anyhow!("Test error to trigger custom response"))
4788 });
4789
4790 // Wait for the task to complete (it will fail)
4791 let result = handle.await.unwrap();
4792 assert!(result.is_err());
4793
4794 // Await a timeout/deadline or the cancellation token to be cancelled
4795 // The expectation is that the task will fail, and the cancellation token will be triggered
4796 // Hitting the deadline is a failure
4797 tokio::select! {
4798 _ = tokio::time::sleep(Duration::from_secs(1)) => {
4799 panic!("Task should have failed, but hit the deadline");
4800 }
4801 _ = custom_cancel_token.cancelled() => {
4802 // Task should have failed, and the cancellation token should be triggered
4803 }
4804 }
4805
4806 // The custom cancellation token should now be triggered by our policy
4807 assert!(
4808 custom_cancel_token.is_cancelled(),
4809 "Custom cancellation token should be triggered by ErrorResponse::Custom"
4810 );
4811
4812 assert!(tracker.cancellation_token().is_cancelled());
4813 assert!(child.cancellation_token().is_cancelled());
4814
4815 // Verify the error was counted
4816 assert_eq!(tracker.metrics().failed(), 1);
4817 }
4818
4819 #[test]
4820 fn test_action_result_variants() {
4821 // Test that ActionResult variants can be created and pattern matched
4822
4823 // Test Fail variant
4824 let fail_result = ActionResult::Fail;
4825 match fail_result {
4826 ActionResult::Fail => {} // Expected
4827 _ => panic!("Expected Fail variant"),
4828 }
4829
4830 // Test Shutdown variant
4831 let shutdown_result = ActionResult::Shutdown;
4832 match shutdown_result {
4833 ActionResult::Shutdown => {} // Expected
4834 _ => panic!("Expected Shutdown variant"),
4835 }
4836
4837 // Test Continue variant with Continuation
4838 #[derive(Debug)]
4839 struct TestRestartable;
4840
4841 #[async_trait]
4842 impl Continuation for TestRestartable {
4843 async fn execute(
4844 &self,
4845 _cancel_token: CancellationToken,
4846 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4847 TaskExecutionResult::Success(Box::new("test_result".to_string()))
4848 }
4849 }
4850
4851 let test_restartable = Arc::new(TestRestartable);
4852 let continue_result = ActionResult::Continue {
4853 continuation: test_restartable,
4854 };
4855
4856 match continue_result {
4857 ActionResult::Continue { continuation } => {
4858 // Verify we have a valid Continuation
4859 assert!(format!("{:?}", continuation).contains("TestRestartable"));
4860 }
4861 _ => panic!("Expected Continue variant"),
4862 }
4863 }
4864
4865 #[test]
4866 fn test_continuation_error_creation() {
4867 // Test RestartableError creation and conversion to anyhow::Error
4868
4869 // Create a dummy restartable task for testing
4870 #[derive(Debug)]
4871 struct DummyRestartable;
4872
4873 #[async_trait]
4874 impl Continuation for DummyRestartable {
4875 async fn execute(
4876 &self,
4877 _cancel_token: CancellationToken,
4878 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4879 TaskExecutionResult::Success(Box::new("restarted_result".to_string()))
4880 }
4881 }
4882
4883 let dummy_restartable = Arc::new(DummyRestartable);
4884 let source_error = anyhow::anyhow!("Original task failed");
4885
4886 // Test FailedWithContinuation::new
4887 let continuation_error = FailedWithContinuation::new(source_error, dummy_restartable);
4888
4889 // Verify the error displays correctly
4890 let error_string = format!("{}", continuation_error);
4891 assert!(error_string.contains("Task failed with continuation"));
4892 assert!(error_string.contains("Original task failed"));
4893
4894 // Test conversion to anyhow::Error
4895 let anyhow_error = anyhow::Error::new(continuation_error);
4896 assert!(
4897 anyhow_error
4898 .to_string()
4899 .contains("Task failed with continuation")
4900 );
4901 }
4902
4903 #[test]
4904 fn test_continuation_error_ext_trait() {
4905 // Test the RestartableErrorExt trait methods
4906
4907 // Test with regular anyhow::Error (not restartable)
4908 let regular_error = anyhow::anyhow!("Regular error");
4909 assert!(!regular_error.has_continuation());
4910 let extracted = regular_error.extract_continuation();
4911 assert!(extracted.is_none());
4912
4913 // Test with RestartableError
4914 #[derive(Debug)]
4915 struct TestRestartable;
4916
4917 #[async_trait]
4918 impl Continuation for TestRestartable {
4919 async fn execute(
4920 &self,
4921 _cancel_token: CancellationToken,
4922 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4923 TaskExecutionResult::Success(Box::new("test_result".to_string()))
4924 }
4925 }
4926
4927 let test_restartable = Arc::new(TestRestartable);
4928 let source_error = anyhow::anyhow!("Source error");
4929 let continuation_error = FailedWithContinuation::new(source_error, test_restartable);
4930
4931 let anyhow_error = anyhow::Error::new(continuation_error);
4932 assert!(anyhow_error.has_continuation());
4933
4934 // Test extraction of restartable task
4935 let extracted = anyhow_error.extract_continuation();
4936 assert!(extracted.is_some());
4937 }
4938
4939 #[test]
4940 fn test_continuation_error_into_anyhow_helper() {
4941 // Test the convenience method for creating restartable errors
4942 // Note: This test uses a mock TaskExecutor since we don't have real ones yet
4943
4944 // For now, we'll test the type erasure concept with a simple type
4945 struct MockExecutor;
4946
4947 let _source_error = anyhow::anyhow!("Mock task failed");
4948
4949 // We can't test FailedWithContinuation::into_anyhow yet because it requires
4950 // a real TaskExecutor<T>. This will be tested in Phase 3.
4951 // For now, just verify the concept works with manual construction.
4952
4953 #[derive(Debug)]
4954 struct MockRestartable;
4955
4956 #[async_trait]
4957 impl Continuation for MockRestartable {
4958 async fn execute(
4959 &self,
4960 _cancel_token: CancellationToken,
4961 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4962 TaskExecutionResult::Success(Box::new("mock_result".to_string()))
4963 }
4964 }
4965
4966 let mock_restartable = Arc::new(MockRestartable);
4967 let continuation_error =
4968 FailedWithContinuation::new(anyhow::anyhow!("Mock task failed"), mock_restartable);
4969
4970 let anyhow_error = anyhow::Error::new(continuation_error);
4971 assert!(anyhow_error.has_continuation());
4972 }
4973
4974 #[test]
4975 fn test_continuation_error_with_task_executor() {
4976 // Test RestartableError creation with TaskExecutor
4977
4978 #[derive(Debug)]
4979 struct TestRestartableTask;
4980
4981 #[async_trait]
4982 impl Continuation for TestRestartableTask {
4983 async fn execute(
4984 &self,
4985 _cancel_token: CancellationToken,
4986 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
4987 TaskExecutionResult::Success(Box::new("test_result".to_string()))
4988 }
4989 }
4990
4991 let restartable_task = Arc::new(TestRestartableTask);
4992 let source_error = anyhow::anyhow!("Task failed");
4993
4994 // Test FailedWithContinuation::new with Restartable
4995 let continuation_error = FailedWithContinuation::new(source_error, restartable_task);
4996
4997 // Verify the error displays correctly
4998 let error_string = format!("{}", continuation_error);
4999 assert!(error_string.contains("Task failed with continuation"));
5000 assert!(error_string.contains("Task failed"));
5001
5002 // Test conversion to anyhow::Error
5003 let anyhow_error = anyhow::Error::new(continuation_error);
5004 assert!(anyhow_error.has_continuation());
5005
5006 // Test extraction (should work now with Restartable trait)
5007 let extracted = anyhow_error.extract_continuation();
5008 assert!(extracted.is_some()); // Should successfully extract the Restartable
5009 }
5010
5011 #[test]
5012 fn test_continuation_error_into_anyhow_convenience() {
5013 // Test the convenience method for creating restartable errors
5014
5015 #[derive(Debug)]
5016 struct ConvenienceRestartable;
5017
5018 #[async_trait]
5019 impl Continuation for ConvenienceRestartable {
5020 async fn execute(
5021 &self,
5022 _cancel_token: CancellationToken,
5023 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5024 TaskExecutionResult::Success(Box::new(42u32))
5025 }
5026 }
5027
5028 let restartable_task = Arc::new(ConvenienceRestartable);
5029 let source_error = anyhow::anyhow!("Computation failed");
5030
5031 // Test FailedWithContinuation::into_anyhow convenience method
5032 let anyhow_error = FailedWithContinuation::into_anyhow(source_error, restartable_task);
5033
5034 assert!(anyhow_error.has_continuation());
5035 assert!(
5036 anyhow_error
5037 .to_string()
5038 .contains("Task failed with continuation")
5039 );
5040 assert!(anyhow_error.to_string().contains("Computation failed"));
5041 }
5042
5043 #[test]
5044 fn test_handle_task_error_with_continuation_error() {
5045 // Test that handle_task_error properly detects RestartableError
5046
5047 // Create a mock Restartable task
5048 #[derive(Debug)]
5049 struct MockRestartableTask;
5050
5051 #[async_trait]
5052 impl Continuation for MockRestartableTask {
5053 async fn execute(
5054 &self,
5055 _cancel_token: CancellationToken,
5056 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5057 TaskExecutionResult::Success(Box::new("retry_result".to_string()))
5058 }
5059 }
5060
5061 let restartable_task = Arc::new(MockRestartableTask);
5062
5063 // Create RestartableError
5064 let source_error = anyhow::anyhow!("Task failed, but can retry");
5065 let continuation_error = FailedWithContinuation::new(source_error, restartable_task);
5066 let anyhow_error = anyhow::Error::new(continuation_error);
5067
5068 // Verify it's detected as restartable
5069 assert!(anyhow_error.has_continuation());
5070
5071 // Verify we can downcast to FailedWithContinuation
5072 let continuation_ref = anyhow_error.downcast_ref::<FailedWithContinuation>();
5073 assert!(continuation_ref.is_some());
5074
5075 // Verify the continuation task is present
5076 let continuation = continuation_ref.unwrap();
5077 // Note: We can verify the Arc is valid by checking that Arc::strong_count > 0
5078 assert!(Arc::strong_count(&continuation.continuation) > 0);
5079 }
5080
5081 #[test]
5082 fn test_handle_task_error_with_regular_error() {
5083 // Test that handle_task_error properly handles regular errors
5084
5085 let regular_error = anyhow::anyhow!("Regular task failure");
5086
5087 // Verify it's not detected as restartable
5088 assert!(!regular_error.has_continuation());
5089
5090 // Verify we cannot downcast to FailedWithContinuation
5091 let continuation_ref = regular_error.downcast_ref::<FailedWithContinuation>();
5092 assert!(continuation_ref.is_none());
5093 }
5094
5095 // ========================================
5096 // END-TO-END ACTIONRESULT TESTS
5097 // ========================================
5098
5099 #[rstest]
5100 #[tokio::test]
5101 async fn test_end_to_end_continuation_execution(
5102 unlimited_scheduler: Arc<UnlimitedScheduler>,
5103 log_policy: Arc<LogOnlyPolicy>,
5104 ) {
5105 // Test that a task returning FailedWithContinuation actually executes the continuation
5106 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
5107
5108 // Shared state to track execution
5109 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5110 let log_clone = execution_log.clone();
5111
5112 // Create a continuation that logs its execution
5113 #[derive(Debug)]
5114 struct LoggingContinuation {
5115 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5116 result: String,
5117 }
5118
5119 #[async_trait]
5120 impl Continuation for LoggingContinuation {
5121 async fn execute(
5122 &self,
5123 _cancel_token: CancellationToken,
5124 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5125 self.log
5126 .lock()
5127 .await
5128 .push("continuation_executed".to_string());
5129 TaskExecutionResult::Success(Box::new(self.result.clone()))
5130 }
5131 }
5132
5133 let continuation = Arc::new(LoggingContinuation {
5134 log: log_clone,
5135 result: "continuation_result".to_string(),
5136 });
5137
5138 // Task that fails with continuation
5139 let log_for_task = execution_log.clone();
5140 let handle = tracker.spawn(async move {
5141 log_for_task
5142 .lock()
5143 .await
5144 .push("original_task_executed".to_string());
5145
5146 // Return FailedWithContinuation
5147 let error = anyhow::anyhow!("Original task failed");
5148 let result: Result<String, anyhow::Error> =
5149 Err(FailedWithContinuation::into_anyhow(error, continuation));
5150 result
5151 });
5152
5153 // Execute and verify the continuation was called
5154 let result = handle.await.expect("Task should complete");
5155 assert!(result.is_ok(), "Continuation should succeed");
5156
5157 // Verify execution order
5158 let log = execution_log.lock().await;
5159 assert_eq!(log.len(), 2);
5160 assert_eq!(log[0], "original_task_executed");
5161 assert_eq!(log[1], "continuation_executed");
5162
5163 // Verify metrics - should show 1 success (from continuation)
5164 assert_eq!(tracker.metrics().success(), 1);
5165 assert_eq!(tracker.metrics().failed(), 0); // Continuation succeeded
5166 assert_eq!(tracker.metrics().cancelled(), 0);
5167 }
5168
5169 #[rstest]
5170 #[tokio::test]
5171 async fn test_end_to_end_multiple_continuations(
5172 unlimited_scheduler: Arc<UnlimitedScheduler>,
5173 log_policy: Arc<LogOnlyPolicy>,
5174 ) {
5175 // Test multiple continuation attempts
5176 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
5177
5178 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5179 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
5180
5181 // Continuation that fails twice, then succeeds
5182 #[derive(Debug)]
5183 struct RetryingContinuation {
5184 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5185 attempt_count: Arc<std::sync::atomic::AtomicU32>,
5186 }
5187
5188 #[async_trait]
5189 impl Continuation for RetryingContinuation {
5190 async fn execute(
5191 &self,
5192 _cancel_token: CancellationToken,
5193 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5194 let attempt = self
5195 .attempt_count
5196 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
5197 + 1;
5198 self.log
5199 .lock()
5200 .await
5201 .push(format!("continuation_attempt_{}", attempt));
5202
5203 if attempt < 3 {
5204 // Fail with another continuation
5205 let next_continuation = Arc::new(RetryingContinuation {
5206 log: self.log.clone(),
5207 attempt_count: self.attempt_count.clone(),
5208 });
5209 let error = anyhow::anyhow!("Continuation attempt {} failed", attempt);
5210 TaskExecutionResult::Error(FailedWithContinuation::into_anyhow(
5211 error,
5212 next_continuation,
5213 ))
5214 } else {
5215 // Succeed on third attempt
5216 TaskExecutionResult::Success(Box::new(format!(
5217 "success_on_attempt_{}",
5218 attempt
5219 )))
5220 }
5221 }
5222 }
5223
5224 let initial_continuation = Arc::new(RetryingContinuation {
5225 log: execution_log.clone(),
5226 attempt_count: attempt_count.clone(),
5227 });
5228
5229 // Task that immediately fails with continuation
5230 let handle = tracker.spawn(async move {
5231 let error = anyhow::anyhow!("Original task failed");
5232 let result: Result<String, anyhow::Error> = Err(FailedWithContinuation::into_anyhow(
5233 error,
5234 initial_continuation,
5235 ));
5236 result
5237 });
5238
5239 // Execute and verify multiple continuations
5240 let result = handle.await.expect("Task should complete");
5241 assert!(result.is_ok(), "Final continuation should succeed");
5242
5243 // Verify all attempts were made
5244 let log = execution_log.lock().await;
5245 assert_eq!(log.len(), 3);
5246 assert_eq!(log[0], "continuation_attempt_1");
5247 assert_eq!(log[1], "continuation_attempt_2");
5248 assert_eq!(log[2], "continuation_attempt_3");
5249
5250 // Verify final attempt count
5251 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 3);
5252
5253 // Verify metrics - should show 1 success (final continuation)
5254 assert_eq!(tracker.metrics().success(), 1);
5255 assert_eq!(tracker.metrics().failed(), 0);
5256 }
5257
5258 #[rstest]
5259 #[tokio::test]
5260 async fn test_end_to_end_continuation_failure(
5261 unlimited_scheduler: Arc<UnlimitedScheduler>,
5262 log_policy: Arc<LogOnlyPolicy>,
5263 ) {
5264 // Test continuation that ultimately fails without providing another continuation
5265 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
5266
5267 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5268 let log_clone = execution_log.clone();
5269
5270 // Continuation that fails without providing another continuation
5271 #[derive(Debug)]
5272 struct FailingContinuation {
5273 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5274 }
5275
5276 #[async_trait]
5277 impl Continuation for FailingContinuation {
5278 async fn execute(
5279 &self,
5280 _cancel_token: CancellationToken,
5281 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5282 self.log
5283 .lock()
5284 .await
5285 .push("continuation_failed".to_string());
5286 TaskExecutionResult::Error(anyhow::anyhow!("Continuation failed permanently"))
5287 }
5288 }
5289
5290 let continuation = Arc::new(FailingContinuation { log: log_clone });
5291
5292 // Task that fails with continuation
5293 let log_for_task = execution_log.clone();
5294 let handle = tracker.spawn(async move {
5295 log_for_task
5296 .lock()
5297 .await
5298 .push("original_task_executed".to_string());
5299
5300 let error = anyhow::anyhow!("Original task failed");
5301 let result: Result<String, anyhow::Error> =
5302 Err(FailedWithContinuation::into_anyhow(error, continuation));
5303 result
5304 });
5305
5306 // Execute and verify the continuation failed
5307 let result = handle.await.expect("Task should complete");
5308 assert!(result.is_err(), "Continuation should fail");
5309
5310 // Verify execution order
5311 let log = execution_log.lock().await;
5312 assert_eq!(log.len(), 2);
5313 assert_eq!(log[0], "original_task_executed");
5314 assert_eq!(log[1], "continuation_failed");
5315
5316 // Verify metrics - should show 1 failure (from continuation)
5317 assert_eq!(tracker.metrics().success(), 0);
5318 assert_eq!(tracker.metrics().failed(), 1);
5319 assert_eq!(tracker.metrics().cancelled(), 0);
5320 }
5321
5322 #[rstest]
5323 #[tokio::test]
5324 async fn test_end_to_end_all_action_result_variants(
5325 unlimited_scheduler: Arc<UnlimitedScheduler>,
5326 ) {
5327 // Comprehensive test of Fail, Shutdown, and Continue paths
5328
5329 // Test 1: ActionResult::Fail (via LogOnlyPolicy)
5330 {
5331 let tracker =
5332 TaskTracker::new(unlimited_scheduler.clone(), LogOnlyPolicy::new()).unwrap();
5333 let handle = tracker.spawn(async {
5334 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Test error"));
5335 result
5336 });
5337 let result = handle.await.expect("Task should complete");
5338 assert!(result.is_err(), "LogOnly should let error through");
5339 assert_eq!(tracker.metrics().failed(), 1);
5340 }
5341
5342 // Test 2: ActionResult::Shutdown (via CancelOnError)
5343 {
5344 let tracker =
5345 TaskTracker::new(unlimited_scheduler.clone(), CancelOnError::new()).unwrap();
5346 let handle = tracker.spawn(async {
5347 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Test error"));
5348 result
5349 });
5350 let result = handle.await.expect("Task should complete");
5351 assert!(result.is_err(), "CancelOnError should fail task");
5352 assert!(
5353 tracker.cancellation_token().is_cancelled(),
5354 "Should cancel tracker"
5355 );
5356 assert_eq!(tracker.metrics().failed(), 1);
5357 }
5358
5359 // Test 3: ActionResult::Continue (via FailedWithContinuation)
5360 {
5361 let tracker =
5362 TaskTracker::new(unlimited_scheduler.clone(), LogOnlyPolicy::new()).unwrap();
5363
5364 #[derive(Debug)]
5365 struct TestContinuation;
5366
5367 #[async_trait]
5368 impl Continuation for TestContinuation {
5369 async fn execute(
5370 &self,
5371 _cancel_token: CancellationToken,
5372 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5373 TaskExecutionResult::Success(Box::new("continuation_success".to_string()))
5374 }
5375 }
5376
5377 let continuation = Arc::new(TestContinuation);
5378 let handle = tracker.spawn(async move {
5379 let error = anyhow::anyhow!("Original failure");
5380 let result: Result<String, anyhow::Error> =
5381 Err(FailedWithContinuation::into_anyhow(error, continuation));
5382 result
5383 });
5384
5385 let result = handle.await.expect("Task should complete");
5386 assert!(result.is_ok(), "Continuation should succeed");
5387 assert_eq!(tracker.metrics().success(), 1);
5388 assert_eq!(tracker.metrics().failed(), 0);
5389 }
5390 }
5391
5392 // ========================================
5393 // LOOP BEHAVIOR AND POLICY INTERACTION TESTS
5394 // ========================================
5395 //
5396 // These tests demonstrate the current ActionResult system and identify
5397 // areas for future improvement:
5398 //
5399 // ✅ WHAT WORKS:
5400 // - All ActionResult variants (Continue, Cancel, ExecuteNext) are tested
5401 // - Task-driven continuations work correctly
5402 // - Policy-driven continuations work correctly
5403 // - Mixed continuation sources work correctly
5404 // - Loop behavior with resource management works correctly
5405 //
5406 // 🔄 CURRENT LIMITATIONS:
5407 // - ThresholdCancelPolicy tracks failures GLOBALLY, not per-task
5408 // - OnErrorPolicy doesn't receive attempt_count parameter
5409 // - No per-task context for stateful retry policies
5410 //
5411 // 🚀 FUTURE IMPROVEMENTS IDENTIFIED:
5412 // - Add OnErrorContext associated type for per-task state
5413 // - Pass attempt_count to OnErrorPolicy::on_error
5414 // - Enable per-task failure tracking, backoff timers, etc.
5415 //
5416 // The tests below demonstrate both current capabilities and limitations.
5417
5418 /// Test retry loop behavior with different policies and continuation counts
5419 ///
5420 /// This test verifies that:
5421 /// 1. Tasks can provide multiple continuations in sequence
5422 /// 2. Different error policies can limit the number of continuation attempts
5423 /// 3. The retry loop correctly handles policy decisions about when to stop
5424 ///
5425 /// Key insight: Policies are only consulted for regular errors, not FailedWithContinuation.
5426 /// So we need continuations that eventually fail with regular errors to test policy limits.
5427 ///
5428 /// DESIGN LIMITATION: Current ThresholdCancelPolicy tracks failures GLOBALLY across all tasks,
5429 /// not per-task. This test demonstrates the current behavior but isn't ideal for retry loop testing.
5430 ///
5431 /// FUTURE IMPROVEMENT: Add OnErrorContext associated type to OnErrorPolicy:
5432 /// ```rust
5433 /// trait OnErrorPolicy {
5434 /// type Context: Default + Send + Sync;
5435 /// fn on_error(&self, error: &anyhow::Error, task_id: TaskId,
5436 /// attempt_count: u32, context: &mut Self::Context) -> ErrorResponse;
5437 /// }
5438 /// ```
5439 /// This would enable per-task failure tracking, backoff timers, etc.
5440 ///
5441 /// NOTE: Uses fresh policy instance for each test case to avoid global state interference.
5442 #[rstest]
5443 #[case(
5444 1,
5445 false,
5446 "Global policy with max_failures=1 should stop after first regular error"
5447 )]
5448 #[case(
5449 2,
5450 false, // Actually fails - ActionResult::Fail accepts the error and fails the task
5451 "Global policy with max_failures=2 allows error but ActionResult::Fail still fails the task"
5452 )]
5453 #[tokio::test]
5454 async fn test_continuation_loop_with_global_threshold_policy(
5455 unlimited_scheduler: Arc<UnlimitedScheduler>,
5456 #[case] max_failures: usize,
5457 #[case] should_succeed: bool,
5458 #[case] description: &str,
5459 ) {
5460 // Task that provides continuations, but continuations fail with regular errors
5461 // so the policy gets consulted and can limit retries
5462
5463 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5464 let attempt_counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
5465
5466 // Create a continuation that fails with regular errors (not FailedWithContinuation)
5467 // This allows the policy to be consulted and potentially stop the retries
5468 #[derive(Debug)]
5469 struct PolicyTestContinuation {
5470 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5471 attempt_counter: Arc<std::sync::atomic::AtomicU32>,
5472 max_attempts_before_success: u32,
5473 }
5474
5475 #[async_trait]
5476 impl Continuation for PolicyTestContinuation {
5477 async fn execute(
5478 &self,
5479 _cancel_token: CancellationToken,
5480 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5481 let attempt = self
5482 .attempt_counter
5483 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
5484 + 1;
5485 self.log
5486 .lock()
5487 .await
5488 .push(format!("continuation_attempt_{}", attempt));
5489
5490 if attempt < self.max_attempts_before_success {
5491 // Fail with regular error - this will be seen by the policy
5492 TaskExecutionResult::Error(anyhow::anyhow!(
5493 "Continuation attempt {} failed (regular error)",
5494 attempt
5495 ))
5496 } else {
5497 // Succeed after enough attempts
5498 TaskExecutionResult::Success(Box::new(format!(
5499 "success_on_attempt_{}",
5500 attempt
5501 )))
5502 }
5503 }
5504 }
5505
5506 // Create fresh policy instance for each test case to avoid global state interference
5507 let policy = ThresholdCancelPolicy::with_threshold(max_failures);
5508 let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
5509
5510 // Original task that fails with continuation
5511 let log_for_task = execution_log.clone();
5512 // Set max_attempts_before_success so that:
5513 // - For max_failures=1: Continuation fails 1 time (attempt 1), policy cancels after 1 failure
5514 // - For max_failures=2: Continuation fails 1 time (attempt 1), succeeds on attempt 2
5515 let continuation = Arc::new(PolicyTestContinuation {
5516 log: execution_log.clone(),
5517 attempt_counter: attempt_counter.clone(),
5518 max_attempts_before_success: 2, // Always fail on attempt 1, succeed on attempt 2
5519 });
5520
5521 let handle = tracker.spawn(async move {
5522 log_for_task
5523 .lock()
5524 .await
5525 .push("original_task_executed".to_string());
5526 let error = anyhow::anyhow!("Original task failed");
5527 let result: Result<String, anyhow::Error> =
5528 Err(FailedWithContinuation::into_anyhow(error, continuation));
5529 result
5530 });
5531
5532 // Execute and check result based on policy
5533 let result = handle.await.expect("Task should complete");
5534
5535 // Debug: Print actual results
5536 let log = execution_log.lock().await;
5537 let final_attempt_count = attempt_counter.load(std::sync::atomic::Ordering::Relaxed);
5538 println!(
5539 "Test case: max_failures={}, should_succeed={}",
5540 max_failures, should_succeed
5541 );
5542 println!("Result: {:?}", result.is_ok());
5543 println!("Log entries: {:?}", log);
5544 println!("Attempt count: {}", final_attempt_count);
5545 println!(
5546 "Metrics: success={}, failed={}",
5547 tracker.metrics().success(),
5548 tracker.metrics().failed()
5549 );
5550 drop(log); // Release the lock
5551
5552 // Both test cases should fail because ActionResult::Fail accepts the error and fails the task
5553 assert!(result.is_err(), "{}: Task should fail", description);
5554 assert_eq!(
5555 tracker.metrics().success(),
5556 0,
5557 "{}: Should have 0 successes",
5558 description
5559 );
5560 assert_eq!(
5561 tracker.metrics().failed(),
5562 1,
5563 "{}: Should have 1 failure",
5564 description
5565 );
5566
5567 // Should have stopped after 1 continuation attempt because ActionResult::Fail fails the task
5568 let log = execution_log.lock().await;
5569 assert_eq!(
5570 log.len(),
5571 2,
5572 "{}: Should have 2 log entries (original + 1 continuation attempt)",
5573 description
5574 );
5575 assert_eq!(log[0], "original_task_executed");
5576 assert_eq!(log[1], "continuation_attempt_1");
5577
5578 assert_eq!(
5579 attempt_counter.load(std::sync::atomic::Ordering::Relaxed),
5580 1,
5581 "{}: Should have made 1 continuation attempt",
5582 description
5583 );
5584
5585 // The key difference is whether the tracker gets cancelled
5586 if max_failures == 1 {
5587 assert!(
5588 tracker.cancellation_token().is_cancelled(),
5589 "Tracker should be cancelled with max_failures=1"
5590 );
5591 } else {
5592 assert!(
5593 !tracker.cancellation_token().is_cancelled(),
5594 "Tracker should NOT be cancelled with max_failures=2 (policy allows the error)"
5595 );
5596 }
5597 }
5598
5599 /// Simple test to understand ThresholdCancelPolicy behavior with per-task context
5600 #[rstest]
5601 #[tokio::test]
5602 async fn test_simple_threshold_policy_behavior(unlimited_scheduler: Arc<UnlimitedScheduler>) {
5603 // Test with max_failures=2 - now uses per-task failure counting
5604 let policy = ThresholdCancelPolicy::with_threshold(2);
5605 let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
5606
5607 // Task 1: Should fail but not trigger cancellation (per-task failure count = 1)
5608 let handle1 = tracker.spawn(async {
5609 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("First failure"));
5610 result
5611 });
5612 let result1 = handle1.await.expect("Task should complete");
5613 assert!(result1.is_err(), "First task should fail");
5614 assert!(
5615 !tracker.cancellation_token().is_cancelled(),
5616 "Should not be cancelled after 1 failure"
5617 );
5618
5619 // Task 2: Should fail but not trigger cancellation (different task, per-task failure count = 1)
5620 let handle2 = tracker.spawn(async {
5621 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Second failure"));
5622 result
5623 });
5624 let result2 = handle2.await.expect("Task should complete");
5625 assert!(result2.is_err(), "Second task should fail");
5626 assert!(
5627 !tracker.cancellation_token().is_cancelled(),
5628 "Should NOT be cancelled - per-task context prevents global accumulation"
5629 );
5630
5631 println!("Policy global failure count: {}", policy.failure_count());
5632 assert_eq!(
5633 policy.failure_count(),
5634 2,
5635 "Policy should have counted 2 failures globally (for backwards compatibility)"
5636 );
5637 }
5638
5639 /// Test demonstrating that per-task error context solves the global failure tracking problem
5640 ///
5641 /// This test shows that with OnErrorContext, each task has independent failure tracking.
5642 #[rstest]
5643 #[tokio::test]
5644 async fn test_per_task_context_limitation_demo(unlimited_scheduler: Arc<UnlimitedScheduler>) {
5645 // Create a policy that should allow 2 failures per task
5646 let policy = ThresholdCancelPolicy::with_threshold(2);
5647 let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
5648
5649 // Task 1: Fails once (per-task failure count = 1, below threshold)
5650 let handle1 = tracker.spawn(async {
5651 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Task 1 failure"));
5652 result
5653 });
5654 let result1 = handle1.await.expect("Task should complete");
5655 assert!(result1.is_err(), "Task 1 should fail");
5656
5657 // Task 2: Also fails once (per-task failure count = 1, below threshold)
5658 // With per-task context, this doesn't interfere with Task 1's failure budget
5659 let handle2 = tracker.spawn(async {
5660 let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Task 2 failure"));
5661 result
5662 });
5663 let result2 = handle2.await.expect("Task should complete");
5664 assert!(result2.is_err(), "Task 2 should fail");
5665
5666 // With per-task context, tracker should NOT be cancelled
5667 // Each task failed only once, which is below the threshold of 2
5668 assert!(
5669 !tracker.cancellation_token().is_cancelled(),
5670 "Tracker should NOT be cancelled - per-task context prevents premature cancellation"
5671 );
5672
5673 println!("Global failure count: {}", policy.failure_count());
5674 assert_eq!(
5675 policy.failure_count(),
5676 2,
5677 "Global policy counted 2 failures across different tasks"
5678 );
5679
5680 // This demonstrates the limitation: we can't test per-task retry behavior
5681 // because failures from different tasks affect each other's retry budgets
5682 }
5683
5684 /// Test allow_continuation() policy method with attempt-based logic
5685 ///
5686 /// This test verifies that:
5687 /// 1. Policies can conditionally allow/reject continuations based on context
5688 /// 2. When allow_continuation() returns false, FailedWithContinuation is ignored
5689 /// 3. When allow_continuation() returns true, FailedWithContinuation is processed normally
5690 /// 4. The policy's decision takes precedence over task-provided continuations
5691 #[rstest]
5692 #[case(
5693 3,
5694 true,
5695 "Policy allows continuations up to 3 attempts - should succeed"
5696 )]
5697 #[case(
5698 2,
5699 true,
5700 "Policy allows continuations up to 2 attempts - should succeed"
5701 )]
5702 #[case(0, false, "Policy allows 0 attempts - should fail immediately")]
5703 #[tokio::test]
5704 async fn test_allow_continuation_policy_control(
5705 unlimited_scheduler: Arc<UnlimitedScheduler>,
5706 #[case] max_attempts: u32,
5707 #[case] should_succeed: bool,
5708 #[case] description: &str,
5709 ) {
5710 // Policy that allows continuations only up to max_attempts
5711 #[derive(Debug)]
5712 struct AttemptLimitPolicy {
5713 max_attempts: u32,
5714 }
5715
5716 impl OnErrorPolicy for AttemptLimitPolicy {
5717 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
5718 Arc::new(AttemptLimitPolicy {
5719 max_attempts: self.max_attempts,
5720 })
5721 }
5722
5723 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
5724 None // Stateless policy
5725 }
5726
5727 fn allow_continuation(&self, _error: &anyhow::Error, context: &OnErrorContext) -> bool {
5728 context.attempt_count <= self.max_attempts
5729 }
5730
5731 fn on_error(
5732 &self,
5733 _error: &anyhow::Error,
5734 _context: &mut OnErrorContext,
5735 ) -> ErrorResponse {
5736 ErrorResponse::Fail // Just fail when continuations are not allowed
5737 }
5738 }
5739
5740 let policy = Arc::new(AttemptLimitPolicy { max_attempts });
5741 let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
5742 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
5743
5744 // Continuation that always tries to retry
5745 #[derive(Debug)]
5746 struct AlwaysRetryContinuation {
5747 log: Arc<tokio::sync::Mutex<Vec<String>>>,
5748 attempt: u32,
5749 }
5750
5751 #[async_trait]
5752 impl Continuation for AlwaysRetryContinuation {
5753 async fn execute(
5754 &self,
5755 _cancel_token: CancellationToken,
5756 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
5757 self.log
5758 .lock()
5759 .await
5760 .push(format!("continuation_attempt_{}", self.attempt));
5761
5762 if self.attempt >= 2 {
5763 // Success after 2 attempts
5764 TaskExecutionResult::Success(Box::new("final_success".to_string()))
5765 } else {
5766 // Try to continue with another continuation
5767 let next_continuation = Arc::new(AlwaysRetryContinuation {
5768 log: self.log.clone(),
5769 attempt: self.attempt + 1,
5770 });
5771 let error = anyhow::anyhow!("Continuation attempt {} failed", self.attempt);
5772 TaskExecutionResult::Error(FailedWithContinuation::into_anyhow(
5773 error,
5774 next_continuation,
5775 ))
5776 }
5777 }
5778 }
5779
5780 // Task that immediately fails with a continuation
5781 let initial_continuation = Arc::new(AlwaysRetryContinuation {
5782 log: execution_log.clone(),
5783 attempt: 1,
5784 });
5785
5786 let log_for_task = execution_log.clone();
5787 let handle = tracker.spawn(async move {
5788 log_for_task
5789 .lock()
5790 .await
5791 .push("initial_task_failure".to_string());
5792 let error = anyhow::anyhow!("Initial task failure");
5793 let result: Result<String, anyhow::Error> = Err(FailedWithContinuation::into_anyhow(
5794 error,
5795 initial_continuation,
5796 ));
5797 result
5798 });
5799
5800 let result = handle.await.expect("Task should complete");
5801
5802 if should_succeed {
5803 assert!(result.is_ok(), "{}: Task should succeed", description);
5804 assert_eq!(
5805 tracker.metrics().success(),
5806 1,
5807 "{}: Should have 1 success",
5808 description
5809 );
5810
5811 // Should have executed multiple continuations
5812 let log = execution_log.lock().await;
5813 assert!(
5814 log.len() > 2,
5815 "{}: Should have multiple log entries",
5816 description
5817 );
5818 assert!(log.contains(&"continuation_attempt_1".to_string()));
5819 } else {
5820 assert!(result.is_err(), "{}: Task should fail", description);
5821 assert_eq!(
5822 tracker.metrics().failed(),
5823 1,
5824 "{}: Should have 1 failure",
5825 description
5826 );
5827
5828 // Should have stopped early due to policy rejection
5829 let log = execution_log.lock().await;
5830 assert_eq!(
5831 log.len(),
5832 1,
5833 "{}: Should only have initial task entry",
5834 description
5835 );
5836 assert_eq!(log[0], "initial_task_failure");
5837 // Should NOT contain continuation attempts because policy rejected them
5838 assert!(
5839 !log.iter()
5840 .any(|entry| entry.contains("continuation_attempt")),
5841 "{}: Should not have continuation attempts, but got: {:?}",
5842 description,
5843 *log
5844 );
5845 }
5846 }
5847
5848 /// Test TaskHandle functionality
5849 ///
5850 /// This test verifies that:
5851 /// 1. TaskHandle can be awaited like a JoinHandle
5852 /// 2. TaskHandle provides access to the task's cancellation token
5853 /// 3. Individual task cancellation works correctly
5854 /// 4. TaskHandle methods (abort, is_finished) work as expected
5855 #[tokio::test]
5856 async fn test_task_handle_functionality() {
5857 let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
5858
5859 // Test basic functionality - TaskHandle can be awaited
5860 let handle1 = tracker.spawn(async {
5861 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
5862 Ok("completed".to_string())
5863 });
5864
5865 // Verify we can access the cancellation token
5866 let cancel_token = handle1.cancellation_token();
5867 assert!(
5868 !cancel_token.is_cancelled(),
5869 "Token should not be cancelled initially"
5870 );
5871
5872 // Await the task
5873 let result1 = handle1.await.expect("Task should complete");
5874 assert!(result1.is_ok(), "Task should succeed");
5875 assert_eq!(result1.unwrap(), "completed");
5876
5877 // Test individual task cancellation
5878 let handle2 = tracker.spawn_cancellable(|cancel_token| async move {
5879 tokio::select! {
5880 _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
5881 CancellableTaskResult::Ok("task_was_not_cancelled".to_string())
5882 },
5883 _ = cancel_token.cancelled() => {
5884 CancellableTaskResult::Cancelled
5885 },
5886
5887 }
5888 });
5889
5890 let cancel_token2 = handle2.cancellation_token();
5891
5892 // Cancel this specific task
5893 cancel_token2.cancel();
5894
5895 // The task should be cancelled
5896 let result2 = handle2.await.expect("Task should complete");
5897 assert!(result2.is_err(), "Task should be cancelled");
5898 assert!(
5899 result2.unwrap_err().is_cancellation(),
5900 "Should be a cancellation error"
5901 );
5902
5903 // Test that other tasks are not affected
5904 let handle3 = tracker.spawn(async { Ok("not_cancelled".to_string()) });
5905
5906 let result3 = handle3.await.expect("Task should complete");
5907 assert!(result3.is_ok(), "Other tasks should not be affected");
5908 assert_eq!(result3.unwrap(), "not_cancelled");
5909
5910 // Test abort functionality
5911 let handle4 = tracker.spawn(async {
5912 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
5913 Ok("should_be_aborted".to_string())
5914 });
5915
5916 // Check is_finished before abort
5917 assert!(!handle4.is_finished(), "Task should not be finished yet");
5918
5919 // Abort the task
5920 handle4.abort();
5921
5922 // Task should be aborted (JoinError)
5923 let result4 = handle4.await;
5924 assert!(result4.is_err(), "Aborted task should return JoinError");
5925
5926 // Verify metrics
5927 assert_eq!(
5928 tracker.metrics().success(),
5929 2,
5930 "Should have 2 successful tasks"
5931 );
5932 assert_eq!(
5933 tracker.metrics().cancelled(),
5934 1,
5935 "Should have 1 cancelled task"
5936 );
5937 // Note: aborted tasks don't count as cancelled in our metrics
5938 }
5939
5940 /// Test TaskHandle with cancellable tasks
5941 #[tokio::test]
5942 async fn test_task_handle_with_cancellable_tasks() {
5943 let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
5944
5945 // Test cancellable task with TaskHandle
5946 let handle = tracker.spawn_cancellable(|cancel_token| async move {
5947 tokio::select! {
5948 _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
5949 CancellableTaskResult::Ok("completed".to_string())
5950 },
5951 _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
5952 }
5953 });
5954
5955 // Verify we can access the task's individual cancellation token
5956 let task_cancel_token = handle.cancellation_token();
5957 assert!(
5958 !task_cancel_token.is_cancelled(),
5959 "Task token should not be cancelled initially"
5960 );
5961
5962 // Let the task complete normally
5963 let result = handle.await.expect("Task should complete");
5964 assert!(result.is_ok(), "Task should succeed");
5965 assert_eq!(result.unwrap(), "completed");
5966
5967 // Test cancellation of cancellable task
5968 let handle2 = tracker.spawn_cancellable(|cancel_token| async move {
5969 tokio::select! {
5970 _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
5971 CancellableTaskResult::Ok("should_not_complete".to_string())
5972 },
5973 _ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
5974 }
5975 });
5976
5977 // Cancel the specific task
5978 handle2.cancellation_token().cancel();
5979
5980 let result2 = handle2.await.expect("Task should complete");
5981 assert!(result2.is_err(), "Task should be cancelled");
5982 assert!(
5983 result2.unwrap_err().is_cancellation(),
5984 "Should be a cancellation error"
5985 );
5986
5987 // Verify metrics
5988 assert_eq!(
5989 tracker.metrics().success(),
5990 1,
5991 "Should have 1 successful task"
5992 );
5993 assert_eq!(
5994 tracker.metrics().cancelled(),
5995 1,
5996 "Should have 1 cancelled task"
5997 );
5998 }
5999
6000 /// Test FailedWithContinuation helper methods
6001 ///
6002 /// This test verifies that:
6003 /// 1. from_fn creates working continuations from simple closures
6004 /// 2. from_cancellable creates working continuations from cancellable closures
6005 /// 3. Both helpers integrate correctly with the task execution system
6006 #[tokio::test]
6007 async fn test_continuation_helpers() {
6008 let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
6009
6010 // Test from_fn helper
6011 let handle1 = tracker.spawn(async {
6012 let error =
6013 FailedWithContinuation::from_fn(anyhow::anyhow!("Initial failure"), || async {
6014 Ok("Success from from_fn".to_string())
6015 });
6016 let result: Result<String, anyhow::Error> = Err(error);
6017 result
6018 });
6019
6020 let result1 = handle1.await.expect("Task should complete");
6021 assert!(
6022 result1.is_ok(),
6023 "Task with from_fn continuation should succeed"
6024 );
6025 assert_eq!(result1.unwrap(), "Success from from_fn");
6026
6027 // Test from_cancellable helper
6028 let handle2 = tracker.spawn(async {
6029 let error = FailedWithContinuation::from_cancellable(
6030 anyhow::anyhow!("Initial failure"),
6031 |_cancel_token| async move { Ok("Success from from_cancellable".to_string()) },
6032 );
6033 let result: Result<String, anyhow::Error> = Err(error);
6034 result
6035 });
6036
6037 let result2 = handle2.await.expect("Task should complete");
6038 assert!(
6039 result2.is_ok(),
6040 "Task with from_cancellable continuation should succeed"
6041 );
6042 assert_eq!(result2.unwrap(), "Success from from_cancellable");
6043
6044 // Verify metrics
6045 assert_eq!(
6046 tracker.metrics().success(),
6047 2,
6048 "Should have 2 successful tasks"
6049 );
6050 assert_eq!(tracker.metrics().failed(), 0, "Should have 0 failed tasks");
6051 }
6052
6053 /// Test should_reschedule() policy method with mock scheduler tracking
6054 ///
6055 /// This test verifies that:
6056 /// 1. When should_reschedule() returns false, the guard is reused (efficient)
6057 /// 2. When should_reschedule() returns true, the guard is re-acquired through scheduler
6058 /// 3. The scheduler's acquire_execution_slot is called the expected number of times
6059 /// 4. Rescheduling works for both task-driven and policy-driven continuations
6060 #[rstest]
6061 #[case(false, 1, "Policy requests no rescheduling - should reuse guard")]
6062 #[case(true, 2, "Policy requests rescheduling - should re-acquire guard")]
6063 #[tokio::test]
6064 async fn test_should_reschedule_policy_control(
6065 #[case] should_reschedule: bool,
6066 #[case] expected_acquisitions: u32,
6067 #[case] description: &str,
6068 ) {
6069 // Mock scheduler that tracks acquisition calls
6070 #[derive(Debug)]
6071 struct MockScheduler {
6072 acquisition_count: Arc<AtomicU32>,
6073 }
6074
6075 impl MockScheduler {
6076 fn new() -> Self {
6077 Self {
6078 acquisition_count: Arc::new(AtomicU32::new(0)),
6079 }
6080 }
6081
6082 fn acquisition_count(&self) -> u32 {
6083 self.acquisition_count.load(Ordering::Relaxed)
6084 }
6085 }
6086
6087 #[async_trait]
6088 impl TaskScheduler for MockScheduler {
6089 async fn acquire_execution_slot(
6090 &self,
6091 _cancel_token: CancellationToken,
6092 ) -> SchedulingResult<Box<dyn ResourceGuard>> {
6093 self.acquisition_count.fetch_add(1, Ordering::Relaxed);
6094 SchedulingResult::Execute(Box::new(UnlimitedGuard))
6095 }
6096 }
6097
6098 // Policy that controls rescheduling behavior
6099 #[derive(Debug)]
6100 struct RescheduleTestPolicy {
6101 should_reschedule: bool,
6102 }
6103
6104 impl OnErrorPolicy for RescheduleTestPolicy {
6105 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
6106 Arc::new(RescheduleTestPolicy {
6107 should_reschedule: self.should_reschedule,
6108 })
6109 }
6110
6111 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
6112 None // Stateless policy
6113 }
6114
6115 fn allow_continuation(
6116 &self,
6117 _error: &anyhow::Error,
6118 _context: &OnErrorContext,
6119 ) -> bool {
6120 true // Always allow continuations for this test
6121 }
6122
6123 fn should_reschedule(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
6124 self.should_reschedule
6125 }
6126
6127 fn on_error(
6128 &self,
6129 _error: &anyhow::Error,
6130 _context: &mut OnErrorContext,
6131 ) -> ErrorResponse {
6132 ErrorResponse::Fail // Just fail when continuations are not allowed
6133 }
6134 }
6135
6136 let mock_scheduler = Arc::new(MockScheduler::new());
6137 let policy = Arc::new(RescheduleTestPolicy { should_reschedule });
6138 let tracker = TaskTracker::new(mock_scheduler.clone(), policy).unwrap();
6139 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
6140
6141 // Simple continuation that succeeds on second attempt
6142 #[derive(Debug)]
6143 struct SimpleRetryContinuation {
6144 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6145 }
6146
6147 #[async_trait]
6148 impl Continuation for SimpleRetryContinuation {
6149 async fn execute(
6150 &self,
6151 _cancel_token: CancellationToken,
6152 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
6153 self.log
6154 .lock()
6155 .await
6156 .push("continuation_executed".to_string());
6157
6158 // Succeed immediately
6159 TaskExecutionResult::Success(Box::new("continuation_success".to_string()))
6160 }
6161 }
6162
6163 // Task that fails with a continuation
6164 let continuation = Arc::new(SimpleRetryContinuation {
6165 log: execution_log.clone(),
6166 });
6167
6168 let log_for_task = execution_log.clone();
6169 let handle = tracker.spawn(async move {
6170 log_for_task
6171 .lock()
6172 .await
6173 .push("initial_task_failure".to_string());
6174 let error = anyhow::anyhow!("Initial task failure");
6175 let result: Result<String, anyhow::Error> =
6176 Err(FailedWithContinuation::into_anyhow(error, continuation));
6177 result
6178 });
6179
6180 let result = handle.await.expect("Task should complete");
6181
6182 // Task should succeed regardless of rescheduling behavior
6183 assert!(result.is_ok(), "{}: Task should succeed", description);
6184 assert_eq!(
6185 tracker.metrics().success(),
6186 1,
6187 "{}: Should have 1 success",
6188 description
6189 );
6190
6191 // Verify the execution log
6192 let log = execution_log.lock().await;
6193 assert_eq!(
6194 log.len(),
6195 2,
6196 "{}: Should have initial task + continuation",
6197 description
6198 );
6199 assert_eq!(log[0], "initial_task_failure");
6200 assert_eq!(log[1], "continuation_executed");
6201
6202 // Most importantly: verify the scheduler acquisition count
6203 let actual_acquisitions = mock_scheduler.acquisition_count();
6204 assert_eq!(
6205 actual_acquisitions, expected_acquisitions,
6206 "{}: Expected {} scheduler acquisitions, got {}",
6207 description, expected_acquisitions, actual_acquisitions
6208 );
6209 }
6210
6211 /// Test continuation loop with custom action policies
6212 ///
6213 /// This tests that custom error actions can also provide continuations
6214 /// and that the loop behavior works correctly with policy-provided continuations
6215 ///
6216 /// NOTE: Uses fresh policy/action instances to avoid global state interference.
6217 #[rstest]
6218 #[case(1, true, "Custom action with 1 retry should succeed")]
6219 #[case(3, true, "Custom action with 3 retries should succeed")]
6220 #[tokio::test]
6221 async fn test_continuation_loop_with_custom_action_policy(
6222 unlimited_scheduler: Arc<UnlimitedScheduler>,
6223 #[case] max_retries: u32,
6224 #[case] should_succeed: bool,
6225 #[case] description: &str,
6226 ) {
6227 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
6228 let retry_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
6229
6230 // Custom action that provides continuations up to max_retries
6231 #[derive(Debug)]
6232 struct RetryAction {
6233 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6234 retry_count: Arc<std::sync::atomic::AtomicU32>,
6235 max_retries: u32,
6236 }
6237
6238 #[async_trait]
6239 impl OnErrorAction for RetryAction {
6240 async fn execute(
6241 &self,
6242 _error: &anyhow::Error,
6243 _task_id: TaskId,
6244 _attempt_count: u32,
6245 _context: &TaskExecutionContext,
6246 ) -> ActionResult {
6247 let current_retry = self
6248 .retry_count
6249 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
6250 + 1;
6251 self.log
6252 .lock()
6253 .await
6254 .push(format!("custom_action_retry_{}", current_retry));
6255
6256 if current_retry <= self.max_retries {
6257 // Provide a continuation that succeeds if this is the final retry
6258 #[derive(Debug)]
6259 struct RetryContinuation {
6260 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6261 retry_number: u32,
6262 max_retries: u32,
6263 }
6264
6265 #[async_trait]
6266 impl Continuation for RetryContinuation {
6267 async fn execute(
6268 &self,
6269 _cancel_token: CancellationToken,
6270 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>>
6271 {
6272 self.log
6273 .lock()
6274 .await
6275 .push(format!("retry_continuation_{}", self.retry_number));
6276
6277 if self.retry_number >= self.max_retries {
6278 // Final retry succeeds
6279 TaskExecutionResult::Success(Box::new(format!(
6280 "success_after_{}_retries",
6281 self.retry_number
6282 )))
6283 } else {
6284 // Still need more retries, fail with regular error (not FailedWithContinuation)
6285 // This will trigger the custom action again
6286 TaskExecutionResult::Error(anyhow::anyhow!(
6287 "Retry {} still failing",
6288 self.retry_number
6289 ))
6290 }
6291 }
6292 }
6293
6294 let continuation = Arc::new(RetryContinuation {
6295 log: self.log.clone(),
6296 retry_number: current_retry,
6297 max_retries: self.max_retries,
6298 });
6299
6300 ActionResult::Continue { continuation }
6301 } else {
6302 // Exceeded max retries, cancel
6303 ActionResult::Shutdown
6304 }
6305 }
6306 }
6307
6308 // Custom policy that uses the retry action
6309 #[derive(Debug)]
6310 struct CustomRetryPolicy {
6311 action: Arc<RetryAction>,
6312 }
6313
6314 impl OnErrorPolicy for CustomRetryPolicy {
6315 fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
6316 Arc::new(CustomRetryPolicy {
6317 action: self.action.clone(),
6318 })
6319 }
6320
6321 fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
6322 None // Stateless policy - no heap allocation
6323 }
6324
6325 fn on_error(
6326 &self,
6327 _error: &anyhow::Error,
6328 _context: &mut OnErrorContext,
6329 ) -> ErrorResponse {
6330 ErrorResponse::Custom(Box::new(RetryAction {
6331 log: self.action.log.clone(),
6332 retry_count: self.action.retry_count.clone(),
6333 max_retries: self.action.max_retries,
6334 }))
6335 }
6336 }
6337
6338 let action = Arc::new(RetryAction {
6339 log: execution_log.clone(),
6340 retry_count: retry_count.clone(),
6341 max_retries,
6342 });
6343 let policy = Arc::new(CustomRetryPolicy { action });
6344 let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
6345
6346 // Task that always fails with regular error (not FailedWithContinuation)
6347 let log_for_task = execution_log.clone();
6348 let handle = tracker.spawn(async move {
6349 log_for_task
6350 .lock()
6351 .await
6352 .push("original_task_failed".to_string());
6353 let result: Result<String, anyhow::Error> =
6354 Err(anyhow::anyhow!("Original task failure"));
6355 result
6356 });
6357
6358 // Execute and verify results
6359 let result = handle.await.expect("Task should complete");
6360
6361 if should_succeed {
6362 assert!(result.is_ok(), "{}: Task should succeed", description);
6363 assert_eq!(
6364 tracker.metrics().success(),
6365 1,
6366 "{}: Should have 1 success",
6367 description
6368 );
6369
6370 // Verify the retry sequence
6371 let log = execution_log.lock().await;
6372 let expected_entries = 1 + (max_retries * 2); // original + (action + continuation) per retry
6373 assert_eq!(
6374 log.len(),
6375 expected_entries as usize,
6376 "{}: Should have {} log entries",
6377 description,
6378 expected_entries
6379 );
6380
6381 assert_eq!(
6382 retry_count.load(std::sync::atomic::Ordering::Relaxed),
6383 max_retries,
6384 "{}: Should have made {} retry attempts",
6385 description,
6386 max_retries
6387 );
6388 } else {
6389 assert!(result.is_err(), "{}: Task should fail", description);
6390 assert!(
6391 tracker.cancellation_token().is_cancelled(),
6392 "{}: Should be cancelled",
6393 description
6394 );
6395
6396 // Should have stopped after max_retries
6397 let final_retry_count = retry_count.load(std::sync::atomic::Ordering::Relaxed);
6398 assert!(
6399 final_retry_count > max_retries,
6400 "{}: Should have exceeded max_retries ({}), got {}",
6401 description,
6402 max_retries,
6403 final_retry_count
6404 );
6405 }
6406 }
6407
6408 /// Test mixed continuation sources (task-driven + policy-driven)
6409 ///
6410 /// This test verifies that both task-provided continuations and policy-provided
6411 /// continuations can work together in the same execution flow
6412 #[rstest]
6413 #[tokio::test]
6414 async fn test_mixed_continuation_sources(
6415 unlimited_scheduler: Arc<UnlimitedScheduler>,
6416 log_policy: Arc<LogOnlyPolicy>,
6417 ) {
6418 let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
6419 let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
6420
6421 // Task that provides a continuation, which then fails with regular error
6422 let log_for_task = execution_log.clone();
6423 let log_for_continuation = execution_log.clone();
6424
6425 #[derive(Debug)]
6426 struct MixedContinuation {
6427 log: Arc<tokio::sync::Mutex<Vec<String>>>,
6428 }
6429
6430 #[async_trait]
6431 impl Continuation for MixedContinuation {
6432 async fn execute(
6433 &self,
6434 _cancel_token: CancellationToken,
6435 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
6436 self.log
6437 .lock()
6438 .await
6439 .push("task_continuation_executed".to_string());
6440 // This continuation fails with a regular error (not FailedWithContinuation)
6441 // So it will be handled by the policy (LogOnlyPolicy just continues)
6442 TaskExecutionResult::Error(anyhow::anyhow!("Task continuation failed"))
6443 }
6444 }
6445
6446 let continuation = Arc::new(MixedContinuation {
6447 log: log_for_continuation,
6448 });
6449
6450 let handle = tracker.spawn(async move {
6451 log_for_task
6452 .lock()
6453 .await
6454 .push("original_task_executed".to_string());
6455
6456 // Task provides continuation
6457 let error = anyhow::anyhow!("Original task failed");
6458 let result: Result<String, anyhow::Error> =
6459 Err(FailedWithContinuation::into_anyhow(error, continuation));
6460 result
6461 });
6462
6463 // Execute - should fail because continuation fails and LogOnlyPolicy just logs
6464 let result = handle.await.expect("Task should complete");
6465 assert!(
6466 result.is_err(),
6467 "Should fail because continuation fails and policy just logs"
6468 );
6469
6470 // Verify execution sequence
6471 let log = execution_log.lock().await;
6472 assert_eq!(log.len(), 2);
6473 assert_eq!(log[0], "original_task_executed");
6474 assert_eq!(log[1], "task_continuation_executed");
6475
6476 // Verify metrics - should show failure from continuation
6477 assert_eq!(tracker.metrics().success(), 0);
6478 assert_eq!(tracker.metrics().failed(), 1);
6479 }
6480
6481 /// Debug test to understand the threshold policy behavior in retry loop
6482 #[rstest]
6483 #[tokio::test]
6484 async fn debug_threshold_policy_in_retry_loop(unlimited_scheduler: Arc<UnlimitedScheduler>) {
6485 let policy = ThresholdCancelPolicy::with_threshold(2);
6486 let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
6487
6488 // Simple continuation that always fails with regular error
6489 #[derive(Debug)]
6490 struct AlwaysFailContinuation {
6491 attempt: Arc<std::sync::atomic::AtomicU32>,
6492 }
6493
6494 #[async_trait]
6495 impl Continuation for AlwaysFailContinuation {
6496 async fn execute(
6497 &self,
6498 _cancel_token: CancellationToken,
6499 ) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
6500 let attempt_num = self
6501 .attempt
6502 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
6503 + 1;
6504 println!("Continuation attempt {}", attempt_num);
6505 TaskExecutionResult::Error(anyhow::anyhow!(
6506 "Continuation attempt {} failed",
6507 attempt_num
6508 ))
6509 }
6510 }
6511
6512 let attempt_counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
6513 let continuation = Arc::new(AlwaysFailContinuation {
6514 attempt: attempt_counter.clone(),
6515 });
6516
6517 let handle = tracker.spawn(async move {
6518 println!("Original task executing");
6519 let error = anyhow::anyhow!("Original task failed");
6520 let result: Result<String, anyhow::Error> =
6521 Err(FailedWithContinuation::into_anyhow(error, continuation));
6522 result
6523 });
6524
6525 let result = handle.await.expect("Task should complete");
6526 println!("Final result: {:?}", result.is_ok());
6527 println!("Policy failure count: {}", policy.failure_count());
6528 println!(
6529 "Continuation attempts: {}",
6530 attempt_counter.load(std::sync::atomic::Ordering::Relaxed)
6531 );
6532 println!(
6533 "Tracker cancelled: {}",
6534 tracker.cancellation_token().is_cancelled()
6535 );
6536 println!(
6537 "Metrics: success={}, failed={}",
6538 tracker.metrics().success(),
6539 tracker.metrics().failed()
6540 );
6541
6542 // This should help us understand what's happening
6543 }
6544}