1use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use std::collections::HashMap;
6use std::fmt;
7use std::time::Duration;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19pub enum ErrorComponent {
20 Agent,
22 Model,
24 Tool,
26 Session,
28 Artifact,
30 Memory,
32 Graph,
34 Realtime,
36 Code,
38 Server,
40 Auth,
42 Guardrail,
44 Eval,
46 Deploy,
48}
49
50impl fmt::Display for ErrorComponent {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 let s = match self {
53 Self::Agent => "agent",
54 Self::Model => "model",
55 Self::Tool => "tool",
56 Self::Session => "session",
57 Self::Artifact => "artifact",
58 Self::Memory => "memory",
59 Self::Graph => "graph",
60 Self::Realtime => "realtime",
61 Self::Code => "code",
62 Self::Server => "server",
63 Self::Auth => "auth",
64 Self::Guardrail => "guardrail",
65 Self::Eval => "eval",
66 Self::Deploy => "deploy",
67 };
68 f.write_str(s)
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
86#[serde(rename_all = "snake_case")]
87pub enum ErrorCategory {
88 InvalidInput,
90 Unauthorized,
92 Forbidden,
94 NotFound,
96 RateLimited,
98 Timeout,
100 Unavailable,
102 Cancelled,
104 Internal,
106 Unsupported,
108}
109
110impl fmt::Display for ErrorCategory {
111 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112 let s = match self {
113 Self::InvalidInput => "invalid_input",
114 Self::Unauthorized => "unauthorized",
115 Self::Forbidden => "forbidden",
116 Self::NotFound => "not_found",
117 Self::RateLimited => "rate_limited",
118 Self::Timeout => "timeout",
119 Self::Unavailable => "unavailable",
120 Self::Cancelled => "cancelled",
121 Self::Internal => "internal",
122 Self::Unsupported => "unsupported",
123 };
124 f.write_str(s)
125 }
126}
127
128#[derive(Debug, Clone, Default, Serialize, Deserialize)]
130pub struct RetryHint {
131 pub should_retry: bool,
133 #[serde(default, skip_serializing_if = "Option::is_none")]
135 pub retry_after_ms: Option<u64>,
136 #[serde(default, skip_serializing_if = "Option::is_none")]
138 pub max_attempts: Option<u32>,
139}
140
141impl RetryHint {
142 pub fn for_category(category: ErrorCategory) -> Self {
144 match category {
145 ErrorCategory::RateLimited | ErrorCategory::Unavailable | ErrorCategory::Timeout => {
146 Self { should_retry: true, ..Default::default() }
147 }
148 _ => Self::default(),
149 }
150 }
151
152 pub fn retry_after(&self) -> Option<Duration> {
154 self.retry_after_ms.map(Duration::from_millis)
155 }
156
157 pub fn with_retry_after(mut self, duration: Duration) -> Self {
159 self.retry_after_ms = Some(duration.as_millis() as u64);
160 self
161 }
162}
163
164#[derive(Debug, Clone, Default, Serialize, Deserialize)]
166pub struct ErrorDetails {
167 #[serde(default, skip_serializing_if = "Option::is_none")]
169 pub upstream_status_code: Option<u16>,
170 #[serde(default, skip_serializing_if = "Option::is_none")]
172 pub request_id: Option<String>,
173 #[serde(default, skip_serializing_if = "Option::is_none")]
175 pub provider: Option<String>,
176 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
178 pub metadata: HashMap<String, Value>,
179}
180
181pub struct AdkError {
211 pub component: ErrorComponent,
213 pub category: ErrorCategory,
215 pub code: &'static str,
217 pub message: String,
219 pub retry: RetryHint,
221 pub details: Box<ErrorDetails>,
223 source: Option<Box<dyn std::error::Error + Send + Sync>>,
224}
225
226impl fmt::Debug for AdkError {
227 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228 let mut d = f.debug_struct("AdkError");
229 d.field("component", &self.component)
230 .field("category", &self.category)
231 .field("code", &self.code)
232 .field("message", &self.message)
233 .field("retry", &self.retry)
234 .field("details", &self.details);
235 if let Some(src) = &self.source {
236 d.field("source", &format_args!("{src}"));
237 }
238 d.finish()
239 }
240}
241
242impl fmt::Display for AdkError {
243 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244 write!(f, "{}.{}: {}", self.component, self.category, self.message)
245 }
246}
247
248impl std::error::Error for AdkError {
249 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
250 self.source.as_ref().map(|e| e.as_ref() as &(dyn std::error::Error + 'static))
251 }
252}
253
254const _: () = {
255 fn _assert_send<T: Send>() {}
256 fn _assert_sync<T: Sync>() {}
257 fn _assertions() {
258 _assert_send::<AdkError>();
259 _assert_sync::<AdkError>();
260 }
261};
262
263impl AdkError {
264 pub fn new(
266 component: ErrorComponent,
267 category: ErrorCategory,
268 code: &'static str,
269 message: impl Into<String>,
270 ) -> Self {
271 Self {
272 component,
273 category,
274 code,
275 message: message.into(),
276 retry: RetryHint::for_category(category),
277 details: Box::new(ErrorDetails::default()),
278 source: None,
279 }
280 }
281
282 pub fn with_source(mut self, source: impl std::error::Error + Send + Sync + 'static) -> Self {
284 self.source = Some(Box::new(source));
285 self
286 }
287
288 pub fn with_retry(mut self, retry: RetryHint) -> Self {
290 self.retry = retry;
291 self
292 }
293
294 pub fn with_details(mut self, details: ErrorDetails) -> Self {
296 self.details = Box::new(details);
297 self
298 }
299
300 pub fn with_upstream_status(mut self, status_code: u16) -> Self {
302 self.details.upstream_status_code = Some(status_code);
303 self
304 }
305
306 pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
308 self.details.request_id = Some(request_id.into());
309 self
310 }
311
312 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
314 self.details.provider = Some(provider.into());
315 self
316 }
317}
318
319impl AdkError {
320 pub fn not_found(
322 component: ErrorComponent,
323 code: &'static str,
324 message: impl Into<String>,
325 ) -> Self {
326 Self::new(component, ErrorCategory::NotFound, code, message)
327 }
328
329 pub fn rate_limited(
331 component: ErrorComponent,
332 code: &'static str,
333 message: impl Into<String>,
334 ) -> Self {
335 Self::new(component, ErrorCategory::RateLimited, code, message)
336 }
337
338 pub fn unauthorized(
340 component: ErrorComponent,
341 code: &'static str,
342 message: impl Into<String>,
343 ) -> Self {
344 Self::new(component, ErrorCategory::Unauthorized, code, message)
345 }
346
347 pub fn internal(
349 component: ErrorComponent,
350 code: &'static str,
351 message: impl Into<String>,
352 ) -> Self {
353 Self::new(component, ErrorCategory::Internal, code, message)
354 }
355
356 pub fn timeout(
358 component: ErrorComponent,
359 code: &'static str,
360 message: impl Into<String>,
361 ) -> Self {
362 Self::new(component, ErrorCategory::Timeout, code, message)
363 }
364
365 pub fn unavailable(
367 component: ErrorComponent,
368 code: &'static str,
369 message: impl Into<String>,
370 ) -> Self {
371 Self::new(component, ErrorCategory::Unavailable, code, message)
372 }
373}
374
375impl AdkError {
376 pub fn agent(message: impl Into<String>) -> Self {
378 Self::new(ErrorComponent::Agent, ErrorCategory::Internal, "agent.legacy", message)
379 }
380
381 pub fn model(message: impl Into<String>) -> Self {
383 Self::new(ErrorComponent::Model, ErrorCategory::Internal, "model.legacy", message)
384 }
385
386 pub fn tool(message: impl Into<String>) -> Self {
388 Self::new(ErrorComponent::Tool, ErrorCategory::Internal, "tool.legacy", message)
389 }
390
391 pub fn session(message: impl Into<String>) -> Self {
393 Self::new(ErrorComponent::Session, ErrorCategory::Internal, "session.legacy", message)
394 }
395
396 pub fn memory(message: impl Into<String>) -> Self {
398 Self::new(ErrorComponent::Memory, ErrorCategory::Internal, "memory.legacy", message)
399 }
400
401 pub fn config(message: impl Into<String>) -> Self {
403 Self::new(ErrorComponent::Server, ErrorCategory::InvalidInput, "config.legacy", message)
404 }
405
406 pub fn artifact(message: impl Into<String>) -> Self {
408 Self::new(ErrorComponent::Artifact, ErrorCategory::Internal, "artifact.legacy", message)
409 }
410}
411
412impl AdkError {
413 pub fn is_agent(&self) -> bool {
415 self.component == ErrorComponent::Agent
416 }
417 pub fn is_model(&self) -> bool {
419 self.component == ErrorComponent::Model
420 }
421 pub fn is_tool(&self) -> bool {
423 self.component == ErrorComponent::Tool
424 }
425 pub fn is_session(&self) -> bool {
427 self.component == ErrorComponent::Session
428 }
429 pub fn is_artifact(&self) -> bool {
431 self.component == ErrorComponent::Artifact
432 }
433 pub fn is_memory(&self) -> bool {
435 self.component == ErrorComponent::Memory
436 }
437 pub fn is_config(&self) -> bool {
439 self.code == "config.legacy"
440 }
441}
442
443impl AdkError {
444 pub fn is_retryable(&self) -> bool {
446 self.retry.should_retry
447 }
448 pub fn is_not_found(&self) -> bool {
450 self.category == ErrorCategory::NotFound
451 }
452 pub fn is_unauthorized(&self) -> bool {
454 self.category == ErrorCategory::Unauthorized
455 }
456 pub fn is_rate_limited(&self) -> bool {
458 self.category == ErrorCategory::RateLimited
459 }
460 pub fn is_timeout(&self) -> bool {
462 self.category == ErrorCategory::Timeout
463 }
464}
465
466impl AdkError {
467 #[allow(unreachable_patterns)]
469 pub fn http_status_code(&self) -> u16 {
470 match self.category {
471 ErrorCategory::InvalidInput => 400,
472 ErrorCategory::Unauthorized => 401,
473 ErrorCategory::Forbidden => 403,
474 ErrorCategory::NotFound => 404,
475 ErrorCategory::RateLimited => 429,
476 ErrorCategory::Timeout => 408,
477 ErrorCategory::Unavailable => 503,
478 ErrorCategory::Cancelled => 499,
479 ErrorCategory::Internal => 500,
480 ErrorCategory::Unsupported => 501,
481 _ => 500,
482 }
483 }
484}
485
486impl AdkError {
487 pub fn to_problem_json(&self) -> Value {
489 json!({
490 "error": {
491 "code": self.code,
492 "message": self.message,
493 "component": self.component,
494 "category": self.category,
495 "requestId": self.details.request_id,
496 "retryAfter": self.retry.retry_after_ms,
497 "upstreamStatusCode": self.details.upstream_status_code,
498 }
499 })
500 }
501}
502
503pub type Result<T> = std::result::Result<T, AdkError>;
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_new_sets_fields() {
512 let err = AdkError::new(
513 ErrorComponent::Model,
514 ErrorCategory::RateLimited,
515 "model.rate_limited",
516 "too many requests",
517 );
518 assert_eq!(err.component, ErrorComponent::Model);
519 assert_eq!(err.category, ErrorCategory::RateLimited);
520 assert_eq!(err.code, "model.rate_limited");
521 assert_eq!(err.message, "too many requests");
522 assert!(err.retry.should_retry);
523 }
524
525 #[test]
526 fn test_display_format() {
527 let err = AdkError::new(
528 ErrorComponent::Session,
529 ErrorCategory::NotFound,
530 "session.not_found",
531 "session xyz not found",
532 );
533 assert_eq!(err.to_string(), "session.not_found: session xyz not found");
534 }
535
536 #[test]
537 fn test_convenience_not_found() {
538 let err = AdkError::not_found(ErrorComponent::Session, "session.not_found", "gone");
539 assert_eq!(err.category, ErrorCategory::NotFound);
540 assert!(!err.is_retryable());
541 }
542
543 #[test]
544 fn test_convenience_rate_limited() {
545 let err = AdkError::rate_limited(ErrorComponent::Model, "model.rate_limited", "slow down");
546 assert!(err.is_retryable());
547 assert!(err.is_rate_limited());
548 }
549
550 #[test]
551 fn test_convenience_unauthorized() {
552 let err = AdkError::unauthorized(ErrorComponent::Auth, "auth.unauthorized", "bad token");
553 assert!(err.is_unauthorized());
554 assert!(!err.is_retryable());
555 }
556
557 #[test]
558 fn test_convenience_internal() {
559 let err = AdkError::internal(ErrorComponent::Agent, "agent.internal", "oops");
560 assert_eq!(err.category, ErrorCategory::Internal);
561 }
562
563 #[test]
564 fn test_convenience_timeout() {
565 let err = AdkError::timeout(ErrorComponent::Model, "model.timeout", "timed out");
566 assert!(err.is_timeout());
567 assert!(err.is_retryable());
568 }
569
570 #[test]
571 fn test_convenience_unavailable() {
572 let err = AdkError::unavailable(ErrorComponent::Model, "model.unavailable", "503");
573 assert!(err.is_retryable());
574 }
575
576 #[test]
577 fn test_backward_compat_agent() {
578 let err = AdkError::agent("test error");
579 assert!(err.is_agent());
580 assert_eq!(err.code, "agent.legacy");
581 assert_eq!(err.category, ErrorCategory::Internal);
582 assert_eq!(err.to_string(), "agent.internal: test error");
583 }
584
585 #[test]
586 fn test_backward_compat_model() {
587 let err = AdkError::model("model fail");
588 assert!(err.is_model());
589 assert_eq!(err.code, "model.legacy");
590 }
591
592 #[test]
593 fn test_backward_compat_tool() {
594 let err = AdkError::tool("tool fail");
595 assert!(err.is_tool());
596 assert_eq!(err.code, "tool.legacy");
597 }
598
599 #[test]
600 fn test_backward_compat_session() {
601 let err = AdkError::session("session fail");
602 assert!(err.is_session());
603 assert_eq!(err.code, "session.legacy");
604 }
605
606 #[test]
607 fn test_backward_compat_memory() {
608 let err = AdkError::memory("memory fail");
609 assert!(err.is_memory());
610 assert_eq!(err.code, "memory.legacy");
611 }
612
613 #[test]
614 fn test_backward_compat_artifact() {
615 let err = AdkError::artifact("artifact fail");
616 assert!(err.is_artifact());
617 assert_eq!(err.code, "artifact.legacy");
618 }
619
620 #[test]
621 fn test_backward_compat_config() {
622 let err = AdkError::config("bad config");
623 assert!(err.is_config());
624 assert_eq!(err.code, "config.legacy");
625 assert_eq!(err.component, ErrorComponent::Server);
626 assert_eq!(err.category, ErrorCategory::InvalidInput);
627 }
628
629 #[test]
630 fn test_backward_compat_codes_end_with_legacy() {
631 let errors = [
632 AdkError::agent("a"),
633 AdkError::model("m"),
634 AdkError::tool("t"),
635 AdkError::session("s"),
636 AdkError::memory("mem"),
637 AdkError::config("c"),
638 AdkError::artifact("art"),
639 ];
640 for err in &errors {
641 assert!(err.code.ends_with(".legacy"), "code '{}' should end with .legacy", err.code);
642 }
643 }
644
645 #[test]
646 fn test_is_config_false_for_non_config() {
647 assert!(!AdkError::agent("not config").is_config());
648 }
649
650 #[test]
651 fn test_retryable_categories_default_true() {
652 for cat in [ErrorCategory::RateLimited, ErrorCategory::Unavailable, ErrorCategory::Timeout]
653 {
654 let err = AdkError::new(ErrorComponent::Model, cat, "test", "msg");
655 assert!(err.is_retryable(), "expected is_retryable() == true for {cat}");
656 }
657 }
658
659 #[test]
660 fn test_retryable_override_to_false() {
661 let err =
662 AdkError::new(ErrorComponent::Model, ErrorCategory::RateLimited, "m.rl", "overridden")
663 .with_retry(RetryHint { should_retry: false, ..Default::default() });
664 assert!(!err.is_retryable());
665 }
666
667 #[test]
668 fn test_non_retryable_categories_default_false() {
669 for cat in [
670 ErrorCategory::InvalidInput,
671 ErrorCategory::Unauthorized,
672 ErrorCategory::Forbidden,
673 ErrorCategory::NotFound,
674 ErrorCategory::Cancelled,
675 ErrorCategory::Internal,
676 ErrorCategory::Unsupported,
677 ] {
678 let err = AdkError::new(ErrorComponent::Model, cat, "test", "msg");
679 assert!(!err.is_retryable(), "expected is_retryable() == false for {cat}");
680 }
681 }
682
683 #[test]
684 fn test_http_status_code_mapping() {
685 let cases = [
686 (ErrorCategory::InvalidInput, 400),
687 (ErrorCategory::Unauthorized, 401),
688 (ErrorCategory::Forbidden, 403),
689 (ErrorCategory::NotFound, 404),
690 (ErrorCategory::RateLimited, 429),
691 (ErrorCategory::Timeout, 408),
692 (ErrorCategory::Unavailable, 503),
693 (ErrorCategory::Cancelled, 499),
694 (ErrorCategory::Internal, 500),
695 (ErrorCategory::Unsupported, 501),
696 ];
697 for (cat, expected) in &cases {
698 let err = AdkError::new(ErrorComponent::Server, *cat, "test", "msg");
699 assert_eq!(err.http_status_code(), *expected, "wrong status for {cat}");
700 }
701 }
702
703 #[test]
704 fn test_source_returns_some_when_set() {
705 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
706 let err = AdkError::new(ErrorComponent::Session, ErrorCategory::NotFound, "s.f", "missing")
707 .with_source(io_err);
708 assert!(std::error::Error::source(&err).is_some());
709 }
710
711 #[test]
712 fn test_source_returns_none_when_not_set() {
713 assert!(std::error::Error::source(&AdkError::agent("no source")).is_none());
714 }
715
716 #[test]
717 fn test_retry_hint_for_category() {
718 assert!(RetryHint::for_category(ErrorCategory::RateLimited).should_retry);
719 assert!(RetryHint::for_category(ErrorCategory::Unavailable).should_retry);
720 assert!(RetryHint::for_category(ErrorCategory::Timeout).should_retry);
721 assert!(!RetryHint::for_category(ErrorCategory::Internal).should_retry);
722 assert!(!RetryHint::for_category(ErrorCategory::NotFound).should_retry);
723 }
724
725 #[test]
726 fn test_retry_hint_with_retry_after() {
727 let hint = RetryHint::default().with_retry_after(Duration::from_secs(5));
728 assert_eq!(hint.retry_after_ms, Some(5000));
729 assert_eq!(hint.retry_after(), Some(Duration::from_secs(5)));
730 }
731
732 #[test]
733 fn test_to_problem_json() {
734 let err = AdkError::new(
735 ErrorComponent::Model,
736 ErrorCategory::RateLimited,
737 "model.rate_limited",
738 "slow down",
739 )
740 .with_request_id("req-123")
741 .with_upstream_status(429);
742 let j = err.to_problem_json();
743 let o = &j["error"];
744 assert_eq!(o["code"], "model.rate_limited");
745 assert_eq!(o["message"], "slow down");
746 assert_eq!(o["component"], "model");
747 assert_eq!(o["category"], "rate_limited");
748 assert_eq!(o["requestId"], "req-123");
749 assert_eq!(o["upstreamStatusCode"], 429);
750 }
751
752 #[test]
753 fn test_to_problem_json_null_optionals() {
754 let j = AdkError::agent("simple").to_problem_json();
755 let o = &j["error"];
756 assert!(o["requestId"].is_null());
757 assert!(o["retryAfter"].is_null());
758 assert!(o["upstreamStatusCode"].is_null());
759 }
760
761 #[test]
762 fn test_builder_chaining() {
763 let err = AdkError::new(ErrorComponent::Model, ErrorCategory::Unavailable, "m.u", "down")
764 .with_provider("openai")
765 .with_request_id("req-456")
766 .with_upstream_status(503)
767 .with_retry(RetryHint {
768 should_retry: true,
769 retry_after_ms: Some(1000),
770 max_attempts: Some(3),
771 });
772 assert_eq!(err.details.provider.as_deref(), Some("openai"));
773 assert_eq!(err.details.request_id.as_deref(), Some("req-456"));
774 assert_eq!(err.details.upstream_status_code, Some(503));
775 assert!(err.is_retryable());
776 assert_eq!(err.retry.retry_after_ms, Some(1000));
777 assert_eq!(err.retry.max_attempts, Some(3));
778 }
779
780 #[test]
781 fn test_error_component_display() {
782 assert_eq!(ErrorComponent::Agent.to_string(), "agent");
783 assert_eq!(ErrorComponent::Model.to_string(), "model");
784 assert_eq!(ErrorComponent::Graph.to_string(), "graph");
785 assert_eq!(ErrorComponent::Realtime.to_string(), "realtime");
786 assert_eq!(ErrorComponent::Deploy.to_string(), "deploy");
787 }
788
789 #[test]
790 fn test_error_category_display() {
791 assert_eq!(ErrorCategory::InvalidInput.to_string(), "invalid_input");
792 assert_eq!(ErrorCategory::RateLimited.to_string(), "rate_limited");
793 assert_eq!(ErrorCategory::NotFound.to_string(), "not_found");
794 assert_eq!(ErrorCategory::Internal.to_string(), "internal");
795 }
796
797 #[test]
798 #[allow(clippy::unnecessary_literal_unwrap)]
799 fn test_result_type() {
800 let ok: Result<i32> = Ok(42);
801 assert_eq!(ok.unwrap(), 42);
802 let err: Result<i32> = Err(AdkError::config("invalid"));
803 assert!(err.is_err());
804 }
805
806 #[test]
807 fn test_with_details() {
808 let d = ErrorDetails {
809 upstream_status_code: Some(502),
810 request_id: Some("abc".into()),
811 provider: Some("gemini".into()),
812 metadata: HashMap::new(),
813 };
814 let err = AdkError::agent("test").with_details(d);
815 assert_eq!(err.details.upstream_status_code, Some(502));
816 assert_eq!(err.details.request_id.as_deref(), Some("abc"));
817 assert_eq!(err.details.provider.as_deref(), Some("gemini"));
818 }
819
820 #[test]
821 fn test_debug_impl() {
822 let s = format!("{:?}", AdkError::agent("debug test"));
823 assert!(s.contains("AdkError"));
824 assert!(s.contains("agent.legacy"));
825 }
826
827 #[test]
828 fn test_send_sync() {
829 fn assert_send<T: Send>() {}
830 fn assert_sync<T: Sync>() {}
831 assert_send::<AdkError>();
832 assert_sync::<AdkError>();
833 }
834}