Skip to main content

litellm_rust/
http.rs

1use crate::error::{LiteLLMError, Result};
2use reqwest::header::HeaderMap;
3use reqwest::StatusCode;
4use std::time::Duration;
5use tokio::time::sleep;
6
7/// Maximum buffer size for SSE streaming (16MB)
8pub const MAX_SSE_BUFFER_SIZE: usize = 16 * 1024 * 1024;
9
10/// Default retry configuration
11pub const DEFAULT_MAX_RETRIES: u32 = 3;
12pub const DEFAULT_INITIAL_BACKOFF_MS: u64 = 1000;
13pub const DEFAULT_MAX_BACKOFF_MS: u64 = 30000;
14pub const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
15
16/// Configuration for retry behavior
17#[derive(Debug, Clone)]
18pub struct RetryConfig {
19    /// Maximum number of retry attempts
20    pub max_retries: u32,
21    /// Initial backoff duration in milliseconds
22    pub initial_backoff_ms: u64,
23    /// Maximum backoff duration in milliseconds
24    pub max_backoff_ms: u64,
25    /// Multiplier for exponential backoff
26    pub backoff_multiplier: f64,
27}
28
29impl Default for RetryConfig {
30    fn default() -> Self {
31        Self {
32            max_retries: DEFAULT_MAX_RETRIES,
33            initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
34            max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
35            backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
36        }
37    }
38}
39
40impl RetryConfig {
41    /// Calculate backoff duration for a given attempt number
42    fn backoff_duration(&self, attempt: u32) -> Duration {
43        let backoff_ms =
44            (self.initial_backoff_ms as f64) * self.backoff_multiplier.powi(attempt as i32);
45        let clamped_ms = backoff_ms.min(self.max_backoff_ms as f64) as u64;
46        Duration::from_millis(clamped_ms)
47    }
48}
49
50/// Determines if a status code is retryable
51fn is_retryable_status(status: StatusCode) -> bool {
52    matches!(
53        status,
54        StatusCode::TOO_MANY_REQUESTS
55            | StatusCode::SERVICE_UNAVAILABLE
56            | StatusCode::GATEWAY_TIMEOUT
57            | StatusCode::BAD_GATEWAY
58            | StatusCode::REQUEST_TIMEOUT
59    )
60}
61
62/// Send a JSON request and parse the response.
63///
64/// This function includes retry logic with exponential backoff for transient failures.
65pub async fn send_json<T: serde::de::DeserializeOwned>(
66    req: reqwest::RequestBuilder,
67) -> Result<(T, HeaderMap)> {
68    send_json_with_retry(req, &RetryConfig::default()).await
69}
70
71/// Send a JSON request with custom retry configuration.
72///
73/// Note: Due to reqwest::RequestBuilder not implementing Clone, the retry config
74/// is currently unused for direct builder calls. Use `with_retry` for retryable requests.
75#[allow(unused_variables)]
76pub async fn send_json_with_retry<T: serde::de::DeserializeOwned>(
77    req: reqwest::RequestBuilder,
78    retry_config: &RetryConfig,
79) -> Result<(T, HeaderMap)> {
80    // We need to clone the request for retries, but RequestBuilder doesn't implement Clone.
81    // Instead, we'll try to build the request and handle retries at a higher level.
82    // For now, we execute once - the retry logic should be applied at the call site
83    // where the builder can be recreated.
84    send_json_once(req).await
85}
86
87/// Execute a request once without retries.
88pub async fn send_json_once<T: serde::de::DeserializeOwned>(
89    req: reqwest::RequestBuilder,
90) -> Result<(T, HeaderMap)> {
91    let resp = req.send().await.map_err(LiteLLMError::from)?;
92
93    let status = resp.status();
94    let headers = resp.headers().clone();
95
96    if !status.is_success() {
97        let text = resp.text().await.map_err(LiteLLMError::from)?;
98        let trimmed = text.lines().take(20).collect::<Vec<_>>().join("\n");
99        return Err(LiteLLMError::http(format!(
100            "http {}: {}",
101            status.as_u16(),
102            trimmed
103        )));
104    }
105
106    let parsed = resp
107        .json()
108        .await
109        .map_err(|e| LiteLLMError::Parse(e.to_string()))?;
110    Ok((parsed, headers))
111}
112
113/// Helper to execute a request-building closure with retry logic.
114///
115/// The closure should build and return a fresh RequestBuilder for each attempt.
116pub async fn with_retry<T, F, Fut>(
117    retry_config: &RetryConfig,
118    mut build_request: F,
119) -> Result<(T, HeaderMap)>
120where
121    T: serde::de::DeserializeOwned,
122    F: FnMut() -> Fut,
123    Fut: std::future::Future<Output = Result<reqwest::RequestBuilder>>,
124{
125    let mut last_error = None;
126
127    for attempt in 0..=retry_config.max_retries {
128        let req = build_request().await?;
129        let resp = req.send().await;
130
131        match resp {
132            Ok(response) => {
133                let status = response.status();
134                let headers = response.headers().clone();
135
136                if status.is_success() {
137                    let parsed = response
138                        .json()
139                        .await
140                        .map_err(|e| LiteLLMError::Parse(e.to_string()))?;
141                    return Ok((parsed, headers));
142                }
143
144                let text = response.text().await.map_err(LiteLLMError::from)?;
145
146                // Check if this is a retryable error
147                if is_retryable_status(status) && attempt < retry_config.max_retries {
148                    let backoff = retry_config.backoff_duration(attempt);
149                    sleep(backoff).await;
150                    last_error = Some(LiteLLMError::http(format!(
151                        "http {}: {}",
152                        status.as_u16(),
153                        text.lines().take(5).collect::<Vec<_>>().join("\n")
154                    )));
155                    continue;
156                }
157
158                // Non-retryable error or max retries exceeded
159                let trimmed = text.lines().take(20).collect::<Vec<_>>().join("\n");
160                return Err(LiteLLMError::http(format!(
161                    "http {}: {}",
162                    status.as_u16(),
163                    trimmed
164                )));
165            }
166            Err(e) => {
167                // Network errors are retryable
168                if attempt < retry_config.max_retries {
169                    let backoff = retry_config.backoff_duration(attempt);
170                    sleep(backoff).await;
171                    last_error = Some(LiteLLMError::from(e));
172                    continue;
173                }
174                return Err(LiteLLMError::from(e));
175            }
176        }
177    }
178
179    Err(last_error.unwrap_or_else(|| LiteLLMError::http("max retries exceeded")))
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn retry_config_backoff_calculation() {
188        let config = RetryConfig {
189            max_retries: 3,
190            initial_backoff_ms: 1000,
191            max_backoff_ms: 10000,
192            backoff_multiplier: 2.0,
193        };
194
195        assert_eq!(config.backoff_duration(0), Duration::from_millis(1000));
196        assert_eq!(config.backoff_duration(1), Duration::from_millis(2000));
197        assert_eq!(config.backoff_duration(2), Duration::from_millis(4000));
198        assert_eq!(config.backoff_duration(3), Duration::from_millis(8000));
199        // Should be clamped to max
200        assert_eq!(config.backoff_duration(4), Duration::from_millis(10000));
201    }
202
203    #[test]
204    fn retryable_status_codes() {
205        assert!(is_retryable_status(StatusCode::TOO_MANY_REQUESTS));
206        assert!(is_retryable_status(StatusCode::SERVICE_UNAVAILABLE));
207        assert!(is_retryable_status(StatusCode::GATEWAY_TIMEOUT));
208        assert!(is_retryable_status(StatusCode::BAD_GATEWAY));
209        assert!(is_retryable_status(StatusCode::REQUEST_TIMEOUT));
210
211        assert!(!is_retryable_status(StatusCode::OK));
212        assert!(!is_retryable_status(StatusCode::BAD_REQUEST));
213        assert!(!is_retryable_status(StatusCode::UNAUTHORIZED));
214        assert!(!is_retryable_status(StatusCode::NOT_FOUND));
215        assert!(!is_retryable_status(StatusCode::INTERNAL_SERVER_ERROR));
216    }
217}