1use crate::error::{LiteLLMError, Result};
2use reqwest::header::HeaderMap;
3use reqwest::StatusCode;
4use std::time::Duration;
5use tokio::time::sleep;
6
7pub const MAX_SSE_BUFFER_SIZE: usize = 16 * 1024 * 1024;
9
10pub const DEFAULT_MAX_RETRIES: u32 = 3;
12pub const DEFAULT_INITIAL_BACKOFF_MS: u64 = 1000;
13pub const DEFAULT_MAX_BACKOFF_MS: u64 = 30000;
14pub const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
15
16#[derive(Debug, Clone)]
18pub struct RetryConfig {
19 pub max_retries: u32,
21 pub initial_backoff_ms: u64,
23 pub max_backoff_ms: u64,
25 pub backoff_multiplier: f64,
27}
28
29impl Default for RetryConfig {
30 fn default() -> Self {
31 Self {
32 max_retries: DEFAULT_MAX_RETRIES,
33 initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
34 max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
35 backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
36 }
37 }
38}
39
40impl RetryConfig {
41 fn backoff_duration(&self, attempt: u32) -> Duration {
43 let backoff_ms =
44 (self.initial_backoff_ms as f64) * self.backoff_multiplier.powi(attempt as i32);
45 let clamped_ms = backoff_ms.min(self.max_backoff_ms as f64) as u64;
46 Duration::from_millis(clamped_ms)
47 }
48}
49
50fn is_retryable_status(status: StatusCode) -> bool {
52 matches!(
53 status,
54 StatusCode::TOO_MANY_REQUESTS
55 | StatusCode::SERVICE_UNAVAILABLE
56 | StatusCode::GATEWAY_TIMEOUT
57 | StatusCode::BAD_GATEWAY
58 | StatusCode::REQUEST_TIMEOUT
59 )
60}
61
62pub async fn send_json<T: serde::de::DeserializeOwned>(
66 req: reqwest::RequestBuilder,
67) -> Result<(T, HeaderMap)> {
68 send_json_with_retry(req, &RetryConfig::default()).await
69}
70
71#[allow(unused_variables)]
76pub async fn send_json_with_retry<T: serde::de::DeserializeOwned>(
77 req: reqwest::RequestBuilder,
78 retry_config: &RetryConfig,
79) -> Result<(T, HeaderMap)> {
80 send_json_once(req).await
85}
86
87pub async fn send_json_once<T: serde::de::DeserializeOwned>(
89 req: reqwest::RequestBuilder,
90) -> Result<(T, HeaderMap)> {
91 let resp = req.send().await.map_err(LiteLLMError::from)?;
92
93 let status = resp.status();
94 let headers = resp.headers().clone();
95
96 if !status.is_success() {
97 let text = resp.text().await.map_err(LiteLLMError::from)?;
98 let trimmed = text.lines().take(20).collect::<Vec<_>>().join("\n");
99 return Err(LiteLLMError::http(format!(
100 "http {}: {}",
101 status.as_u16(),
102 trimmed
103 )));
104 }
105
106 let parsed = resp
107 .json()
108 .await
109 .map_err(|e| LiteLLMError::Parse(e.to_string()))?;
110 Ok((parsed, headers))
111}
112
113pub async fn with_retry<T, F, Fut>(
117 retry_config: &RetryConfig,
118 mut build_request: F,
119) -> Result<(T, HeaderMap)>
120where
121 T: serde::de::DeserializeOwned,
122 F: FnMut() -> Fut,
123 Fut: std::future::Future<Output = Result<reqwest::RequestBuilder>>,
124{
125 let mut last_error = None;
126
127 for attempt in 0..=retry_config.max_retries {
128 let req = build_request().await?;
129 let resp = req.send().await;
130
131 match resp {
132 Ok(response) => {
133 let status = response.status();
134 let headers = response.headers().clone();
135
136 if status.is_success() {
137 let parsed = response
138 .json()
139 .await
140 .map_err(|e| LiteLLMError::Parse(e.to_string()))?;
141 return Ok((parsed, headers));
142 }
143
144 let text = response.text().await.map_err(LiteLLMError::from)?;
145
146 if is_retryable_status(status) && attempt < retry_config.max_retries {
148 let backoff = retry_config.backoff_duration(attempt);
149 sleep(backoff).await;
150 last_error = Some(LiteLLMError::http(format!(
151 "http {}: {}",
152 status.as_u16(),
153 text.lines().take(5).collect::<Vec<_>>().join("\n")
154 )));
155 continue;
156 }
157
158 let trimmed = text.lines().take(20).collect::<Vec<_>>().join("\n");
160 return Err(LiteLLMError::http(format!(
161 "http {}: {}",
162 status.as_u16(),
163 trimmed
164 )));
165 }
166 Err(e) => {
167 if attempt < retry_config.max_retries {
169 let backoff = retry_config.backoff_duration(attempt);
170 sleep(backoff).await;
171 last_error = Some(LiteLLMError::from(e));
172 continue;
173 }
174 return Err(LiteLLMError::from(e));
175 }
176 }
177 }
178
179 Err(last_error.unwrap_or_else(|| LiteLLMError::http("max retries exceeded")))
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn retry_config_backoff_calculation() {
188 let config = RetryConfig {
189 max_retries: 3,
190 initial_backoff_ms: 1000,
191 max_backoff_ms: 10000,
192 backoff_multiplier: 2.0,
193 };
194
195 assert_eq!(config.backoff_duration(0), Duration::from_millis(1000));
196 assert_eq!(config.backoff_duration(1), Duration::from_millis(2000));
197 assert_eq!(config.backoff_duration(2), Duration::from_millis(4000));
198 assert_eq!(config.backoff_duration(3), Duration::from_millis(8000));
199 assert_eq!(config.backoff_duration(4), Duration::from_millis(10000));
201 }
202
203 #[test]
204 fn retryable_status_codes() {
205 assert!(is_retryable_status(StatusCode::TOO_MANY_REQUESTS));
206 assert!(is_retryable_status(StatusCode::SERVICE_UNAVAILABLE));
207 assert!(is_retryable_status(StatusCode::GATEWAY_TIMEOUT));
208 assert!(is_retryable_status(StatusCode::BAD_GATEWAY));
209 assert!(is_retryable_status(StatusCode::REQUEST_TIMEOUT));
210
211 assert!(!is_retryable_status(StatusCode::OK));
212 assert!(!is_retryable_status(StatusCode::BAD_REQUEST));
213 assert!(!is_retryable_status(StatusCode::UNAUTHORIZED));
214 assert!(!is_retryable_status(StatusCode::NOT_FOUND));
215 assert!(!is_retryable_status(StatusCode::INTERNAL_SERVER_ERROR));
216 }
217}