1use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use std::collections::HashMap;
6use std::fmt;
7use std::time::Duration;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19#[non_exhaustive]
20pub enum ErrorComponent {
21 Agent,
23 Model,
25 Tool,
27 Session,
29 Artifact,
31 Memory,
33 Graph,
35 Realtime,
37 Code,
39 Server,
41 Auth,
43 Guardrail,
45 Eval,
47 Deploy,
49}
50
51impl fmt::Display for ErrorComponent {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 let s = match self {
54 Self::Agent => "agent",
55 Self::Model => "model",
56 Self::Tool => "tool",
57 Self::Session => "session",
58 Self::Artifact => "artifact",
59 Self::Memory => "memory",
60 Self::Graph => "graph",
61 Self::Realtime => "realtime",
62 Self::Code => "code",
63 Self::Server => "server",
64 Self::Auth => "auth",
65 Self::Guardrail => "guardrail",
66 Self::Eval => "eval",
67 Self::Deploy => "deploy",
68 };
69 f.write_str(s)
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
87#[serde(rename_all = "snake_case")]
88#[non_exhaustive]
89pub enum ErrorCategory {
90 InvalidInput,
92 Unauthorized,
94 Forbidden,
96 NotFound,
98 RateLimited,
100 Timeout,
102 Unavailable,
104 Cancelled,
106 Internal,
108 Unsupported,
110}
111
112impl fmt::Display for ErrorCategory {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 let s = match self {
115 Self::InvalidInput => "invalid_input",
116 Self::Unauthorized => "unauthorized",
117 Self::Forbidden => "forbidden",
118 Self::NotFound => "not_found",
119 Self::RateLimited => "rate_limited",
120 Self::Timeout => "timeout",
121 Self::Unavailable => "unavailable",
122 Self::Cancelled => "cancelled",
123 Self::Internal => "internal",
124 Self::Unsupported => "unsupported",
125 };
126 f.write_str(s)
127 }
128}
129
130#[derive(Debug, Clone, Default, Serialize, Deserialize)]
132pub struct RetryHint {
133 pub should_retry: bool,
135 #[serde(default, skip_serializing_if = "Option::is_none")]
137 pub retry_after_ms: Option<u64>,
138 #[serde(default, skip_serializing_if = "Option::is_none")]
140 pub max_attempts: Option<u32>,
141}
142
143impl RetryHint {
144 pub fn for_category(category: ErrorCategory) -> Self {
146 match category {
147 ErrorCategory::RateLimited | ErrorCategory::Unavailable | ErrorCategory::Timeout => {
148 Self { should_retry: true, ..Default::default() }
149 }
150 _ => Self::default(),
151 }
152 }
153
154 pub fn retry_after(&self) -> Option<Duration> {
156 self.retry_after_ms.map(Duration::from_millis)
157 }
158
159 pub fn with_retry_after(mut self, duration: Duration) -> Self {
161 self.retry_after_ms = Some(duration.as_millis() as u64);
162 self
163 }
164}
165
166#[derive(Debug, Clone, Default, Serialize, Deserialize)]
168pub struct ErrorDetails {
169 #[serde(default, skip_serializing_if = "Option::is_none")]
171 pub upstream_status_code: Option<u16>,
172 #[serde(default, skip_serializing_if = "Option::is_none")]
174 pub request_id: Option<String>,
175 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub provider: Option<String>,
178 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
180 pub metadata: HashMap<String, Value>,
181}
182
183pub struct AdkError {
213 pub component: ErrorComponent,
215 pub category: ErrorCategory,
217 pub code: &'static str,
219 pub message: String,
221 pub retry: RetryHint,
223 pub details: Box<ErrorDetails>,
225 source: Option<Box<dyn std::error::Error + Send + Sync>>,
226}
227
228impl fmt::Debug for AdkError {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 let mut d = f.debug_struct("AdkError");
231 d.field("component", &self.component)
232 .field("category", &self.category)
233 .field("code", &self.code)
234 .field("message", &self.message)
235 .field("retry", &self.retry)
236 .field("details", &self.details);
237 if let Some(src) = &self.source {
238 d.field("source", &format_args!("{src}"));
239 }
240 d.finish()
241 }
242}
243
244impl fmt::Display for AdkError {
245 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246 write!(f, "{}.{}: {}", self.component, self.category, self.message)
247 }
248}
249
250impl std::error::Error for AdkError {
251 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
252 self.source.as_ref().map(|e| e.as_ref() as &(dyn std::error::Error + 'static))
253 }
254}
255
256const _: () = {
257 fn _assert_send<T: Send>() {}
258 fn _assert_sync<T: Sync>() {}
259 fn _assertions() {
260 _assert_send::<AdkError>();
261 _assert_sync::<AdkError>();
262 }
263};
264
265impl AdkError {
266 pub fn new(
268 component: ErrorComponent,
269 category: ErrorCategory,
270 code: &'static str,
271 message: impl Into<String>,
272 ) -> Self {
273 Self {
274 component,
275 category,
276 code,
277 message: message.into(),
278 retry: RetryHint::for_category(category),
279 details: Box::new(ErrorDetails::default()),
280 source: None,
281 }
282 }
283
284 pub fn with_source(mut self, source: impl std::error::Error + Send + Sync + 'static) -> Self {
286 self.source = Some(Box::new(source));
287 self
288 }
289
290 pub fn with_retry(mut self, retry: RetryHint) -> Self {
292 self.retry = retry;
293 self
294 }
295
296 pub fn with_details(mut self, details: ErrorDetails) -> Self {
298 self.details = Box::new(details);
299 self
300 }
301
302 pub fn with_upstream_status(mut self, status_code: u16) -> Self {
304 self.details.upstream_status_code = Some(status_code);
305 self
306 }
307
308 pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
310 self.details.request_id = Some(request_id.into());
311 self
312 }
313
314 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
316 self.details.provider = Some(provider.into());
317 self
318 }
319}
320
321impl AdkError {
322 pub fn not_found(
324 component: ErrorComponent,
325 code: &'static str,
326 message: impl Into<String>,
327 ) -> Self {
328 Self::new(component, ErrorCategory::NotFound, code, message)
329 }
330
331 pub fn rate_limited(
333 component: ErrorComponent,
334 code: &'static str,
335 message: impl Into<String>,
336 ) -> Self {
337 Self::new(component, ErrorCategory::RateLimited, code, message)
338 }
339
340 pub fn unauthorized(
342 component: ErrorComponent,
343 code: &'static str,
344 message: impl Into<String>,
345 ) -> Self {
346 Self::new(component, ErrorCategory::Unauthorized, code, message)
347 }
348
349 pub fn internal(
351 component: ErrorComponent,
352 code: &'static str,
353 message: impl Into<String>,
354 ) -> Self {
355 Self::new(component, ErrorCategory::Internal, code, message)
356 }
357
358 pub fn timeout(
360 component: ErrorComponent,
361 code: &'static str,
362 message: impl Into<String>,
363 ) -> Self {
364 Self::new(component, ErrorCategory::Timeout, code, message)
365 }
366
367 pub fn unavailable(
369 component: ErrorComponent,
370 code: &'static str,
371 message: impl Into<String>,
372 ) -> Self {
373 Self::new(component, ErrorCategory::Unavailable, code, message)
374 }
375}
376
377impl AdkError {
378 pub fn agent(message: impl Into<String>) -> Self {
380 Self::new(ErrorComponent::Agent, ErrorCategory::Internal, "agent.legacy", message)
381 }
382
383 pub fn model(message: impl Into<String>) -> Self {
385 Self::new(ErrorComponent::Model, ErrorCategory::Internal, "model.legacy", message)
386 }
387
388 pub fn tool(message: impl Into<String>) -> Self {
390 Self::new(ErrorComponent::Tool, ErrorCategory::Internal, "tool.legacy", message)
391 }
392
393 pub fn session(message: impl Into<String>) -> Self {
395 Self::new(ErrorComponent::Session, ErrorCategory::Internal, "session.legacy", message)
396 }
397
398 pub fn memory(message: impl Into<String>) -> Self {
400 Self::new(ErrorComponent::Memory, ErrorCategory::Internal, "memory.legacy", message)
401 }
402
403 pub fn config(message: impl Into<String>) -> Self {
405 Self::new(ErrorComponent::Server, ErrorCategory::InvalidInput, "config.legacy", message)
406 }
407
408 pub fn artifact(message: impl Into<String>) -> Self {
410 Self::new(ErrorComponent::Artifact, ErrorCategory::Internal, "artifact.legacy", message)
411 }
412}
413
414impl AdkError {
415 pub fn is_agent(&self) -> bool {
417 self.component == ErrorComponent::Agent
418 }
419 pub fn is_model(&self) -> bool {
421 self.component == ErrorComponent::Model
422 }
423 pub fn is_tool(&self) -> bool {
425 self.component == ErrorComponent::Tool
426 }
427 pub fn is_session(&self) -> bool {
429 self.component == ErrorComponent::Session
430 }
431 pub fn is_artifact(&self) -> bool {
433 self.component == ErrorComponent::Artifact
434 }
435 pub fn is_memory(&self) -> bool {
437 self.component == ErrorComponent::Memory
438 }
439 pub fn is_config(&self) -> bool {
441 self.code == "config.legacy"
442 }
443}
444
445impl AdkError {
446 pub fn is_retryable(&self) -> bool {
448 self.retry.should_retry
449 }
450 pub fn is_not_found(&self) -> bool {
452 self.category == ErrorCategory::NotFound
453 }
454 pub fn is_unauthorized(&self) -> bool {
456 self.category == ErrorCategory::Unauthorized
457 }
458 pub fn is_rate_limited(&self) -> bool {
460 self.category == ErrorCategory::RateLimited
461 }
462 pub fn is_timeout(&self) -> bool {
464 self.category == ErrorCategory::Timeout
465 }
466}
467
468impl AdkError {
469 #[allow(unreachable_patterns)]
471 pub fn http_status_code(&self) -> u16 {
472 match self.category {
473 ErrorCategory::InvalidInput => 400,
474 ErrorCategory::Unauthorized => 401,
475 ErrorCategory::Forbidden => 403,
476 ErrorCategory::NotFound => 404,
477 ErrorCategory::RateLimited => 429,
478 ErrorCategory::Timeout => 408,
479 ErrorCategory::Unavailable => 503,
480 ErrorCategory::Cancelled => 499,
481 ErrorCategory::Internal => 500,
482 ErrorCategory::Unsupported => 501,
483 _ => 500,
484 }
485 }
486}
487
488impl AdkError {
489 pub fn to_problem_json(&self) -> Value {
491 json!({
492 "error": {
493 "code": self.code,
494 "message": self.message,
495 "component": self.component,
496 "category": self.category,
497 "requestId": self.details.request_id,
498 "retryAfter": self.retry.retry_after_ms,
499 "upstreamStatusCode": self.details.upstream_status_code,
500 }
501 })
502 }
503}
504
505pub type Result<T> = std::result::Result<T, AdkError>;
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_new_sets_fields() {
514 let err = AdkError::new(
515 ErrorComponent::Model,
516 ErrorCategory::RateLimited,
517 "model.rate_limited",
518 "too many requests",
519 );
520 assert_eq!(err.component, ErrorComponent::Model);
521 assert_eq!(err.category, ErrorCategory::RateLimited);
522 assert_eq!(err.code, "model.rate_limited");
523 assert_eq!(err.message, "too many requests");
524 assert!(err.retry.should_retry);
525 }
526
527 #[test]
528 fn test_display_format() {
529 let err = AdkError::new(
530 ErrorComponent::Session,
531 ErrorCategory::NotFound,
532 "session.not_found",
533 "session xyz not found",
534 );
535 assert_eq!(err.to_string(), "session.not_found: session xyz not found");
536 }
537
538 #[test]
539 fn test_convenience_not_found() {
540 let err = AdkError::not_found(ErrorComponent::Session, "session.not_found", "gone");
541 assert_eq!(err.category, ErrorCategory::NotFound);
542 assert!(!err.is_retryable());
543 }
544
545 #[test]
546 fn test_convenience_rate_limited() {
547 let err = AdkError::rate_limited(ErrorComponent::Model, "model.rate_limited", "slow down");
548 assert!(err.is_retryable());
549 assert!(err.is_rate_limited());
550 }
551
552 #[test]
553 fn test_convenience_unauthorized() {
554 let err = AdkError::unauthorized(ErrorComponent::Auth, "auth.unauthorized", "bad token");
555 assert!(err.is_unauthorized());
556 assert!(!err.is_retryable());
557 }
558
559 #[test]
560 fn test_convenience_internal() {
561 let err = AdkError::internal(ErrorComponent::Agent, "agent.internal", "oops");
562 assert_eq!(err.category, ErrorCategory::Internal);
563 }
564
565 #[test]
566 fn test_convenience_timeout() {
567 let err = AdkError::timeout(ErrorComponent::Model, "model.timeout", "timed out");
568 assert!(err.is_timeout());
569 assert!(err.is_retryable());
570 }
571
572 #[test]
573 fn test_convenience_unavailable() {
574 let err = AdkError::unavailable(ErrorComponent::Model, "model.unavailable", "503");
575 assert!(err.is_retryable());
576 }
577
578 #[test]
579 fn test_backward_compat_agent() {
580 let err = AdkError::agent("test error");
581 assert!(err.is_agent());
582 assert_eq!(err.code, "agent.legacy");
583 assert_eq!(err.category, ErrorCategory::Internal);
584 assert_eq!(err.to_string(), "agent.internal: test error");
585 }
586
587 #[test]
588 fn test_backward_compat_model() {
589 let err = AdkError::model("model fail");
590 assert!(err.is_model());
591 assert_eq!(err.code, "model.legacy");
592 }
593
594 #[test]
595 fn test_backward_compat_tool() {
596 let err = AdkError::tool("tool fail");
597 assert!(err.is_tool());
598 assert_eq!(err.code, "tool.legacy");
599 }
600
601 #[test]
602 fn test_backward_compat_session() {
603 let err = AdkError::session("session fail");
604 assert!(err.is_session());
605 assert_eq!(err.code, "session.legacy");
606 }
607
608 #[test]
609 fn test_backward_compat_memory() {
610 let err = AdkError::memory("memory fail");
611 assert!(err.is_memory());
612 assert_eq!(err.code, "memory.legacy");
613 }
614
615 #[test]
616 fn test_backward_compat_artifact() {
617 let err = AdkError::artifact("artifact fail");
618 assert!(err.is_artifact());
619 assert_eq!(err.code, "artifact.legacy");
620 }
621
622 #[test]
623 fn test_backward_compat_config() {
624 let err = AdkError::config("bad config");
625 assert!(err.is_config());
626 assert_eq!(err.code, "config.legacy");
627 assert_eq!(err.component, ErrorComponent::Server);
628 assert_eq!(err.category, ErrorCategory::InvalidInput);
629 }
630
631 #[test]
632 fn test_backward_compat_codes_end_with_legacy() {
633 let errors = [
634 AdkError::agent("a"),
635 AdkError::model("m"),
636 AdkError::tool("t"),
637 AdkError::session("s"),
638 AdkError::memory("mem"),
639 AdkError::config("c"),
640 AdkError::artifact("art"),
641 ];
642 for err in &errors {
643 assert!(err.code.ends_with(".legacy"), "code '{}' should end with .legacy", err.code);
644 }
645 }
646
647 #[test]
648 fn test_is_config_false_for_non_config() {
649 assert!(!AdkError::agent("not config").is_config());
650 }
651
652 #[test]
653 fn test_retryable_categories_default_true() {
654 for cat in [ErrorCategory::RateLimited, ErrorCategory::Unavailable, ErrorCategory::Timeout]
655 {
656 let err = AdkError::new(ErrorComponent::Model, cat, "test", "msg");
657 assert!(err.is_retryable(), "expected is_retryable() == true for {cat}");
658 }
659 }
660
661 #[test]
662 fn test_retryable_override_to_false() {
663 let err =
664 AdkError::new(ErrorComponent::Model, ErrorCategory::RateLimited, "m.rl", "overridden")
665 .with_retry(RetryHint { should_retry: false, ..Default::default() });
666 assert!(!err.is_retryable());
667 }
668
669 #[test]
670 fn test_non_retryable_categories_default_false() {
671 for cat in [
672 ErrorCategory::InvalidInput,
673 ErrorCategory::Unauthorized,
674 ErrorCategory::Forbidden,
675 ErrorCategory::NotFound,
676 ErrorCategory::Cancelled,
677 ErrorCategory::Internal,
678 ErrorCategory::Unsupported,
679 ] {
680 let err = AdkError::new(ErrorComponent::Model, cat, "test", "msg");
681 assert!(!err.is_retryable(), "expected is_retryable() == false for {cat}");
682 }
683 }
684
685 #[test]
686 fn test_http_status_code_mapping() {
687 let cases = [
688 (ErrorCategory::InvalidInput, 400),
689 (ErrorCategory::Unauthorized, 401),
690 (ErrorCategory::Forbidden, 403),
691 (ErrorCategory::NotFound, 404),
692 (ErrorCategory::RateLimited, 429),
693 (ErrorCategory::Timeout, 408),
694 (ErrorCategory::Unavailable, 503),
695 (ErrorCategory::Cancelled, 499),
696 (ErrorCategory::Internal, 500),
697 (ErrorCategory::Unsupported, 501),
698 ];
699 for (cat, expected) in &cases {
700 let err = AdkError::new(ErrorComponent::Server, *cat, "test", "msg");
701 assert_eq!(err.http_status_code(), *expected, "wrong status for {cat}");
702 }
703 }
704
705 #[test]
706 fn test_source_returns_some_when_set() {
707 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
708 let err = AdkError::new(ErrorComponent::Session, ErrorCategory::NotFound, "s.f", "missing")
709 .with_source(io_err);
710 assert!(std::error::Error::source(&err).is_some());
711 }
712
713 #[test]
714 fn test_source_returns_none_when_not_set() {
715 assert!(std::error::Error::source(&AdkError::agent("no source")).is_none());
716 }
717
718 #[test]
719 fn test_retry_hint_for_category() {
720 assert!(RetryHint::for_category(ErrorCategory::RateLimited).should_retry);
721 assert!(RetryHint::for_category(ErrorCategory::Unavailable).should_retry);
722 assert!(RetryHint::for_category(ErrorCategory::Timeout).should_retry);
723 assert!(!RetryHint::for_category(ErrorCategory::Internal).should_retry);
724 assert!(!RetryHint::for_category(ErrorCategory::NotFound).should_retry);
725 }
726
727 #[test]
728 fn test_retry_hint_with_retry_after() {
729 let hint = RetryHint::default().with_retry_after(Duration::from_secs(5));
730 assert_eq!(hint.retry_after_ms, Some(5000));
731 assert_eq!(hint.retry_after(), Some(Duration::from_secs(5)));
732 }
733
734 #[test]
735 fn test_to_problem_json() {
736 let err = AdkError::new(
737 ErrorComponent::Model,
738 ErrorCategory::RateLimited,
739 "model.rate_limited",
740 "slow down",
741 )
742 .with_request_id("req-123")
743 .with_upstream_status(429);
744 let j = err.to_problem_json();
745 let o = &j["error"];
746 assert_eq!(o["code"], "model.rate_limited");
747 assert_eq!(o["message"], "slow down");
748 assert_eq!(o["component"], "model");
749 assert_eq!(o["category"], "rate_limited");
750 assert_eq!(o["requestId"], "req-123");
751 assert_eq!(o["upstreamStatusCode"], 429);
752 }
753
754 #[test]
755 fn test_to_problem_json_null_optionals() {
756 let j = AdkError::agent("simple").to_problem_json();
757 let o = &j["error"];
758 assert!(o["requestId"].is_null());
759 assert!(o["retryAfter"].is_null());
760 assert!(o["upstreamStatusCode"].is_null());
761 }
762
763 #[test]
764 fn test_builder_chaining() {
765 let err = AdkError::new(ErrorComponent::Model, ErrorCategory::Unavailable, "m.u", "down")
766 .with_provider("openai")
767 .with_request_id("req-456")
768 .with_upstream_status(503)
769 .with_retry(RetryHint {
770 should_retry: true,
771 retry_after_ms: Some(1000),
772 max_attempts: Some(3),
773 });
774 assert_eq!(err.details.provider.as_deref(), Some("openai"));
775 assert_eq!(err.details.request_id.as_deref(), Some("req-456"));
776 assert_eq!(err.details.upstream_status_code, Some(503));
777 assert!(err.is_retryable());
778 assert_eq!(err.retry.retry_after_ms, Some(1000));
779 assert_eq!(err.retry.max_attempts, Some(3));
780 }
781
782 #[test]
783 fn test_error_component_display() {
784 assert_eq!(ErrorComponent::Agent.to_string(), "agent");
785 assert_eq!(ErrorComponent::Model.to_string(), "model");
786 assert_eq!(ErrorComponent::Graph.to_string(), "graph");
787 assert_eq!(ErrorComponent::Realtime.to_string(), "realtime");
788 assert_eq!(ErrorComponent::Deploy.to_string(), "deploy");
789 }
790
791 #[test]
792 fn test_error_category_display() {
793 assert_eq!(ErrorCategory::InvalidInput.to_string(), "invalid_input");
794 assert_eq!(ErrorCategory::RateLimited.to_string(), "rate_limited");
795 assert_eq!(ErrorCategory::NotFound.to_string(), "not_found");
796 assert_eq!(ErrorCategory::Internal.to_string(), "internal");
797 }
798
799 #[test]
800 #[allow(clippy::unnecessary_literal_unwrap)]
801 fn test_result_type() {
802 let ok: Result<i32> = Ok(42);
803 assert_eq!(ok.unwrap(), 42);
804 let err: Result<i32> = Err(AdkError::config("invalid"));
805 assert!(err.is_err());
806 }
807
808 #[test]
809 fn test_with_details() {
810 let d = ErrorDetails {
811 upstream_status_code: Some(502),
812 request_id: Some("abc".into()),
813 provider: Some("gemini".into()),
814 metadata: HashMap::new(),
815 };
816 let err = AdkError::agent("test").with_details(d);
817 assert_eq!(err.details.upstream_status_code, Some(502));
818 assert_eq!(err.details.request_id.as_deref(), Some("abc"));
819 assert_eq!(err.details.provider.as_deref(), Some("gemini"));
820 }
821
822 #[test]
823 fn test_debug_impl() {
824 let s = format!("{:?}", AdkError::agent("debug test"));
825 assert!(s.contains("AdkError"));
826 assert!(s.contains("agent.legacy"));
827 }
828
829 #[test]
830 fn test_send_sync() {
831 fn assert_send<T: Send>() {}
832 fn assert_sync<T: Sync>() {}
833 assert_send::<AdkError>();
834 assert_sync::<AdkError>();
835 }
836}