use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde_json::Value;
use crate::error::{CognisError, Result};
use super::base::Runnable;
use super::config::RunnableConfig;
use super::RunnableStream;
pub struct RunnableRetry {
inner: Arc<dyn Runnable>,
max_retries: u32,
initial_delay_ms: u64,
backoff_factor: f64,
max_delay_ms: u64,
#[allow(clippy::type_complexity)]
retry_on: Option<Box<dyn Fn(&CognisError) -> bool + Send + Sync>>,
}
impl RunnableRetry {
pub fn new(inner: Arc<dyn Runnable>, max_retries: u32) -> Self {
Self {
inner,
max_retries,
initial_delay_ms: 500,
backoff_factor: 2.0,
max_delay_ms: 30_000,
retry_on: None,
}
}
pub fn with_initial_delay(mut self, delay_ms: u64) -> Self {
self.initial_delay_ms = delay_ms;
self
}
pub fn with_backoff_factor(mut self, factor: f64) -> Self {
self.backoff_factor = factor;
self
}
pub fn with_max_delay(mut self, max_delay_ms: u64) -> Self {
self.max_delay_ms = max_delay_ms;
self
}
pub fn with_wait(mut self, initial_ms: u64, max_ms: u64) -> Self {
self.initial_delay_ms = initial_ms;
self.max_delay_ms = max_ms;
self
}
pub fn with_retry_on<F>(mut self, filter: F) -> Self
where
F: Fn(&CognisError) -> bool + Send + Sync + 'static,
{
self.retry_on = Some(Box::new(filter));
self
}
}
#[async_trait]
impl Runnable for RunnableRetry {
fn name(&self) -> &str {
"RunnableRetry"
}
async fn invoke(&self, input: Value, config: Option<&RunnableConfig>) -> Result<Value> {
let mut last_err = None;
let mut delay = self.initial_delay_ms;
for attempt in 0..=self.max_retries {
match self.inner.invoke(input.clone(), config).await {
Ok(result) => return Ok(result),
Err(e) => {
if let Some(ref filter) = self.retry_on {
if !filter(&e) {
return Err(e);
}
}
last_err = Some(e);
if attempt < self.max_retries {
tokio::time::sleep(Duration::from_millis(delay)).await;
delay = (delay as f64 * self.backoff_factor).min(self.max_delay_ms as f64)
as u64;
}
}
}
}
Err(last_err.unwrap())
}
async fn batch(
&self,
inputs: Vec<Value>,
config: Option<&RunnableConfig>,
) -> Result<Vec<Value>> {
let mut results = Vec::with_capacity(inputs.len());
for input in inputs {
results.push(self.invoke(input, config).await?);
}
Ok(results)
}
async fn stream(
&self,
input: Value,
config: Option<&RunnableConfig>,
) -> Result<RunnableStream> {
let mut last_err = None;
let mut delay = self.initial_delay_ms;
for attempt in 0..=self.max_retries {
match self.inner.stream(input.clone(), config).await {
Ok(stream) => return Ok(stream),
Err(e) => {
if let Some(ref filter) = self.retry_on {
if !filter(&e) {
return Err(e);
}
}
last_err = Some(e);
if attempt < self.max_retries {
tokio::time::sleep(Duration::from_millis(delay)).await;
delay = (delay as f64 * self.backoff_factor).min(self.max_delay_ms as f64)
as u64;
}
}
}
}
Err(last_err.unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
struct FailNTimes {
fail_count: u32,
attempts: AtomicU32,
}
impl FailNTimes {
fn new(fail_count: u32) -> Self {
Self {
fail_count,
attempts: AtomicU32::new(0),
}
}
}
#[async_trait]
impl Runnable for FailNTimes {
fn name(&self) -> &str {
"FailNTimes"
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
let attempt = self.attempts.fetch_add(1, Ordering::SeqCst);
if attempt < self.fail_count {
Err(CognisError::Other(format!("attempt {} failed", attempt)))
} else {
Ok(input)
}
}
}
struct AlwaysFails;
#[async_trait]
impl Runnable for AlwaysFails {
fn name(&self) -> &str {
"AlwaysFails"
}
async fn invoke(&self, _input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
Err(CognisError::Other("always fails".into()))
}
}
struct FailsWithToolError;
#[async_trait]
impl Runnable for FailsWithToolError {
fn name(&self) -> &str {
"FailsWithToolError"
}
async fn invoke(&self, _input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
Err(CognisError::ToolException("tool broke".into()))
}
}
#[tokio::test]
async fn test_retry_succeeds_first_try() {
let inner = Arc::new(FailNTimes::new(0)); let retry = RunnableRetry::new(inner, 3);
let result = retry.invoke(serde_json::json!(42), None).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), serde_json::json!(42));
}
#[tokio::test]
async fn test_retry_succeeds_after_failures() {
let inner = Arc::new(FailNTimes::new(2)); let retry = RunnableRetry::new(inner, 3)
.with_initial_delay(1) .with_max_delay(10);
let result = retry.invoke(serde_json::json!("hello"), None).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), serde_json::json!("hello"));
}
#[tokio::test]
async fn test_retry_exhausts_attempts() {
let inner = Arc::new(AlwaysFails);
let retry = RunnableRetry::new(inner, 2)
.with_initial_delay(1)
.with_max_delay(5);
let result = retry.invoke(serde_json::json!("test"), None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(format!("{}", err).contains("always fails"));
}
#[tokio::test]
async fn test_retry_with_filter() {
let inner = Arc::new(AlwaysFails); let retry = RunnableRetry::new(inner, 3)
.with_initial_delay(1)
.with_retry_on(|e| matches!(e, CognisError::ToolException(_)));
let result = retry.invoke(serde_json::json!("test"), None).await;
assert!(result.is_err());
let inner2 = Arc::new(FailsWithToolError);
let retry2 = RunnableRetry::new(inner2, 1)
.with_initial_delay(1)
.with_retry_on(|e| matches!(e, CognisError::ToolException(_)));
let result2 = retry2.invoke(serde_json::json!("test"), None).await;
assert!(result2.is_err());
assert!(format!("{}", result2.unwrap_err()).contains("tool broke"));
}
#[tokio::test]
async fn test_retry_backoff_delay() {
let inner = Arc::new(AlwaysFails);
let retry = RunnableRetry::new(inner, 2)
.with_initial_delay(50)
.with_backoff_factor(2.0)
.with_max_delay(30_000);
let start = tokio::time::Instant::now();
let _ = retry.invoke(serde_json::json!("test"), None).await;
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() >= 140,
"Expected at least ~150ms of backoff delay, got {}ms",
elapsed.as_millis()
);
}
}