openapi_snapshot/
fetch.rs1use std::thread;
2use std::time::Duration;
3
4use reqwest::blocking::Client;
5use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
6use serde_json::Value;
7
8use crate::config::Config;
9use crate::errors::AppError;
10
11const USER_AGENT: &str = concat!("openapi-snapshot/", env!("CARGO_PKG_VERSION"));
12const MAX_RETRIES: usize = 3;
13const BASE_BACKOFF_MS: u64 = 100;
14const MAX_BACKOFF_MS: u64 = 2_000;
15const ERROR_SNIPPET_LIMIT: usize = 256;
16
17pub fn fetch_openapi(config: &Config) -> Result<Vec<u8>, AppError> {
18 let headers = build_headers(&config.headers)?;
19 let client = Client::builder()
20 .timeout(Duration::from_millis(config.timeout_ms))
21 .default_headers(headers)
22 .build()
23 .map_err(|err| AppError::Network(format!("client error: {err}")))?;
24
25 let mut backoff = BASE_BACKOFF_MS;
26 let mut attempt = 0;
27 loop {
28 attempt += 1;
29 match client.get(&config.url).send() {
30 Ok(response) => {
31 let status = response.status();
32 if !status.is_success() {
33 let snippet = body_snippet(response.text().unwrap_or_default());
34 let message = format!("HTTP {status}: {snippet}");
35 if should_retry_status(status) && attempt < MAX_RETRIES {
36 sleep(backoff);
37 backoff = next_backoff(backoff);
38 continue;
39 }
40 return Err(AppError::Network(message));
41 }
42
43 match response.bytes() {
44 Ok(bytes) => return Ok(bytes.to_vec()),
45 Err(err) => {
46 if is_retryable_error(&err) && attempt < MAX_RETRIES {
47 sleep(backoff);
48 backoff = next_backoff(backoff);
49 continue;
50 }
51 return Err(AppError::Network(format!("failed to read response: {err}")));
52 }
53 }
54 }
55 Err(err) => {
56 if is_retryable_error(&err) && attempt < MAX_RETRIES {
57 sleep(backoff);
58 backoff = next_backoff(backoff);
59 continue;
60 }
61 return Err(AppError::Network(format!("request failed: {err}")));
62 }
63 }
64 }
65}
66
67pub fn parse_json(bytes: &[u8]) -> Result<Value, AppError> {
68 serde_json::from_slice(bytes).map_err(|err| AppError::Json(format!("invalid JSON: {err}")))
69}
70
71fn build_headers(raw_headers: &[String]) -> Result<HeaderMap, AppError> {
72 let mut headers = HeaderMap::new();
73 headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
74 headers.insert(header::USER_AGENT, HeaderValue::from_static(USER_AGENT));
75
76 for raw in raw_headers {
77 let (name, value) = parse_header(raw)?;
78 headers.insert(name, value);
79 }
80 Ok(headers)
81}
82
83fn parse_header(raw: &str) -> Result<(HeaderName, HeaderValue), AppError> {
84 let mut split = raw.splitn(2, ':');
85 let name = split
86 .next()
87 .map(str::trim)
88 .filter(|value| !value.is_empty())
89 .ok_or_else(|| AppError::Usage(format!("invalid header format: {raw}")))?;
90 let value = split
91 .next()
92 .map(str::trim)
93 .ok_or_else(|| AppError::Usage(format!("invalid header format: {raw}")))?;
94 let header_name = HeaderName::from_bytes(name.as_bytes())
95 .map_err(|_| AppError::Usage(format!("invalid header name: {name}")))?;
96 let header_value = HeaderValue::from_str(value)
97 .map_err(|_| AppError::Usage(format!("invalid header value for: {name}")))?;
98 Ok((header_name, header_value))
99}
100
101fn is_retryable_error(err: &reqwest::Error) -> bool {
102 err.is_timeout() || err.is_connect() || err.is_body()
103}
104
105fn should_retry_status(status: reqwest::StatusCode) -> bool {
106 status.as_u16() == 429 || status.is_server_error()
107}
108
109fn next_backoff(current: u64) -> u64 {
110 (current.saturating_mul(2)).min(MAX_BACKOFF_MS)
111}
112
113fn sleep(duration_ms: u64) {
114 thread::sleep(Duration::from_millis(duration_ms));
115}
116
117fn body_snippet(body: String) -> String {
118 let trimmed = body.trim();
119 if trimmed.is_empty() {
120 return String::from("<empty body>");
121 }
122 let snippet: String = trimmed.chars().take(ERROR_SNIPPET_LIMIT).collect();
123 if snippet.len() < trimmed.len() {
124 format!("{snippet}…")
125 } else {
126 snippet
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use crate::cli::OutputProfile;
134 use crate::config::Config;
135 use httpmock::prelude::*;
136
137 fn base_config(url: String) -> Config {
138 Config {
139 url,
140 url_from_default: false,
141 out: None,
142 outline_out: None,
143 reduce: Vec::new(),
144 profile: OutputProfile::Full,
145 minify: false,
146 timeout_ms: 5_000,
147 headers: Vec::new(),
148 stdout: true,
149 }
150 }
151
152 #[test]
153 fn fetch_includes_default_and_custom_headers() {
154 let server = MockServer::start();
155 let mock = server.mock(|when, then| {
156 when.method(GET)
157 .path("/openapi.json")
158 .header("accept", "application/json")
159 .header("user-agent", USER_AGENT)
160 .header("authorization", "Bearer token");
161 then.status(200)
162 .header("content-type", "application/json")
163 .body(r#"{"openapi":"3.0.3","paths":{},"components":{}}"#);
164 });
165
166 let mut config = base_config(server.url("/openapi.json"));
167 config
168 .headers
169 .push("Authorization: Bearer token".to_string());
170
171 let bytes = fetch_openapi(&config).unwrap();
172 let value: Value = serde_json::from_slice(&bytes).unwrap();
173 assert_eq!(value["openapi"], serde_json::json!("3.0.3"));
174 mock.assert_hits(1);
175 }
176
177 #[test]
178 fn retries_on_server_error_then_succeeds() {
179 use std::sync::atomic::{AtomicUsize, Ordering};
180
181 static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
182 CALL_COUNT.store(0, Ordering::SeqCst);
183
184 let server = MockServer::start();
185
186 server.mock(|when, then| {
187 when.method(GET)
188 .path("/openapi.json")
189 .matches(|_| CALL_COUNT.fetch_add(1, Ordering::SeqCst) < 1);
190 then.status(500).body("temporary");
191 });
192
193 let success = server.mock(|when, then| {
194 when.method(GET).path("/openapi.json");
195 then.status(200)
196 .header("content-type", "application/json")
197 .body(r#"{"openapi":"3.0.3","paths":{},"components":{}}"#);
198 });
199
200 let config = base_config(server.url("/openapi.json"));
201 let bytes = fetch_openapi(&config).unwrap();
202 let value: Value = serde_json::from_slice(&bytes).unwrap();
203 assert_eq!(value["openapi"], serde_json::json!("3.0.3"));
204 assert!(CALL_COUNT.load(Ordering::SeqCst) >= 2);
205 assert!(success.hits() >= 1);
206 }
207
208 #[test]
209 fn fetch_surfaces_status_and_body_snippet() {
210 let server = MockServer::start();
211 server.mock(|when, then| {
212 when.method(GET).path("/openapi.json");
213 then.status(502).body("gateway down");
214 });
215
216 let config = base_config(server.url("/openapi.json"));
217 let err = fetch_openapi(&config).unwrap_err();
218 match err {
219 AppError::Network(msg) => {
220 assert!(msg.contains("502"));
221 assert!(msg.contains("gateway down"));
222 }
223 other => panic!("expected network error, got {other:?}"),
224 }
225 }
226
227 #[test]
228 fn returns_error_with_status_and_snippet_when_retries_exhausted() {
229 let server = MockServer::start();
230 let fail = server.mock(|when, then| {
231 when.method(GET).path("/openapi.json");
232 then.status(500).body("server exploded");
233 });
234
235 let config = base_config(server.url("/openapi.json"));
236 let err = fetch_openapi(&config).unwrap_err();
237 let message = format!("{err}");
238 assert!(message.contains("HTTP 500"));
239 assert!(message.contains("server exploded"));
240 fail.assert_hits(MAX_RETRIES);
241 }
242
243 #[test]
244 fn stops_after_max_retries_and_returns_error() {
245 let server = MockServer::start();
246 let mock = server.mock(|when, then| {
247 when.method(GET).path("/openapi.json");
248 then.status(503).body("down");
249 });
250
251 let config = base_config(server.url("/openapi.json"));
252 let err = fetch_openapi(&config).unwrap_err();
253 assert!(format!("{err}").contains("HTTP 503"));
254 mock.assert_hits(MAX_RETRIES);
255 }
256
257 #[test]
258 fn error_includes_body_snippet() {
259 let server = MockServer::start();
260 let mock = server.mock(|when, then| {
261 when.method(GET).path("/openapi.json");
262 then.status(400).body("something went wrong in backend");
263 });
264
265 let config = base_config(server.url("/openapi.json"));
266 let err = fetch_openapi(&config).unwrap_err();
267 match err {
268 AppError::Network(msg) => {
269 assert!(msg.contains("400"));
270 assert!(msg.contains("something went wrong in backend"));
271 }
272 other => panic!("expected network error, got {other:?}"),
273 }
274 mock.assert_hits(1);
275 }
276}