1use crate::error::Error;
45use crate::polling_state::PollingState;
46use crate::retry_result::RetryResult;
47use std::sync::Arc;
48
49pub trait PollingErrorPolicy: Send + Sync + std::fmt::Debug {
54 #[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
60 fn on_error(&self, state: &PollingState, error: Error) -> RetryResult;
61
62 #[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
68 fn on_in_progress(&self, _state: &PollingState, _operation_name: &str) -> Result<(), Error> {
69 Ok(())
70 }
71}
72
73#[derive(Clone)]
75pub struct PollingErrorPolicyArg(pub(crate) Arc<dyn PollingErrorPolicy>);
76
77impl<T> std::convert::From<T> for PollingErrorPolicyArg
78where
79 T: PollingErrorPolicy + 'static,
80{
81 fn from(value: T) -> Self {
82 Self(Arc::new(value))
83 }
84}
85
86impl std::convert::From<Arc<dyn PollingErrorPolicy>> for PollingErrorPolicyArg {
87 fn from(value: Arc<dyn PollingErrorPolicy>) -> Self {
88 Self(value)
89 }
90}
91
92pub trait PollingErrorPolicyExt: PollingErrorPolicy + Sized {
94 fn with_time_limit(self, maximum_duration: std::time::Duration) -> LimitedElapsedTime<Self> {
117 LimitedElapsedTime::custom(self, maximum_duration)
118 }
119
120 fn with_attempt_limit(self, maximum_attempts: u32) -> LimitedAttemptCount<Self> {
148 LimitedAttemptCount::custom(self, maximum_attempts)
149 }
150}
151
152impl<T: PollingErrorPolicy> PollingErrorPolicyExt for T {}
153
154#[derive(Clone, Debug)]
177pub struct Aip194Strict;
178
179impl PollingErrorPolicy for Aip194Strict {
180 fn on_error(&self, _state: &PollingState, error: Error) -> RetryResult {
181 if error.is_transient_and_before_rpc() {
182 return RetryResult::Continue(error);
183 }
184 if error.is_io() {
185 return RetryResult::Continue(error);
186 }
187 if let Some(status) = error.status() {
188 return if status.code == crate::error::rpc::Code::Unavailable {
189 RetryResult::Continue(error)
190 } else {
191 RetryResult::Permanent(error)
192 };
193 }
194
195 match error.http_status_code() {
196 Some(code) if code == http::StatusCode::SERVICE_UNAVAILABLE.as_u16() => {
197 RetryResult::Continue(error)
198 }
199 _ => RetryResult::Permanent(error),
200 }
201 }
202}
203
204#[derive(Clone, Debug)]
225pub struct AlwaysContinue;
226
227impl PollingErrorPolicy for AlwaysContinue {
228 fn on_error(&self, _state: &PollingState, error: Error) -> RetryResult {
229 RetryResult::Continue(error)
230 }
231}
232
233#[derive(Debug)]
249pub struct LimitedElapsedTime<P = Aip194Strict>
250where
251 P: PollingErrorPolicy,
252{
253 inner: P,
254 maximum_duration: std::time::Duration,
255}
256
257impl LimitedElapsedTime {
258 pub fn new(maximum_duration: std::time::Duration) -> Self {
274 Self {
275 inner: Aip194Strict,
276 maximum_duration,
277 }
278 }
279}
280
281impl<P> LimitedElapsedTime<P>
282where
283 P: PollingErrorPolicy,
284{
285 pub fn custom(inner: P, maximum_duration: std::time::Duration) -> Self {
301 Self {
302 inner,
303 maximum_duration,
304 }
305 }
306
307 fn in_progress_impl(
308 &self,
309 start: std::time::Instant,
310 operation_name: &str,
311 ) -> Result<(), Error> {
312 let now = std::time::Instant::now();
313 if now < start + self.maximum_duration {
314 return Ok(());
315 }
316 Err(Error::exhausted(Exhausted::new(
317 operation_name,
318 "elapsed time",
319 format!("{:?}", now.checked_duration_since(start).unwrap()),
320 format!("{:?}", self.maximum_duration),
321 )))
322 }
323}
324
325impl<P> PollingErrorPolicy for LimitedElapsedTime<P>
326where
327 P: PollingErrorPolicy + 'static,
328{
329 fn on_error(&self, state: &PollingState, error: Error) -> RetryResult {
330 match self.inner.on_error(state, error) {
331 RetryResult::Permanent(e) => RetryResult::Permanent(e),
332 RetryResult::Exhausted(e) => RetryResult::Exhausted(e),
333 RetryResult::Continue(e) => {
334 if std::time::Instant::now() >= state.start + self.maximum_duration {
335 RetryResult::Exhausted(e)
336 } else {
337 RetryResult::Continue(e)
338 }
339 }
340 }
341 }
342
343 fn on_in_progress(&self, state: &PollingState, operation_name: &str) -> Result<(), Error> {
344 self.inner
345 .on_in_progress(state, operation_name)
346 .and_then(|_| self.in_progress_impl(state.start, operation_name))
347 }
348}
349
350#[derive(Debug)]
364pub struct LimitedAttemptCount<P = Aip194Strict>
365where
366 P: PollingErrorPolicy,
367{
368 inner: P,
369 maximum_attempts: u32,
370}
371
372impl LimitedAttemptCount {
373 pub fn new(maximum_attempts: u32) -> Self {
389 Self {
390 inner: Aip194Strict,
391 maximum_attempts,
392 }
393 }
394}
395
396impl<P> LimitedAttemptCount<P>
397where
398 P: PollingErrorPolicy,
399{
400 pub fn custom(inner: P, maximum_attempts: u32) -> Self {
415 Self {
416 inner,
417 maximum_attempts,
418 }
419 }
420
421 fn in_progress_impl(&self, count: u32, operation_name: &str) -> Result<(), Error> {
422 if count < self.maximum_attempts {
423 return Ok(());
424 }
425 Err(Error::exhausted(Exhausted::new(
426 operation_name,
427 "attempt count",
428 count.to_string(),
429 self.maximum_attempts.to_string(),
430 )))
431 }
432}
433
434impl<P> PollingErrorPolicy for LimitedAttemptCount<P>
435where
436 P: PollingErrorPolicy,
437{
438 fn on_error(&self, state: &PollingState, error: Error) -> RetryResult {
439 match self.inner.on_error(state, error) {
440 RetryResult::Permanent(e) => RetryResult::Permanent(e),
441 RetryResult::Exhausted(e) => RetryResult::Exhausted(e),
442 RetryResult::Continue(e) => {
443 if state.attempt_count >= self.maximum_attempts {
444 RetryResult::Exhausted(e)
445 } else {
446 RetryResult::Continue(e)
447 }
448 }
449 }
450 }
451
452 fn on_in_progress(&self, state: &PollingState, operation_name: &str) -> Result<(), Error> {
453 self.inner
454 .on_in_progress(state, operation_name)
455 .and_then(|_| self.in_progress_impl(state.attempt_count, operation_name))
456 }
457}
458
459#[derive(Debug)]
461pub struct Exhausted {
462 operation_name: String,
463 limit_name: &'static str,
464 value: String,
465 limit: String,
466}
467
468impl Exhausted {
469 pub fn new(
471 operation_name: &str,
472 limit_name: &'static str,
473 value: String,
474 limit: String,
475 ) -> Self {
476 Self {
477 operation_name: operation_name.to_string(),
478 limit_name,
479 value,
480 limit,
481 }
482 }
483}
484
485impl std::fmt::Display for Exhausted {
486 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
487 write!(
488 f,
489 "polling loop for {} exhausted, {} value ({}) exceeds limit ({})",
490 self.operation_name, self.limit_name, self.value, self.limit
491 )
492 }
493}
494
495impl std::error::Error for Exhausted {}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use crate::error::{CredentialsError, Error};
501 use http::HeaderMap;
502 use std::error::Error as _;
503 use std::time::{Duration, Instant};
504
505 mockall::mock! {
506 #[derive(Debug)]
507 Policy {}
508 impl PollingErrorPolicy for Policy {
509 fn on_error(&self, state: &PollingState, error: Error) -> RetryResult;
510 fn on_in_progress(&self, state: &PollingState, operation_name: &str) -> Result<(), Error>;
511 }
512 }
513
514 #[test]
516 fn polling_policy_arg() {
517 let policy = LimitedAttemptCount::new(3);
518 let _ = PollingErrorPolicyArg::from(policy);
519
520 let policy: Arc<dyn PollingErrorPolicy> = Arc::new(LimitedAttemptCount::new(3));
521 let _ = PollingErrorPolicyArg::from(policy);
522 }
523
524 #[test]
525 fn aip194_strict() -> anyhow::Result<()> {
526 let p = Aip194Strict;
527 p.on_in_progress(&PollingState::default(), "unused")?;
528 assert!(
529 p.on_error(&PollingState::default(), unavailable())
530 .is_continue()
531 );
532 assert!(
533 p.on_error(&PollingState::default(), permission_denied())
534 .is_permanent()
535 );
536 assert!(
537 p.on_error(&PollingState::default(), http_unavailable())
538 .is_continue()
539 );
540 assert!(
541 p.on_error(&PollingState::default(), http_permission_denied())
542 .is_permanent()
543 );
544
545 assert!(
546 p.on_error(&PollingState::default(), Error::io("err".to_string()))
547 .is_continue()
548 );
549
550 assert!(
551 p.on_error(
552 &PollingState::default(),
553 Error::authentication(CredentialsError::from_msg(true, "err"))
554 )
555 .is_continue()
556 );
557
558 assert!(
559 p.on_error(&PollingState::default(), Error::ser("err".to_string()))
560 .is_permanent()
561 );
562 Ok(())
563 }
564
565 #[test]
566 fn always_continue() {
567 let p = AlwaysContinue;
568
569 let result = p.on_in_progress(&PollingState::default(), "unused");
570 assert!(result.is_ok(), "{result:?}");
571 assert!(
572 p.on_error(&PollingState::default(), http_unavailable())
573 .is_continue()
574 );
575 assert!(
576 p.on_error(&PollingState::default(), unavailable())
577 .is_continue()
578 );
579 }
580
581 #[test_case::test_case(Error::io("err"))]
582 #[test_case::test_case(Error::authentication(CredentialsError::from_msg(true, "err")))]
583 #[test_case::test_case(Error::ser("err"))]
584 fn always_continue_error_kind(error: Error) {
585 let p = AlwaysContinue;
586 assert!(p.on_error(&PollingState::default(), error).is_continue());
587 }
588
589 #[test]
590 fn with_time_limit() {
591 let policy = AlwaysContinue.with_time_limit(Duration::from_secs(10));
592 assert!(
593 policy
594 .on_error(
595 &PollingState::default()
596 .set_start(Instant::now() - Duration::from_secs(1))
597 .set_attempt_count(1_u32),
598 permission_denied()
599 )
600 .is_continue(),
601 "{policy:?}"
602 );
603 assert!(
604 policy
605 .on_error(
606 &PollingState::default()
607 .set_start(Instant::now() - Duration::from_secs(20))
608 .set_attempt_count(1_u32),
609 permission_denied()
610 )
611 .is_exhausted(),
612 "{policy:?}"
613 );
614 }
615
616 #[test]
617 fn with_attempt_limit() {
618 let policy = AlwaysContinue.with_attempt_limit(3);
619 assert!(
620 policy
621 .on_error(
622 &PollingState::default().set_attempt_count(1_u32),
623 permission_denied()
624 )
625 .is_continue(),
626 "{policy:?}"
627 );
628 assert!(
629 policy
630 .on_error(
631 &PollingState::default().set_attempt_count(5_u32),
632 permission_denied()
633 )
634 .is_exhausted(),
635 "{policy:?}"
636 );
637 }
638
639 fn http_error(code: u16, message: &str) -> Error {
640 let error = serde_json::json!({"error": {
641 "code": code,
642 "message": message,
643 }});
644 let payload = bytes::Bytes::from_owner(serde_json::to_string(&error).unwrap());
645 Error::http(code, HeaderMap::new(), payload)
646 }
647
648 fn http_unavailable() -> Error {
649 http_error(503, "SERVICE UNAVAILABLE")
650 }
651
652 fn http_permission_denied() -> Error {
653 http_error(403, "PERMISSION DENIED")
654 }
655
656 fn unavailable() -> Error {
657 use crate::error::rpc::Code;
658 let status = crate::error::rpc::Status::default()
659 .set_code(Code::Unavailable)
660 .set_message("UNAVAILABLE");
661 Error::service(status)
662 }
663
664 fn permission_denied() -> Error {
665 use crate::error::rpc::Code;
666 let status = crate::error::rpc::Status::default()
667 .set_code(Code::PermissionDenied)
668 .set_message("PERMISSION_DENIED");
669 Error::service(status)
670 }
671
672 #[test]
673 fn test_limited_elapsed_time_on_error() {
674 let policy = LimitedElapsedTime::new(Duration::from_secs(20));
675 assert!(
676 policy
677 .on_error(
678 &PollingState::default()
679 .set_start(Instant::now() - Duration::from_secs(10))
680 .set_attempt_count(1_u32),
681 unavailable()
682 )
683 .is_continue(),
684 "{policy:?}"
685 );
686 assert!(
687 policy
688 .on_error(
689 &PollingState::default()
690 .set_start(Instant::now() - Duration::from_secs(30))
691 .set_attempt_count(1_u32),
692 unavailable()
693 )
694 .is_exhausted(),
695 "{policy:?}"
696 );
697 }
698
699 #[test]
700 fn test_limited_elapsed_time_in_progress() {
701 let policy = LimitedElapsedTime::new(Duration::from_secs(20));
702 let result = policy.on_in_progress(
703 &PollingState::default()
704 .set_start(Instant::now() - Duration::from_secs(10))
705 .set_attempt_count(1_u32),
706 "unused",
707 );
708 assert!(result.is_ok(), "{result:?}");
709 let err = policy
710 .on_in_progress(
711 &PollingState::default()
712 .set_start(Instant::now() - Duration::from_secs(30))
713 .set_attempt_count(1_u32),
714 "test-operation-name",
715 )
716 .unwrap_err();
717 let exhausted = err.source().and_then(|e| e.downcast_ref::<Exhausted>());
718 assert!(exhausted.is_some(), "{err:?}");
719 }
720
721 #[test]
722 fn test_limited_time_forwards_on_error() {
723 let mut mock = MockPolicy::new();
724 mock.expect_on_error()
725 .times(1..)
726 .returning(|_, e| RetryResult::Continue(e));
727
728 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
729 let rf = policy.on_error(&PollingState::default(), transient_error());
730 assert!(rf.is_continue());
731 }
732
733 #[test]
734 fn test_limited_time_forwards_in_progress() {
735 let mut mock = MockPolicy::new();
736 mock.expect_on_in_progress()
737 .times(3)
738 .returning(|_, _| Ok(()));
739
740 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
741 assert!(
742 policy
743 .on_in_progress(
744 &PollingState::default().set_attempt_count(1_u32),
745 "test-op-name"
746 )
747 .is_ok()
748 );
749 assert!(
750 policy
751 .on_in_progress(
752 &PollingState::default().set_attempt_count(2_u32),
753 "test-op-name"
754 )
755 .is_ok()
756 );
757 assert!(
758 policy
759 .on_in_progress(
760 &PollingState::default().set_attempt_count(3_u32),
761 "test-op-name"
762 )
763 .is_ok()
764 );
765 }
766
767 #[test]
768 fn test_limited_time_in_progress_returns_inner() {
769 let mut mock = MockPolicy::new();
770 mock.expect_on_in_progress()
771 .times(1)
772 .returning(|_, _| Err(transient_error()));
773
774 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
775 assert!(
776 policy
777 .on_in_progress(
778 &PollingState::default().set_attempt_count(1_u32),
779 "test-op-name"
780 )
781 .is_err()
782 );
783 }
784
785 #[test]
786 fn test_limited_time_inner_continues() {
787 let mut mock = MockPolicy::new();
788 mock.expect_on_error()
789 .times(1..)
790 .returning(|_, e| RetryResult::Continue(e));
791
792 let now = std::time::Instant::now();
793 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
794 let rf = policy.on_error(
795 &PollingState::default()
796 .set_start(now - Duration::from_secs(10))
797 .set_attempt_count(1_u32),
798 transient_error(),
799 );
800 assert!(rf.is_continue());
801
802 let rf = policy.on_error(
803 &PollingState::default()
804 .set_start(now - Duration::from_secs(70))
805 .set_attempt_count(1_u32),
806 transient_error(),
807 );
808 assert!(rf.is_exhausted());
809 }
810
811 #[test]
812 fn test_limited_time_inner_permanent() {
813 let mut mock = MockPolicy::new();
814 mock.expect_on_error()
815 .times(2)
816 .returning(|_, e| RetryResult::Permanent(e));
817
818 let now = std::time::Instant::now();
819 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
820
821 let rf = policy.on_error(
822 &PollingState::default()
823 .set_start(now - Duration::from_secs(10))
824 .set_attempt_count(1_u32),
825 transient_error(),
826 );
827 assert!(rf.is_permanent());
828
829 let rf = policy.on_error(
830 &PollingState::default()
831 .set_start(now + Duration::from_secs(10))
832 .set_attempt_count(1_u32),
833 transient_error(),
834 );
835 assert!(rf.is_permanent());
836 }
837
838 #[test]
839 fn test_limited_time_inner_exhausted() {
840 let mut mock = MockPolicy::new();
841 mock.expect_on_error()
842 .times(2)
843 .returning(|_, e| RetryResult::Exhausted(e));
844
845 let now = std::time::Instant::now();
846 let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
847
848 let rf = policy.on_error(
849 &PollingState::default()
850 .set_start(now - Duration::from_secs(10))
851 .set_attempt_count(1_u32),
852 transient_error(),
853 );
854 assert!(rf.is_exhausted());
855
856 let rf = policy.on_error(
857 &PollingState::default()
858 .set_start(now + Duration::from_secs(10))
859 .set_attempt_count(1_u32),
860 transient_error(),
861 );
862 assert!(rf.is_exhausted());
863 }
864
865 #[test]
866 fn test_limited_attempt_count_on_error() {
867 let policy = LimitedAttemptCount::new(20);
868 assert!(
869 policy
870 .on_error(
871 &PollingState::default().set_attempt_count(10_u32),
872 unavailable()
873 )
874 .is_continue(),
875 "{policy:?}"
876 );
877 assert!(
878 policy
879 .on_error(
880 &PollingState::default().set_attempt_count(30_u32),
881 unavailable()
882 )
883 .is_exhausted(),
884 "{policy:?}"
885 );
886 }
887
888 #[test]
889 fn test_limited_attempt_count_in_progress() {
890 let policy = LimitedAttemptCount::new(20);
891 let result =
892 policy.on_in_progress(&PollingState::default().set_attempt_count(10_u32), "unused");
893 assert!(result.is_ok(), "{result:?}");
894 let err = policy
895 .on_in_progress(
896 &PollingState::default().set_attempt_count(30_u32),
897 "test-operation-name",
898 )
899 .unwrap_err();
900 let exhausted = err.source().and_then(|e| e.downcast_ref::<Exhausted>());
901 assert!(exhausted.is_some(), "{err:?}");
902 }
903
904 #[test]
905 fn test_limited_attempt_count_forwards_on_error() {
906 let mut mock = MockPolicy::new();
907 mock.expect_on_error()
908 .times(1..)
909 .returning(|_, e| RetryResult::Continue(e));
910
911 let policy = LimitedAttemptCount::custom(mock, 3);
912 assert!(
913 policy
914 .on_error(
915 &PollingState::default().set_attempt_count(1_u32),
916 transient_error()
917 )
918 .is_continue()
919 );
920 assert!(
921 policy
922 .on_error(
923 &PollingState::default().set_attempt_count(2_u32),
924 transient_error()
925 )
926 .is_continue()
927 );
928 assert!(
929 policy
930 .on_error(
931 &PollingState::default().set_attempt_count(3_u32),
932 transient_error()
933 )
934 .is_exhausted()
935 );
936 }
937
938 #[test]
939 fn test_limited_attempt_count_forwards_in_progress() {
940 let mut mock = MockPolicy::new();
941 mock.expect_on_in_progress()
942 .times(3)
943 .returning(|_, _| Ok(()));
944
945 let policy = LimitedAttemptCount::custom(mock, 5);
946 assert!(
947 policy
948 .on_in_progress(
949 &PollingState::default().set_attempt_count(1_u32),
950 "test-op-name"
951 )
952 .is_ok()
953 );
954 assert!(
955 policy
956 .on_in_progress(
957 &PollingState::default().set_attempt_count(2_u32),
958 "test-op-name"
959 )
960 .is_ok()
961 );
962 assert!(
963 policy
964 .on_in_progress(
965 &PollingState::default().set_attempt_count(3_u32),
966 "test-op-name"
967 )
968 .is_ok()
969 );
970 }
971
972 #[test]
973 fn test_limited_attempt_count_in_progress_returns_inner() {
974 let mut mock = MockPolicy::new();
975 mock.expect_on_in_progress()
976 .times(1)
977 .returning(|_, _| Err(unavailable()));
978
979 let policy = LimitedAttemptCount::custom(mock, 5);
980 assert!(
981 policy
982 .on_in_progress(
983 &PollingState::default().set_attempt_count(1_u32),
984 "test-op-name"
985 )
986 .is_err()
987 );
988 }
989
990 #[test]
991 fn test_limited_attempt_count_inner_permanent() {
992 let mut mock = MockPolicy::new();
993 mock.expect_on_error()
994 .times(2)
995 .returning(|_, e| RetryResult::Permanent(e));
996 let policy = LimitedAttemptCount::custom(mock, 2);
997 let rf = policy.on_error(
998 &PollingState::default().set_attempt_count(1_u32),
999 Error::ser("err"),
1000 );
1001 assert!(rf.is_permanent());
1002
1003 let rf = policy.on_error(
1004 &PollingState::default().set_attempt_count(1_u32),
1005 Error::ser("err"),
1006 );
1007 assert!(rf.is_permanent());
1008 }
1009
1010 #[test]
1011 fn test_limited_attempt_count_inner_exhausted() {
1012 let mut mock = MockPolicy::new();
1013 mock.expect_on_error()
1014 .times(2)
1015 .returning(|_, e| RetryResult::Exhausted(e));
1016 let policy = LimitedAttemptCount::custom(mock, 2);
1017
1018 let rf = policy.on_error(
1019 &PollingState::default().set_attempt_count(1_u32),
1020 transient_error(),
1021 );
1022 assert!(rf.is_exhausted());
1023
1024 let rf = policy.on_error(
1025 &PollingState::default().set_attempt_count(1_u32),
1026 transient_error(),
1027 );
1028 assert!(rf.is_exhausted());
1029 }
1030
1031 #[test]
1032 fn test_exhausted_fmt() {
1033 let exhausted = Exhausted::new(
1034 "op-name",
1035 "limit-name",
1036 "test-value".to_string(),
1037 "test-limit".to_string(),
1038 );
1039 let fmt = format!("{exhausted}");
1040 assert!(fmt.contains("op-name"), "{fmt}");
1041 assert!(fmt.contains("limit-name"), "{fmt}");
1042 assert!(fmt.contains("test-value"), "{fmt}");
1043 assert!(fmt.contains("test-limit"), "{fmt}");
1044 }
1045
1046 fn transient_error() -> Error {
1047 use crate::error::rpc::{Code, Status};
1048 Error::service(
1049 Status::default()
1050 .set_code(Code::Unavailable)
1051 .set_message("try-again"),
1052 )
1053 }
1054}