use crate::core::utils::join_url;
use crate::error::{Error, Result};
use futures::Stream;
use futures::StreamExt;
use reqwest;
use reqwest::IntoUrl;
use reqwest_eventsource::{Event, RequestBuilderExt};
use serde::de::DeserializeOwned;
use std::pin::Pin;
use std::time::Duration;
#[derive(Debug, Clone)]
struct RetryConfig {
max_retries: u32,
initial_wait: Duration,
max_wait: Duration,
use_jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 5,
initial_wait: Duration::from_secs(1),
max_wait: Duration::from_secs(30),
use_jitter: true,
}
}
}
fn is_retryable_status(status: reqwest::StatusCode) -> bool {
matches!(
status,
reqwest::StatusCode::TOO_MANY_REQUESTS
| reqwest::StatusCode::BAD_GATEWAY
| reqwest::StatusCode::SERVICE_UNAVAILABLE
| reqwest::StatusCode::GATEWAY_TIMEOUT
)
}
fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
headers
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| {
if let Ok(seconds) = s.parse::<u64>() {
return Some(Duration::from_secs(seconds));
}
None
})
}
fn calculate_backoff(
retry_count: u32,
config: &RetryConfig,
retry_after: Option<Duration>,
) -> Duration {
if let Some(duration) = retry_after {
return duration.min(config.max_wait);
}
let backoff = config
.initial_wait
.saturating_mul(2_u32.saturating_pow(retry_count));
let backoff = backoff.min(config.max_wait);
if config.use_jitter {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let jitter_pct = ((now % 200) as i64 - 100) as f64 / 1000.0; let jitter_ms = (backoff.as_millis() as f64 * jitter_pct) as i64;
if jitter_ms >= 0 {
backoff.saturating_add(Duration::from_millis(jitter_ms as u64))
} else {
backoff.saturating_sub(Duration::from_millis((-jitter_ms) as u64))
}
} else {
backoff
}
}
async fn retry_request<F, T>(
url: reqwest::Url,
method: reqwest::Method,
headers: reqwest::header::HeaderMap,
query_params: Vec<(&str, &str)>,
body_fn: F,
config: RetryConfig,
) -> Result<T>
where
F: Fn() -> reqwest::Body,
T: DeserializeOwned + std::fmt::Debug,
{
let client = reqwest::Client::new();
let mut retry_count = 0;
loop {
let body = body_fn();
let resp = client
.request(method.clone(), url.clone())
.headers(headers.clone())
.query(&query_params)
.body(body)
.send()
.await
.map_err(|e| {
if e.is_timeout() || e.is_connect() {
log::warn!(
"Request failed with retryable error (attempt {}/{}): {}",
retry_count + 1,
config.max_retries + 1,
e
);
} else {
log::error!("Request failed: {e}");
}
Error::ApiError {
status_code: e.status(),
details: e.to_string(),
}
})?;
let status = resp.status();
let response_headers = resp.headers().clone();
let resp_text = resp.text().await.map_err(|e| Error::ApiError {
status_code: e.status(),
details: format!("Failed to read response: {e}"),
})?;
if status.is_success() {
log::debug!("Request succeeded on attempt {}", retry_count + 1);
return serde_json::from_str(&resp_text).map_err(|e| Error::ApiError {
status_code: Some(status),
details: format!("Failed to parse response: {e}"),
});
}
if is_retryable_status(status) && retry_count < config.max_retries {
retry_count += 1;
let retry_after = parse_retry_after(&response_headers);
let wait_time = calculate_backoff(retry_count - 1, &config, retry_after);
log::warn!(
"Request failed with status {} (attempt {}/{}). Retrying after {:?}...",
status,
retry_count,
config.max_retries + 1,
wait_time
);
tokio::time::sleep(wait_time).await;
continue;
}
if retry_count >= config.max_retries {
log::error!(
"Request failed after {} retries with status {}: {}",
retry_count + 1,
status,
resp_text
);
} else {
log::error!("Request failed with non-retryable status {status}: {resp_text}");
}
return Err(Error::ApiError {
status_code: Some(status),
details: resp_text,
});
}
}
#[allow(dead_code)]
pub(crate) trait LanguageModelClient {
type Response: DeserializeOwned + std::fmt::Debug + Clone;
type StreamEvent: DeserializeOwned + std::fmt::Debug + Clone;
fn path(&self) -> String;
fn method(&self) -> reqwest::Method;
fn query_params(&self) -> Vec<(&str, &str)>;
fn body(&self) -> reqwest::Body;
fn headers(&self) -> reqwest::header::HeaderMap;
async fn send(&self, base_url: impl IntoUrl) -> Result<Self::Response> {
let url = join_url(base_url, &self.path())?;
let body_bytes = {
let body = self.body();
match body.as_bytes() {
Some(bytes) => bytes.to_vec(),
None => {
log::warn!("Request body is not retryable (streaming body)");
vec![]
}
}
};
let method = self.method();
let headers = self.headers();
let query_params = self.query_params();
let config = RetryConfig::default();
retry_request(
url,
method,
headers,
query_params,
move || reqwest::Body::from(body_bytes.clone()),
config,
)
.await
}
fn parse_stream_sse(
event: std::result::Result<Event, reqwest_eventsource::Error>,
) -> Result<Self::StreamEvent>;
fn end_stream(event: &Self::StreamEvent) -> bool;
async fn send_and_stream(
&self,
base_url: impl IntoUrl,
) -> Result<Pin<Box<dyn Stream<Item = Result<Self::StreamEvent>> + Send>>>
where
Self::StreamEvent: Send + 'static,
Self: Sync,
{
let client = reqwest::Client::new();
let url = join_url(base_url, &self.path())?;
let events_stream = client
.request(self.method(), url.clone())
.headers(self.headers())
.query(&self.query_params())
.body(self.body())
.eventsource()
.map_err(|e| Error::ApiError {
status_code: None,
details: format!("SSE stream error: {e}"),
})?;
let mapped_stream = events_stream.map(|event_result| Self::parse_stream_sse(event_result));
let ended = std::sync::Arc::new(std::sync::Mutex::new(false));
let stream = mapped_stream.scan(ended, |ended, res| {
let mut ended = ended.lock().unwrap();
if *ended {
return futures::future::ready(None); }
*ended = res.as_ref().map_or(true, |evt| Self::end_stream(evt));
futures::future::ready(Some(res)) });
Ok(Box::pin(stream))
}
}
#[allow(dead_code)]
pub(crate) trait EmbeddingClient {
type Response: DeserializeOwned + std::fmt::Debug + Clone;
fn path(&self) -> String;
fn method(&self) -> reqwest::Method;
fn query_params(&self) -> Vec<(&str, &str)>;
fn body(&self) -> reqwest::Body;
fn headers(&self) -> reqwest::header::HeaderMap;
async fn send(&self, base_url: impl IntoUrl) -> Result<Self::Response> {
let base_url = base_url
.into_url()
.map_err(|_| Error::InvalidInput("Invalid base URL".into()))?;
let url = join_url(base_url, &self.path())?;
let body_bytes = {
let body = self.body();
match body.as_bytes() {
Some(bytes) => bytes.to_vec(),
None => {
log::warn!("Request body is not retryable (streaming body)");
vec![]
}
}
};
let method = self.method();
let headers = self.headers();
let query_params = self.query_params();
let config = RetryConfig::default();
retry_request(
url,
method,
headers,
query_params,
move || reqwest::Body::from(body_bytes.clone()),
config,
)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config(
max_retries: u32,
initial_wait_ms: u64,
max_wait_ms: u64,
use_jitter: bool,
) -> RetryConfig {
RetryConfig {
max_retries,
initial_wait: Duration::from_millis(initial_wait_ms),
max_wait: Duration::from_millis(max_wait_ms),
use_jitter,
}
}
#[test]
fn test_calculate_backoff_with_retry_after_header() {
let config = test_config(5, 1000, 30000, false);
let retry_after = Some(Duration::from_secs(5));
let result = calculate_backoff(0, &config, retry_after);
assert_eq!(result, Duration::from_secs(5));
}
#[test]
fn test_calculate_backoff_retry_after_capped_at_max_wait() {
let config = test_config(5, 1000, 10000, false);
let retry_after = Some(Duration::from_secs(60));
let result = calculate_backoff(0, &config, retry_after);
assert_eq!(result, Duration::from_millis(10000));
}
#[test]
fn test_calculate_backoff_retry_after_below_max_wait() {
let config = test_config(5, 1000, 30000, false);
let retry_after = Some(Duration::from_millis(5000));
let result = calculate_backoff(3, &config, retry_after);
assert_eq!(result, Duration::from_millis(5000));
}
#[test]
fn test_calculate_backoff_retry_after_zero() {
let config = test_config(5, 1000, 30000, false);
let retry_after = Some(Duration::from_secs(0));
let result = calculate_backoff(0, &config, retry_after);
assert_eq!(result, Duration::from_secs(0));
}
#[test]
fn test_calculate_backoff_retry_after_very_large() {
let config = test_config(5, 1000, 1000, false);
let retry_after = Some(Duration::from_secs(u64::MAX / 2));
let result = calculate_backoff(0, &config, retry_after);
assert_eq!(result, Duration::from_millis(1000)); }
#[test]
fn test_calculate_backoff_retry_count_zero_no_jitter() {
let config = test_config(5, 1000, 30000, false);
let result = calculate_backoff(0, &config, None);
assert_eq!(result, Duration::from_millis(1000));
}
#[test]
fn test_calculate_backoff_retry_count_one_no_jitter() {
let config = test_config(5, 1000, 30000, false);
let result = calculate_backoff(1, &config, None);
assert_eq!(result, Duration::from_millis(2000));
}
#[test]
fn test_calculate_backoff_retry_count_two_no_jitter() {
let config = test_config(5, 1000, 30000, false);
let result = calculate_backoff(2, &config, None);
assert_eq!(result, Duration::from_millis(4000));
}
#[test]
fn test_calculate_backoff_retry_count_three_no_jitter() {
let config = test_config(5, 1000, 30000, false);
let result = calculate_backoff(3, &config, None);
assert_eq!(result, Duration::from_millis(8000));
}
#[test]
fn test_calculate_backoff_retry_count_four_no_jitter() {
let config = test_config(5, 1000, 30000, false);
let result = calculate_backoff(4, &config, None);
assert_eq!(result, Duration::from_millis(16000));
}
#[test]
fn test_calculate_backoff_retry_count_five_no_jitter() {
let config = test_config(5, 1000, 30000, false);
let result = calculate_backoff(5, &config, None);
assert_eq!(result, Duration::from_millis(30000));
}
#[test]
fn test_calculate_backoff_exceeds_max_wait_no_jitter() {
let config = test_config(5, 1000, 5000, false);
let result = calculate_backoff(3, &config, None);
assert_eq!(result, Duration::from_millis(5000));
}
#[test]
fn test_calculate_backoff_exactly_at_max_wait_no_jitter() {
let config = test_config(5, 1000, 8000, false);
let result = calculate_backoff(3, &config, None);
assert_eq!(result, Duration::from_millis(8000));
}
#[test]
fn test_calculate_backoff_large_retry_count_no_jitter() {
let config = test_config(100, 1000, 30000, false);
let result = calculate_backoff(20, &config, None);
assert_eq!(result, Duration::from_millis(30000));
}
#[test]
fn test_calculate_backoff_saturation_no_jitter() {
let config = test_config(5, 1_000_000, 60000, false);
let result = calculate_backoff(10, &config, None);
assert_eq!(result, Duration::from_millis(60000));
}
#[test]
fn test_calculate_backoff_with_jitter_within_range() {
let config = test_config(5, 1000, 30000, true);
let result = calculate_backoff(2, &config, None);
let _base = Duration::from_millis(4000);
let min = Duration::from_millis(3600);
let max = Duration::from_millis(4400);
assert!(
result >= min && result <= max,
"Result {result:?} should be between {min:?} and {max:?}"
);
}
#[test]
fn test_calculate_backoff_with_jitter_different_retry_counts() {
let config = test_config(5, 1000, 30000, true);
for retry_count in 0..5 {
let result = calculate_backoff(retry_count, &config, None);
let base = 1000 * 2_u64.pow(retry_count);
let min = (base as f64 * 0.9) as u64;
let max = (base as f64 * 1.1) as u64;
assert!(
result >= Duration::from_millis(min) && result <= Duration::from_millis(max),
"Retry count {}: result {:?} should be between {:?} and {:?}",
retry_count,
result,
Duration::from_millis(min),
Duration::from_millis(max)
);
}
}
#[test]
fn test_calculate_backoff_jitter_respects_max_wait() {
let config = test_config(5, 1000, 10000, true);
let result = calculate_backoff(4, &config, None);
assert!(
result >= Duration::from_millis(9000) && result <= Duration::from_millis(11000),
"Result {result:?} should be around 10000ms ±10%"
);
}
#[test]
fn test_calculate_backoff_jitter_deterministic_within_run() {
let config = test_config(5, 1000, 30000, true);
let result1 = calculate_backoff(2, &config, None);
std::thread::sleep(Duration::from_nanos(100)); let result2 = calculate_backoff(2, &config, None);
let _base = Duration::from_millis(4000);
let min = Duration::from_millis(3600);
let max = Duration::from_millis(4400);
assert!(result1 >= min && result1 <= max);
assert!(result2 >= min && result2 <= max);
}
#[test]
fn test_calculate_backoff_jitter_at_zero_retry_count() {
let config = test_config(5, 1000, 30000, true);
let result = calculate_backoff(0, &config, None);
assert!(
result >= Duration::from_millis(900) && result <= Duration::from_millis(1100),
"Result {result:?} should be around 1000ms ±10%"
);
}
#[test]
fn test_calculate_backoff_initial_wait_zero() {
let config = test_config(5, 0, 30000, false);
let result = calculate_backoff(5, &config, None);
assert_eq!(result, Duration::from_millis(0));
}
#[test]
fn test_calculate_backoff_max_wait_zero() {
let config = test_config(5, 1000, 0, false);
let result = calculate_backoff(0, &config, None);
assert_eq!(result, Duration::from_millis(0));
}
#[test]
fn test_calculate_backoff_both_zeros() {
let config = test_config(5, 0, 0, false);
let result = calculate_backoff(10, &config, None);
assert_eq!(result, Duration::from_millis(0));
}
#[test]
fn test_calculate_backoff_very_large_initial_wait() {
let config = RetryConfig {
max_retries: 5,
initial_wait: Duration::from_secs(1_000_000),
max_wait: Duration::from_secs(2_000_000),
use_jitter: false,
};
let result = calculate_backoff(0, &config, None);
assert_eq!(result, Duration::from_secs(1_000_000));
}
#[test]
fn test_calculate_backoff_overflow_protection() {
let config = RetryConfig {
max_retries: 100,
initial_wait: Duration::from_millis(u64::MAX / 2),
max_wait: Duration::from_secs(60),
use_jitter: false,
};
let result = calculate_backoff(10, &config, None);
assert_eq!(result, Duration::from_secs(60));
}
#[test]
fn test_calculate_backoff_u32_max_retry_count() {
let config = test_config(u32::MAX, 1000, 30000, false);
let result = calculate_backoff(u32::MAX, &config, None);
assert_eq!(result, Duration::from_millis(30000));
}
#[test]
fn test_calculate_backoff_power_of_two_overflow() {
let config = test_config(100, 1000, 60000, false);
let result = calculate_backoff(63, &config, None);
assert_eq!(result, Duration::from_millis(60000));
}
#[test]
fn test_calculate_backoff_jitter_with_zero_base() {
let config = test_config(5, 0, 30000, true);
let result = calculate_backoff(0, &config, None);
assert_eq!(result, Duration::from_millis(0));
}
#[test]
fn test_calculate_backoff_jitter_with_very_small_base() {
let config = test_config(5, 10, 30000, true);
let result = calculate_backoff(0, &config, None);
assert!(
result >= Duration::from_millis(9) && result <= Duration::from_millis(11),
"Result {result:?} should be around 10ms ±10%"
);
}
#[test]
fn test_calculate_backoff_sequence_increases_exponentially() {
let config = test_config(5, 1000, 100000, false);
let mut prev_result = Duration::from_millis(0);
for retry_count in 0..10 {
let result = calculate_backoff(retry_count, &config, None);
assert!(
result >= prev_result,
"Retry count {retry_count}: {result:?} should be >= {prev_result:?}"
);
if retry_count > 0 {
let ratio = result.as_millis() as f64 / prev_result.as_millis() as f64;
if result.as_millis() < config.max_wait.as_millis() {
assert!(
(ratio - 2.0).abs() < 0.01,
"Retry count {retry_count}: ratio {ratio} should be ~2.0"
);
}
}
prev_result = result;
}
}
#[test]
fn test_is_retryable_status_429() {
assert!(is_retryable_status(reqwest::StatusCode::TOO_MANY_REQUESTS));
}
#[test]
fn test_is_retryable_status_502() {
assert!(is_retryable_status(reqwest::StatusCode::BAD_GATEWAY));
}
#[test]
fn test_is_retryable_status_503() {
assert!(is_retryable_status(
reqwest::StatusCode::SERVICE_UNAVAILABLE
));
}
#[test]
fn test_is_retryable_status_504() {
assert!(is_retryable_status(reqwest::StatusCode::GATEWAY_TIMEOUT));
}
#[test]
fn test_is_retryable_status_200_not_retryable() {
assert!(!is_retryable_status(reqwest::StatusCode::OK));
}
#[test]
fn test_is_retryable_status_400_not_retryable() {
assert!(!is_retryable_status(reqwest::StatusCode::BAD_REQUEST));
}
#[test]
fn test_is_retryable_status_401_not_retryable() {
assert!(!is_retryable_status(reqwest::StatusCode::UNAUTHORIZED));
}
#[test]
fn test_is_retryable_status_403_not_retryable() {
assert!(!is_retryable_status(reqwest::StatusCode::FORBIDDEN));
}
#[test]
fn test_is_retryable_status_404_not_retryable() {
assert!(!is_retryable_status(reqwest::StatusCode::NOT_FOUND));
}
#[test]
fn test_is_retryable_status_500_not_retryable() {
assert!(!is_retryable_status(
reqwest::StatusCode::INTERNAL_SERVER_ERROR
));
}
#[test]
fn test_parse_retry_after_valid_seconds() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::RETRY_AFTER,
reqwest::header::HeaderValue::from_static("120"),
);
let result = parse_retry_after(&headers);
assert_eq!(result, Some(Duration::from_secs(120)));
}
#[test]
fn test_parse_retry_after_zero_seconds() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::RETRY_AFTER,
reqwest::header::HeaderValue::from_static("0"),
);
let result = parse_retry_after(&headers);
assert_eq!(result, Some(Duration::from_secs(0)));
}
#[test]
fn test_parse_retry_after_large_seconds() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::RETRY_AFTER,
reqwest::header::HeaderValue::from_static("86400"), );
let result = parse_retry_after(&headers);
assert_eq!(result, Some(Duration::from_secs(86400)));
}
#[test]
fn test_parse_retry_after_missing_header() {
let headers = reqwest::header::HeaderMap::new();
let result = parse_retry_after(&headers);
assert_eq!(result, None);
}
#[test]
fn test_parse_retry_after_invalid_format() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::RETRY_AFTER,
reqwest::header::HeaderValue::from_static("invalid"),
);
let result = parse_retry_after(&headers);
assert_eq!(result, None);
}
#[test]
fn test_parse_retry_after_http_date_format() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::RETRY_AFTER,
reqwest::header::HeaderValue::from_static("Wed, 21 Oct 2025 07:28:00 GMT"),
);
let result = parse_retry_after(&headers);
assert_eq!(result, None); }
#[test]
fn test_parse_retry_after_negative_number() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::RETRY_AFTER,
reqwest::header::HeaderValue::from_static("-10"),
);
let result = parse_retry_after(&headers);
assert_eq!(result, None); }
#[test]
fn test_parse_retry_after_decimal_number() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::RETRY_AFTER,
reqwest::header::HeaderValue::from_static("10.5"),
);
let result = parse_retry_after(&headers);
assert_eq!(result, None); }
}