nblm_core/client/
retry.rs1use std::time::{Duration, SystemTime};
2
3use backon::{BackoffBuilder, ExponentialBuilder};
4use httpdate::parse_http_date;
5use reqwest::{header::RETRY_AFTER, StatusCode};
6use tokio::time::sleep;
7use tracing::warn;
8
9use crate::error::{Error, Result};
10
11const DEFAULT_RETRY_MIN_DELAY_MS: u64 = 500;
12const DEFAULT_RETRY_MAX_DELAY_SECS: u64 = 5;
13const DEFAULT_RETRY_MAX_RETRIES: usize = 3;
14
15#[derive(Debug, Clone)]
16pub struct RetryConfig {
17 pub min_delay: Duration,
19 pub max_delay: Duration,
21 pub max_retries: usize,
23 pub jitter: bool,
24}
25
26impl RetryConfig {
27 pub fn with_min_delay(mut self, delay: Duration) -> Self {
28 self.min_delay = delay;
29 self
30 }
31
32 pub fn with_max_delay(mut self, delay: Duration) -> Self {
33 self.max_delay = delay;
34 self
35 }
36
37 pub fn with_max_retries(mut self, retries: usize) -> Self {
38 self.max_retries = retries;
39 self
40 }
41
42 pub fn with_jitter(mut self, jitter: bool) -> Self {
43 self.jitter = jitter;
44 self
45 }
46}
47
48impl Default for RetryConfig {
49 fn default() -> Self {
50 Self {
51 min_delay: Duration::from_millis(DEFAULT_RETRY_MIN_DELAY_MS),
52 max_delay: Duration::from_secs(DEFAULT_RETRY_MAX_DELAY_SECS),
53 max_retries: DEFAULT_RETRY_MAX_RETRIES,
54 jitter: true,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
60pub struct Retryer {
61 config: RetryConfig,
62}
63
64impl Retryer {
65 pub fn new(config: RetryConfig) -> Self {
66 Self { config }
67 }
68
69 pub async fn run_with_retry<F, Fut>(&self, mut operation: F) -> Result<reqwest::Response>
70 where
71 F: FnMut() -> Fut,
72 Fut: std::future::Future<Output = std::result::Result<reqwest::Response, Error>>,
73 {
74 let mut builder = ExponentialBuilder::default()
75 .with_min_delay(self.config.min_delay)
76 .with_max_delay(self.config.max_delay)
77 .with_max_times(self.config.max_retries);
78 if self.config.jitter {
79 builder = builder.with_jitter();
80 }
81 let mut backoff = builder.build();
82 let mut attempts = 0usize;
83
84 loop {
85 match operation().await {
86 Ok(response) => {
87 if should_retry_status(response.status()) {
88 let status = response.status();
89 let retry_after = retry_after_delay(&response);
90 if attempts >= self.config.max_retries {
91 let body = response.text().await.unwrap_or_default();
92 return Err(Error::http(status, body));
93 }
94 attempts += 1;
95 let max_delay = self.config.max_delay;
96 let backoff_delay = backoff.next().map(|d| d.min(max_delay));
97 let delay = retry_after
98 .map(|d| d.min(max_delay))
99 .or(backoff_delay)
100 .unwrap_or(Duration::from_millis(0));
101 let _ = response.bytes().await;
102 warn!(
103 %status,
104 attempt = attempts,
105 max_retries = self.config.max_retries,
106 retry_after = ?delay,
107 "retrying HTTP request due to status"
108 );
109 sleep(delay).await;
110 continue;
111 }
112 return Ok(response);
113 }
114 Err(err) => {
115 if is_retryable_error(&err) {
116 if attempts >= self.config.max_retries {
117 return Err(err);
118 }
119 attempts += 1;
120 if let Some(delay) = backoff.next().map(|d| d.min(self.config.max_delay)) {
121 warn!(
122 ?err,
123 attempt = attempts,
124 max_retries = self.config.max_retries,
125 retry_after = ?delay,
126 "retrying HTTP request due to error"
127 );
128 sleep(delay).await;
129 continue;
130 }
131 }
132 return Err(err);
133 }
134 }
135 }
136 }
137}
138
139fn should_retry_status(status: StatusCode) -> bool {
140 matches!(
141 status,
142 StatusCode::TOO_MANY_REQUESTS | StatusCode::REQUEST_TIMEOUT
143 ) || status.is_server_error()
144}
145
146fn retry_after_delay(response: &reqwest::Response) -> Option<Duration> {
147 response
148 .headers()
149 .get(RETRY_AFTER)
150 .and_then(|value| parse_retry_after(value.to_str().ok()?, SystemTime::now()))
151}
152
153fn is_retryable_error(err: &Error) -> bool {
154 match err {
155 Error::Request(req_err) => req_err.is_connect() || req_err.is_timeout(),
156 Error::Http { status, .. } => should_retry_status(*status),
157 _ => false,
158 }
159}
160
161fn parse_retry_after(value: &str, now: SystemTime) -> Option<Duration> {
162 if let Ok(seconds) = value.parse::<u64>() {
163 return Some(Duration::from_secs(seconds));
164 }
165 if let Ok(date) = parse_http_date(value) {
166 if let Ok(dur) = date.duration_since(now) {
167 return Some(dur);
168 }
169 }
170 None
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn parse_retry_after_seconds() {
179 let now = SystemTime::now();
180 let delay = parse_retry_after("5", now).unwrap();
181 assert_eq!(delay, Duration::from_secs(5));
182 }
183
184 #[test]
185 fn parse_retry_after_http_date() {
186 let future = SystemTime::now() + Duration::from_secs(3);
187 let header = httpdate::fmt_http_date(future);
188 let delay = parse_retry_after(&header, SystemTime::now()).unwrap();
189 assert!(delay <= Duration::from_secs(3));
190 }
191
192 #[test]
193 fn parse_retry_after_invalid() {
194 let now = SystemTime::now();
195 assert!(parse_retry_after("invalid", now).is_none());
196 }
197
198 #[test]
199 fn should_retry_status_for_retryable_codes() {
200 assert!(should_retry_status(StatusCode::TOO_MANY_REQUESTS));
201 assert!(should_retry_status(StatusCode::REQUEST_TIMEOUT));
202 assert!(should_retry_status(StatusCode::INTERNAL_SERVER_ERROR));
203 assert!(should_retry_status(StatusCode::BAD_GATEWAY));
204 assert!(should_retry_status(StatusCode::SERVICE_UNAVAILABLE));
205 }
206
207 #[test]
208 fn should_retry_status_for_non_retryable_codes() {
209 assert!(!should_retry_status(StatusCode::OK));
210 assert!(!should_retry_status(StatusCode::NOT_FOUND));
211 assert!(!should_retry_status(StatusCode::BAD_REQUEST));
212 assert!(!should_retry_status(StatusCode::UNAUTHORIZED));
213 }
214
215 #[test]
216 fn is_retryable_error_for_connect_and_timeout() {
217 let err = Error::Http {
220 status: StatusCode::TOO_MANY_REQUESTS,
221 message: "test".to_string(),
222 body: "test".to_string(),
223 };
224 assert!(is_retryable_error(&err));
225 }
226
227 #[test]
228 fn is_retryable_error_for_non_retryable() {
229 let err = Error::TokenProvider("test".to_string());
230 assert!(!is_retryable_error(&err));
231
232 let err = Error::Endpoint("test".to_string());
233 assert!(!is_retryable_error(&err));
234 }
235}