use crate::error::Error;
use crate::polling_state::PollingState;
use crate::retry_result::RetryResult;
use std::sync::Arc;
pub trait PollingErrorPolicy: Send + Sync + std::fmt::Debug {
#[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
fn on_error(&self, state: &PollingState, error: Error) -> RetryResult;
#[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
fn on_in_progress(&self, _state: &PollingState, _operation_name: &str) -> Result<(), Error> {
Ok(())
}
}
#[derive(Clone)]
pub struct PollingErrorPolicyArg(pub(crate) Arc<dyn PollingErrorPolicy>);
impl<T> std::convert::From<T> for PollingErrorPolicyArg
where
T: PollingErrorPolicy + 'static,
{
fn from(value: T) -> Self {
Self(Arc::new(value))
}
}
impl std::convert::From<Arc<dyn PollingErrorPolicy>> for PollingErrorPolicyArg {
fn from(value: Arc<dyn PollingErrorPolicy>) -> Self {
Self(value)
}
}
pub trait PollingErrorPolicyExt: PollingErrorPolicy + Sized {
fn with_time_limit(self, maximum_duration: std::time::Duration) -> LimitedElapsedTime<Self> {
LimitedElapsedTime::custom(self, maximum_duration)
}
fn with_attempt_limit(self, maximum_attempts: u32) -> LimitedAttemptCount<Self> {
LimitedAttemptCount::custom(self, maximum_attempts)
}
}
impl<T: PollingErrorPolicy> PollingErrorPolicyExt for T {}
#[derive(Clone, Debug)]
pub struct Aip194Strict;
impl PollingErrorPolicy for Aip194Strict {
fn on_error(&self, _state: &PollingState, error: Error) -> RetryResult {
if error.is_transient_and_before_rpc() {
return RetryResult::Continue(error);
}
if error.is_io() {
return RetryResult::Continue(error);
}
if let Some(status) = error.status() {
return if status.code == crate::error::rpc::Code::Unavailable {
RetryResult::Continue(error)
} else {
RetryResult::Permanent(error)
};
}
match error.http_status_code() {
Some(code) if code == http::StatusCode::SERVICE_UNAVAILABLE.as_u16() => {
RetryResult::Continue(error)
}
_ => RetryResult::Permanent(error),
}
}
}
#[derive(Clone, Debug)]
pub struct AlwaysContinue;
impl PollingErrorPolicy for AlwaysContinue {
fn on_error(&self, _state: &PollingState, error: Error) -> RetryResult {
RetryResult::Continue(error)
}
}
#[derive(Debug)]
pub struct LimitedElapsedTime<P = Aip194Strict>
where
P: PollingErrorPolicy,
{
inner: P,
maximum_duration: std::time::Duration,
}
impl LimitedElapsedTime {
pub fn new(maximum_duration: std::time::Duration) -> Self {
Self {
inner: Aip194Strict,
maximum_duration,
}
}
}
impl<P> LimitedElapsedTime<P>
where
P: PollingErrorPolicy,
{
pub fn custom(inner: P, maximum_duration: std::time::Duration) -> Self {
Self {
inner,
maximum_duration,
}
}
fn in_progress_impl(
&self,
start: std::time::Instant,
operation_name: &str,
) -> Result<(), Error> {
let now = std::time::Instant::now();
if now < start + self.maximum_duration {
return Ok(());
}
Err(Error::exhausted(Exhausted::new(
operation_name,
"elapsed time",
format!("{:?}", now.checked_duration_since(start).unwrap()),
format!("{:?}", self.maximum_duration),
)))
}
}
impl<P> PollingErrorPolicy for LimitedElapsedTime<P>
where
P: PollingErrorPolicy + 'static,
{
fn on_error(&self, state: &PollingState, error: Error) -> RetryResult {
match self.inner.on_error(state, error) {
RetryResult::Permanent(e) => RetryResult::Permanent(e),
RetryResult::Exhausted(e) => RetryResult::Exhausted(e),
RetryResult::Continue(e) => {
if std::time::Instant::now() >= state.start + self.maximum_duration {
RetryResult::Exhausted(e)
} else {
RetryResult::Continue(e)
}
}
}
}
fn on_in_progress(&self, state: &PollingState, operation_name: &str) -> Result<(), Error> {
self.inner
.on_in_progress(state, operation_name)
.and_then(|_| self.in_progress_impl(state.start, operation_name))
}
}
#[derive(Debug)]
pub struct LimitedAttemptCount<P = Aip194Strict>
where
P: PollingErrorPolicy,
{
inner: P,
maximum_attempts: u32,
}
impl LimitedAttemptCount {
pub fn new(maximum_attempts: u32) -> Self {
Self {
inner: Aip194Strict,
maximum_attempts,
}
}
}
impl<P> LimitedAttemptCount<P>
where
P: PollingErrorPolicy,
{
pub fn custom(inner: P, maximum_attempts: u32) -> Self {
Self {
inner,
maximum_attempts,
}
}
fn in_progress_impl(&self, count: u32, operation_name: &str) -> Result<(), Error> {
if count < self.maximum_attempts {
return Ok(());
}
Err(Error::exhausted(Exhausted::new(
operation_name,
"attempt count",
count.to_string(),
self.maximum_attempts.to_string(),
)))
}
}
impl<P> PollingErrorPolicy for LimitedAttemptCount<P>
where
P: PollingErrorPolicy,
{
fn on_error(&self, state: &PollingState, error: Error) -> RetryResult {
match self.inner.on_error(state, error) {
RetryResult::Permanent(e) => RetryResult::Permanent(e),
RetryResult::Exhausted(e) => RetryResult::Exhausted(e),
RetryResult::Continue(e) => {
if state.attempt_count >= self.maximum_attempts {
RetryResult::Exhausted(e)
} else {
RetryResult::Continue(e)
}
}
}
}
fn on_in_progress(&self, state: &PollingState, operation_name: &str) -> Result<(), Error> {
self.inner
.on_in_progress(state, operation_name)
.and_then(|_| self.in_progress_impl(state.attempt_count, operation_name))
}
}
#[derive(Debug)]
pub struct Exhausted {
operation_name: String,
limit_name: &'static str,
value: String,
limit: String,
}
impl Exhausted {
pub fn new(
operation_name: &str,
limit_name: &'static str,
value: String,
limit: String,
) -> Self {
Self {
operation_name: operation_name.to_string(),
limit_name,
value,
limit,
}
}
}
impl std::fmt::Display for Exhausted {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"polling loop for {} exhausted, {} value ({}) exceeds limit ({})",
self.operation_name, self.limit_name, self.value, self.limit
)
}
}
impl std::error::Error for Exhausted {}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::{CredentialsError, Error};
use http::HeaderMap;
use std::error::Error as _;
use std::time::{Duration, Instant};
mockall::mock! {
#[derive(Debug)]
Policy {}
impl PollingErrorPolicy for Policy {
fn on_error(&self, state: &PollingState, error: Error) -> RetryResult;
fn on_in_progress(&self, state: &PollingState, operation_name: &str) -> Result<(), Error>;
}
}
#[test]
fn polling_policy_arg() {
let policy = LimitedAttemptCount::new(3);
let _ = PollingErrorPolicyArg::from(policy);
let policy: Arc<dyn PollingErrorPolicy> = Arc::new(LimitedAttemptCount::new(3));
let _ = PollingErrorPolicyArg::from(policy);
}
#[test]
fn aip194_strict() -> anyhow::Result<()> {
let p = Aip194Strict;
p.on_in_progress(&PollingState::default(), "unused")?;
assert!(
p.on_error(&PollingState::default(), unavailable())
.is_continue()
);
assert!(
p.on_error(&PollingState::default(), permission_denied())
.is_permanent()
);
assert!(
p.on_error(&PollingState::default(), http_unavailable())
.is_continue()
);
assert!(
p.on_error(&PollingState::default(), http_permission_denied())
.is_permanent()
);
assert!(
p.on_error(&PollingState::default(), Error::io("err".to_string()))
.is_continue()
);
assert!(
p.on_error(
&PollingState::default(),
Error::authentication(CredentialsError::from_msg(true, "err"))
)
.is_continue()
);
assert!(
p.on_error(&PollingState::default(), Error::ser("err".to_string()))
.is_permanent()
);
Ok(())
}
#[test]
fn always_continue() {
let p = AlwaysContinue;
let result = p.on_in_progress(&PollingState::default(), "unused");
assert!(result.is_ok(), "{result:?}");
assert!(
p.on_error(&PollingState::default(), http_unavailable())
.is_continue()
);
assert!(
p.on_error(&PollingState::default(), unavailable())
.is_continue()
);
}
#[test_case::test_case(Error::io("err"))]
#[test_case::test_case(Error::authentication(CredentialsError::from_msg(true, "err")))]
#[test_case::test_case(Error::ser("err"))]
fn always_continue_error_kind(error: Error) {
let p = AlwaysContinue;
assert!(p.on_error(&PollingState::default(), error).is_continue());
}
#[test]
fn with_time_limit() {
let policy = AlwaysContinue.with_time_limit(Duration::from_secs(10));
assert!(
policy
.on_error(
&PollingState::default()
.set_start(Instant::now() - Duration::from_secs(1))
.set_attempt_count(1_u32),
permission_denied()
)
.is_continue(),
"{policy:?}"
);
assert!(
policy
.on_error(
&PollingState::default()
.set_start(Instant::now() - Duration::from_secs(20))
.set_attempt_count(1_u32),
permission_denied()
)
.is_exhausted(),
"{policy:?}"
);
}
#[test]
fn with_attempt_limit() {
let policy = AlwaysContinue.with_attempt_limit(3);
assert!(
policy
.on_error(
&PollingState::default().set_attempt_count(1_u32),
permission_denied()
)
.is_continue(),
"{policy:?}"
);
assert!(
policy
.on_error(
&PollingState::default().set_attempt_count(5_u32),
permission_denied()
)
.is_exhausted(),
"{policy:?}"
);
}
fn http_error(code: u16, message: &str) -> Error {
let error = serde_json::json!({"error": {
"code": code,
"message": message,
}});
let payload = bytes::Bytes::from_owner(serde_json::to_string(&error).unwrap());
Error::http(code, HeaderMap::new(), payload)
}
fn http_unavailable() -> Error {
http_error(503, "SERVICE UNAVAILABLE")
}
fn http_permission_denied() -> Error {
http_error(403, "PERMISSION DENIED")
}
fn unavailable() -> Error {
use crate::error::rpc::Code;
let status = crate::error::rpc::Status::default()
.set_code(Code::Unavailable)
.set_message("UNAVAILABLE");
Error::service(status)
}
fn permission_denied() -> Error {
use crate::error::rpc::Code;
let status = crate::error::rpc::Status::default()
.set_code(Code::PermissionDenied)
.set_message("PERMISSION_DENIED");
Error::service(status)
}
#[test]
fn test_limited_elapsed_time_on_error() {
let policy = LimitedElapsedTime::new(Duration::from_secs(20));
assert!(
policy
.on_error(
&PollingState::default()
.set_start(Instant::now() - Duration::from_secs(10))
.set_attempt_count(1_u32),
unavailable()
)
.is_continue(),
"{policy:?}"
);
assert!(
policy
.on_error(
&PollingState::default()
.set_start(Instant::now() - Duration::from_secs(30))
.set_attempt_count(1_u32),
unavailable()
)
.is_exhausted(),
"{policy:?}"
);
}
#[test]
fn test_limited_elapsed_time_in_progress() {
let policy = LimitedElapsedTime::new(Duration::from_secs(20));
let result = policy.on_in_progress(
&PollingState::default()
.set_start(Instant::now() - Duration::from_secs(10))
.set_attempt_count(1_u32),
"unused",
);
assert!(result.is_ok(), "{result:?}");
let err = policy
.on_in_progress(
&PollingState::default()
.set_start(Instant::now() - Duration::from_secs(30))
.set_attempt_count(1_u32),
"test-operation-name",
)
.unwrap_err();
let exhausted = err.source().and_then(|e| e.downcast_ref::<Exhausted>());
assert!(exhausted.is_some(), "{err:?}");
}
#[test]
fn test_limited_time_forwards_on_error() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(1..)
.returning(|_, e| RetryResult::Continue(e));
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let rf = policy.on_error(&PollingState::default(), transient_error());
assert!(rf.is_continue());
}
#[test]
fn test_limited_time_forwards_in_progress() {
let mut mock = MockPolicy::new();
mock.expect_on_in_progress()
.times(3)
.returning(|_, _| Ok(()));
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
assert!(
policy
.on_in_progress(
&PollingState::default().set_attempt_count(1_u32),
"test-op-name"
)
.is_ok()
);
assert!(
policy
.on_in_progress(
&PollingState::default().set_attempt_count(2_u32),
"test-op-name"
)
.is_ok()
);
assert!(
policy
.on_in_progress(
&PollingState::default().set_attempt_count(3_u32),
"test-op-name"
)
.is_ok()
);
}
#[test]
fn test_limited_time_in_progress_returns_inner() {
let mut mock = MockPolicy::new();
mock.expect_on_in_progress()
.times(1)
.returning(|_, _| Err(transient_error()));
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
assert!(
policy
.on_in_progress(
&PollingState::default().set_attempt_count(1_u32),
"test-op-name"
)
.is_err()
);
}
#[test]
fn test_limited_time_inner_continues() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(1..)
.returning(|_, e| RetryResult::Continue(e));
let now = std::time::Instant::now();
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let rf = policy.on_error(
&PollingState::default()
.set_start(now - Duration::from_secs(10))
.set_attempt_count(1_u32),
transient_error(),
);
assert!(rf.is_continue());
let rf = policy.on_error(
&PollingState::default()
.set_start(now - Duration::from_secs(70))
.set_attempt_count(1_u32),
transient_error(),
);
assert!(rf.is_exhausted());
}
#[test]
fn test_limited_time_inner_permanent() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(2)
.returning(|_, e| RetryResult::Permanent(e));
let now = std::time::Instant::now();
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let rf = policy.on_error(
&PollingState::default()
.set_start(now - Duration::from_secs(10))
.set_attempt_count(1_u32),
transient_error(),
);
assert!(rf.is_permanent());
let rf = policy.on_error(
&PollingState::default()
.set_start(now + Duration::from_secs(10))
.set_attempt_count(1_u32),
transient_error(),
);
assert!(rf.is_permanent());
}
#[test]
fn test_limited_time_inner_exhausted() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(2)
.returning(|_, e| RetryResult::Exhausted(e));
let now = std::time::Instant::now();
let policy = LimitedElapsedTime::custom(mock, Duration::from_secs(60));
let rf = policy.on_error(
&PollingState::default()
.set_start(now - Duration::from_secs(10))
.set_attempt_count(1_u32),
transient_error(),
);
assert!(rf.is_exhausted());
let rf = policy.on_error(
&PollingState::default()
.set_start(now + Duration::from_secs(10))
.set_attempt_count(1_u32),
transient_error(),
);
assert!(rf.is_exhausted());
}
#[test]
fn test_limited_attempt_count_on_error() {
let policy = LimitedAttemptCount::new(20);
assert!(
policy
.on_error(
&PollingState::default().set_attempt_count(10_u32),
unavailable()
)
.is_continue(),
"{policy:?}"
);
assert!(
policy
.on_error(
&PollingState::default().set_attempt_count(30_u32),
unavailable()
)
.is_exhausted(),
"{policy:?}"
);
}
#[test]
fn test_limited_attempt_count_in_progress() {
let policy = LimitedAttemptCount::new(20);
let result =
policy.on_in_progress(&PollingState::default().set_attempt_count(10_u32), "unused");
assert!(result.is_ok(), "{result:?}");
let err = policy
.on_in_progress(
&PollingState::default().set_attempt_count(30_u32),
"test-operation-name",
)
.unwrap_err();
let exhausted = err.source().and_then(|e| e.downcast_ref::<Exhausted>());
assert!(exhausted.is_some(), "{err:?}");
}
#[test]
fn test_limited_attempt_count_forwards_on_error() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(1..)
.returning(|_, e| RetryResult::Continue(e));
let policy = LimitedAttemptCount::custom(mock, 3);
assert!(
policy
.on_error(
&PollingState::default().set_attempt_count(1_u32),
transient_error()
)
.is_continue()
);
assert!(
policy
.on_error(
&PollingState::default().set_attempt_count(2_u32),
transient_error()
)
.is_continue()
);
assert!(
policy
.on_error(
&PollingState::default().set_attempt_count(3_u32),
transient_error()
)
.is_exhausted()
);
}
#[test]
fn test_limited_attempt_count_forwards_in_progress() {
let mut mock = MockPolicy::new();
mock.expect_on_in_progress()
.times(3)
.returning(|_, _| Ok(()));
let policy = LimitedAttemptCount::custom(mock, 5);
assert!(
policy
.on_in_progress(
&PollingState::default().set_attempt_count(1_u32),
"test-op-name"
)
.is_ok()
);
assert!(
policy
.on_in_progress(
&PollingState::default().set_attempt_count(2_u32),
"test-op-name"
)
.is_ok()
);
assert!(
policy
.on_in_progress(
&PollingState::default().set_attempt_count(3_u32),
"test-op-name"
)
.is_ok()
);
}
#[test]
fn test_limited_attempt_count_in_progress_returns_inner() {
let mut mock = MockPolicy::new();
mock.expect_on_in_progress()
.times(1)
.returning(|_, _| Err(unavailable()));
let policy = LimitedAttemptCount::custom(mock, 5);
assert!(
policy
.on_in_progress(
&PollingState::default().set_attempt_count(1_u32),
"test-op-name"
)
.is_err()
);
}
#[test]
fn test_limited_attempt_count_inner_permanent() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(2)
.returning(|_, e| RetryResult::Permanent(e));
let policy = LimitedAttemptCount::custom(mock, 2);
let rf = policy.on_error(
&PollingState::default().set_attempt_count(1_u32),
Error::ser("err"),
);
assert!(rf.is_permanent());
let rf = policy.on_error(
&PollingState::default().set_attempt_count(1_u32),
Error::ser("err"),
);
assert!(rf.is_permanent());
}
#[test]
fn test_limited_attempt_count_inner_exhausted() {
let mut mock = MockPolicy::new();
mock.expect_on_error()
.times(2)
.returning(|_, e| RetryResult::Exhausted(e));
let policy = LimitedAttemptCount::custom(mock, 2);
let rf = policy.on_error(
&PollingState::default().set_attempt_count(1_u32),
transient_error(),
);
assert!(rf.is_exhausted());
let rf = policy.on_error(
&PollingState::default().set_attempt_count(1_u32),
transient_error(),
);
assert!(rf.is_exhausted());
}
#[test]
fn test_exhausted_fmt() {
let exhausted = Exhausted::new(
"op-name",
"limit-name",
"test-value".to_string(),
"test-limit".to_string(),
);
let fmt = format!("{exhausted}");
assert!(fmt.contains("op-name"), "{fmt}");
assert!(fmt.contains("limit-name"), "{fmt}");
assert!(fmt.contains("test-value"), "{fmt}");
assert!(fmt.contains("test-limit"), "{fmt}");
}
fn transient_error() -> Error {
use crate::error::rpc::{Code, Status};
Error::service(
Status::default()
.set_code(Code::Unavailable)
.set_message("try-again"),
)
}
}