use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use crate::error::TypedError;
use crate::executor::Executor;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct RetryPolicy {
pub max_retries: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
}
impl RetryPolicy {
pub fn new(max_retries: u32, initial_backoff: Duration) -> Self {
Self {
max_retries,
initial_backoff,
max_backoff: Duration::from_secs(30),
}
}
pub fn with_max_backoff(mut self, max: Duration) -> Self {
self.max_backoff = max;
self
}
pub async fn execute<'a, T, E: Executor + ?Sized>(
&self,
db: &'a E,
f: impl Fn(&'a E) -> Pin<Box<dyn Future<Output = Result<T, TypedError>> + Send + 'a>>,
) -> Result<T, TypedError> {
let mut last_err = None;
let mut backoff = self.initial_backoff;
for attempt in 0..=self.max_retries {
if attempt > 0 {
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(self.max_backoff);
}
match f(db).await {
Ok(val) => return Ok(val),
Err(e) => {
if is_transient(&e) && attempt < self.max_retries {
tracing::warn!(
"Transient error on attempt {}/{}: {}",
attempt + 1,
self.max_retries + 1,
e,
);
last_err = Some(e);
} else {
return Err(e);
}
}
}
}
Err(last_err.expect("retry loop sets last_err on every transient error before iterating"))
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(30),
}
}
}
fn is_transient(err: &TypedError) -> bool {
match err {
TypedError::Wire(wire_err) => match wire_err.as_ref() {
pg_wired::PgWireError::Io(_) => true,
pg_wired::PgWireError::ConnectionClosed => true,
pg_wired::PgWireError::Pg(pg_err) => is_transient_pg_code(&pg_err.code),
_ => false,
},
_ => false,
}
}
fn is_transient_pg_code(code: &str) -> bool {
matches!(
code,
"08000" | "08001" | "08003" | "08004" | "08006" |
"40001" | "40P01" | "57P01" | "57P02" | "57P03" )
}