use std::sync::LazyLock;
use std::time::Duration;
use regex::Regex;
static STATUS_5XX_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(
r"(?x)
(?:
# prefix-anchored: status / code / error / http /
# response / request / returned, with optional
# `:`/`=`/`-`/whitespace between marker and number.
(?:status(?:_code)?|code|error|http|response|request|returned|returns)
\s*[:=\-]?\s*
5\d{2}
(?:\D|$)
)
|
(?:
# leading status + HTTP reason phrase (5xx Service / 5xx
# Gateway / 5xx Internal / 5xx Bad / 5xx Server).
(?:^|\D)
5\d{2}
\s+
(?:service|gateway|internal|bad|server)
)
",
)
.expect("static regex compiles")
});
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ErrorKind {
ContextLength,
RateLimit,
Network,
Auth,
Other,
}
#[derive(Debug, Clone)]
pub struct RecoveryPolicy {
max_retries: usize,
backoff_base: Duration,
}
impl Default for RecoveryPolicy {
fn default() -> Self {
Self {
max_retries: 5,
backoff_base: Duration::from_secs(1),
}
}
}
impl RecoveryPolicy {
pub fn max_retries(&self) -> usize {
self.max_retries
}
pub fn should_retry(&self, attempts: usize, kind: ErrorKind) -> bool {
if attempts >= self.max_retries {
return false;
}
matches!(kind, ErrorKind::Network | ErrorKind::RateLimit)
}
pub fn backoff_duration(&self, attempts: usize) -> Duration {
let exp = 1u64 << attempts.min(6); let base = self.backoff_base.as_millis() as u64;
let ms = base.saturating_mul(exp);
let jitter = pseudo_random(attempts as u64) % (ms / 4).max(1);
Duration::from_millis(ms.saturating_add(jitter))
}
pub fn backoff_duration_for_msg(&self, attempts: usize, error_msg: &str) -> Duration {
let computed = self.backoff_duration(attempts);
match retry_after_from_error_msg(error_msg) {
Some(server_wants) => {
const CAP: Duration = Duration::from_secs(300);
let chosen = server_wants.max(computed);
if chosen > CAP { CAP } else { chosen }
}
None => computed,
}
}
#[cfg(test)]
pub(crate) fn with_backoff(max_retries: usize, backoff_base: Duration) -> Self {
Self {
max_retries,
backoff_base,
}
}
}
pub async fn run_with_retry<T, E, F, Fut>(
policy: &RecoveryPolicy,
label: &str,
mut attempt: F,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let mut attempts = 0;
loop {
match attempt().await {
Ok(value) => return Ok(value),
Err(err) => {
let msg = err.to_string();
let kind = classify_error(&msg);
if !policy.should_retry(attempts, kind) {
return Err(err);
}
let delay = policy.backoff_duration_for_msg(attempts, &msg);
tracing::warn!(
op = label,
attempt = attempts + 1,
max = policy.max_retries(),
delay_ms = delay.as_millis() as u64,
kind = ?kind,
error = %msg,
"retrying after transient failure",
);
tokio::time::sleep(delay).await;
attempts += 1;
}
}
}
}
pub(crate) fn retry_after_from_error_msg(msg: &str) -> Option<Duration> {
fn parse_after_label(msg: &str, label: &str) -> Option<u64> {
let label_bytes = label.as_bytes();
let msg_bytes = msg.as_bytes();
if msg_bytes.len() < label_bytes.len() {
return None;
}
let mut idx = None;
for i in 0..=msg_bytes.len() - label_bytes.len() {
let window = &msg_bytes[i..i + label_bytes.len()];
if window
.iter()
.zip(label_bytes.iter())
.all(|(a, b)| a.eq_ignore_ascii_case(b))
{
idx = Some(i);
break;
}
}
let idx = idx?;
let after = idx + label.len();
if !msg.is_char_boundary(after) {
return None;
}
let tail = &msg[after..];
let tail = tail.trim_start_matches([':', ' ', '\t', '"']).trim_start();
let n: String = tail
.chars()
.take_while(|c| c.is_ascii_digit())
.take(11)
.collect();
if n.is_empty() {
return None;
}
n.parse().ok()
}
if let Some(ms) = parse_after_label(msg, "retry-after-ms") {
return Some(Duration::from_millis(ms));
}
if let Some(secs) = parse_after_label(msg, "retry-after") {
return Some(Duration::from_secs(secs));
}
if let Some(secs) = parse_after_label(msg, "retry_after") {
return Some(Duration::from_secs(secs));
}
if let Some(d) = parse_http_date_retry_after(msg) {
return Some(d);
}
None
}
fn parse_http_date_retry_after(msg: &str) -> Option<Duration> {
let label = "retry-after";
let label_bytes = label.as_bytes();
let msg_bytes = msg.as_bytes();
if msg_bytes.len() < label_bytes.len() {
return None;
}
let mut found = None;
for i in 0..=msg_bytes.len() - label_bytes.len() {
let window = &msg_bytes[i..i + label_bytes.len()];
if window
.iter()
.zip(label_bytes.iter())
.all(|(a, b)| a.eq_ignore_ascii_case(b))
{
found = Some(i);
break;
}
}
let idx = found?;
let after = idx + label.len();
if !msg.is_char_boundary(after) {
return None;
}
let tail = &msg[after..];
let tail = tail.trim_start_matches([':', ' ', '\t', '"']);
let value: String = tail
.chars()
.take_while(|&c| c != '\n' && c != '\r' && c != '"')
.collect();
let value = value.trim();
if value.is_empty() {
return None;
}
let parsed = chrono::DateTime::parse_from_rfc2822(value)
.ok()
.or_else(|| {
chrono::NaiveDateTime::parse_from_str(value, "%a %b %e %H:%M:%S %Y")
.ok()
.map(|n| n.and_utc().fixed_offset())
})?;
let now = chrono::Utc::now().fixed_offset();
let delta = parsed - now;
Some(Duration::from_secs(delta.num_seconds().max(0) as u64))
}
fn pseudo_random(salt: u64) -> u64 {
use std::sync::atomic::{AtomicU64, Ordering};
static SEQ: AtomicU64 = AtomicU64::new(0);
let seq = SEQ.fetch_add(1, Ordering::Relaxed);
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos() as u64)
.unwrap_or(0);
let mut z = nanos
.wrapping_add(salt)
.wrapping_add(seq.wrapping_mul(0xA240_2A1F_1CE4_E5B9))
.wrapping_add(0x9E37_79B9_7F4A_7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
pub fn classify_error(msg: &str) -> ErrorKind {
let lower = msg.to_lowercase();
if lower.contains(" 401 ")
|| lower.contains(" 403 ")
|| lower.contains("error 401")
|| lower.contains("error 403")
|| lower.starts_with("401 ")
|| lower.starts_with("403 ")
{
return ErrorKind::Auth;
}
if lower.contains("unauthorized")
|| lower.contains("invalid api key")
|| lower.contains("authentication failed")
{
return ErrorKind::Auth;
}
if lower.contains("insufficient_quota")
|| lower.contains("billing_not_active")
|| lower.contains("billing_hard_limit_reached")
{
return ErrorKind::Auth;
}
if lower.contains("rate limit") || lower.contains("too many requests") {
return ErrorKind::RateLimit;
}
if lower.contains(" 429 ") || lower.contains("error 429") || lower.starts_with("429 ") {
return ErrorKind::RateLimit;
}
if lower.contains("resource_exhausted") || lower.contains("resource has been exhausted") {
return ErrorKind::RateLimit;
}
if lower.contains("overloaded") {
return ErrorKind::RateLimit;
}
if STATUS_5XX_RE.is_match(&lower) {
return ErrorKind::Network;
}
if lower.contains("context_length_exceeded")
|| lower.contains("maximum context length")
|| lower.contains("reduce the length of the messages")
|| lower.contains("request too large")
|| lower.contains("prompt is too long")
|| lower.contains("input is too long")
|| lower.contains("input token count exceeds")
|| lower.contains("tokens exceed")
|| lower.contains("exceeds the model's context")
|| lower.contains("max_tokens is too large")
|| lower.contains("too many tokens")
|| lower.contains("range of input length")
|| lower.contains("messages.length too large")
{
return ErrorKind::ContextLength;
}
if lower.contains("<!doctype html")
|| lower.contains("<html")
|| lower.contains("bad gateway")
|| lower.contains("service unavailable")
|| lower.contains("gateway timeout")
|| lower.contains("cloudflare")
{
return ErrorKind::Network;
}
if lower.contains("connection refused")
|| lower.contains("connection reset")
|| lower.contains("broken pipe")
|| lower.contains("dns error")
|| lower.contains("tls")
|| lower.contains("ssl")
|| lower.contains("timed out")
|| lower.contains("request timeout")
|| lower.contains("server error")
|| lower.contains("error sending request")
|| lower.contains("connect error")
|| lower.contains("tcp connect")
|| lower.contains("error decoding response body")
|| lower.contains("invalid response body")
|| lower.contains("decode error")
{
return ErrorKind::Network;
}
ErrorKind::Other
}
#[allow(dead_code)]
pub fn user_facing_error(msg: &str, attempts: usize) -> String {
let kind = classify_error(msg);
let lower = msg.to_lowercase();
let (headline, hint) = match kind {
ErrorKind::Auth => (
"authentication failed talking to the LLM provider",
"check your API key env var (e.g. OPENROUTER_API_KEY) and provider config",
),
ErrorKind::RateLimit => (
"provider rate-limited the request",
"wait a moment and retry, or switch to a different model via /model",
),
ErrorKind::ContextLength => (
"conversation exceeds the model's context window",
"run /compress to summarize older turns and try again",
),
ErrorKind::Network if lower.contains("error decoding response body") => (
"lost the response stream from the provider (truncated or malformed body)",
"usually transient — retry. If it persists the provider may be having issues or returning non-JSON (HTML error pages, plaintext)",
),
ErrorKind::Network => (
"network error reaching the LLM provider",
"check connectivity / firewall / proxy; the request will retry automatically",
),
ErrorKind::Other => (
"the LLM provider returned an error we didn't recognize",
"see the cause below; consider /model to try a different provider",
),
};
let attempts_note = if attempts > 1 {
format!(" (after {} attempt(s))", attempts)
} else {
String::new()
};
format!(
"{}{}\n ↳ hint: {}\n ↳ cause: {}",
headline, attempts_note, hint, msg
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn default_budget_retries_transient_failures_up_to_five_times() {
let p = RecoveryPolicy::default();
assert_eq!(p.max_retries(), 5);
assert!(p.should_retry(0, ErrorKind::Network));
assert!(p.should_retry(4, ErrorKind::Network));
assert!(!p.should_retry(5, ErrorKind::Network));
assert!(!p.should_retry(0, ErrorKind::Auth));
}
#[tokio::test]
async fn run_with_retry_returns_first_success() {
let policy = RecoveryPolicy::default();
let calls = AtomicUsize::new(0);
let r: Result<u32, String> = run_with_retry(&policy, "t", || {
calls.fetch_add(1, Ordering::SeqCst);
async { Ok(7) }
})
.await;
assert_eq!(r.unwrap(), 7);
assert_eq!(calls.load(Ordering::SeqCst), 1, "no retry on success");
}
#[tokio::test]
async fn run_with_retry_bails_immediately_on_non_retryable() {
let policy = RecoveryPolicy::default();
let calls = AtomicUsize::new(0);
let r: Result<u32, String> = run_with_retry(&policy, "t", || {
calls.fetch_add(1, Ordering::SeqCst);
async { Err("invalid api key".to_string()) }
})
.await;
assert!(r.is_err());
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"auth error must not be retried"
);
}
#[tokio::test]
async fn run_with_retry_retries_transient_then_succeeds() {
let policy = RecoveryPolicy::with_backoff(3, Duration::from_millis(1));
let calls = AtomicUsize::new(0);
let r: Result<u32, String> = run_with_retry(&policy, "t", || {
let n = calls.fetch_add(1, Ordering::SeqCst);
async move {
if n < 2 {
Err("rate limit exceeded".to_string())
} else {
Ok(42)
}
}
})
.await;
assert_eq!(r.unwrap(), 42);
assert_eq!(calls.load(Ordering::SeqCst), 3, "two retries then success");
}
#[tokio::test]
async fn run_with_retry_exhausts_then_returns_last_error() {
let policy = RecoveryPolicy::with_backoff(2, Duration::from_millis(1));
let calls = AtomicUsize::new(0);
let r: Result<u32, String> = run_with_retry(&policy, "t", || {
calls.fetch_add(1, Ordering::SeqCst);
async { Err("rate limit exceeded".to_string()) }
})
.await;
assert!(r.is_err());
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
#[test]
fn classify_connect_send_failures_as_network() {
for msg in [
"ProviderError: Http client error: error sending request for url (https://api.deepseek.com/v1/chat/completions)",
"error sending request for url (https://api.openai.com/v1/chat/completions)",
"reqwest::Error { kind: Connect, ... }: tcp connect error",
"Http client error: connect error",
] {
assert_eq!(
classify_error(msg),
ErrorKind::Network,
"connect/send failure must be retryable: {msg}"
);
}
let policy = RecoveryPolicy::default();
assert!(
policy.should_retry(0, classify_error("error sending request for url (x)")),
"the DeepSeek connect failure must be retried"
);
}
#[test]
fn test_classify_context_length() {
assert_eq!(
classify_error("context_length_exceeded: prompt too long"),
ErrorKind::ContextLength
);
assert_eq!(
classify_error("reduce the length of the messages"),
ErrorKind::ContextLength
);
assert_eq!(
classify_error("request too large for model"),
ErrorKind::ContextLength
);
}
#[test]
fn test_classify_context_length_provider_variants() {
assert_eq!(
classify_error("prompt is too long: 250000 tokens > 200000 maximum"),
ErrorKind::ContextLength
);
assert_eq!(
classify_error(
"This model's maximum context length is 128000 tokens. However, your messages resulted in 130000 tokens."
),
ErrorKind::ContextLength
);
assert_eq!(
classify_error("input is too long for the requested model"),
ErrorKind::ContextLength
);
assert_eq!(
classify_error("The input token count exceeds the maximum number of tokens allowed"),
ErrorKind::ContextLength
);
assert_eq!(
classify_error("Total tokens exceed model's context window"),
ErrorKind::ContextLength
);
assert_eq!(
classify_error("the messages array exceeds the model's context length"),
ErrorKind::ContextLength
);
}
#[test]
fn test_classify_html_proxy_response_as_network() {
assert_eq!(
classify_error("<!DOCTYPE html><html><head><title>502 Bad Gateway</title>"),
ErrorKind::Network
);
assert_eq!(
classify_error("<html><body><h1>503 Service Unavailable</h1></body></html>"),
ErrorKind::Network
);
assert_eq!(
classify_error("ProviderError: <html><head><meta http-equiv=\"refresh\""),
ErrorKind::Network
);
}
#[test]
fn retry_after_http_date_parses() {
let future = chrono::Utc::now() + chrono::Duration::seconds(30);
let header = future.format("%a, %d %b %Y %H:%M:%S GMT").to_string();
let msg = format!("429 Too Many Requests\nRetry-After: {}", header);
let parsed = retry_after_from_error_msg(&msg).expect("HTTP-date should parse");
let secs = parsed.as_secs();
assert!(
(25..=35).contains(&secs),
"expected ~30s, got {}s (header={})",
secs,
header
);
}
#[test]
fn retry_after_http_date_in_past_clamps_to_zero() {
let msg = "Retry-After: Thu, 01 Jan 1970 00:00:00 GMT";
let parsed = retry_after_from_error_msg(msg).expect("past HTTP-date should parse");
assert_eq!(parsed, Duration::from_secs(0));
}
#[test]
fn test_classify_network() {
assert_eq!(classify_error("connection refused"), ErrorKind::Network);
assert_eq!(
classify_error("connection reset by peer"),
ErrorKind::Network
);
assert_eq!(classify_error("request timed out"), ErrorKind::Network);
assert_eq!(
classify_error("503 service unavailable"),
ErrorKind::Network
);
assert_eq!(
classify_error(
"CompletionError: ProviderError: Http client error: error decoding response body"
),
ErrorKind::Network
);
assert_eq!(classify_error("decode error: EOF"), ErrorKind::Network);
assert_eq!(
classify_error("500 Internal Server Error"),
ErrorKind::Network
);
assert_eq!(classify_error("Http status: 500"), ErrorKind::Network);
assert_eq!(classify_error("status=502"), ErrorKind::Network);
assert_eq!(classify_error("status_code=503"), ErrorKind::Network);
assert_eq!(classify_error("code: 504"), ErrorKind::Network);
assert_eq!(
classify_error("CompletionError: error 500: backend hiccup"),
ErrorKind::Network
);
assert_eq!(
classify_error("received http 502 from upstream"),
ErrorKind::Network
);
}
#[test]
fn user_facing_error_includes_cause() {
let raw = "CompletionError: ProviderError: Http client error: error decoding response body";
let pretty = user_facing_error(raw, 1);
assert!(pretty.contains("lost the response stream"));
assert!(pretty.contains("hint:"));
assert!(pretty.contains("cause:"));
assert!(pretty.contains(raw));
}
#[test]
fn user_facing_error_classifies_auth() {
let pretty = user_facing_error("401 unauthorized", 1);
assert!(pretty.contains("authentication failed"));
assert!(pretty.contains("API key"));
}
#[test]
fn user_facing_error_classifies_context_length() {
let pretty = user_facing_error("maximum context length exceeded", 1);
assert!(pretty.contains("/compress"));
}
#[test]
fn test_classify_rate_limit() {
assert_eq!(classify_error("rate limit exceeded"), ErrorKind::RateLimit);
assert_eq!(
classify_error("429 too many requests"),
ErrorKind::RateLimit
);
}
#[test]
fn classify_anthropic_overloaded_error_as_retryable() {
assert_eq!(
classify_error("overloaded_error: Anthropic API is overloaded"),
ErrorKind::RateLimit,
);
assert_eq!(
classify_error("Provider overloaded; please retry later"),
ErrorKind::RateLimit,
);
}
#[test]
fn test_classify_auth() {
assert_eq!(classify_error("401 unauthorized"), ErrorKind::Auth);
assert_eq!(classify_error("invalid api key"), ErrorKind::Auth);
}
#[test]
fn test_classify_other() {
assert_eq!(classify_error("something else"), ErrorKind::Other);
assert_eq!(classify_error("file not found"), ErrorKind::Other);
assert_eq!(
classify_error("database connection closed"),
ErrorKind::Other
);
assert_eq!(classify_error("form reset successful"), ErrorKind::Other);
assert_eq!(classify_error("processed 500 items"), ErrorKind::Other);
}
#[test]
fn test_retry_policy() {
let policy = RecoveryPolicy::default();
assert!(policy.should_retry(0, ErrorKind::Network));
assert!(policy.should_retry(2, ErrorKind::Network));
assert!(policy.should_retry(4, ErrorKind::Network));
assert!(!policy.should_retry(5, ErrorKind::Network));
assert!(policy.should_retry(0, ErrorKind::RateLimit));
assert!(!policy.should_retry(0, ErrorKind::ContextLength));
assert!(!policy.should_retry(0, ErrorKind::Auth));
assert!(!policy.should_retry(0, ErrorKind::Other));
}
#[test]
fn test_backoff_duration() {
let policy = RecoveryPolicy::default();
let d0 = policy.backoff_duration(0);
let d1 = policy.backoff_duration(1);
let d2 = policy.backoff_duration(2);
assert!(d0 >= Duration::from_secs(1));
assert!(d1 >= Duration::from_secs(2));
assert!(d2 >= Duration::from_secs(4));
}
#[test]
fn test_backoff_overflow_guard() {
let policy = RecoveryPolicy::default();
let d = policy.backoff_duration(20); assert!(d >= Duration::from_secs(64));
assert!(d < Duration::from_secs(81));
}
#[test]
fn test_backoff_jitter_present() {
let policy = RecoveryPolicy::default();
let mut seen = std::collections::HashSet::new();
for _ in 0..8 {
seen.insert(policy.backoff_duration(3));
std::thread::sleep(Duration::from_millis(1));
}
assert!(
seen.len() > 1,
"expected jittered backoff to vary across calls"
);
}
#[test]
fn retry_after_parses_anthropic_ms() {
let msg = "rate limited: retry-after-ms: 5000";
assert_eq!(
retry_after_from_error_msg(msg),
Some(Duration::from_millis(5000)),
);
}
#[test]
fn retry_after_parses_standard_seconds() {
let msg = "HTTP 429 Too Many Requests\nRetry-After: 30";
assert_eq!(
retry_after_from_error_msg(msg),
Some(Duration::from_secs(30)),
);
}
#[test]
fn retry_after_parses_json_body() {
let msg = r#"{"error":"rate_limit","retry_after":12}"#;
assert_eq!(
retry_after_from_error_msg(msg),
Some(Duration::from_secs(12)),
);
}
#[test]
fn retry_after_parses_no_colon() {
let msg = "got 429, retry-after 7 next time";
assert_eq!(
retry_after_from_error_msg(msg),
Some(Duration::from_secs(7)),
);
}
#[test]
fn retry_after_returns_none_when_absent() {
let msg = "generic network error: connection reset";
assert_eq!(retry_after_from_error_msg(msg), None);
}
#[test]
fn retry_after_handles_unicode_before_label() {
let msg = "İoError: Retry-After: 8";
assert_eq!(
retry_after_from_error_msg(msg),
Some(Duration::from_secs(8)),
);
}
#[test]
fn retry_after_label_match_is_case_insensitive() {
assert_eq!(
retry_after_from_error_msg("rate limited: RETRY-AFTER-MS: 750"),
Some(Duration::from_millis(750)),
);
assert_eq!(
retry_after_from_error_msg("Retry-After-Ms: 750"),
Some(Duration::from_millis(750)),
);
}
#[test]
fn retry_after_caps_pathological_digit_run() {
let msg = "Retry-After: 99999999999999999999999";
let parsed = retry_after_from_error_msg(msg);
assert!(parsed.is_some(), "must parse, not return None");
let policy = RecoveryPolicy::default();
let d = policy.backoff_duration_for_msg(0, msg);
assert!(
d <= Duration::from_secs(300),
"backoff must cap at 5min; got {:?}",
d,
);
}
#[test]
fn backoff_duration_for_msg_prefers_longer_value() {
let policy = RecoveryPolicy::default();
let d = policy.backoff_duration_for_msg(0, "Retry-After: 10");
assert!(d >= Duration::from_secs(10) && d < Duration::from_secs(11));
let d = policy.backoff_duration_for_msg(3, "retry-after-ms: 50");
assert!(d >= Duration::from_secs(8));
}
#[test]
fn backoff_duration_for_msg_caps_at_5_minutes() {
let policy = RecoveryPolicy::default();
let d = policy.backoff_duration_for_msg(0, "Retry-After: 9999");
assert!(d <= Duration::from_secs(300));
}
}