openapi_snapshot/
fetch.rs

1use std::thread;
2use std::time::Duration;
3
4use reqwest::blocking::Client;
5use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
6use serde_json::Value;
7
8use crate::config::Config;
9use crate::errors::AppError;
10
11const USER_AGENT: &str = concat!("openapi-snapshot/", env!("CARGO_PKG_VERSION"));
12const MAX_RETRIES: usize = 3;
13const BASE_BACKOFF_MS: u64 = 100;
14const MAX_BACKOFF_MS: u64 = 2_000;
15const ERROR_SNIPPET_LIMIT: usize = 256;
16
17pub fn fetch_openapi(config: &Config) -> Result<Vec<u8>, AppError> {
18    let headers = build_headers(&config.headers)?;
19    let client = Client::builder()
20        .timeout(Duration::from_millis(config.timeout_ms))
21        .default_headers(headers)
22        .build()
23        .map_err(|err| AppError::Network(format!("client error: {err}")))?;
24
25    let mut backoff = BASE_BACKOFF_MS;
26    let mut attempt = 0;
27    loop {
28        attempt += 1;
29        match client.get(&config.url).send() {
30            Ok(response) => {
31                let status = response.status();
32                if !status.is_success() {
33                    let snippet = body_snippet(response.text().unwrap_or_default());
34                    let message = format!("HTTP {status}: {snippet}");
35                    if should_retry_status(status) && attempt < MAX_RETRIES {
36                        sleep(backoff);
37                        backoff = next_backoff(backoff);
38                        continue;
39                    }
40                    return Err(AppError::Network(message));
41                }
42
43                match response.bytes() {
44                    Ok(bytes) => return Ok(bytes.to_vec()),
45                    Err(err) => {
46                        if is_retryable_error(&err) && attempt < MAX_RETRIES {
47                            sleep(backoff);
48                            backoff = next_backoff(backoff);
49                            continue;
50                        }
51                        return Err(AppError::Network(format!("failed to read response: {err}")));
52                    }
53                }
54            }
55            Err(err) => {
56                if is_retryable_error(&err) && attempt < MAX_RETRIES {
57                    sleep(backoff);
58                    backoff = next_backoff(backoff);
59                    continue;
60                }
61                return Err(AppError::Network(format!("request failed: {err}")));
62            }
63        }
64    }
65}
66
67pub fn parse_json(bytes: &[u8]) -> Result<Value, AppError> {
68    serde_json::from_slice(bytes).map_err(|err| AppError::Json(format!("invalid JSON: {err}")))
69}
70
71fn build_headers(raw_headers: &[String]) -> Result<HeaderMap, AppError> {
72    let mut headers = HeaderMap::new();
73    headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
74    headers.insert(header::USER_AGENT, HeaderValue::from_static(USER_AGENT));
75
76    for raw in raw_headers {
77        let (name, value) = parse_header(raw)?;
78        headers.insert(name, value);
79    }
80    Ok(headers)
81}
82
83fn parse_header(raw: &str) -> Result<(HeaderName, HeaderValue), AppError> {
84    let mut split = raw.splitn(2, ':');
85    let name = split
86        .next()
87        .map(str::trim)
88        .filter(|value| !value.is_empty())
89        .ok_or_else(|| AppError::Usage(format!("invalid header format: {raw}")))?;
90    let value = split
91        .next()
92        .map(str::trim)
93        .ok_or_else(|| AppError::Usage(format!("invalid header format: {raw}")))?;
94    let header_name = HeaderName::from_bytes(name.as_bytes())
95        .map_err(|_| AppError::Usage(format!("invalid header name: {name}")))?;
96    let header_value = HeaderValue::from_str(value)
97        .map_err(|_| AppError::Usage(format!("invalid header value for: {name}")))?;
98    Ok((header_name, header_value))
99}
100
101fn is_retryable_error(err: &reqwest::Error) -> bool {
102    err.is_timeout() || err.is_connect() || err.is_body()
103}
104
105fn should_retry_status(status: reqwest::StatusCode) -> bool {
106    status.as_u16() == 429 || status.is_server_error()
107}
108
109fn next_backoff(current: u64) -> u64 {
110    (current.saturating_mul(2)).min(MAX_BACKOFF_MS)
111}
112
113fn sleep(duration_ms: u64) {
114    thread::sleep(Duration::from_millis(duration_ms));
115}
116
117fn body_snippet(body: String) -> String {
118    let trimmed = body.trim();
119    if trimmed.is_empty() {
120        return String::from("<empty body>");
121    }
122    let snippet: String = trimmed.chars().take(ERROR_SNIPPET_LIMIT).collect();
123    if snippet.len() < trimmed.len() {
124        format!("{snippet}…")
125    } else {
126        snippet
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::cli::OutputProfile;
134    use crate::config::Config;
135    use httpmock::prelude::*;
136
137    fn base_config(url: String) -> Config {
138        Config {
139            url,
140            url_from_default: false,
141            out: None,
142            outline_out: None,
143            reduce: Vec::new(),
144            profile: OutputProfile::Full,
145            minify: false,
146            timeout_ms: 5_000,
147            headers: Vec::new(),
148            stdout: true,
149        }
150    }
151
152    #[test]
153    fn fetch_includes_default_and_custom_headers() {
154        let server = MockServer::start();
155        let mock = server.mock(|when, then| {
156            when.method(GET)
157                .path("/openapi.json")
158                .header("accept", "application/json")
159                .header("user-agent", USER_AGENT)
160                .header("authorization", "Bearer token");
161            then.status(200)
162                .header("content-type", "application/json")
163                .body(r#"{"openapi":"3.0.3","paths":{},"components":{}}"#);
164        });
165
166        let mut config = base_config(server.url("/openapi.json"));
167        config
168            .headers
169            .push("Authorization: Bearer token".to_string());
170
171        let bytes = fetch_openapi(&config).unwrap();
172        let value: Value = serde_json::from_slice(&bytes).unwrap();
173        assert_eq!(value["openapi"], serde_json::json!("3.0.3"));
174        mock.assert_hits(1);
175    }
176
177    #[test]
178    fn retries_on_server_error_then_succeeds() {
179        use std::sync::atomic::{AtomicUsize, Ordering};
180
181        static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
182        CALL_COUNT.store(0, Ordering::SeqCst);
183
184        let server = MockServer::start();
185
186        server.mock(|when, then| {
187            when.method(GET)
188                .path("/openapi.json")
189                .matches(|_| CALL_COUNT.fetch_add(1, Ordering::SeqCst) < 1);
190            then.status(500).body("temporary");
191        });
192
193        let success = server.mock(|when, then| {
194            when.method(GET).path("/openapi.json");
195            then.status(200)
196                .header("content-type", "application/json")
197                .body(r#"{"openapi":"3.0.3","paths":{},"components":{}}"#);
198        });
199
200        let config = base_config(server.url("/openapi.json"));
201        let bytes = fetch_openapi(&config).unwrap();
202        let value: Value = serde_json::from_slice(&bytes).unwrap();
203        assert_eq!(value["openapi"], serde_json::json!("3.0.3"));
204        assert!(CALL_COUNT.load(Ordering::SeqCst) >= 2);
205        assert!(success.hits() >= 1);
206    }
207
208    #[test]
209    fn fetch_surfaces_status_and_body_snippet() {
210        let server = MockServer::start();
211        server.mock(|when, then| {
212            when.method(GET).path("/openapi.json");
213            then.status(502).body("gateway down");
214        });
215
216        let config = base_config(server.url("/openapi.json"));
217        let err = fetch_openapi(&config).unwrap_err();
218        match err {
219            AppError::Network(msg) => {
220                assert!(msg.contains("502"));
221                assert!(msg.contains("gateway down"));
222            }
223            other => panic!("expected network error, got {other:?}"),
224        }
225    }
226
227    #[test]
228    fn returns_error_with_status_and_snippet_when_retries_exhausted() {
229        let server = MockServer::start();
230        let fail = server.mock(|when, then| {
231            when.method(GET).path("/openapi.json");
232            then.status(500).body("server exploded");
233        });
234
235        let config = base_config(server.url("/openapi.json"));
236        let err = fetch_openapi(&config).unwrap_err();
237        let message = format!("{err}");
238        assert!(message.contains("HTTP 500"));
239        assert!(message.contains("server exploded"));
240        fail.assert_hits(MAX_RETRIES);
241    }
242
243    #[test]
244    fn stops_after_max_retries_and_returns_error() {
245        let server = MockServer::start();
246        let mock = server.mock(|when, then| {
247            when.method(GET).path("/openapi.json");
248            then.status(503).body("down");
249        });
250
251        let config = base_config(server.url("/openapi.json"));
252        let err = fetch_openapi(&config).unwrap_err();
253        assert!(format!("{err}").contains("HTTP 503"));
254        mock.assert_hits(MAX_RETRIES);
255    }
256
257    #[test]
258    fn error_includes_body_snippet() {
259        let server = MockServer::start();
260        let mock = server.mock(|when, then| {
261            when.method(GET).path("/openapi.json");
262            then.status(400).body("something went wrong in backend");
263        });
264
265        let config = base_config(server.url("/openapi.json"));
266        let err = fetch_openapi(&config).unwrap_err();
267        match err {
268            AppError::Network(msg) => {
269                assert!(msg.contains("400"));
270                assert!(msg.contains("something went wrong in backend"));
271            }
272            other => panic!("expected network error, got {other:?}"),
273        }
274        mock.assert_hits(1);
275    }
276}