1use std::time::Duration;
4
5use fast_rands::Rand;
6use pyo3::prelude::*;
7
8use crate::streaming::StreamError;
9
10#[non_exhaustive]
12#[derive(Debug, Clone, thiserror::Error)]
13pub enum Error {
14 #[error("Agent is not started or has been shut down")]
16 AgentNotStarted,
17 #[error("Backend error: {message}")]
19 BackendError {
20 message: String,
22 },
23
24 #[error("Connection error: {message}")]
26 ConnectionError {
27 message: String,
29 },
30
31 #[error("Quota exceeded, retry after {retry_after:?}")]
33 QuotaExceeded {
34 retry_after: Duration,
36 },
37
38 #[error("Channel closed: {message}")]
40 ChannelClosed {
41 message: String,
43 },
44
45 #[error("Blocked by safety filter")]
47 Safety,
48
49 #[error("Max tokens reached")]
51 MaxTokens,
52
53 #[error("Connection permanently closed: {message}")]
55 ConnectionClosed {
56 message: String,
58 },
59
60 #[error("Timeout after {duration:?}: {operation}")]
62 Timeout {
63 duration: Duration,
65 operation: String,
67 },
68
69 #[error(transparent)]
71 Stream(StreamError),
72
73 #[error("Invalid configuration: {message}")]
75 InvalidConfig {
76 message: String,
78 },
79
80 #[error("I/O error: {message}")]
82 Io {
83 message: String,
85 kind: std::io::ErrorKind,
87 },
88}
89
90impl Error {
91 #[must_use]
103 pub fn is_retryable(&self) -> bool {
104 match self {
105 Self::ConnectionError { .. } | Self::QuotaExceeded { .. } => true,
106 Self::BackendError { message } => message.contains("503"),
107 _ => false,
108 }
109 }
110
111 #[must_use]
117 pub fn is_quota_error(&self) -> bool {
118 match self {
119 Self::QuotaExceeded { .. } => true,
120 Self::BackendError { message } => {
121 message.contains("429")
122 || message.contains("503")
123 || message.contains("RESOURCE_EXHAUSTED")
124 }
125 _ => false,
126 }
127 }
128}
129
130impl From<std::io::Error> for Error {
140 fn from(err: std::io::Error) -> Self {
141 Self::Io {
142 message: err.to_string(),
143 kind: err.kind(),
144 }
145 }
146}
147
148impl From<StreamError> for Error {
149 fn from(err: StreamError) -> Self {
150 let msg = err.message.to_lowercase();
151 if msg.contains("safety") {
152 Self::Safety
153 } else if msg.contains("max tokens") || msg.contains("token limit") {
154 Self::MaxTokens
155 } else {
156 Self::Stream(err)
157 }
158 }
159}
160
161#[doc(hidden)]
162impl From<PyErr> for Error {
163 fn from(err: PyErr) -> Self {
164 Python::with_gil(|py| classify_py_error(py, &err))
165 }
166}
167
168#[doc(hidden)]
169impl From<Error> for PyErr {
170 fn from(err: Error) -> Self {
171 pyo3::exceptions::PyRuntimeError::new_err(err.to_string())
172 }
173}
174
175pub(crate) fn classify_py_error(py: Python<'_>, err: &PyErr) -> Error {
181 if let Some(classified) = check_antigravity_error(py, err) {
182 return classified;
183 }
184 if let Some(classified) = check_pydantic_error(py, err) {
185 return classified;
186 }
187 if let Some(classified) = check_builtin_error(py, err) {
188 return classified;
189 }
190
191 let message = format_backend_error(py, err);
192 Error::BackendError { message }
193}
194
195fn check_antigravity_error(py: Python<'_>, err: &PyErr) -> Option<Error> {
196 if let Ok(types_mod) = py.import_bound("google.antigravity.types") {
197 if let Ok(conn_err_cls) = types_mod.getattr("AntigravityConnectionError")
198 && err.is_instance_bound(py, &conn_err_cls)
199 {
200 return Some(Error::ConnectionError {
201 message: err.to_string(),
202 });
203 }
204 if let Ok(val_err_cls) = types_mod.getattr("AntigravityValidationError")
205 && err.is_instance_bound(py, &val_err_cls)
206 {
207 return Some(Error::BackendError {
208 message: err.to_string(),
209 });
210 }
211 }
212 None
213}
214
215fn check_pydantic_error(py: Python<'_>, err: &PyErr) -> Option<Error> {
216 if let Ok(pydantic) = py.import_bound("pydantic")
217 && let Ok(validation_err_cls) = pydantic.getattr("ValidationError")
218 && err.is_instance_bound(py, &validation_err_cls)
219 {
220 return Some(Error::BackendError {
221 message: err.to_string(),
222 });
223 }
224 None
225}
226
227fn check_builtin_error(py: Python<'_>, err: &PyErr) -> Option<Error> {
228 if let Ok(builtins) = py.import_bound("builtins") {
229 if let Ok(import_err_cls) = builtins.getattr("ImportError")
230 && err.is_instance_bound(py, &import_err_cls)
231 {
232 return Some(Error::BackendError {
233 message: err.to_string(),
234 });
235 }
236 } else {
237 tracing::warn!("Failed to import Python builtins module, skipping ImportError check");
238 }
239 None
240}
241
242fn format_backend_error(py: Python<'_>, err: &PyErr) -> String {
244 let formatted = py
246 .import_bound("traceback")
247 .and_then(|tb_mod| {
248 tb_mod.call_method1(
249 "format_exception",
250 (
251 err.get_type_bound(py),
252 err.value_bound(py),
253 err.traceback_bound(py),
254 ),
255 )
256 })
257 .and_then(|lines| lines.extract::<Vec<String>>());
258
259 match formatted {
260 Ok(lines) => lines.join(""),
261 Err(fmt_err) => {
262 tracing::warn!(error = %fmt_err, "Failed to format backend traceback, using fallback");
263 let traceback = err.traceback_bound(py);
265 traceback.as_ref().map_or_else(
266 || err.to_string(),
267 |tb| {
268 tb.format().map_or_else(
269 |tb_fmt_err| {
270 tracing::warn!(error = %tb_fmt_err, "Failed to format Python traceback");
271 err.to_string()
272 },
273 |tb_str| format!("{}\nTraceback:\n{}", err.value_bound(py), tb_str),
274 )
275 },
276 )
277 }
278 }
279}
280
281pub async fn with_timeout<F, T>(timeout: Duration, operation: &str, f: F) -> Result<T, Error>
289where
290 F: std::future::Future<Output = Result<T, Error>>,
291{
292 match tokio::time::timeout(timeout, f).await {
293 Ok(result) => result,
294 Err(_elapsed) => Err(Error::Timeout {
295 duration: timeout,
296 operation: operation.to_string(),
297 }),
298 }
299}
300
301pub async fn with_retry<F, Fut, T>(max_retries: u32, operation: &str, mut f: F) -> Result<T, Error>
318where
319 F: FnMut() -> Fut,
320 Fut: std::future::Future<Output = Result<T, Error>>,
321{
322 let mut attempt = 0u32;
323 loop {
324 match f().await {
325 Ok(val) => return Ok(val),
326 Err(Error::ConnectionError { ref message }) => {
327 attempt += 1;
328 if attempt > max_retries {
329 tracing::error!(
330 attempts = attempt,
331 operation,
332 "All retries exhausted for connection error: {message}"
333 );
334 return Err(Error::ConnectionError {
335 message: message.clone(),
336 });
337 }
338 let backoff = backoff_duration(attempt);
339 tracing::warn!(
340 attempt,
341 max_retries,
342 backoff_ms = u64::try_from(backoff.as_millis()).unwrap_or_else(|e| {
343 tracing::warn!("Int conversion failed: {}", e);
344 u64::MAX
345 }),
346 operation,
347 "Connection error, retrying: {message}"
348 );
349 tokio::time::sleep(backoff).await;
350 }
351 Err(other) => return Err(other),
352 }
353 }
354}
355
356pub(crate) const MAX_BACKOFF_SECS: u64 = 120;
357
358const BACKOFF_EXPONENT_BASE: u64 = 2;
360const MILLISECONDS_PER_SECOND: u64 = 1000;
362const JITTER_TOTAL_SPREAD_DIVISOR: u64 = 2;
364const JITTER_MIN_SUBTRACT_DIVISOR: u64 = 4;
366
367fn backoff_duration(attempt: u32) -> Duration {
376 let attempt = attempt.max(1);
377 let base_secs = BACKOFF_EXPONENT_BASE
378 .checked_shl(attempt.saturating_sub(1))
379 .unwrap_or(MAX_BACKOFF_SECS)
380 .min(MAX_BACKOFF_SECS);
381 let base_ms = base_secs.saturating_mul(MILLISECONDS_PER_SECOND);
382 let jitter_range = base_ms / JITTER_TOTAL_SPREAD_DIVISOR; let jitter_min = base_ms.saturating_sub(base_ms / JITTER_MIN_SUBTRACT_DIVISOR);
385 let jittered_ms = if jitter_range == 0 {
386 base_ms
387 } else {
388 let limit = u32::try_from(jitter_range).unwrap_or_else(|e| {
389 tracing::warn!("Int conversion failed: {}", e);
390 u32::MAX
391 });
392 jitter_min
393 + (fast_rands::StdRand::new().between(0, limit.saturating_sub(1) as usize) as u64)
394 };
395 Duration::from_millis(jittered_ms)
396}
397
398#[cfg(test)]
399mod tests {
400 use std::sync::atomic::{AtomicU32, Ordering};
401
402 use super::*;
403
404 #[test]
405 fn test_stream_error_conversion() {
406 let safety_err = StreamError {
407 message: "Step error (status=ERROR): Candidate blocked by safety".to_string(),
408 };
409 let mapped_safety = Error::from(safety_err);
410 assert!(matches!(mapped_safety, Error::Safety));
411
412 let max_tokens_err = StreamError {
413 message: "Step error (status=ERROR): Max tokens reached".to_string(),
414 };
415 let mapped_max_tokens = Error::from(max_tokens_err);
416 assert!(matches!(mapped_max_tokens, Error::MaxTokens));
417
418 let other_err = StreamError {
419 message: "Some other connection issue".to_string(),
420 };
421 let mapped_other = Error::from(other_err);
422 match mapped_other {
423 Error::Stream(e) => {
424 assert_eq!(e.message, "Some other connection issue");
425 }
426 other => panic!("Expected Error::Stream, got: {other:?}"),
427 }
428 }
429
430 #[test]
431 fn test_backend_error_from_pyerr() {
432 pyo3::prepare_freethreaded_python();
433 let err = Python::with_gil(|py| {
434 let result: PyResult<()> =
435 py.run_bound("raise ValueError('test error 42')", None, None);
436 result.unwrap_err()
437 });
438
439 let bridge_err: Error = err.into();
440 match &bridge_err {
441 Error::BackendError { message } => {
442 assert!(
443 message.contains("ValueError"),
444 "Expected 'ValueError' in message, got: {message}"
445 );
446 assert!(
447 message.contains("test error 42"),
448 "Expected 'test error 42' in message, got: {message}"
449 );
450 }
451 other => panic!("Expected BackendError, got: {other:?}"),
452 }
453 }
454
455 #[tokio::test]
456 async fn test_timeout_triggers() {
457 let short_timeout = Duration::from_millis(50);
458 let result: Result<(), Error> = with_timeout(short_timeout, "test_op", async {
459 tokio::time::sleep(Duration::from_secs(10)).await;
460 Ok(())
461 })
462 .await;
463
464 match result {
465 Err(Error::Timeout {
466 duration,
467 operation,
468 }) => {
469 assert_eq!(duration, short_timeout);
470 assert_eq!(operation, "test_op");
471 }
472 other => panic!("Expected Timeout, got: {other:?}"),
473 }
474 }
475
476 #[tokio::test]
477 async fn test_timeout_succeeds_when_fast() {
478 let result = with_timeout(Duration::from_secs(5), "fast_op", async { Ok(42) }).await;
479 assert_eq!(result.unwrap(), 42);
480 }
481
482 #[tokio::test]
483 async fn test_retry_succeeds_after_failures() {
484 let counter = AtomicU32::new(0);
485 let result = with_retry(3, "test_retry", || {
486 let attempt = counter.fetch_add(1, Ordering::SeqCst);
487 async move {
488 if attempt < 2 {
489 Err(Error::ConnectionError {
490 message: "transient".to_string(),
491 })
492 } else {
493 Ok(42)
494 }
495 }
496 })
497 .await;
498
499 assert_eq!(result.unwrap(), 42);
500 assert_eq!(counter.load(Ordering::SeqCst), 3);
501 }
502
503 #[tokio::test]
504 async fn test_retry_exhausted() {
505 let counter = AtomicU32::new(0);
506 let result: Result<i32, Error> = with_retry(2, "doomed", || {
507 counter.fetch_add(1, Ordering::SeqCst);
508 async {
509 Err(Error::ConnectionError {
510 message: "always fails".to_string(),
511 })
512 }
513 })
514 .await;
515
516 assert!(matches!(result, Err(Error::ConnectionError { .. })));
517 assert_eq!(counter.load(Ordering::SeqCst), 3);
519 }
520
521 #[tokio::test]
522 async fn test_retry_does_not_retry_non_connection_errors() {
523 let counter = AtomicU32::new(0);
524 let result: Result<i32, Error> = with_retry(5, "python_err", || {
525 counter.fetch_add(1, Ordering::SeqCst);
526 async {
527 Err(Error::BackendError {
528 message: "kaboom".to_string(),
529 })
530 }
531 })
532 .await;
533
534 assert!(matches!(result, Err(Error::BackendError { .. })));
535 assert_eq!(counter.load(Ordering::SeqCst), 1);
536 }
537
538 #[test]
539 fn test_backoff_duration_progression() {
540 let bases_ms: [(u32, u64); 6] = [
542 (1, 2_000),
543 (2, 4_000),
544 (3, 8_000),
545 (4, 16_000),
546 (7, 120_000), (100, 120_000), ];
549 for (attempt, base_ms) in bases_ms {
550 let d = backoff_duration(attempt);
551 let lo = base_ms * 3 / 4;
552 let hi = base_ms * 5 / 4;
553 assert!(
554 d.as_millis() >= u128::from(lo) && d.as_millis() <= u128::from(hi),
555 "backoff_duration({attempt}) = {d:?} outside [{lo}ms, {hi}ms]"
556 );
557 }
558 }
559
560 #[test]
561 fn test_error_display_messages() {
562 let err = Error::BackendError {
563 message: "test".to_string(),
564 };
565 assert_eq!(format!("{err}"), "Backend error: test");
566
567 let err = Error::ConnectionError {
568 message: "lost".to_string(),
569 };
570 assert_eq!(format!("{err}"), "Connection error: lost");
571
572 let err = Error::QuotaExceeded {
573 retry_after: Duration::from_secs(5),
574 };
575 assert!(format!("{err}").contains("5s"));
576
577 let err = Error::ChannelClosed {
578 message: "cmd".to_string(),
579 };
580 assert_eq!(format!("{err}"), "Channel closed: cmd");
581
582 let err = Error::Timeout {
583 duration: Duration::from_secs(30),
584 operation: "chat".to_string(),
585 };
586 assert!(format!("{err}").contains("chat"));
587 }
588
589 #[test]
590 fn test_backoff_duration_zero_attempt() {
591 let d = backoff_duration(0);
593 assert!(
594 d.as_millis() >= 1500 && d.as_millis() <= 2500,
595 "backoff_duration(0) = {d:?} outside [1500ms, 2500ms]"
596 );
597 }
598
599 #[test]
600 fn test_backoff_duration_large_attempt_capped() {
601 let d = backoff_duration(u32::MAX);
603 assert!(
604 d.as_millis() >= 90_000 && d.as_millis() <= 150_000,
605 "backoff_duration(u32::MAX) = {d:?} outside [90s, 150s]"
606 );
607 }
608
609 #[tokio::test]
610 async fn test_timeout_propagates_inner_error() {
611 let result: Result<(), Error> = with_timeout(Duration::from_secs(10), "inner_err", async {
612 Err(Error::BackendError {
613 message: "inner failure".to_string(),
614 })
615 })
616 .await;
617
618 match result {
619 Err(Error::BackendError { message }) => {
620 assert_eq!(message, "inner failure");
621 }
622 other => panic!("Expected BackendError, got: {other:?}"),
623 }
624 }
625
626 #[tokio::test]
627 async fn test_retry_zero_max_retries_still_runs_once() {
628 let counter = AtomicU32::new(0);
629 let result: Result<i32, Error> = with_retry(0, "no_retries", || {
630 counter.fetch_add(1, Ordering::SeqCst);
631 async {
632 Err(Error::ConnectionError {
633 message: "fail".to_string(),
634 })
635 }
636 })
637 .await;
638
639 assert!(matches!(result, Err(Error::ConnectionError { .. })));
640 assert_eq!(counter.load(Ordering::SeqCst), 1);
642 }
643
644 #[tokio::test]
645 async fn test_retry_succeeds_on_first_attempt() {
646 let counter = AtomicU32::new(0);
647 let result = with_retry(5, "instant_success", || {
648 counter.fetch_add(1, Ordering::SeqCst);
649 async { Ok(99) }
650 })
651 .await;
652
653 assert_eq!(result.unwrap(), 99);
654 assert_eq!(counter.load(Ordering::SeqCst), 1);
655 }
656
657 #[tokio::test]
658 async fn test_retry_quota_exceeded_not_retried() {
659 let counter = AtomicU32::new(0);
660 let result: Result<i32, Error> = with_retry(5, "quota", || {
661 counter.fetch_add(1, Ordering::SeqCst);
662 async {
663 Err(Error::QuotaExceeded {
664 retry_after: Duration::from_secs(1),
665 })
666 }
667 })
668 .await;
669
670 assert!(matches!(result, Err(Error::QuotaExceeded { .. })));
671 assert_eq!(counter.load(Ordering::SeqCst), 1);
673 }
674
675 #[tokio::test]
676 async fn test_retry_timeout_not_retried() {
677 let counter = AtomicU32::new(0);
678 let result: Result<i32, Error> = with_retry(5, "timeout", || {
679 counter.fetch_add(1, Ordering::SeqCst);
680 async {
681 Err(Error::Timeout {
682 duration: Duration::from_secs(10),
683 operation: "test".to_string(),
684 })
685 }
686 })
687 .await;
688
689 assert!(matches!(result, Err(Error::Timeout { .. })));
690 assert_eq!(counter.load(Ordering::SeqCst), 1);
691 }
692
693 #[tokio::test]
694 async fn test_retry_channel_closed_not_retried() {
695 let counter = AtomicU32::new(0);
696 let result: Result<i32, Error> = with_retry(5, "channel", || {
697 counter.fetch_add(1, Ordering::SeqCst);
698 async {
699 Err(Error::ChannelClosed {
700 message: "gone".to_string(),
701 })
702 }
703 })
704 .await;
705
706 assert!(matches!(result, Err(Error::ChannelClosed { .. })));
707 assert_eq!(counter.load(Ordering::SeqCst), 1);
708 }
709
710 #[test]
711 fn test_error_debug_format() {
712 let err = Error::BackendError {
713 message: "debug test".to_string(),
714 };
715 let debug = format!("{err:?}");
716 assert!(debug.contains("BackendError"));
717 assert!(debug.contains("debug test"));
718 }
719
720 #[test]
721 fn test_backoff_duration_full_progression() {
722 let base_secs: [u64; 8] = [2, 4, 8, 16, 32, 64, 120, 120];
724 for (i, base) in base_secs.iter().enumerate() {
725 let attempt = u32::try_from(i + 1).unwrap();
726 let d = backoff_duration(attempt);
727 let base_ms = base * 1000;
728 let lo = base_ms * 3 / 4;
729 let hi = base_ms * 5 / 4;
730 assert!(
731 d.as_millis() >= u128::from(lo) && d.as_millis() <= u128::from(hi),
732 "backoff_duration({attempt}) = {d:?} outside [{lo}ms, {hi}ms]"
733 );
734 }
735 }
736
737 #[test]
738 fn test_stream_error_from_conversion() {
739 let stream_err = StreamError {
740 message: "connection reset".to_string(),
741 };
742 let bridge_err = Error::from(stream_err);
743 match &bridge_err {
744 Error::Stream(inner) => {
745 assert_eq!(inner.message, "connection reset");
746 }
747 other => panic!("Expected Stream variant, got: {other:?}"),
748 }
749 }
750
751 #[test]
752 fn test_stream_error_display_through_bridge() {
753 let stream_err = StreamError {
754 message: "quota exceeded".to_string(),
755 };
756 let bridge_err = Error::from(stream_err);
757 let display = format!("{bridge_err}");
758 assert!(
759 display.contains("quota exceeded"),
760 "Expected 'quota exceeded' in display, got: {display}"
761 );
762 }
763
764 #[test]
765 fn test_is_retryable_connection_error() {
766 let err = Error::ConnectionError {
767 message: "timeout".to_string(),
768 };
769 assert!(err.is_retryable());
770 }
771
772 #[test]
773 fn test_quota_exceeded_is_retryable() {
774 let err = Error::QuotaExceeded {
775 retry_after: Duration::from_secs(5),
776 };
777 assert!(err.is_retryable());
778 }
779
780 #[test]
781 fn test_is_not_retryable_backend_error() {
782 let err = Error::BackendError {
783 message: "kaboom".to_string(),
784 };
785 assert!(!err.is_retryable());
786 }
787
788 #[test]
789 fn test_is_not_retryable_channel_closed() {
790 let err = Error::ChannelClosed {
791 message: "gone".to_string(),
792 };
793 assert!(!err.is_retryable());
794 }
795
796 #[test]
797 fn test_is_not_retryable_timeout() {
798 let err = Error::Timeout {
799 duration: Duration::from_secs(30),
800 operation: "chat".to_string(),
801 };
802 assert!(!err.is_retryable());
803 }
804
805 #[test]
806 fn test_is_not_retryable_stream() {
807 let err = Error::Stream(StreamError {
808 message: "stream failed".to_string(),
809 });
810 assert!(!err.is_retryable());
811 }
812
813 #[test]
814 fn test_is_retryable_503_backend_error() {
815 let err = Error::BackendError {
816 message: "request failed (code 503): high demand".to_string(),
817 };
818 assert!(err.is_retryable());
819 }
820
821 #[test]
822 fn test_is_quota_error_quota_exceeded() {
823 let err = Error::QuotaExceeded {
824 retry_after: Duration::from_secs(5),
825 };
826 assert!(err.is_quota_error());
827 }
828
829 #[test]
830 fn test_is_quota_error_backend_429() {
831 let err = Error::BackendError {
832 message: "HTTP 429 Too Many Requests".to_string(),
833 };
834 assert!(err.is_quota_error());
835 }
836
837 #[test]
838 fn test_is_quota_error_resource_exhausted() {
839 let err = Error::BackendError {
840 message: "RESOURCE_EXHAUSTED: quota exceeded".to_string(),
841 };
842 assert!(err.is_quota_error());
843 }
844
845 #[test]
846 fn test_is_not_quota_error_connection() {
847 let err = Error::ConnectionError {
848 message: "timeout".to_string(),
849 };
850 assert!(!err.is_quota_error());
851 }
852
853 #[test]
854 fn test_is_not_quota_error_normal_backend() {
855 let err = Error::BackendError {
856 message: "something else".to_string(),
857 };
858 assert!(!err.is_quota_error());
859 }
860
861 #[test]
862 fn test_is_quota_error_503_high_demand() {
863 let err = Error::BackendError {
864 message: "request failed (code 503): This model is currently experiencing high demand"
865 .to_string(),
866 };
867 assert!(err.is_quota_error());
868 }
869}