oxify_connect_llm/
retry.rs1use crate::{
4 EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmError, LlmProvider, LlmRequest,
5 LlmResponse, LlmStream, Result, StreamingLlmProvider,
6};
7use async_trait::async_trait;
8use std::time::Duration;
9
10#[derive(Debug, Clone)]
12pub struct RetryConfig {
13 pub max_retries: u32,
15
16 pub initial_delay: Duration,
18
19 pub max_delay: Duration,
21
22 pub backoff_multiplier: f64,
24
25 pub jitter: bool,
27}
28
29impl Default for RetryConfig {
30 fn default() -> Self {
31 Self {
32 max_retries: 3,
33 initial_delay: Duration::from_secs(1),
34 max_delay: Duration::from_secs(30),
35 backoff_multiplier: 2.0,
36 jitter: true,
37 }
38 }
39}
40
41impl RetryConfig {
42 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
44 self.max_retries = max_retries;
45 self
46 }
47
48 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
50 self.initial_delay = delay;
51 self
52 }
53
54 pub fn with_max_delay(mut self, delay: Duration) -> Self {
56 self.max_delay = delay;
57 self
58 }
59
60 pub fn without_jitter(mut self) -> Self {
62 self.jitter = false;
63 self
64 }
65
66 fn calculate_delay(&self, attempt: u32) -> Duration {
68 let base_delay =
69 self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
70 let delay_ms = base_delay.min(self.max_delay.as_millis() as f64);
71
72 let delay_ms = if self.jitter {
73 use std::collections::hash_map::DefaultHasher;
75 use std::hash::{Hash, Hasher};
76 use std::time::SystemTime;
77
78 let mut hasher = DefaultHasher::new();
79 SystemTime::now().hash(&mut hasher);
80 attempt.hash(&mut hasher);
81 let hash = hasher.finish();
82
83 let jitter_factor = (hash % 1000) as f64 / 1000.0; delay_ms * (0.5 + jitter_factor * 0.5) } else {
86 delay_ms
87 };
88
89 Duration::from_millis(delay_ms as u64)
90 }
91}
92
93fn is_retryable_error(error: &LlmError) -> bool {
95 matches!(
96 error,
97 LlmError::RateLimited(_)
98 | LlmError::NetworkError(_)
99 | LlmError::ApiError(_)
100 | LlmError::Timeout(_)
101 )
102}
103
104pub struct RetryProvider<P> {
106 inner: P,
107 config: RetryConfig,
108}
109
110impl<P> RetryProvider<P> {
111 pub fn new(provider: P) -> Self {
113 Self {
114 inner: provider,
115 config: RetryConfig::default(),
116 }
117 }
118
119 pub fn with_config(provider: P, config: RetryConfig) -> Self {
121 Self {
122 inner: provider,
123 config,
124 }
125 }
126
127 pub fn inner(&self) -> &P {
129 &self.inner
130 }
131
132 pub fn inner_mut(&mut self) -> &mut P {
134 &mut self.inner
135 }
136
137 pub fn config(&self) -> &RetryConfig {
139 &self.config
140 }
141}
142
143#[async_trait]
144impl<P: LlmProvider> LlmProvider for RetryProvider<P> {
145 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
146 let mut last_error = None;
147
148 for attempt in 0..=self.config.max_retries {
149 match self.inner.complete(request.clone()).await {
150 Ok(response) => return Ok(response),
151 Err(e) => {
152 if attempt < self.config.max_retries && is_retryable_error(&e) {
153 let delay = e
155 .retry_after()
156 .unwrap_or_else(|| self.config.calculate_delay(attempt));
157 tracing::warn!(
158 attempt = attempt + 1,
159 max_retries = self.config.max_retries,
160 delay_ms = delay.as_millis(),
161 error = %e,
162 "LLM request failed, retrying"
163 );
164 tokio::time::sleep(delay).await;
165 last_error = Some(e);
166 } else {
167 return Err(e);
168 }
169 }
170 }
171 }
172
173 Err(last_error.unwrap_or(LlmError::ApiError("Unknown retry error".to_string())))
174 }
175}
176
177#[async_trait]
178impl<P: StreamingLlmProvider> StreamingLlmProvider for RetryProvider<P> {
179 async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
180 let mut last_error = None;
181
182 for attempt in 0..=self.config.max_retries {
183 match self.inner.complete_stream(request.clone()).await {
184 Ok(stream) => return Ok(stream),
185 Err(e) => {
186 if attempt < self.config.max_retries && is_retryable_error(&e) {
187 let delay = e
189 .retry_after()
190 .unwrap_or_else(|| self.config.calculate_delay(attempt));
191 tracing::warn!(
192 attempt = attempt + 1,
193 max_retries = self.config.max_retries,
194 delay_ms = delay.as_millis(),
195 error = %e,
196 "LLM stream request failed, retrying"
197 );
198 tokio::time::sleep(delay).await;
199 last_error = Some(e);
200 } else {
201 return Err(e);
202 }
203 }
204 }
205 }
206
207 Err(last_error.unwrap_or(LlmError::ApiError("Unknown retry error".to_string())))
208 }
209}
210
211#[async_trait]
212impl<P: EmbeddingProvider> EmbeddingProvider for RetryProvider<P> {
213 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
214 let mut last_error = None;
215
216 for attempt in 0..=self.config.max_retries {
217 match self.inner.embed(request.clone()).await {
218 Ok(response) => return Ok(response),
219 Err(e) => {
220 if attempt < self.config.max_retries && is_retryable_error(&e) {
221 let delay = e
223 .retry_after()
224 .unwrap_or_else(|| self.config.calculate_delay(attempt));
225 tracing::warn!(
226 attempt = attempt + 1,
227 max_retries = self.config.max_retries,
228 delay_ms = delay.as_millis(),
229 error = %e,
230 "Embedding request failed, retrying"
231 );
232 tokio::time::sleep(delay).await;
233 last_error = Some(e);
234 } else {
235 return Err(e);
236 }
237 }
238 }
239 }
240
241 Err(last_error.unwrap_or(LlmError::ApiError("Unknown retry error".to_string())))
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_retry_config_default() {
251 let config = RetryConfig::default();
252 assert_eq!(config.max_retries, 3);
253 assert_eq!(config.initial_delay, Duration::from_secs(1));
254 assert_eq!(config.max_delay, Duration::from_secs(30));
255 assert_eq!(config.backoff_multiplier, 2.0);
256 assert!(config.jitter);
257 }
258
259 #[test]
260 fn test_retry_config_builder() {
261 let config = RetryConfig::default()
262 .with_max_retries(5)
263 .with_initial_delay(Duration::from_millis(500))
264 .with_max_delay(Duration::from_secs(60))
265 .without_jitter();
266
267 assert_eq!(config.max_retries, 5);
268 assert_eq!(config.initial_delay, Duration::from_millis(500));
269 assert_eq!(config.max_delay, Duration::from_secs(60));
270 assert!(!config.jitter);
271 }
272
273 #[test]
274 fn test_calculate_delay_no_jitter() {
275 let config = RetryConfig::default().without_jitter();
276
277 assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
279 assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
281 assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
283 assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
285 }
286
287 #[test]
288 fn test_calculate_delay_with_max() {
289 let config = RetryConfig::default()
290 .with_max_delay(Duration::from_secs(5))
291 .without_jitter();
292
293 assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
295 assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
296 assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
297
298 assert_eq!(config.calculate_delay(3), Duration::from_secs(5));
300 assert_eq!(config.calculate_delay(10), Duration::from_secs(5));
301 }
302
303 #[test]
304 fn test_is_retryable_error() {
305 assert!(is_retryable_error(&LlmError::RateLimited(None)));
306 assert!(is_retryable_error(&LlmError::RateLimited(Some(
307 Duration::from_secs(5)
308 ))));
309 assert!(is_retryable_error(&LlmError::ApiError("error".to_string())));
310 assert!(is_retryable_error(&LlmError::Timeout(Duration::from_secs(
311 30
312 ))));
313 assert!(!is_retryable_error(&LlmError::ConfigError(
314 "invalid".to_string()
315 )));
316 assert!(!is_retryable_error(&LlmError::InvalidRequest(
317 "bad".to_string()
318 )));
319 }
320}