use std::sync::Arc;
use std::time::Duration;
use cognis_core::error::CognisError;
#[derive(Debug, Clone)]
pub enum OnFailure {
Continue,
Error,
}
#[derive(Default)]
pub enum RetryCondition {
#[default]
AnyError,
ErrorContains(Vec<String>),
Custom(Arc<dyn Fn(&CognisError) -> bool + Send + Sync>),
}
impl std::fmt::Debug for RetryCondition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RetryCondition::AnyError => write!(f, "AnyError"),
RetryCondition::ErrorContains(v) => write!(f, "ErrorContains({:?})", v),
RetryCondition::Custom(_) => write!(f, "Custom(<fn>)"),
}
}
}
impl Clone for RetryCondition {
fn clone(&self) -> Self {
match self {
RetryCondition::AnyError => RetryCondition::AnyError,
RetryCondition::ErrorContains(v) => RetryCondition::ErrorContains(v.clone()),
RetryCondition::Custom(f) => RetryCondition::Custom(Arc::clone(f)),
}
}
}
impl RetryCondition {
pub fn matches(&self, error: &CognisError) -> bool {
match self {
RetryCondition::AnyError => true,
RetryCondition::ErrorContains(substrings) => {
let msg = error.to_string();
substrings.iter().any(|s| msg.contains(s))
}
RetryCondition::Custom(f) => f(error),
}
}
}
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: usize,
pub initial_delay_ms: u64,
pub backoff_multiplier: f64,
pub max_delay_ms: u64,
pub jitter: bool,
pub on_failure: OnFailure,
pub retry_on: RetryCondition,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 2,
initial_delay_ms: 1000,
backoff_multiplier: 2.0,
max_delay_ms: 60_000,
jitter: true,
on_failure: OnFailure::Continue,
retry_on: RetryCondition::default(),
}
}
}
impl RetryConfig {
pub fn new(max_retries: usize) -> Self {
Self {
max_retries,
..Default::default()
}
}
pub fn calculate_delay(&self, attempt: usize) -> Duration {
let effective_multiplier = if self.backoff_multiplier == 0.0 {
1.0
} else {
self.backoff_multiplier
};
let base_ms = self.initial_delay_ms as f64 * effective_multiplier.powi(attempt as i32);
let capped_ms = base_ms.min(self.max_delay_ms as f64);
let final_ms = if self.jitter {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
let jitter_factor = 0.5 + (nanos as f64 % 1_000_000.0) / 2_000_000.0;
capped_ms * jitter_factor
} else {
capped_ms
};
Duration::from_millis(final_ms.max(1.0) as u64)
}
}
pub fn should_retry(error: &CognisError, condition: &RetryCondition) -> bool {
condition.matches(error)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_config_default() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 2);
assert_eq!(config.initial_delay_ms, 1000);
assert_eq!(config.backoff_multiplier, 2.0);
assert!(matches!(config.on_failure, OnFailure::Continue));
assert!(matches!(config.retry_on, RetryCondition::AnyError));
}
#[test]
fn test_calculate_delay_exponential() {
let config = RetryConfig {
initial_delay_ms: 100,
backoff_multiplier: 2.0,
jitter: false,
max_delay_ms: 10_000,
..Default::default()
};
assert_eq!(config.calculate_delay(0).as_millis(), 100);
assert_eq!(config.calculate_delay(1).as_millis(), 200);
assert_eq!(config.calculate_delay(2).as_millis(), 400);
}
#[test]
fn test_calculate_delay_capped() {
let config = RetryConfig {
initial_delay_ms: 5000,
backoff_multiplier: 3.0,
jitter: false,
max_delay_ms: 10_000,
..Default::default()
};
assert_eq!(config.calculate_delay(1).as_millis(), 10_000);
}
#[test]
fn test_calculate_delay_zero_multiplier() {
let config = RetryConfig {
initial_delay_ms: 500,
backoff_multiplier: 0.0,
jitter: false,
max_delay_ms: 60_000,
..Default::default()
};
assert_eq!(config.calculate_delay(0).as_millis(), 500);
assert_eq!(config.calculate_delay(1).as_millis(), 500);
assert_eq!(config.calculate_delay(5).as_millis(), 500);
}
#[test]
fn test_calculate_delay_with_jitter() {
let config = RetryConfig {
initial_delay_ms: 1000,
backoff_multiplier: 1.0,
jitter: true,
max_delay_ms: 60_000,
..Default::default()
};
let delay = config.calculate_delay(0);
assert!(delay.as_millis() >= 500);
assert!(delay.as_millis() <= 1000);
}
#[test]
fn test_should_retry_any_error() {
let cond = RetryCondition::AnyError;
assert!(should_retry(&CognisError::Other("timeout".into()), &cond));
assert!(should_retry(
&CognisError::ToolException("bad".into()),
&cond
));
}
#[test]
fn test_should_retry_error_contains() {
let cond = RetryCondition::ErrorContains(vec!["timeout".into()]);
assert!(should_retry(
&CognisError::Other("connection timeout".into()),
&cond
));
assert!(!should_retry(
&CognisError::Other("bad input".into()),
&cond
));
}
#[test]
fn test_should_retry_custom() {
let cond = RetryCondition::Custom(Arc::new(|e| matches!(e, CognisError::HttpError { .. })));
assert!(should_retry(
&CognisError::HttpError {
status: 500,
body: "err".into()
},
&cond
));
assert!(!should_retry(&CognisError::Other("nope".into()), &cond));
}
#[test]
fn test_retry_condition_clone() {
let cond = RetryCondition::ErrorContains(vec!["test".into()]);
let cloned = cond.clone();
assert!(matches!(cloned, RetryCondition::ErrorContains(_)));
}
}