use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::metrics::MetricsHierarchy;
use crate::metrics::prometheus_names::task_tracker;
use anyhow::Result;
use async_trait::async_trait;
use derive_builder::Builder;
use std::collections::HashSet;
use std::sync::{Mutex, RwLock, Weak};
use std::time::Duration;
use thiserror::Error;
use tokio::sync::Semaphore;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker as TokioTaskTracker;
use tracing::{Instrument, debug, error, warn};
use uuid::Uuid;
#[derive(Error, Debug)]
pub enum TaskError {
#[error("Task was cancelled")]
Cancelled,
#[error(transparent)]
Failed(#[from] anyhow::Error),
#[error("Cannot spawn task on a closed tracker")]
TrackerClosed,
}
impl TaskError {
pub fn is_cancellation(&self) -> bool {
matches!(self, TaskError::Cancelled)
}
pub fn is_failure(&self) -> bool {
matches!(self, TaskError::Failed(_))
}
pub fn into_anyhow(self) -> anyhow::Error {
match self {
TaskError::Failed(err) => err,
TaskError::Cancelled => anyhow::anyhow!("Task was cancelled"),
TaskError::TrackerClosed => anyhow::anyhow!("Cannot spawn task on a closed tracker"),
}
}
}
pub struct TaskHandle<T> {
join_handle: JoinHandle<Result<T, TaskError>>,
cancel_token: CancellationToken,
}
impl<T> TaskHandle<T> {
pub(crate) fn new(
join_handle: JoinHandle<Result<T, TaskError>>,
cancel_token: CancellationToken,
) -> Self {
Self {
join_handle,
cancel_token,
}
}
pub fn cancellation_token(&self) -> &CancellationToken {
&self.cancel_token
}
pub fn abort(&self) {
self.join_handle.abort();
}
pub fn is_finished(&self) -> bool {
self.join_handle.is_finished()
}
}
impl<T> std::future::Future for TaskHandle<T> {
type Output = Result<Result<T, TaskError>, tokio::task::JoinError>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
std::pin::Pin::new(&mut self.join_handle).poll(cx)
}
}
impl<T> std::fmt::Debug for TaskHandle<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TaskHandle")
.field("join_handle", &"<JoinHandle>")
.field("cancel_token", &self.cancel_token)
.finish()
}
}
#[async_trait]
pub trait Continuation: Send + Sync + std::fmt::Debug + std::any::Any {
async fn execute(
&self,
cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>>;
}
#[derive(Error, Debug)]
#[error("Task failed with continuation: {source}")]
pub struct FailedWithContinuation {
#[source]
pub source: anyhow::Error,
pub continuation: Arc<dyn Continuation + Send + Sync + 'static>,
}
impl FailedWithContinuation {
pub fn new(
source: anyhow::Error,
continuation: Arc<dyn Continuation + Send + Sync + 'static>,
) -> Self {
Self {
source,
continuation,
}
}
pub fn into_anyhow(
source: anyhow::Error,
continuation: Arc<dyn Continuation + Send + Sync + 'static>,
) -> anyhow::Error {
anyhow::Error::new(Self::new(source, continuation))
}
pub fn from_fn<F, Fut, T>(source: anyhow::Error, f: F) -> anyhow::Error
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
T: Send + 'static,
{
let continuation = Arc::new(FnContinuation { f: Box::new(f) });
Self::into_anyhow(source, continuation)
}
pub fn from_cancellable<F, Fut, T>(source: anyhow::Error, f: F) -> anyhow::Error
where
F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
T: Send + 'static,
{
let continuation = Arc::new(CancellableFnContinuation { f: Box::new(f) });
Self::into_anyhow(source, continuation)
}
}
pub trait FailedWithContinuationExt {
fn extract_continuation(&self) -> Option<Arc<dyn Continuation + Send + Sync + 'static>>;
fn has_continuation(&self) -> bool;
}
impl FailedWithContinuationExt for anyhow::Error {
fn extract_continuation(&self) -> Option<Arc<dyn Continuation + Send + Sync + 'static>> {
if let Some(continuation_err) = self.downcast_ref::<FailedWithContinuation>() {
Some(continuation_err.continuation.clone())
} else {
None
}
}
fn has_continuation(&self) -> bool {
self.downcast_ref::<FailedWithContinuation>().is_some()
}
}
struct FnContinuation<F> {
f: Box<F>,
}
impl<F> std::fmt::Debug for FnContinuation<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FnContinuation")
.field("f", &"<closure>")
.finish()
}
}
#[async_trait]
impl<F, Fut, T> Continuation for FnContinuation<F>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
T: Send + 'static,
{
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
match (self.f)().await {
Ok(result) => TaskExecutionResult::Success(Box::new(result)),
Err(error) => TaskExecutionResult::Error(error),
}
}
}
struct CancellableFnContinuation<F> {
f: Box<F>,
}
impl<F> std::fmt::Debug for CancellableFnContinuation<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CancellableFnContinuation")
.field("f", &"<closure>")
.finish()
}
}
#[async_trait]
impl<F, Fut, T> Continuation for CancellableFnContinuation<F>
where
F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<T, anyhow::Error>> + Send + 'static,
T: Send + 'static,
{
async fn execute(
&self,
cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
match (self.f)(cancel_token).await {
Ok(result) => TaskExecutionResult::Success(Box::new(result)),
Err(error) => TaskExecutionResult::Error(error),
}
}
}
#[derive(Debug, Clone)]
pub enum SchedulingPolicy {
Unlimited,
Semaphore(usize),
}
pub struct OnErrorContext {
pub attempt_count: u32,
pub task_id: TaskId,
pub execution_context: TaskExecutionContext,
pub state: Option<Box<dyn std::any::Any + Send + 'static>>,
}
pub trait OnErrorPolicy: Send + Sync + std::fmt::Debug {
fn create_child(&self) -> Arc<dyn OnErrorPolicy>;
fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>>;
fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse;
fn allow_continuation(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
true }
fn should_reschedule(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
false }
}
#[derive(Debug, Clone)]
pub enum ErrorPolicy {
LogOnly,
CancelOnError,
CancelOnPatterns(Vec<String>),
CancelOnThreshold { max_failures: usize },
CancelOnRate {
max_failure_rate: f32,
window_secs: u64,
},
}
#[derive(Debug)]
pub enum ErrorResponse {
Fail,
Shutdown,
Custom(Box<dyn OnErrorAction>),
}
#[async_trait]
pub trait OnErrorAction: Send + Sync + std::fmt::Debug {
async fn execute(
&self,
error: &anyhow::Error,
task_id: TaskId,
attempt_count: u32,
context: &TaskExecutionContext,
) -> ActionResult;
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum GuardState {
Keep,
Reschedule,
}
#[derive(Debug)]
pub enum ActionResult {
Fail,
Continue {
continuation: Arc<dyn Continuation + Send + Sync + 'static>,
},
Shutdown,
}
pub struct TaskExecutionContext {
pub scheduler: Arc<dyn TaskScheduler>,
pub metrics: Arc<dyn HierarchicalTaskMetrics>,
}
#[derive(Debug)]
pub enum TaskExecutionResult<T> {
Success(T),
Cancelled,
Error(anyhow::Error),
}
#[async_trait]
trait TaskExecutor<T>: Send {
async fn execute(&mut self, cancel_token: CancellationToken) -> TaskExecutionResult<T>;
}
struct RegularTaskExecutor<F, T>
where
F: Future<Output = Result<T>> + Send + 'static,
T: Send + 'static,
{
future: Option<F>,
_phantom: std::marker::PhantomData<T>,
}
impl<F, T> RegularTaskExecutor<F, T>
where
F: Future<Output = Result<T>> + Send + 'static,
T: Send + 'static,
{
fn new(future: F) -> Self {
Self {
future: Some(future),
_phantom: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<F, T> TaskExecutor<T> for RegularTaskExecutor<F, T>
where
F: Future<Output = Result<T>> + Send + 'static,
T: Send + 'static,
{
async fn execute(&mut self, _cancel_token: CancellationToken) -> TaskExecutionResult<T> {
if let Some(future) = self.future.take() {
match future.await {
Ok(value) => TaskExecutionResult::Success(value),
Err(error) => TaskExecutionResult::Error(error),
}
} else {
TaskExecutionResult::Error(anyhow::anyhow!("Regular task already consumed"))
}
}
}
struct CancellableTaskExecutor<F, Fut, T>
where
F: FnMut(CancellationToken) -> Fut + Send + 'static,
Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
T: Send + 'static,
{
task_fn: F,
}
impl<F, Fut, T> CancellableTaskExecutor<F, Fut, T>
where
F: FnMut(CancellationToken) -> Fut + Send + 'static,
Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
T: Send + 'static,
{
fn new(task_fn: F) -> Self {
Self { task_fn }
}
}
#[async_trait]
impl<F, Fut, T> TaskExecutor<T> for CancellableTaskExecutor<F, Fut, T>
where
F: FnMut(CancellationToken) -> Fut + Send + 'static,
Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
T: Send + 'static,
{
async fn execute(&mut self, cancel_token: CancellationToken) -> TaskExecutionResult<T> {
let future = (self.task_fn)(cancel_token);
match future.await {
CancellableTaskResult::Ok(value) => TaskExecutionResult::Success(value),
CancellableTaskResult::Cancelled => TaskExecutionResult::Cancelled,
CancellableTaskResult::Err(error) => TaskExecutionResult::Error(error),
}
}
}
pub trait ArcPolicy: Sized + Send + Sync + 'static {
fn new_arc(self) -> Arc<Self> {
Arc::new(self)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TaskId(Uuid);
impl TaskId {
fn new() -> Self {
Self(Uuid::new_v4())
}
}
impl std::fmt::Display for TaskId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "task-{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum CompletionStatus {
Ok,
Cancelled,
Failed(String),
}
#[derive(Debug)]
pub enum CancellableTaskResult<T> {
Ok(T),
Cancelled,
Err(anyhow::Error),
}
#[derive(Debug)]
pub enum SchedulingResult<T> {
Execute(T),
Cancelled,
Rejected(String),
}
pub trait ResourceGuard: Send + 'static {
}
#[async_trait]
pub trait TaskScheduler: Send + Sync + std::fmt::Debug {
async fn acquire_execution_slot(
&self,
cancel_token: CancellationToken,
) -> SchedulingResult<Box<dyn ResourceGuard>>;
}
pub trait HierarchicalTaskMetrics: Send + Sync + std::fmt::Debug {
fn increment_issued(&self);
fn increment_started(&self);
fn increment_success(&self);
fn increment_cancelled(&self);
fn increment_failed(&self);
fn increment_rejected(&self);
fn issued(&self) -> u64;
fn started(&self) -> u64;
fn success(&self) -> u64;
fn cancelled(&self) -> u64;
fn failed(&self) -> u64;
fn rejected(&self) -> u64;
fn total_completed(&self) -> u64 {
self.success() + self.cancelled() + self.failed() + self.rejected()
}
fn pending(&self) -> u64 {
self.issued().saturating_sub(self.total_completed())
}
fn active(&self) -> u64 {
self.started().saturating_sub(self.total_completed())
}
fn queued(&self) -> u64 {
self.issued().saturating_sub(self.started())
}
}
#[derive(Debug, Default)]
pub struct TaskMetrics {
pub issued_count: AtomicU64,
pub started_count: AtomicU64,
pub success_count: AtomicU64,
pub cancelled_count: AtomicU64,
pub failed_count: AtomicU64,
pub rejected_count: AtomicU64,
}
impl TaskMetrics {
pub fn new() -> Self {
Self::default()
}
}
impl HierarchicalTaskMetrics for TaskMetrics {
fn increment_issued(&self) {
self.issued_count.fetch_add(1, Ordering::Relaxed);
}
fn increment_started(&self) {
self.started_count.fetch_add(1, Ordering::Relaxed);
}
fn increment_success(&self) {
self.success_count.fetch_add(1, Ordering::Relaxed);
}
fn increment_cancelled(&self) {
self.cancelled_count.fetch_add(1, Ordering::Relaxed);
}
fn increment_failed(&self) {
self.failed_count.fetch_add(1, Ordering::Relaxed);
}
fn increment_rejected(&self) {
self.rejected_count.fetch_add(1, Ordering::Relaxed);
}
fn issued(&self) -> u64 {
self.issued_count.load(Ordering::Relaxed)
}
fn started(&self) -> u64 {
self.started_count.load(Ordering::Relaxed)
}
fn success(&self) -> u64 {
self.success_count.load(Ordering::Relaxed)
}
fn cancelled(&self) -> u64 {
self.cancelled_count.load(Ordering::Relaxed)
}
fn failed(&self) -> u64 {
self.failed_count.load(Ordering::Relaxed)
}
fn rejected(&self) -> u64 {
self.rejected_count.load(Ordering::Relaxed)
}
}
#[derive(Debug)]
pub struct PrometheusTaskMetrics {
prometheus_issued: prometheus::IntCounter,
prometheus_started: prometheus::IntCounter,
prometheus_success: prometheus::IntCounter,
prometheus_cancelled: prometheus::IntCounter,
prometheus_failed: prometheus::IntCounter,
prometheus_rejected: prometheus::IntCounter,
}
impl PrometheusTaskMetrics {
pub fn new<R: MetricsHierarchy>(registry: &R, component_name: &str) -> anyhow::Result<Self> {
let metrics = registry.metrics();
let issued_counter = metrics.create_intcounter(
&format!("{}_{}", component_name, task_tracker::TASKS_ISSUED_TOTAL),
"Total number of tasks issued/submitted",
&[],
)?;
let started_counter = metrics.create_intcounter(
&format!("{}_{}", component_name, task_tracker::TASKS_STARTED_TOTAL),
"Total number of tasks started",
&[],
)?;
let success_counter = metrics.create_intcounter(
&format!("{}_{}", component_name, task_tracker::TASKS_SUCCESS_TOTAL),
"Total number of successfully completed tasks",
&[],
)?;
let cancelled_counter = metrics.create_intcounter(
&format!("{}_{}", component_name, task_tracker::TASKS_CANCELLED_TOTAL),
"Total number of cancelled tasks",
&[],
)?;
let failed_counter = metrics.create_intcounter(
&format!("{}_{}", component_name, task_tracker::TASKS_FAILED_TOTAL),
"Total number of failed tasks",
&[],
)?;
let rejected_counter = metrics.create_intcounter(
&format!("{}_{}", component_name, task_tracker::TASKS_REJECTED_TOTAL),
"Total number of rejected tasks",
&[],
)?;
Ok(Self {
prometheus_issued: issued_counter,
prometheus_started: started_counter,
prometheus_success: success_counter,
prometheus_cancelled: cancelled_counter,
prometheus_failed: failed_counter,
prometheus_rejected: rejected_counter,
})
}
}
impl HierarchicalTaskMetrics for PrometheusTaskMetrics {
fn increment_issued(&self) {
self.prometheus_issued.inc();
}
fn increment_started(&self) {
self.prometheus_started.inc();
}
fn increment_success(&self) {
self.prometheus_success.inc();
}
fn increment_cancelled(&self) {
self.prometheus_cancelled.inc();
}
fn increment_failed(&self) {
self.prometheus_failed.inc();
}
fn increment_rejected(&self) {
self.prometheus_rejected.inc();
}
fn issued(&self) -> u64 {
self.prometheus_issued.get()
}
fn started(&self) -> u64 {
self.prometheus_started.get()
}
fn success(&self) -> u64 {
self.prometheus_success.get()
}
fn cancelled(&self) -> u64 {
self.prometheus_cancelled.get()
}
fn failed(&self) -> u64 {
self.prometheus_failed.get()
}
fn rejected(&self) -> u64 {
self.prometheus_rejected.get()
}
}
#[derive(Debug)]
struct ChildTaskMetrics {
local_metrics: TaskMetrics,
parent_metrics: Arc<dyn HierarchicalTaskMetrics>,
}
impl ChildTaskMetrics {
fn new(parent_metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
Self {
local_metrics: TaskMetrics::new(),
parent_metrics,
}
}
}
impl HierarchicalTaskMetrics for ChildTaskMetrics {
fn increment_issued(&self) {
self.local_metrics.increment_issued();
self.parent_metrics.increment_issued();
}
fn increment_started(&self) {
self.local_metrics.increment_started();
self.parent_metrics.increment_started();
}
fn increment_success(&self) {
self.local_metrics.increment_success();
self.parent_metrics.increment_success();
}
fn increment_cancelled(&self) {
self.local_metrics.increment_cancelled();
self.parent_metrics.increment_cancelled();
}
fn increment_failed(&self) {
self.local_metrics.increment_failed();
self.parent_metrics.increment_failed();
}
fn increment_rejected(&self) {
self.local_metrics.increment_rejected();
self.parent_metrics.increment_rejected();
}
fn issued(&self) -> u64 {
self.local_metrics.issued()
}
fn started(&self) -> u64 {
self.local_metrics.started()
}
fn success(&self) -> u64 {
self.local_metrics.success()
}
fn cancelled(&self) -> u64 {
self.local_metrics.cancelled()
}
fn failed(&self) -> u64 {
self.local_metrics.failed()
}
fn rejected(&self) -> u64 {
self.local_metrics.rejected()
}
}
pub struct ChildTrackerBuilder<'parent> {
parent: &'parent TaskTracker,
scheduler: Option<Arc<dyn TaskScheduler>>,
error_policy: Option<Arc<dyn OnErrorPolicy>>,
}
impl<'parent> ChildTrackerBuilder<'parent> {
pub fn new(parent: &'parent TaskTracker) -> Self {
Self {
parent,
scheduler: None,
error_policy: None,
}
}
pub fn scheduler(mut self, scheduler: Arc<dyn TaskScheduler>) -> Self {
self.scheduler = Some(scheduler);
self
}
pub fn error_policy(mut self, error_policy: Arc<dyn OnErrorPolicy>) -> Self {
self.error_policy = Some(error_policy);
self
}
pub fn build(self) -> anyhow::Result<TaskTracker> {
if self.parent.is_closed() {
return Err(anyhow::anyhow!(
"Cannot create child tracker from closed parent tracker"
));
}
let parent = self.parent.0.clone();
let child_cancel_token = parent.cancel_token.child_token();
let child_metrics = Arc::new(ChildTaskMetrics::new(parent.metrics.clone()));
let scheduler = self.scheduler.unwrap_or_else(|| parent.scheduler.clone());
let error_policy = self
.error_policy
.unwrap_or_else(|| parent.error_policy.create_child());
let child = Arc::new(TaskTrackerInner {
tokio_tracker: TokioTaskTracker::new(),
parent: None, scheduler,
error_policy,
metrics: child_metrics,
cancel_token: child_cancel_token,
children: RwLock::new(Vec::new()),
});
parent
.children
.write()
.unwrap()
.push(Arc::downgrade(&child));
parent.cleanup_dead_children();
Ok(TaskTracker(child))
}
}
struct TaskTrackerInner {
tokio_tracker: TokioTaskTracker,
parent: Option<Arc<TaskTrackerInner>>,
scheduler: Arc<dyn TaskScheduler>,
error_policy: Arc<dyn OnErrorPolicy>,
metrics: Arc<dyn HierarchicalTaskMetrics>,
cancel_token: CancellationToken,
children: RwLock<Vec<Weak<TaskTrackerInner>>>,
}
#[derive(Clone)]
pub struct TaskTracker(Arc<TaskTrackerInner>);
#[derive(Default)]
pub struct TaskTrackerBuilder {
scheduler: Option<Arc<dyn TaskScheduler>>,
error_policy: Option<Arc<dyn OnErrorPolicy>>,
metrics: Option<Arc<dyn HierarchicalTaskMetrics>>,
cancel_token: Option<CancellationToken>,
}
impl TaskTrackerBuilder {
pub fn scheduler(mut self, scheduler: Arc<dyn TaskScheduler>) -> Self {
self.scheduler = Some(scheduler);
self
}
pub fn error_policy(mut self, error_policy: Arc<dyn OnErrorPolicy>) -> Self {
self.error_policy = Some(error_policy);
self
}
pub fn metrics(mut self, metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
self.metrics = Some(metrics);
self
}
pub fn cancel_token(mut self, cancel_token: CancellationToken) -> Self {
self.cancel_token = Some(cancel_token);
self
}
pub fn build(self) -> anyhow::Result<TaskTracker> {
let scheduler = self
.scheduler
.ok_or_else(|| anyhow::anyhow!("TaskTracker requires a scheduler"))?;
let error_policy = self
.error_policy
.ok_or_else(|| anyhow::anyhow!("TaskTracker requires an error policy"))?;
let metrics = self.metrics.unwrap_or_else(|| Arc::new(TaskMetrics::new()));
let cancel_token = self.cancel_token.unwrap_or_default();
let inner = TaskTrackerInner {
tokio_tracker: TokioTaskTracker::new(),
parent: None,
scheduler,
error_policy,
metrics,
cancel_token,
children: RwLock::new(Vec::new()),
};
Ok(TaskTracker(Arc::new(inner)))
}
}
impl TaskTracker {
pub fn builder() -> TaskTrackerBuilder {
TaskTrackerBuilder::default()
}
pub fn new(
scheduler: Arc<dyn TaskScheduler>,
error_policy: Arc<dyn OnErrorPolicy>,
) -> anyhow::Result<Self> {
Self::builder()
.scheduler(scheduler)
.error_policy(error_policy)
.build()
}
pub fn new_with_prometheus<R: MetricsHierarchy>(
scheduler: Arc<dyn TaskScheduler>,
error_policy: Arc<dyn OnErrorPolicy>,
registry: &R,
component_name: &str,
) -> anyhow::Result<Self> {
let prometheus_metrics = Arc::new(PrometheusTaskMetrics::new(registry, component_name)?);
Self::builder()
.scheduler(scheduler)
.error_policy(error_policy)
.metrics(prometheus_metrics)
.build()
}
pub fn child_tracker(&self) -> anyhow::Result<TaskTracker> {
Ok(TaskTracker(self.0.child_tracker()?))
}
pub fn spawn<F, T>(&self, future: F) -> TaskHandle<T>
where
F: Future<Output = Result<T>> + Send + 'static,
T: Send + 'static,
{
self.0
.spawn(future)
.expect("TaskTracker must not be closed when spawning tasks")
}
pub fn spawn_cancellable<F, Fut, T>(&self, task_fn: F) -> TaskHandle<T>
where
F: FnMut(CancellationToken) -> Fut + Send + 'static,
Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
T: Send + 'static,
{
self.0
.spawn_cancellable(task_fn)
.expect("TaskTracker must not be closed when spawning tasks")
}
pub fn metrics(&self) -> &dyn HierarchicalTaskMetrics {
self.0.metrics.as_ref()
}
pub fn cancel(&self) {
self.0.cancel();
}
pub fn is_closed(&self) -> bool {
self.0.is_closed()
}
pub fn cancellation_token(&self) -> CancellationToken {
self.0.cancellation_token()
}
pub fn child_count(&self) -> usize {
self.0.child_count()
}
pub fn child_tracker_builder(&self) -> ChildTrackerBuilder<'_> {
ChildTrackerBuilder::new(self)
}
pub async fn join(&self) {
self.0.join().await
}
}
impl TaskTrackerInner {
fn child_tracker(self: &Arc<Self>) -> anyhow::Result<Arc<TaskTrackerInner>> {
if self.is_closed() {
return Err(anyhow::anyhow!(
"Cannot create child tracker from closed parent tracker"
));
}
let child_cancel_token = self.cancel_token.child_token();
let child_metrics = Arc::new(ChildTaskMetrics::new(self.metrics.clone()));
let child = Arc::new(TaskTrackerInner {
tokio_tracker: TokioTaskTracker::new(),
parent: Some(self.clone()),
scheduler: self.scheduler.clone(),
error_policy: self.error_policy.create_child(),
metrics: child_metrics,
cancel_token: child_cancel_token,
children: RwLock::new(Vec::new()),
});
self.children.write().unwrap().push(Arc::downgrade(&child));
self.cleanup_dead_children();
Ok(child)
}
fn spawn<F, T>(self: &Arc<Self>, future: F) -> Result<TaskHandle<T>, TaskError>
where
F: Future<Output = Result<T>> + Send + 'static,
T: Send + 'static,
{
if self.tokio_tracker.is_closed() {
return Err(TaskError::TrackerClosed);
}
let task_id = self.generate_task_id();
self.metrics.increment_issued();
let task_cancel_token = self.cancel_token.child_token();
let cancel_token = task_cancel_token.clone();
let inner = self.clone();
let wrapped_future =
async move { Self::execute_with_policies(task_id, future, cancel_token, inner).await };
let join_handle = self.tokio_tracker.spawn(wrapped_future);
Ok(TaskHandle::new(join_handle, task_cancel_token))
}
fn spawn_cancellable<F, Fut, T>(
self: &Arc<Self>,
task_fn: F,
) -> Result<TaskHandle<T>, TaskError>
where
F: FnMut(CancellationToken) -> Fut + Send + 'static,
Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
T: Send + 'static,
{
if self.tokio_tracker.is_closed() {
return Err(TaskError::TrackerClosed);
}
let task_id = self.generate_task_id();
self.metrics.increment_issued();
let task_cancel_token = self.cancel_token.child_token();
let cancel_token = task_cancel_token.clone();
let inner = self.clone();
let wrapped_future = async move {
Self::execute_cancellable_with_policies(task_id, task_fn, cancel_token, inner).await
};
let join_handle = self.tokio_tracker.spawn(wrapped_future);
Ok(TaskHandle::new(join_handle, task_cancel_token))
}
fn cancel(&self) {
self.tokio_tracker.close();
self.cancel_token.cancel();
}
fn is_closed(&self) -> bool {
self.tokio_tracker.is_closed()
}
fn generate_task_id(&self) -> TaskId {
TaskId::new()
}
fn cleanup_dead_children(&self) {
let mut children_guard = self.children.write().unwrap();
children_guard.retain(|weak| weak.upgrade().is_some());
}
fn cancellation_token(&self) -> CancellationToken {
self.cancel_token.clone()
}
fn child_count(&self) -> usize {
let children_guard = self.children.read().unwrap();
children_guard
.iter()
.filter(|weak| weak.upgrade().is_some())
.count()
}
async fn join(self: &Arc<Self>) {
let is_leaf = {
let children_guard = self.children.read().unwrap();
children_guard.is_empty()
};
if is_leaf {
self.tokio_tracker.close();
self.tokio_tracker.wait().await;
return;
}
let trackers = self.collect_hierarchy();
for t in trackers {
t.tokio_tracker.close();
t.tokio_tracker.wait().await;
}
}
fn collect_hierarchy(self: &Arc<TaskTrackerInner>) -> Vec<Arc<TaskTrackerInner>> {
let mut result = Vec::new();
let mut stack = vec![self.clone()];
let mut visited = HashSet::new();
while let Some(tracker) = stack.pop() {
let tracker_ptr = Arc::as_ptr(&tracker) as usize;
if visited.contains(&tracker_ptr) {
continue;
}
visited.insert(tracker_ptr);
result.push(tracker.clone());
if let Ok(children_guard) = tracker.children.read() {
for weak_child in children_guard.iter() {
if let Some(child) = weak_child.upgrade() {
let child_ptr = Arc::as_ptr(&child) as usize;
if !visited.contains(&child_ptr) {
stack.push(child);
}
}
}
}
}
result.reverse();
result
}
#[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
async fn execute_with_policies<F, T>(
task_id: TaskId,
future: F,
task_cancel_token: CancellationToken,
inner: Arc<TaskTrackerInner>,
) -> Result<T, TaskError>
where
F: Future<Output = Result<T>> + Send + 'static,
T: Send + 'static,
{
let task_executor = RegularTaskExecutor::new(future);
Self::execute_with_retry_loop(task_id, task_executor, task_cancel_token, inner).await
}
#[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
async fn execute_cancellable_with_policies<F, Fut, T>(
task_id: TaskId,
task_fn: F,
task_cancel_token: CancellationToken,
inner: Arc<TaskTrackerInner>,
) -> Result<T, TaskError>
where
F: FnMut(CancellationToken) -> Fut + Send + 'static,
Fut: Future<Output = CancellableTaskResult<T>> + Send + 'static,
T: Send + 'static,
{
let task_executor = CancellableTaskExecutor::new(task_fn);
Self::execute_with_retry_loop(task_id, task_executor, task_cancel_token, inner).await
}
#[tracing::instrument(level = "debug", skip_all, fields(task_id = %task_id))]
async fn execute_with_retry_loop<E, T>(
task_id: TaskId,
initial_executor: E,
task_cancellation_token: CancellationToken,
inner: Arc<TaskTrackerInner>,
) -> Result<T, TaskError>
where
E: TaskExecutor<T> + Send + 'static,
T: Send + 'static,
{
debug!("Starting task execution");
struct ActiveCountGuard {
metrics: Arc<dyn HierarchicalTaskMetrics>,
is_active: bool,
}
impl ActiveCountGuard {
fn new(metrics: Arc<dyn HierarchicalTaskMetrics>) -> Self {
Self {
metrics,
is_active: false,
}
}
fn activate(&mut self) {
if !self.is_active {
self.metrics.increment_started();
self.is_active = true;
}
}
}
enum CurrentExecutable<E>
where
E: Send + 'static,
{
TaskExecutor(E),
Continuation(Arc<dyn Continuation + Send + Sync + 'static>),
}
let mut current_executable = CurrentExecutable::TaskExecutor(initial_executor);
let mut active_guard = ActiveCountGuard::new(inner.metrics.clone());
let mut error_context: Option<OnErrorContext> = None;
let mut scheduler_guard_state = self::GuardState::Keep;
let mut guard_result = async {
inner
.scheduler
.acquire_execution_slot(task_cancellation_token.child_token())
.await
}
.instrument(tracing::debug_span!("scheduler_resource_reacquisition"))
.await;
loop {
if scheduler_guard_state == self::GuardState::Reschedule {
guard_result = async {
inner
.scheduler
.acquire_execution_slot(inner.cancel_token.child_token())
.await
}
.instrument(tracing::debug_span!("scheduler_resource_reacquisition"))
.await;
}
match &guard_result {
SchedulingResult::Execute(_guard) => {
active_guard.activate();
let execution_result = async {
debug!("Executing task with acquired resources");
match &mut current_executable {
CurrentExecutable::TaskExecutor(executor) => {
executor.execute(inner.cancel_token.child_token()).await
}
CurrentExecutable::Continuation(continuation) => {
match continuation.execute(inner.cancel_token.child_token()).await {
TaskExecutionResult::Success(result) => {
if let Ok(typed_result) = result.downcast::<T>() {
TaskExecutionResult::Success(*typed_result)
} else {
let type_error = anyhow::anyhow!(
"Continuation task returned wrong type"
);
error!(
?type_error,
"Type mismatch in continuation task result"
);
TaskExecutionResult::Error(type_error)
}
}
TaskExecutionResult::Cancelled => {
TaskExecutionResult::Cancelled
}
TaskExecutionResult::Error(error) => {
TaskExecutionResult::Error(error)
}
}
}
}
}
.instrument(tracing::debug_span!("task_execution"))
.await;
match execution_result {
TaskExecutionResult::Success(value) => {
inner.metrics.increment_success();
debug!("Task completed successfully");
return Ok(value);
}
TaskExecutionResult::Cancelled => {
inner.metrics.increment_cancelled();
debug!("Task was cancelled during execution");
return Err(TaskError::Cancelled);
}
TaskExecutionResult::Error(error) => {
debug!("Task failed - handling error through policy - {error:?}");
let (action_result, guard_state) = Self::handle_task_error(
&error,
&mut error_context,
task_id,
&inner,
)
.await;
scheduler_guard_state = guard_state;
match action_result {
ActionResult::Fail => {
inner.metrics.increment_failed();
debug!("Policy accepted error - task failed {error:?}");
return Err(TaskError::Failed(error));
}
ActionResult::Shutdown => {
inner.metrics.increment_failed();
warn!("Policy triggered shutdown - {error:?}");
inner.cancel();
return Err(TaskError::Failed(error));
}
ActionResult::Continue { continuation } => {
debug!(
"Policy provided next executable - continuing loop - {error:?}"
);
current_executable =
CurrentExecutable::Continuation(continuation);
continue; }
}
}
}
}
SchedulingResult::Cancelled => {
inner.metrics.increment_cancelled();
debug!("Task was cancelled during resource acquisition");
return Err(TaskError::Cancelled);
}
SchedulingResult::Rejected(reason) => {
inner.metrics.increment_rejected();
debug!(reason, "Task was rejected by scheduler");
return Err(TaskError::Failed(anyhow::anyhow!(
"Task rejected: {}",
reason
)));
}
}
}
}
async fn handle_task_error(
error: &anyhow::Error,
error_context: &mut Option<OnErrorContext>,
task_id: TaskId,
inner: &Arc<TaskTrackerInner>,
) -> (ActionResult, self::GuardState) {
let context = error_context.get_or_insert_with(|| OnErrorContext {
attempt_count: 0, task_id,
execution_context: TaskExecutionContext {
scheduler: inner.scheduler.clone(),
metrics: inner.metrics.clone(),
},
state: inner.error_policy.create_context(),
});
context.attempt_count += 1;
let current_attempt = context.attempt_count;
if inner.error_policy.allow_continuation(error, context) {
if let Some(continuation_err) = error.downcast_ref::<FailedWithContinuation>() {
debug!(
task_id = %task_id,
attempt_count = current_attempt,
"Task provided FailedWithContinuation and policy allows continuations - {error:?}"
);
let continuation = continuation_err.continuation.clone();
let should_reschedule = inner.error_policy.should_reschedule(error, context);
let guard_state = if should_reschedule {
self::GuardState::Reschedule
} else {
self::GuardState::Keep
};
return (ActionResult::Continue { continuation }, guard_state);
}
} else {
debug!(
task_id = %task_id,
attempt_count = current_attempt,
"Policy rejected continuations, ignoring any FailedWithContinuation - {error:?}"
);
}
let response = inner.error_policy.on_error(error, context);
match response {
ErrorResponse::Fail => (ActionResult::Fail, self::GuardState::Keep),
ErrorResponse::Shutdown => (ActionResult::Shutdown, self::GuardState::Keep),
ErrorResponse::Custom(action) => {
debug!("Task failed - executing custom action - {error:?}");
let action_result = action
.execute(error, task_id, current_attempt, &context.execution_context)
.await;
debug!(?action_result, "Custom action completed");
let guard_state = match &action_result {
ActionResult::Continue { .. } => {
let should_reschedule =
inner.error_policy.should_reschedule(error, context);
if should_reschedule {
self::GuardState::Reschedule
} else {
self::GuardState::Keep
}
}
_ => self::GuardState::Keep, };
(action_result, guard_state)
}
}
}
}
impl ArcPolicy for UnlimitedScheduler {}
impl ArcPolicy for SemaphoreScheduler {}
impl ArcPolicy for LogOnlyPolicy {}
impl ArcPolicy for CancelOnError {}
impl ArcPolicy for ThresholdCancelPolicy {}
impl ArcPolicy for RateCancelPolicy {}
#[derive(Debug)]
pub struct UnlimitedGuard;
impl ResourceGuard for UnlimitedGuard {
}
#[derive(Debug)]
pub struct UnlimitedScheduler;
impl UnlimitedScheduler {
pub fn new() -> Arc<Self> {
Arc::new(Self)
}
}
impl Default for UnlimitedScheduler {
fn default() -> Self {
UnlimitedScheduler
}
}
#[async_trait]
impl TaskScheduler for UnlimitedScheduler {
async fn acquire_execution_slot(
&self,
cancel_token: CancellationToken,
) -> SchedulingResult<Box<dyn ResourceGuard>> {
debug!("Acquiring execution slot (unlimited scheduler)");
if cancel_token.is_cancelled() {
debug!("Task cancelled before acquiring execution slot");
return SchedulingResult::Cancelled;
}
debug!("Execution slot acquired immediately");
SchedulingResult::Execute(Box::new(UnlimitedGuard))
}
}
#[derive(Debug)]
pub struct SemaphoreGuard {
_permit: tokio::sync::OwnedSemaphorePermit,
}
impl ResourceGuard for SemaphoreGuard {
}
#[derive(Debug)]
pub struct SemaphoreScheduler {
semaphore: Arc<Semaphore>,
}
impl SemaphoreScheduler {
pub fn new(semaphore: Arc<Semaphore>) -> Self {
Self { semaphore }
}
pub fn with_permits(permits: usize) -> Arc<Self> {
Arc::new(Self::new(Arc::new(Semaphore::new(permits))))
}
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
}
#[async_trait]
impl TaskScheduler for SemaphoreScheduler {
async fn acquire_execution_slot(
&self,
cancel_token: CancellationToken,
) -> SchedulingResult<Box<dyn ResourceGuard>> {
debug!("Acquiring semaphore permit");
if cancel_token.is_cancelled() {
debug!("Task cancelled before acquiring semaphore permit");
return SchedulingResult::Cancelled;
}
let permit = {
tokio::select! {
result = self.semaphore.clone().acquire_owned() => {
match result {
Ok(permit) => permit,
Err(_) => return SchedulingResult::Cancelled,
}
}
_ = cancel_token.cancelled() => {
debug!("Task cancelled while waiting for semaphore permit");
return SchedulingResult::Cancelled;
}
}
};
debug!("Acquired semaphore permit");
SchedulingResult::Execute(Box::new(SemaphoreGuard { _permit: permit }))
}
}
#[derive(Debug)]
pub struct CancelOnError {
error_patterns: Vec<String>,
}
impl CancelOnError {
pub fn new() -> Arc<Self> {
Arc::new(Self {
error_patterns: vec![], })
}
pub fn with_patterns(error_patterns: Vec<String>) -> (Arc<Self>, CancellationToken) {
let token = CancellationToken::new();
let policy = Arc::new(Self { error_patterns });
(policy, token)
}
}
#[async_trait]
impl OnErrorPolicy for CancelOnError {
fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
Arc::new(CancelOnError {
error_patterns: self.error_patterns.clone(),
})
}
fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
None }
fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
error!(?context.task_id, "Task failed - {error:?}");
if self.error_patterns.is_empty() {
return ErrorResponse::Shutdown;
}
let error_str = error.to_string();
let should_cancel = self
.error_patterns
.iter()
.any(|pattern| error_str.contains(pattern));
if should_cancel {
ErrorResponse::Shutdown
} else {
ErrorResponse::Fail
}
}
}
#[derive(Debug)]
pub struct LogOnlyPolicy;
impl LogOnlyPolicy {
pub fn new() -> Arc<Self> {
Arc::new(Self)
}
}
impl Default for LogOnlyPolicy {
fn default() -> Self {
LogOnlyPolicy
}
}
impl OnErrorPolicy for LogOnlyPolicy {
fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
Arc::new(LogOnlyPolicy)
}
fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
None }
fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
error!(?context.task_id, "Task failed - logging only - {error:?}");
ErrorResponse::Fail
}
}
#[derive(Debug)]
pub struct ThresholdCancelPolicy {
max_failures: usize,
failure_count: AtomicU64,
}
impl ThresholdCancelPolicy {
pub fn with_threshold(max_failures: usize) -> Arc<Self> {
Arc::new(Self {
max_failures,
failure_count: AtomicU64::new(0),
})
}
pub fn failure_count(&self) -> u64 {
self.failure_count.load(Ordering::Relaxed)
}
pub fn reset_failure_count(&self) {
self.failure_count.store(0, Ordering::Relaxed);
}
}
#[derive(Debug)]
struct ThresholdState {
failure_count: u32,
}
impl OnErrorPolicy for ThresholdCancelPolicy {
fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
Arc::new(ThresholdCancelPolicy {
max_failures: self.max_failures,
failure_count: AtomicU64::new(0), })
}
fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
Some(Box::new(ThresholdState { failure_count: 0 }))
}
fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
error!(?context.task_id, "Task failed - {error:?}");
let global_failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
let state = context
.state
.as_mut()
.expect("ThresholdCancelPolicy requires state")
.downcast_mut::<ThresholdState>()
.expect("Context type mismatch");
state.failure_count += 1;
let current_failures = state.failure_count;
if current_failures >= self.max_failures as u32 {
warn!(
?context.task_id,
current_failures,
global_failures,
max_failures = self.max_failures,
"Per-task failure threshold exceeded, triggering cancellation"
);
ErrorResponse::Shutdown
} else {
debug!(
?context.task_id,
current_failures,
global_failures,
max_failures = self.max_failures,
"Task failed, tracking per-task failure count"
);
ErrorResponse::Fail
}
}
}
#[derive(Debug)]
pub struct RateCancelPolicy {
cancel_token: CancellationToken,
max_failure_rate: f32,
window_secs: u64,
}
impl RateCancelPolicy {
pub fn builder() -> RateCancelPolicyBuilder {
RateCancelPolicyBuilder::new()
}
}
pub struct RateCancelPolicyBuilder {
max_failure_rate: Option<f32>,
window_secs: Option<u64>,
}
impl RateCancelPolicyBuilder {
fn new() -> Self {
Self {
max_failure_rate: None,
window_secs: None,
}
}
pub fn rate(mut self, max_failure_rate: f32) -> Self {
self.max_failure_rate = Some(max_failure_rate);
self
}
pub fn window_secs(mut self, window_secs: u64) -> Self {
self.window_secs = Some(window_secs);
self
}
pub fn build(self) -> (Arc<RateCancelPolicy>, CancellationToken) {
let max_failure_rate = self.max_failure_rate.expect("rate must be set");
let window_secs = self.window_secs.expect("window_secs must be set");
let token = CancellationToken::new();
let policy = Arc::new(RateCancelPolicy {
cancel_token: token.clone(),
max_failure_rate,
window_secs,
});
(policy, token)
}
}
#[async_trait]
impl OnErrorPolicy for RateCancelPolicy {
fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
Arc::new(RateCancelPolicy {
cancel_token: self.cancel_token.child_token(),
max_failure_rate: self.max_failure_rate,
window_secs: self.window_secs,
})
}
fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
None }
fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
error!(?context.task_id, "Task failed - {error:?}");
warn!(
?context.task_id,
max_failure_rate = self.max_failure_rate,
window_secs = self.window_secs,
"Rate-based error policy - time window tracking not yet implemented"
);
ErrorResponse::Fail
}
}
#[derive(Debug)]
pub struct TriggerCancellationTokenAction {
cancel_token: CancellationToken,
}
impl TriggerCancellationTokenAction {
pub fn new(cancel_token: CancellationToken) -> Self {
Self { cancel_token }
}
}
#[async_trait]
impl OnErrorAction for TriggerCancellationTokenAction {
async fn execute(
&self,
error: &anyhow::Error,
task_id: TaskId,
_attempt_count: u32,
_context: &TaskExecutionContext,
) -> ActionResult {
warn!(
?task_id,
"Executing custom action: triggering cancellation token - {error:?}"
);
self.cancel_token.cancel();
ActionResult::Shutdown
}
}
#[derive(Debug)]
pub struct TriggerCancellationTokenOnError {
cancel_token: CancellationToken,
}
impl TriggerCancellationTokenOnError {
pub fn new(cancel_token: CancellationToken) -> Arc<Self> {
Arc::new(Self { cancel_token })
}
}
impl OnErrorPolicy for TriggerCancellationTokenOnError {
fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
Arc::new(TriggerCancellationTokenOnError {
cancel_token: self.cancel_token.clone(),
})
}
fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
None }
fn on_error(&self, error: &anyhow::Error, context: &mut OnErrorContext) -> ErrorResponse {
error!(
?context.task_id,
"Task failed - triggering custom cancellation token - {error:?}"
);
let action = TriggerCancellationTokenAction::new(self.cancel_token.clone());
ErrorResponse::Custom(Box::new(action))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
use std::sync::atomic::AtomicU32;
use std::time::Duration;
#[fixture]
fn semaphore_scheduler() -> Arc<SemaphoreScheduler> {
Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))))
}
#[fixture]
fn unlimited_scheduler() -> Arc<UnlimitedScheduler> {
UnlimitedScheduler::new()
}
#[fixture]
fn log_policy() -> Arc<LogOnlyPolicy> {
LogOnlyPolicy::new()
}
#[fixture]
fn cancel_policy() -> Arc<CancelOnError> {
CancelOnError::new()
}
#[fixture]
fn basic_tracker(
unlimited_scheduler: Arc<UnlimitedScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) -> TaskTracker {
TaskTracker::new(unlimited_scheduler, log_policy).unwrap()
}
#[rstest]
#[tokio::test]
async fn test_basic_task_execution(basic_tracker: TaskTracker) {
let (tx, rx) = tokio::sync::oneshot::channel();
let handle = basic_tracker.spawn(async {
rx.await.ok();
Ok(42)
});
tx.send(()).ok();
let result = handle
.await
.expect("Task should complete")
.expect("Task should succeed");
assert_eq!(result, 42);
assert_eq!(basic_tracker.metrics().success(), 1);
assert_eq!(basic_tracker.metrics().failed(), 0);
assert_eq!(basic_tracker.metrics().cancelled(), 0);
assert_eq!(basic_tracker.metrics().active(), 0);
}
#[rstest]
#[tokio::test]
async fn test_task_failure(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let tracker = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let handle = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("test error")) });
let result = handle.await.unwrap();
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), TaskError::Failed(_)));
assert_eq!(tracker.metrics().success(), 0);
assert_eq!(tracker.metrics().failed(), 1);
assert_eq!(tracker.metrics().cancelled(), 0);
}
#[rstest]
#[tokio::test]
async fn test_semaphore_concurrency_limit(log_policy: Arc<LogOnlyPolicy>) {
let limited_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2)))); let tracker = TaskTracker::new(limited_scheduler, log_policy).unwrap();
let counter = Arc::new(AtomicU32::new(0));
let max_concurrent = Arc::new(AtomicU32::new(0));
let (tx, _) = tokio::sync::broadcast::channel(1);
let mut handles = Vec::new();
for _ in 0..5 {
let counter_clone = counter.clone();
let max_clone = max_concurrent.clone();
let mut rx = tx.subscribe();
let handle = tracker.spawn(async move {
let current = counter_clone.fetch_add(1, Ordering::Relaxed) + 1;
max_clone.fetch_max(current, Ordering::Relaxed);
rx.recv().await.ok();
counter_clone.fetch_sub(1, Ordering::Relaxed);
Ok(())
});
handles.push(handle);
}
tokio::task::yield_now().await;
tokio::task::yield_now().await;
tx.send(()).ok();
for handle in handles {
handle.await.unwrap().unwrap();
}
assert!(max_concurrent.load(Ordering::Relaxed) <= 2);
assert_eq!(tracker.metrics().success(), 5);
assert_eq!(tracker.metrics().failed(), 0);
}
#[rstest]
#[tokio::test]
async fn test_cancel_on_error_policy() {
let error_policy = cancel_policy();
let scheduler = semaphore_scheduler();
let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
let handle =
tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("OutOfMemory error occurred")) });
let result = handle.await.unwrap();
assert!(result.is_err());
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(tracker.cancellation_token().is_cancelled());
}
#[rstest]
#[tokio::test]
async fn test_tracker_cancellation() {
let error_policy = cancel_policy();
let scheduler = semaphore_scheduler();
let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
let cancel_token = tracker.cancellation_token().child_token();
let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
let handle = tracker.spawn({
let cancel_token = cancel_token.clone();
async move {
tokio::select! {
_ = rx => Ok(()),
_ = cancel_token.cancelled() => Err(anyhow::anyhow!("Task was cancelled")),
}
}
});
tracker.cancel();
let result = handle.await.unwrap();
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
}
#[rstest]
#[tokio::test]
async fn test_child_tracker_independence(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let child = parent.child_tracker().unwrap();
assert!(!parent.is_closed());
assert!(!child.is_closed());
child.cancel();
assert!(!parent.is_closed());
let handle = parent.spawn(async { Ok(42) });
let result = handle.await.unwrap().unwrap();
assert_eq!(result, 42);
}
#[rstest]
#[tokio::test]
async fn test_independent_metrics(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let child = parent.child_tracker().unwrap();
let handle1 = parent.spawn(async { Ok(1) });
handle1.await.unwrap().unwrap();
let handle2 = child.spawn(async { Ok(2) });
handle2.await.unwrap().unwrap();
assert_eq!(parent.metrics().success(), 2); assert_eq!(child.metrics().success(), 1); assert_eq!(parent.metrics().total_completed(), 2); assert_eq!(child.metrics().total_completed(), 1); }
#[rstest]
#[tokio::test]
async fn test_cancel_on_error_hierarchy() {
let parent_error_policy = cancel_policy();
let scheduler = semaphore_scheduler();
let parent = TaskTracker::new(scheduler, parent_error_policy).unwrap();
let parent_policy_token = parent.cancellation_token().child_token();
let child = parent.child_tracker().unwrap();
assert!(!parent_policy_token.is_cancelled());
let (error_tx, error_rx) = tokio::sync::oneshot::channel();
let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel();
let parent_token_monitor = parent_policy_token.clone();
let monitor_handle = tokio::spawn(async move {
tokio::select! {
_ = parent_token_monitor.cancelled() => {
cancel_tx.send(true).ok();
}
_ = tokio::time::sleep(Duration::from_millis(100)) => {
cancel_tx.send(false).ok();
}
}
});
let handle = child.spawn(async move {
let result = Err::<(), _>(anyhow::anyhow!("OutOfMemory in child"));
error_tx.send(()).ok(); result
});
let error_result = handle.await.unwrap();
assert!(error_result.is_err());
error_rx.await.ok();
let was_cancelled = cancel_rx.await.unwrap_or(false);
monitor_handle.await.ok();
assert!(
!was_cancelled,
"Parent policy token should not be cancelled by child errors"
);
assert!(
!parent_policy_token.is_cancelled(),
"Parent policy token should remain active"
);
}
#[rstest]
#[tokio::test]
async fn test_graceful_shutdown(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let tracker = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let (tx, _) = tokio::sync::broadcast::channel(1);
let mut handles = Vec::new();
for i in 0..3 {
let mut rx = tx.subscribe();
let handle = tracker.spawn(async move {
rx.recv().await.ok();
Ok(i)
});
handles.push(handle);
}
tx.send(()).ok();
tracker.join().await;
for handle in handles {
let result = handle.await.unwrap().unwrap();
assert!(result < 3);
}
assert!(tracker.is_closed());
}
#[rstest]
#[tokio::test]
async fn test_semaphore_scheduler_permit_tracking(log_policy: Arc<LogOnlyPolicy>) {
let semaphore = Arc::new(Semaphore::new(3));
let scheduler = Arc::new(SemaphoreScheduler::new(semaphore.clone()));
let tracker = TaskTracker::new(scheduler.clone(), log_policy).unwrap();
assert_eq!(scheduler.available_permits(), 3);
let (tx, _) = tokio::sync::broadcast::channel(1);
let mut handles = Vec::new();
for _ in 0..3 {
let mut rx = tx.subscribe();
let handle = tracker.spawn(async move {
rx.recv().await.ok();
Ok(())
});
handles.push(handle);
}
tokio::task::yield_now().await;
tokio::task::yield_now().await;
assert_eq!(scheduler.available_permits(), 0);
tx.send(()).ok();
for handle in handles {
handle.await.unwrap().unwrap();
}
assert_eq!(scheduler.available_permits(), 3);
}
#[rstest]
#[tokio::test]
async fn test_builder_pattern(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
let error_policy = log_policy;
let tracker = TaskTracker::builder()
.scheduler(scheduler)
.error_policy(error_policy)
.build()
.unwrap();
let token = tracker.cancellation_token();
assert!(!token.is_cancelled());
let handle = tracker.spawn(async { Ok(42) });
let result = handle.await.unwrap().unwrap();
assert_eq!(result, 42);
}
#[rstest]
#[tokio::test]
async fn test_all_trackers_have_cancellation_tokens(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
let root = TaskTracker::new(scheduler, log_policy).unwrap();
let child = root.child_tracker().unwrap();
let grandchild = child.child_tracker().unwrap();
let root_token = root.cancellation_token();
let child_token = child.cancellation_token();
let grandchild_token = grandchild.cancellation_token();
assert!(!root_token.is_cancelled());
assert!(!child_token.is_cancelled());
assert!(!grandchild_token.is_cancelled());
root_token.cancel();
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(root_token.is_cancelled());
assert!(child_token.is_cancelled());
assert!(grandchild_token.is_cancelled());
}
#[rstest]
#[tokio::test]
async fn test_spawn_cancellable_task(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let rx = Arc::new(tokio::sync::Mutex::new(Some(rx)));
let handle = tracker.spawn_cancellable(move |_cancel_token| {
let rx = rx.clone();
async move {
if let Some(rx) = rx.lock().await.take() {
rx.await.ok();
}
CancellableTaskResult::Ok(42)
}
});
tx.send(()).ok();
let result = handle.await.unwrap().unwrap();
assert_eq!(result, 42);
assert_eq!(tracker.metrics().success(), 1);
let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
let rx = Arc::new(tokio::sync::Mutex::new(Some(rx)));
let handle = tracker.spawn_cancellable(move |cancel_token| {
let rx = rx.clone();
async move {
tokio::select! {
_ = async {
if let Some(rx) = rx.lock().await.take() {
rx.await.ok();
}
} => CancellableTaskResult::Ok("should not complete"),
_ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
}
}
});
tracker.cancel();
let result = handle.await.unwrap();
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
}
#[rstest]
#[tokio::test]
async fn test_cancellable_task_metrics_tracking(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
assert_eq!(tracker.metrics().cancelled(), 0);
assert_eq!(tracker.metrics().failed(), 0);
assert_eq!(tracker.metrics().success(), 0);
let (start_tx, start_rx) = tokio::sync::oneshot::channel::<()>();
let (_continue_tx, continue_rx) = tokio::sync::oneshot::channel::<()>();
let start_tx_shared = Arc::new(tokio::sync::Mutex::new(Some(start_tx)));
let continue_rx_shared = Arc::new(tokio::sync::Mutex::new(Some(continue_rx)));
let start_tx_for_task = start_tx_shared.clone();
let continue_rx_for_task = continue_rx_shared.clone();
let handle = tracker.spawn_cancellable(move |cancel_token| {
let start_tx = start_tx_for_task.clone();
let continue_rx = continue_rx_for_task.clone();
async move {
if let Some(tx) = start_tx.lock().await.take() {
tx.send(()).ok();
}
tokio::select! {
_ = async {
if let Some(rx) = continue_rx.lock().await.take() {
rx.await.ok();
}
} => CancellableTaskResult::Ok("completed normally"),
_ = cancel_token.cancelled() => {
println!("Task detected cancellation and is returning Cancelled");
CancellableTaskResult::Cancelled
},
}
}
});
start_rx.await.ok();
println!("Cancelling tracker while task is executing...");
tracker.cancel();
let result = handle.await.unwrap();
println!("Task result: {:?}", result);
println!(
"Cancelled: {}, Failed: {}, Success: {}",
tracker.metrics().cancelled(),
tracker.metrics().failed(),
tracker.metrics().success()
);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
assert_eq!(
tracker.metrics().cancelled(),
1,
"Properly cancelled task should increment cancelled count"
);
assert_eq!(
tracker.metrics().failed(),
0,
"Properly cancelled task should NOT increment failed count"
);
}
#[rstest]
#[tokio::test]
async fn test_cancellable_vs_error_metrics_distinction(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
let handle1 = tracker.spawn_cancellable(|_cancel_token| async move {
CancellableTaskResult::<i32>::Err(anyhow::anyhow!("This is a real error"))
});
let result1 = handle1.await.unwrap();
assert!(result1.is_err());
assert!(matches!(result1.unwrap_err(), TaskError::Failed(_)));
assert_eq!(tracker.metrics().failed(), 1);
assert_eq!(tracker.metrics().cancelled(), 0);
let handle2 = tracker.spawn_cancellable(|_cancel_token| async move {
CancellableTaskResult::<i32>::Cancelled
});
let result2 = handle2.await.unwrap();
assert!(result2.is_err());
assert!(matches!(result2.unwrap_err(), TaskError::Cancelled));
assert_eq!(tracker.metrics().failed(), 1); assert_eq!(tracker.metrics().cancelled(), 1); }
#[rstest]
#[tokio::test]
async fn test_spawn_cancellable_error_handling(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(5))));
let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
let handle = tracker.spawn_cancellable(|_cancel_token| async move {
CancellableTaskResult::<i32>::Err(anyhow::anyhow!("test error"))
});
let result = handle.await.unwrap();
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), TaskError::Failed(_)));
assert_eq!(tracker.metrics().failed(), 1);
}
#[rstest]
#[tokio::test]
async fn test_cancellation_before_execution(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(1))));
let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
tracker.cancel();
tokio::time::sleep(Duration::from_millis(5)).await;
let panic_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
tracker.spawn(async { Ok(42) })
}));
assert!(
panic_result.is_err(),
"spawn() should panic when tracker is closed"
);
if let Err(panic_payload) = panic_result {
if let Some(panic_msg) = panic_payload.downcast_ref::<String>() {
assert!(
panic_msg.contains("TaskTracker must not be closed"),
"Panic message should indicate tracker is closed: {}",
panic_msg
);
} else if let Some(panic_msg) = panic_payload.downcast_ref::<&str>() {
assert!(
panic_msg.contains("TaskTracker must not be closed"),
"Panic message should indicate tracker is closed: {}",
panic_msg
);
}
}
}
#[rstest]
#[tokio::test]
async fn test_semaphore_scheduler_with_cancellation(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(1))));
let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
let blocker_token = tracker.cancellation_token();
let _blocker_handle = tracker.spawn(async move {
blocker_token.cancelled().await;
Ok(())
});
tokio::task::yield_now().await;
let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
let handle = tracker.spawn(async {
rx.await.ok();
Ok(42)
});
tracker.cancel();
let result = handle.await.unwrap();
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), TaskError::Cancelled));
}
#[rstest]
#[tokio::test]
async fn test_child_tracker_cancellation_independence(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let child = parent.child_tracker().unwrap();
child.cancel();
let parent_token = parent.cancellation_token();
assert!(!parent_token.is_cancelled());
let handle = parent.spawn(async { Ok(42) });
let result = handle.await.unwrap().unwrap();
assert_eq!(result, 42);
let child_token = child.cancellation_token();
assert!(child_token.is_cancelled());
}
#[rstest]
#[tokio::test]
async fn test_parent_cancellation_propagates_to_children(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let child1 = parent.child_tracker().unwrap();
let child2 = parent.child_tracker().unwrap();
let grandchild = child1.child_tracker().unwrap();
parent.cancel();
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(parent.cancellation_token().is_cancelled());
assert!(child1.cancellation_token().is_cancelled());
assert!(child2.cancellation_token().is_cancelled());
assert!(grandchild.cancellation_token().is_cancelled());
}
#[rstest]
#[tokio::test]
async fn test_issued_counter_tracking(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2))));
let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
assert_eq!(tracker.metrics().issued(), 0);
assert_eq!(tracker.metrics().pending(), 0);
let handle1 = tracker.spawn(async { Ok(1) });
let handle2 = tracker.spawn(async { Ok(2) });
let handle3 = tracker.spawn_cancellable(|_| async { CancellableTaskResult::Ok(3) });
assert_eq!(tracker.metrics().issued(), 3);
assert_eq!(tracker.metrics().pending(), 3);
assert_eq!(handle1.await.unwrap().unwrap(), 1);
assert_eq!(handle2.await.unwrap().unwrap(), 2);
assert_eq!(handle3.await.unwrap().unwrap(), 3);
assert_eq!(tracker.metrics().issued(), 3);
assert_eq!(tracker.metrics().success(), 3);
assert_eq!(tracker.metrics().total_completed(), 3);
assert_eq!(tracker.metrics().pending(), 0);
let child = tracker.child_tracker().unwrap();
let child_handle = child.spawn(async { Ok(42) });
assert_eq!(child.metrics().issued(), 1);
assert_eq!(tracker.metrics().issued(), 4);
child_handle.await.unwrap().unwrap();
assert_eq!(child.metrics().pending(), 0);
assert_eq!(tracker.metrics().pending(), 0);
assert_eq!(tracker.metrics().success(), 4); }
#[rstest]
#[tokio::test]
async fn test_child_tracker_builder(log_policy: Arc<LogOnlyPolicy>) {
let parent_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
let parent = TaskTracker::new(parent_scheduler, log_policy).unwrap();
let child_error_policy = CancelOnError::new();
let child = parent
.child_tracker_builder()
.error_policy(child_error_policy)
.build()
.unwrap();
let handle = child.spawn(async { Ok(42) });
let result = handle.await.unwrap().unwrap();
assert_eq!(result, 42);
assert_eq!(child.metrics().success(), 1);
assert_eq!(parent.metrics().total_completed(), 1); }
#[rstest]
#[tokio::test]
async fn test_hierarchical_metrics_aggregation(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
let parent = TaskTracker::new(scheduler, log_policy.clone()).unwrap();
let child1 = parent.child_tracker().unwrap();
let child_error_policy = CancelOnError::new();
let child2 = parent
.child_tracker_builder()
.error_policy(child_error_policy)
.build()
.unwrap();
let another_scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(3))));
let another_error_policy = CancelOnError::new();
let child3 = parent
.child_tracker_builder()
.scheduler(another_scheduler)
.error_policy(another_error_policy)
.build()
.unwrap();
assert_eq!(parent.child_count(), 3);
let handle1 = child1.spawn(async { Ok(1) });
let handle2 = child2.spawn(async { Ok(2) });
let handle3 = child3.spawn(async { Ok(3) });
assert_eq!(handle1.await.unwrap().unwrap(), 1);
assert_eq!(handle2.await.unwrap().unwrap(), 2);
assert_eq!(handle3.await.unwrap().unwrap(), 3);
assert_eq!(parent.metrics().success(), 3); assert_eq!(child1.metrics().success(), 1);
assert_eq!(child2.metrics().success(), 1);
assert_eq!(child3.metrics().success(), 1);
}
#[rstest]
#[tokio::test]
async fn test_scheduler_queue_depth_calculation(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(2)))); let tracker = TaskTracker::new(scheduler, log_policy).unwrap();
assert_eq!(tracker.metrics().issued(), 0);
assert_eq!(tracker.metrics().active(), 0);
assert_eq!(tracker.metrics().queued(), 0);
assert_eq!(tracker.metrics().pending(), 0);
let (complete_tx, _complete_rx) = tokio::sync::broadcast::channel(1);
let handle1 = tracker.spawn({
let mut rx = complete_tx.subscribe();
async move {
rx.recv().await.ok();
Ok(1)
}
});
let handle2 = tracker.spawn({
let mut rx = complete_tx.subscribe();
async move {
rx.recv().await.ok();
Ok(2)
}
});
tokio::task::yield_now().await;
tokio::task::yield_now().await;
assert_eq!(tracker.metrics().issued(), 2);
assert_eq!(tracker.metrics().active(), 2);
assert_eq!(tracker.metrics().queued(), 0);
assert_eq!(tracker.metrics().pending(), 2);
let handle3 = tracker.spawn(async move { Ok(3) });
tokio::task::yield_now().await;
assert_eq!(tracker.metrics().issued(), 3);
assert_eq!(tracker.metrics().active(), 2);
assert_eq!(
tracker.metrics().queued(),
tracker.metrics().pending() - tracker.metrics().active()
);
assert_eq!(tracker.metrics().pending(), 3);
complete_tx.send(()).ok();
let result1 = handle1.await.unwrap().unwrap();
let result2 = handle2.await.unwrap().unwrap();
let result3 = handle3.await.unwrap().unwrap();
assert_eq!(result1, 1);
assert_eq!(result2, 2);
assert_eq!(result3, 3);
assert_eq!(tracker.metrics().success(), 3);
assert_eq!(tracker.metrics().active(), 0);
assert_eq!(tracker.metrics().queued(), 0);
assert_eq!(tracker.metrics().pending(), 0);
}
#[rstest]
#[tokio::test]
async fn test_hierarchical_metrics_failure_aggregation(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let child = parent.child_tracker().unwrap();
let success_handle = child.spawn(async { Ok(42) });
let failure_handle = child.spawn(async { Err::<(), _>(anyhow::anyhow!("test error")) });
let _success_result = success_handle.await.unwrap().unwrap();
let _failure_result = failure_handle.await.unwrap().unwrap_err();
assert_eq!(child.metrics().success(), 1, "Child should have 1 success");
assert_eq!(child.metrics().failed(), 1, "Child should have 1 failure");
}
#[rstest]
#[tokio::test]
async fn test_metrics_independence_between_tracker_instances(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let tracker1 = TaskTracker::new(semaphore_scheduler.clone(), log_policy.clone()).unwrap();
let tracker2 = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let handle1 = tracker1.spawn(async { Ok(1) });
let handle2 = tracker2.spawn(async { Ok(2) });
handle1.await.unwrap().unwrap();
handle2.await.unwrap().unwrap();
assert_eq!(tracker1.metrics().success(), 1);
assert_eq!(tracker2.metrics().success(), 1);
assert_eq!(tracker1.metrics().total_completed(), 1);
assert_eq!(tracker2.metrics().total_completed(), 1);
}
#[rstest]
#[tokio::test]
async fn test_hierarchical_join_waits_for_all(log_policy: Arc<LogOnlyPolicy>) {
let scheduler = Arc::new(SemaphoreScheduler::new(Arc::new(Semaphore::new(10))));
let parent = TaskTracker::new(scheduler, log_policy).unwrap();
let child1 = parent.child_tracker().unwrap();
let child2 = parent.child_tracker().unwrap();
let grandchild = child1.child_tracker().unwrap();
assert_eq!(parent.child_count(), 2);
assert_eq!(child1.child_count(), 1);
assert_eq!(child2.child_count(), 0);
assert_eq!(grandchild.child_count(), 0);
let completion_order = Arc::new(Mutex::new(Vec::new()));
let order_clone = completion_order.clone();
let parent_handle = parent.spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
order_clone.lock().unwrap().push("parent");
Ok(())
});
let order_clone = completion_order.clone();
let child1_handle = child1.spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
order_clone.lock().unwrap().push("child1");
Ok(())
});
let order_clone = completion_order.clone();
let child2_handle = child2.spawn(async move {
tokio::time::sleep(Duration::from_millis(75)).await;
order_clone.lock().unwrap().push("child2");
Ok(())
});
let order_clone = completion_order.clone();
let grandchild_handle = grandchild.spawn(async move {
tokio::time::sleep(Duration::from_millis(125)).await;
order_clone.lock().unwrap().push("grandchild");
Ok(())
});
println!("[TEST] About to call parent.join()");
let start = std::time::Instant::now();
parent.join().await; let elapsed = start.elapsed();
println!("[TEST] parent.join() completed in {:?}", elapsed);
assert!(
elapsed >= Duration::from_millis(120),
"Hierarchical join should wait for longest task"
);
assert!(parent_handle.is_finished());
assert!(child1_handle.is_finished());
assert!(child2_handle.is_finished());
assert!(grandchild_handle.is_finished());
let final_order = completion_order.lock().unwrap();
assert_eq!(final_order.len(), 4);
assert!(final_order.contains(&"parent"));
assert!(final_order.contains(&"child1"));
assert!(final_order.contains(&"child2"));
assert!(final_order.contains(&"grandchild"));
}
#[rstest]
#[tokio::test]
async fn test_hierarchical_join_waits_for_children(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let child = parent.child_tracker().unwrap();
let _parent_handle = parent.spawn(async {
tokio::time::sleep(Duration::from_millis(20)).await;
Ok(())
});
let _child_handle = child.spawn(async {
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(())
});
let start = std::time::Instant::now();
parent.join().await; let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(90),
"Hierarchical join should wait for all child tasks"
);
}
#[rstest]
#[tokio::test]
async fn test_hierarchical_join_operations(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let child = parent.child_tracker().unwrap();
let grandchild = child.child_tracker().unwrap();
assert!(!parent.is_closed());
assert!(!child.is_closed());
assert!(!grandchild.is_closed());
parent.join().await;
assert!(child.is_closed());
assert!(grandchild.is_closed());
}
#[rstest]
#[tokio::test]
async fn test_unlimited_scheduler() {
let scheduler = UnlimitedScheduler::new();
let error_policy = LogOnlyPolicy::new();
let tracker = TaskTracker::new(scheduler, error_policy).unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let handle = tracker.spawn(async {
rx.await.ok();
Ok(42)
});
tx.send(()).ok();
let result = handle.await.unwrap().unwrap();
assert_eq!(result, 42);
assert_eq!(tracker.metrics().success(), 1);
}
#[rstest]
#[tokio::test]
async fn test_threshold_cancel_policy(semaphore_scheduler: Arc<SemaphoreScheduler>) {
let error_policy = ThresholdCancelPolicy::with_threshold(2); let tracker = TaskTracker::new(semaphore_scheduler, error_policy.clone()).unwrap();
let cancel_token = tracker.cancellation_token().child_token();
let _handle1 = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("First failure")) });
tokio::task::yield_now().await;
assert!(!cancel_token.is_cancelled());
assert_eq!(error_policy.failure_count(), 1);
let _handle2 = tracker.spawn(async { Err::<(), _>(anyhow::anyhow!("Second failure")) });
tokio::task::yield_now().await;
assert!(!cancel_token.is_cancelled()); assert_eq!(error_policy.failure_count(), 2);
}
#[tokio::test]
async fn test_policy_constructors() {
let _unlimited = UnlimitedScheduler::new();
let _semaphore = SemaphoreScheduler::with_permits(5);
let _log_only = LogOnlyPolicy::new();
let _cancel_policy = CancelOnError::new();
let _threshold_policy = ThresholdCancelPolicy::with_threshold(3);
let _rate_policy = RateCancelPolicy::builder()
.rate(0.5)
.window_secs(60)
.build();
}
#[rstest]
#[tokio::test]
async fn test_child_creation_fails_after_join(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let _child = parent.child_tracker().unwrap();
let parent_clone = parent.clone();
parent.join().await;
assert!(parent_clone.is_closed());
let result = parent_clone.child_tracker();
assert!(result.is_err());
assert!(
result
.err()
.unwrap()
.to_string()
.contains("closed parent tracker")
);
}
#[rstest]
#[tokio::test]
async fn test_child_builder_fails_after_join(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let _child = parent.child_tracker_builder().build().unwrap();
let parent_clone = parent.clone();
parent.join().await;
assert!(parent_clone.is_closed());
let result = parent_clone.child_tracker_builder().build();
assert!(result.is_err());
assert!(
result
.err()
.unwrap()
.to_string()
.contains("closed parent tracker")
);
}
#[rstest]
#[tokio::test]
async fn test_child_creation_succeeds_before_join(
semaphore_scheduler: Arc<SemaphoreScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let parent = TaskTracker::new(semaphore_scheduler, log_policy).unwrap();
let child1 = parent.child_tracker().unwrap();
let child2 = parent.child_tracker_builder().build().unwrap();
let handle1 = child1.spawn(async { Ok(42) });
let handle2 = child2.spawn(async { Ok(24) });
let result1 = handle1.await.unwrap().unwrap();
let result2 = handle2.await.unwrap().unwrap();
assert_eq!(result1, 42);
assert_eq!(result2, 24);
assert_eq!(parent.metrics().success(), 2); }
#[rstest]
#[tokio::test]
async fn test_custom_error_response_with_cancellation_token(
semaphore_scheduler: Arc<SemaphoreScheduler>,
) {
let custom_cancel_token = CancellationToken::new();
let error_policy = TriggerCancellationTokenOnError::new(custom_cancel_token.clone());
let tracker = TaskTracker::builder()
.scheduler(semaphore_scheduler)
.error_policy(error_policy)
.cancel_token(custom_cancel_token.clone())
.build()
.unwrap();
let child = tracker.child_tracker().unwrap();
assert!(!custom_cancel_token.is_cancelled());
let handle = child.spawn(async {
Err::<(), _>(anyhow::anyhow!("Test error to trigger custom response"))
});
let result = handle.await.unwrap();
assert!(result.is_err());
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(1)) => {
panic!("Task should have failed, but hit the deadline");
}
_ = custom_cancel_token.cancelled() => {
}
}
assert!(
custom_cancel_token.is_cancelled(),
"Custom cancellation token should be triggered by ErrorResponse::Custom"
);
assert!(tracker.cancellation_token().is_cancelled());
assert!(child.cancellation_token().is_cancelled());
assert_eq!(tracker.metrics().failed(), 1);
}
#[test]
fn test_action_result_variants() {
let fail_result = ActionResult::Fail;
match fail_result {
ActionResult::Fail => {} _ => panic!("Expected Fail variant"),
}
let shutdown_result = ActionResult::Shutdown;
match shutdown_result {
ActionResult::Shutdown => {} _ => panic!("Expected Shutdown variant"),
}
#[derive(Debug)]
struct TestRestartable;
#[async_trait]
impl Continuation for TestRestartable {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
TaskExecutionResult::Success(Box::new("test_result".to_string()))
}
}
let test_restartable = Arc::new(TestRestartable);
let continue_result = ActionResult::Continue {
continuation: test_restartable,
};
match continue_result {
ActionResult::Continue { continuation } => {
assert!(format!("{:?}", continuation).contains("TestRestartable"));
}
_ => panic!("Expected Continue variant"),
}
}
#[test]
fn test_continuation_error_creation() {
#[derive(Debug)]
struct DummyRestartable;
#[async_trait]
impl Continuation for DummyRestartable {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
TaskExecutionResult::Success(Box::new("restarted_result".to_string()))
}
}
let dummy_restartable = Arc::new(DummyRestartable);
let source_error = anyhow::anyhow!("Original task failed");
let continuation_error = FailedWithContinuation::new(source_error, dummy_restartable);
let error_string = format!("{}", continuation_error);
assert!(error_string.contains("Task failed with continuation"));
assert!(error_string.contains("Original task failed"));
let anyhow_error = anyhow::Error::new(continuation_error);
assert!(
anyhow_error
.to_string()
.contains("Task failed with continuation")
);
}
#[test]
fn test_continuation_error_ext_trait() {
let regular_error = anyhow::anyhow!("Regular error");
assert!(!regular_error.has_continuation());
let extracted = regular_error.extract_continuation();
assert!(extracted.is_none());
#[derive(Debug)]
struct TestRestartable;
#[async_trait]
impl Continuation for TestRestartable {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
TaskExecutionResult::Success(Box::new("test_result".to_string()))
}
}
let test_restartable = Arc::new(TestRestartable);
let source_error = anyhow::anyhow!("Source error");
let continuation_error = FailedWithContinuation::new(source_error, test_restartable);
let anyhow_error = anyhow::Error::new(continuation_error);
assert!(anyhow_error.has_continuation());
let extracted = anyhow_error.extract_continuation();
assert!(extracted.is_some());
}
#[test]
fn test_continuation_error_into_anyhow_helper() {
struct MockExecutor;
let _source_error = anyhow::anyhow!("Mock task failed");
#[derive(Debug)]
struct MockRestartable;
#[async_trait]
impl Continuation for MockRestartable {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
TaskExecutionResult::Success(Box::new("mock_result".to_string()))
}
}
let mock_restartable = Arc::new(MockRestartable);
let continuation_error =
FailedWithContinuation::new(anyhow::anyhow!("Mock task failed"), mock_restartable);
let anyhow_error = anyhow::Error::new(continuation_error);
assert!(anyhow_error.has_continuation());
}
#[test]
fn test_continuation_error_with_task_executor() {
#[derive(Debug)]
struct TestRestartableTask;
#[async_trait]
impl Continuation for TestRestartableTask {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
TaskExecutionResult::Success(Box::new("test_result".to_string()))
}
}
let restartable_task = Arc::new(TestRestartableTask);
let source_error = anyhow::anyhow!("Task failed");
let continuation_error = FailedWithContinuation::new(source_error, restartable_task);
let error_string = format!("{}", continuation_error);
assert!(error_string.contains("Task failed with continuation"));
assert!(error_string.contains("Task failed"));
let anyhow_error = anyhow::Error::new(continuation_error);
assert!(anyhow_error.has_continuation());
let extracted = anyhow_error.extract_continuation();
assert!(extracted.is_some()); }
#[test]
fn test_continuation_error_into_anyhow_convenience() {
#[derive(Debug)]
struct ConvenienceRestartable;
#[async_trait]
impl Continuation for ConvenienceRestartable {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
TaskExecutionResult::Success(Box::new(42u32))
}
}
let restartable_task = Arc::new(ConvenienceRestartable);
let source_error = anyhow::anyhow!("Computation failed");
let anyhow_error = FailedWithContinuation::into_anyhow(source_error, restartable_task);
assert!(anyhow_error.has_continuation());
assert!(
anyhow_error
.to_string()
.contains("Task failed with continuation")
);
assert!(anyhow_error.to_string().contains("Computation failed"));
}
#[test]
fn test_handle_task_error_with_continuation_error() {
#[derive(Debug)]
struct MockRestartableTask;
#[async_trait]
impl Continuation for MockRestartableTask {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
TaskExecutionResult::Success(Box::new("retry_result".to_string()))
}
}
let restartable_task = Arc::new(MockRestartableTask);
let source_error = anyhow::anyhow!("Task failed, but can retry");
let continuation_error = FailedWithContinuation::new(source_error, restartable_task);
let anyhow_error = anyhow::Error::new(continuation_error);
assert!(anyhow_error.has_continuation());
let continuation_ref = anyhow_error.downcast_ref::<FailedWithContinuation>();
assert!(continuation_ref.is_some());
let continuation = continuation_ref.unwrap();
assert!(Arc::strong_count(&continuation.continuation) > 0);
}
#[test]
fn test_handle_task_error_with_regular_error() {
let regular_error = anyhow::anyhow!("Regular task failure");
assert!(!regular_error.has_continuation());
let continuation_ref = regular_error.downcast_ref::<FailedWithContinuation>();
assert!(continuation_ref.is_none());
}
#[rstest]
#[tokio::test]
async fn test_end_to_end_continuation_execution(
unlimited_scheduler: Arc<UnlimitedScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
let log_clone = execution_log.clone();
#[derive(Debug)]
struct LoggingContinuation {
log: Arc<tokio::sync::Mutex<Vec<String>>>,
result: String,
}
#[async_trait]
impl Continuation for LoggingContinuation {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
self.log
.lock()
.await
.push("continuation_executed".to_string());
TaskExecutionResult::Success(Box::new(self.result.clone()))
}
}
let continuation = Arc::new(LoggingContinuation {
log: log_clone,
result: "continuation_result".to_string(),
});
let log_for_task = execution_log.clone();
let handle = tracker.spawn(async move {
log_for_task
.lock()
.await
.push("original_task_executed".to_string());
let error = anyhow::anyhow!("Original task failed");
let result: Result<String, anyhow::Error> =
Err(FailedWithContinuation::into_anyhow(error, continuation));
result
});
let result = handle.await.expect("Task should complete");
assert!(result.is_ok(), "Continuation should succeed");
let log = execution_log.lock().await;
assert_eq!(log.len(), 2);
assert_eq!(log[0], "original_task_executed");
assert_eq!(log[1], "continuation_executed");
assert_eq!(tracker.metrics().success(), 1);
assert_eq!(tracker.metrics().failed(), 0); assert_eq!(tracker.metrics().cancelled(), 0);
}
#[rstest]
#[tokio::test]
async fn test_end_to_end_multiple_continuations(
unlimited_scheduler: Arc<UnlimitedScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
#[derive(Debug)]
struct RetryingContinuation {
log: Arc<tokio::sync::Mutex<Vec<String>>>,
attempt_count: Arc<std::sync::atomic::AtomicU32>,
}
#[async_trait]
impl Continuation for RetryingContinuation {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
let attempt = self
.attempt_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1;
self.log
.lock()
.await
.push(format!("continuation_attempt_{}", attempt));
if attempt < 3 {
let next_continuation = Arc::new(RetryingContinuation {
log: self.log.clone(),
attempt_count: self.attempt_count.clone(),
});
let error = anyhow::anyhow!("Continuation attempt {} failed", attempt);
TaskExecutionResult::Error(FailedWithContinuation::into_anyhow(
error,
next_continuation,
))
} else {
TaskExecutionResult::Success(Box::new(format!(
"success_on_attempt_{}",
attempt
)))
}
}
}
let initial_continuation = Arc::new(RetryingContinuation {
log: execution_log.clone(),
attempt_count: attempt_count.clone(),
});
let handle = tracker.spawn(async move {
let error = anyhow::anyhow!("Original task failed");
let result: Result<String, anyhow::Error> = Err(FailedWithContinuation::into_anyhow(
error,
initial_continuation,
));
result
});
let result = handle.await.expect("Task should complete");
assert!(result.is_ok(), "Final continuation should succeed");
let log = execution_log.lock().await;
assert_eq!(log.len(), 3);
assert_eq!(log[0], "continuation_attempt_1");
assert_eq!(log[1], "continuation_attempt_2");
assert_eq!(log[2], "continuation_attempt_3");
assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 3);
assert_eq!(tracker.metrics().success(), 1);
assert_eq!(tracker.metrics().failed(), 0);
}
#[rstest]
#[tokio::test]
async fn test_end_to_end_continuation_failure(
unlimited_scheduler: Arc<UnlimitedScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
let log_clone = execution_log.clone();
#[derive(Debug)]
struct FailingContinuation {
log: Arc<tokio::sync::Mutex<Vec<String>>>,
}
#[async_trait]
impl Continuation for FailingContinuation {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
self.log
.lock()
.await
.push("continuation_failed".to_string());
TaskExecutionResult::Error(anyhow::anyhow!("Continuation failed permanently"))
}
}
let continuation = Arc::new(FailingContinuation { log: log_clone });
let log_for_task = execution_log.clone();
let handle = tracker.spawn(async move {
log_for_task
.lock()
.await
.push("original_task_executed".to_string());
let error = anyhow::anyhow!("Original task failed");
let result: Result<String, anyhow::Error> =
Err(FailedWithContinuation::into_anyhow(error, continuation));
result
});
let result = handle.await.expect("Task should complete");
assert!(result.is_err(), "Continuation should fail");
let log = execution_log.lock().await;
assert_eq!(log.len(), 2);
assert_eq!(log[0], "original_task_executed");
assert_eq!(log[1], "continuation_failed");
assert_eq!(tracker.metrics().success(), 0);
assert_eq!(tracker.metrics().failed(), 1);
assert_eq!(tracker.metrics().cancelled(), 0);
}
#[rstest]
#[tokio::test]
async fn test_end_to_end_all_action_result_variants(
unlimited_scheduler: Arc<UnlimitedScheduler>,
) {
{
let tracker =
TaskTracker::new(unlimited_scheduler.clone(), LogOnlyPolicy::new()).unwrap();
let handle = tracker.spawn(async {
let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Test error"));
result
});
let result = handle.await.expect("Task should complete");
assert!(result.is_err(), "LogOnly should let error through");
assert_eq!(tracker.metrics().failed(), 1);
}
{
let tracker =
TaskTracker::new(unlimited_scheduler.clone(), CancelOnError::new()).unwrap();
let handle = tracker.spawn(async {
let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Test error"));
result
});
let result = handle.await.expect("Task should complete");
assert!(result.is_err(), "CancelOnError should fail task");
assert!(
tracker.cancellation_token().is_cancelled(),
"Should cancel tracker"
);
assert_eq!(tracker.metrics().failed(), 1);
}
{
let tracker =
TaskTracker::new(unlimited_scheduler.clone(), LogOnlyPolicy::new()).unwrap();
#[derive(Debug)]
struct TestContinuation;
#[async_trait]
impl Continuation for TestContinuation {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
TaskExecutionResult::Success(Box::new("continuation_success".to_string()))
}
}
let continuation = Arc::new(TestContinuation);
let handle = tracker.spawn(async move {
let error = anyhow::anyhow!("Original failure");
let result: Result<String, anyhow::Error> =
Err(FailedWithContinuation::into_anyhow(error, continuation));
result
});
let result = handle.await.expect("Task should complete");
assert!(result.is_ok(), "Continuation should succeed");
assert_eq!(tracker.metrics().success(), 1);
assert_eq!(tracker.metrics().failed(), 0);
}
}
#[rstest]
#[case(
1,
false,
"Global policy with max_failures=1 should stop after first regular error"
)]
#[case(
2,
false, // Actually fails - ActionResult::Fail accepts the error and fails the task
"Global policy with max_failures=2 allows error but ActionResult::Fail still fails the task"
)]
#[tokio::test]
async fn test_continuation_loop_with_global_threshold_policy(
unlimited_scheduler: Arc<UnlimitedScheduler>,
#[case] max_failures: usize,
#[case] should_succeed: bool,
#[case] description: &str,
) {
let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
let attempt_counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
#[derive(Debug)]
struct PolicyTestContinuation {
log: Arc<tokio::sync::Mutex<Vec<String>>>,
attempt_counter: Arc<std::sync::atomic::AtomicU32>,
max_attempts_before_success: u32,
}
#[async_trait]
impl Continuation for PolicyTestContinuation {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
let attempt = self
.attempt_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1;
self.log
.lock()
.await
.push(format!("continuation_attempt_{}", attempt));
if attempt < self.max_attempts_before_success {
TaskExecutionResult::Error(anyhow::anyhow!(
"Continuation attempt {} failed (regular error)",
attempt
))
} else {
TaskExecutionResult::Success(Box::new(format!(
"success_on_attempt_{}",
attempt
)))
}
}
}
let policy = ThresholdCancelPolicy::with_threshold(max_failures);
let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
let log_for_task = execution_log.clone();
let continuation = Arc::new(PolicyTestContinuation {
log: execution_log.clone(),
attempt_counter: attempt_counter.clone(),
max_attempts_before_success: 2, });
let handle = tracker.spawn(async move {
log_for_task
.lock()
.await
.push("original_task_executed".to_string());
let error = anyhow::anyhow!("Original task failed");
let result: Result<String, anyhow::Error> =
Err(FailedWithContinuation::into_anyhow(error, continuation));
result
});
let result = handle.await.expect("Task should complete");
let log = execution_log.lock().await;
let final_attempt_count = attempt_counter.load(std::sync::atomic::Ordering::Relaxed);
println!(
"Test case: max_failures={}, should_succeed={}",
max_failures, should_succeed
);
println!("Result: {:?}", result.is_ok());
println!("Log entries: {:?}", log);
println!("Attempt count: {}", final_attempt_count);
println!(
"Metrics: success={}, failed={}",
tracker.metrics().success(),
tracker.metrics().failed()
);
drop(log);
assert!(result.is_err(), "{}: Task should fail", description);
assert_eq!(
tracker.metrics().success(),
0,
"{}: Should have 0 successes",
description
);
assert_eq!(
tracker.metrics().failed(),
1,
"{}: Should have 1 failure",
description
);
let log = execution_log.lock().await;
assert_eq!(
log.len(),
2,
"{}: Should have 2 log entries (original + 1 continuation attempt)",
description
);
assert_eq!(log[0], "original_task_executed");
assert_eq!(log[1], "continuation_attempt_1");
assert_eq!(
attempt_counter.load(std::sync::atomic::Ordering::Relaxed),
1,
"{}: Should have made 1 continuation attempt",
description
);
if max_failures == 1 {
assert!(
tracker.cancellation_token().is_cancelled(),
"Tracker should be cancelled with max_failures=1"
);
} else {
assert!(
!tracker.cancellation_token().is_cancelled(),
"Tracker should NOT be cancelled with max_failures=2 (policy allows the error)"
);
}
}
#[rstest]
#[tokio::test]
async fn test_simple_threshold_policy_behavior(unlimited_scheduler: Arc<UnlimitedScheduler>) {
let policy = ThresholdCancelPolicy::with_threshold(2);
let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
let handle1 = tracker.spawn(async {
let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("First failure"));
result
});
let result1 = handle1.await.expect("Task should complete");
assert!(result1.is_err(), "First task should fail");
assert!(
!tracker.cancellation_token().is_cancelled(),
"Should not be cancelled after 1 failure"
);
let handle2 = tracker.spawn(async {
let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Second failure"));
result
});
let result2 = handle2.await.expect("Task should complete");
assert!(result2.is_err(), "Second task should fail");
assert!(
!tracker.cancellation_token().is_cancelled(),
"Should NOT be cancelled - per-task context prevents global accumulation"
);
println!("Policy global failure count: {}", policy.failure_count());
assert_eq!(
policy.failure_count(),
2,
"Policy should have counted 2 failures globally (for backwards compatibility)"
);
}
#[rstest]
#[tokio::test]
async fn test_per_task_context_limitation_demo(unlimited_scheduler: Arc<UnlimitedScheduler>) {
let policy = ThresholdCancelPolicy::with_threshold(2);
let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
let handle1 = tracker.spawn(async {
let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Task 1 failure"));
result
});
let result1 = handle1.await.expect("Task should complete");
assert!(result1.is_err(), "Task 1 should fail");
let handle2 = tracker.spawn(async {
let result: Result<String, anyhow::Error> = Err(anyhow::anyhow!("Task 2 failure"));
result
});
let result2 = handle2.await.expect("Task should complete");
assert!(result2.is_err(), "Task 2 should fail");
assert!(
!tracker.cancellation_token().is_cancelled(),
"Tracker should NOT be cancelled - per-task context prevents premature cancellation"
);
println!("Global failure count: {}", policy.failure_count());
assert_eq!(
policy.failure_count(),
2,
"Global policy counted 2 failures across different tasks"
);
}
#[rstest]
#[case(
3,
true,
"Policy allows continuations up to 3 attempts - should succeed"
)]
#[case(
2,
true,
"Policy allows continuations up to 2 attempts - should succeed"
)]
#[case(0, false, "Policy allows 0 attempts - should fail immediately")]
#[tokio::test]
async fn test_allow_continuation_policy_control(
unlimited_scheduler: Arc<UnlimitedScheduler>,
#[case] max_attempts: u32,
#[case] should_succeed: bool,
#[case] description: &str,
) {
#[derive(Debug)]
struct AttemptLimitPolicy {
max_attempts: u32,
}
impl OnErrorPolicy for AttemptLimitPolicy {
fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
Arc::new(AttemptLimitPolicy {
max_attempts: self.max_attempts,
})
}
fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
None }
fn allow_continuation(&self, _error: &anyhow::Error, context: &OnErrorContext) -> bool {
context.attempt_count <= self.max_attempts
}
fn on_error(
&self,
_error: &anyhow::Error,
_context: &mut OnErrorContext,
) -> ErrorResponse {
ErrorResponse::Fail }
}
let policy = Arc::new(AttemptLimitPolicy { max_attempts });
let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
#[derive(Debug)]
struct AlwaysRetryContinuation {
log: Arc<tokio::sync::Mutex<Vec<String>>>,
attempt: u32,
}
#[async_trait]
impl Continuation for AlwaysRetryContinuation {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
self.log
.lock()
.await
.push(format!("continuation_attempt_{}", self.attempt));
if self.attempt >= 2 {
TaskExecutionResult::Success(Box::new("final_success".to_string()))
} else {
let next_continuation = Arc::new(AlwaysRetryContinuation {
log: self.log.clone(),
attempt: self.attempt + 1,
});
let error = anyhow::anyhow!("Continuation attempt {} failed", self.attempt);
TaskExecutionResult::Error(FailedWithContinuation::into_anyhow(
error,
next_continuation,
))
}
}
}
let initial_continuation = Arc::new(AlwaysRetryContinuation {
log: execution_log.clone(),
attempt: 1,
});
let log_for_task = execution_log.clone();
let handle = tracker.spawn(async move {
log_for_task
.lock()
.await
.push("initial_task_failure".to_string());
let error = anyhow::anyhow!("Initial task failure");
let result: Result<String, anyhow::Error> = Err(FailedWithContinuation::into_anyhow(
error,
initial_continuation,
));
result
});
let result = handle.await.expect("Task should complete");
if should_succeed {
assert!(result.is_ok(), "{}: Task should succeed", description);
assert_eq!(
tracker.metrics().success(),
1,
"{}: Should have 1 success",
description
);
let log = execution_log.lock().await;
assert!(
log.len() > 2,
"{}: Should have multiple log entries",
description
);
assert!(log.contains(&"continuation_attempt_1".to_string()));
} else {
assert!(result.is_err(), "{}: Task should fail", description);
assert_eq!(
tracker.metrics().failed(),
1,
"{}: Should have 1 failure",
description
);
let log = execution_log.lock().await;
assert_eq!(
log.len(),
1,
"{}: Should only have initial task entry",
description
);
assert_eq!(log[0], "initial_task_failure");
assert!(
!log.iter()
.any(|entry| entry.contains("continuation_attempt")),
"{}: Should not have continuation attempts, but got: {:?}",
description,
*log
);
}
}
#[tokio::test]
async fn test_task_handle_functionality() {
let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
let handle1 = tracker.spawn(async {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Ok("completed".to_string())
});
let cancel_token = handle1.cancellation_token();
assert!(
!cancel_token.is_cancelled(),
"Token should not be cancelled initially"
);
let result1 = handle1.await.expect("Task should complete");
assert!(result1.is_ok(), "Task should succeed");
assert_eq!(result1.unwrap(), "completed");
let handle2 = tracker.spawn_cancellable(|cancel_token| async move {
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
CancellableTaskResult::Ok("task_was_not_cancelled".to_string())
},
_ = cancel_token.cancelled() => {
CancellableTaskResult::Cancelled
},
}
});
let cancel_token2 = handle2.cancellation_token();
cancel_token2.cancel();
let result2 = handle2.await.expect("Task should complete");
assert!(result2.is_err(), "Task should be cancelled");
assert!(
result2.unwrap_err().is_cancellation(),
"Should be a cancellation error"
);
let handle3 = tracker.spawn(async { Ok("not_cancelled".to_string()) });
let result3 = handle3.await.expect("Task should complete");
assert!(result3.is_ok(), "Other tasks should not be affected");
assert_eq!(result3.unwrap(), "not_cancelled");
let handle4 = tracker.spawn(async {
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
Ok("should_be_aborted".to_string())
});
assert!(!handle4.is_finished(), "Task should not be finished yet");
handle4.abort();
let result4 = handle4.await;
assert!(result4.is_err(), "Aborted task should return JoinError");
assert_eq!(
tracker.metrics().success(),
2,
"Should have 2 successful tasks"
);
assert_eq!(
tracker.metrics().cancelled(),
1,
"Should have 1 cancelled task"
);
}
#[tokio::test]
async fn test_task_handle_with_cancellable_tasks() {
let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
let handle = tracker.spawn_cancellable(|cancel_token| async move {
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
CancellableTaskResult::Ok("completed".to_string())
},
_ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
}
});
let task_cancel_token = handle.cancellation_token();
assert!(
!task_cancel_token.is_cancelled(),
"Task token should not be cancelled initially"
);
let result = handle.await.expect("Task should complete");
assert!(result.is_ok(), "Task should succeed");
assert_eq!(result.unwrap(), "completed");
let handle2 = tracker.spawn_cancellable(|cancel_token| async move {
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
CancellableTaskResult::Ok("should_not_complete".to_string())
},
_ = cancel_token.cancelled() => CancellableTaskResult::Cancelled,
}
});
handle2.cancellation_token().cancel();
let result2 = handle2.await.expect("Task should complete");
assert!(result2.is_err(), "Task should be cancelled");
assert!(
result2.unwrap_err().is_cancellation(),
"Should be a cancellation error"
);
assert_eq!(
tracker.metrics().success(),
1,
"Should have 1 successful task"
);
assert_eq!(
tracker.metrics().cancelled(),
1,
"Should have 1 cancelled task"
);
}
#[tokio::test]
async fn test_continuation_helpers() {
let tracker = TaskTracker::new(UnlimitedScheduler::new(), LogOnlyPolicy::new()).unwrap();
let handle1 = tracker.spawn(async {
let error =
FailedWithContinuation::from_fn(anyhow::anyhow!("Initial failure"), || async {
Ok("Success from from_fn".to_string())
});
let result: Result<String, anyhow::Error> = Err(error);
result
});
let result1 = handle1.await.expect("Task should complete");
assert!(
result1.is_ok(),
"Task with from_fn continuation should succeed"
);
assert_eq!(result1.unwrap(), "Success from from_fn");
let handle2 = tracker.spawn(async {
let error = FailedWithContinuation::from_cancellable(
anyhow::anyhow!("Initial failure"),
|_cancel_token| async move { Ok("Success from from_cancellable".to_string()) },
);
let result: Result<String, anyhow::Error> = Err(error);
result
});
let result2 = handle2.await.expect("Task should complete");
assert!(
result2.is_ok(),
"Task with from_cancellable continuation should succeed"
);
assert_eq!(result2.unwrap(), "Success from from_cancellable");
assert_eq!(
tracker.metrics().success(),
2,
"Should have 2 successful tasks"
);
assert_eq!(tracker.metrics().failed(), 0, "Should have 0 failed tasks");
}
#[rstest]
#[case(false, 1, "Policy requests no rescheduling - should reuse guard")]
#[case(true, 2, "Policy requests rescheduling - should re-acquire guard")]
#[tokio::test]
async fn test_should_reschedule_policy_control(
#[case] should_reschedule: bool,
#[case] expected_acquisitions: u32,
#[case] description: &str,
) {
#[derive(Debug)]
struct MockScheduler {
acquisition_count: Arc<AtomicU32>,
}
impl MockScheduler {
fn new() -> Self {
Self {
acquisition_count: Arc::new(AtomicU32::new(0)),
}
}
fn acquisition_count(&self) -> u32 {
self.acquisition_count.load(Ordering::Relaxed)
}
}
#[async_trait]
impl TaskScheduler for MockScheduler {
async fn acquire_execution_slot(
&self,
_cancel_token: CancellationToken,
) -> SchedulingResult<Box<dyn ResourceGuard>> {
self.acquisition_count.fetch_add(1, Ordering::Relaxed);
SchedulingResult::Execute(Box::new(UnlimitedGuard))
}
}
#[derive(Debug)]
struct RescheduleTestPolicy {
should_reschedule: bool,
}
impl OnErrorPolicy for RescheduleTestPolicy {
fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
Arc::new(RescheduleTestPolicy {
should_reschedule: self.should_reschedule,
})
}
fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
None }
fn allow_continuation(
&self,
_error: &anyhow::Error,
_context: &OnErrorContext,
) -> bool {
true }
fn should_reschedule(&self, _error: &anyhow::Error, _context: &OnErrorContext) -> bool {
self.should_reschedule
}
fn on_error(
&self,
_error: &anyhow::Error,
_context: &mut OnErrorContext,
) -> ErrorResponse {
ErrorResponse::Fail }
}
let mock_scheduler = Arc::new(MockScheduler::new());
let policy = Arc::new(RescheduleTestPolicy { should_reschedule });
let tracker = TaskTracker::new(mock_scheduler.clone(), policy).unwrap();
let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
#[derive(Debug)]
struct SimpleRetryContinuation {
log: Arc<tokio::sync::Mutex<Vec<String>>>,
}
#[async_trait]
impl Continuation for SimpleRetryContinuation {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
self.log
.lock()
.await
.push("continuation_executed".to_string());
TaskExecutionResult::Success(Box::new("continuation_success".to_string()))
}
}
let continuation = Arc::new(SimpleRetryContinuation {
log: execution_log.clone(),
});
let log_for_task = execution_log.clone();
let handle = tracker.spawn(async move {
log_for_task
.lock()
.await
.push("initial_task_failure".to_string());
let error = anyhow::anyhow!("Initial task failure");
let result: Result<String, anyhow::Error> =
Err(FailedWithContinuation::into_anyhow(error, continuation));
result
});
let result = handle.await.expect("Task should complete");
assert!(result.is_ok(), "{}: Task should succeed", description);
assert_eq!(
tracker.metrics().success(),
1,
"{}: Should have 1 success",
description
);
let log = execution_log.lock().await;
assert_eq!(
log.len(),
2,
"{}: Should have initial task + continuation",
description
);
assert_eq!(log[0], "initial_task_failure");
assert_eq!(log[1], "continuation_executed");
let actual_acquisitions = mock_scheduler.acquisition_count();
assert_eq!(
actual_acquisitions, expected_acquisitions,
"{}: Expected {} scheduler acquisitions, got {}",
description, expected_acquisitions, actual_acquisitions
);
}
#[rstest]
#[case(1, true, "Custom action with 1 retry should succeed")]
#[case(3, true, "Custom action with 3 retries should succeed")]
#[tokio::test]
async fn test_continuation_loop_with_custom_action_policy(
unlimited_scheduler: Arc<UnlimitedScheduler>,
#[case] max_retries: u32,
#[case] should_succeed: bool,
#[case] description: &str,
) {
let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
let retry_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
#[derive(Debug)]
struct RetryAction {
log: Arc<tokio::sync::Mutex<Vec<String>>>,
retry_count: Arc<std::sync::atomic::AtomicU32>,
max_retries: u32,
}
#[async_trait]
impl OnErrorAction for RetryAction {
async fn execute(
&self,
_error: &anyhow::Error,
_task_id: TaskId,
_attempt_count: u32,
_context: &TaskExecutionContext,
) -> ActionResult {
let current_retry = self
.retry_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1;
self.log
.lock()
.await
.push(format!("custom_action_retry_{}", current_retry));
if current_retry <= self.max_retries {
#[derive(Debug)]
struct RetryContinuation {
log: Arc<tokio::sync::Mutex<Vec<String>>>,
retry_number: u32,
max_retries: u32,
}
#[async_trait]
impl Continuation for RetryContinuation {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>>
{
self.log
.lock()
.await
.push(format!("retry_continuation_{}", self.retry_number));
if self.retry_number >= self.max_retries {
TaskExecutionResult::Success(Box::new(format!(
"success_after_{}_retries",
self.retry_number
)))
} else {
TaskExecutionResult::Error(anyhow::anyhow!(
"Retry {} still failing",
self.retry_number
))
}
}
}
let continuation = Arc::new(RetryContinuation {
log: self.log.clone(),
retry_number: current_retry,
max_retries: self.max_retries,
});
ActionResult::Continue { continuation }
} else {
ActionResult::Shutdown
}
}
}
#[derive(Debug)]
struct CustomRetryPolicy {
action: Arc<RetryAction>,
}
impl OnErrorPolicy for CustomRetryPolicy {
fn create_child(&self) -> Arc<dyn OnErrorPolicy> {
Arc::new(CustomRetryPolicy {
action: self.action.clone(),
})
}
fn create_context(&self) -> Option<Box<dyn std::any::Any + Send + 'static>> {
None }
fn on_error(
&self,
_error: &anyhow::Error,
_context: &mut OnErrorContext,
) -> ErrorResponse {
ErrorResponse::Custom(Box::new(RetryAction {
log: self.action.log.clone(),
retry_count: self.action.retry_count.clone(),
max_retries: self.action.max_retries,
}))
}
}
let action = Arc::new(RetryAction {
log: execution_log.clone(),
retry_count: retry_count.clone(),
max_retries,
});
let policy = Arc::new(CustomRetryPolicy { action });
let tracker = TaskTracker::new(unlimited_scheduler, policy).unwrap();
let log_for_task = execution_log.clone();
let handle = tracker.spawn(async move {
log_for_task
.lock()
.await
.push("original_task_failed".to_string());
let result: Result<String, anyhow::Error> =
Err(anyhow::anyhow!("Original task failure"));
result
});
let result = handle.await.expect("Task should complete");
if should_succeed {
assert!(result.is_ok(), "{}: Task should succeed", description);
assert_eq!(
tracker.metrics().success(),
1,
"{}: Should have 1 success",
description
);
let log = execution_log.lock().await;
let expected_entries = 1 + (max_retries * 2); assert_eq!(
log.len(),
expected_entries as usize,
"{}: Should have {} log entries",
description,
expected_entries
);
assert_eq!(
retry_count.load(std::sync::atomic::Ordering::Relaxed),
max_retries,
"{}: Should have made {} retry attempts",
description,
max_retries
);
} else {
assert!(result.is_err(), "{}: Task should fail", description);
assert!(
tracker.cancellation_token().is_cancelled(),
"{}: Should be cancelled",
description
);
let final_retry_count = retry_count.load(std::sync::atomic::Ordering::Relaxed);
assert!(
final_retry_count > max_retries,
"{}: Should have exceeded max_retries ({}), got {}",
description,
max_retries,
final_retry_count
);
}
}
#[rstest]
#[tokio::test]
async fn test_mixed_continuation_sources(
unlimited_scheduler: Arc<UnlimitedScheduler>,
log_policy: Arc<LogOnlyPolicy>,
) {
let execution_log = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
let tracker = TaskTracker::new(unlimited_scheduler, log_policy).unwrap();
let log_for_task = execution_log.clone();
let log_for_continuation = execution_log.clone();
#[derive(Debug)]
struct MixedContinuation {
log: Arc<tokio::sync::Mutex<Vec<String>>>,
}
#[async_trait]
impl Continuation for MixedContinuation {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
self.log
.lock()
.await
.push("task_continuation_executed".to_string());
TaskExecutionResult::Error(anyhow::anyhow!("Task continuation failed"))
}
}
let continuation = Arc::new(MixedContinuation {
log: log_for_continuation,
});
let handle = tracker.spawn(async move {
log_for_task
.lock()
.await
.push("original_task_executed".to_string());
let error = anyhow::anyhow!("Original task failed");
let result: Result<String, anyhow::Error> =
Err(FailedWithContinuation::into_anyhow(error, continuation));
result
});
let result = handle.await.expect("Task should complete");
assert!(
result.is_err(),
"Should fail because continuation fails and policy just logs"
);
let log = execution_log.lock().await;
assert_eq!(log.len(), 2);
assert_eq!(log[0], "original_task_executed");
assert_eq!(log[1], "task_continuation_executed");
assert_eq!(tracker.metrics().success(), 0);
assert_eq!(tracker.metrics().failed(), 1);
}
#[rstest]
#[tokio::test]
async fn debug_threshold_policy_in_retry_loop(unlimited_scheduler: Arc<UnlimitedScheduler>) {
let policy = ThresholdCancelPolicy::with_threshold(2);
let tracker = TaskTracker::new(unlimited_scheduler, policy.clone()).unwrap();
#[derive(Debug)]
struct AlwaysFailContinuation {
attempt: Arc<std::sync::atomic::AtomicU32>,
}
#[async_trait]
impl Continuation for AlwaysFailContinuation {
async fn execute(
&self,
_cancel_token: CancellationToken,
) -> TaskExecutionResult<Box<dyn std::any::Any + Send + 'static>> {
let attempt_num = self
.attempt
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1;
println!("Continuation attempt {}", attempt_num);
TaskExecutionResult::Error(anyhow::anyhow!(
"Continuation attempt {} failed",
attempt_num
))
}
}
let attempt_counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let continuation = Arc::new(AlwaysFailContinuation {
attempt: attempt_counter.clone(),
});
let handle = tracker.spawn(async move {
println!("Original task executing");
let error = anyhow::anyhow!("Original task failed");
let result: Result<String, anyhow::Error> =
Err(FailedWithContinuation::into_anyhow(error, continuation));
result
});
let result = handle.await.expect("Task should complete");
println!("Final result: {:?}", result.is_ok());
println!("Policy failure count: {}", policy.failure_count());
println!(
"Continuation attempts: {}",
attempt_counter.load(std::sync::atomic::Ordering::Relaxed)
);
println!(
"Tracker cancelled: {}",
tracker.cancellation_token().is_cancelled()
);
println!(
"Metrics: success={}, failed={}",
tracker.metrics().success(),
tracker.metrics().failed()
);
}
}