mod too_many_requests;
use crate::error::Error;
use crate::retry_result::RetryResult;
use crate::retry_state::RetryState;
use crate::throttle_result::ThrottleResult;
use std::sync::Arc;
use std::time::Duration;
pub use too_many_requests::TooManyRequests;
pub trait RetryPolicy: Send + Sync + std::fmt::Debug {
#[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
fn on_error(&self, state: &RetryState, error: Error) -> RetryResult;
#[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
fn on_throttle(&self, _state: &RetryState, error: Error) -> ThrottleResult {
ThrottleResult::Continue(error)
}
#[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
fn remaining_time(&self, _state: &RetryState) -> Option<Duration> {
None
}
}
#[derive(Clone, Debug)]
pub struct RetryPolicyArg(Arc<dyn RetryPolicy>);
impl<T> std::convert::From<T> for RetryPolicyArg
where
T: RetryPolicy + 'static,
{
fn from(value: T) -> Self {
Self(Arc::new(value))
}
}
impl std::convert::From<Arc<dyn RetryPolicy>> for RetryPolicyArg {
fn from(value: Arc<dyn RetryPolicy>) -> Self {
Self(value)
}
}
impl From<RetryPolicyArg> for Arc<dyn RetryPolicy> {
fn from(value: RetryPolicyArg) -> Arc<dyn RetryPolicy> {
value.0
}
}
pub trait RetryPolicyExt: RetryPolicy + Sized {
fn with_time_limit(self, maximum_duration: Duration) -> LimitedElapsedTime<Self> {
LimitedElapsedTime::custom(self, maximum_duration)
}
fn with_attempt_limit(self, maximum_attempts: u32) -> LimitedAttemptCount<Self> {
LimitedAttemptCount::custom(self, maximum_attempts)
}
fn continue_on_too_many_requests(self) -> TooManyRequests<Self> {
TooManyRequests::new(self)
}
}
impl<T: RetryPolicy> RetryPolicyExt for T {}
#[derive(Clone, Debug)]
pub struct Aip194Strict;
impl RetryPolicy for Aip194Strict {
fn on_error(&self, state: &RetryState, error: Error) -> RetryResult {
use crate::error::rpc::Code;
use http::StatusCode;
if error.is_transient_and_before_rpc() {
return RetryResult::Continue(error);
}
if !state.idempotent {
return RetryResult::Permanent(error);
}
if error.is_io() {
return RetryResult::Continue(error);
}
if error.status().is_some_and(|s| s.code == Code::Unavailable) {
return RetryResult::Continue(error);
}
if error
.http_status_code()
.is_some_and(|code| code == StatusCode::SERVICE_UNAVAILABLE.as_u16())
{
return RetryResult::Continue(error);
}
RetryResult::Permanent(error)
}
}
#[derive(Clone, Debug)]
pub struct AlwaysRetry;
impl RetryPolicy for AlwaysRetry {
fn on_error(&self, _state: &RetryState, error: Error) -> RetryResult {
RetryResult::Continue(error)
}
}
#[derive(Clone, Debug)]
pub struct NeverRetry;
impl RetryPolicy for NeverRetry {
fn on_error(&self, _state: &RetryState, error: Error) -> RetryResult {
RetryResult::Exhausted(error)
}
}
#[derive(thiserror::Error, Debug)]
pub struct LimitedElapsedTimeError {
maximum_duration: Duration,
#[source]
source: Error,
}
impl LimitedElapsedTimeError {
pub(crate) fn new(maximum_duration: Duration, source: Error) -> Self {
Self {
maximum_duration,
source,
}
}
pub fn maximum_duration(&self) -> Duration {
self.maximum_duration
}
}
impl std::fmt::Display for LimitedElapsedTimeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"retry policy is exhausted after {}s, the last retry attempt was throttled",
self.maximum_duration.as_secs_f64()
)
}
}
#[derive(Debug)]
pub struct LimitedElapsedTime<P = Aip194Strict>
where
P: RetryPolicy,
{
inner: P,
maximum_duration: Duration,
}
impl LimitedElapsedTime {
pub fn new(maximum_duration: Duration) -> Self {
Self {
inner: Aip194Strict,
maximum_duration,
}
}
}
impl<P> LimitedElapsedTime<P>
where
P: RetryPolicy,
{
pub fn custom(inner: P, maximum_duration: Duration) -> Self {
Self {
inner,
maximum_duration,
}
}
fn error_if_exhausted(&self, state: &RetryState, error: Error) -> ThrottleResult {
let deadline = state.start + self.maximum_duration;
let now = tokio::time::Instant::now().into_std();
if now < deadline {
ThrottleResult::Continue(error)
} else {
ThrottleResult::Exhausted(Error::exhausted(LimitedElapsedTimeError::new(
self.maximum_duration,
error,
)))
}
}
}
impl<P> RetryPolicy for LimitedElapsedTime<P>
where
P: RetryPolicy + 'static,
{
fn on_error(&self, state: &RetryState, error: Error) -> RetryResult {
match self.inner.on_error(state, error) {
RetryResult::Permanent(e) => RetryResult::Permanent(e),
RetryResult::Exhausted(e) => RetryResult::Exhausted(e),
RetryResult::Continue(e) => {
if tokio::time::Instant::now().into_std() >= state.start + self.maximum_duration {
RetryResult::Exhausted(e)
} else {
RetryResult::Continue(e)
}
}
}
}
fn on_throttle(&self, state: &RetryState, error: Error) -> ThrottleResult {
match self.inner.on_throttle(state, error) {
ThrottleResult::Continue(e) => self.error_if_exhausted(state, e),
ThrottleResult::Exhausted(e) => ThrottleResult::Exhausted(e),
}
}
fn remaining_time(&self, state: &RetryState) -> Option<Duration> {
let deadline = state.start + self.maximum_duration;
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now().into_std());
if let Some(inner) = self.inner.remaining_time(state) {
return Some(std::cmp::min(remaining, inner));
}
Some(remaining)
}
}
#[derive(Debug)]
pub struct LimitedAttemptCount<P = Aip194Strict>
where
P: RetryPolicy,
{
inner: P,
maximum_attempts: u32,
}
impl LimitedAttemptCount {
pub fn new(maximum_attempts: u32) -> Self {
Self {
inner: Aip194Strict,
maximum_attempts,
}
}
}
impl<P> LimitedAttemptCount<P>
where
P: RetryPolicy,
{
pub fn custom(inner: P, maximum_attempts: u32) -> Self {
Self {
inner,
maximum_attempts,
}
}
}
impl<P> RetryPolicy for LimitedAttemptCount<P>
where
P: RetryPolicy,
{
fn on_error(&self, state: &RetryState, error: Error) -> RetryResult {
match self.inner.on_error(state, error) {
RetryResult::Permanent(e) => RetryResult::Permanent(e),
RetryResult::Exhausted(e) => RetryResult::Exhausted(e),
RetryResult::Continue(e) => {
if state.attempt_count >= self.maximum_attempts {
RetryResult::Exhausted(e)
} else {
RetryResult::Continue(e)
}
}
}
}
fn on_throttle(&self, state: &RetryState, error: Error) -> ThrottleResult {
assert!(state.attempt_count < self.maximum_attempts);
self.inner.on_throttle(state, error)
}
fn remaining_time(&self, state: &RetryState) -> Option<Duration> {
self.inner.remaining_time(state)
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use http::HeaderMap;
use std::error::Error as StdError;
use std::time::Instant;
#[test]
fn retry_policy_arg() {
let policy = LimitedAttemptCount::new(3);
let _ = RetryPolicyArg::from(policy);
let policy: Arc<dyn RetryPolicy> = Arc::new(LimitedAttemptCount::new(3));
let _ = RetryPolicyArg::from(policy);
}
#[test]
fn aip194_strict() {
let p = Aip194Strict;
let now = Instant::now();
assert!(
p.on_error(&idempotent_state(now), unavailable())
.is_continue()
);
assert!(
p.on_error(&non_idempotent_state(now), unavailable())
.is_permanent()
);
assert!(matches!(
p.on_throttle(&idempotent_state(now), unavailable()),
ThrottleResult::Continue(_)
));
assert!(
p.on_error(&idempotent_state(now), unknown_and_503())
.is_continue()
);
assert!(
p.on_error(&non_idempotent_state(now), unknown_and_503())
.is_permanent()
);
assert!(matches!(
p.on_throttle(&idempotent_state(now), unknown_and_503()),
ThrottleResult::Continue(_)
));
assert!(
p.on_error(&idempotent_state(now), permission_denied())
.is_permanent()
);
assert!(
p.on_error(&non_idempotent_state(now), permission_denied())
.is_permanent()
);
assert!(
p.on_error(&idempotent_state(now), http_unavailable())
.is_continue()
);
assert!(
p.on_error(&non_idempotent_state(now), http_unavailable())
.is_permanent()
);
assert!(matches!(
p.on_throttle(&idempotent_state(now), http_unavailable()),
ThrottleResult::Continue(_)
));
assert!(
p.on_error(&idempotent_state(now), http_permission_denied())
.is_permanent()
);
assert!(
p.on_error(&non_idempotent_state(now), http_permission_denied())
.is_permanent()
);
assert!(
p.on_error(&idempotent_state(now), Error::io("err".to_string()))
.is_continue()
);
assert!(
p.on_error(&non_idempotent_state(now), Error::io("err".to_string()))
.is_permanent()
);
assert!(
p.on_error(&idempotent_state(now), pre_rpc_transient())
.is_continue()
);
assert!(
p.on_error(&non_idempotent_state(now), pre_rpc_transient())
.is_continue()
);
assert!(
p.on_error(&idempotent_state(now), Error::ser("err"))
.is_permanent()
);
assert!(
p.on_error(&non_idempotent_state(now), Error::ser("err"))
.is_permanent()
);
assert!(
p.on_error(&idempotent_state(now), Error::deser("err"))
.is_permanent()
);
assert!(
p.on_error(&non_idempotent_state(now), Error::deser("err"))
.is_permanent()
);
assert!(
p.remaining_time(&idempotent_state(now)).is_none(),
"p={p:?}, now={now:?}"
);
}
#[test]
fn always_retry() {
let p = AlwaysRetry;
let now = Instant::now();
assert!(
p.remaining_time(&idempotent_state(now)).is_none(),
"p={p:?}, now={now:?}"
);
assert!(
p.on_error(&idempotent_state(now), http_unavailable())
.is_continue()
);
assert!(
p.on_error(&non_idempotent_state(now), http_unavailable())
.is_continue()
);
assert!(matches!(
p.on_throttle(&idempotent_state(now), http_unavailable()),
ThrottleResult::Continue(_)
));
assert!(
p.on_error(&idempotent_state(now), unavailable())
.is_continue()
);
assert!(
p.on_error(&non_idempotent_state(now), unavailable())
.is_continue()
);
}
#[test_case::test_case(true, Error::io("err"))]
#[test_case::test_case(true, pre_rpc_transient())]
#[test_case::test_case(true, Error::ser("err"))]
#[test_case::test_case(false, Error::io("err"))]
#[test_case::test_case(false, pre_rpc_transient())]
#[test_case::test_case(false, Error::ser("err"))]
fn always_retry_error_kind(idempotent: bool, error: Error) {
let p = AlwaysRetry;
let now = Instant::now();
let state = if idempotent {
idempotent_state(now)
} else {
non_idempotent_state(now)
};
assert!(p.on_error(&state, error).is_continue());
}
#[test]
fn never_retry() {
let p = NeverRetry;
let now = Instant::now();
assert!(
p.remaining_time(&idempotent_state(now)).is_none(),
"p={p:?}, now={now:?}"
);
assert!(
p.on_error(&idempotent_state(now), http_unavailable())
.is_exhausted()
);
assert!(
p.on_error(&non_idempotent_state(now), http_unavailable())
.is_exhausted()
);
assert!(matches!(
p.on_throttle(&idempotent_state(now), http_unavailable()),
ThrottleResult::Continue(_)
));
assert!(
p.on_error(&idempotent_state(now), unavailable())
.is_exhausted()
);
assert!(
p.on_error(&non_idempotent_state(now), unavailable())
.is_exhausted()
);
assert!(
p.on_error(&idempotent_state(now), http_permission_denied())
.is_exhausted()
);
assert!(
p.on_error(&non_idempotent_state(now), http_permission_denied())
.is_exhausted()
);
}
#[test_case::test_case(true, Error::io("err"))]
#[test_case::test_case(true, pre_rpc_transient())]
#[test_case::test_case(true, Error::ser("err"))]
#[test_case::test_case(false, Error::io("err"))]
#[test_case::test_case(false, pre_rpc_transient())]
#[test_case::test_case(false, Error::ser("err"))]
fn never_retry_error_kind(idempotent: bool, error: Error) {
let p = NeverRetry;
let now = Instant::now();
let state = if idempotent {
idempotent_state(now)
} else {
non_idempotent_state(now)
};
assert!(p.on_error(&state, error).is_exhausted());
}
fn pre_rpc_transient() -> Error {
use crate::error::CredentialsError;
Error::authentication(CredentialsError::from_msg(true, "err"))
}
fn http_unavailable() -> Error {
Error::http(
503_u16,
HeaderMap::new(),
bytes::Bytes::from_owner("SERVICE UNAVAILABLE".to_string()),
)
}
fn http_permission_denied() -> Error {
Error::http(
403_u16,
HeaderMap::new(),
bytes::Bytes::from_owner("PERMISSION DENIED".to_string()),
)
}
fn unavailable() -> Error {
use crate::error::rpc::Code;
let status = crate::error::rpc::Status::default()
.set_code(Code::Unavailable)
.set_message("UNAVAILABLE");
Error::service(status)
}
fn unknown_and_503() -> Error {
use crate::error::rpc::Code;
let status = crate::error::rpc::Status::default()
.set_code(Code::Unknown)
.set_message("UNAVAILABLE");
Error::service_full(status, Some(503), None, Some("source error".into()))
}
fn permission_denied() -> Error {
use crate::error::rpc::Code;
let status = crate::error::rpc::Status::default()
.set_code(Code::PermissionDenied)
.set_message("PERMISSION_DENIED");
Error::service(status)
}
mockall::mock! {
#[derive(Debug)]
pub(crate) Policy {}
impl RetryPolicy for Policy {
fn on_error(&self, state: &RetryState, error: Error) -> RetryResult;
fn on_throttle(&self, state: &RetryState, error: Error) -> ThrottleResult;
fn remaining_time(&self, state: &RetryState) -> Option<Duration>;
}
}
#[test]
fn limited_elapsed_time_error() {
let limit = Duration::from_secs(123) + Duration::from_millis(567);
let err = LimitedElapsedTimeError::new(limit, unavailable());
assert_eq!(err.maximum_duration(), limit);
let fmt = err.to_string();
assert!(fmt.contains("123.567s"), "display={fmt}, debug={err:?}");
assert!(err.source().is_some(), "{err:?}");
}
#[test]
fn test_limited_time_forwards() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(1..)
.returning(|_, e| RetryResult::Continue(e));
mock.expect_on_throttle()
.times(1..)
.returning(|_, e| ThrottleResult::Continue(e));
mock.expect_remaining_time().times(1).returning(|_| None);
let now = Instant::now();
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let rf = policy.on_error(&idempotent_state(now), transient_error());
assert!(rf.is_continue());
let rt = policy.remaining_time(&idempotent_state(now));
assert!(rt.is_some(), "policy={policy:?}, now={now:?}");
let e = policy.on_throttle(&idempotent_state(now), transient_error());
assert!(matches!(e, ThrottleResult::Continue(_)));
}
#[test]
fn test_limited_time_on_throttle_continue() {
let mut mock = MockPolicy::new();
mock.expect_on_throttle()
.times(1..)
.returning(|_, e| ThrottleResult::Continue(e));
let now = Instant::now();
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let rf = policy.on_throttle(
&idempotent_state(now - Duration::from_secs(50)),
unavailable(),
);
assert!(matches!(rf, ThrottleResult::Continue(_)), "{rf:?}");
let rf = policy.on_throttle(
&idempotent_state(now - Duration::from_secs(70)),
unavailable(),
);
assert!(matches!(rf, ThrottleResult::Exhausted(_)), "{rf:?}");
}
#[test]
fn test_limited_time_on_throttle_exhausted() {
let mut mock = MockPolicy::new();
mock.expect_on_throttle()
.times(1..)
.returning(|_, e| ThrottleResult::Exhausted(e));
let now = Instant::now();
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let rf = policy.on_throttle(
&idempotent_state(now - Duration::from_secs(50)),
unavailable(),
);
assert!(matches!(rf, ThrottleResult::Exhausted(_)), "{rf:?}");
}
#[test]
fn test_limited_time_inner_continues() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(1..)
.returning(|_, e| RetryResult::Continue(e));
let now = Instant::now();
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let rf = policy.on_error(
&idempotent_state(now - Duration::from_secs(10)),
transient_error(),
);
assert!(rf.is_continue());
let rf = policy.on_error(
&idempotent_state(now - Duration::from_secs(70)),
transient_error(),
);
assert!(rf.is_exhausted());
}
#[test]
fn test_limited_time_inner_permanent() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(2)
.returning(|_, e| RetryResult::Permanent(e));
let now = Instant::now();
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let rf = policy.on_error(
&non_idempotent_state(now - Duration::from_secs(10)),
transient_error(),
);
assert!(rf.is_permanent());
let rf = policy.on_error(
&non_idempotent_state(now + Duration::from_secs(10)),
transient_error(),
);
assert!(rf.is_permanent());
}
#[test]
fn test_limited_time_inner_exhausted() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(2)
.returning(|_, e| RetryResult::Exhausted(e));
let now = Instant::now();
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let rf = policy.on_error(
&non_idempotent_state(now - Duration::from_secs(10)),
transient_error(),
);
assert!(rf.is_exhausted());
let rf = policy.on_error(
&non_idempotent_state(now + Duration::from_secs(10)),
transient_error(),
);
assert!(rf.is_exhausted());
}
#[test]
fn test_limited_time_remaining_inner_longer() {
let mut mock = MockPolicy::new();
mock.expect_remaining_time()
.times(1)
.returning(|_| Some(Duration::from_secs(30)));
let now = Instant::now();
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let remaining = policy.remaining_time(&idempotent_state(now - Duration::from_secs(55)));
assert!(remaining <= Some(Duration::from_secs(5)), "{remaining:?}");
}
#[test]
fn test_limited_time_remaining_inner_shorter() {
let mut mock = MockPolicy::new();
mock.expect_remaining_time()
.times(1)
.returning(|_| Some(Duration::from_secs(5)));
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let now = Instant::now();
let remaining = policy.remaining_time(&idempotent_state(now - Duration::from_secs(5)));
assert!(remaining <= Some(Duration::from_secs(10)), "{remaining:?}");
}
#[test]
fn test_limited_time_remaining_inner_is_none() {
let mut mock = MockPolicy::new();
mock.expect_remaining_time().times(1).returning(|_| None);
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let now = Instant::now();
let remaining = policy.remaining_time(&idempotent_state(now - Duration::from_secs(50)));
assert!(remaining <= Some(Duration::from_secs(10)), "{remaining:?}");
}
#[test]
fn test_limited_attempt_count_on_error() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(1..)
.returning(|_, e| RetryResult::Continue(e));
let now = Instant::now();
let policy = LimitedAttemptCount::custom(mock, 3);
assert!(
policy
.on_error(
&idempotent_state(now).set_attempt_count(1_u32),
transient_error()
)
.is_continue()
);
assert!(
policy
.on_error(
&idempotent_state(now).set_attempt_count(2_u32),
transient_error()
)
.is_continue()
);
assert!(
policy
.on_error(
&idempotent_state(now).set_attempt_count(3_u32),
transient_error()
)
.is_exhausted()
);
}
#[test]
fn test_limited_attempt_count_on_throttle_continue() {
let mut mock = MockPolicy::new();
mock.expect_on_throttle()
.times(1..)
.returning(|_, e| ThrottleResult::Continue(e));
let now = Instant::now();
let policy = LimitedAttemptCount::custom(mock, 3);
assert!(matches!(
policy.on_throttle(
&idempotent_state(now).set_attempt_count(2_u32),
unavailable()
),
ThrottleResult::Continue(_)
));
}
#[test]
fn test_limited_attempt_count_on_throttle_error() {
let mut mock = MockPolicy::new();
mock.expect_on_throttle()
.times(1..)
.returning(|_, e| ThrottleResult::Exhausted(e));
let now = Instant::now();
let policy = LimitedAttemptCount::custom(mock, 3);
assert!(matches!(
policy.on_throttle(&idempotent_state(now), unavailable()),
ThrottleResult::Exhausted(_)
));
}
#[test]
fn test_limited_attempt_count_remaining_none() {
let mut mock = MockPolicy::new();
mock.expect_remaining_time().times(1).returning(|_| None);
let policy = LimitedAttemptCount::custom(mock, 3);
let now = Instant::now();
assert!(
policy.remaining_time(&idempotent_state(now)).is_none(),
"policy={policy:?} now={now:?}"
);
}
#[test]
fn test_limited_attempt_count_remaining_some() {
let mut mock = MockPolicy::new();
mock.expect_remaining_time()
.times(1)
.returning(|_| Some(Duration::from_secs(123)));
let policy = LimitedAttemptCount::custom(mock, 3);
let now = Instant::now();
assert_eq!(
policy.remaining_time(&idempotent_state(now)),
Some(Duration::from_secs(123))
);
}
#[test]
fn test_limited_attempt_count_inner_permanent() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(2)
.returning(|_, e| RetryResult::Permanent(e));
let policy = LimitedAttemptCount::custom(mock, 2);
let now = Instant::now();
let rf = policy.on_error(&non_idempotent_state(now), transient_error());
assert!(rf.is_permanent());
let rf = policy.on_error(&non_idempotent_state(now), transient_error());
assert!(rf.is_permanent());
}
#[test]
fn test_limited_attempt_count_inner_exhausted() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(2)
.returning(|_, e| RetryResult::Exhausted(e));
let policy = LimitedAttemptCount::custom(mock, 2);
let now = Instant::now();
let rf = policy.on_error(&non_idempotent_state(now), transient_error());
assert!(rf.is_exhausted());
let rf = policy.on_error(&non_idempotent_state(now), transient_error());
assert!(rf.is_exhausted());
}
fn transient_error() -> Error {
use crate::error::rpc::{Code, Status};
Error::service(
Status::default()
.set_code(Code::Unavailable)
.set_message("try-again"),
)
}
pub(crate) fn idempotent_state(now: Instant) -> RetryState {
RetryState::new(true).set_start(now)
}
pub(crate) fn non_idempotent_state(now: Instant) -> RetryState {
RetryState::new(false).set_start(now)
}
}