1use crate::error::ProviderError;
6use std::sync::Arc;
7use std::time::Duration;
8
9#[non_exhaustive]
15#[derive(Debug, Clone)]
16pub struct CompletionConfig {
17 pub max_tokens: u32,
19 pub temperature: f64,
21}
22
23impl Default for CompletionConfig {
24 fn default() -> Self {
25 Self {
26 max_tokens: 4096,
27 temperature: 0.0,
28 }
29 }
30}
31
32#[async_trait::async_trait]
42pub trait LlmProvider: Send + Sync {
43 async fn complete(
53 &self,
54 system_prompt: &str,
55 user_prompt: &str,
56 config: &CompletionConfig,
57 ) -> Result<String, ProviderError>;
58
59 fn name(&self) -> &str;
61
62 fn model(&self) -> &str;
64}
65
66pub fn resolve_claude_alias(model: &str) -> Result<String, ProviderError> {
92 match model {
93 "sonnet" => Ok("claude-sonnet-4-6".to_string()),
94 "opus" => Ok("claude-opus-4-7".to_string()),
95 "haiku" => Ok("claude-haiku-4-5-20251001".to_string()),
96 m if m.contains("claude-") => Ok(m.to_string()),
97 _ => Err(ProviderError::Auth {
98 message: format!("unknown model alias: {model}"),
99 }),
100 }
101}
102
103pub struct RetryProvider {
112 inner: Arc<dyn LlmProvider>,
113 pub max_retries: u32,
115 pub base_delay: Duration,
117}
118
119impl RetryProvider {
120 pub fn new(inner: Arc<dyn LlmProvider>) -> Self {
125 Self {
126 inner,
127 max_retries: 3,
128 base_delay: Duration::from_secs(1),
129 }
130 }
131
132 pub fn with_config(
139 inner: Arc<dyn LlmProvider>,
140 max_retries: u32,
141 base_delay: Duration,
142 ) -> Self {
143 Self {
144 inner,
145 max_retries,
146 base_delay,
147 }
148 }
149}
150
151fn is_retryable(error: &ProviderError) -> bool {
164 match error {
165 ProviderError::Timeout { .. } | ProviderError::Network { .. } => true,
166 ProviderError::Http { status, .. } => *status == 500 || *status == 429,
167 _ => false,
168 }
169}
170
171#[async_trait::async_trait]
172impl LlmProvider for RetryProvider {
173 async fn complete(
174 &self,
175 system_prompt: &str,
176 user_prompt: &str,
177 config: &CompletionConfig,
178 ) -> Result<String, ProviderError> {
179 let mut last_error = None;
180 let mut delay = self.base_delay;
181 for attempt in 0..=self.max_retries {
182 match self
183 .inner
184 .complete(system_prompt, user_prompt, config)
185 .await
186 {
187 Ok(response) => return Ok(response),
188 Err(err) => {
189 if !is_retryable(&err) || attempt == self.max_retries {
190 return Err(err);
191 }
192 last_error = Some(err);
193 tokio::time::sleep(delay).await;
194 delay = delay.saturating_mul(2);
195 }
196 }
197 }
198 Err(last_error.expect("at least one attempt must have been made"))
199 }
200
201 fn name(&self) -> &str {
202 self.inner.name()
203 }
204
205 fn model(&self) -> &str {
206 self.inner.model()
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use std::sync::Arc;
214 use std::sync::atomic::{AtomicU32, Ordering};
215 use std::time::Duration;
216
217 struct MockProvider {
219 provider_name: String,
220 provider_model: String,
221 responses: std::sync::Mutex<Vec<Result<String, ProviderError>>>,
222 call_count: AtomicU32,
223 }
224
225 impl MockProvider {
226 fn new(name: &str, model: &str) -> Self {
227 Self {
228 provider_name: name.to_string(),
229 provider_model: model.to_string(),
230 responses: std::sync::Mutex::new(Vec::new()),
231 call_count: AtomicU32::new(0),
232 }
233 }
234
235 fn with_responses(
236 name: &str,
237 model: &str,
238 responses: Vec<Result<String, ProviderError>>,
239 ) -> Self {
240 let mut reversed = responses;
242 reversed.reverse();
243 Self {
244 provider_name: name.to_string(),
245 provider_model: model.to_string(),
246 responses: std::sync::Mutex::new(reversed),
247 call_count: AtomicU32::new(0),
248 }
249 }
250
251 fn call_count(&self) -> u32 {
252 self.call_count.load(Ordering::SeqCst)
253 }
254 }
255
256 #[async_trait::async_trait]
257 impl LlmProvider for MockProvider {
258 async fn complete(
259 &self,
260 _system_prompt: &str,
261 _user_prompt: &str,
262 _config: &CompletionConfig,
263 ) -> Result<String, ProviderError> {
264 self.call_count.fetch_add(1, Ordering::SeqCst);
265 let mut responses = self.responses.lock().unwrap();
266 if let Some(result) = responses.pop() {
267 result
268 } else {
269 Ok("default response".to_string())
270 }
271 }
272
273 fn name(&self) -> &str {
274 &self.provider_name
275 }
276
277 fn model(&self) -> &str {
278 &self.provider_model
279 }
280 }
281
282 #[test]
286 fn test_completion_config_default_values() {
287 let config = CompletionConfig::default();
288 assert_eq!(config.max_tokens, 4096);
289 assert!((config.temperature - 0.0).abs() < f64::EPSILON);
290 }
291
292 #[test]
294 fn test_completion_config_is_non_exhaustive() {
295 let config = CompletionConfig::default();
296 assert_eq!(config.max_tokens, 4096);
297 assert!((config.temperature).abs() < f64::EPSILON);
298 }
299
300 #[tokio::test]
304 async fn test_retry_provider_delegates_name() {
305 let mock = Arc::new(MockProvider::new("test-provider", "test-model"));
306 let retry = RetryProvider::new(mock);
307 assert_eq!(retry.name(), "test-provider");
308 }
309
310 #[tokio::test]
312 async fn test_retry_provider_delegates_model() {
313 let mock = Arc::new(MockProvider::new("test-provider", "test-model"));
314 let retry = RetryProvider::new(mock);
315 assert_eq!(retry.model(), "test-model");
316 }
317
318 #[tokio::test]
322 async fn test_retry_provider_retries_on_timeout() {
323 let mock = Arc::new(MockProvider::with_responses(
324 "p",
325 "m",
326 vec![
327 Err(ProviderError::Timeout {
328 message: "t1".into(),
329 }),
330 Err(ProviderError::Timeout {
331 message: "t2".into(),
332 }),
333 Ok("success".into()),
334 ],
335 ));
336 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
337 let config = CompletionConfig::default();
338 let result = retry.complete("sys", "usr", &config).await;
339 assert!(result.is_ok());
340 assert_eq!(result.unwrap(), "success");
341 assert_eq!(mock.call_count(), 3);
342 }
343
344 #[tokio::test]
346 async fn test_retry_provider_retries_on_http_500() {
347 let mock = Arc::new(MockProvider::with_responses(
348 "p",
349 "m",
350 vec![
351 Err(ProviderError::Http {
352 status: 500,
353 body: "err".into(),
354 }),
355 Ok("ok".into()),
356 ],
357 ));
358 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
359 let config = CompletionConfig::default();
360 let result = retry.complete("sys", "usr", &config).await;
361 assert!(result.is_ok());
362 assert_eq!(mock.call_count(), 2);
363 }
364
365 #[tokio::test]
367 async fn test_retry_provider_retries_on_http_429() {
368 let mock = Arc::new(MockProvider::with_responses(
369 "p",
370 "m",
371 vec![
372 Err(ProviderError::Http {
373 status: 429,
374 body: "rate limit".into(),
375 }),
376 Ok("ok".into()),
377 ],
378 ));
379 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
380 let config = CompletionConfig::default();
381 let result = retry.complete("sys", "usr", &config).await;
382 assert!(result.is_ok());
383 assert_eq!(mock.call_count(), 2);
384 }
385
386 #[tokio::test]
388 async fn test_retry_provider_retries_on_network() {
389 let mock = Arc::new(MockProvider::with_responses(
390 "p",
391 "m",
392 vec![
393 Err(ProviderError::Network {
394 message: "dns".into(),
395 }),
396 Ok("ok".into()),
397 ],
398 ));
399 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
400 let config = CompletionConfig::default();
401 let result = retry.complete("sys", "usr", &config).await;
402 assert!(result.is_ok());
403 assert_eq!(mock.call_count(), 2);
404 }
405
406 #[tokio::test]
408 async fn test_retry_provider_does_not_retry_on_auth() {
409 let mock = Arc::new(MockProvider::with_responses(
410 "p",
411 "m",
412 vec![Err(ProviderError::Auth {
413 message: "bad key".into(),
414 })],
415 ));
416 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
417 let config = CompletionConfig::default();
418 let result = retry.complete("sys", "usr", &config).await;
419 assert!(result.is_err());
420 assert_eq!(mock.call_count(), 1);
421 }
422
423 #[tokio::test]
425 async fn test_retry_provider_does_not_retry_on_process() {
426 let mock = Arc::new(MockProvider::with_responses(
427 "p",
428 "m",
429 vec![Err(ProviderError::Process {
430 exit_code: Some(1),
431 stderr: "fail".into(),
432 })],
433 ));
434 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
435 let config = CompletionConfig::default();
436 let result = retry.complete("sys", "usr", &config).await;
437 assert!(result.is_err());
438 assert_eq!(mock.call_count(), 1);
439 }
440
441 #[tokio::test]
443 async fn test_retry_provider_does_not_retry_on_nested_session() {
444 let mock = Arc::new(MockProvider::with_responses(
445 "p",
446 "m",
447 vec![Err(ProviderError::NestedSession)],
448 ));
449 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
450 let config = CompletionConfig::default();
451 let result = retry.complete("sys", "usr", &config).await;
452 assert!(result.is_err());
453 assert_eq!(mock.call_count(), 1);
454 }
455
456 #[tokio::test]
458 async fn test_retry_provider_does_not_retry_on_http_4xx() {
459 let mock = Arc::new(MockProvider::with_responses(
460 "p",
461 "m",
462 vec![Err(ProviderError::Http {
463 status: 403,
464 body: "forbidden".into(),
465 })],
466 ));
467 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
468 let config = CompletionConfig::default();
469 let result = retry.complete("sys", "usr", &config).await;
470 assert!(result.is_err());
471 assert_eq!(mock.call_count(), 1);
472 }
473
474 #[tokio::test]
476 async fn test_retry_provider_returns_last_error_after_exhausting_retries() {
477 let mock = Arc::new(MockProvider::with_responses(
478 "p",
479 "m",
480 vec![
481 Err(ProviderError::Timeout {
482 message: "t1".into(),
483 }),
484 Err(ProviderError::Timeout {
485 message: "t2".into(),
486 }),
487 Err(ProviderError::Timeout {
488 message: "t3".into(),
489 }),
490 ],
491 ));
492 let retry = RetryProvider::with_config(mock.clone(), 2, Duration::from_millis(1));
494 let config = CompletionConfig::default();
495 let result = retry.complete("sys", "usr", &config).await;
496 assert!(result.is_err());
497 assert_eq!(mock.call_count(), 3);
498 match result.unwrap_err() {
499 ProviderError::Timeout { message } => assert_eq!(message, "t3"),
500 other => panic!("expected Timeout, got: {other}"),
501 }
502 }
503
504 #[tokio::test]
506 async fn test_retry_provider_returns_success_on_first_retry() {
507 let mock = Arc::new(MockProvider::with_responses(
508 "p",
509 "m",
510 vec![
511 Err(ProviderError::Timeout {
512 message: "t1".into(),
513 }),
514 Ok("recovered".into()),
515 ],
516 ));
517 let retry = RetryProvider::with_config(mock.clone(), 3, Duration::from_millis(1));
518 let config = CompletionConfig::default();
519 let result = retry.complete("sys", "usr", &config).await;
520 assert!(result.is_ok());
521 assert_eq!(result.unwrap(), "recovered");
522 assert_eq!(mock.call_count(), 2);
523 }
524
525 #[test]
527 fn test_retry_provider_default_config() {
528 let mock = Arc::new(MockProvider::new("p", "m"));
529 let retry = RetryProvider::new(mock);
530 assert_eq!(retry.max_retries, 3);
531 assert_eq!(retry.base_delay, Duration::from_secs(1));
532 }
533
534 #[test]
535 fn test_resolve_claude_alias_opus_returns_claude_opus_4_7() {
536 let result = resolve_claude_alias("opus").unwrap();
537 assert_eq!(result, "claude-opus-4-7");
538 }
539
540 #[test]
541 fn test_resolve_claude_alias_sonnet_returns_claude_sonnet_4_6() {
542 let result = resolve_claude_alias("sonnet").unwrap();
543 assert_eq!(result, "claude-sonnet-4-6");
544 }
545
546 #[test]
547 fn test_resolve_claude_alias_haiku_returns_claude_haiku_4_5_20251001() {
548 let result = resolve_claude_alias("haiku").unwrap();
549 assert_eq!(result, "claude-haiku-4-5-20251001");
550 }
551
552 #[test]
555 fn test_resolve_claude_alias_accepts_literal_claude_opus_4_6_passthrough() {
556 assert_eq!(
559 resolve_claude_alias("claude-opus-4-6").unwrap(),
560 "claude-opus-4-6"
561 );
562 }
563}