use crate::error::CanoError;
use crate::store::MemoryStore;
use async_trait::async_trait;
use rand::RngExt;
use std::collections::HashMap;
use std::time::Duration;
#[cfg(feature = "tracing")]
use tracing::{debug, error, info, info_span, instrument, warn};
#[derive(Debug, Clone)]
pub enum RetryMode {
None,
Fixed { retries: usize, delay: Duration },
ExponentialBackoff {
max_retries: usize,
base_delay: Duration,
multiplier: f64,
max_delay: Duration,
jitter: f64,
},
}
impl RetryMode {
pub fn fixed(retries: usize, delay: Duration) -> Self {
Self::Fixed { retries, delay }
}
pub fn exponential(max_retries: usize) -> Self {
Self::ExponentialBackoff {
max_retries,
base_delay: Duration::from_millis(100),
multiplier: 2.0,
max_delay: Duration::from_secs(30),
jitter: 0.1,
}
}
pub fn exponential_custom(
max_retries: usize,
base_delay: Duration,
multiplier: f64,
max_delay: Duration,
jitter: f64,
) -> Self {
Self::ExponentialBackoff {
max_retries,
base_delay,
multiplier,
max_delay,
jitter: jitter.clamp(0.0, 1.0), }
}
pub fn max_attempts(&self) -> usize {
match self {
Self::None => 1,
Self::Fixed { retries, .. } => retries + 1,
Self::ExponentialBackoff { max_retries, .. } => max_retries + 1,
}
}
pub fn delay_for_attempt(&self, attempt: usize) -> Option<Duration> {
match self {
Self::None => None,
Self::Fixed { retries, delay } => {
if attempt < *retries {
Some(*delay)
} else {
None
}
}
Self::ExponentialBackoff {
max_retries,
base_delay,
multiplier,
max_delay,
jitter,
} => {
if attempt < *max_retries {
let base_ms = base_delay.as_millis() as f64;
let exponential_delay = base_ms * multiplier.powi(attempt as i32);
let capped_delay = exponential_delay.min(max_delay.as_millis() as f64);
let jitter_factor = if *jitter > 0.0 {
let mut rng = rand::rng();
let random_factor: f64 = rng.random_range(-1.0..=1.0);
1.0 + (jitter * random_factor)
} else {
1.0
};
let final_delay_f = (capped_delay * jitter_factor).max(0.0);
let final_delay = if final_delay_f >= u64::MAX as f64 {
u64::MAX
} else {
final_delay_f as u64
};
Some(Duration::from_millis(final_delay))
} else {
None
}
}
}
}
}
impl Default for RetryMode {
fn default() -> Self {
Self::ExponentialBackoff {
max_retries: 3,
base_delay: Duration::from_millis(100),
multiplier: 2.0,
max_delay: Duration::from_secs(30),
jitter: 0.1,
}
}
}
#[must_use]
#[derive(Clone, Default)]
pub struct TaskConfig {
pub retry_mode: RetryMode,
}
impl TaskConfig {
pub fn new() -> Self {
Self::default()
}
pub fn minimal() -> Self {
Self {
retry_mode: RetryMode::None,
}
}
pub fn with_retry(mut self, retry_mode: RetryMode) -> Self {
self.retry_mode = retry_mode;
self
}
pub fn with_fixed_retry(self, retries: usize, delay: Duration) -> Self {
self.with_retry(RetryMode::fixed(retries, delay))
}
pub fn with_exponential_retry(self, max_retries: usize) -> Self {
self.with_retry(RetryMode::exponential(max_retries))
}
}
#[cfg_attr(feature = "tracing", instrument(
skip(config, run_fn),
fields(max_attempts = config.retry_mode.max_attempts())
))]
pub async fn run_with_retries<TState, F, Fut>(
config: &TaskConfig,
run_fn: F,
) -> Result<TState, CanoError>
where
TState: Send + Sync,
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<TState, CanoError>>,
{
let max_attempts = config.retry_mode.max_attempts();
let mut attempt = 0;
#[cfg(feature = "tracing")]
info!(max_attempts, "Starting task execution with retry logic");
loop {
#[cfg(feature = "tracing")]
let attempt_span = info_span!("task_attempt", attempt = attempt + 1, max_attempts);
#[cfg(feature = "tracing")]
let _span_guard = attempt_span.enter();
#[cfg(feature = "tracing")]
debug!(attempt = attempt + 1, "Executing task attempt");
match run_fn().await {
Ok(result) => {
#[cfg(feature = "tracing")]
info!(attempt = attempt + 1, "Task execution successful");
return Ok(result);
}
Err(e) => {
attempt += 1;
#[cfg(feature = "tracing")]
if attempt >= max_attempts {
error!(
error = %e,
final_attempt = attempt,
max_attempts,
"Task execution failed after all retry attempts"
);
} else {
warn!(
error = %e,
attempt,
max_attempts,
"Task execution failed, will retry"
);
}
if attempt >= max_attempts {
if max_attempts <= 1 {
return Err(e);
}
return Err(CanoError::retry_exhausted(format!(
"Task failed after {} attempt(s): {}",
attempt, e
)));
} else if let Some(delay) = config.retry_mode.delay_for_attempt(attempt - 1) {
#[cfg(feature = "tracing")]
debug!(delay_ms = delay.as_millis(), "Waiting before retry");
tokio::time::sleep(delay).await;
}
}
}
}
}
pub type DefaultTaskParams = HashMap<String, String>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TaskResult<TState> {
Single(TState),
Split(Vec<TState>),
}
#[async_trait]
pub trait Task<TState, TStore = MemoryStore, TParams = DefaultTaskParams>: Send + Sync
where
TState: Clone + std::fmt::Debug + Send + Sync + 'static,
TParams: Send + Sync + Clone,
TStore: Send + Sync + 'static,
{
fn set_params(&mut self, _params: TParams) {
}
fn config(&self) -> TaskConfig {
TaskConfig::default()
}
async fn run(&self, store: &TStore) -> Result<TaskResult<TState>, CanoError>;
}
#[async_trait]
impl<TState, TStore, TParams, N> Task<TState, TStore, TParams> for N
where
N: crate::node::Node<TState, TStore, TParams>,
TState: Clone + std::fmt::Debug + Send + Sync + 'static,
TParams: Send + Sync + Clone,
TStore: Send + Sync + 'static,
{
fn set_params(&mut self, params: TParams) {
crate::node::Node::set_params(self, params);
}
fn config(&self) -> TaskConfig {
let node_config = crate::node::Node::config(self);
TaskConfig {
retry_mode: node_config.retry_mode,
}
}
#[cfg_attr(
feature = "tracing",
instrument(skip(self, store), fields(task_type = "node_adapter"))
)]
async fn run(&self, store: &TStore) -> Result<TaskResult<TState>, CanoError> {
#[cfg(feature = "tracing")]
debug!("Executing task through Node adapter");
let prep_res = crate::node::Node::prep(self, store).await?;
let exec_res = crate::node::Node::exec(self, prep_res).await;
let state = crate::node::Node::post(self, store, exec_res).await?;
#[cfg(feature = "tracing")]
info!(next_state = ?state, "Task execution completed successfully");
Ok(TaskResult::Single(state))
}
}
pub trait DynTask<TState>: Task<TState, MemoryStore, DefaultTaskParams>
where
TState: Clone + std::fmt::Debug + Send + Sync + 'static,
{
}
impl<TState, T> DynTask<TState> for T
where
TState: Clone + std::fmt::Debug + Send + Sync + 'static,
T: Task<TState, MemoryStore, DefaultTaskParams>,
{
}
pub type TaskObject<TState> = dyn DynTask<TState> + Send + Sync;
#[cfg(test)]
mod tests {
use super::*;
use crate::store::{KeyValueStore, MemoryStore};
use async_trait::async_trait;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[allow(dead_code)]
enum TestAction {
Continue,
Complete,
Error,
Retry,
}
struct SimpleTask {
execution_count: Arc<AtomicU32>,
}
impl SimpleTask {
fn new() -> Self {
Self {
execution_count: Arc::new(AtomicU32::new(0)),
}
}
fn execution_count(&self) -> u32 {
self.execution_count.load(Ordering::SeqCst)
}
}
#[async_trait]
impl Task<TestAction> for SimpleTask {
async fn run(&self, store: &MemoryStore) -> Result<TaskResult<TestAction>, CanoError> {
self.execution_count.fetch_add(1, Ordering::SeqCst);
store.put("simple_task_executed", true)?;
Ok(TaskResult::Single(TestAction::Complete))
}
}
struct ParameterizedTask {
params: DefaultTaskParams,
multiplier: i32,
}
impl ParameterizedTask {
fn new() -> Self {
Self {
params: HashMap::new(),
multiplier: 1,
}
}
}
#[async_trait]
impl Task<TestAction> for ParameterizedTask {
fn set_params(&mut self, params: DefaultTaskParams) {
self.params = params;
if let Some(multiplier_str) = self.params.get("multiplier")
&& let Ok(multiplier) = multiplier_str.parse::<i32>()
{
self.multiplier = multiplier;
}
}
async fn run(&self, store: &MemoryStore) -> Result<TaskResult<TestAction>, CanoError> {
let base_value = self
.params
.get("base_value")
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or(10);
let result = base_value * self.multiplier;
store.put("result", result)?;
Ok(TaskResult::Single(TestAction::Complete))
}
}
struct FailingTask {
should_fail: bool,
}
impl FailingTask {
fn new(should_fail: bool) -> Self {
Self { should_fail }
}
}
#[async_trait]
impl Task<TestAction> for FailingTask {
async fn run(&self, store: &MemoryStore) -> Result<TaskResult<TestAction>, CanoError> {
if self.should_fail {
Err(CanoError::task_execution("Task intentionally failed"))
} else {
store.put("failing_task_executed", true)?;
Ok(TaskResult::Single(TestAction::Complete))
}
}
}
struct DataProcessingTask {
input_key: String,
output_key: String,
}
impl DataProcessingTask {
fn new(input_key: &str, output_key: &str) -> Self {
Self {
input_key: input_key.to_string(),
output_key: output_key.to_string(),
}
}
}
#[async_trait]
impl Task<TestAction> for DataProcessingTask {
async fn run(&self, store: &MemoryStore) -> Result<TaskResult<TestAction>, CanoError> {
let input_data: String = store
.get(&self.input_key)
.map_err(|e| CanoError::task_execution(format!("Failed to read input: {e}")))?;
let processed_data = format!("processed: {input_data}");
store.put(&self.output_key, processed_data)?;
Ok(TaskResult::Single(TestAction::Complete))
}
}
#[tokio::test]
async fn test_simple_task_execution() {
let task = SimpleTask::new();
let store = MemoryStore::new();
let result = task.run(&store).await.unwrap();
assert_eq!(result, TaskResult::Single(TestAction::Complete));
assert_eq!(task.execution_count(), 1);
let executed: bool = store.get("simple_task_executed").unwrap();
assert!(executed);
}
#[tokio::test]
async fn test_parameterized_task() {
let mut task = ParameterizedTask::new();
let store = MemoryStore::new();
let result = task.run(&store).await.unwrap();
assert_eq!(result, TaskResult::Single(TestAction::Complete));
let stored_result: i32 = store.get("result").unwrap();
assert_eq!(stored_result, 10);
let mut params = HashMap::new();
params.insert("base_value".to_string(), "5".to_string());
params.insert("multiplier".to_string(), "3".to_string());
task.set_params(params);
let result2 = task.run(&store).await.unwrap();
assert_eq!(result2, TaskResult::Single(TestAction::Complete));
let stored_result2: i32 = store.get("result").unwrap();
assert_eq!(stored_result2, 15); }
#[tokio::test]
async fn test_failing_task() {
let store = MemoryStore::new();
let success_task = FailingTask::new(false);
let result = success_task.run(&store).await.unwrap();
assert_eq!(result, TaskResult::Single(TestAction::Complete));
let executed: bool = store.get("failing_task_executed").unwrap();
assert!(executed);
let fail_task = FailingTask::new(true);
let result = fail_task.run(&store).await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.to_string().contains("Task intentionally failed"));
}
#[tokio::test]
async fn test_data_processing_task() {
let store = MemoryStore::new();
let task = DataProcessingTask::new("input_data", "output_data");
store.put("input_data", "test_value".to_string()).unwrap();
let result = task.run(&store).await.unwrap();
assert_eq!(result, TaskResult::Single(TestAction::Complete));
let output: String = store.get("output_data").unwrap();
assert_eq!(output, "processed: test_value");
}
#[tokio::test]
async fn test_data_processing_task_missing_input() {
let store = MemoryStore::new();
let task = DataProcessingTask::new("missing_input", "output_data");
let result = task.run(&store).await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.to_string().contains("Failed to read input"));
}
#[tokio::test]
async fn test_concurrent_task_execution() {
use tokio::task;
let task = Arc::new(SimpleTask::new());
let store = Arc::new(MemoryStore::new());
let mut handles = vec![];
for _ in 0..10 {
let task_clone = Arc::clone(&task);
let store_clone = Arc::clone(&store);
let handle = task::spawn(async move { task_clone.run(&*store_clone).await });
handles.push(handle);
}
let mut success_count = 0;
for handle in handles {
let result = handle.await.unwrap();
if let Ok(TaskResult::Single(TestAction::Complete)) = result {
success_count += 1;
}
}
assert_eq!(success_count, 10);
assert_eq!(task.execution_count(), 10);
}
#[tokio::test]
async fn test_task_trait_object_compatibility() {
let _store = MemoryStore::new();
let task = SimpleTask::new();
fn assert_task_traits<T>(_: &T)
where
T: Task<TestAction, MemoryStore, DefaultTaskParams>,
{
}
assert_task_traits(&task);
}
#[tokio::test]
async fn test_multiple_task_executions() {
let task = SimpleTask::new();
let store = MemoryStore::new();
for i in 1..=5 {
let result = task.run(&store).await.unwrap();
assert_eq!(result, TaskResult::Single(TestAction::Complete));
assert_eq!(task.execution_count(), i);
}
}
#[tokio::test]
async fn test_task_state_isolation() {
let store1 = MemoryStore::new();
let store2 = MemoryStore::new();
let task1 = DataProcessingTask::new("input", "output1");
let task2 = DataProcessingTask::new("input", "output2");
store1.put("input", "data1".to_string()).unwrap();
store2.put("input", "data2".to_string()).unwrap();
task1.run(&store1).await.unwrap();
task2.run(&store2).await.unwrap();
let result1: String = store1.get("output1").unwrap();
let result2: String = store2.get("output2").unwrap();
assert_eq!(result1, "processed: data1");
assert_eq!(result2, "processed: data2");
assert!(store1.get::<String>("output2").is_err());
assert!(store2.get::<String>("output1").is_err());
}
use crate::node::Node;
struct TestNode;
#[async_trait]
impl Node<TestAction> for TestNode {
type PrepResult = String;
type ExecResult = bool;
async fn prep(&self, _store: &MemoryStore) -> Result<Self::PrepResult, CanoError> {
Ok("node_prepared".to_string())
}
async fn exec(&self, prep_res: Self::PrepResult) -> Self::ExecResult {
prep_res == "node_prepared"
}
async fn post(
&self,
store: &MemoryStore,
exec_res: Self::ExecResult,
) -> Result<TestAction, CanoError> {
store.put("node_executed", exec_res)?;
if exec_res {
Ok(TestAction::Complete)
} else {
Ok(TestAction::Error)
}
}
}
#[tokio::test]
async fn test_node_as_task_compatibility() {
let node = TestNode;
let store = MemoryStore::new();
let result = Task::run(&node, &store).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), TaskResult::Single(TestAction::Complete));
let executed: bool = store.get("node_executed").unwrap();
assert!(executed);
}
#[test]
fn test_retry_mode_none() {
let retry_mode = RetryMode::None;
assert_eq!(retry_mode.max_attempts(), 1);
assert_eq!(retry_mode.delay_for_attempt(0), None);
assert_eq!(retry_mode.delay_for_attempt(1), None);
}
#[test]
fn test_retry_mode_fixed() {
let retry_mode = RetryMode::fixed(3, Duration::from_millis(100));
assert_eq!(retry_mode.max_attempts(), 4);
assert_eq!(
retry_mode.delay_for_attempt(0),
Some(Duration::from_millis(100))
);
assert_eq!(
retry_mode.delay_for_attempt(1),
Some(Duration::from_millis(100))
);
assert_eq!(
retry_mode.delay_for_attempt(2),
Some(Duration::from_millis(100))
);
assert_eq!(retry_mode.delay_for_attempt(3), None); assert_eq!(retry_mode.delay_for_attempt(4), None);
}
#[test]
fn test_retry_mode_exponential_basic() {
let retry_mode = RetryMode::exponential(3);
assert_eq!(retry_mode.max_attempts(), 4);
let delay0 = retry_mode.delay_for_attempt(0).unwrap();
let delay1 = retry_mode.delay_for_attempt(1).unwrap();
let delay2 = retry_mode.delay_for_attempt(2).unwrap();
assert!(delay1.as_millis() >= delay0.as_millis() / 2); assert!(delay2.as_millis() >= delay1.as_millis() / 2);
assert_eq!(retry_mode.delay_for_attempt(3), None);
assert_eq!(retry_mode.delay_for_attempt(4), None);
}
#[test]
fn test_task_config_creation() {
let config = TaskConfig::new();
assert_eq!(config.retry_mode.max_attempts(), 4);
}
#[test]
fn test_task_config_default() {
let config = TaskConfig::default();
assert_eq!(config.retry_mode.max_attempts(), 4);
}
#[test]
fn test_task_config_minimal() {
let config = TaskConfig::minimal();
assert_eq!(config.retry_mode.max_attempts(), 1);
}
#[test]
fn test_task_config_with_fixed_retry() {
let config = TaskConfig::new().with_fixed_retry(5, Duration::from_millis(100));
assert_eq!(config.retry_mode.max_attempts(), 6);
}
#[test]
fn test_task_config_builder_pattern() {
let config = TaskConfig::new().with_fixed_retry(10, Duration::from_secs(1));
assert_eq!(config.retry_mode.max_attempts(), 11);
}
#[tokio::test]
async fn test_run_with_retries_success() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::minimal();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("success".to_string()))
}
})
.await
.unwrap();
assert_eq!(result, TaskResult::Single("success".to_string()));
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_run_with_retries_failure() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::new().with_fixed_retry(2, Duration::from_millis(1));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
let count = counter.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err(CanoError::task_execution("failure"))
} else {
Ok::<TaskResult<String>, CanoError>(TaskResult::Single("success".to_string()))
}
}
})
.await
.unwrap();
assert_eq!(result, TaskResult::Single("success".to_string()));
assert_eq!(counter.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_run_with_retries_exhausted() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::new().with_fixed_retry(2, Duration::from_millis(1));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<TaskResult<String>, CanoError>(CanoError::task_execution("always fails"))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_run_with_retries_mode_none() {
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::minimal();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<TaskResult<String>, CanoError>(CanoError::task_execution("immediate fail"))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1);
let err = result.unwrap_err();
assert!(
matches!(err, CanoError::TaskExecution(_)),
"expected original TaskExecution variant when retries disabled, got: {err}"
);
assert!(err.to_string().contains("immediate fail"));
}
#[tokio::test]
async fn test_retry_exhausted_error_type() {
use std::sync::atomic::{AtomicUsize, Ordering};
let config = TaskConfig::new().with_fixed_retry(2, Duration::from_millis(1));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = run_with_retries::<TaskResult<String>, _, _>(&config, || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<TaskResult<String>, CanoError>(CanoError::task_execution(
"persistent failure",
))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 3);
let err = result.unwrap_err();
assert!(
matches!(err, CanoError::RetryExhausted(_)),
"expected RetryExhausted after retry exhaustion, got: {err}"
);
}
}