Skip to main content

pi_ai/
retry.rs

1//! Shared retry helper used by HTTP-backed providers.
2//!
3//! Retries on transient failures (5xx, 429) with exponential back-off. If the
4//! response includes a `Retry-After` header, the delay honors it (capped at
5//! `max_retry_delay`).
6
7use std::time::Duration;
8
9use tokio::time::sleep;
10use tokio_util::sync::CancellationToken;
11
12use crate::error::{Error, Result};
13
14#[derive(Debug, Clone)]
15pub struct RetryConfig {
16    pub max_attempts: u32,
17    pub base_delay: Duration,
18    pub max_delay: Duration,
19}
20
21impl Default for RetryConfig {
22    fn default() -> Self {
23        Self {
24            max_attempts: 3,
25            base_delay: Duration::from_millis(500),
26            max_delay: Duration::from_secs(60),
27        }
28    }
29}
30
31/// Outcome of a single attempt.
32#[allow(clippy::large_enum_variant)]
33pub enum Attempt<T> {
34    Ok(T),
35    /// Permanent failure — return immediately.
36    Fatal(Error),
37    /// Transient failure — try again. `retry_after` is the server-hinted delay.
38    Retry {
39        error: Error,
40        retry_after: Option<Duration>,
41    },
42}
43
44pub async fn with_retry<T, F, Fut>(
45    cfg: &RetryConfig,
46    cancel: Option<&CancellationToken>,
47    mut f: F,
48) -> Result<T>
49where
50    F: FnMut(u32) -> Fut,
51    Fut: std::future::Future<Output = Attempt<T>>,
52{
53    let mut attempt: u32 = 0;
54    let mut last_err: Option<Error>;
55    loop {
56        if let Some(c) = cancel {
57            if c.is_cancelled() {
58                return Err(Error::Cancelled);
59            }
60        }
61        attempt += 1;
62        match f(attempt).await {
63            Attempt::Ok(v) => return Ok(v),
64            Attempt::Fatal(e) => return Err(e),
65            Attempt::Retry { error, retry_after } => {
66                last_err = Some(error);
67                let _ = &last_err;
68                if attempt >= cfg.max_attempts {
69                    break;
70                }
71                let backoff = cfg
72                    .base_delay
73                    .saturating_mul(1u32 << attempt.min(6))
74                    .min(cfg.max_delay);
75                let delay = retry_after.map(|d| d.min(cfg.max_delay)).unwrap_or(backoff);
76                tracing::warn!(?delay, attempt, "retrying after transient error");
77                tokio::select! {
78                    _ = sleep(delay) => {},
79                    _ = async {
80                        if let Some(c) = cancel { c.cancelled().await; }
81                        else { futures::future::pending::<()>().await; }
82                    } => return Err(Error::Cancelled),
83                }
84            }
85        }
86    }
87    Err(Error::RetryExhausted {
88        attempts: attempt,
89        source: Box::new(last_err.unwrap_or_else(|| Error::Other("retry exhausted".into()))),
90    })
91}
92
93/// Classify a status code into retry-worthy categories.
94pub fn classify_status(status: u16) -> Option<ClassifiedStatus> {
95    match status {
96        429 => Some(ClassifiedStatus::RateLimited),
97        500..=599 => Some(ClassifiedStatus::ServerError),
98        _ => None,
99    }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum ClassifiedStatus {
104    RateLimited,
105    ServerError,
106}
107
108pub fn parse_retry_after(value: &str) -> Option<Duration> {
109    if let Ok(seconds) = value.trim().parse::<u64>() {
110        return Some(Duration::from_secs(seconds));
111    }
112    None
113}