mermaid_cli/effect/
middleware.rs1use std::time::Duration;
10
11use crate::models::{BackendError, ModelError, Result};
12
13pub const DEFAULT_MAX_ATTEMPTS: usize = 3;
17
18const DEFAULT_INITIAL_DELAY_MS: u64 = 500;
19const MAX_DELAY_MS: u64 = 3_000;
20
21pub async fn retry_transient_http<F, Fut>(mut build_and_send: F) -> Result<reqwest::Response>
30where
31 F: FnMut() -> Fut,
32 Fut: std::future::Future<Output = Result<reqwest::Response>>,
33{
34 retry_transient_http_with(
35 RetryPolicy {
36 max_attempts: DEFAULT_MAX_ATTEMPTS,
37 },
38 &mut build_and_send,
39 )
40 .await
41}
42
43async fn retry_transient_http_with<F, Fut>(
44 policy: RetryPolicy,
45 build_and_send: &mut F,
46) -> Result<reqwest::Response>
47where
48 F: FnMut() -> Fut,
49 Fut: std::future::Future<Output = Result<reqwest::Response>>,
50{
51 let mut attempt: usize = 1;
52 let mut delay_ms = DEFAULT_INITIAL_DELAY_MS;
53
54 loop {
55 let result = build_and_send().await;
56 let transience = classify(&result);
57
58 if !transience.is_transient() || attempt >= policy.max_attempts {
59 if transience.is_transient() {
60 tracing::warn!(
61 attempts = attempt,
62 reason = transience.reason(),
63 "middleware: transient upstream failure — retries exhausted"
64 );
65 }
66 return result;
67 }
68
69 tracing::warn!(
70 attempt,
71 max = policy.max_attempts,
72 delay_ms,
73 reason = transience.reason(),
74 "middleware: retrying transient upstream failure"
75 );
76 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
77 attempt += 1;
78 delay_ms = (delay_ms * 2).min(MAX_DELAY_MS);
79 }
80}
81
82#[derive(Debug, Clone, Copy)]
83struct RetryPolicy {
84 max_attempts: usize,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88enum Transience {
89 Success,
90 Terminal,
91 Retryable(&'static str),
92}
93
94impl Transience {
95 fn is_transient(self) -> bool {
96 matches!(self, Transience::Retryable(_))
97 }
98
99 fn reason(self) -> &'static str {
100 match self {
101 Transience::Success => "success",
102 Transience::Terminal => "terminal",
103 Transience::Retryable(r) => r,
104 }
105 }
106}
107
108fn classify(result: &Result<reqwest::Response>) -> Transience {
109 match result {
110 Ok(resp) => {
111 let status = resp.status().as_u16();
112 if status == 429 {
113 Transience::Retryable("http_429")
114 } else if (500..=599).contains(&status) {
115 Transience::Retryable("http_5xx")
116 } else {
117 Transience::Success
118 }
119 },
120 Err(ModelError::Backend(BackendError::ConnectionFailed { .. })) => {
121 Transience::Retryable("connection_failed")
122 },
123 Err(_) => Transience::Terminal,
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use std::sync::Arc;
131 use std::sync::atomic::{AtomicUsize, Ordering};
132 use tokio::io::{AsyncReadExt, AsyncWriteExt};
133 use tokio::net::TcpListener;
134
135 async fn fake_response(status: u16) -> reqwest::Response {
136 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
137 let addr = listener.local_addr().expect("local_addr");
138
139 tokio::spawn(async move {
140 if let Ok((mut sock, _)) = listener.accept().await {
141 let mut buf = [0u8; 1024];
142 let _ = sock.read(&mut buf).await;
143 let body = format!(
144 "HTTP/1.1 {status} X\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"
145 );
146 let _ = sock.write_all(body.as_bytes()).await;
147 }
148 });
149
150 let url = format!("http://{}/x", addr);
151 reqwest::get(url).await.expect("send")
152 }
153
154 #[tokio::test]
155 async fn retries_5xx_then_succeeds() {
156 let calls = Arc::new(AtomicUsize::new(0));
157 let cc = Arc::clone(&calls);
158 let result = retry_transient_http_with(RetryPolicy { max_attempts: 3 }, &mut move || {
159 let c = Arc::clone(&cc);
160 async move {
161 let n = c.fetch_add(1, Ordering::SeqCst);
162 let status = if n < 2 { 500 } else { 200 };
163 Ok(fake_response(status).await)
164 }
165 })
166 .await;
167 assert!(result.is_ok());
168 assert_eq!(result.unwrap().status().as_u16(), 200);
169 assert_eq!(calls.load(Ordering::SeqCst), 3);
170 }
171
172 #[tokio::test]
173 async fn does_not_retry_4xx_client_errors() {
174 let calls = Arc::new(AtomicUsize::new(0));
175 let cc = Arc::clone(&calls);
176 let result = retry_transient_http_with(RetryPolicy { max_attempts: 3 }, &mut move || {
177 let c = Arc::clone(&cc);
178 async move {
179 c.fetch_add(1, Ordering::SeqCst);
180 Ok(fake_response(400).await)
181 }
182 })
183 .await;
184 assert!(result.is_ok());
185 assert_eq!(result.unwrap().status().as_u16(), 400);
186 assert_eq!(calls.load(Ordering::SeqCst), 1);
187 }
188
189 #[tokio::test]
190 async fn retries_429() {
191 let calls = Arc::new(AtomicUsize::new(0));
192 let cc = Arc::clone(&calls);
193 let result = retry_transient_http_with(RetryPolicy { max_attempts: 2 }, &mut move || {
194 let c = Arc::clone(&cc);
195 async move {
196 c.fetch_add(1, Ordering::SeqCst);
197 Ok(fake_response(429).await)
198 }
199 })
200 .await;
201 assert!(result.is_ok());
202 assert_eq!(result.unwrap().status().as_u16(), 429);
203 assert_eq!(calls.load(Ordering::SeqCst), 2);
204 }
205
206 #[tokio::test]
207 async fn retries_connection_failed_error() {
208 let calls = Arc::new(AtomicUsize::new(0));
209 let cc = Arc::clone(&calls);
210 let result = retry_transient_http_with(RetryPolicy { max_attempts: 3 }, &mut move || {
211 let c = Arc::clone(&cc);
212 async move {
213 let n = c.fetch_add(1, Ordering::SeqCst);
214 if n < 2 {
215 Err(ModelError::Backend(BackendError::ConnectionFailed {
216 backend: "test".to_string(),
217 url: "http://nope".to_string(),
218 reason: "dns".to_string(),
219 }))
220 } else {
221 Ok(fake_response(200).await)
222 }
223 }
224 })
225 .await;
226 assert!(result.is_ok());
227 assert_eq!(calls.load(Ordering::SeqCst), 3);
228 }
229}