nblm_core/client/
retry.rs

1use std::time::{Duration, SystemTime};
2
3use backon::{BackoffBuilder, ExponentialBuilder};
4use httpdate::parse_http_date;
5use reqwest::{header::RETRY_AFTER, StatusCode};
6use tokio::time::sleep;
7use tracing::warn;
8
9use crate::error::{Error, Result};
10
11const DEFAULT_RETRY_MIN_DELAY_MS: u64 = 500;
12const DEFAULT_RETRY_MAX_DELAY_SECS: u64 = 5;
13const DEFAULT_RETRY_MAX_RETRIES: usize = 3;
14
15#[derive(Debug, Clone)]
16pub struct RetryConfig {
17    /// Minimum backoff delay between retry attempts.
18    pub min_delay: Duration,
19    /// Maximum backoff delay between retry attempts.
20    pub max_delay: Duration,
21    /// Maximum number of retry attempts after the initial request.
22    pub max_retries: usize,
23    pub jitter: bool,
24}
25
26impl RetryConfig {
27    pub fn with_min_delay(mut self, delay: Duration) -> Self {
28        self.min_delay = delay;
29        self
30    }
31
32    pub fn with_max_delay(mut self, delay: Duration) -> Self {
33        self.max_delay = delay;
34        self
35    }
36
37    pub fn with_max_retries(mut self, retries: usize) -> Self {
38        self.max_retries = retries;
39        self
40    }
41
42    pub fn with_jitter(mut self, jitter: bool) -> Self {
43        self.jitter = jitter;
44        self
45    }
46}
47
48impl Default for RetryConfig {
49    fn default() -> Self {
50        Self {
51            min_delay: Duration::from_millis(DEFAULT_RETRY_MIN_DELAY_MS),
52            max_delay: Duration::from_secs(DEFAULT_RETRY_MAX_DELAY_SECS),
53            max_retries: DEFAULT_RETRY_MAX_RETRIES,
54            jitter: true,
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct Retryer {
61    config: RetryConfig,
62}
63
64impl Retryer {
65    pub fn new(config: RetryConfig) -> Self {
66        Self { config }
67    }
68
69    pub async fn run_with_retry<F, Fut>(&self, mut operation: F) -> Result<reqwest::Response>
70    where
71        F: FnMut() -> Fut,
72        Fut: std::future::Future<Output = std::result::Result<reqwest::Response, Error>>,
73    {
74        let mut builder = ExponentialBuilder::default()
75            .with_min_delay(self.config.min_delay)
76            .with_max_delay(self.config.max_delay)
77            .with_max_times(self.config.max_retries);
78        if self.config.jitter {
79            builder = builder.with_jitter();
80        }
81        let mut backoff = builder.build();
82        let mut attempts = 0usize;
83
84        loop {
85            match operation().await {
86                Ok(response) => {
87                    if should_retry_status(response.status()) {
88                        let status = response.status();
89                        let retry_after = retry_after_delay(&response);
90                        if attempts >= self.config.max_retries {
91                            let body = response.text().await.unwrap_or_default();
92                            return Err(Error::http(status, body));
93                        }
94                        attempts += 1;
95                        let max_delay = self.config.max_delay;
96                        let backoff_delay = backoff.next().map(|d| d.min(max_delay));
97                        let delay = retry_after
98                            .map(|d| d.min(max_delay))
99                            .or(backoff_delay)
100                            .unwrap_or(Duration::from_millis(0));
101                        let _ = response.bytes().await;
102                        warn!(
103                            %status,
104                            attempt = attempts,
105                            max_retries = self.config.max_retries,
106                            retry_after = ?delay,
107                            "retrying HTTP request due to status"
108                        );
109                        sleep(delay).await;
110                        continue;
111                    }
112                    return Ok(response);
113                }
114                Err(err) => {
115                    if is_retryable_error(&err) {
116                        if attempts >= self.config.max_retries {
117                            return Err(err);
118                        }
119                        attempts += 1;
120                        if let Some(delay) = backoff.next().map(|d| d.min(self.config.max_delay)) {
121                            warn!(
122                                ?err,
123                                attempt = attempts,
124                                max_retries = self.config.max_retries,
125                                retry_after = ?delay,
126                                "retrying HTTP request due to error"
127                            );
128                            sleep(delay).await;
129                            continue;
130                        }
131                    }
132                    return Err(err);
133                }
134            }
135        }
136    }
137}
138
139fn should_retry_status(status: StatusCode) -> bool {
140    matches!(
141        status,
142        StatusCode::TOO_MANY_REQUESTS | StatusCode::REQUEST_TIMEOUT
143    ) || status.is_server_error()
144}
145
146fn retry_after_delay(response: &reqwest::Response) -> Option<Duration> {
147    response
148        .headers()
149        .get(RETRY_AFTER)
150        .and_then(|value| parse_retry_after(value.to_str().ok()?, SystemTime::now()))
151}
152
153fn is_retryable_error(err: &Error) -> bool {
154    match err {
155        Error::Request(req_err) => req_err.is_connect() || req_err.is_timeout(),
156        Error::Http { status, .. } => should_retry_status(*status),
157        _ => false,
158    }
159}
160
161fn parse_retry_after(value: &str, now: SystemTime) -> Option<Duration> {
162    if let Ok(seconds) = value.parse::<u64>() {
163        return Some(Duration::from_secs(seconds));
164    }
165    if let Ok(date) = parse_http_date(value) {
166        if let Ok(dur) = date.duration_since(now) {
167            return Some(dur);
168        }
169    }
170    None
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn parse_retry_after_seconds() {
179        let now = SystemTime::now();
180        let delay = parse_retry_after("5", now).unwrap();
181        assert_eq!(delay, Duration::from_secs(5));
182    }
183
184    #[test]
185    fn parse_retry_after_http_date() {
186        let future = SystemTime::now() + Duration::from_secs(3);
187        let header = httpdate::fmt_http_date(future);
188        let delay = parse_retry_after(&header, SystemTime::now()).unwrap();
189        assert!(delay <= Duration::from_secs(3));
190    }
191
192    #[test]
193    fn parse_retry_after_invalid() {
194        let now = SystemTime::now();
195        assert!(parse_retry_after("invalid", now).is_none());
196    }
197
198    #[test]
199    fn should_retry_status_for_retryable_codes() {
200        assert!(should_retry_status(StatusCode::TOO_MANY_REQUESTS));
201        assert!(should_retry_status(StatusCode::REQUEST_TIMEOUT));
202        assert!(should_retry_status(StatusCode::INTERNAL_SERVER_ERROR));
203        assert!(should_retry_status(StatusCode::BAD_GATEWAY));
204        assert!(should_retry_status(StatusCode::SERVICE_UNAVAILABLE));
205    }
206
207    #[test]
208    fn should_retry_status_for_non_retryable_codes() {
209        assert!(!should_retry_status(StatusCode::OK));
210        assert!(!should_retry_status(StatusCode::NOT_FOUND));
211        assert!(!should_retry_status(StatusCode::BAD_REQUEST));
212        assert!(!should_retry_status(StatusCode::UNAUTHORIZED));
213    }
214
215    #[test]
216    fn is_retryable_error_for_connect_and_timeout() {
217        // We can't easily construct reqwest::Error with is_connect()/is_timeout() true,
218        // but we can test the Error::Http path
219        let err = Error::Http {
220            status: StatusCode::TOO_MANY_REQUESTS,
221            message: "test".to_string(),
222            body: "test".to_string(),
223        };
224        assert!(is_retryable_error(&err));
225    }
226
227    #[test]
228    fn is_retryable_error_for_non_retryable() {
229        let err = Error::TokenProvider("test".to_string());
230        assert!(!is_retryable_error(&err));
231
232        let err = Error::Endpoint("test".to_string());
233        assert!(!is_retryable_error(&err));
234    }
235}