use std::time::Duration;
use caliban_provider::Error as ProviderError;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub initial_backoff: Duration,
pub backoff_multiplier: f32,
pub max_backoff: Duration,
pub jitter: bool,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_backoff: Duration::from_millis(500),
backoff_multiplier: 2.0,
max_backoff: Duration::from_secs(30),
jitter: true,
}
}
}
impl RetryPolicy {
#[must_use]
pub fn no_retry() -> Self {
Self {
max_attempts: 1,
..Self::default()
}
}
}
#[must_use]
pub fn is_retryable(e: &ProviderError) -> bool {
matches!(
e,
ProviderError::RateLimit { .. }
| ProviderError::Network(_)
| ProviderError::StreamInterrupted(_)
| ProviderError::ServerError {
status: 500..=599,
..
},
)
}
#[must_use]
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_precision_loss
)]
pub fn compute_backoff(policy: &RetryPolicy, attempt: u32) -> Duration {
let factor = policy
.backoff_multiplier
.powi(i32::try_from(attempt.saturating_sub(1)).unwrap_or(i32::MAX));
let nominal_ms = (policy.initial_backoff.as_millis() as f64 * f64::from(factor)) as u64;
let nominal = Duration::from_millis(nominal_ms).min(policy.max_backoff);
if policy.jitter {
let pct = 0.5 + rand::random::<f32>() * 0.5;
let jittered_ms = (nominal.as_millis() as f64 * f64::from(pct)) as u64;
Duration::from_millis(jittered_ms)
} else {
nominal
}
}
#[must_use]
pub fn sleep_for(policy: &RetryPolicy, error: &ProviderError, attempt: u32) -> Duration {
if let ProviderError::RateLimit {
retry_after: Some(d),
} = error
{
return *d;
}
compute_backoff(policy, attempt)
}
pub async fn with_retry<F, Fut, T>(
policy: &RetryPolicy,
cancel: &CancellationToken,
mut f: F,
) -> std::result::Result<T, ProviderError>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, ProviderError>>,
{
let mut last_err: Option<ProviderError> = None;
for attempt in 1..=policy.max_attempts {
if cancel.is_cancelled() {
return Err(ProviderError::Cancelled);
}
match f().await {
Ok(v) => return Ok(v),
Err(e) => {
if !is_retryable(&e) || attempt == policy.max_attempts {
return Err(e);
}
let sleep_d = sleep_for(policy, &e, attempt);
tracing::warn!(attempt, backoff_ms = u64::try_from(sleep_d.as_millis()).unwrap_or(u64::MAX), error = %e, "provider call failed; retrying");
last_err = Some(e);
tokio::select! {
() = tokio::time::sleep(sleep_d) => {}
() = cancel.cancelled() => {
return Err(ProviderError::Cancelled);
}
}
}
}
}
Err(last_err.unwrap_or_else(|| {
ProviderError::Adapter(Box::<dyn std::error::Error + Send + Sync>::from(
"retry exhausted",
))
}))
}