1use crate::error::ProviderError;
6use std::sync::Arc;
7use std::time::Duration;
8
9#[non_exhaustive]
15#[derive(Debug, Clone)]
16pub struct CompletionConfig {
17 pub max_tokens: u32,
19 pub temperature: f64,
21}
22
23impl Default for CompletionConfig {
24 fn default() -> Self {
25 Self {
26 max_tokens: 4096,
27 temperature: 0.0,
28 }
29 }
30}
31
32#[async_trait::async_trait]
42pub trait LlmProvider: Send + Sync {
43 async fn complete(
53 &self,
54 system_prompt: &str,
55 user_prompt: &str,
56 config: &CompletionConfig,
57 ) -> Result<String, ProviderError>;
58
59 fn name(&self) -> &str;
61
62 fn model(&self) -> &str;
64}
65
66pub struct RetryProvider {
75 inner: Arc<dyn LlmProvider>,
76 pub max_retries: u32,
78 pub base_delay: Duration,
80}
81
82impl RetryProvider {
83 pub fn new(inner: Arc<dyn LlmProvider>) -> Self {
88 Self {
89 inner,
90 max_retries: 3,
91 base_delay: Duration::from_secs(1),
92 }
93 }
94
95 pub fn with_config(
102 inner: Arc<dyn LlmProvider>,
103 max_retries: u32,
104 base_delay: Duration,
105 ) -> Self {
106 Self {
107 inner,
108 max_retries,
109 base_delay,
110 }
111 }
112}
113
114fn is_retryable(error: &ProviderError) -> bool {
127 match error {
128 ProviderError::Timeout { .. } | ProviderError::Network { .. } => true,
129 ProviderError::Http { status, .. } => *status == 500 || *status == 429,
130 _ => false,
131 }
132}
133
134#[async_trait::async_trait]
135impl LlmProvider for RetryProvider {
136 async fn complete(
137 &self,
138 system_prompt: &str,
139 user_prompt: &str,
140 config: &CompletionConfig,
141 ) -> Result<String, ProviderError> {
142 let mut last_error = None;
143 let mut delay = self.base_delay;
144 for attempt in 0..=self.max_retries {
145 match self
146 .inner
147 .complete(system_prompt, user_prompt, config)
148 .await
149 {
150 Ok(response) => return Ok(response),
151 Err(err) => {
152 if !is_retryable(&err) || attempt == self.max_retries {
153 return Err(err);
154 }
155 last_error = Some(err);
156 tokio::time::sleep(delay).await;
157 delay = delay.saturating_mul(2);
158 }
159 }
160 }
161 Err(last_error.expect("at least one attempt must have been made"))
162 }
163
164 fn name(&self) -> &str {
165 self.inner.name()
166 }
167
168 fn model(&self) -> &str {
169 self.inner.model()
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use std::sync::Arc;
177 use std::sync::atomic::{AtomicU32, Ordering};
178 use std::time::Duration;
179
180 struct MockProvider {
182 provider_name: String,
183 provider_model: String,
184 responses: std::sync::Mutex<Vec<Result<String, ProviderError>>>,
185 call_count: AtomicU32,
186 }
187
188 impl MockProvider {
189 fn new(name: &str, model: &str) -> Self {
190 Self {
191 provider_name: name.to_string(),
192 provider_model: model.to_string(),
193 responses: std::sync::Mutex::new(Vec::new()),
194 call_count: AtomicU32::new(0),
195 }
196 }
197
198 fn with_responses(
199 name: &str,
200 model: &str,
201 responses: Vec<Result<String, ProviderError>>,
202 ) -> Self {
203 let mut reversed = responses;
205 reversed.reverse();
206 Self {
207 provider_name: name.to_string(),
208 provider_model: model.to_string(),
209 responses: std::sync::Mutex::new(reversed),
210 call_count: AtomicU32::new(0),
211 }
212 }
213
214 fn call_count(&self) -> u32 {
215 self.call_count.load(Ordering::SeqCst)
216 }
217 }
218
219 #[async_trait::async_trait]
220 impl LlmProvider for MockProvider {
221 async fn complete(
222 &self,
223 _system_prompt: &str,
224 _user_prompt: &str,
225 _config: &CompletionConfig,
226 ) -> Result<String, ProviderError> {
227 self.call_count.fetch_add(1, Ordering::SeqCst);
228 let mut responses = self.responses.lock().unwrap();
229 if let Some(result) = responses.pop() {
230 result
231 } else {
232 Ok("default response".to_string())
233 }
234 }
235
236 fn name(&self) -> &str {
237 &self.provider_name
238 }
239
240 fn model(&self) -> &str {
241 &self.provider_model
242 }
243 }
244
245 #[test]
249 fn test_completion_config_default_values() {
250 let config = CompletionConfig::default();
251 assert_eq!(config.max_tokens, 4096);
252 assert!((config.temperature - 0.0).abs() < f64::EPSILON);
253 }
254
255 #[test]
257 fn test_completion_config_is_non_exhaustive() {
258 let config = CompletionConfig::default();
259 assert_eq!(config.max_tokens, 4096);
260 assert!((config.temperature).abs() < f64::EPSILON);
261 }
262
263 #[tokio::test]
267 async fn test_retry_provider_delegates_name() {
268 let mock = Arc::new(MockProvider::new("test-provider", "test-model"));
269 let retry = RetryProvider::new(mock);
270 assert_eq!(retry.name(), "test-provider");
271 }
272
273 #[tokio::test]
275 async fn test_retry_provider_delegates_model() {
276 let mock = Arc::new(MockProvider::new("test-provider", "test-model"));
277 let retry = RetryProvider::new(mock);
278 assert_eq!(retry.model(), "test-model");
279 }
280
281 #[tokio::test]
285 async fn test_retry_provider_retries_on_timeout() {
286 let mock = Arc::new(MockProvider::with_responses(
287 "p",
288 "m",
289 vec![
290 Err(ProviderError::Timeout {
291 message: "t1".into(),
292 }),
293 Err(ProviderError::Timeout {
294 message: "t2".into(),
295 }),
296 Ok("success".into()),
297 ],
298 ));
299 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
300 let config = CompletionConfig::default();
301 let result = retry.complete("sys", "usr", &config).await;
302 assert!(result.is_ok());
303 assert_eq!(result.unwrap(), "success");
304 assert_eq!(mock.call_count(), 3);
305 }
306
307 #[tokio::test]
309 async fn test_retry_provider_retries_on_http_500() {
310 let mock = Arc::new(MockProvider::with_responses(
311 "p",
312 "m",
313 vec![
314 Err(ProviderError::Http {
315 status: 500,
316 body: "err".into(),
317 }),
318 Ok("ok".into()),
319 ],
320 ));
321 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
322 let config = CompletionConfig::default();
323 let result = retry.complete("sys", "usr", &config).await;
324 assert!(result.is_ok());
325 assert_eq!(mock.call_count(), 2);
326 }
327
328 #[tokio::test]
330 async fn test_retry_provider_retries_on_http_429() {
331 let mock = Arc::new(MockProvider::with_responses(
332 "p",
333 "m",
334 vec![
335 Err(ProviderError::Http {
336 status: 429,
337 body: "rate limit".into(),
338 }),
339 Ok("ok".into()),
340 ],
341 ));
342 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
343 let config = CompletionConfig::default();
344 let result = retry.complete("sys", "usr", &config).await;
345 assert!(result.is_ok());
346 assert_eq!(mock.call_count(), 2);
347 }
348
349 #[tokio::test]
351 async fn test_retry_provider_retries_on_network() {
352 let mock = Arc::new(MockProvider::with_responses(
353 "p",
354 "m",
355 vec![
356 Err(ProviderError::Network {
357 message: "dns".into(),
358 }),
359 Ok("ok".into()),
360 ],
361 ));
362 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
363 let config = CompletionConfig::default();
364 let result = retry.complete("sys", "usr", &config).await;
365 assert!(result.is_ok());
366 assert_eq!(mock.call_count(), 2);
367 }
368
369 #[tokio::test]
371 async fn test_retry_provider_does_not_retry_on_auth() {
372 let mock = Arc::new(MockProvider::with_responses(
373 "p",
374 "m",
375 vec![Err(ProviderError::Auth {
376 message: "bad key".into(),
377 })],
378 ));
379 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
380 let config = CompletionConfig::default();
381 let result = retry.complete("sys", "usr", &config).await;
382 assert!(result.is_err());
383 assert_eq!(mock.call_count(), 1);
384 }
385
386 #[tokio::test]
388 async fn test_retry_provider_does_not_retry_on_process() {
389 let mock = Arc::new(MockProvider::with_responses(
390 "p",
391 "m",
392 vec![Err(ProviderError::Process {
393 exit_code: Some(1),
394 stderr: "fail".into(),
395 })],
396 ));
397 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
398 let config = CompletionConfig::default();
399 let result = retry.complete("sys", "usr", &config).await;
400 assert!(result.is_err());
401 assert_eq!(mock.call_count(), 1);
402 }
403
404 #[tokio::test]
406 async fn test_retry_provider_does_not_retry_on_nested_session() {
407 let mock = Arc::new(MockProvider::with_responses(
408 "p",
409 "m",
410 vec![Err(ProviderError::NestedSession)],
411 ));
412 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
413 let config = CompletionConfig::default();
414 let result = retry.complete("sys", "usr", &config).await;
415 assert!(result.is_err());
416 assert_eq!(mock.call_count(), 1);
417 }
418
419 #[tokio::test]
421 async fn test_retry_provider_does_not_retry_on_http_4xx() {
422 let mock = Arc::new(MockProvider::with_responses(
423 "p",
424 "m",
425 vec![Err(ProviderError::Http {
426 status: 403,
427 body: "forbidden".into(),
428 })],
429 ));
430 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
431 let config = CompletionConfig::default();
432 let result = retry.complete("sys", "usr", &config).await;
433 assert!(result.is_err());
434 assert_eq!(mock.call_count(), 1);
435 }
436
437 #[tokio::test]
439 async fn test_retry_provider_returns_last_error_after_exhausting_retries() {
440 let mock = Arc::new(MockProvider::with_responses(
441 "p",
442 "m",
443 vec![
444 Err(ProviderError::Timeout {
445 message: "t1".into(),
446 }),
447 Err(ProviderError::Timeout {
448 message: "t2".into(),
449 }),
450 Err(ProviderError::Timeout {
451 message: "t3".into(),
452 }),
453 ],
454 ));
455 let retry = RetryProvider::with_config(mock.clone(), 2, Duration::from_millis(1));
457 let config = CompletionConfig::default();
458 let result = retry.complete("sys", "usr", &config).await;
459 assert!(result.is_err());
460 assert_eq!(mock.call_count(), 3);
461 match result.unwrap_err() {
462 ProviderError::Timeout { message } => assert_eq!(message, "t3"),
463 other => panic!("expected Timeout, got: {other}"),
464 }
465 }
466
467 #[tokio::test]
469 async fn test_retry_provider_returns_success_on_first_retry() {
470 let mock = Arc::new(MockProvider::with_responses(
471 "p",
472 "m",
473 vec![
474 Err(ProviderError::Timeout {
475 message: "t1".into(),
476 }),
477 Ok("recovered".into()),
478 ],
479 ));
480 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
481 let config = CompletionConfig::default();
482 let result = retry.complete("sys", "usr", &config).await;
483 assert!(result.is_ok());
484 assert_eq!(result.unwrap(), "recovered");
485 assert_eq!(mock.call_count(), 2);
486 }
487
488 #[test]
490 fn test_retry_provider_default_config() {
491 let mock = Arc::new(MockProvider::new("p", "m"));
492 let retry = RetryProvider::new(mock);
493 assert_eq!(retry.max_retries, 3);
494 assert_eq!(retry.base_delay, Duration::from_secs(1));
495 }
496}