1use crate::error::Error;
45use crate::polling_state::PollingState;
46use crate::retry_result::RetryResult;
47use std::sync::Arc;
48
49pub trait PollingErrorPolicy: Send + Sync + std::fmt::Debug {
54 #[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
60 fn on_error(&self, state: &PollingState, error: Error) -> RetryResult;
61
62 #[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
68 fn on_in_progress(&self, _state: &PollingState, _operation_name: &str) -> Result<(), Error> {
69 Ok(())
70 }
71}
72
73#[derive(Clone)]
75pub struct PollingErrorPolicyArg(pub(crate) Arc<dyn PollingErrorPolicy>);
76
77impl<T> std::convert::From<T> for PollingErrorPolicyArg
78where
79 T: PollingErrorPolicy + 'static,
80{
81 fn from(value: T) -> Self {
82 Self(Arc::new(value))
83 }
84}
85
86impl std::convert::From<Arc<dyn PollingErrorPolicy>> for PollingErrorPolicyArg {
87 fn from(value: Arc<dyn PollingErrorPolicy>) -> Self {
88 Self(value)
89 }
90}
91
92pub trait PollingErrorPolicyExt: PollingErrorPolicy + Sized {
94 fn with_time_limit(self, maximum_duration: std::time::Duration) -> LimitedElapsedTime<Self> {
117 LimitedElapsedTime::custom(self, maximum_duration)
118 }
119
120 fn with_attempt_limit(self, maximum_attempts: u32) -> LimitedAttemptCount<Self> {
148 LimitedAttemptCount::custom(self, maximum_attempts)
149 }
150}
151
152impl<T: PollingErrorPolicy> PollingErrorPolicyExt for T {}
153
154#[derive(Clone, Debug)]
177pub struct Aip194Strict;
178
179impl PollingErrorPolicy for Aip194Strict {
180 fn on_error(&self, _state: &PollingState, error: Error) -> RetryResult {
181 if error.is_transient_and_before_rpc() {
182 return RetryResult::Continue(error);
183 }
184 if error.is_io() {
185 return RetryResult::Continue(error);
186 }
187 if let Some(status) = error.status() {
188 return if status.code == crate::error::rpc::Code::Unavailable {
189 RetryResult::Continue(error)
190 } else {
191 RetryResult::Permanent(error)
192 };
193 }
194
195 match error.http_status_code() {
196 Some(code) if code == http::StatusCode::SERVICE_UNAVAILABLE.as_u16() => {
197 RetryResult::Continue(error)
198 }
199 _ => RetryResult::Permanent(error),
200 }
201 }
202}
203
204#[derive(Clone, Debug)]
225pub struct AlwaysContinue;
226
227impl PollingErrorPolicy for AlwaysContinue {
228 fn on_error(&self, _state: &PollingState, error: Error) -> RetryResult {
229 RetryResult::Continue(error)
230 }
231}
232
233#[derive(Debug)]
249pub struct LimitedElapsedTime<P = Aip194Strict>
250where
251 P: PollingErrorPolicy,
252{
253 inner: P,
254 maximum_duration: std::time::Duration,
255}
256
257impl LimitedElapsedTime {
258 pub fn new(maximum_duration: std::time::Duration) -> Self {
274 Self {
275 inner: Aip194Strict,
276 maximum_duration,
277 }
278 }
279}
280
281impl<P> LimitedElapsedTime<P>
282where
283 P: PollingErrorPolicy,
284{
285 pub fn custom(inner: P, maximum_duration: std::time::Duration) -> Self {
301 Self {
302 inner,
303 maximum_duration,
304 }
305 }
306
307 fn in_progress_impl(
308 &self,
309 start: std::time::Instant,
310 operation_name: &str,
311 ) -> Result<(), Error> {
312 let now = std::time::Instant::now();
313 if now < start + self.maximum_duration {
314 return Ok(());
315 }
316 Err(Error::exhausted(Exhausted::new(
317 operation_name,
318 "elapsed time",
319 format!("{:?}", now.checked_duration_since(start).unwrap()),
320 format!("{:?}", self.maximum_duration),
321 )))
322 }
323}
324
325impl<P> PollingErrorPolicy for LimitedElapsedTime<P>
326where
327 P: PollingErrorPolicy + 'static,
328{
329 fn on_error(&self, state: &PollingState, error: Error) -> RetryResult {
330 match self.inner.on_error(state, error) {
331 RetryResult::Permanent(e) => RetryResult::Permanent(e),
332 RetryResult::Exhausted(e) => RetryResult::Exhausted(e),
333 RetryResult::Continue(e) => {
334 if std::time::Instant::now() >= state.start + self.maximum_duration {
335 RetryResult::Exhausted(e)
336 } else {
337 RetryResult::Continue(e)
338 }
339 }
340 }
341 }
342
343 fn on_in_progress(&self, state: &PollingState, operation_name: &str) -> Result<(), Error> {
344 self.inner
345 .on_in_progress(state, operation_name)
346 .and_then(|_| self.in_progress_impl(state.start, operation_name))
347 }
348}
349
350#[derive(Debug)]
364pub struct LimitedAttemptCount<P = Aip194Strict>
365where
366 P: PollingErrorPolicy,
367{
368 inner: P,
369 maximum_attempts: u32,
370}
371
372impl LimitedAttemptCount {
373 pub fn new(maximum_attempts: u32) -> Self {
389 Self {
390 inner: Aip194Strict,
391 maximum_attempts,
392 }
393 }
394}
395
396impl<P> LimitedAttemptCount<P>
397where
398 P: PollingErrorPolicy,
399{
400 pub fn custom(inner: P, maximum_attempts: u32) -> Self {
415 Self {
416 inner,
417 maximum_attempts,
418 }
419 }
420
421 fn in_progress_impl(&self, count: u32, operation_name: &str) -> Result<(), Error> {
422 if count < self.maximum_attempts {
423 return Ok(());
424 }
425 Err(Error::exhausted(Exhausted::new(
426 operation_name,
427 "attempt count",
428 count.to_string(),
429 self.maximum_attempts.to_string(),
430 )))
431 }
432}
433
434impl<P> PollingErrorPolicy for LimitedAttemptCount<P>
435where
436 P: PollingErrorPolicy,
437{
438 fn on_error(&self, state: &PollingState, error: Error) -> RetryResult {
439 match self.inner.on_error(state, error) {
440 RetryResult::Permanent(e) => RetryResult::Permanent(e),
441 RetryResult::Exhausted(e) => RetryResult::Exhausted(e),
442 RetryResult::Continue(e) => {
443 if state.attempt_count >= self.maximum_attempts {
444 RetryResult::Exhausted(e)
445 } else {
446 RetryResult::Continue(e)
447 }
448 }
449 }
450 }
451
452 fn on_in_progress(&self, state: &PollingState, operation_name: &str) -> Result<(), Error> {
453 self.inner
454 .on_in_progress(state, operation_name)
455 .and_then(|_| self.in_progress_impl(state.attempt_count, operation_name))
456 }
457}
458
459#[derive(Debug)]
461pub struct Exhausted {
462 operation_name: String,
463 limit_name: &'static str,
464 value: String,
465 limit: String,
466}
467
468impl Exhausted {
469 pub fn new(
470 operation_name: &str,
471 limit_name: &'static str,
472 value: String,
473 limit: String,
474 ) -> Self {
475 Self {
476 operation_name: operation_name.to_string(),
477 limit_name,
478 value,
479 limit,
480 }
481 }
482}
483
484impl std::fmt::Display for Exhausted {
485 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
486 write!(
487 f,
488 "polling loop for {} exhausted, {} value ({}) exceeds limit ({})",
489 self.operation_name, self.limit_name, self.value, self.limit
490 )
491 }
492}
493
494impl std::error::Error for Exhausted {}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use crate::error::{CredentialsError, Error};
500 use http::HeaderMap;
501 use std::error::Error as _;
502 use std::time::{Duration, Instant};
503
504 mockall::mock! {
505 #[derive(Debug)]
506 Policy {}
507 impl PollingErrorPolicy for Policy {
508 fn on_error(&self, state: &PollingState, error: Error) -> RetryResult;
509 fn on_in_progress(&self, state: &PollingState, operation_name: &str) -> Result<(), Error>;
510 }
511 }
512
513 #[test]
515 fn polling_policy_arg() {
516 let policy = LimitedAttemptCount::new(3);
517 let _ = PollingErrorPolicyArg::from(policy);
518
519 let policy: Arc<dyn PollingErrorPolicy> = Arc::new(LimitedAttemptCount::new(3));
520 let _ = PollingErrorPolicyArg::from(policy);
521 }
522
523 #[test]
524 fn aip194_strict() -> anyhow::Result<()> {
525 let p = Aip194Strict;
526 p.on_in_progress(&PollingState::default(), "unused")?;
527 assert!(
528 p.on_error(&PollingState::default(), unavailable())
529 .is_continue()
530 );
531 assert!(
532 p.on_error(&PollingState::default(), permission_denied())
533 .is_permanent()
534 );
535 assert!(
536 p.on_error(&PollingState::default(), http_unavailable())
537 .is_continue()
538 );
539 assert!(
540 p.on_error(&PollingState::default(), http_permission_denied())
541 .is_permanent()
542 );
543
544 assert!(
545 p.on_error(&PollingState::default(), Error::io("err".to_string()))
546 .is_continue()
547 );
548
549 assert!(
550 p.on_error(
551 &PollingState::default(),
552 Error::authentication(CredentialsError::from_msg(true, "err"))
553 )
554 .is_continue()
555 );
556
557 assert!(
558 p.on_error(&PollingState::default(), Error::ser("err".to_string()))
559 .is_permanent()
560 );
561 Ok(())
562 }
563
564 #[test]
565 fn always_continue() {
566 let p = AlwaysContinue;
567
568 let result = p.on_in_progress(&PollingState::default(), "unused");
569 assert!(result.is_ok(), "{result:?}");
570 assert!(
571 p.on_error(&PollingState::default(), http_unavailable())
572 .is_continue()
573 );
574 assert!(
575 p.on_error(&PollingState::default(), unavailable())
576 .is_continue()
577 );
578 }
579
580 #[test_case::test_case(Error::io("err"))]
581 #[test_case::test_case(Error::authentication(CredentialsError::from_msg(true, "err")))]
582 #[test_case::test_case(Error::ser("err"))]
583 fn always_continue_error_kind(error: Error) {
584 let p = AlwaysContinue;
585 assert!(p.on_error(&PollingState::default(), error).is_continue());
586 }
587
588 #[test]
589 fn with_time_limit() {
590 let policy = AlwaysContinue.with_time_limit(Duration::from_secs(10));
591 assert!(
592 policy
593 .on_error(
594 &PollingState::default()
595 .set_start(Instant::now() - Duration::from_secs(1))
596 .set_attempt_count(1_u32),
597 permission_denied()
598 )
599 .is_continue(),
600 "{policy:?}"
601 );
602 assert!(
603 policy
604 .on_error(
605 &PollingState::default()
606 .set_start(Instant::now() - Duration::from_secs(20))
607 .set_attempt_count(1_u32),
608 permission_denied()
609 )
610 .is_exhausted(),
611 "{policy:?}"
612 );
613 }
614
615 #[test]
616 fn with_attempt_limit() {
617 let policy = AlwaysContinue.with_attempt_limit(3);
618 assert!(
619 policy
620 .on_error(
621 &PollingState::default().set_attempt_count(1_u32),
622 permission_denied()
623 )
624 .is_continue(),
625 "{policy:?}"
626 );
627 assert!(
628 policy
629 .on_error(
630 &PollingState::default().set_attempt_count(5_u32),
631 permission_denied()
632 )
633 .is_exhausted(),
634 "{policy:?}"
635 );
636 }
637
638 fn http_error(code: u16, message: &str) -> Error {
639 let error = serde_json::json!({"error": {
640 "code": code,
641 "message": message,
642 }});
643 let payload = bytes::Bytes::from_owner(serde_json::to_string(&error).unwrap());
644 Error::http(code, HeaderMap::new(), payload)
645 }
646
647 fn http_unavailable() -> Error {
648 http_error(503, "SERVICE UNAVAILABLE")
649 }
650
651 fn http_permission_denied() -> Error {
652 http_error(403, "PERMISSION DENIED")
653 }
654
655 fn unavailable() -> Error {
656 use crate::error::rpc::Code;
657 let status = crate::error::rpc::Status::default()
658 .set_code(Code::Unavailable)
659 .set_message("UNAVAILABLE");
660 Error::service(status)
661 }
662
663 fn permission_denied() -> Error {
664 use crate::error::rpc::Code;
665 let status = crate::error::rpc::Status::default()
666 .set_code(Code::PermissionDenied)
667 .set_message("PERMISSION_DENIED");
668 Error::service(status)
669 }
670
671 #[test]
672 fn test_limited_elapsed_time_on_error() {
673 let policy = LimitedElapsedTime::new(Duration::from_secs(20));
674 assert!(
675 policy
676 .on_error(
677 &PollingState::default()
678 .set_start(Instant::now() - Duration::from_secs(10))
679 .set_attempt_count(1_u32),
680 unavailable()
681 )
682 .is_continue(),
683 "{policy:?}"
684 );
685 assert!(
686 policy
687 .on_error(
688 &PollingState::default()
689 .set_start(Instant::now() - Duration::from_secs(30))
690 .set_attempt_count(1_u32),
691 unavailable()
692 )
693 .is_exhausted(),
694 "{policy:?}"
695 );
696 }
697
698 #[test]
699 fn test_limited_elapsed_time_in_progress() {
700 let policy = LimitedElapsedTime::new(Duration::from_secs(20));
701 let result = policy.on_in_progress(
702 &PollingState::default()
703 .set_start(Instant::now() - Duration::from_secs(10))
704 .set_attempt_count(1_u32),
705 "unused",
706 );
707 assert!(result.is_ok(), "{result:?}");
708 let err = policy
709 .on_in_progress(
710 &PollingState::default()
711 .set_start(Instant::now() - Duration::from_secs(30))
712 .set_attempt_count(1_u32),
713 "test-operation-name",
714 )
715 .unwrap_err();
716 let exhausted = err.source().and_then(|e| e.downcast_ref::<Exhausted>());
717 assert!(exhausted.is_some(), "{err:?}");
718 }
719
720 #[test]
721 fn test_limited_time_forwards_on_error() {
722 let mut mock = MockPolicy::new();
723 mock.expect_on_error()
724 .times(1..)
725 .returning(|_, e| RetryResult::Continue(e));
726
727 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
728 let rf = policy.on_error(&PollingState::default(), transient_error());
729 assert!(rf.is_continue());
730 }
731
732 #[test]
733 fn test_limited_time_forwards_in_progress() {
734 let mut mock = MockPolicy::new();
735 mock.expect_on_in_progress()
736 .times(3)
737 .returning(|_, _| Ok(()));
738
739 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
740 assert!(
741 policy
742 .on_in_progress(
743 &PollingState::default().set_attempt_count(1_u32),
744 "test-op-name"
745 )
746 .is_ok()
747 );
748 assert!(
749 policy
750 .on_in_progress(
751 &PollingState::default().set_attempt_count(2_u32),
752 "test-op-name"
753 )
754 .is_ok()
755 );
756 assert!(
757 policy
758 .on_in_progress(
759 &PollingState::default().set_attempt_count(3_u32),
760 "test-op-name"
761 )
762 .is_ok()
763 );
764 }
765
766 #[test]
767 fn test_limited_time_in_progress_returns_inner() {
768 let mut mock = MockPolicy::new();
769 mock.expect_on_in_progress()
770 .times(1)
771 .returning(|_, _| Err(transient_error()));
772
773 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
774 assert!(
775 policy
776 .on_in_progress(
777 &PollingState::default().set_attempt_count(1_u32),
778 "test-op-name"
779 )
780 .is_err()
781 );
782 }
783
784 #[test]
785 fn test_limited_time_inner_continues() {
786 let mut mock = MockPolicy::new();
787 mock.expect_on_error()
788 .times(1..)
789 .returning(|_, e| RetryResult::Continue(e));
790
791 let now = std::time::Instant::now();
792 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
793 let rf = policy.on_error(
794 &PollingState::default()
795 .set_start(now - Duration::from_secs(10))
796 .set_attempt_count(1_u32),
797 transient_error(),
798 );
799 assert!(rf.is_continue());
800
801 let rf = policy.on_error(
802 &PollingState::default()
803 .set_start(now - Duration::from_secs(70))
804 .set_attempt_count(1_u32),
805 transient_error(),
806 );
807 assert!(rf.is_exhausted());
808 }
809
810 #[test]
811 fn test_limited_time_inner_permanent() {
812 let mut mock = MockPolicy::new();
813 mock.expect_on_error()
814 .times(2)
815 .returning(|_, e| RetryResult::Permanent(e));
816
817 let now = std::time::Instant::now();
818 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
819
820 let rf = policy.on_error(
821 &PollingState::default()
822 .set_start(now - Duration::from_secs(10))
823 .set_attempt_count(1_u32),
824 transient_error(),
825 );
826 assert!(rf.is_permanent());
827
828 let rf = policy.on_error(
829 &PollingState::default()
830 .set_start(now + Duration::from_secs(10))
831 .set_attempt_count(1_u32),
832 transient_error(),
833 );
834 assert!(rf.is_permanent());
835 }
836
837 #[test]
838 fn test_limited_time_inner_exhausted() {
839 let mut mock = MockPolicy::new();
840 mock.expect_on_error()
841 .times(2)
842 .returning(|_, e| RetryResult::Exhausted(e));
843
844 let now = std::time::Instant::now();
845 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
846
847 let rf = policy.on_error(
848 &PollingState::default()
849 .set_start(now - Duration::from_secs(10))
850 .set_attempt_count(1_u32),
851 transient_error(),
852 );
853 assert!(rf.is_exhausted());
854
855 let rf = policy.on_error(
856 &PollingState::default()
857 .set_start(now + Duration::from_secs(10))
858 .set_attempt_count(1_u32),
859 transient_error(),
860 );
861 assert!(rf.is_exhausted());
862 }
863
864 #[test]
865 fn test_limited_attempt_count_on_error() {
866 let policy = LimitedAttemptCount::new(20);
867 assert!(
868 policy
869 .on_error(
870 &PollingState::default().set_attempt_count(10_u32),
871 unavailable()
872 )
873 .is_continue(),
874 "{policy:?}"
875 );
876 assert!(
877 policy
878 .on_error(
879 &PollingState::default().set_attempt_count(30_u32),
880 unavailable()
881 )
882 .is_exhausted(),
883 "{policy:?}"
884 );
885 }
886
887 #[test]
888 fn test_limited_attempt_count_in_progress() {
889 let policy = LimitedAttemptCount::new(20);
890 let result =
891 policy.on_in_progress(&PollingState::default().set_attempt_count(10_u32), "unused");
892 assert!(result.is_ok(), "{result:?}");
893 let err = policy
894 .on_in_progress(
895 &PollingState::default().set_attempt_count(30_u32),
896 "test-operation-name",
897 )
898 .unwrap_err();
899 let exhausted = err.source().and_then(|e| e.downcast_ref::<Exhausted>());
900 assert!(exhausted.is_some(), "{err:?}");
901 }
902
903 #[test]
904 fn test_limited_attempt_count_forwards_on_error() {
905 let mut mock = MockPolicy::new();
906 mock.expect_on_error()
907 .times(1..)
908 .returning(|_, e| RetryResult::Continue(e));
909
910 let policy = LimitedAttemptCount::custom(mock, 3);
911 assert!(
912 policy
913 .on_error(
914 &PollingState::default().set_attempt_count(1_u32),
915 transient_error()
916 )
917 .is_continue()
918 );
919 assert!(
920 policy
921 .on_error(
922 &PollingState::default().set_attempt_count(2_u32),
923 transient_error()
924 )
925 .is_continue()
926 );
927 assert!(
928 policy
929 .on_error(
930 &PollingState::default().set_attempt_count(3_u32),
931 transient_error()
932 )
933 .is_exhausted()
934 );
935 }
936
937 #[test]
938 fn test_limited_attempt_count_forwards_in_progress() {
939 let mut mock = MockPolicy::new();
940 mock.expect_on_in_progress()
941 .times(3)
942 .returning(|_, _| Ok(()));
943
944 let policy = LimitedAttemptCount::custom(mock, 5);
945 assert!(
946 policy
947 .on_in_progress(
948 &PollingState::default().set_attempt_count(1_u32),
949 "test-op-name"
950 )
951 .is_ok()
952 );
953 assert!(
954 policy
955 .on_in_progress(
956 &PollingState::default().set_attempt_count(2_u32),
957 "test-op-name"
958 )
959 .is_ok()
960 );
961 assert!(
962 policy
963 .on_in_progress(
964 &PollingState::default().set_attempt_count(3_u32),
965 "test-op-name"
966 )
967 .is_ok()
968 );
969 }
970
971 #[test]
972 fn test_limited_attempt_count_in_progress_returns_inner() {
973 let mut mock = MockPolicy::new();
974 mock.expect_on_in_progress()
975 .times(1)
976 .returning(|_, _| Err(unavailable()));
977
978 let policy = LimitedAttemptCount::custom(mock, 5);
979 assert!(
980 policy
981 .on_in_progress(
982 &PollingState::default().set_attempt_count(1_u32),
983 "test-op-name"
984 )
985 .is_err()
986 );
987 }
988
989 #[test]
990 fn test_limited_attempt_count_inner_permanent() {
991 let mut mock = MockPolicy::new();
992 mock.expect_on_error()
993 .times(2)
994 .returning(|_, e| RetryResult::Permanent(e));
995 let policy = LimitedAttemptCount::custom(mock, 2);
996 let rf = policy.on_error(
997 &PollingState::default().set_attempt_count(1_u32),
998 Error::ser("err"),
999 );
1000 assert!(rf.is_permanent());
1001
1002 let rf = policy.on_error(
1003 &PollingState::default().set_attempt_count(1_u32),
1004 Error::ser("err"),
1005 );
1006 assert!(rf.is_permanent());
1007 }
1008
1009 #[test]
1010 fn test_limited_attempt_count_inner_exhausted() {
1011 let mut mock = MockPolicy::new();
1012 mock.expect_on_error()
1013 .times(2)
1014 .returning(|_, e| RetryResult::Exhausted(e));
1015 let policy = LimitedAttemptCount::custom(mock, 2);
1016
1017 let rf = policy.on_error(
1018 &PollingState::default().set_attempt_count(1_u32),
1019 transient_error(),
1020 );
1021 assert!(rf.is_exhausted());
1022
1023 let rf = policy.on_error(
1024 &PollingState::default().set_attempt_count(1_u32),
1025 transient_error(),
1026 );
1027 assert!(rf.is_exhausted());
1028 }
1029
1030 #[test]
1031 fn test_exhausted_fmt() {
1032 let exhausted = Exhausted::new(
1033 "op-name",
1034 "limit-name",
1035 "test-value".to_string(),
1036 "test-limit".to_string(),
1037 );
1038 let fmt = format!("{exhausted}");
1039 assert!(fmt.contains("op-name"), "{fmt}");
1040 assert!(fmt.contains("limit-name"), "{fmt}");
1041 assert!(fmt.contains("test-value"), "{fmt}");
1042 assert!(fmt.contains("test-limit"), "{fmt}");
1043 }
1044
1045 fn transient_error() -> Error {
1046 use crate::error::rpc::{Code, Status};
1047 Error::service(
1048 Status::default()
1049 .set_code(Code::Unavailable)
1050 .set_message("try-again"),
1051 )
1052 }
1053}