use std::time::Duration;
use nexo_config::types::llm::RetryConfig;
pub fn parse_retry_after_ms(
headers: &reqwest::header::HeaderMap,
header_name: &str,
fallback_ms: u64,
) -> u64 {
let Some(raw) = headers.get(header_name).and_then(|v| v.to_str().ok()) else {
return fallback_ms;
};
if let Ok(secs) = raw.parse::<u64>() {
return secs.saturating_mul(1000);
}
if let Ok(when) = chrono::DateTime::parse_from_rfc2822(raw) {
let now = chrono::Utc::now();
let delta = when.with_timezone(&chrono::Utc) - now;
let ms = delta.num_milliseconds().max(0) as u64;
return ms.max(1_000);
}
fallback_ms
}
#[derive(Debug, thiserror::Error)]
pub enum LlmError {
#[error("rate limited — retry after {retry_after_ms}ms")]
RateLimit { retry_after_ms: u64 },
#[error("server error {status}: {body}")]
ServerError { status: u16, body: String },
#[error("credential invalid: {hint}")]
CredentialInvalid { hint: String },
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub enum RetryClass {
RateLimit,
Server,
Fatal,
}
pub fn classify(status: u16) -> RetryClass {
match status {
429 => RetryClass::RateLimit,
500..=599 => RetryClass::Server,
_ => RetryClass::Fatal,
}
}
fn jittered_backoff(base_ms: u64, last_ms: u64, multiplier: f32, max_ms: u64) -> u64 {
let hi = ((last_ms as f32) * multiplier).max(base_ms as f32) as u64;
let hi = hi.min(max_ms).max(base_ms);
if hi <= base_ms {
return base_ms.min(max_ms);
}
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos() as u64)
.unwrap_or(0);
let span = hi - base_ms + 1;
base_ms + (nanos % span)
}
pub async fn with_retry<T, F, Fut>(config: &RetryConfig, mut f: F) -> Result<T, LlmError>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, LlmError>>,
{
let mut attempt = 0u32;
let mut backoff_ms = config.initial_backoff_ms;
loop {
match f().await {
Ok(v) => return Ok(v),
Err(LlmError::RateLimit { retry_after_ms }) => {
attempt += 1;
if attempt >= 5 {
return Err(LlmError::RateLimit { retry_after_ms });
}
let wait = retry_after_ms.max(backoff_ms);
tracing::warn!(attempt, wait_ms = wait, "LLM rate limited — retrying");
tokio::time::sleep(Duration::from_millis(wait)).await;
backoff_ms = jittered_backoff(
config.initial_backoff_ms,
backoff_ms,
config.backoff_multiplier,
config.max_backoff_ms,
);
}
Err(LlmError::ServerError { status, ref body }) => {
attempt += 1;
if attempt >= 3 {
return Err(LlmError::ServerError {
status,
body: body.clone(),
});
}
tracing::warn!(attempt, status, "LLM server error — retrying");
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = jittered_backoff(
config.initial_backoff_ms,
backoff_ms,
config.backoff_multiplier,
config.max_backoff_ms,
);
}
Err(e) => return Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::jittered_backoff;
#[test]
fn jitter_bounded_by_range() {
for _ in 0..50 {
let b = jittered_backoff(100, 400, 2.0, 10_000);
assert!((100..=800).contains(&b), "got {b}");
}
}
#[test]
fn jitter_respects_max() {
let b = jittered_backoff(100, 10_000, 2.0, 5_000);
assert!((100..=5_000).contains(&b), "got {b}");
}
}