use std::{future::Future, time::Duration};
use tonic::Status;
use crate::rpc::{RpcRetryConfig, deadline::DeadlineBudget};
pub fn is_retryable_status(status: &Status, config: &RpcRetryConfig) -> bool {
config.enabled && config.retryable_codes.contains(&status.code())
}
pub fn next_backoff(config: &RpcRetryConfig, retry_attempt: u32) -> Duration {
let shift = retry_attempt.min(16);
let multiplier = 2_u32.saturating_pow(shift);
let backoff = config.initial_backoff.saturating_mul(multiplier);
backoff.min(config.max_backoff)
}
pub async fn run_with_retry<T, F, Fut>(
config: &RpcRetryConfig,
mut operation: F,
) -> Result<T, Status>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, Status>>,
{
let max_attempts = if config.enabled {
config.max_attempts.max(1)
} else {
1
};
let mut attempt = 0;
loop {
let result = operation().await;
match result {
Ok(value) => return Ok(value),
Err(status) if attempt + 1 < max_attempts && is_retryable_status(&status, config) => {
tokio::time::sleep(next_backoff(config, attempt)).await;
attempt += 1;
}
Err(status) => return Err(status),
}
}
}
pub async fn run_with_retry_budget<T, F, Fut>(
config: &RpcRetryConfig,
budget: &DeadlineBudget,
mut operation: F,
) -> Result<T, Status>
where
F: FnMut(std::time::Duration) -> Fut,
Fut: Future<Output = Result<T, Status>>,
{
if budget.expired() {
return Err(Status::deadline_exceeded("rpc deadline exhausted"));
}
let max_attempts = if config.enabled {
config.max_attempts.max(1)
} else {
1
};
let mut attempt = 0;
loop {
let remaining = budget.remaining();
if remaining.is_zero() {
return Err(Status::deadline_exceeded("rpc deadline exhausted"));
}
let result = operation(remaining).await;
match result {
Ok(value) => return Ok(value),
Err(status) if attempt + 1 < max_attempts && is_retryable_status(&status, config) => {
let sleep_for = next_backoff(config, attempt).min(budget.remaining());
if sleep_for.is_zero() {
return Err(Status::deadline_exceeded("rpc deadline exhausted"));
}
tokio::time::sleep(sleep_for).await;
attempt += 1;
}
Err(status) => return Err(status),
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tonic::{Code, Status};
use super::{is_retryable_status, next_backoff};
use crate::rpc::RpcRetryConfig;
#[test]
fn retry_policy_caps_backoff() {
let config = RpcRetryConfig {
initial_backoff: Duration::from_millis(10),
max_backoff: Duration::from_millis(25),
..RpcRetryConfig::production_defaults()
};
assert_eq!(next_backoff(&config, 0), Duration::from_millis(10));
assert_eq!(next_backoff(&config, 4), Duration::from_millis(25));
}
#[test]
fn retry_policy_only_retries_configured_codes() {
let config = RpcRetryConfig::production_defaults();
assert!(is_retryable_status(
&Status::new(Code::Unavailable, "down"),
&config
));
assert!(!is_retryable_status(
&Status::new(Code::InvalidArgument, "bad"),
&config
));
}
}