use crate::circuit::CircuitBreaker;
use crate::error::CanoError;
use crate::resource::Resources;
use cano_macros::task;
use rand::RngExt;
use std::borrow::Cow;
use std::fmt;
use std::hash::Hash;
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "tracing")]
use tracing::{debug, error, info, info_span, instrument, warn};
#[derive(Debug, Clone)]
pub enum RetryMode {
None,
Fixed { retries: usize, delay: Duration },
ExponentialBackoff {
max_retries: usize,
base_delay: Duration,
multiplier: f64,
max_delay: Duration,
jitter: f64,
},
}
impl RetryMode {
pub fn fixed(retries: usize, delay: Duration) -> Self {
Self::Fixed { retries, delay }
}
pub fn exponential(max_retries: usize) -> Self {
Self::ExponentialBackoff {
max_retries,
base_delay: Duration::from_millis(100),
multiplier: 2.0,
max_delay: Duration::from_secs(30),
jitter: 0.1,
}
}
pub fn exponential_custom(
max_retries: usize,
base_delay: Duration,
multiplier: f64,
max_delay: Duration,
jitter: f64,
) -> Self {
Self::ExponentialBackoff {
max_retries,
base_delay,
multiplier,
max_delay,
jitter: jitter.clamp(0.0, 1.0), }
}
pub fn max_attempts(&self) -> usize {
match self {
Self::None => 1,
Self::Fixed { retries, .. } => retries + 1,
Self::ExponentialBackoff { max_retries, .. } => max_retries + 1,
}
}
pub fn delay_for_attempt(&self, attempt: usize) -> Option<Duration> {
match self {
Self::None => None,
Self::Fixed { retries, delay } => {
if attempt < *retries {
Some(*delay)
} else {
None
}
}
Self::ExponentialBackoff {
max_retries,
base_delay,
multiplier,
max_delay,
jitter,
} => {
if attempt < *max_retries {
let base_ms = base_delay.as_millis() as f64;
let exponential_delay = base_ms * multiplier.powi(attempt as i32);
let capped_delay = exponential_delay.min(max_delay.as_millis() as f64);
let jitter_factor = if *jitter > 0.0 {
let mut rng = rand::rng();
let random_factor: f64 = rng.random_range(-1.0..=1.0);
1.0 + (jitter * random_factor)
} else {
1.0
};
let final_delay_f = (capped_delay * jitter_factor).max(0.0);
let final_delay = if final_delay_f >= u64::MAX as f64 {
u64::MAX
} else {
final_delay_f as u64
};
Some(Duration::from_millis(final_delay))
} else {
None
}
}
}
}
}
impl Default for RetryMode {
fn default() -> Self {
Self::ExponentialBackoff {
max_retries: 3,
base_delay: Duration::from_millis(100),
multiplier: 2.0,
max_delay: Duration::from_secs(30),
jitter: 0.1,
}
}
}
#[must_use]
#[derive(Clone, Default)]
pub struct TaskConfig {
pub retry_mode: RetryMode,
pub attempt_timeout: Option<Duration>,
pub circuit_breaker: Option<Arc<CircuitBreaker>>,
}
impl TaskConfig {
pub fn new() -> Self {
Self::default()
}
pub fn minimal() -> Self {
Self {
retry_mode: RetryMode::None,
attempt_timeout: None,
circuit_breaker: None,
}
}
pub fn with_retry(mut self, retry_mode: RetryMode) -> Self {
self.retry_mode = retry_mode;
self
}
pub fn with_fixed_retry(self, retries: usize, delay: Duration) -> Self {
self.with_retry(RetryMode::fixed(retries, delay))
}
pub fn with_exponential_retry(self, max_retries: usize) -> Self {
self.with_retry(RetryMode::exponential(max_retries))
}
pub fn with_attempt_timeout(mut self, timeout: Duration) -> Self {
self.attempt_timeout = Some(timeout);
self
}
pub fn with_circuit_breaker(mut self, breaker: Arc<CircuitBreaker>) -> Self {
self.circuit_breaker = Some(breaker);
self
}
}
#[cfg_attr(feature = "tracing", instrument(
skip(config, run_fn),
fields(max_attempts = config.retry_mode.max_attempts())
))]
pub async fn run_with_retries<TState, F, Fut>(
config: &TaskConfig,
run_fn: F,
) -> Result<TState, CanoError>
where
TState: Send + Sync,
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<TState, CanoError>>,
{
let max_attempts = config.retry_mode.max_attempts();
let mut attempt = 0;
let breaker = config.circuit_breaker.as_ref();
#[cfg(feature = "tracing")]
info!(max_attempts, "Starting task execution with retry logic");
loop {
#[cfg(feature = "tracing")]
let attempt_span = info_span!("task_attempt", attempt = attempt + 1, max_attempts);
#[cfg(feature = "tracing")]
let _span_guard = attempt_span.enter();
#[cfg(feature = "tracing")]
debug!(attempt = attempt + 1, "Executing task attempt");
let permit = match breaker {
Some(b) => match b.try_acquire() {
Ok(p) => Some(p),
Err(e) => {
#[cfg(feature = "tracing")]
warn!(error = %e, "Circuit breaker open; short-circuiting task");
return Err(e);
}
},
None => None,
};
let attempt_outcome = match config.attempt_timeout {
Some(d) => match tokio::time::timeout(d, run_fn()).await {
Ok(inner) => inner,
Err(_) => Err(CanoError::timeout("task attempt exceeded attempt_timeout")),
},
None => run_fn().await,
};
if let (Some(b), Some(p)) = (breaker, permit) {
match &attempt_outcome {
Ok(_) => b.record_success(p),
Err(_) => b.record_failure(p),
}
}
match attempt_outcome {
Ok(result) => {
#[cfg(feature = "tracing")]
info!(attempt = attempt + 1, "Task execution successful");
return Ok(result);
}
Err(e) => {
attempt += 1;
#[cfg(feature = "tracing")]
if attempt >= max_attempts {
error!(
error = %e,
final_attempt = attempt,
max_attempts,
"Task execution failed after all retry attempts"
);
} else {
warn!(
error = %e,
attempt,
max_attempts,
"Task execution failed, will retry"
);
}
if attempt >= max_attempts {
if max_attempts <= 1 {
return Err(e);
}
return Err(CanoError::retry_exhausted(format!(
"Task failed after {} attempt(s): {}",
attempt, e
)));
} else if let Some(delay) = config.retry_mode.delay_for_attempt(attempt - 1) {
#[cfg(feature = "tracing")]
debug!(delay_ms = delay.as_millis(), "Waiting before retry");
tokio::time::sleep(delay).await;
}
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TaskResult<TState> {
Single(TState),
Split(Vec<TState>),
}
#[task]
pub trait Task<TState, TResourceKey = Cow<'static, str>>: Send + Sync
where
TState: Clone + fmt::Debug + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
fn config(&self) -> TaskConfig {
TaskConfig::default()
}
async fn run(&self, res: &Resources<TResourceKey>) -> Result<TaskResult<TState>, CanoError> {
let _ = res;
self.run_bare().await
}
async fn run_bare(&self) -> Result<TaskResult<TState>, CanoError> {
Err(CanoError::configuration(format!(
"Task<{}>: neither `run` nor `run_bare` was implemented; override one of them",
std::any::type_name::<Self>(),
)))
}
}
#[task]
impl<TState, TResourceKey, N> Task<TState, TResourceKey> for N
where
N: crate::node::Node<TState, TResourceKey>,
TState: Clone + fmt::Debug + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
fn config(&self) -> TaskConfig {
crate::node::Node::config(self)
}
#[cfg_attr(
feature = "tracing",
instrument(skip(self, res), fields(task_type = "node_adapter"))
)]
async fn run(&self, res: &Resources<TResourceKey>) -> Result<TaskResult<TState>, CanoError> {
#[cfg(feature = "tracing")]
debug!("Executing task through Node adapter");
let prep_result = crate::node::Node::prep(self, res).await?;
let exec_result = crate::node::Node::exec(self, prep_result).await;
let next_state = crate::node::Node::post(self, res, exec_result).await?;
#[cfg(feature = "tracing")]
info!(next_state = ?next_state, "Task execution completed successfully");
Ok(TaskResult::Single(next_state))
}
}
pub type DynTask<TState, TResourceKey = Cow<'static, str>> =
dyn Task<TState, TResourceKey> + Send + Sync;
pub type TaskObject<TState, TResourceKey = Cow<'static, str>> =
std::sync::Arc<DynTask<TState, TResourceKey>>;
#[cfg(test)]
mod tests {
use super::*;
use crate::resource::Resources;
use cano_macros::{node, task};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[allow(dead_code)]
enum TestAction {
Continue,
Complete,
Error,
Retry,
}
struct SimpleTask {
execution_count: Arc<AtomicU32>,
}
impl SimpleTask {
fn new() -> Self {
Self {
execution_count: Arc::new(AtomicU32::new(0)),
}
}
fn execution_count(&self) -> u32 {
self.execution_count.load(Ordering::SeqCst)
}
}
#[task]
impl Task<TestAction> for SimpleTask {
async fn run_bare(&self) -> Result<TaskResult<TestAction>, CanoError> {
self.execution_count.fetch_add(1, Ordering::SeqCst);
Ok(TaskResult::Single(TestAction::Complete))
}
}
struct FailingTask {
should_fail: bool,
}
impl FailingTask {
fn new(should_fail: bool) -> Self {
Self { should_fail }
}
}
#[task]
impl Task<TestAction> for FailingTask {
async fn run_bare(&self) -> Result<TaskResult<TestAction>, CanoError> {
if self.should_fail {
Err(CanoError::task_execution("Task intentionally failed"))
} else {
Ok(TaskResult::Single(TestAction::Complete))
}
}
}
struct SplitTask;
#[task]
impl Task<TestAction> for SplitTask {
async fn run_bare(&self) -> Result<TaskResult<TestAction>, CanoError> {
Ok(TaskResult::Split(vec![
TestAction::Continue,
TestAction::Complete,
]))
}
}
#[tokio::test]
async fn test_simple_task_execution() {
let task = SimpleTask::new();
let result = task.run_bare().await.unwrap();
assert_eq!(result, TaskResult::Single(TestAction::Complete));
assert_eq!(task.execution_count(), 1);
}
#[tokio::test]
async fn test_failing_task() {
let success_task = FailingTask::new(false);
let result = success_task.run_bare().await.unwrap();
assert_eq!(result, TaskResult::Single(TestAction::Complete));
let fail_task = FailingTask::new(true);
let result = fail_task.run_bare().await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.to_string().contains("Task intentionally failed"));
}
#[tokio::test]
async fn test_split_task() {
let task = SplitTask;
let result = task.run_bare().await.unwrap();
assert_eq!(
result,
TaskResult::Split(vec![TestAction::Continue, TestAction::Complete])
);
}
#[tokio::test]
async fn test_unimplemented_run_returns_configuration_error() {
struct ForgotToImplement;
#[task]
impl Task<TestAction> for ForgotToImplement {}
let task = ForgotToImplement;
let res = Resources::new();
let err = task.run(&res).await.unwrap_err();
assert_eq!(err.category(), "configuration");
assert!(
err.message().contains("ForgotToImplement"),
"error should name the offending type, got: {}",
err.message()
);
}
#[tokio::test]
async fn test_concurrent_task_execution() {
use tokio::task;
let task = Arc::new(SimpleTask::new());
let mut handles = vec![];
for _ in 0..10 {
let task_clone = Arc::clone(&task);
let handle = task::spawn(async move { task_clone.run_bare().await });
handles.push(handle);
}
let mut success_count = 0;
for handle in handles {
let result = handle.await.unwrap();
if let Ok(TaskResult::Single(TestAction::Complete)) = result {
success_count += 1;
}
}
assert_eq!(success_count, 10);
assert_eq!(task.execution_count(), 10);
}
#[tokio::test]
async fn test_multiple_task_executions() {
let task = SimpleTask::new();
for i in 1..=5 {
let result = task.run_bare().await.unwrap();
assert_eq!(result, TaskResult::Single(TestAction::Complete));
assert_eq!(task.execution_count(), i);
}
}
use crate::node::Node;
struct TestNode;
#[node]
impl Node<TestAction> for TestNode {
type PrepResult = String;
type ExecResult = bool;
async fn prep(&self, _res: &Resources) -> Result<Self::PrepResult, CanoError> {
Ok("node_prepared".to_string())
}
async fn exec(&self, prep_res: Self::PrepResult) -> Self::ExecResult {
prep_res == "node_prepared"
}
async fn post(
&self,
_res: &Resources,
exec_res: Self::ExecResult,
) -> Result<TestAction, CanoError> {
if exec_res {
Ok(TestAction::Complete)
} else {
Ok(TestAction::Error)
}
}
}
#[tokio::test]
async fn test_node_as_task_compatibility() {
let node = TestNode;
let res = Resources::new();
let result = Task::run(&node, &res).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), TaskResult::Single(TestAction::Complete));
}
#[test]
fn test_retry_mode_none() {
let retry_mode = RetryMode::None;
assert_eq!(retry_mode.max_attempts(), 1);
assert_eq!(retry_mode.delay_for_attempt(0), None);
assert_eq!(retry_mode.delay_for_attempt(1), None);
}
#[test]
fn test_retry_mode_fixed() {
let retry_mode = RetryMode::fixed(3, Duration::from_millis(100));
assert_eq!(retry_mode.max_attempts(), 4);
assert_eq!(
retry_mode.delay_for_attempt(0),
Some(Duration::from_millis(100))
);
assert_eq!(
retry_mode.delay_for_attempt(1),
Some(Duration::from_millis(100))
);
assert_eq!(
retry_mode.delay_for_attempt(2),
Some(Duration::from_millis(100))
);
assert_eq!(retry_mode.delay_for_attempt(3), None); assert_eq!(retry_mode.delay_for_attempt(4), None);
}
#[test]
fn test_retry_mode_exponential_basic() {
let retry_mode = RetryMode::exponential(3);
assert_eq!(retry_mode.max_attempts(), 4);
let delay0 = retry_mode.delay_for_attempt(0).unwrap();
let delay1 = retry_mode.delay_for_attempt(1).unwrap();
let delay2 = retry_mode.delay_for_attempt(2).unwrap();
assert!(delay1.as_millis() >= delay0.as_millis() / 2); assert!(delay2.as_millis() >= delay1.as_millis() / 2);
assert_eq!(retry_mode.delay_for_attempt(3), None);
assert_eq!(retry_mode.delay_for_attempt(4), None);
}
#[test]
fn test_retry_mode_exponential_custom() {
let retry_mode = RetryMode::exponential_custom(
2, Duration::from_millis(50), 3.0, Duration::from_secs(5), 0.0, );
assert_eq!(retry_mode.max_attempts(), 3);
assert_eq!(
retry_mode.delay_for_attempt(0),
Some(Duration::from_millis(50))
);
assert_eq!(
retry_mode.delay_for_attempt(1),
Some(Duration::from_millis(150))
);
assert_eq!(retry_mode.delay_for_attempt(2), None);
}
#[test]
fn test_retry_mode_exponential_max_delay_cap() {
let retry_mode = RetryMode::exponential_custom(
5,
Duration::from_millis(100), 10.0, Duration::from_millis(500), 0.0, );
let delay0 = retry_mode.delay_for_attempt(0).unwrap();
let delay1 = retry_mode.delay_for_attempt(1).unwrap();
let delay2 = retry_mode.delay_for_attempt(2).unwrap();
assert_eq!(delay0, Duration::from_millis(100)); assert_eq!(delay1, Duration::from_millis(500)); assert_eq!(delay2, Duration::from_millis(500)); }
#[test]
fn test_retry_mode_exponential_jitter_bounds() {
let retry_mode = RetryMode::exponential_custom(
3,
Duration::from_millis(100),
2.0,
Duration::from_secs(30),
0.5, );
let mut delays = Vec::new();
for _ in 0..20 {
if let Some(delay) = retry_mode.delay_for_attempt(0) {
delays.push(delay.as_millis());
}
}
let min_delay = delays.iter().min().unwrap();
let max_delay = delays.iter().max().unwrap();
assert!(*min_delay >= 50); assert!(*max_delay <= 150); }
#[test]
fn test_retry_mode_jitter_clamping() {
let retry_mode1 = RetryMode::exponential_custom(
1,
Duration::from_millis(100),
2.0,
Duration::from_secs(30),
-0.5, );
let retry_mode2 = RetryMode::exponential_custom(
1,
Duration::from_millis(100),
2.0,
Duration::from_secs(30),
1.5, );
assert!(retry_mode1.delay_for_attempt(0).is_some());
assert!(retry_mode2.delay_for_attempt(0).is_some());
}
#[test]
fn test_retry_mode_default() {
let retry_mode = RetryMode::default();
assert_eq!(retry_mode.max_attempts(), 4);
assert!(retry_mode.delay_for_attempt(0).is_some());
assert!(retry_mode.delay_for_attempt(1).is_some());
assert!(retry_mode.delay_for_attempt(2).is_some());
assert!(retry_mode.delay_for_attempt(3).is_none());
}
#[test]
fn test_task_config_creation() {
let config = TaskConfig::new();
assert_eq!(config.retry_mode.max_attempts(), 4);
}
#[test]
fn test_task_config_default() {
let config = TaskConfig::default();
assert_eq!(config.retry_mode.max_attempts(), 4);
}
#[test]
fn test_task_config_minimal() {
let config = TaskConfig::minimal();
assert_eq!(config.retry_mode.max_attempts(), 1);
}
#[test]
fn test_task_config_with_fixed_retry() {
let config = TaskConfig::new().with_fixed_retry(5, Duration::from_millis(100));
assert_eq!(config.retry_mode.max_attempts(), 6);
}
#[test]
fn test_task_config_builder_pattern() {
let config = TaskConfig::new().with_fixed_retry(10, Duration::from_secs(1));
assert_eq!(config.retry_mode.max_attempts(), 11);
}
#[tokio::test]
async fn test_run_with_retries_success() {
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::minimal();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("success".to_string()))
}
})
.await
.unwrap();
assert_eq!(result, TaskResult::Single("success".to_string()));
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_run_with_retries_failure() {
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::new().with_fixed_retry(2, Duration::from_millis(1));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
let count = counter.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err(CanoError::task_execution("failure"))
} else {
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("success".to_string()))
}
}
})
.await
.unwrap();
assert_eq!(result, TaskResult::Single("success".to_string()));
assert_eq!(counter.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_run_with_retries_exhausted() {
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::new().with_fixed_retry(2, Duration::from_millis(1));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<TaskResult<String>, CanoError>(CanoError::task_execution("always fails"))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_run_with_retries_mode_none() {
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::minimal();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<TaskResult<String>, CanoError>(CanoError::task_execution("immediate fail"))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1);
let err = result.unwrap_err();
assert!(
matches!(err, CanoError::TaskExecution(_)),
"expected original TaskExecution variant when retries disabled, got: {err}"
);
assert!(err.to_string().contains("immediate fail"));
}
struct BareTask;
#[task]
impl Task<TestAction> for BareTask {
async fn run_bare(&self) -> Result<TaskResult<TestAction>, CanoError> {
Ok(TaskResult::Single(TestAction::Complete))
}
}
struct ExplicitRunTask {
bare_called: Arc<AtomicU32>,
}
#[task]
impl Task<TestAction> for ExplicitRunTask {
async fn run(&self, _res: &Resources) -> Result<TaskResult<TestAction>, CanoError> {
Ok(TaskResult::Single(TestAction::Continue))
}
async fn run_bare(&self) -> Result<TaskResult<TestAction>, CanoError> {
self.bare_called.fetch_add(1, Ordering::SeqCst);
Ok(TaskResult::Single(TestAction::Error))
}
}
#[tokio::test]
async fn test_run_bare_called_when_run_not_overridden() {
let task = BareTask;
let res = Resources::new();
let result = task.run(&res).await.unwrap();
assert_eq!(result, TaskResult::Single(TestAction::Complete));
}
#[tokio::test]
async fn test_run_overrides_bypass_bare() {
let bare_called = Arc::new(AtomicU32::new(0));
let task = ExplicitRunTask {
bare_called: Arc::clone(&bare_called),
};
let res = Resources::new();
let result = task.run(&res).await.unwrap();
assert_eq!(result, TaskResult::Single(TestAction::Continue));
assert_eq!(bare_called.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_attempt_timeout_triggers_retry() {
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::new()
.with_fixed_retry(2, Duration::from_millis(1))
.with_attempt_timeout(Duration::from_millis(20));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(200)).await;
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("never".to_string()))
}
})
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, CanoError::RetryExhausted(_)),
"expected RetryExhausted, got: {err}"
);
let msg = err.to_string();
assert!(
msg.contains("Timeout error") || msg.contains("attempt_timeout"),
"expected timeout context in error, got: {msg}"
);
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_attempt_timeout_none_unchanged() {
let config = TaskConfig::new();
assert!(config.attempt_timeout.is_none());
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || async {
tokio::time::sleep(Duration::from_millis(5)).await;
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("ok".to_string()))
})
.await
.unwrap();
assert_eq!(result, TaskResult::Single("ok".to_string()));
}
#[tokio::test]
async fn test_attempt_timeout_resets_per_attempt() {
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::new()
.with_fixed_retry(1, Duration::from_millis(1))
.with_attempt_timeout(Duration::from_millis(30));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
let n = counter.fetch_add(1, Ordering::SeqCst);
if n == 0 {
tokio::time::sleep(Duration::from_millis(200)).await;
}
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("ok".to_string()))
}
})
.await
.unwrap();
assert_eq!(result, TaskResult::Single("ok".to_string()));
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_retry_exhausted_error_type() {
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::new().with_fixed_retry(2, Duration::from_millis(1));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<TaskResult<String>, CanoError>(CanoError::task_execution(
"persistent failure",
))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 3);
let err = result.unwrap_err();
assert!(
matches!(err, CanoError::RetryExhausted(_)),
"expected RetryExhausted after retry exhaustion, got: {err}"
);
}
use crate::circuit::{CircuitBreaker, CircuitPolicy, CircuitState};
fn cb_policy(threshold: u32) -> CircuitPolicy {
CircuitPolicy {
failure_threshold: threshold,
reset_timeout: Duration::from_millis(20),
half_open_max_calls: 1,
}
}
#[tokio::test]
async fn test_circuit_open_short_circuits_without_invoking_task() {
use std::sync::atomic::{AtomicUsize, Ordering};
let breaker = Arc::new(CircuitBreaker::new(cb_policy(3)));
let config = TaskConfig::minimal().with_circuit_breaker(Arc::clone(&breaker));
let counter = Arc::new(AtomicUsize::new(0));
for _ in 0..3 {
let counter = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<TaskResult<String>, CanoError>(CanoError::task_execution("boom"))
}
})
.await;
assert!(result.is_err());
}
assert_eq!(counter.load(Ordering::SeqCst), 3);
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
let counter_before = counter.load(Ordering::SeqCst);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("never".into()))
}
})
.await;
let err = result.unwrap_err();
assert!(
matches!(err, CanoError::CircuitOpen(_)),
"expected CircuitOpen, got: {err}"
);
assert_eq!(
counter.load(Ordering::SeqCst),
counter_before,
"task body must not run when breaker is open"
);
}
#[tokio::test]
async fn test_circuit_open_does_not_consume_retry_attempts() {
let breaker = Arc::new(CircuitBreaker::new(cb_policy(1)));
let config = TaskConfig::new()
.with_fixed_retry(5, Duration::from_millis(1))
.with_circuit_breaker(Arc::clone(&breaker));
let _ = run_with_retries::<TaskResult<String>, _, _>(&config, || async {
Err::<TaskResult<String>, CanoError>(CanoError::task_execution("boom"))
})
.await;
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || async {
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("never".into()))
})
.await;
let err = result.unwrap_err();
assert!(
matches!(err, CanoError::CircuitOpen(_)),
"expected raw CircuitOpen, got: {err}"
);
}
#[tokio::test]
async fn test_circuit_half_open_recovery_via_run_with_retries() {
use std::sync::atomic::{AtomicUsize, Ordering};
let breaker = Arc::new(CircuitBreaker::new(cb_policy(2)));
let config = TaskConfig::minimal().with_circuit_breaker(Arc::clone(&breaker));
for _ in 0..2 {
let _ = run_with_retries::<TaskResult<String>, _, _>(&config, || async {
Err::<TaskResult<String>, CanoError>(CanoError::task_execution("boom"))
})
.await;
}
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
tokio::time::sleep(Duration::from_millis(40)).await;
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("ok".into()))
}
})
.await
.unwrap();
assert_eq!(result, TaskResult::Single("ok".to_string()));
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_shared_across_tasks() {
use std::sync::atomic::{AtomicUsize, Ordering};
let breaker = Arc::new(CircuitBreaker::new(cb_policy(3)));
let config_a = TaskConfig::minimal().with_circuit_breaker(Arc::clone(&breaker));
let config_b = TaskConfig::minimal().with_circuit_breaker(Arc::clone(&breaker));
for _ in 0..3 {
let _ = run_with_retries::<TaskResult<String>, _, _>(&config_a, || async {
Err::<TaskResult<String>, CanoError>(CanoError::task_execution("A failed"))
})
.await;
}
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
let b_invocations = Arc::new(AtomicUsize::new(0));
let b_invocations_clone = Arc::clone(&b_invocations);
let result = run_with_retries::<TaskResult<String>, _, _>(&config_b, || {
let b = Arc::clone(&b_invocations_clone);
async move {
b.fetch_add(1, Ordering::SeqCst);
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("B ok".into()))
}
})
.await;
let err = result.unwrap_err();
assert!(matches!(err, CanoError::CircuitOpen(_)));
assert_eq!(
b_invocations.load(Ordering::SeqCst),
0,
"task B must not run while breaker tripped by task A is open"
);
}
#[tokio::test]
async fn test_circuit_open_mid_loop_with_fixed_retry_short_circuits_remaining_attempts() {
use std::sync::atomic::{AtomicUsize, Ordering};
let breaker = Arc::new(CircuitBreaker::new(cb_policy(2)));
let config = TaskConfig::new()
.with_fixed_retry(5, Duration::from_millis(1))
.with_circuit_breaker(Arc::clone(&breaker));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<TaskResult<String>, CanoError>(CanoError::task_execution("boom"))
}
})
.await;
let err = result.unwrap_err();
assert!(
matches!(err, CanoError::CircuitOpen(_)),
"expected raw CircuitOpen after mid-loop trip, got: {err}"
);
assert_eq!(
counter.load(Ordering::SeqCst),
2,
"task body must run exactly threshold times before the breaker trips short-circuits the rest of the retry budget"
);
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
}
#[tokio::test]
async fn test_circuit_breaker_unused_does_not_change_behavior() {
let config = TaskConfig::new().with_fixed_retry(1, Duration::from_millis(1));
assert!(config.circuit_breaker.is_none());
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || async {
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("ok".into()))
})
.await
.unwrap();
assert_eq!(result, TaskResult::Single("ok".into()));
}
#[tokio::test]
async fn test_attempt_timeout_counts_as_circuit_failure() {
use std::sync::atomic::{AtomicUsize, Ordering};
let breaker = Arc::new(CircuitBreaker::new(cb_policy(2)));
let config = TaskConfig::minimal()
.with_attempt_timeout(Duration::from_millis(10))
.with_circuit_breaker(Arc::clone(&breaker));
let invocations = Arc::new(AtomicUsize::new(0));
for _ in 0..2 {
let invocations = Arc::clone(&invocations);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let invocations = Arc::clone(&invocations);
async move {
invocations.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(100)).await;
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("never".into()))
}
})
.await;
assert!(matches!(result, Err(CanoError::Timeout(_))));
}
assert_eq!(invocations.load(Ordering::SeqCst), 2);
assert!(
matches!(breaker.state(), CircuitState::Open { .. }),
"two timed-out attempts must trip the breaker, got {:?}",
breaker.state()
);
let before = invocations.load(Ordering::SeqCst);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let invocations = Arc::clone(&invocations);
async move {
invocations.fetch_add(1, Ordering::SeqCst);
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("never".into()))
}
})
.await;
assert!(matches!(result, Err(CanoError::CircuitOpen(_))));
assert_eq!(
invocations.load(Ordering::SeqCst),
before,
"task body must not run when the breaker is open"
);
}
}