Skip to main content

mermaid_cli/effect/
middleware.rs

1//! Cross-cutting wrappers over effect handlers.
2//!
3//! Retry-on-5xx, tracing, rate-limiting — all concerns that would
4//! otherwise be re-implemented per-adapter. Living here means any
5//! new effect handler picks them up uniformly; 500ms→3s exponential
6//! backoff, 3-attempt cap, same classification function for every
7//! provider.
8
9use std::time::Duration;
10
11use crate::models::{BackendError, ModelError, Result};
12
13/// Total attempts (initial + retries). 3 attempts means up to 2
14/// retries on top of the first request, costing at most ~1.5s of
15/// extra latency on the worst path (500ms + 1000ms backoff).
16pub const DEFAULT_MAX_ATTEMPTS: usize = 3;
17
18const DEFAULT_INITIAL_DELAY_MS: u64 = 500;
19const MAX_DELAY_MS: u64 = 3_000;
20
21/// Retry a closure whose output is `Result<reqwest::Response>`
22/// whenever the response was a transient upstream failure (5xx / 429
23/// / connection failed). Returns the first non-transient response, or
24/// the last result after attempts are exhausted.
25///
26/// The closure takes nothing and must rebuild the request internally
27/// because `reqwest::RequestBuilder::send` consumes the builder — so
28/// each attempt needs a fresh one.
29pub async fn retry_transient_http<F, Fut>(mut build_and_send: F) -> Result<reqwest::Response>
30where
31    F: FnMut() -> Fut,
32    Fut: std::future::Future<Output = Result<reqwest::Response>>,
33{
34    retry_transient_http_with(
35        RetryPolicy {
36            max_attempts: DEFAULT_MAX_ATTEMPTS,
37        },
38        &mut build_and_send,
39    )
40    .await
41}
42
43async fn retry_transient_http_with<F, Fut>(
44    policy: RetryPolicy,
45    build_and_send: &mut F,
46) -> Result<reqwest::Response>
47where
48    F: FnMut() -> Fut,
49    Fut: std::future::Future<Output = Result<reqwest::Response>>,
50{
51    let mut attempt: usize = 1;
52    let mut delay_ms = DEFAULT_INITIAL_DELAY_MS;
53
54    loop {
55        let result = build_and_send().await;
56        let transience = classify(&result);
57
58        if !transience.is_transient() || attempt >= policy.max_attempts {
59            if transience.is_transient() {
60                tracing::warn!(
61                    attempts = attempt,
62                    reason = transience.reason(),
63                    "middleware: transient upstream failure — retries exhausted"
64                );
65            }
66            return result;
67        }
68
69        tracing::warn!(
70            attempt,
71            max = policy.max_attempts,
72            delay_ms,
73            reason = transience.reason(),
74            "middleware: retrying transient upstream failure"
75        );
76        tokio::time::sleep(Duration::from_millis(delay_ms)).await;
77        attempt += 1;
78        delay_ms = (delay_ms * 2).min(MAX_DELAY_MS);
79    }
80}
81
82#[derive(Debug, Clone, Copy)]
83struct RetryPolicy {
84    max_attempts: usize,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88enum Transience {
89    Success,
90    Terminal,
91    Retryable(&'static str),
92}
93
94impl Transience {
95    fn is_transient(self) -> bool {
96        matches!(self, Transience::Retryable(_))
97    }
98
99    fn reason(self) -> &'static str {
100        match self {
101            Transience::Success => "success",
102            Transience::Terminal => "terminal",
103            Transience::Retryable(r) => r,
104        }
105    }
106}
107
108fn classify(result: &Result<reqwest::Response>) -> Transience {
109    match result {
110        Ok(resp) => {
111            let status = resp.status().as_u16();
112            if status == 429 {
113                Transience::Retryable("http_429")
114            } else if (500..=599).contains(&status) {
115                Transience::Retryable("http_5xx")
116            } else {
117                Transience::Success
118            }
119        },
120        Err(ModelError::Backend(BackendError::ConnectionFailed { .. })) => {
121            Transience::Retryable("connection_failed")
122        },
123        Err(_) => Transience::Terminal,
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use std::sync::Arc;
131    use std::sync::atomic::{AtomicUsize, Ordering};
132    use tokio::io::{AsyncReadExt, AsyncWriteExt};
133    use tokio::net::TcpListener;
134
135    async fn fake_response(status: u16) -> reqwest::Response {
136        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
137        let addr = listener.local_addr().expect("local_addr");
138
139        tokio::spawn(async move {
140            if let Ok((mut sock, _)) = listener.accept().await {
141                let mut buf = [0u8; 1024];
142                let _ = sock.read(&mut buf).await;
143                let body = format!(
144                    "HTTP/1.1 {status} X\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"
145                );
146                let _ = sock.write_all(body.as_bytes()).await;
147            }
148        });
149
150        let url = format!("http://{}/x", addr);
151        reqwest::get(url).await.expect("send")
152    }
153
154    #[tokio::test]
155    async fn retries_5xx_then_succeeds() {
156        let calls = Arc::new(AtomicUsize::new(0));
157        let cc = Arc::clone(&calls);
158        let result = retry_transient_http_with(RetryPolicy { max_attempts: 3 }, &mut move || {
159            let c = Arc::clone(&cc);
160            async move {
161                let n = c.fetch_add(1, Ordering::SeqCst);
162                let status = if n < 2 { 500 } else { 200 };
163                Ok(fake_response(status).await)
164            }
165        })
166        .await;
167        assert!(result.is_ok());
168        assert_eq!(result.unwrap().status().as_u16(), 200);
169        assert_eq!(calls.load(Ordering::SeqCst), 3);
170    }
171
172    #[tokio::test]
173    async fn does_not_retry_4xx_client_errors() {
174        let calls = Arc::new(AtomicUsize::new(0));
175        let cc = Arc::clone(&calls);
176        let result = retry_transient_http_with(RetryPolicy { max_attempts: 3 }, &mut move || {
177            let c = Arc::clone(&cc);
178            async move {
179                c.fetch_add(1, Ordering::SeqCst);
180                Ok(fake_response(400).await)
181            }
182        })
183        .await;
184        assert!(result.is_ok());
185        assert_eq!(result.unwrap().status().as_u16(), 400);
186        assert_eq!(calls.load(Ordering::SeqCst), 1);
187    }
188
189    #[tokio::test]
190    async fn retries_429() {
191        let calls = Arc::new(AtomicUsize::new(0));
192        let cc = Arc::clone(&calls);
193        let result = retry_transient_http_with(RetryPolicy { max_attempts: 2 }, &mut move || {
194            let c = Arc::clone(&cc);
195            async move {
196                c.fetch_add(1, Ordering::SeqCst);
197                Ok(fake_response(429).await)
198            }
199        })
200        .await;
201        assert!(result.is_ok());
202        assert_eq!(result.unwrap().status().as_u16(), 429);
203        assert_eq!(calls.load(Ordering::SeqCst), 2);
204    }
205
206    #[tokio::test]
207    async fn retries_connection_failed_error() {
208        let calls = Arc::new(AtomicUsize::new(0));
209        let cc = Arc::clone(&calls);
210        let result = retry_transient_http_with(RetryPolicy { max_attempts: 3 }, &mut move || {
211            let c = Arc::clone(&cc);
212            async move {
213                let n = c.fetch_add(1, Ordering::SeqCst);
214                if n < 2 {
215                    Err(ModelError::Backend(BackendError::ConnectionFailed {
216                        backend: "test".to_string(),
217                        url: "http://nope".to_string(),
218                        reason: "dns".to_string(),
219                    }))
220                } else {
221                    Ok(fake_response(200).await)
222                }
223            }
224        })
225        .await;
226        assert!(result.is_ok());
227        assert_eq!(calls.load(Ordering::SeqCst), 3);
228    }
229}