use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub backoff: BackoffStrategy,
#[serde(with = "duration_secs")]
pub initial_delay: Duration,
#[serde(with = "duration_secs")]
pub max_delay: Duration,
pub jitter: bool,
pub retryable_on: Vec<ErrorClass>,
}
impl RetryPolicy {
pub fn io_default() -> Self {
Self {
max_attempts: 3,
backoff: BackoffStrategy::Exponential,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(30),
jitter: true,
retryable_on: vec![
ErrorClass::IoError,
ErrorClass::Timeout,
ErrorClass::ConnectionReset,
],
}
}
pub fn llm_default() -> Self {
Self {
max_attempts: 3,
backoff: BackoffStrategy::Exponential,
initial_delay: Duration::from_secs(2),
max_delay: Duration::from_secs(60),
jitter: true,
retryable_on: vec![
ErrorClass::RateLimit,
ErrorClass::Timeout,
ErrorClass::ServerError,
],
}
}
pub fn no_retry() -> Self {
Self {
max_attempts: 1,
backoff: BackoffStrategy::Fixed,
initial_delay: Duration::ZERO,
max_delay: Duration::ZERO,
jitter: false,
retryable_on: vec![],
}
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
if attempt == 0 {
return Duration::ZERO;
}
let base = match self.backoff {
BackoffStrategy::Fixed => self.initial_delay,
BackoffStrategy::Linear => self.initial_delay * attempt,
BackoffStrategy::Exponential => {
let factor = 2u64.saturating_pow(attempt - 1);
self.initial_delay.saturating_mul(factor as u32)
}
};
let capped = base.min(self.max_delay);
if self.jitter {
capped
} else {
capped
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BackoffStrategy {
Fixed,
Linear,
Exponential,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ErrorClass {
IoError,
Timeout,
RateLimit,
ServerError,
ConnectionReset,
Custom(String),
}
mod duration_secs {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
s.serialize_u64(d.as_secs())
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
let secs = u64::deserialize(d)?;
Ok(Duration::from_secs(secs))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exponential_delay() {
let policy = RetryPolicy {
max_attempts: 5,
backoff: BackoffStrategy::Exponential,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(30),
jitter: false,
retryable_on: vec![],
};
assert_eq!(policy.delay_for_attempt(0), Duration::ZERO);
assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(1));
assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(2));
assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(4));
assert_eq!(policy.delay_for_attempt(6), Duration::from_secs(30));
}
}