1use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use std::collections::HashMap;
6use std::fmt;
7use std::time::Duration;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19#[non_exhaustive]
20pub enum ErrorComponent {
21 Agent,
22 Model,
23 Tool,
24 Session,
25 Artifact,
26 Memory,
27 Graph,
28 Realtime,
29 Code,
30 Server,
31 Auth,
32 Guardrail,
33 Eval,
34 Deploy,
35}
36
37impl fmt::Display for ErrorComponent {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 let s = match self {
40 Self::Agent => "agent",
41 Self::Model => "model",
42 Self::Tool => "tool",
43 Self::Session => "session",
44 Self::Artifact => "artifact",
45 Self::Memory => "memory",
46 Self::Graph => "graph",
47 Self::Realtime => "realtime",
48 Self::Code => "code",
49 Self::Server => "server",
50 Self::Auth => "auth",
51 Self::Guardrail => "guardrail",
52 Self::Eval => "eval",
53 Self::Deploy => "deploy",
54 };
55 f.write_str(s)
56 }
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
73#[serde(rename_all = "snake_case")]
74#[non_exhaustive]
75pub enum ErrorCategory {
76 InvalidInput,
77 Unauthorized,
78 Forbidden,
79 NotFound,
80 RateLimited,
81 Timeout,
82 Unavailable,
83 Cancelled,
84 Internal,
85 Unsupported,
86}
87
88impl fmt::Display for ErrorCategory {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 let s = match self {
91 Self::InvalidInput => "invalid_input",
92 Self::Unauthorized => "unauthorized",
93 Self::Forbidden => "forbidden",
94 Self::NotFound => "not_found",
95 Self::RateLimited => "rate_limited",
96 Self::Timeout => "timeout",
97 Self::Unavailable => "unavailable",
98 Self::Cancelled => "cancelled",
99 Self::Internal => "internal",
100 Self::Unsupported => "unsupported",
101 };
102 f.write_str(s)
103 }
104}
105
106#[derive(Debug, Clone, Default, Serialize, Deserialize)]
108pub struct RetryHint {
109 pub should_retry: bool,
110 #[serde(default, skip_serializing_if = "Option::is_none")]
111 pub retry_after_ms: Option<u64>,
112 #[serde(default, skip_serializing_if = "Option::is_none")]
113 pub max_attempts: Option<u32>,
114}
115
116impl RetryHint {
117 pub fn for_category(category: ErrorCategory) -> Self {
119 match category {
120 ErrorCategory::RateLimited | ErrorCategory::Unavailable | ErrorCategory::Timeout => {
121 Self { should_retry: true, ..Default::default() }
122 }
123 _ => Self::default(),
124 }
125 }
126
127 pub fn retry_after(&self) -> Option<Duration> {
129 self.retry_after_ms.map(Duration::from_millis)
130 }
131
132 pub fn with_retry_after(mut self, duration: Duration) -> Self {
134 self.retry_after_ms = Some(duration.as_millis() as u64);
135 self
136 }
137}
138
139#[derive(Debug, Clone, Default, Serialize, Deserialize)]
141pub struct ErrorDetails {
142 #[serde(default, skip_serializing_if = "Option::is_none")]
143 pub upstream_status_code: Option<u16>,
144 #[serde(default, skip_serializing_if = "Option::is_none")]
145 pub request_id: Option<String>,
146 #[serde(default, skip_serializing_if = "Option::is_none")]
147 pub provider: Option<String>,
148 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
149 pub metadata: HashMap<String, Value>,
150}
151
152pub struct AdkError {
184 pub component: ErrorComponent,
185 pub category: ErrorCategory,
186 pub code: &'static str,
187 pub message: String,
188 pub retry: RetryHint,
189 pub details: Box<ErrorDetails>,
190 source: Option<Box<dyn std::error::Error + Send + Sync>>,
191}
192
193impl fmt::Debug for AdkError {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 let mut d = f.debug_struct("AdkError");
196 d.field("component", &self.component)
197 .field("category", &self.category)
198 .field("code", &self.code)
199 .field("message", &self.message)
200 .field("retry", &self.retry)
201 .field("details", &self.details);
202 if let Some(src) = &self.source {
203 d.field("source", &format_args!("{src}"));
204 }
205 d.finish()
206 }
207}
208
209impl fmt::Display for AdkError {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 write!(f, "{}.{}: {}", self.component, self.category, self.message)
212 }
213}
214
215impl std::error::Error for AdkError {
216 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
217 self.source.as_ref().map(|e| e.as_ref() as &(dyn std::error::Error + 'static))
218 }
219}
220
221const _: () = {
222 fn _assert_send<T: Send>() {}
223 fn _assert_sync<T: Sync>() {}
224 fn _assertions() {
225 _assert_send::<AdkError>();
226 _assert_sync::<AdkError>();
227 }
228};
229
230impl AdkError {
231 pub fn new(
232 component: ErrorComponent,
233 category: ErrorCategory,
234 code: &'static str,
235 message: impl Into<String>,
236 ) -> Self {
237 Self {
238 component,
239 category,
240 code,
241 message: message.into(),
242 retry: RetryHint::for_category(category),
243 details: Box::new(ErrorDetails::default()),
244 source: None,
245 }
246 }
247
248 pub fn with_source(mut self, source: impl std::error::Error + Send + Sync + 'static) -> Self {
249 self.source = Some(Box::new(source));
250 self
251 }
252
253 pub fn with_retry(mut self, retry: RetryHint) -> Self {
254 self.retry = retry;
255 self
256 }
257
258 pub fn with_details(mut self, details: ErrorDetails) -> Self {
259 self.details = Box::new(details);
260 self
261 }
262
263 pub fn with_upstream_status(mut self, status_code: u16) -> Self {
264 self.details.upstream_status_code = Some(status_code);
265 self
266 }
267
268 pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
269 self.details.request_id = Some(request_id.into());
270 self
271 }
272
273 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
274 self.details.provider = Some(provider.into());
275 self
276 }
277}
278
279impl AdkError {
280 pub fn not_found(
281 component: ErrorComponent,
282 code: &'static str,
283 message: impl Into<String>,
284 ) -> Self {
285 Self::new(component, ErrorCategory::NotFound, code, message)
286 }
287
288 pub fn rate_limited(
289 component: ErrorComponent,
290 code: &'static str,
291 message: impl Into<String>,
292 ) -> Self {
293 Self::new(component, ErrorCategory::RateLimited, code, message)
294 }
295
296 pub fn unauthorized(
297 component: ErrorComponent,
298 code: &'static str,
299 message: impl Into<String>,
300 ) -> Self {
301 Self::new(component, ErrorCategory::Unauthorized, code, message)
302 }
303
304 pub fn internal(
305 component: ErrorComponent,
306 code: &'static str,
307 message: impl Into<String>,
308 ) -> Self {
309 Self::new(component, ErrorCategory::Internal, code, message)
310 }
311
312 pub fn timeout(
313 component: ErrorComponent,
314 code: &'static str,
315 message: impl Into<String>,
316 ) -> Self {
317 Self::new(component, ErrorCategory::Timeout, code, message)
318 }
319
320 pub fn unavailable(
321 component: ErrorComponent,
322 code: &'static str,
323 message: impl Into<String>,
324 ) -> Self {
325 Self::new(component, ErrorCategory::Unavailable, code, message)
326 }
327}
328
329impl AdkError {
330 pub fn agent(message: impl Into<String>) -> Self {
331 Self::new(ErrorComponent::Agent, ErrorCategory::Internal, "agent.legacy", message)
332 }
333
334 pub fn model(message: impl Into<String>) -> Self {
335 Self::new(ErrorComponent::Model, ErrorCategory::Internal, "model.legacy", message)
336 }
337
338 pub fn tool(message: impl Into<String>) -> Self {
339 Self::new(ErrorComponent::Tool, ErrorCategory::Internal, "tool.legacy", message)
340 }
341
342 pub fn session(message: impl Into<String>) -> Self {
343 Self::new(ErrorComponent::Session, ErrorCategory::Internal, "session.legacy", message)
344 }
345
346 pub fn memory(message: impl Into<String>) -> Self {
347 Self::new(ErrorComponent::Memory, ErrorCategory::Internal, "memory.legacy", message)
348 }
349
350 pub fn config(message: impl Into<String>) -> Self {
351 Self::new(ErrorComponent::Server, ErrorCategory::InvalidInput, "config.legacy", message)
352 }
353
354 pub fn artifact(message: impl Into<String>) -> Self {
355 Self::new(ErrorComponent::Artifact, ErrorCategory::Internal, "artifact.legacy", message)
356 }
357}
358
359impl AdkError {
360 pub fn is_agent(&self) -> bool {
361 self.component == ErrorComponent::Agent
362 }
363 pub fn is_model(&self) -> bool {
364 self.component == ErrorComponent::Model
365 }
366 pub fn is_tool(&self) -> bool {
367 self.component == ErrorComponent::Tool
368 }
369 pub fn is_session(&self) -> bool {
370 self.component == ErrorComponent::Session
371 }
372 pub fn is_artifact(&self) -> bool {
373 self.component == ErrorComponent::Artifact
374 }
375 pub fn is_memory(&self) -> bool {
376 self.component == ErrorComponent::Memory
377 }
378 pub fn is_config(&self) -> bool {
379 self.code == "config.legacy"
380 }
381}
382
383impl AdkError {
384 pub fn is_retryable(&self) -> bool {
385 self.retry.should_retry
386 }
387 pub fn is_not_found(&self) -> bool {
388 self.category == ErrorCategory::NotFound
389 }
390 pub fn is_unauthorized(&self) -> bool {
391 self.category == ErrorCategory::Unauthorized
392 }
393 pub fn is_rate_limited(&self) -> bool {
394 self.category == ErrorCategory::RateLimited
395 }
396 pub fn is_timeout(&self) -> bool {
397 self.category == ErrorCategory::Timeout
398 }
399}
400
401impl AdkError {
402 #[allow(unreachable_patterns)]
403 pub fn http_status_code(&self) -> u16 {
404 match self.category {
405 ErrorCategory::InvalidInput => 400,
406 ErrorCategory::Unauthorized => 401,
407 ErrorCategory::Forbidden => 403,
408 ErrorCategory::NotFound => 404,
409 ErrorCategory::RateLimited => 429,
410 ErrorCategory::Timeout => 408,
411 ErrorCategory::Unavailable => 503,
412 ErrorCategory::Cancelled => 499,
413 ErrorCategory::Internal => 500,
414 ErrorCategory::Unsupported => 501,
415 _ => 500,
416 }
417 }
418}
419
420impl AdkError {
421 pub fn to_problem_json(&self) -> Value {
422 json!({
423 "error": {
424 "code": self.code,
425 "message": self.message,
426 "component": self.component,
427 "category": self.category,
428 "requestId": self.details.request_id,
429 "retryAfter": self.retry.retry_after_ms,
430 "upstreamStatusCode": self.details.upstream_status_code,
431 }
432 })
433 }
434}
435
436pub type Result<T> = std::result::Result<T, AdkError>;
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn test_new_sets_fields() {
445 let err = AdkError::new(
446 ErrorComponent::Model,
447 ErrorCategory::RateLimited,
448 "model.rate_limited",
449 "too many requests",
450 );
451 assert_eq!(err.component, ErrorComponent::Model);
452 assert_eq!(err.category, ErrorCategory::RateLimited);
453 assert_eq!(err.code, "model.rate_limited");
454 assert_eq!(err.message, "too many requests");
455 assert!(err.retry.should_retry);
456 }
457
458 #[test]
459 fn test_display_format() {
460 let err = AdkError::new(
461 ErrorComponent::Session,
462 ErrorCategory::NotFound,
463 "session.not_found",
464 "session xyz not found",
465 );
466 assert_eq!(err.to_string(), "session.not_found: session xyz not found");
467 }
468
469 #[test]
470 fn test_convenience_not_found() {
471 let err = AdkError::not_found(ErrorComponent::Session, "session.not_found", "gone");
472 assert_eq!(err.category, ErrorCategory::NotFound);
473 assert!(!err.is_retryable());
474 }
475
476 #[test]
477 fn test_convenience_rate_limited() {
478 let err = AdkError::rate_limited(ErrorComponent::Model, "model.rate_limited", "slow down");
479 assert!(err.is_retryable());
480 assert!(err.is_rate_limited());
481 }
482
483 #[test]
484 fn test_convenience_unauthorized() {
485 let err = AdkError::unauthorized(ErrorComponent::Auth, "auth.unauthorized", "bad token");
486 assert!(err.is_unauthorized());
487 assert!(!err.is_retryable());
488 }
489
490 #[test]
491 fn test_convenience_internal() {
492 let err = AdkError::internal(ErrorComponent::Agent, "agent.internal", "oops");
493 assert_eq!(err.category, ErrorCategory::Internal);
494 }
495
496 #[test]
497 fn test_convenience_timeout() {
498 let err = AdkError::timeout(ErrorComponent::Model, "model.timeout", "timed out");
499 assert!(err.is_timeout());
500 assert!(err.is_retryable());
501 }
502
503 #[test]
504 fn test_convenience_unavailable() {
505 let err = AdkError::unavailable(ErrorComponent::Model, "model.unavailable", "503");
506 assert!(err.is_retryable());
507 }
508
509 #[test]
510 fn test_backward_compat_agent() {
511 let err = AdkError::agent("test error");
512 assert!(err.is_agent());
513 assert_eq!(err.code, "agent.legacy");
514 assert_eq!(err.category, ErrorCategory::Internal);
515 assert_eq!(err.to_string(), "agent.internal: test error");
516 }
517
518 #[test]
519 fn test_backward_compat_model() {
520 let err = AdkError::model("model fail");
521 assert!(err.is_model());
522 assert_eq!(err.code, "model.legacy");
523 }
524
525 #[test]
526 fn test_backward_compat_tool() {
527 let err = AdkError::tool("tool fail");
528 assert!(err.is_tool());
529 assert_eq!(err.code, "tool.legacy");
530 }
531
532 #[test]
533 fn test_backward_compat_session() {
534 let err = AdkError::session("session fail");
535 assert!(err.is_session());
536 assert_eq!(err.code, "session.legacy");
537 }
538
539 #[test]
540 fn test_backward_compat_memory() {
541 let err = AdkError::memory("memory fail");
542 assert!(err.is_memory());
543 assert_eq!(err.code, "memory.legacy");
544 }
545
546 #[test]
547 fn test_backward_compat_artifact() {
548 let err = AdkError::artifact("artifact fail");
549 assert!(err.is_artifact());
550 assert_eq!(err.code, "artifact.legacy");
551 }
552
553 #[test]
554 fn test_backward_compat_config() {
555 let err = AdkError::config("bad config");
556 assert!(err.is_config());
557 assert_eq!(err.code, "config.legacy");
558 assert_eq!(err.component, ErrorComponent::Server);
559 assert_eq!(err.category, ErrorCategory::InvalidInput);
560 }
561
562 #[test]
563 fn test_backward_compat_codes_end_with_legacy() {
564 let errors = [
565 AdkError::agent("a"),
566 AdkError::model("m"),
567 AdkError::tool("t"),
568 AdkError::session("s"),
569 AdkError::memory("mem"),
570 AdkError::config("c"),
571 AdkError::artifact("art"),
572 ];
573 for err in &errors {
574 assert!(err.code.ends_with(".legacy"), "code '{}' should end with .legacy", err.code);
575 }
576 }
577
578 #[test]
579 fn test_is_config_false_for_non_config() {
580 assert!(!AdkError::agent("not config").is_config());
581 }
582
583 #[test]
584 fn test_retryable_categories_default_true() {
585 for cat in [ErrorCategory::RateLimited, ErrorCategory::Unavailable, ErrorCategory::Timeout]
586 {
587 let err = AdkError::new(ErrorComponent::Model, cat, "test", "msg");
588 assert!(err.is_retryable(), "expected is_retryable() == true for {cat}");
589 }
590 }
591
592 #[test]
593 fn test_retryable_override_to_false() {
594 let err =
595 AdkError::new(ErrorComponent::Model, ErrorCategory::RateLimited, "m.rl", "overridden")
596 .with_retry(RetryHint { should_retry: false, ..Default::default() });
597 assert!(!err.is_retryable());
598 }
599
600 #[test]
601 fn test_non_retryable_categories_default_false() {
602 for cat in [
603 ErrorCategory::InvalidInput,
604 ErrorCategory::Unauthorized,
605 ErrorCategory::Forbidden,
606 ErrorCategory::NotFound,
607 ErrorCategory::Cancelled,
608 ErrorCategory::Internal,
609 ErrorCategory::Unsupported,
610 ] {
611 let err = AdkError::new(ErrorComponent::Model, cat, "test", "msg");
612 assert!(!err.is_retryable(), "expected is_retryable() == false for {cat}");
613 }
614 }
615
616 #[test]
617 fn test_http_status_code_mapping() {
618 let cases = [
619 (ErrorCategory::InvalidInput, 400),
620 (ErrorCategory::Unauthorized, 401),
621 (ErrorCategory::Forbidden, 403),
622 (ErrorCategory::NotFound, 404),
623 (ErrorCategory::RateLimited, 429),
624 (ErrorCategory::Timeout, 408),
625 (ErrorCategory::Unavailable, 503),
626 (ErrorCategory::Cancelled, 499),
627 (ErrorCategory::Internal, 500),
628 (ErrorCategory::Unsupported, 501),
629 ];
630 for (cat, expected) in &cases {
631 let err = AdkError::new(ErrorComponent::Server, *cat, "test", "msg");
632 assert_eq!(err.http_status_code(), *expected, "wrong status for {cat}");
633 }
634 }
635
636 #[test]
637 fn test_source_returns_some_when_set() {
638 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
639 let err = AdkError::new(ErrorComponent::Session, ErrorCategory::NotFound, "s.f", "missing")
640 .with_source(io_err);
641 assert!(std::error::Error::source(&err).is_some());
642 }
643
644 #[test]
645 fn test_source_returns_none_when_not_set() {
646 assert!(std::error::Error::source(&AdkError::agent("no source")).is_none());
647 }
648
649 #[test]
650 fn test_retry_hint_for_category() {
651 assert!(RetryHint::for_category(ErrorCategory::RateLimited).should_retry);
652 assert!(RetryHint::for_category(ErrorCategory::Unavailable).should_retry);
653 assert!(RetryHint::for_category(ErrorCategory::Timeout).should_retry);
654 assert!(!RetryHint::for_category(ErrorCategory::Internal).should_retry);
655 assert!(!RetryHint::for_category(ErrorCategory::NotFound).should_retry);
656 }
657
658 #[test]
659 fn test_retry_hint_with_retry_after() {
660 let hint = RetryHint::default().with_retry_after(Duration::from_secs(5));
661 assert_eq!(hint.retry_after_ms, Some(5000));
662 assert_eq!(hint.retry_after(), Some(Duration::from_secs(5)));
663 }
664
665 #[test]
666 fn test_to_problem_json() {
667 let err = AdkError::new(
668 ErrorComponent::Model,
669 ErrorCategory::RateLimited,
670 "model.rate_limited",
671 "slow down",
672 )
673 .with_request_id("req-123")
674 .with_upstream_status(429);
675 let j = err.to_problem_json();
676 let o = &j["error"];
677 assert_eq!(o["code"], "model.rate_limited");
678 assert_eq!(o["message"], "slow down");
679 assert_eq!(o["component"], "model");
680 assert_eq!(o["category"], "rate_limited");
681 assert_eq!(o["requestId"], "req-123");
682 assert_eq!(o["upstreamStatusCode"], 429);
683 }
684
685 #[test]
686 fn test_to_problem_json_null_optionals() {
687 let j = AdkError::agent("simple").to_problem_json();
688 let o = &j["error"];
689 assert!(o["requestId"].is_null());
690 assert!(o["retryAfter"].is_null());
691 assert!(o["upstreamStatusCode"].is_null());
692 }
693
694 #[test]
695 fn test_builder_chaining() {
696 let err = AdkError::new(ErrorComponent::Model, ErrorCategory::Unavailable, "m.u", "down")
697 .with_provider("openai")
698 .with_request_id("req-456")
699 .with_upstream_status(503)
700 .with_retry(RetryHint {
701 should_retry: true,
702 retry_after_ms: Some(1000),
703 max_attempts: Some(3),
704 });
705 assert_eq!(err.details.provider.as_deref(), Some("openai"));
706 assert_eq!(err.details.request_id.as_deref(), Some("req-456"));
707 assert_eq!(err.details.upstream_status_code, Some(503));
708 assert!(err.is_retryable());
709 assert_eq!(err.retry.retry_after_ms, Some(1000));
710 assert_eq!(err.retry.max_attempts, Some(3));
711 }
712
713 #[test]
714 fn test_error_component_display() {
715 assert_eq!(ErrorComponent::Agent.to_string(), "agent");
716 assert_eq!(ErrorComponent::Model.to_string(), "model");
717 assert_eq!(ErrorComponent::Graph.to_string(), "graph");
718 assert_eq!(ErrorComponent::Realtime.to_string(), "realtime");
719 assert_eq!(ErrorComponent::Deploy.to_string(), "deploy");
720 }
721
722 #[test]
723 fn test_error_category_display() {
724 assert_eq!(ErrorCategory::InvalidInput.to_string(), "invalid_input");
725 assert_eq!(ErrorCategory::RateLimited.to_string(), "rate_limited");
726 assert_eq!(ErrorCategory::NotFound.to_string(), "not_found");
727 assert_eq!(ErrorCategory::Internal.to_string(), "internal");
728 }
729
730 #[test]
731 #[allow(clippy::unnecessary_literal_unwrap)]
732 fn test_result_type() {
733 let ok: Result<i32> = Ok(42);
734 assert_eq!(ok.unwrap(), 42);
735 let err: Result<i32> = Err(AdkError::config("invalid"));
736 assert!(err.is_err());
737 }
738
739 #[test]
740 fn test_with_details() {
741 let d = ErrorDetails {
742 upstream_status_code: Some(502),
743 request_id: Some("abc".into()),
744 provider: Some("gemini".into()),
745 metadata: HashMap::new(),
746 };
747 let err = AdkError::agent("test").with_details(d);
748 assert_eq!(err.details.upstream_status_code, Some(502));
749 assert_eq!(err.details.request_id.as_deref(), Some("abc"));
750 assert_eq!(err.details.provider.as_deref(), Some("gemini"));
751 }
752
753 #[test]
754 fn test_debug_impl() {
755 let s = format!("{:?}", AdkError::agent("debug test"));
756 assert!(s.contains("AdkError"));
757 assert!(s.contains("agent.legacy"));
758 }
759
760 #[test]
761 fn test_send_sync() {
762 fn assert_send<T: Send>() {}
763 fn assert_sync<T: Sync>() {}
764 assert_send::<AdkError>();
765 assert_sync::<AdkError>();
766 }
767}