1use crate::error::Error;
46use crate::retry_result::RetryResult;
47use std::sync::Arc;
48
49pub trait PollingErrorPolicy: Send + Sync + std::fmt::Debug {
54 #[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
63 fn on_error(
64 &self,
65 loop_start: std::time::Instant,
66 attempt_count: u32,
67 error: Error,
68 ) -> RetryResult;
69
70 #[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
73 fn on_in_progress(
74 &self,
75 _loop_start: std::time::Instant,
76 _attempt_count: u32,
77 _operation_name: &str,
78 ) -> Option<Error> {
79 None
80 }
81}
82
83#[derive(Clone)]
85pub struct PollingErrorPolicyArg(pub(crate) Arc<dyn PollingErrorPolicy>);
86
87impl<T> std::convert::From<T> for PollingErrorPolicyArg
88where
89 T: PollingErrorPolicy + 'static,
90{
91 fn from(value: T) -> Self {
92 Self(Arc::new(value))
93 }
94}
95
96impl std::convert::From<Arc<dyn PollingErrorPolicy>> for PollingErrorPolicyArg {
97 fn from(value: Arc<dyn PollingErrorPolicy>) -> Self {
98 Self(value)
99 }
100}
101
102pub trait PollingErrorPolicyExt: PollingErrorPolicy + Sized {
104 fn with_time_limit(self, maximum_duration: std::time::Duration) -> LimitedElapsedTime<Self> {
126 LimitedElapsedTime::custom(self, maximum_duration)
127 }
128
129 fn with_attempt_limit(self, maximum_attempts: u32) -> LimitedAttemptCount<Self> {
157 LimitedAttemptCount::custom(self, maximum_attempts)
158 }
159}
160
161impl<T: PollingErrorPolicy> PollingErrorPolicyExt for T {}
162
163#[derive(Clone, Debug)]
186pub struct Aip194Strict;
187
188impl PollingErrorPolicy for Aip194Strict {
189 fn on_error(
190 &self,
191 _loop_start: std::time::Instant,
192 _attempt_count: u32,
193 error: Error,
194 ) -> RetryResult {
195 if error.is_transient_and_before_rpc() {
196 return RetryResult::Continue(error);
197 }
198 if error.is_io() {
199 return RetryResult::Continue(error);
200 }
201 if let Some(status) = error.status() {
202 return if status.code == crate::error::rpc::Code::Unavailable {
203 RetryResult::Continue(error)
204 } else {
205 RetryResult::Permanent(error)
206 };
207 }
208
209 match error.http_status_code() {
210 Some(code) if code == http::StatusCode::SERVICE_UNAVAILABLE.as_u16() => {
211 RetryResult::Continue(error)
212 }
213 _ => RetryResult::Permanent(error),
214 }
215 }
216}
217
218#[derive(Clone, Debug)]
239pub struct AlwaysContinue;
240
241impl PollingErrorPolicy for AlwaysContinue {
242 fn on_error(
243 &self,
244 _loop_start: std::time::Instant,
245 _attempt_count: u32,
246 error: Error,
247 ) -> RetryResult {
248 RetryResult::Continue(error)
249 }
250}
251
252#[derive(Debug)]
268pub struct LimitedElapsedTime<P = Aip194Strict>
269where
270 P: PollingErrorPolicy,
271{
272 inner: P,
273 maximum_duration: std::time::Duration,
274}
275
276impl LimitedElapsedTime {
277 pub fn new(maximum_duration: std::time::Duration) -> Self {
292 Self {
293 inner: Aip194Strict,
294 maximum_duration,
295 }
296 }
297}
298
299impl<P> LimitedElapsedTime<P>
300where
301 P: PollingErrorPolicy,
302{
303 pub fn custom(inner: P, maximum_duration: std::time::Duration) -> Self {
318 Self {
319 inner,
320 maximum_duration,
321 }
322 }
323
324 fn in_progress_impl(&self, start: std::time::Instant, operation_name: &str) -> Option<Error> {
325 let now = std::time::Instant::now();
326 if now < start + self.maximum_duration {
327 return None;
328 }
329 Some(Error::exhausted(Exhausted::new(
330 operation_name,
331 "elapsed time",
332 format!("{:?}", now.checked_duration_since(start).unwrap()),
333 format!("{:?}", self.maximum_duration),
334 )))
335 }
336}
337
338impl<P> PollingErrorPolicy for LimitedElapsedTime<P>
339where
340 P: PollingErrorPolicy + 'static,
341{
342 fn on_error(&self, start: std::time::Instant, count: u32, error: Error) -> RetryResult {
343 match self.inner.on_error(start, count, error) {
344 RetryResult::Permanent(e) => RetryResult::Permanent(e),
345 RetryResult::Exhausted(e) => RetryResult::Exhausted(e),
346 RetryResult::Continue(e) => {
347 if std::time::Instant::now() >= start + self.maximum_duration {
348 RetryResult::Exhausted(e)
349 } else {
350 RetryResult::Continue(e)
351 }
352 }
353 }
354 }
355
356 fn on_in_progress(
357 &self,
358 start: std::time::Instant,
359 count: u32,
360 operation_name: &str,
361 ) -> Option<Error> {
362 self.inner
363 .on_in_progress(start, count, operation_name)
364 .or_else(|| self.in_progress_impl(start, operation_name))
365 }
366}
367
368#[derive(Debug)]
382pub struct LimitedAttemptCount<P = Aip194Strict>
383where
384 P: PollingErrorPolicy,
385{
386 inner: P,
387 maximum_attempts: u32,
388}
389
390impl LimitedAttemptCount {
391 pub fn new(maximum_attempts: u32) -> Self {
406 Self {
407 inner: Aip194Strict,
408 maximum_attempts,
409 }
410 }
411}
412
413impl<P> LimitedAttemptCount<P>
414where
415 P: PollingErrorPolicy,
416{
417 pub fn custom(inner: P, maximum_attempts: u32) -> Self {
432 Self {
433 inner,
434 maximum_attempts,
435 }
436 }
437
438 fn in_progress_impl(&self, count: u32, operation_name: &str) -> Option<Error> {
439 if count < self.maximum_attempts {
440 return None;
441 }
442 Some(Error::exhausted(Exhausted::new(
443 operation_name,
444 "attempt count",
445 count.to_string(),
446 self.maximum_attempts.to_string(),
447 )))
448 }
449}
450
451impl<P> PollingErrorPolicy for LimitedAttemptCount<P>
452where
453 P: PollingErrorPolicy,
454{
455 fn on_error(&self, start: std::time::Instant, count: u32, error: Error) -> RetryResult {
456 match self.inner.on_error(start, count, error) {
457 RetryResult::Permanent(e) => RetryResult::Permanent(e),
458 RetryResult::Exhausted(e) => RetryResult::Exhausted(e),
459 RetryResult::Continue(e) => {
460 if count >= self.maximum_attempts {
461 RetryResult::Exhausted(e)
462 } else {
463 RetryResult::Continue(e)
464 }
465 }
466 }
467 }
468
469 fn on_in_progress(
470 &self,
471 start: std::time::Instant,
472 count: u32,
473 operation_name: &str,
474 ) -> Option<Error> {
475 self.inner
476 .on_in_progress(start, count, operation_name)
477 .or_else(|| self.in_progress_impl(count, operation_name))
478 }
479}
480
481#[derive(Debug)]
483pub struct Exhausted {
484 operation_name: String,
485 limit_name: &'static str,
486 value: String,
487 limit: String,
488}
489
490impl Exhausted {
491 pub fn new(
492 operation_name: &str,
493 limit_name: &'static str,
494 value: String,
495 limit: String,
496 ) -> Self {
497 Self {
498 operation_name: operation_name.to_string(),
499 limit_name,
500 value,
501 limit,
502 }
503 }
504}
505
506impl std::fmt::Display for Exhausted {
507 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508 write!(
509 f,
510 "polling loop for {} exhausted, {} value ({}) exceeds limit ({})",
511 self.operation_name, self.limit_name, self.value, self.limit
512 )
513 }
514}
515
516impl std::error::Error for Exhausted {}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521 use crate::error::{CredentialsError, Error};
522 use http::HeaderMap;
523 use std::error::Error as _;
524 use std::time::{Duration, Instant};
525
526 mockall::mock! {
527 #[derive(Debug)]
528 Policy {}
529 impl PollingErrorPolicy for Policy {
530 fn on_error(&self, loop_start: std::time::Instant, attempt_count: u32, error: Error) -> RetryResult;
531 fn on_in_progress(&self, loop_start: std::time::Instant, attempt_count: u32, operation_name: &str) -> Option<Error>;
532 }
533 }
534
535 #[test]
537 fn polling_policy_arg() {
538 let policy = LimitedAttemptCount::new(3);
539 let _ = PollingErrorPolicyArg::from(policy);
540
541 let policy: Arc<dyn PollingErrorPolicy> = Arc::new(LimitedAttemptCount::new(3));
542 let _ = PollingErrorPolicyArg::from(policy);
543 }
544
545 #[test]
546 fn aip194_strict() {
547 let p = Aip194Strict;
548
549 let now = std::time::Instant::now();
550 assert!(p.on_in_progress(now, 0, "unused").is_none());
551 assert!(p.on_error(now, 0, unavailable()).is_continue());
552 assert!(p.on_error(now, 0, permission_denied()).is_permanent());
553 assert!(p.on_error(now, 0, http_unavailable()).is_continue());
554 assert!(p.on_error(now, 0, http_permission_denied()).is_permanent());
555
556 assert!(
557 p.on_error(now, 0, Error::io("err".to_string()))
558 .is_continue()
559 );
560
561 assert!(
562 p.on_error(
563 now,
564 0,
565 Error::authentication(CredentialsError::from_msg(true, "err"))
566 )
567 .is_continue()
568 );
569
570 assert!(
571 p.on_error(now, 0, Error::ser("err".to_string()))
572 .is_permanent()
573 );
574 }
575
576 #[test]
577 fn always_continue() {
578 let p = AlwaysContinue;
579
580 let now = std::time::Instant::now();
581 assert!(p.on_in_progress(now, 0, "unused").is_none());
582 assert!(p.on_error(now, 0, http_unavailable()).is_continue());
583 assert!(p.on_error(now, 0, unavailable()).is_continue());
584 }
585
586 #[test_case::test_case(Error::io("err"))]
587 #[test_case::test_case(Error::authentication(CredentialsError::from_msg(true, "err")))]
588 #[test_case::test_case(Error::ser("err"))]
589 fn always_continue_error_kind(error: Error) {
590 let p = AlwaysContinue;
591 let now = std::time::Instant::now();
592 assert!(p.on_error(now, 0, error).is_continue());
593 }
594
595 #[test]
596 fn with_time_limit() {
597 let policy = AlwaysContinue.with_time_limit(Duration::from_secs(10));
598 assert!(
599 policy
600 .on_error(
601 Instant::now() - Duration::from_secs(1),
602 1,
603 permission_denied()
604 )
605 .is_continue(),
606 "{policy:?}"
607 );
608 assert!(
609 policy
610 .on_error(
611 Instant::now() - Duration::from_secs(20),
612 1,
613 permission_denied()
614 )
615 .is_exhausted(),
616 "{policy:?}"
617 );
618 }
619
620 #[test]
621 fn with_attempt_limit() {
622 let policy = AlwaysContinue.with_attempt_limit(3);
623 assert!(
624 policy
625 .on_error(Instant::now(), 1, permission_denied())
626 .is_continue(),
627 "{policy:?}"
628 );
629 assert!(
630 policy
631 .on_error(Instant::now(), 5, permission_denied())
632 .is_exhausted(),
633 "{policy:?}"
634 );
635 }
636
637 fn http_error(code: u16, message: &str) -> Error {
638 let error = serde_json::json!({"error": {
639 "code": code,
640 "message": message,
641 }});
642 let payload = bytes::Bytes::from_owner(serde_json::to_string(&error).unwrap());
643 Error::http(code, HeaderMap::new(), payload)
644 }
645
646 fn http_unavailable() -> Error {
647 http_error(503, "SERVICE UNAVAILABLE")
648 }
649
650 fn http_permission_denied() -> Error {
651 http_error(403, "PERMISSION DENIED")
652 }
653
654 fn unavailable() -> Error {
655 use crate::error::rpc::Code;
656 let status = crate::error::rpc::Status::default()
657 .set_code(Code::Unavailable)
658 .set_message("UNAVAILABLE");
659 Error::service(status)
660 }
661
662 fn permission_denied() -> Error {
663 use crate::error::rpc::Code;
664 let status = crate::error::rpc::Status::default()
665 .set_code(Code::PermissionDenied)
666 .set_message("PERMISSION_DENIED");
667 Error::service(status)
668 }
669
670 #[test]
671 fn test_limited_elapsed_time_on_error() {
672 let policy = LimitedElapsedTime::new(Duration::from_secs(20));
673 assert!(
674 policy
675 .on_error(Instant::now() - Duration::from_secs(10), 1, unavailable())
676 .is_continue(),
677 "{policy:?}"
678 );
679 assert!(
680 policy
681 .on_error(Instant::now() - Duration::from_secs(30), 1, unavailable())
682 .is_exhausted(),
683 "{policy:?}"
684 );
685 }
686
687 #[test]
688 fn test_limited_elapsed_time_in_progress() {
689 let policy = LimitedElapsedTime::new(Duration::from_secs(20));
690 let err = policy.on_in_progress(Instant::now() - Duration::from_secs(10), 1, "unused");
691 assert!(err.is_none(), "{err:?}");
692 let err = policy
693 .on_in_progress(
694 Instant::now() - Duration::from_secs(30),
695 1,
696 "test-operation-name",
697 )
698 .unwrap();
699 let exhausted = err.source().and_then(|e| e.downcast_ref::<Exhausted>());
700 assert!(exhausted.is_some());
701 }
702
703 #[test]
704 fn test_limited_time_forwards_on_error() {
705 let mut mock = MockPolicy::new();
706 mock.expect_on_error()
707 .times(1..)
708 .returning(|_, _, e| RetryResult::Continue(e));
709
710 let now = std::time::Instant::now();
711 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
712 let rf = policy.on_error(now, 0, transient_error());
713 assert!(rf.is_continue());
714 }
715
716 #[test]
717 fn test_limited_time_forwards_in_progress() {
718 let mut mock = MockPolicy::new();
719 mock.expect_on_in_progress()
720 .times(3)
721 .returning(|_, _, _| None);
722
723 let now = std::time::Instant::now();
724 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
725 assert!(policy.on_in_progress(now, 1, "test-op-name").is_none());
726 assert!(policy.on_in_progress(now, 2, "test-op-name").is_none());
727 assert!(policy.on_in_progress(now, 3, "test-op-name").is_none());
728 }
729
730 #[test]
731 fn test_limited_time_in_progress_returns_inner() {
732 let mut mock = MockPolicy::new();
733 mock.expect_on_in_progress()
734 .times(1)
735 .returning(|_, _, _| Some(transient_error()));
736
737 let now = std::time::Instant::now();
738 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
739 assert!(policy.on_in_progress(now, 1, "test-op-name").is_some());
740 }
741
742 #[test]
743 fn test_limited_time_inner_continues() {
744 let mut mock = MockPolicy::new();
745 mock.expect_on_error()
746 .times(1..)
747 .returning(|_, _, e| RetryResult::Continue(e));
748
749 let now = std::time::Instant::now();
750 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
751 let rf = policy.on_error(now - Duration::from_secs(10), 1, transient_error());
752 assert!(rf.is_continue());
753
754 let rf = policy.on_error(now - Duration::from_secs(70), 1, transient_error());
755 assert!(rf.is_exhausted());
756 }
757
758 #[test]
759 fn test_limited_time_inner_permanent() {
760 let mut mock = MockPolicy::new();
761 mock.expect_on_error()
762 .times(2)
763 .returning(|_, _, e| RetryResult::Permanent(e));
764
765 let now = std::time::Instant::now();
766 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
767
768 let rf = policy.on_error(now - Duration::from_secs(10), 1, transient_error());
769 assert!(rf.is_permanent());
770
771 let rf = policy.on_error(now + Duration::from_secs(10), 1, transient_error());
772 assert!(rf.is_permanent());
773 }
774
775 #[test]
776 fn test_limited_time_inner_exhausted() {
777 let mut mock = MockPolicy::new();
778 mock.expect_on_error()
779 .times(2)
780 .returning(|_, _, e| RetryResult::Exhausted(e));
781
782 let now = std::time::Instant::now();
783 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
784
785 let rf = policy.on_error(now - Duration::from_secs(10), 1, transient_error());
786 assert!(rf.is_exhausted());
787
788 let rf = policy.on_error(now + Duration::from_secs(10), 1, transient_error());
789 assert!(rf.is_exhausted());
790 }
791
792 #[test]
793 fn test_limited_attempt_count_on_error() {
794 let policy = LimitedAttemptCount::new(20);
795 assert!(
796 policy
797 .on_error(Instant::now(), 10, unavailable())
798 .is_continue(),
799 "{policy:?}"
800 );
801 assert!(
802 policy
803 .on_error(Instant::now(), 30, unavailable())
804 .is_exhausted(),
805 "{policy:?}"
806 );
807 }
808
809 #[test]
810 fn test_limited_attempt_count_in_progress() {
811 let policy = LimitedAttemptCount::new(20);
812 let err = policy.on_in_progress(Instant::now(), 10, "unused");
813 assert!(err.is_none(), "{err:?}");
814 let err = policy
815 .on_in_progress(Instant::now(), 30, "test-operation-name")
816 .unwrap();
817 let exhausted = err.source().and_then(|e| e.downcast_ref::<Exhausted>());
818 assert!(exhausted.is_some());
819 }
820
821 #[test]
822 fn test_limited_attempt_count_forwards_on_error() {
823 let mut mock = MockPolicy::new();
824 mock.expect_on_error()
825 .times(1..)
826 .returning(|_, _, e| RetryResult::Continue(e));
827
828 let now = std::time::Instant::now();
829 let policy = LimitedAttemptCount::custom(mock, 3);
830 assert!(policy.on_error(now, 1, transient_error()).is_continue());
831 assert!(policy.on_error(now, 2, transient_error()).is_continue());
832 assert!(policy.on_error(now, 3, transient_error()).is_exhausted());
833 }
834
835 #[test]
836 fn test_limited_attempt_count_forwards_in_progress() {
837 let mut mock = MockPolicy::new();
838 mock.expect_on_in_progress()
839 .times(3)
840 .returning(|_, _, _| None);
841
842 let now = std::time::Instant::now();
843 let policy = LimitedAttemptCount::custom(mock, 5);
844 assert!(policy.on_in_progress(now, 1, "test-op-name").is_none());
845 assert!(policy.on_in_progress(now, 2, "test-op-name").is_none());
846 assert!(policy.on_in_progress(now, 3, "test-op-name").is_none());
847 }
848
849 #[test]
850 fn test_limited_attempt_count_in_progress_returns_inner() {
851 let mut mock = MockPolicy::new();
852 mock.expect_on_in_progress()
853 .times(1)
854 .returning(|_, _, _| Some(unavailable()));
855
856 let now = std::time::Instant::now();
857 let policy = LimitedAttemptCount::custom(mock, 5);
858 assert!(policy.on_in_progress(now, 1, "test-op-name").is_some());
859 }
860
861 #[test]
862 fn test_limited_attempt_count_inner_permanent() {
863 let mut mock = MockPolicy::new();
864 mock.expect_on_error()
865 .times(2)
866 .returning(|_, _, e| RetryResult::Permanent(e));
867 let policy = LimitedAttemptCount::custom(mock, 2);
868 let now = std::time::Instant::now();
869 let rf = policy.on_error(now, 1, Error::ser("err"));
870 assert!(rf.is_permanent());
871
872 let rf = policy.on_error(now, 1, Error::ser("err"));
873 assert!(rf.is_permanent());
874 }
875
876 #[test]
877 fn test_limited_attempt_count_inner_exhausted() {
878 let mut mock = MockPolicy::new();
879 mock.expect_on_error()
880 .times(2)
881 .returning(|_, _, e| RetryResult::Exhausted(e));
882 let policy = LimitedAttemptCount::custom(mock, 2);
883 let now = std::time::Instant::now();
884
885 let rf = policy.on_error(now, 1, transient_error());
886 assert!(rf.is_exhausted());
887
888 let rf = policy.on_error(now, 1, transient_error());
889 assert!(rf.is_exhausted());
890 }
891
892 #[test]
893 fn test_exhausted_fmt() {
894 let exhausted = Exhausted::new(
895 "op-name",
896 "limit-name",
897 "test-value".to_string(),
898 "test-limit".to_string(),
899 );
900 let fmt = format!("{exhausted}");
901 assert!(fmt.contains("op-name"), "{fmt}");
902 assert!(fmt.contains("limit-name"), "{fmt}");
903 assert!(fmt.contains("test-value"), "{fmt}");
904 assert!(fmt.contains("test-limit"), "{fmt}");
905 }
906
907 fn transient_error() -> Error {
908 use crate::error::rpc::{Code, Status};
909 Error::service(
910 Status::default()
911 .set_code(Code::Unavailable)
912 .set_message("try-again"),
913 )
914 }
915}