use anyhow::Result;
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorClass {
Transient,
PromptTooLong,
MaxOutputTokens,
ProviderOverloaded,
Fatal,
}
pub struct RetryContext {
pub attempt: u32,
pub last_error_class: Option<ErrorClass>,
}
pub struct RetryPolicy {
pub max_attempts: u32,
pub base_delay: Duration,
pub max_delay: Duration,
pub backoff_factor: f64,
pub classify: Arc<dyn Fn(&anyhow::Error) -> ErrorClass + Send + Sync>,
}
impl RetryPolicy {
pub fn default_transient() -> Self {
Self {
max_attempts: 3,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(2),
backoff_factor: 2.0,
classify: Arc::new(|_| ErrorClass::Transient),
}
}
}
pub async fn with_retry<T, F, Fut>(policy: &RetryPolicy, f: F) -> Result<T>
where
F: Fn(RetryContext) -> Fut,
Fut: Future<Output = Result<T>>,
{
let mut attempt = 0u32;
let mut last_class = None;
loop {
let ctx = RetryContext {
attempt,
last_error_class: last_class,
};
match f(ctx).await {
Ok(v) => return Ok(v),
Err(e) => {
let class = (policy.classify)(&e);
if class == ErrorClass::Fatal || attempt + 1 >= policy.max_attempts {
return Err(e);
}
let delay_ms = (policy.base_delay.as_millis() as f64
* policy.backoff_factor.powi(attempt as i32))
as u64;
let delay =
Duration::from_millis(delay_ms.min(policy.max_delay.as_millis() as u64));
sleep(delay).await;
last_class = Some(class);
attempt += 1;
}
}
}
}
impl std::fmt::Display for ErrorClass {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorClass::Transient => write!(f, "Transient"),
ErrorClass::PromptTooLong => write!(f, "PromptTooLong"),
ErrorClass::MaxOutputTokens => write!(f, "MaxOutputTokens"),
ErrorClass::ProviderOverloaded => write!(f, "ProviderOverloaded"),
ErrorClass::Fatal => write!(f, "Fatal"),
}
}
}