use std::time::Duration;
use crate::config::validate::parse_duration;
use crate::{ConfigError, SondaError};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "config", derive(serde::Serialize, serde::Deserialize))]
pub struct RetryConfig {
pub max_attempts: u32,
pub initial_backoff: String,
pub max_backoff: String,
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
max_attempts: u32,
initial_backoff: Duration,
max_backoff: Duration,
}
impl RetryPolicy {
pub fn from_config(config: &RetryConfig) -> Result<Self, SondaError> {
if config.max_attempts < 1 {
return Err(SondaError::Config(ConfigError::invalid(
"retry max_attempts must be at least 1",
)));
}
let initial_backoff = parse_duration(&config.initial_backoff)?;
let max_backoff = parse_duration(&config.max_backoff)?;
if max_backoff < initial_backoff {
return Err(SondaError::Config(ConfigError::invalid(format!(
"retry max_backoff ({}) must be >= initial_backoff ({})",
config.max_backoff, config.initial_backoff
))));
}
Ok(Self {
max_attempts: config.max_attempts,
initial_backoff,
max_backoff,
})
}
pub fn execute<F, C>(&self, mut operation: F, classify: C) -> Result<(), SondaError>
where
F: FnMut() -> Result<(), SondaError>,
C: Fn(&SondaError) -> bool,
{
let mut last_error = match operation() {
Ok(()) => return Ok(()),
Err(e) => e,
};
for attempt in 0..self.max_attempts {
if !classify(&last_error) {
return Err(last_error);
}
let backoff = self.jittered_backoff(attempt);
eprintln!(
"sonda: retry {}/{} after {}ms (error: {})",
attempt + 1,
self.max_attempts,
backoff.as_millis(),
last_error,
);
std::thread::sleep(backoff);
match operation() {
Ok(()) => return Ok(()),
Err(e) => last_error = e,
}
}
eprintln!(
"sonda: all {} retries exhausted (last error: {})",
self.max_attempts, last_error,
);
Err(last_error)
}
fn jittered_backoff(&self, attempt: u32) -> Duration {
let multiplier: u32 = 1u32.checked_shl(attempt).unwrap_or(u32::MAX);
let base = self.initial_backoff.saturating_mul(multiplier);
let capped = base.min(self.max_backoff);
let nanos = capped.as_nanos() as u64;
if nanos == 0 {
return Duration::ZERO;
}
let thread_hash = {
let name = std::thread::current().name().unwrap_or("").to_owned();
let mut h: u64 = 0xcbf2_9ce4_8422_2325; for byte in name.bytes() {
h ^= byte as u64;
h = h.wrapping_mul(0x0100_0000_01b3); }
h
};
let seed = (attempt as u64)
.wrapping_mul(0x517c_c1b7_2722_0a95)
.wrapping_add(thread_hash);
let hash = crate::util::splitmix64(seed);
let jittered_nanos = hash % (nanos + 1);
Duration::from_nanos(jittered_nanos)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_config_with_valid_values_succeeds() {
let config = RetryConfig {
max_attempts: 3,
initial_backoff: "100ms".to_string(),
max_backoff: "5s".to_string(),
};
let policy = RetryPolicy::from_config(&config).expect("should succeed");
assert_eq!(policy.max_attempts, 3);
assert_eq!(policy.initial_backoff, Duration::from_millis(100));
assert_eq!(policy.max_backoff, Duration::from_secs(5));
}
#[test]
fn from_config_with_equal_backoffs_succeeds() {
let config = RetryConfig {
max_attempts: 1,
initial_backoff: "1s".to_string(),
max_backoff: "1s".to_string(),
};
let policy = RetryPolicy::from_config(&config).expect("should succeed");
assert_eq!(policy.initial_backoff, policy.max_backoff);
}
#[test]
fn from_config_zero_attempts_returns_error() {
let config = RetryConfig {
max_attempts: 0,
initial_backoff: "100ms".to_string(),
max_backoff: "5s".to_string(),
};
let err = RetryPolicy::from_config(&config).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("max_attempts") && msg.contains("at least 1"),
"expected validation message about max_attempts, got: {msg}"
);
}
#[test]
fn from_config_max_less_than_initial_returns_error() {
let config = RetryConfig {
max_attempts: 3,
initial_backoff: "5s".to_string(),
max_backoff: "100ms".to_string(),
};
let err = RetryPolicy::from_config(&config).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("max_backoff") && msg.contains("initial_backoff"),
"expected message about max_backoff >= initial_backoff, got: {msg}"
);
}
#[test]
fn from_config_invalid_initial_backoff_returns_error() {
let config = RetryConfig {
max_attempts: 3,
initial_backoff: "not-a-duration".to_string(),
max_backoff: "5s".to_string(),
};
assert!(RetryPolicy::from_config(&config).is_err());
}
#[test]
fn from_config_invalid_max_backoff_returns_error() {
let config = RetryConfig {
max_attempts: 3,
initial_backoff: "100ms".to_string(),
max_backoff: "bad".to_string(),
};
assert!(RetryPolicy::from_config(&config).is_err());
}
#[cfg(feature = "config")]
#[test]
fn retry_config_deserializes_from_yaml() {
let yaml = r#"
max_attempts: 5
initial_backoff: 200ms
max_backoff: 10s
"#;
let config: RetryConfig = serde_yaml_ng::from_str(yaml).expect("should deserialize");
assert_eq!(config.max_attempts, 5);
assert_eq!(config.initial_backoff, "200ms");
assert_eq!(config.max_backoff, "10s");
}
#[cfg(feature = "config")]
#[test]
fn retry_config_round_trip_through_policy() {
let yaml = r#"
max_attempts: 3
initial_backoff: 100ms
max_backoff: 5s
"#;
let config: RetryConfig = serde_yaml_ng::from_str(yaml).expect("should deserialize");
let policy = RetryPolicy::from_config(&config).expect("should validate");
assert_eq!(policy.max_attempts, 3);
assert_eq!(policy.initial_backoff, Duration::from_millis(100));
assert_eq!(policy.max_backoff, Duration::from_secs(5));
}
#[test]
fn jittered_backoff_is_at_most_initial_for_attempt_zero() {
let policy = RetryPolicy {
max_attempts: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(10),
};
for _ in 0..100 {
let backoff = policy.jittered_backoff(0);
assert!(
backoff <= Duration::from_millis(100),
"attempt 0 backoff {} must be <= 100ms",
backoff.as_millis()
);
}
}
#[test]
fn jittered_backoff_capped_at_max_backoff() {
let policy = RetryPolicy {
max_attempts: 10,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_millis(500),
};
for _ in 0..100 {
let backoff = policy.jittered_backoff(10);
assert!(
backoff <= Duration::from_millis(500),
"backoff {} must be <= 500ms max_backoff",
backoff.as_millis()
);
}
}
#[test]
fn jittered_backoff_with_zero_duration_returns_zero() {
let policy = RetryPolicy {
max_attempts: 1,
initial_backoff: Duration::ZERO,
max_backoff: Duration::ZERO,
};
let backoff = policy.jittered_backoff(0);
assert_eq!(backoff, Duration::ZERO);
}
#[test]
fn execute_succeeds_on_first_attempt() {
let policy = RetryPolicy {
max_attempts: 3,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(1),
};
let mut calls = 0u32;
let result = policy.execute(
|| {
calls += 1;
Ok(())
},
|_| true,
);
assert!(result.is_ok());
assert_eq!(calls, 1, "should only call once on immediate success");
}
#[test]
fn execute_retries_transient_error_then_succeeds() {
let policy = RetryPolicy {
max_attempts: 3,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(1),
};
let mut calls = 0u32;
let result = policy.execute(
|| {
calls += 1;
if calls < 3 {
Err(SondaError::Sink(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"transient",
)))
} else {
Ok(())
}
},
|_| true,
);
assert!(result.is_ok());
assert_eq!(calls, 3, "should call 1 initial + 2 retries");
}
#[test]
fn execute_exhausts_retries_returns_last_error() {
let policy = RetryPolicy {
max_attempts: 2,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(1),
};
let mut calls = 0u32;
let result = policy.execute(
|| {
calls += 1;
Err(SondaError::Sink(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
"always fails",
)))
},
|_| true,
);
assert!(result.is_err());
assert_eq!(calls, 3, "should call 1 initial + 2 retries");
}
#[test]
fn execute_non_retryable_error_returns_immediately() {
let policy = RetryPolicy {
max_attempts: 5,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(1),
};
let mut calls = 0u32;
let result = policy.execute(
|| {
calls += 1;
Err(SondaError::Sink(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"permanent 4xx",
)))
},
|_| false, );
assert!(result.is_err());
assert_eq!(calls, 1, "non-retryable error should not trigger retries");
}
#[test]
fn execute_classifier_distinguishes_retryable_from_permanent() {
let policy = RetryPolicy {
max_attempts: 5,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(1),
};
let mut calls = 0u32;
let result = policy.execute(
|| {
calls += 1;
if calls == 1 {
Err(SondaError::Sink(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"transient",
)))
} else {
Err(SondaError::Sink(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"permanent",
)))
}
},
|err| {
matches!(err, SondaError::Sink(ref io_err) if io_err.kind() == std::io::ErrorKind::ConnectionReset)
},
);
assert!(result.is_err());
assert_eq!(
calls, 2,
"should retry once (transient) then stop (permanent)"
);
}
#[test]
fn retry_policy_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<RetryPolicy>();
}
#[test]
fn retry_config_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<RetryConfig>();
}
#[test]
fn retry_policy_is_debuggable() {
let policy = RetryPolicy {
max_attempts: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(5),
};
let s = format!("{policy:?}");
assert!(s.contains("RetryPolicy"));
}
#[test]
fn retry_config_is_cloneable() {
let config = RetryConfig {
max_attempts: 3,
initial_backoff: "100ms".to_string(),
max_backoff: "5s".to_string(),
};
let cloned = config.clone();
assert_eq!(cloned.max_attempts, 3);
assert_eq!(cloned.initial_backoff, "100ms");
assert_eq!(cloned.max_backoff, "5s");
}
}