1use std::time::Duration;
4
5use fast_rands::Rand;
6use pyo3::prelude::*;
7
8use crate::streaming::StreamError;
9
10#[non_exhaustive]
12#[derive(Debug, Clone, thiserror::Error)]
13pub enum Error {
14 #[error("Agent is not started or has been shut down")]
16 AgentNotStarted,
17 #[error("Backend error: {message}")]
19 BackendError {
20 message: String,
22 },
23
24 #[error("Connection error: {message}")]
26 ConnectionError {
27 message: String,
29 },
30
31 #[error("Quota exceeded, retry after {retry_after:?}")]
33 QuotaExceeded {
34 retry_after: Duration,
36 },
37
38 #[error("Channel closed: {message}")]
40 ChannelClosed {
41 message: String,
43 },
44
45 #[error("Blocked by safety filter")]
47 Safety,
48
49 #[error("Max tokens reached")]
51 MaxTokens,
52
53 #[error("Connection permanently closed: {message}")]
55 ConnectionClosed {
56 message: String,
58 },
59
60 #[error("Timeout after {duration:?}: {operation}")]
62 Timeout {
63 duration: Duration,
65 operation: String,
67 },
68
69 #[error(transparent)]
71 Stream(StreamError),
72
73 #[error("Invalid configuration: {message}")]
75 InvalidConfig {
76 message: String,
78 },
79
80 #[error("I/O error: {message}")]
82 Io {
83 message: String,
85 kind: std::io::ErrorKind,
87 },
88}
89
90impl Error {
91 #[must_use]
103 pub fn is_retryable(&self) -> bool {
104 match self {
105 Self::ConnectionError { .. } | Self::QuotaExceeded { .. } => true,
106 Self::BackendError { message } => message.contains("503"),
107 _ => false,
108 }
109 }
110
111 #[must_use]
117 pub fn is_quota_error(&self) -> bool {
118 match self {
119 Self::QuotaExceeded { .. } => true,
120 Self::BackendError { message } => {
121 message.contains("429")
122 || message.contains("503")
123 || message.contains("RESOURCE_EXHAUSTED")
124 }
125 _ => false,
126 }
127 }
128}
129
130impl From<std::io::Error> for Error {
140 fn from(err: std::io::Error) -> Self {
141 Self::Io {
142 message: err.to_string(),
143 kind: err.kind(),
144 }
145 }
146}
147
148impl From<StreamError> for Error {
149 fn from(err: StreamError) -> Self {
150 let msg = err.message.to_lowercase();
151 if msg.contains("safety") {
152 Self::Safety
153 } else if msg.contains("max tokens") || msg.contains("token limit") {
154 Self::MaxTokens
155 } else {
156 Self::Stream(err)
157 }
158 }
159}
160
161#[doc(hidden)]
162impl From<PyErr> for Error {
163 fn from(err: PyErr) -> Self {
164 Python::attach(|py| classify_py_error(py, &err))
165 }
166}
167
168#[doc(hidden)]
169impl From<Error> for PyErr {
170 fn from(err: Error) -> Self {
171 pyo3::exceptions::PyRuntimeError::new_err(err.to_string())
172 }
173}
174
175pub(crate) fn classify_py_error(py: Python<'_>, err: &PyErr) -> Error {
181 if let Some(classified) = check_antigravity_error(py, err) {
182 return classified;
183 }
184 if let Some(classified) = check_pydantic_error(py, err) {
185 return classified;
186 }
187 if let Some(classified) = check_builtin_error(py, err) {
188 return classified;
189 }
190
191 let message = format_backend_error(py, err);
192 Error::BackendError { message }
193}
194
195fn check_antigravity_error(py: Python<'_>, err: &PyErr) -> Option<Error> {
196 match py.import("google.antigravity.types") {
197 Ok(types_mod) => {
198 if let Ok(conn_err_cls) = types_mod.getattr("AntigravityConnectionError")
199 && err.is_instance(py, &conn_err_cls)
200 {
201 return Some(Error::ConnectionError {
202 message: err.to_string(),
203 });
204 }
205 if let Ok(val_err_cls) = types_mod.getattr("AntigravityValidationError")
206 && err.is_instance(py, &val_err_cls)
207 {
208 return Some(Error::BackendError {
209 message: err.to_string(),
210 });
211 }
212 }
213 Err(import_err) => {
214 tracing::debug!(
215 error = %import_err,
216 "antigravity.types not available, skipping AntigravityError classification"
217 );
218 }
219 }
220 None
221}
222
223fn check_pydantic_error(py: Python<'_>, err: &PyErr) -> Option<Error> {
224 match py.import("pydantic") {
225 Ok(pydantic) => {
226 if let Ok(validation_err_cls) = pydantic.getattr("ValidationError")
227 && err.is_instance(py, &validation_err_cls)
228 {
229 return Some(Error::BackendError {
230 message: err.to_string(),
231 });
232 }
233 }
234 Err(import_err) => {
235 tracing::debug!(
236 error = %import_err,
237 "pydantic not available, skipping ValidationError classification"
238 );
239 }
240 }
241 None
242}
243
244fn check_builtin_error(py: Python<'_>, err: &PyErr) -> Option<Error> {
245 if let Ok(builtins) = py.import("builtins") {
246 if let Ok(import_err_cls) = builtins.getattr("ImportError")
247 && err.is_instance(py, &import_err_cls)
248 {
249 return Some(Error::BackendError {
250 message: err.to_string(),
251 });
252 }
253 } else {
254 tracing::warn!("Failed to import Python builtins module, skipping ImportError check");
255 }
256 None
257}
258
259fn format_backend_error(py: Python<'_>, err: &PyErr) -> String {
261 let formatted = py
263 .import("traceback")
264 .and_then(|tb_mod| {
265 tb_mod.call_method1(
266 "format_exception",
267 (err.get_type(py), err.value(py), err.traceback(py)),
268 )
269 })
270 .and_then(|lines| lines.extract::<Vec<String>>());
271
272 match formatted {
273 Ok(lines) => lines.join(""),
274 Err(fmt_err) => {
275 tracing::warn!(error = %fmt_err, "Failed to format backend traceback, using fallback");
276 let traceback = err.traceback(py);
278 traceback.as_ref().map_or_else(
279 || err.to_string(),
280 |tb| {
281 tb.format().map_or_else(
282 |tb_fmt_err| {
283 tracing::warn!(error = %tb_fmt_err, "Failed to format Python traceback");
284 err.to_string()
285 },
286 |tb_str| format!("{}\nTraceback:\n{}", err.value(py), tb_str),
287 )
288 },
289 )
290 }
291 }
292}
293
294pub async fn with_timeout<F, T>(timeout: Duration, operation: &str, f: F) -> Result<T, Error>
302where
303 F: std::future::Future<Output = Result<T, Error>>,
304{
305 match tokio::time::timeout(timeout, f).await {
306 Ok(result) => result,
307 Err(_elapsed) => Err(Error::Timeout {
308 duration: timeout,
309 operation: operation.to_string(),
310 }),
311 }
312}
313
314pub async fn with_retry<F, Fut, T>(max_retries: u32, operation: &str, mut f: F) -> Result<T, Error>
331where
332 F: FnMut() -> Fut,
333 Fut: std::future::Future<Output = Result<T, Error>>,
334{
335 let mut attempt = 0u32;
336 loop {
337 match f().await {
338 Ok(val) => return Ok(val),
339 Err(Error::ConnectionError { ref message }) => {
340 attempt += 1;
341 if attempt > max_retries {
342 tracing::error!(
343 attempts = attempt,
344 operation,
345 "All retries exhausted for connection error: {message}"
346 );
347 return Err(Error::ConnectionError {
348 message: message.clone(),
349 });
350 }
351 let backoff = backoff_duration(attempt);
352 tracing::warn!(
353 attempt,
354 max_retries,
355 backoff_ms = u64::try_from(backoff.as_millis()).unwrap_or_else(|e| {
356 tracing::warn!("Int conversion failed: {}", e);
357 u64::MAX
358 }),
359 operation,
360 "Connection error, retrying: {message}"
361 );
362 tokio::time::sleep(backoff).await;
363 }
364 Err(other) => return Err(other),
365 }
366 }
367}
368
369pub(crate) const MAX_BACKOFF_SECS: u64 = 120;
370
371const BACKOFF_EXPONENT_BASE: u64 = 2;
373const MILLISECONDS_PER_SECOND: u64 = 1000;
375const JITTER_TOTAL_SPREAD_DIVISOR: u64 = 2;
377const JITTER_MIN_SUBTRACT_DIVISOR: u64 = 4;
379
380pub(crate) fn backoff_duration(attempt: u32) -> Duration {
389 let attempt = attempt.max(1);
390 let base_secs = BACKOFF_EXPONENT_BASE
391 .checked_shl(attempt.saturating_sub(1))
392 .unwrap_or(MAX_BACKOFF_SECS)
393 .min(MAX_BACKOFF_SECS);
394 let base_ms = base_secs.saturating_mul(MILLISECONDS_PER_SECOND);
395 let jitter_range = base_ms / JITTER_TOTAL_SPREAD_DIVISOR; let jitter_min = base_ms.saturating_sub(base_ms / JITTER_MIN_SUBTRACT_DIVISOR);
398 let jittered_ms = if jitter_range == 0 {
399 base_ms
400 } else {
401 let limit = u32::try_from(jitter_range).unwrap_or_else(|e| {
402 tracing::warn!("Int conversion failed: {}", e);
403 u32::MAX
404 });
405 jitter_min
406 + (fast_rands::StdRand::new().between(0, limit.saturating_sub(1) as usize) as u64)
407 };
408 Duration::from_millis(jittered_ms)
409}
410
411#[cfg(test)]
412mod tests {
413 use std::sync::atomic::{AtomicU32, Ordering};
414
415 use super::*;
416
417 #[test]
418 fn test_stream_error_conversion() {
419 let safety_err = StreamError {
420 message: "Step error (status=ERROR): Candidate blocked by safety".to_string(),
421 };
422 let mapped_safety = Error::from(safety_err);
423 assert!(matches!(mapped_safety, Error::Safety));
424
425 let max_tokens_err = StreamError {
426 message: "Step error (status=ERROR): Max tokens reached".to_string(),
427 };
428 let mapped_max_tokens = Error::from(max_tokens_err);
429 assert!(matches!(mapped_max_tokens, Error::MaxTokens));
430
431 let other_err = StreamError {
432 message: "Some other connection issue".to_string(),
433 };
434 let mapped_other = Error::from(other_err);
435 match mapped_other {
436 Error::Stream(e) => {
437 assert_eq!(e.message, "Some other connection issue");
438 }
439 other => panic!("Expected Error::Stream, got: {other:?}"),
440 }
441 }
442
443 #[test]
444 fn test_backend_error_from_pyerr() {
445 Python::initialize();
446 let err = Python::attach(|py| {
447 let result: PyResult<()> = py.run(c"raise ValueError('test error 42')", None, None);
448 result.unwrap_err()
449 });
450
451 let bridge_err: Error = err.into();
452 match &bridge_err {
453 Error::BackendError { message } => {
454 assert!(
455 message.contains("ValueError"),
456 "Expected 'ValueError' in message, got: {message}"
457 );
458 assert!(
459 message.contains("test error 42"),
460 "Expected 'test error 42' in message, got: {message}"
461 );
462 }
463 other => panic!("Expected BackendError, got: {other:?}"),
464 }
465 }
466
467 #[tokio::test]
468 async fn test_timeout_triggers() {
469 let short_timeout = Duration::from_millis(50);
470 let result: Result<(), Error> = with_timeout(short_timeout, "test_op", async {
471 tokio::time::sleep(Duration::from_secs(10)).await;
472 Ok(())
473 })
474 .await;
475
476 match result {
477 Err(Error::Timeout {
478 duration,
479 operation,
480 }) => {
481 assert_eq!(duration, short_timeout);
482 assert_eq!(operation, "test_op");
483 }
484 other => panic!("Expected Timeout, got: {other:?}"),
485 }
486 }
487
488 #[tokio::test]
489 async fn test_timeout_succeeds_when_fast() {
490 let result = with_timeout(Duration::from_secs(5), "fast_op", async { Ok(42) }).await;
491 assert_eq!(result.unwrap(), 42);
492 }
493
494 #[tokio::test]
495 async fn test_retry_succeeds_after_failures() {
496 let counter = AtomicU32::new(0);
497 let result = with_retry(3, "test_retry", || {
498 let attempt = counter.fetch_add(1, Ordering::SeqCst);
499 async move {
500 if attempt < 2 {
501 Err(Error::ConnectionError {
502 message: "transient".to_string(),
503 })
504 } else {
505 Ok(42)
506 }
507 }
508 })
509 .await;
510
511 assert_eq!(result.unwrap(), 42);
512 assert_eq!(counter.load(Ordering::SeqCst), 3);
513 }
514
515 #[tokio::test]
516 async fn test_retry_exhausted() {
517 let counter = AtomicU32::new(0);
518 let result: Result<i32, Error> = with_retry(2, "doomed", || {
519 counter.fetch_add(1, Ordering::SeqCst);
520 async {
521 Err(Error::ConnectionError {
522 message: "always fails".to_string(),
523 })
524 }
525 })
526 .await;
527
528 assert!(matches!(result, Err(Error::ConnectionError { .. })));
529 assert_eq!(counter.load(Ordering::SeqCst), 3);
531 }
532
533 #[tokio::test]
534 async fn test_retry_does_not_retry_non_connection_errors() {
535 let counter = AtomicU32::new(0);
536 let result: Result<i32, Error> = with_retry(5, "python_err", || {
537 counter.fetch_add(1, Ordering::SeqCst);
538 async {
539 Err(Error::BackendError {
540 message: "kaboom".to_string(),
541 })
542 }
543 })
544 .await;
545
546 assert!(matches!(result, Err(Error::BackendError { .. })));
547 assert_eq!(counter.load(Ordering::SeqCst), 1);
548 }
549
550 #[test]
551 fn test_backoff_duration_progression() {
552 let bases_ms: [(u32, u64); 6] = [
554 (1, 2_000),
555 (2, 4_000),
556 (3, 8_000),
557 (4, 16_000),
558 (7, 120_000), (100, 120_000), ];
561 for (attempt, base_ms) in bases_ms {
562 let d = backoff_duration(attempt);
563 let lo = base_ms * 3 / 4;
564 let hi = base_ms * 5 / 4;
565 assert!(
566 d.as_millis() >= u128::from(lo) && d.as_millis() <= u128::from(hi),
567 "backoff_duration({attempt}) = {d:?} outside [{lo}ms, {hi}ms]"
568 );
569 }
570 }
571
572 #[test]
573 fn test_error_display_messages() {
574 let err = Error::BackendError {
575 message: "test".to_string(),
576 };
577 assert_eq!(format!("{err}"), "Backend error: test");
578
579 let err = Error::ConnectionError {
580 message: "lost".to_string(),
581 };
582 assert_eq!(format!("{err}"), "Connection error: lost");
583
584 let err = Error::QuotaExceeded {
585 retry_after: Duration::from_secs(5),
586 };
587 assert!(format!("{err}").contains("5s"));
588
589 let err = Error::ChannelClosed {
590 message: "cmd".to_string(),
591 };
592 assert_eq!(format!("{err}"), "Channel closed: cmd");
593
594 let err = Error::Timeout {
595 duration: Duration::from_secs(30),
596 operation: "chat".to_string(),
597 };
598 assert!(format!("{err}").contains("chat"));
599 }
600
601 #[test]
602 fn test_backoff_duration_zero_attempt() {
603 let d = backoff_duration(0);
605 assert!(
606 d.as_millis() >= 1500 && d.as_millis() <= 2500,
607 "backoff_duration(0) = {d:?} outside [1500ms, 2500ms]"
608 );
609 }
610
611 #[test]
612 fn test_backoff_duration_large_attempt_capped() {
613 let d = backoff_duration(u32::MAX);
615 assert!(
616 d.as_millis() >= 90_000 && d.as_millis() <= 150_000,
617 "backoff_duration(u32::MAX) = {d:?} outside [90s, 150s]"
618 );
619 }
620
621 #[tokio::test]
622 async fn test_timeout_propagates_inner_error() {
623 let result: Result<(), Error> = with_timeout(Duration::from_secs(10), "inner_err", async {
624 Err(Error::BackendError {
625 message: "inner failure".to_string(),
626 })
627 })
628 .await;
629
630 match result {
631 Err(Error::BackendError { message }) => {
632 assert_eq!(message, "inner failure");
633 }
634 other => panic!("Expected BackendError, got: {other:?}"),
635 }
636 }
637
638 #[tokio::test]
639 async fn test_retry_zero_max_retries_still_runs_once() {
640 let counter = AtomicU32::new(0);
641 let result: Result<i32, Error> = with_retry(0, "no_retries", || {
642 counter.fetch_add(1, Ordering::SeqCst);
643 async {
644 Err(Error::ConnectionError {
645 message: "fail".to_string(),
646 })
647 }
648 })
649 .await;
650
651 assert!(matches!(result, Err(Error::ConnectionError { .. })));
652 assert_eq!(counter.load(Ordering::SeqCst), 1);
654 }
655
656 #[tokio::test]
657 async fn test_retry_succeeds_on_first_attempt() {
658 let counter = AtomicU32::new(0);
659 let result = with_retry(5, "instant_success", || {
660 counter.fetch_add(1, Ordering::SeqCst);
661 async { Ok(99) }
662 })
663 .await;
664
665 assert_eq!(result.unwrap(), 99);
666 assert_eq!(counter.load(Ordering::SeqCst), 1);
667 }
668
669 #[tokio::test]
670 async fn test_retry_quota_exceeded_not_retried() {
671 let counter = AtomicU32::new(0);
672 let result: Result<i32, Error> = with_retry(5, "quota", || {
673 counter.fetch_add(1, Ordering::SeqCst);
674 async {
675 Err(Error::QuotaExceeded {
676 retry_after: Duration::from_secs(1),
677 })
678 }
679 })
680 .await;
681
682 assert!(matches!(result, Err(Error::QuotaExceeded { .. })));
683 assert_eq!(counter.load(Ordering::SeqCst), 1);
685 }
686
687 #[tokio::test]
688 async fn test_retry_timeout_not_retried() {
689 let counter = AtomicU32::new(0);
690 let result: Result<i32, Error> = with_retry(5, "timeout", || {
691 counter.fetch_add(1, Ordering::SeqCst);
692 async {
693 Err(Error::Timeout {
694 duration: Duration::from_secs(10),
695 operation: "test".to_string(),
696 })
697 }
698 })
699 .await;
700
701 assert!(matches!(result, Err(Error::Timeout { .. })));
702 assert_eq!(counter.load(Ordering::SeqCst), 1);
703 }
704
705 #[tokio::test]
706 async fn test_retry_channel_closed_not_retried() {
707 let counter = AtomicU32::new(0);
708 let result: Result<i32, Error> = with_retry(5, "channel", || {
709 counter.fetch_add(1, Ordering::SeqCst);
710 async {
711 Err(Error::ChannelClosed {
712 message: "gone".to_string(),
713 })
714 }
715 })
716 .await;
717
718 assert!(matches!(result, Err(Error::ChannelClosed { .. })));
719 assert_eq!(counter.load(Ordering::SeqCst), 1);
720 }
721
722 #[test]
723 fn test_error_debug_format() {
724 let err = Error::BackendError {
725 message: "debug test".to_string(),
726 };
727 let debug = format!("{err:?}");
728 assert!(debug.contains("BackendError"));
729 assert!(debug.contains("debug test"));
730 }
731
732 #[test]
733 fn test_backoff_duration_full_progression() {
734 let base_secs: [u64; 8] = [2, 4, 8, 16, 32, 64, 120, 120];
736 for (i, base) in base_secs.iter().enumerate() {
737 let attempt = u32::try_from(i + 1).unwrap();
738 let d = backoff_duration(attempt);
739 let base_ms = base * 1000;
740 let lo = base_ms * 3 / 4;
741 let hi = base_ms * 5 / 4;
742 assert!(
743 d.as_millis() >= u128::from(lo) && d.as_millis() <= u128::from(hi),
744 "backoff_duration({attempt}) = {d:?} outside [{lo}ms, {hi}ms]"
745 );
746 }
747 }
748
749 #[test]
750 fn test_stream_error_from_conversion() {
751 let stream_err = StreamError {
752 message: "connection reset".to_string(),
753 };
754 let bridge_err = Error::from(stream_err);
755 match &bridge_err {
756 Error::Stream(inner) => {
757 assert_eq!(inner.message, "connection reset");
758 }
759 other => panic!("Expected Stream variant, got: {other:?}"),
760 }
761 }
762
763 #[test]
764 fn test_stream_error_display_through_bridge() {
765 let stream_err = StreamError {
766 message: "quota exceeded".to_string(),
767 };
768 let bridge_err = Error::from(stream_err);
769 let display = format!("{bridge_err}");
770 assert!(
771 display.contains("quota exceeded"),
772 "Expected 'quota exceeded' in display, got: {display}"
773 );
774 }
775
776 #[test]
777 fn test_is_retryable_connection_error() {
778 let err = Error::ConnectionError {
779 message: "timeout".to_string(),
780 };
781 assert!(err.is_retryable());
782 }
783
784 #[test]
785 fn test_quota_exceeded_is_retryable() {
786 let err = Error::QuotaExceeded {
787 retry_after: Duration::from_secs(5),
788 };
789 assert!(err.is_retryable());
790 }
791
792 #[test]
793 fn test_is_not_retryable_backend_error() {
794 let err = Error::BackendError {
795 message: "kaboom".to_string(),
796 };
797 assert!(!err.is_retryable());
798 }
799
800 #[test]
801 fn test_is_not_retryable_channel_closed() {
802 let err = Error::ChannelClosed {
803 message: "gone".to_string(),
804 };
805 assert!(!err.is_retryable());
806 }
807
808 #[test]
809 fn test_is_not_retryable_timeout() {
810 let err = Error::Timeout {
811 duration: Duration::from_secs(30),
812 operation: "chat".to_string(),
813 };
814 assert!(!err.is_retryable());
815 }
816
817 #[test]
818 fn test_is_not_retryable_stream() {
819 let err = Error::Stream(StreamError {
820 message: "stream failed".to_string(),
821 });
822 assert!(!err.is_retryable());
823 }
824
825 #[test]
826 fn test_is_retryable_503_backend_error() {
827 let err = Error::BackendError {
828 message: "request failed (code 503): high demand".to_string(),
829 };
830 assert!(err.is_retryable());
831 }
832
833 #[test]
834 fn test_is_quota_error_quota_exceeded() {
835 let err = Error::QuotaExceeded {
836 retry_after: Duration::from_secs(5),
837 };
838 assert!(err.is_quota_error());
839 }
840
841 #[test]
842 fn test_is_quota_error_backend_429() {
843 let err = Error::BackendError {
844 message: "HTTP 429 Too Many Requests".to_string(),
845 };
846 assert!(err.is_quota_error());
847 }
848
849 #[test]
850 fn test_is_quota_error_resource_exhausted() {
851 let err = Error::BackendError {
852 message: "RESOURCE_EXHAUSTED: quota exceeded".to_string(),
853 };
854 assert!(err.is_quota_error());
855 }
856
857 #[test]
858 fn test_is_not_quota_error_connection() {
859 let err = Error::ConnectionError {
860 message: "timeout".to_string(),
861 };
862 assert!(!err.is_quota_error());
863 }
864
865 #[test]
866 fn test_is_not_quota_error_normal_backend() {
867 let err = Error::BackendError {
868 message: "something else".to_string(),
869 };
870 assert!(!err.is_quota_error());
871 }
872
873 #[test]
874 fn test_is_quota_error_503_high_demand() {
875 let err = Error::BackendError {
876 message: "request failed (code 503): This model is currently experiencing high demand"
877 .to_string(),
878 };
879 assert!(err.is_quota_error());
880 }
881}