oxify_connect_llm/
retry.rs

1//! Retry logic with exponential backoff for LLM providers
2
3use crate::{
4    EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmError, LlmProvider, LlmRequest,
5    LlmResponse, LlmStream, Result, StreamingLlmProvider,
6};
7use async_trait::async_trait;
8use std::time::Duration;
9
10/// Configuration for retry behavior
11#[derive(Debug, Clone)]
12pub struct RetryConfig {
13    /// Maximum number of retry attempts (default: 3)
14    pub max_retries: u32,
15
16    /// Initial delay before first retry (default: 1s)
17    pub initial_delay: Duration,
18
19    /// Maximum delay between retries (default: 30s)
20    pub max_delay: Duration,
21
22    /// Multiplier for exponential backoff (default: 2.0)
23    pub backoff_multiplier: f64,
24
25    /// Whether to add jitter to delays (default: true)
26    pub jitter: bool,
27}
28
29impl Default for RetryConfig {
30    fn default() -> Self {
31        Self {
32            max_retries: 3,
33            initial_delay: Duration::from_secs(1),
34            max_delay: Duration::from_secs(30),
35            backoff_multiplier: 2.0,
36            jitter: true,
37        }
38    }
39}
40
41impl RetryConfig {
42    /// Create a new config with custom max retries
43    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
44        self.max_retries = max_retries;
45        self
46    }
47
48    /// Create a new config with custom initial delay
49    pub fn with_initial_delay(mut self, delay: Duration) -> Self {
50        self.initial_delay = delay;
51        self
52    }
53
54    /// Create a new config with custom max delay
55    pub fn with_max_delay(mut self, delay: Duration) -> Self {
56        self.max_delay = delay;
57        self
58    }
59
60    /// Create a new config with jitter disabled
61    pub fn without_jitter(mut self) -> Self {
62        self.jitter = false;
63        self
64    }
65
66    /// Calculate delay for a given attempt number (0-indexed)
67    fn calculate_delay(&self, attempt: u32) -> Duration {
68        let base_delay =
69            self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
70        let delay_ms = base_delay.min(self.max_delay.as_millis() as f64);
71
72        let delay_ms = if self.jitter {
73            // Add jitter: random value between 0 and delay_ms
74            use std::collections::hash_map::DefaultHasher;
75            use std::hash::{Hash, Hasher};
76            use std::time::SystemTime;
77
78            let mut hasher = DefaultHasher::new();
79            SystemTime::now().hash(&mut hasher);
80            attempt.hash(&mut hasher);
81            let hash = hasher.finish();
82
83            let jitter_factor = (hash % 1000) as f64 / 1000.0; // 0.0 to 1.0
84            delay_ms * (0.5 + jitter_factor * 0.5) // Between 50% and 100% of delay
85        } else {
86            delay_ms
87        };
88
89        Duration::from_millis(delay_ms as u64)
90    }
91}
92
93/// Check if an error is retryable
94fn is_retryable_error(error: &LlmError) -> bool {
95    matches!(
96        error,
97        LlmError::RateLimited(_)
98            | LlmError::NetworkError(_)
99            | LlmError::ApiError(_)
100            | LlmError::Timeout(_)
101    )
102}
103
104/// A wrapper that adds retry functionality to any LLM provider
105pub struct RetryProvider<P> {
106    inner: P,
107    config: RetryConfig,
108}
109
110impl<P> RetryProvider<P> {
111    /// Create a new RetryProvider with default configuration
112    pub fn new(provider: P) -> Self {
113        Self {
114            inner: provider,
115            config: RetryConfig::default(),
116        }
117    }
118
119    /// Create a new RetryProvider with custom configuration
120    pub fn with_config(provider: P, config: RetryConfig) -> Self {
121        Self {
122            inner: provider,
123            config,
124        }
125    }
126
127    /// Get a reference to the inner provider
128    pub fn inner(&self) -> &P {
129        &self.inner
130    }
131
132    /// Get a mutable reference to the inner provider
133    pub fn inner_mut(&mut self) -> &mut P {
134        &mut self.inner
135    }
136
137    /// Get the retry configuration
138    pub fn config(&self) -> &RetryConfig {
139        &self.config
140    }
141}
142
143#[async_trait]
144impl<P: LlmProvider> LlmProvider for RetryProvider<P> {
145    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
146        let mut last_error = None;
147
148        for attempt in 0..=self.config.max_retries {
149            match self.inner.complete(request.clone()).await {
150                Ok(response) => return Ok(response),
151                Err(e) => {
152                    if attempt < self.config.max_retries && is_retryable_error(&e) {
153                        // Use Retry-After header if available, otherwise use exponential backoff
154                        let delay = e
155                            .retry_after()
156                            .unwrap_or_else(|| self.config.calculate_delay(attempt));
157                        tracing::warn!(
158                            attempt = attempt + 1,
159                            max_retries = self.config.max_retries,
160                            delay_ms = delay.as_millis(),
161                            error = %e,
162                            "LLM request failed, retrying"
163                        );
164                        tokio::time::sleep(delay).await;
165                        last_error = Some(e);
166                    } else {
167                        return Err(e);
168                    }
169                }
170            }
171        }
172
173        Err(last_error.unwrap_or(LlmError::ApiError("Unknown retry error".to_string())))
174    }
175}
176
177#[async_trait]
178impl<P: StreamingLlmProvider> StreamingLlmProvider for RetryProvider<P> {
179    async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
180        let mut last_error = None;
181
182        for attempt in 0..=self.config.max_retries {
183            match self.inner.complete_stream(request.clone()).await {
184                Ok(stream) => return Ok(stream),
185                Err(e) => {
186                    if attempt < self.config.max_retries && is_retryable_error(&e) {
187                        // Use Retry-After header if available, otherwise use exponential backoff
188                        let delay = e
189                            .retry_after()
190                            .unwrap_or_else(|| self.config.calculate_delay(attempt));
191                        tracing::warn!(
192                            attempt = attempt + 1,
193                            max_retries = self.config.max_retries,
194                            delay_ms = delay.as_millis(),
195                            error = %e,
196                            "LLM stream request failed, retrying"
197                        );
198                        tokio::time::sleep(delay).await;
199                        last_error = Some(e);
200                    } else {
201                        return Err(e);
202                    }
203                }
204            }
205        }
206
207        Err(last_error.unwrap_or(LlmError::ApiError("Unknown retry error".to_string())))
208    }
209}
210
211#[async_trait]
212impl<P: EmbeddingProvider> EmbeddingProvider for RetryProvider<P> {
213    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
214        let mut last_error = None;
215
216        for attempt in 0..=self.config.max_retries {
217            match self.inner.embed(request.clone()).await {
218                Ok(response) => return Ok(response),
219                Err(e) => {
220                    if attempt < self.config.max_retries && is_retryable_error(&e) {
221                        // Use Retry-After header if available, otherwise use exponential backoff
222                        let delay = e
223                            .retry_after()
224                            .unwrap_or_else(|| self.config.calculate_delay(attempt));
225                        tracing::warn!(
226                            attempt = attempt + 1,
227                            max_retries = self.config.max_retries,
228                            delay_ms = delay.as_millis(),
229                            error = %e,
230                            "Embedding request failed, retrying"
231                        );
232                        tokio::time::sleep(delay).await;
233                        last_error = Some(e);
234                    } else {
235                        return Err(e);
236                    }
237                }
238            }
239        }
240
241        Err(last_error.unwrap_or(LlmError::ApiError("Unknown retry error".to_string())))
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_retry_config_default() {
251        let config = RetryConfig::default();
252        assert_eq!(config.max_retries, 3);
253        assert_eq!(config.initial_delay, Duration::from_secs(1));
254        assert_eq!(config.max_delay, Duration::from_secs(30));
255        assert_eq!(config.backoff_multiplier, 2.0);
256        assert!(config.jitter);
257    }
258
259    #[test]
260    fn test_retry_config_builder() {
261        let config = RetryConfig::default()
262            .with_max_retries(5)
263            .with_initial_delay(Duration::from_millis(500))
264            .with_max_delay(Duration::from_secs(60))
265            .without_jitter();
266
267        assert_eq!(config.max_retries, 5);
268        assert_eq!(config.initial_delay, Duration::from_millis(500));
269        assert_eq!(config.max_delay, Duration::from_secs(60));
270        assert!(!config.jitter);
271    }
272
273    #[test]
274    fn test_calculate_delay_no_jitter() {
275        let config = RetryConfig::default().without_jitter();
276
277        // Initial delay: 1s
278        assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
279        // Second attempt: 1s * 2^1 = 2s
280        assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
281        // Third attempt: 1s * 2^2 = 4s
282        assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
283        // Fourth attempt: 1s * 2^3 = 8s
284        assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
285    }
286
287    #[test]
288    fn test_calculate_delay_with_max() {
289        let config = RetryConfig::default()
290            .with_max_delay(Duration::from_secs(5))
291            .without_jitter();
292
293        // First attempts should be normal
294        assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
295        assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
296        assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
297
298        // Higher attempts should be capped at max_delay
299        assert_eq!(config.calculate_delay(3), Duration::from_secs(5));
300        assert_eq!(config.calculate_delay(10), Duration::from_secs(5));
301    }
302
303    #[test]
304    fn test_is_retryable_error() {
305        assert!(is_retryable_error(&LlmError::RateLimited(None)));
306        assert!(is_retryable_error(&LlmError::RateLimited(Some(
307            Duration::from_secs(5)
308        ))));
309        assert!(is_retryable_error(&LlmError::ApiError("error".to_string())));
310        assert!(is_retryable_error(&LlmError::Timeout(Duration::from_secs(
311            30
312        ))));
313        assert!(!is_retryable_error(&LlmError::ConfigError(
314            "invalid".to_string()
315        )));
316        assert!(!is_retryable_error(&LlmError::InvalidRequest(
317            "bad".to_string()
318        )));
319    }
320}