1use std::time::Duration;
32use thiserror::Error;
33
34pub type Result<T> = std::result::Result<T, LlmError>;
36
37#[derive(Debug, Clone, PartialEq)]
46pub enum RetryStrategy {
47 ExponentialBackoff {
49 base_delay: Duration,
51 max_delay: Duration,
53 max_attempts: u32,
55 },
56
57 WaitAndRetry {
59 wait: Duration,
61 },
62
63 ReduceContext,
65
66 NoRetry,
68}
69
70impl RetryStrategy {
71 pub fn network_backoff() -> Self {
73 Self::ExponentialBackoff {
74 base_delay: Duration::from_millis(125),
75 max_delay: Duration::from_secs(30),
76 max_attempts: 5,
77 }
78 }
79
80 pub fn server_backoff() -> Self {
82 Self::ExponentialBackoff {
83 base_delay: Duration::from_secs(1),
84 max_delay: Duration::from_secs(60),
85 max_attempts: 3,
86 }
87 }
88
89 pub fn should_retry(&self) -> bool {
91 !matches!(self, Self::NoRetry)
92 }
93}
94
95#[derive(Debug, Error)]
101pub enum LlmError {
102 #[error("API error: {0}")]
104 ApiError(String),
105
106 #[error("Rate limit exceeded: {0}")]
108 RateLimited(String),
109
110 #[error("Invalid request: {0}")]
112 InvalidRequest(String),
113
114 #[error("Authentication error: {0}")]
116 AuthError(String),
117
118 #[error("Token limit exceeded: max {max}, got {got}")]
120 TokenLimitExceeded { max: usize, got: usize },
121
122 #[error("Model not found: {0}")]
124 ModelNotFound(String),
125
126 #[error("Network error: {0}")]
128 NetworkError(String),
129
130 #[error("Serialization error: {0}")]
132 SerializationError(#[from] serde_json::Error),
133
134 #[error("Configuration error: {0}")]
136 ConfigError(String),
137
138 #[error("Provider error: {0}")]
140 ProviderError(String),
141
142 #[error("Request timed out")]
144 Timeout,
145
146 #[error("Not supported: {0}")]
148 NotSupported(String),
149
150 #[error("Unknown error: {0}")]
152 Unknown(String),
153}
154
155impl From<reqwest::Error> for LlmError {
156 fn from(err: reqwest::Error) -> Self {
157 if err.is_timeout() {
158 LlmError::Timeout
159 } else if err.is_connect() {
160 LlmError::NetworkError(format!("Connection failed: {}", err))
161 } else {
162 LlmError::NetworkError(err.to_string())
163 }
164 }
165}
166
167impl From<async_openai::error::OpenAIError> for LlmError {
168 fn from(err: async_openai::error::OpenAIError) -> Self {
169 match err {
170 async_openai::error::OpenAIError::ApiError(api_err) => {
171 let message = api_err.message.clone();
172 if message.contains("rate limit") || message.contains("Rate limit") {
173 LlmError::RateLimited(message)
174 } else if message.contains("authentication") || message.contains("invalid_api_key")
175 {
176 LlmError::AuthError(message)
177 } else if message.contains("model") && message.contains("not found") {
178 LlmError::ModelNotFound(message)
179 } else {
180 LlmError::ApiError(message)
181 }
182 }
183 async_openai::error::OpenAIError::Reqwest(req_err) => LlmError::from(req_err),
184 async_openai::error::OpenAIError::JSONDeserialize(json_err, _content) => {
185 LlmError::SerializationError(json_err)
186 }
187 _ => LlmError::ProviderError(err.to_string()),
188 }
189 }
190}
191
192impl LlmError {
197 pub fn retry_strategy(&self) -> RetryStrategy {
216 match self {
217 Self::NetworkError(_) | Self::Timeout => RetryStrategy::network_backoff(),
219
220 Self::RateLimited(_) => RetryStrategy::WaitAndRetry {
222 wait: Duration::from_secs(60),
223 },
224
225 Self::ApiError(msg)
227 if msg.contains("500") || msg.contains("502") || msg.contains("503") =>
228 {
229 RetryStrategy::server_backoff()
230 }
231 Self::ProviderError(_) => RetryStrategy::server_backoff(),
232
233 Self::TokenLimitExceeded { .. } => RetryStrategy::ReduceContext,
235
236 Self::AuthError(_)
238 | Self::InvalidRequest(_)
239 | Self::ModelNotFound(_)
240 | Self::ConfigError(_)
241 | Self::NotSupported(_) => RetryStrategy::NoRetry,
242
243 Self::ApiError(_) | Self::SerializationError(_) | Self::Unknown(_) => {
245 RetryStrategy::ExponentialBackoff {
246 base_delay: Duration::from_secs(1),
247 max_delay: Duration::from_secs(30),
248 max_attempts: 2,
249 }
250 }
251 }
252 }
253
254 pub fn user_description(&self) -> String {
265 match self {
266 Self::NetworkError(_) => {
267 "Unable to connect to the API. Check your internet connection.".to_string()
268 }
269 Self::Timeout => "Request timed out. The server may be overloaded.".to_string(),
270 Self::RateLimited(_) => "Rate limited by the API. Waiting before retry...".to_string(),
271 Self::TokenLimitExceeded { max, got } => {
272 format!(
273 "Context too large ({}/{} tokens). Reducing context and retrying...",
274 got, max
275 )
276 }
277 Self::AuthError(_) => {
278 "Authentication failed. Please check your API key is valid and not expired."
279 .to_string()
280 }
281 Self::ModelNotFound(model) => {
282 format!(
283 "Model '{}' not found. Use a supported model like 'gpt-4o-mini'.",
284 model
285 )
286 }
287 Self::InvalidRequest(msg) => {
288 format!("Invalid request: {}. Check your parameters.", msg)
289 }
290 Self::ConfigError(msg) => format!("Configuration error: {}.", msg),
291 Self::NotSupported(feature) => {
292 format!("Feature '{}' is not supported by this provider.", feature)
293 }
294 Self::ApiError(_) | Self::ProviderError(_) => {
295 "API server error. Retrying...".to_string()
296 }
297 Self::SerializationError(_) => {
298 "Failed to parse API response. This may be a temporary issue.".to_string()
299 }
300 Self::Unknown(msg) => format!("An unexpected error occurred: {}", msg),
301 }
302 }
303
304 pub fn is_recoverable(&self) -> bool {
306 self.retry_strategy().should_retry()
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_llm_error_display() {
316 let error = LlmError::ApiError("something went wrong".to_string());
317 assert_eq!(error.to_string(), "API error: something went wrong");
318
319 let error = LlmError::RateLimited("too many requests".to_string());
320 assert_eq!(error.to_string(), "Rate limit exceeded: too many requests");
321
322 let error = LlmError::InvalidRequest("bad params".to_string());
323 assert_eq!(error.to_string(), "Invalid request: bad params");
324 }
325
326 #[test]
327 fn test_llm_error_auth() {
328 let error = LlmError::AuthError("invalid key".to_string());
329 assert_eq!(error.to_string(), "Authentication error: invalid key");
330 }
331
332 #[test]
333 fn test_llm_error_token_limit() {
334 let error = LlmError::TokenLimitExceeded {
335 max: 4096,
336 got: 5000,
337 };
338 assert_eq!(
339 error.to_string(),
340 "Token limit exceeded: max 4096, got 5000"
341 );
342 }
343
344 #[test]
345 fn test_llm_error_model_not_found() {
346 let error = LlmError::ModelNotFound("gpt-5-turbo".to_string());
347 assert_eq!(error.to_string(), "Model not found: gpt-5-turbo");
348 }
349
350 #[test]
351 fn test_llm_error_network() {
352 let error = LlmError::NetworkError("connection refused".to_string());
353 assert_eq!(error.to_string(), "Network error: connection refused");
354 }
355
356 #[test]
357 fn test_llm_error_config() {
358 let error = LlmError::ConfigError("missing api key".to_string());
359 assert_eq!(error.to_string(), "Configuration error: missing api key");
360 }
361
362 #[test]
363 fn test_llm_error_provider() {
364 let error = LlmError::ProviderError("openai specific error".to_string());
365 assert_eq!(error.to_string(), "Provider error: openai specific error");
366 }
367
368 #[test]
369 fn test_llm_error_timeout() {
370 let error = LlmError::Timeout;
371 assert_eq!(error.to_string(), "Request timed out");
372 }
373
374 #[test]
375 fn test_llm_error_not_supported() {
376 let error = LlmError::NotSupported("function calling".to_string());
377 assert_eq!(error.to_string(), "Not supported: function calling");
378 }
379
380 #[test]
381 fn test_llm_error_unknown() {
382 let error = LlmError::Unknown("mystery error".to_string());
383 assert_eq!(error.to_string(), "Unknown error: mystery error");
384 }
385
386 #[test]
387 fn test_llm_error_debug() {
388 let error = LlmError::ApiError("test".to_string());
389 let debug = format!("{:?}", error);
390 assert!(debug.contains("ApiError"));
391 assert!(debug.contains("test"));
392 }
393
394 #[test]
395 fn test_llm_error_from_serde_json() {
396 let json_str = "not json at all";
397 let json_err: serde_json::Error =
398 serde_json::from_str::<serde_json::Value>(json_str).unwrap_err();
399 let llm_err: LlmError = json_err.into();
400 assert!(matches!(llm_err, LlmError::SerializationError(_)));
401 }
402
403 #[test]
408 fn test_network_error_retry_strategy() {
409 let error = LlmError::NetworkError("connection failed".to_string());
410 let strategy = error.retry_strategy();
411
412 match strategy {
413 RetryStrategy::ExponentialBackoff { max_attempts, .. } => {
414 assert_eq!(max_attempts, 5);
415 }
416 _ => panic!("Expected ExponentialBackoff for network error"),
417 }
418 assert!(strategy.should_retry());
419 assert!(error.is_recoverable());
420 }
421
422 #[test]
423 fn test_timeout_retry_strategy() {
424 let error = LlmError::Timeout;
425 let strategy = error.retry_strategy();
426
427 assert!(matches!(strategy, RetryStrategy::ExponentialBackoff { .. }));
428 assert!(strategy.should_retry());
429 }
430
431 #[test]
432 fn test_rate_limited_retry_strategy() {
433 let error = LlmError::RateLimited("too many requests".to_string());
434 let strategy = error.retry_strategy();
435
436 match strategy {
437 RetryStrategy::WaitAndRetry { wait } => {
438 assert_eq!(wait, Duration::from_secs(60));
439 }
440 _ => panic!("Expected WaitAndRetry for rate limit"),
441 }
442 assert!(strategy.should_retry());
443 }
444
445 #[test]
446 fn test_token_limit_reduce_context_strategy() {
447 let error = LlmError::TokenLimitExceeded {
448 max: 4096,
449 got: 5000,
450 };
451 let strategy = error.retry_strategy();
452
453 assert!(matches!(strategy, RetryStrategy::ReduceContext));
454 assert!(strategy.should_retry());
455 }
456
457 #[test]
458 fn test_auth_error_no_retry() {
459 let error = LlmError::AuthError("invalid key".to_string());
460 let strategy = error.retry_strategy();
461
462 assert!(matches!(strategy, RetryStrategy::NoRetry));
463 assert!(!strategy.should_retry());
464 assert!(!error.is_recoverable());
465 }
466
467 #[test]
468 fn test_invalid_request_no_retry() {
469 let error = LlmError::InvalidRequest("bad params".to_string());
470 assert!(matches!(error.retry_strategy(), RetryStrategy::NoRetry));
471 }
472
473 #[test]
474 fn test_model_not_found_no_retry() {
475 let error = LlmError::ModelNotFound("gpt-5".to_string());
476 assert!(matches!(error.retry_strategy(), RetryStrategy::NoRetry));
477 }
478
479 #[test]
480 fn test_user_description_network() {
481 let error = LlmError::NetworkError("connection refused".to_string());
482 let desc = error.user_description();
483 assert!(desc.contains("internet connection"));
484 }
485
486 #[test]
487 fn test_user_description_auth() {
488 let error = LlmError::AuthError("invalid".to_string());
489 let desc = error.user_description();
490 assert!(desc.contains("API key"));
491 }
492
493 #[test]
494 fn test_user_description_token_limit() {
495 let error = LlmError::TokenLimitExceeded {
496 max: 4096,
497 got: 5000,
498 };
499 let desc = error.user_description();
500 assert!(desc.contains("5000/4096"));
501 assert!(desc.contains("Reducing"));
502 }
503
504 #[test]
505 fn test_retry_strategy_equality() {
506 let s1 = RetryStrategy::network_backoff();
507 let s2 = RetryStrategy::network_backoff();
508 assert_eq!(s1, s2);
509
510 let s3 = RetryStrategy::NoRetry;
511 assert_ne!(s1, s3);
512 }
513
514 #[test]
519 fn test_user_description_timeout() {
520 let error = LlmError::Timeout;
521 let desc = error.user_description();
522 assert!(desc.contains("timed out"));
523 }
524
525 #[test]
526 fn test_user_description_rate_limited() {
527 let error = LlmError::RateLimited("slow down".to_string());
528 let desc = error.user_description();
529 assert!(desc.contains("Rate limited"));
530 }
531
532 #[test]
533 fn test_user_description_model_not_found() {
534 let error = LlmError::ModelNotFound("gpt-5".to_string());
535 let desc = error.user_description();
536 assert!(desc.contains("gpt-5"));
537 assert!(desc.contains("not found"));
538 }
539
540 #[test]
541 fn test_user_description_not_supported() {
542 let error = LlmError::NotSupported("streaming".to_string());
543 let desc = error.user_description();
544 assert!(desc.contains("streaming"));
545 assert!(desc.contains("not supported"));
546 }
547
548 #[test]
549 fn test_user_description_unknown() {
550 let error = LlmError::Unknown("mystery".to_string());
551 let desc = error.user_description();
552 assert!(desc.contains("mystery"));
553 }
554
555 #[test]
556 fn test_user_description_api_error() {
557 let error = LlmError::ApiError("server crashed".to_string());
558 let desc = error.user_description();
559 assert!(desc.contains("Retrying"));
560 }
561
562 #[test]
563 fn test_user_description_provider_error() {
564 let error = LlmError::ProviderError("internal failure".to_string());
565 let desc = error.user_description();
566 assert!(desc.contains("Retrying"));
567 }
568
569 #[test]
570 fn test_user_description_serialization() {
571 let json_err = serde_json::from_str::<serde_json::Value>("bad").unwrap_err();
572 let error = LlmError::SerializationError(json_err);
573 let desc = error.user_description();
574 assert!(desc.contains("parse"));
575 }
576
577 #[test]
578 fn test_user_description_config() {
579 let error = LlmError::ConfigError("missing field".to_string());
580 let desc = error.user_description();
581 assert!(desc.contains("Configuration"));
582 }
583
584 #[test]
585 fn test_user_description_invalid_request() {
586 let error = LlmError::InvalidRequest("empty prompt".to_string());
587 let desc = error.user_description();
588 assert!(desc.contains("empty prompt"));
589 }
590
591 #[test]
596 fn test_api_error_500_server_backoff() {
597 let error = LlmError::ApiError("HTTP 500 internal server error".to_string());
598 let strategy = error.retry_strategy();
599 match strategy {
600 RetryStrategy::ExponentialBackoff { max_attempts, .. } => {
601 assert_eq!(max_attempts, 3); }
603 _ => panic!("Expected ExponentialBackoff for 500 error"),
604 }
605 }
606
607 #[test]
608 fn test_api_error_502_server_backoff() {
609 let error = LlmError::ApiError("502 bad gateway".to_string());
610 assert!(matches!(
611 error.retry_strategy(),
612 RetryStrategy::ExponentialBackoff { .. }
613 ));
614 }
615
616 #[test]
617 fn test_api_error_503_server_backoff() {
618 let error = LlmError::ApiError("503 service unavailable".to_string());
619 assert!(matches!(
620 error.retry_strategy(),
621 RetryStrategy::ExponentialBackoff { .. }
622 ));
623 }
624
625 #[test]
626 fn test_provider_error_server_backoff() {
627 let error = LlmError::ProviderError("internal issue".to_string());
628 let strategy = error.retry_strategy();
629 match strategy {
630 RetryStrategy::ExponentialBackoff {
631 base_delay,
632 max_delay,
633 max_attempts,
634 } => {
635 assert_eq!(base_delay, Duration::from_secs(1));
636 assert_eq!(max_delay, Duration::from_secs(60));
637 assert_eq!(max_attempts, 3);
638 }
639 _ => panic!("Expected server_backoff for ProviderError"),
640 }
641 }
642
643 #[test]
644 fn test_unknown_error_retry_strategy() {
645 let error = LlmError::Unknown("something".to_string());
646 let strategy = error.retry_strategy();
647 match strategy {
648 RetryStrategy::ExponentialBackoff { max_attempts, .. } => {
649 assert_eq!(max_attempts, 2);
650 }
651 _ => panic!("Expected ExponentialBackoff for Unknown"),
652 }
653 }
654
655 #[test]
656 fn test_serialization_error_retry_strategy() {
657 let json_err = serde_json::from_str::<serde_json::Value>("bad").unwrap_err();
658 let error = LlmError::SerializationError(json_err);
659 let strategy = error.retry_strategy();
660 assert!(matches!(strategy, RetryStrategy::ExponentialBackoff { .. }));
661 }
662
663 #[test]
664 fn test_api_error_non_5xx_retry_strategy() {
665 let error = LlmError::ApiError("generic error".to_string());
666 let strategy = error.retry_strategy();
667 match strategy {
668 RetryStrategy::ExponentialBackoff { max_attempts, .. } => {
669 assert_eq!(max_attempts, 2);
670 }
671 _ => panic!("Expected ExponentialBackoff for generic ApiError"),
672 }
673 }
674
675 #[test]
676 fn test_config_error_no_retry() {
677 let error = LlmError::ConfigError("bad config".to_string());
678 assert!(matches!(error.retry_strategy(), RetryStrategy::NoRetry));
679 assert!(!error.is_recoverable());
680 }
681
682 #[test]
683 fn test_not_supported_no_retry() {
684 let error = LlmError::NotSupported("embeddings".to_string());
685 assert!(matches!(error.retry_strategy(), RetryStrategy::NoRetry));
686 assert!(!error.is_recoverable());
687 }
688
689 #[test]
694 fn test_server_backoff_values() {
695 let strategy = RetryStrategy::server_backoff();
696 match strategy {
697 RetryStrategy::ExponentialBackoff {
698 base_delay,
699 max_delay,
700 max_attempts,
701 } => {
702 assert_eq!(base_delay, Duration::from_secs(1));
703 assert_eq!(max_delay, Duration::from_secs(60));
704 assert_eq!(max_attempts, 3);
705 }
706 _ => panic!("Expected ExponentialBackoff"),
707 }
708 }
709
710 #[test]
711 fn test_network_backoff_values() {
712 let strategy = RetryStrategy::network_backoff();
713 match strategy {
714 RetryStrategy::ExponentialBackoff {
715 base_delay,
716 max_delay,
717 max_attempts,
718 } => {
719 assert_eq!(base_delay, Duration::from_millis(125));
720 assert_eq!(max_delay, Duration::from_secs(30));
721 assert_eq!(max_attempts, 5);
722 }
723 _ => panic!("Expected ExponentialBackoff"),
724 }
725 }
726
727 #[test]
728 fn test_reduce_context_should_retry() {
729 let strategy = RetryStrategy::ReduceContext;
730 assert!(strategy.should_retry());
731 }
732
733 #[test]
734 fn test_wait_and_retry_should_retry() {
735 let strategy = RetryStrategy::WaitAndRetry {
736 wait: Duration::from_secs(1),
737 };
738 assert!(strategy.should_retry());
739 }
740
741 #[test]
746 fn test_is_recoverable_network() {
747 assert!(LlmError::NetworkError("fail".to_string()).is_recoverable());
748 }
749
750 #[test]
751 fn test_is_recoverable_timeout() {
752 assert!(LlmError::Timeout.is_recoverable());
753 }
754
755 #[test]
756 fn test_is_recoverable_rate_limited() {
757 assert!(LlmError::RateLimited("wait".to_string()).is_recoverable());
758 }
759
760 #[test]
761 fn test_is_not_recoverable_invalid_request() {
762 assert!(!LlmError::InvalidRequest("bad".to_string()).is_recoverable());
763 }
764
765 #[test]
766 fn test_is_not_recoverable_model_not_found() {
767 assert!(!LlmError::ModelNotFound("x".to_string()).is_recoverable());
768 }
769}