1use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use std::collections::HashMap;
6use std::fmt;
7use std::time::Duration;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19#[non_exhaustive]
20pub enum ErrorComponent {
21 Agent,
22 Model,
23 Tool,
24 Session,
25 Artifact,
26 Memory,
27 Graph,
28 Realtime,
29 Code,
30 Server,
31 Auth,
32 Guardrail,
33 Eval,
34 Deploy,
35}
36
37impl fmt::Display for ErrorComponent {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 let s = match self {
40 Self::Agent => "agent",
41 Self::Model => "model",
42 Self::Tool => "tool",
43 Self::Session => "session",
44 Self::Artifact => "artifact",
45 Self::Memory => "memory",
46 Self::Graph => "graph",
47 Self::Realtime => "realtime",
48 Self::Code => "code",
49 Self::Server => "server",
50 Self::Auth => "auth",
51 Self::Guardrail => "guardrail",
52 Self::Eval => "eval",
53 Self::Deploy => "deploy",
54 };
55 f.write_str(s)
56 }
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
73#[serde(rename_all = "snake_case")]
74#[non_exhaustive]
75pub enum ErrorCategory {
76 InvalidInput,
77 Unauthorized,
78 Forbidden,
79 NotFound,
80 RateLimited,
81 Timeout,
82 Unavailable,
83 Cancelled,
84 Internal,
85 Unsupported,
86}
87
88impl fmt::Display for ErrorCategory {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 let s = match self {
91 Self::InvalidInput => "invalid_input",
92 Self::Unauthorized => "unauthorized",
93 Self::Forbidden => "forbidden",
94 Self::NotFound => "not_found",
95 Self::RateLimited => "rate_limited",
96 Self::Timeout => "timeout",
97 Self::Unavailable => "unavailable",
98 Self::Cancelled => "cancelled",
99 Self::Internal => "internal",
100 Self::Unsupported => "unsupported",
101 };
102 f.write_str(s)
103 }
104}
105
106#[derive(Debug, Clone, Default, Serialize, Deserialize)]
108pub struct RetryHint {
109 pub should_retry: bool,
110 #[serde(default, skip_serializing_if = "Option::is_none")]
111 pub retry_after_ms: Option<u64>,
112 #[serde(default, skip_serializing_if = "Option::is_none")]
113 pub max_attempts: Option<u32>,
114}
115
116impl RetryHint {
117 pub fn for_category(category: ErrorCategory) -> Self {
119 match category {
120 ErrorCategory::RateLimited | ErrorCategory::Unavailable | ErrorCategory::Timeout => {
121 Self { should_retry: true, ..Default::default() }
122 }
123 _ => Self::default(),
124 }
125 }
126
127 pub fn retry_after(&self) -> Option<Duration> {
129 self.retry_after_ms.map(Duration::from_millis)
130 }
131
132 pub fn with_retry_after(mut self, duration: Duration) -> Self {
134 self.retry_after_ms = Some(duration.as_millis() as u64);
135 self
136 }
137}
138
139#[derive(Debug, Clone, Default, Serialize, Deserialize)]
141pub struct ErrorDetails {
142 #[serde(default, skip_serializing_if = "Option::is_none")]
143 pub upstream_status_code: Option<u16>,
144 #[serde(default, skip_serializing_if = "Option::is_none")]
145 pub request_id: Option<String>,
146 #[serde(default, skip_serializing_if = "Option::is_none")]
147 pub provider: Option<String>,
148 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
149 pub metadata: HashMap<String, Value>,
150}
151
152pub struct AdkError {
182 pub component: ErrorComponent,
183 pub category: ErrorCategory,
184 pub code: &'static str,
185 pub message: String,
186 pub retry: RetryHint,
187 pub details: Box<ErrorDetails>,
188 source: Option<Box<dyn std::error::Error + Send + Sync>>,
189}
190
191impl fmt::Debug for AdkError {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 let mut d = f.debug_struct("AdkError");
194 d.field("component", &self.component)
195 .field("category", &self.category)
196 .field("code", &self.code)
197 .field("message", &self.message)
198 .field("retry", &self.retry)
199 .field("details", &self.details);
200 if let Some(src) = &self.source {
201 d.field("source", &format_args!("{src}"));
202 }
203 d.finish()
204 }
205}
206
207impl fmt::Display for AdkError {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 write!(f, "{}.{}: {}", self.component, self.category, self.message)
210 }
211}
212
213impl std::error::Error for AdkError {
214 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
215 self.source.as_ref().map(|e| e.as_ref() as &(dyn std::error::Error + 'static))
216 }
217}
218
219const _: () = {
220 fn _assert_send<T: Send>() {}
221 fn _assert_sync<T: Sync>() {}
222 fn _assertions() {
223 _assert_send::<AdkError>();
224 _assert_sync::<AdkError>();
225 }
226};
227
228impl AdkError {
229 pub fn new(
230 component: ErrorComponent,
231 category: ErrorCategory,
232 code: &'static str,
233 message: impl Into<String>,
234 ) -> Self {
235 Self {
236 component,
237 category,
238 code,
239 message: message.into(),
240 retry: RetryHint::for_category(category),
241 details: Box::new(ErrorDetails::default()),
242 source: None,
243 }
244 }
245
246 pub fn with_source(mut self, source: impl std::error::Error + Send + Sync + 'static) -> Self {
247 self.source = Some(Box::new(source));
248 self
249 }
250
251 pub fn with_retry(mut self, retry: RetryHint) -> Self {
252 self.retry = retry;
253 self
254 }
255
256 pub fn with_details(mut self, details: ErrorDetails) -> Self {
257 self.details = Box::new(details);
258 self
259 }
260
261 pub fn with_upstream_status(mut self, status_code: u16) -> Self {
262 self.details.upstream_status_code = Some(status_code);
263 self
264 }
265
266 pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
267 self.details.request_id = Some(request_id.into());
268 self
269 }
270
271 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
272 self.details.provider = Some(provider.into());
273 self
274 }
275}
276
277impl AdkError {
278 pub fn not_found(
279 component: ErrorComponent,
280 code: &'static str,
281 message: impl Into<String>,
282 ) -> Self {
283 Self::new(component, ErrorCategory::NotFound, code, message)
284 }
285
286 pub fn rate_limited(
287 component: ErrorComponent,
288 code: &'static str,
289 message: impl Into<String>,
290 ) -> Self {
291 Self::new(component, ErrorCategory::RateLimited, code, message)
292 }
293
294 pub fn unauthorized(
295 component: ErrorComponent,
296 code: &'static str,
297 message: impl Into<String>,
298 ) -> Self {
299 Self::new(component, ErrorCategory::Unauthorized, code, message)
300 }
301
302 pub fn internal(
303 component: ErrorComponent,
304 code: &'static str,
305 message: impl Into<String>,
306 ) -> Self {
307 Self::new(component, ErrorCategory::Internal, code, message)
308 }
309
310 pub fn timeout(
311 component: ErrorComponent,
312 code: &'static str,
313 message: impl Into<String>,
314 ) -> Self {
315 Self::new(component, ErrorCategory::Timeout, code, message)
316 }
317
318 pub fn unavailable(
319 component: ErrorComponent,
320 code: &'static str,
321 message: impl Into<String>,
322 ) -> Self {
323 Self::new(component, ErrorCategory::Unavailable, code, message)
324 }
325}
326
327impl AdkError {
328 pub fn agent(message: impl Into<String>) -> Self {
329 Self::new(ErrorComponent::Agent, ErrorCategory::Internal, "agent.legacy", message)
330 }
331
332 pub fn model(message: impl Into<String>) -> Self {
333 Self::new(ErrorComponent::Model, ErrorCategory::Internal, "model.legacy", message)
334 }
335
336 pub fn tool(message: impl Into<String>) -> Self {
337 Self::new(ErrorComponent::Tool, ErrorCategory::Internal, "tool.legacy", message)
338 }
339
340 pub fn session(message: impl Into<String>) -> Self {
341 Self::new(ErrorComponent::Session, ErrorCategory::Internal, "session.legacy", message)
342 }
343
344 pub fn memory(message: impl Into<String>) -> Self {
345 Self::new(ErrorComponent::Memory, ErrorCategory::Internal, "memory.legacy", message)
346 }
347
348 pub fn config(message: impl Into<String>) -> Self {
349 Self::new(ErrorComponent::Server, ErrorCategory::InvalidInput, "config.legacy", message)
350 }
351
352 pub fn artifact(message: impl Into<String>) -> Self {
353 Self::new(ErrorComponent::Artifact, ErrorCategory::Internal, "artifact.legacy", message)
354 }
355}
356
357impl AdkError {
358 pub fn is_agent(&self) -> bool {
359 self.component == ErrorComponent::Agent
360 }
361 pub fn is_model(&self) -> bool {
362 self.component == ErrorComponent::Model
363 }
364 pub fn is_tool(&self) -> bool {
365 self.component == ErrorComponent::Tool
366 }
367 pub fn is_session(&self) -> bool {
368 self.component == ErrorComponent::Session
369 }
370 pub fn is_artifact(&self) -> bool {
371 self.component == ErrorComponent::Artifact
372 }
373 pub fn is_memory(&self) -> bool {
374 self.component == ErrorComponent::Memory
375 }
376 pub fn is_config(&self) -> bool {
377 self.code == "config.legacy"
378 }
379}
380
381impl AdkError {
382 pub fn is_retryable(&self) -> bool {
383 self.retry.should_retry
384 }
385 pub fn is_not_found(&self) -> bool {
386 self.category == ErrorCategory::NotFound
387 }
388 pub fn is_unauthorized(&self) -> bool {
389 self.category == ErrorCategory::Unauthorized
390 }
391 pub fn is_rate_limited(&self) -> bool {
392 self.category == ErrorCategory::RateLimited
393 }
394 pub fn is_timeout(&self) -> bool {
395 self.category == ErrorCategory::Timeout
396 }
397}
398
399impl AdkError {
400 #[allow(unreachable_patterns)]
401 pub fn http_status_code(&self) -> u16 {
402 match self.category {
403 ErrorCategory::InvalidInput => 400,
404 ErrorCategory::Unauthorized => 401,
405 ErrorCategory::Forbidden => 403,
406 ErrorCategory::NotFound => 404,
407 ErrorCategory::RateLimited => 429,
408 ErrorCategory::Timeout => 408,
409 ErrorCategory::Unavailable => 503,
410 ErrorCategory::Cancelled => 499,
411 ErrorCategory::Internal => 500,
412 ErrorCategory::Unsupported => 501,
413 _ => 500,
414 }
415 }
416}
417
418impl AdkError {
419 pub fn to_problem_json(&self) -> Value {
420 json!({
421 "error": {
422 "code": self.code,
423 "message": self.message,
424 "component": self.component,
425 "category": self.category,
426 "requestId": self.details.request_id,
427 "retryAfter": self.retry.retry_after_ms,
428 "upstreamStatusCode": self.details.upstream_status_code,
429 }
430 })
431 }
432}
433
434pub type Result<T> = std::result::Result<T, AdkError>;
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440
441 #[test]
442 fn test_new_sets_fields() {
443 let err = AdkError::new(
444 ErrorComponent::Model,
445 ErrorCategory::RateLimited,
446 "model.rate_limited",
447 "too many requests",
448 );
449 assert_eq!(err.component, ErrorComponent::Model);
450 assert_eq!(err.category, ErrorCategory::RateLimited);
451 assert_eq!(err.code, "model.rate_limited");
452 assert_eq!(err.message, "too many requests");
453 assert!(err.retry.should_retry);
454 }
455
456 #[test]
457 fn test_display_format() {
458 let err = AdkError::new(
459 ErrorComponent::Session,
460 ErrorCategory::NotFound,
461 "session.not_found",
462 "session xyz not found",
463 );
464 assert_eq!(err.to_string(), "session.not_found: session xyz not found");
465 }
466
467 #[test]
468 fn test_convenience_not_found() {
469 let err = AdkError::not_found(ErrorComponent::Session, "session.not_found", "gone");
470 assert_eq!(err.category, ErrorCategory::NotFound);
471 assert!(!err.is_retryable());
472 }
473
474 #[test]
475 fn test_convenience_rate_limited() {
476 let err = AdkError::rate_limited(ErrorComponent::Model, "model.rate_limited", "slow down");
477 assert!(err.is_retryable());
478 assert!(err.is_rate_limited());
479 }
480
481 #[test]
482 fn test_convenience_unauthorized() {
483 let err = AdkError::unauthorized(ErrorComponent::Auth, "auth.unauthorized", "bad token");
484 assert!(err.is_unauthorized());
485 assert!(!err.is_retryable());
486 }
487
488 #[test]
489 fn test_convenience_internal() {
490 let err = AdkError::internal(ErrorComponent::Agent, "agent.internal", "oops");
491 assert_eq!(err.category, ErrorCategory::Internal);
492 }
493
494 #[test]
495 fn test_convenience_timeout() {
496 let err = AdkError::timeout(ErrorComponent::Model, "model.timeout", "timed out");
497 assert!(err.is_timeout());
498 assert!(err.is_retryable());
499 }
500
501 #[test]
502 fn test_convenience_unavailable() {
503 let err = AdkError::unavailable(ErrorComponent::Model, "model.unavailable", "503");
504 assert!(err.is_retryable());
505 }
506
507 #[test]
508 fn test_backward_compat_agent() {
509 let err = AdkError::agent("test error");
510 assert!(err.is_agent());
511 assert_eq!(err.code, "agent.legacy");
512 assert_eq!(err.category, ErrorCategory::Internal);
513 assert_eq!(err.to_string(), "agent.internal: test error");
514 }
515
516 #[test]
517 fn test_backward_compat_model() {
518 let err = AdkError::model("model fail");
519 assert!(err.is_model());
520 assert_eq!(err.code, "model.legacy");
521 }
522
523 #[test]
524 fn test_backward_compat_tool() {
525 let err = AdkError::tool("tool fail");
526 assert!(err.is_tool());
527 assert_eq!(err.code, "tool.legacy");
528 }
529
530 #[test]
531 fn test_backward_compat_session() {
532 let err = AdkError::session("session fail");
533 assert!(err.is_session());
534 assert_eq!(err.code, "session.legacy");
535 }
536
537 #[test]
538 fn test_backward_compat_memory() {
539 let err = AdkError::memory("memory fail");
540 assert!(err.is_memory());
541 assert_eq!(err.code, "memory.legacy");
542 }
543
544 #[test]
545 fn test_backward_compat_artifact() {
546 let err = AdkError::artifact("artifact fail");
547 assert!(err.is_artifact());
548 assert_eq!(err.code, "artifact.legacy");
549 }
550
551 #[test]
552 fn test_backward_compat_config() {
553 let err = AdkError::config("bad config");
554 assert!(err.is_config());
555 assert_eq!(err.code, "config.legacy");
556 assert_eq!(err.component, ErrorComponent::Server);
557 assert_eq!(err.category, ErrorCategory::InvalidInput);
558 }
559
560 #[test]
561 fn test_backward_compat_codes_end_with_legacy() {
562 let errors = [
563 AdkError::agent("a"),
564 AdkError::model("m"),
565 AdkError::tool("t"),
566 AdkError::session("s"),
567 AdkError::memory("mem"),
568 AdkError::config("c"),
569 AdkError::artifact("art"),
570 ];
571 for err in &errors {
572 assert!(err.code.ends_with(".legacy"), "code '{}' should end with .legacy", err.code);
573 }
574 }
575
576 #[test]
577 fn test_is_config_false_for_non_config() {
578 assert!(!AdkError::agent("not config").is_config());
579 }
580
581 #[test]
582 fn test_retryable_categories_default_true() {
583 for cat in [ErrorCategory::RateLimited, ErrorCategory::Unavailable, ErrorCategory::Timeout]
584 {
585 let err = AdkError::new(ErrorComponent::Model, cat, "test", "msg");
586 assert!(err.is_retryable(), "expected is_retryable() == true for {cat}");
587 }
588 }
589
590 #[test]
591 fn test_retryable_override_to_false() {
592 let err =
593 AdkError::new(ErrorComponent::Model, ErrorCategory::RateLimited, "m.rl", "overridden")
594 .with_retry(RetryHint { should_retry: false, ..Default::default() });
595 assert!(!err.is_retryable());
596 }
597
598 #[test]
599 fn test_non_retryable_categories_default_false() {
600 for cat in [
601 ErrorCategory::InvalidInput,
602 ErrorCategory::Unauthorized,
603 ErrorCategory::Forbidden,
604 ErrorCategory::NotFound,
605 ErrorCategory::Cancelled,
606 ErrorCategory::Internal,
607 ErrorCategory::Unsupported,
608 ] {
609 let err = AdkError::new(ErrorComponent::Model, cat, "test", "msg");
610 assert!(!err.is_retryable(), "expected is_retryable() == false for {cat}");
611 }
612 }
613
614 #[test]
615 fn test_http_status_code_mapping() {
616 let cases = [
617 (ErrorCategory::InvalidInput, 400),
618 (ErrorCategory::Unauthorized, 401),
619 (ErrorCategory::Forbidden, 403),
620 (ErrorCategory::NotFound, 404),
621 (ErrorCategory::RateLimited, 429),
622 (ErrorCategory::Timeout, 408),
623 (ErrorCategory::Unavailable, 503),
624 (ErrorCategory::Cancelled, 499),
625 (ErrorCategory::Internal, 500),
626 (ErrorCategory::Unsupported, 501),
627 ];
628 for (cat, expected) in &cases {
629 let err = AdkError::new(ErrorComponent::Server, *cat, "test", "msg");
630 assert_eq!(err.http_status_code(), *expected, "wrong status for {cat}");
631 }
632 }
633
634 #[test]
635 fn test_source_returns_some_when_set() {
636 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
637 let err = AdkError::new(ErrorComponent::Session, ErrorCategory::NotFound, "s.f", "missing")
638 .with_source(io_err);
639 assert!(std::error::Error::source(&err).is_some());
640 }
641
642 #[test]
643 fn test_source_returns_none_when_not_set() {
644 assert!(std::error::Error::source(&AdkError::agent("no source")).is_none());
645 }
646
647 #[test]
648 fn test_retry_hint_for_category() {
649 assert!(RetryHint::for_category(ErrorCategory::RateLimited).should_retry);
650 assert!(RetryHint::for_category(ErrorCategory::Unavailable).should_retry);
651 assert!(RetryHint::for_category(ErrorCategory::Timeout).should_retry);
652 assert!(!RetryHint::for_category(ErrorCategory::Internal).should_retry);
653 assert!(!RetryHint::for_category(ErrorCategory::NotFound).should_retry);
654 }
655
656 #[test]
657 fn test_retry_hint_with_retry_after() {
658 let hint = RetryHint::default().with_retry_after(Duration::from_secs(5));
659 assert_eq!(hint.retry_after_ms, Some(5000));
660 assert_eq!(hint.retry_after(), Some(Duration::from_secs(5)));
661 }
662
663 #[test]
664 fn test_to_problem_json() {
665 let err = AdkError::new(
666 ErrorComponent::Model,
667 ErrorCategory::RateLimited,
668 "model.rate_limited",
669 "slow down",
670 )
671 .with_request_id("req-123")
672 .with_upstream_status(429);
673 let j = err.to_problem_json();
674 let o = &j["error"];
675 assert_eq!(o["code"], "model.rate_limited");
676 assert_eq!(o["message"], "slow down");
677 assert_eq!(o["component"], "model");
678 assert_eq!(o["category"], "rate_limited");
679 assert_eq!(o["requestId"], "req-123");
680 assert_eq!(o["upstreamStatusCode"], 429);
681 }
682
683 #[test]
684 fn test_to_problem_json_null_optionals() {
685 let j = AdkError::agent("simple").to_problem_json();
686 let o = &j["error"];
687 assert!(o["requestId"].is_null());
688 assert!(o["retryAfter"].is_null());
689 assert!(o["upstreamStatusCode"].is_null());
690 }
691
692 #[test]
693 fn test_builder_chaining() {
694 let err = AdkError::new(ErrorComponent::Model, ErrorCategory::Unavailable, "m.u", "down")
695 .with_provider("openai")
696 .with_request_id("req-456")
697 .with_upstream_status(503)
698 .with_retry(RetryHint {
699 should_retry: true,
700 retry_after_ms: Some(1000),
701 max_attempts: Some(3),
702 });
703 assert_eq!(err.details.provider.as_deref(), Some("openai"));
704 assert_eq!(err.details.request_id.as_deref(), Some("req-456"));
705 assert_eq!(err.details.upstream_status_code, Some(503));
706 assert!(err.is_retryable());
707 assert_eq!(err.retry.retry_after_ms, Some(1000));
708 assert_eq!(err.retry.max_attempts, Some(3));
709 }
710
711 #[test]
712 fn test_error_component_display() {
713 assert_eq!(ErrorComponent::Agent.to_string(), "agent");
714 assert_eq!(ErrorComponent::Model.to_string(), "model");
715 assert_eq!(ErrorComponent::Graph.to_string(), "graph");
716 assert_eq!(ErrorComponent::Realtime.to_string(), "realtime");
717 assert_eq!(ErrorComponent::Deploy.to_string(), "deploy");
718 }
719
720 #[test]
721 fn test_error_category_display() {
722 assert_eq!(ErrorCategory::InvalidInput.to_string(), "invalid_input");
723 assert_eq!(ErrorCategory::RateLimited.to_string(), "rate_limited");
724 assert_eq!(ErrorCategory::NotFound.to_string(), "not_found");
725 assert_eq!(ErrorCategory::Internal.to_string(), "internal");
726 }
727
728 #[test]
729 #[allow(clippy::unnecessary_literal_unwrap)]
730 fn test_result_type() {
731 let ok: Result<i32> = Ok(42);
732 assert_eq!(ok.unwrap(), 42);
733 let err: Result<i32> = Err(AdkError::config("invalid"));
734 assert!(err.is_err());
735 }
736
737 #[test]
738 fn test_with_details() {
739 let d = ErrorDetails {
740 upstream_status_code: Some(502),
741 request_id: Some("abc".into()),
742 provider: Some("gemini".into()),
743 metadata: HashMap::new(),
744 };
745 let err = AdkError::agent("test").with_details(d);
746 assert_eq!(err.details.upstream_status_code, Some(502));
747 assert_eq!(err.details.request_id.as_deref(), Some("abc"));
748 assert_eq!(err.details.provider.as_deref(), Some("gemini"));
749 }
750
751 #[test]
752 fn test_debug_impl() {
753 let s = format!("{:?}", AdkError::agent("debug test"));
754 assert!(s.contains("AdkError"));
755 assert!(s.contains("agent.legacy"));
756 }
757
758 #[test]
759 fn test_send_sync() {
760 fn assert_send<T: Send>() {}
761 fn assert_sync<T: Sync>() {}
762 assert_send::<AdkError>();
763 assert_sync::<AdkError>();
764 }
765}