1use std::fmt;
23use std::time::Duration;
24
25use serde::{Deserialize, Serialize};
26use thiserror::Error;
27
28use crate::types::ToolCallId;
29
30pub type Result<T, E = Error> = std::result::Result<T, E>;
45
46#[derive(Debug, Error)]
80pub enum Error {
81 #[error("provider error: {kind}")]
83 Provider {
84 kind: ProviderErrorKind,
85 suggestion: Option<String>,
87 },
88
89 #[error("tool `{name}` failed: {message}")]
92 Tool {
93 name: String,
94 call_id: ToolCallId,
95 message: String,
96 #[source]
97 source: Option<Box<dyn std::error::Error + Send + Sync>>,
98 },
99
100 #[error("context overflow: {used}/{limit} tokens")]
103 ContextOverflow { used: usize, limit: usize },
104
105 #[error("config error: {message}")]
107 Config { message: String },
108
109 #[error("cancelled")]
111 Cancelled,
112
113 #[error("{0}")]
115 Internal(#[from] Box<dyn std::error::Error + Send + Sync>),
116}
117
118impl Error {
119 pub fn tool(
130 name: impl Into<String>,
131 call_id: ToolCallId,
132 message: impl Into<String>,
133 ) -> Self {
134 Self::Tool {
135 name: name.into(),
136 call_id,
137 message: message.into(),
138 source: None,
139 }
140 }
141
142 pub fn tool_with_source(
144 name: impl Into<String>,
145 call_id: ToolCallId,
146 message: impl Into<String>,
147 source: impl std::error::Error + Send + Sync + 'static,
148 ) -> Self {
149 Self::Tool {
150 name: name.into(),
151 call_id,
152 message: message.into(),
153 source: Some(Box::new(source)),
154 }
155 }
156
157 pub fn provider(kind: ProviderErrorKind, suggestion: impl Into<String>) -> Self {
173 Self::Provider {
174 kind,
175 suggestion: Some(suggestion.into()),
176 }
177 }
178
179 pub fn retryable(&self) -> bool {
183 match self {
184 Self::Provider { kind, .. } => kind.retryable(),
185 _ => false,
186 }
187 }
188
189 pub fn retry_after(&self) -> Option<Duration> {
194 match self {
195 Self::Provider { kind, .. } => kind.retry_after(),
196 _ => None,
197 }
198 }
199}
200
201#[derive(Debug, Error, Clone, Serialize, Deserialize)]
210pub enum ProviderErrorKind {
211 #[error("authentication failed: {message} ({kind})")]
213 Authentication {
214 message: String,
215 kind: AuthErrorKind,
216 },
217
218 #[error("rate limited: {message}")]
220 RateLimit {
221 message: String,
222 #[serde(with = "option_duration_millis")]
224 retry_after: Option<Duration>,
225 },
226
227 #[error("quota exceeded: {message}")]
229 QuotaExceeded { message: String },
230
231 #[error("invalid request: {message}")]
233 InvalidRequest { message: String },
234
235 #[error("content filtered: {message}")]
237 ContentFiltered { message: String },
238
239 #[error("provider internal error ({status}): {message}")]
241 ServerError {
242 message: String,
243 status: u16,
244 #[serde(with = "option_duration_millis")]
245 retry_after: Option<Duration>,
246 },
247
248 #[error("transport error: {message}")]
250 Transport { message: String },
251
252 #[error("invalid response: {message}")]
254 InvalidResponse { message: String },
255
256 #[error("unknown provider error: {message}")]
258 Unknown { message: String },
259}
260
261impl ProviderErrorKind {
262 pub fn retryable(&self) -> bool {
267 matches!(self, Self::RateLimit { .. } | Self::ServerError { .. })
268 }
269
270 pub fn retry_after(&self) -> Option<Duration> {
277 match self {
278 Self::RateLimit { retry_after, .. } => {
279 retry_after.or(Some(Duration::from_secs(30)))
280 }
281 Self::ServerError { retry_after, .. } => {
282 retry_after.or(Some(Duration::from_secs(20)))
283 }
284 _ => None,
285 }
286 }
287}
288
289#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
300#[serde(rename_all = "snake_case")]
301pub enum AuthErrorKind {
302 Missing,
304 Invalid,
306 Expired,
308 InsufficientPermissions,
310 Unknown,
312}
313
314impl fmt::Display for AuthErrorKind {
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 match self {
317 Self::Missing => write!(f, "missing"),
318 Self::Invalid => write!(f, "invalid"),
319 Self::Expired => write!(f, "expired"),
320 Self::InsufficientPermissions => write!(f, "insufficient_permissions"),
321 Self::Unknown => write!(f, "unknown"),
322 }
323 }
324}
325
326mod option_duration_millis {
332 use serde::{Deserialize, Deserializer, Serialize, Serializer};
333 use std::time::Duration;
334
335 pub fn serialize<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
336 where
337 S: Serializer,
338 {
339 match value {
340 Some(d) => d.as_millis().serialize(serializer),
341 None => serializer.serialize_none(),
342 }
343 }
344
345 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
346 where
347 D: Deserializer<'de>,
348 {
349 let opt: Option<u64> = Option::deserialize(deserializer)?;
350 Ok(opt.map(Duration::from_millis))
351 }
352}
353
354#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
366 fn test_error_tool_construction() {
367 let err = Error::tool("read_file", ToolCallId::new("call_1"), "not found");
368 match &err {
369 Error::Tool {
370 name,
371 call_id,
372 message,
373 source,
374 } => {
375 assert_eq!(name, "read_file");
376 assert_eq!(call_id.as_str(), "call_1");
377 assert_eq!(message, "not found");
378 assert!(source.is_none());
379 }
380 _ => panic!("expected Error::Tool"),
381 }
382 }
383
384 #[test]
386 fn test_error_tool_with_source() {
387 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "no such file");
388 let err = Error::tool_with_source(
389 "read_file",
390 ToolCallId::new("call_2"),
391 "failed",
392 io_err,
393 );
394 match &err {
395 Error::Tool { source, .. } => {
396 assert!(source.is_some());
397 }
398 _ => panic!("expected Error::Tool"),
399 }
400 }
401
402 #[test]
404 fn test_error_provider_construction() {
405 let err = Error::provider(
406 ProviderErrorKind::Authentication {
407 message: "bad key".into(),
408 kind: AuthErrorKind::Invalid,
409 },
410 "run /login",
411 );
412 match &err {
413 Error::Provider { kind, suggestion } => {
414 assert!(matches!(kind, ProviderErrorKind::Authentication { .. }));
415 assert_eq!(suggestion.as_deref(), Some("run /login"));
416 }
417 _ => panic!("expected Error::Provider"),
418 }
419 }
420
421 #[test]
423 fn test_cancelled_not_retryable() {
424 let err = Error::Cancelled;
425 assert!(!err.retryable());
426 assert!(err.retry_after().is_none());
427 }
428
429 #[test]
431 fn test_context_overflow_display() {
432 let err = Error::ContextOverflow {
433 used: 130000,
434 limit: 128000,
435 };
436 assert_eq!(err.to_string(), "context overflow: 130000/128000 tokens");
437 }
438
439 #[test]
441 fn test_config_error_display() {
442 let err = Error::Config {
443 message: "missing api_key".into(),
444 };
445 assert!(err.to_string().contains("missing api_key"));
446 }
447
448 #[test]
452 fn test_rate_limit_retryable() {
453 let kind = ProviderErrorKind::RateLimit {
454 message: "429".into(),
455 retry_after: Some(Duration::from_secs(60)),
456 };
457 assert!(kind.retryable());
458 assert_eq!(kind.retry_after(), Some(Duration::from_secs(60)));
459 }
460
461 #[test]
463 fn test_rate_limit_default_retry_after() {
464 let kind = ProviderErrorKind::RateLimit {
465 message: "slow down".into(),
466 retry_after: None,
467 };
468 assert_eq!(kind.retry_after(), Some(Duration::from_secs(30)));
469 }
470
471 #[test]
473 fn test_server_error_retryable() {
474 let kind = ProviderErrorKind::ServerError {
475 message: "internal".into(),
476 status: 500,
477 retry_after: None,
478 };
479 assert!(kind.retryable());
480 assert_eq!(kind.retry_after(), Some(Duration::from_secs(20)));
481 }
482
483 #[test]
485 fn test_authentication_not_retryable() {
486 let kind = ProviderErrorKind::Authentication {
487 message: "invalid".into(),
488 kind: AuthErrorKind::Invalid,
489 };
490 assert!(!kind.retryable());
491 assert!(kind.retry_after().is_none());
492 }
493
494 #[test]
496 fn test_quota_exceeded_not_retryable() {
497 let kind = ProviderErrorKind::QuotaExceeded {
498 message: "out of credits".into(),
499 };
500 assert!(!kind.retryable());
501 }
502
503 #[test]
505 fn test_invalid_request_not_retryable() {
506 let kind = ProviderErrorKind::InvalidRequest {
507 message: "model not found".into(),
508 };
509 assert!(!kind.retryable());
510 }
511
512 #[test]
514 fn test_content_filtered_not_retryable() {
515 let kind = ProviderErrorKind::ContentFiltered {
516 message: "blocked".into(),
517 };
518 assert!(!kind.retryable());
519 }
520
521 #[test]
523 fn test_transport_not_retryable() {
524 let kind = ProviderErrorKind::Transport {
525 message: "dns failed".into(),
526 };
527 assert!(!kind.retryable());
528 }
529
530 #[test]
532 fn test_invalid_response_not_retryable() {
533 let kind = ProviderErrorKind::InvalidResponse {
534 message: "bad json".into(),
535 };
536 assert!(!kind.retryable());
537 }
538
539 #[test]
541 fn test_unknown_not_retryable() {
542 let kind = ProviderErrorKind::Unknown {
543 message: "???".into(),
544 };
545 assert!(!kind.retryable());
546 }
547
548 #[test]
552 fn test_provider_error_kind_serde_roundtrip() {
553 let kind = ProviderErrorKind::RateLimit {
554 message: "429 too many".into(),
555 retry_after: Some(Duration::from_millis(5000)),
556 };
557 let json = serde_json::to_string(&kind).unwrap();
558 let restored: ProviderErrorKind = serde_json::from_str(&json).unwrap();
559 assert!(matches!(
560 restored,
561 ProviderErrorKind::RateLimit {
562 retry_after: Some(d),
563 ..
564 } if d == Duration::from_millis(5000)
565 ));
566 }
567
568 #[test]
570 fn test_server_error_serde_roundtrip() {
571 let kind = ProviderErrorKind::ServerError {
572 message: "overloaded".into(),
573 status: 529,
574 retry_after: None,
575 };
576 let json = serde_json::to_string(&kind).unwrap();
577 let restored: ProviderErrorKind = serde_json::from_str(&json).unwrap();
578 assert!(matches!(
579 restored,
580 ProviderErrorKind::ServerError { status: 529, retry_after: None, .. }
581 ));
582 }
583
584 #[test]
588 fn test_auth_error_kind_display() {
589 assert_eq!(AuthErrorKind::Missing.to_string(), "missing");
590 assert_eq!(AuthErrorKind::Invalid.to_string(), "invalid");
591 assert_eq!(AuthErrorKind::Expired.to_string(), "expired");
592 assert_eq!(
593 AuthErrorKind::InsufficientPermissions.to_string(),
594 "insufficient_permissions"
595 );
596 assert_eq!(AuthErrorKind::Unknown.to_string(), "unknown");
597 }
598
599 #[test]
601 fn test_auth_error_kind_serde_roundtrip() {
602 for kind in [
603 AuthErrorKind::Missing,
604 AuthErrorKind::Invalid,
605 AuthErrorKind::Expired,
606 AuthErrorKind::InsufficientPermissions,
607 AuthErrorKind::Unknown,
608 ] {
609 let json = serde_json::to_string(&kind).unwrap();
610 let restored: AuthErrorKind = serde_json::from_str(&json).unwrap();
611 assert_eq!(kind, restored);
612 }
613 }
614
615 #[test]
619 fn test_error_retryable_delegates() {
620 let retryable_err = Error::Provider {
621 kind: ProviderErrorKind::RateLimit {
622 message: "wait".into(),
623 retry_after: Some(Duration::from_secs(10)),
624 },
625 suggestion: None,
626 };
627 assert!(retryable_err.retryable());
628 assert_eq!(retryable_err.retry_after(), Some(Duration::from_secs(10)));
629
630 let non_retryable_err = Error::Provider {
631 kind: ProviderErrorKind::Authentication {
632 message: "bad".into(),
633 kind: AuthErrorKind::Invalid,
634 },
635 suggestion: None,
636 };
637 assert!(!non_retryable_err.retryable());
638 assert!(non_retryable_err.retry_after().is_none());
639 }
640
641 #[test]
643 fn test_non_provider_errors_not_retryable() {
644 assert!(!Error::Cancelled.retryable());
645 assert!(!Error::Config { message: "x".into() }.retryable());
646 assert!(!Error::ContextOverflow { used: 1, limit: 1 }.retryable());
647 assert!(!Error::tool("t", ToolCallId::new("c"), "m").retryable());
648 }
649
650 #[test]
654 fn test_error_display_formats() {
655 let tool_err = Error::tool("bash", ToolCallId::new("c1"), "permission denied");
656 assert_eq!(
657 tool_err.to_string(),
658 "tool `bash` failed: permission denied"
659 );
660
661 let provider_err = Error::Provider {
662 kind: ProviderErrorKind::Transport {
663 message: "connection reset".into(),
664 },
665 suggestion: Some("check your network".into()),
666 };
667 assert_eq!(
668 provider_err.to_string(),
669 "provider error: transport error: connection reset"
670 );
671 }
672}