use anyllm_client::retry::{
backoff_delay, is_quota_exhausted, is_retryable, parse_retry_after, RetryPolicy,
};
use crate::error::BackendError;
pub const DEFAULT_MAX_RETRIES: u32 = 2;
fn parse_max_retries(raw: Option<&str>) -> u32 {
match raw {
Some(value) => value.trim().parse::<u32>().unwrap_or(DEFAULT_MAX_RETRIES),
None => DEFAULT_MAX_RETRIES,
}
}
pub fn max_retries_from_env() -> u32 {
parse_max_retries(std::env::var("FORGE_UPSTREAM_MAX_RETRIES").ok().as_deref())
}
pub fn upstream_retry_policy() -> RetryPolicy {
RetryPolicy::new(max_retries_from_env()).with_transport_retries(true)
}
pub async fn send_post_with_retry<F>(
mut build: F,
policy: &RetryPolicy,
label: &str,
) -> Result<reqwest::Response, BackendError>
where
F: FnMut() -> reqwest::RequestBuilder,
{
let max_retries = policy.max_retries;
for attempt in 0..=max_retries {
let resp = match build().send().await {
Ok(resp) => resp,
Err(e) => {
if policy.retry_transport_errors && attempt < max_retries && e.is_connect() {
let delay = backoff_delay(attempt, None);
eprintln!(
"warning: {label} transport error (attempt {}/{}); retrying in {}ms",
attempt + 1,
max_retries + 1,
delay.as_millis()
);
tokio::time::sleep(delay).await;
continue;
}
return Err(BackendError::new(0, e.to_string()));
}
};
if resp.status().is_success() {
return Ok(resp);
}
let status = resp.status().as_u16();
if attempt < max_retries && is_retryable(status) {
let retry_after = parse_retry_after(resp.headers());
if status == 429 {
let body_text = resp.text().await.unwrap_or_default();
if is_quota_exhausted(&body_text) {
return Err(BackendError::new(i64::from(status), body_text));
}
} else {
let _ = resp.bytes().await;
}
let delay = backoff_delay(attempt, retry_after);
eprintln!(
"warning: {label} upstream status {status} (attempt {}/{}); retrying in {}ms",
attempt + 1,
max_retries + 1,
delay.as_millis()
);
tokio::time::sleep(delay).await;
continue;
}
let body_text = resp.text().await.unwrap_or_default();
return Err(BackendError::new(i64::from(status), body_text));
}
unreachable!("loop runs max_retries + 1 times and always returns")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_max_retries_handles_values() {
assert_eq!(parse_max_retries(None), DEFAULT_MAX_RETRIES);
assert_eq!(parse_max_retries(Some("0")), 0);
assert_eq!(parse_max_retries(Some(" 4 ")), 4);
assert_eq!(parse_max_retries(Some("bogus")), DEFAULT_MAX_RETRIES);
}
#[test]
fn default_policy_enables_connect_retries() {
assert!(upstream_retry_policy().retry_transport_errors);
}
}