1use crate::error::{Error, ErrorBody, Result};
9use reqwest::{header::HeaderMap, Method, RequestBuilder, Response, StatusCode};
10use serde::de::DeserializeOwned;
11use serde_json::Value;
12use std::time::Duration;
13
14pub(crate) const MAX_BACKOFF: Duration = Duration::from_secs(10);
15
16pub(crate) const API_KEY_HEADER: &str = "X-API-KEY";
17pub(crate) const RATE_LIMIT_REMAINING_HDR: &str = "x-ratelimit-remaining";
18pub(crate) const RATE_LIMIT_LIMIT_HDR: &str = "x-ratelimit-limit";
19pub(crate) const RATE_LIMIT_RESET_HDR: &str = "x-ratelimit-reset";
20pub(crate) const RATE_LIMIT_TYPE_HDR: &str = "x-ratelimit-type";
21pub(crate) const RETRY_AFTER_HDR: &str = "retry-after";
22
23#[derive(Debug, Clone, Default, PartialEq, Eq)]
29pub struct RateLimitInfo {
30 pub remaining: Option<i64>,
32 pub limit: Option<i64>,
34 pub reset_in: Option<i64>,
36 pub retry_after: Option<i64>,
38 pub limit_type: Option<String>,
40}
41
42impl RateLimitInfo {
43 pub(crate) fn from_headers(h: &HeaderMap) -> Self {
44 Self {
45 remaining: parse_int_header(h, RATE_LIMIT_REMAINING_HDR),
46 limit: parse_int_header(h, RATE_LIMIT_LIMIT_HDR),
47 reset_in: parse_int_header(h, RATE_LIMIT_RESET_HDR),
48 retry_after: parse_int_header(h, RETRY_AFTER_HDR),
49 limit_type: h
50 .get(RATE_LIMIT_TYPE_HDR)
51 .and_then(|v| v.to_str().ok())
52 .map(str::to_string),
53 }
54 }
55}
56
57fn parse_int_header(h: &HeaderMap, key: &str) -> Option<i64> {
58 h.get(key)
59 .and_then(|v| v.to_str().ok())
60 .map(str::trim)
61 .and_then(|s| s.parse::<i64>().ok())
62}
63
64pub(crate) fn parse_retry_after(h: &HeaderMap) -> Duration {
68 let Some(raw) = h.get(RETRY_AFTER_HDR).and_then(|v| v.to_str().ok()) else {
69 return Duration::ZERO;
70 };
71 let raw = raw.trim();
72 if let Ok(secs) = raw.parse::<u64>() {
73 return Duration::from_secs(secs).min(MAX_BACKOFF);
74 }
75 Duration::ZERO
79}
80
81pub(crate) fn extract_validation_message(body: Option<&Value>) -> String {
88 const FALLBACK: &str = "invalid request parameters";
89 let Some(Value::Object(map)) = body else {
90 return FALLBACK.to_string();
91 };
92 for key in ["detail", "message", "error"] {
93 if let Some(Value::String(s)) = map.get(key) {
94 if !s.is_empty() {
95 return format!("invalid request parameters: {s}");
96 }
97 }
98 }
99 let mut keys: Vec<&String> = map.keys().collect();
101 keys.sort();
102 for k in keys {
103 match map.get(k) {
104 Some(Value::Array(arr)) if !arr.is_empty() => {
105 if let Some(Value::String(s)) = arr.first() {
106 if !s.is_empty() {
107 return format!("invalid request parameters: {s}");
108 }
109 }
110 }
111 Some(Value::String(s)) if !s.is_empty() => {
112 return format!("invalid request parameters: {s}");
113 }
114 _ => {}
115 }
116 }
117 FALLBACK.to_string()
118}
119
120#[derive(Debug)]
123pub(crate) enum Body<'a> {
124 None,
125 Json(&'a Value),
126}
127
128pub(crate) async fn send_with_retries(
129 inner: &crate::client::ClientInner,
130 method: Method,
131 url: reqwest::Url,
132 body: Body<'_>,
133) -> Result<Vec<u8>> {
134 let max_attempts = inner.retries.saturating_add(1);
135 let mut attempt: u32 = 0;
136
137 let body_bytes = match body {
139 Body::None => None,
140 Body::Json(v) => Some(serde_json::to_vec(v)?),
141 };
142
143 loop {
144 let err =
145 match attempt_once(inner, method.clone(), url.clone(), body_bytes.as_deref()).await {
146 Ok(bytes) => return Ok(bytes),
147 Err(e) => e,
148 };
149
150 if !err.is_retryable() || attempt + 1 >= max_attempts {
151 return Err(err);
152 }
153
154 let wait = if let Error::RateLimit { retry_after, .. } = &err {
155 let r = u64::from(*retry_after);
156 if r > 0 {
157 Duration::from_secs(r).min(MAX_BACKOFF)
158 } else {
159 backoff_for(inner.retry_backoff, attempt)
160 }
161 } else {
162 backoff_for(inner.retry_backoff, attempt)
163 };
164
165 tokio::time::sleep(wait).await;
166 attempt += 1;
167 }
168}
169
170fn backoff_for(base: Duration, attempt: u32) -> Duration {
171 let mult = 1u32 << attempt.min(6); base.saturating_mul(mult).min(MAX_BACKOFF)
173}
174
175async fn attempt_once(
176 inner: &crate::client::ClientInner,
177 method: Method,
178 url: reqwest::Url,
179 body_bytes: Option<&[u8]>,
180) -> Result<Vec<u8>> {
181 let mut req: RequestBuilder = inner.http.request(method, url);
182 req = req.header(reqwest::header::ACCEPT, "application/json");
183 if !inner.api_key.is_empty() {
184 req = req.header(API_KEY_HEADER, &inner.api_key);
185 }
186 if !inner.user_agent.is_empty() {
187 req = req.header(reqwest::header::USER_AGENT, &inner.user_agent);
188 }
189 if let Some(bytes) = body_bytes {
190 req = req
191 .header(reqwest::header::CONTENT_TYPE, "application/json")
192 .body(bytes.to_vec());
193 }
194 if !inner.timeout.is_zero() {
195 req = req.timeout(inner.timeout);
196 }
197
198 let resp_result = req.send().await;
199 let resp: Response = match resp_result {
200 Ok(r) => r,
201 Err(e) => {
202 if e.is_timeout() {
203 return Err(Error::Timeout {
204 timeout: inner.timeout,
205 });
206 }
207 return Err(Error::Transport(e));
208 }
209 };
210
211 let status = resp.status();
212 let headers = resp.headers().clone();
213 inner.set_last_response(&headers);
215
216 let bytes = match resp.bytes().await {
217 Ok(b) => b.to_vec(),
218 Err(e) => return Err(Error::Transport(e)),
219 };
220
221 if status.is_success() {
222 return Ok(bytes);
223 }
224
225 Err(decode_error(status, &headers, &bytes))
226}
227
228fn decode_error(status: StatusCode, headers: &HeaderMap, body: &[u8]) -> Error {
229 let parsed_value: Option<Value> = if body.is_empty() {
230 None
231 } else {
232 serde_json::from_slice(body).ok()
233 };
234 let body_message = parsed_value.as_ref().and_then(extract_top_level_message);
235 let response = parsed_value.as_ref().map(|v| ErrorBody {
236 message: body_message.clone().unwrap_or_default(),
237 raw: Some(v.clone()),
238 });
239 match status.as_u16() {
240 401 => Error::Auth { response },
241 404 => Error::NotFound { response },
242 400 => Error::Validation {
243 message: extract_validation_message(parsed_value.as_ref()),
244 response,
245 },
246 429 => {
247 let retry_after = parse_retry_after(headers).as_secs();
248 Error::RateLimit {
249 retry_after: u32::try_from(retry_after).unwrap_or(u32::MAX),
250 limit_type: headers
251 .get(RATE_LIMIT_TYPE_HDR)
252 .and_then(|v| v.to_str().ok())
253 .map(str::to_string),
254 response,
255 }
256 }
257 code => Error::Api {
258 status: code,
259 message: body_message
260 .unwrap_or_else(|| format!("API request failed with status {code}")),
261 response,
262 },
263 }
264}
265
266fn extract_top_level_message(v: &Value) -> Option<String> {
267 let Value::Object(map) = v else { return None };
268 for key in ["detail", "message", "error"] {
269 if let Some(Value::String(s)) = map.get(key) {
270 if !s.is_empty() {
271 return Some(s.clone());
272 }
273 }
274 }
275 None
276}
277
278pub(crate) fn decode_json<T: DeserializeOwned>(bytes: &[u8]) -> Result<T> {
281 serde_json::from_slice(bytes).map_err(Error::Decode)
282}
283
284pub(crate) fn decode_json_or_default<T>(bytes: &[u8]) -> Result<T>
287where
288 T: DeserializeOwned + Default,
289{
290 if bytes.is_empty() {
291 return Ok(T::default());
292 }
293 decode_json(bytes)
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use serde_json::json;
300
301 #[test]
302 fn extract_envelope_detail() {
303 let body = json!({"detail": "no soup for you"});
304 assert_eq!(
305 extract_validation_message(Some(&body)),
306 "invalid request parameters: no soup for you"
307 );
308 }
309
310 #[test]
311 fn extract_envelope_message() {
312 let body = json!({"message": "bad input"});
313 assert_eq!(
314 extract_validation_message(Some(&body)),
315 "invalid request parameters: bad input"
316 );
317 }
318
319 #[test]
320 fn extract_field_errors_sorted() {
321 let body = json!({
322 "zebra": ["last alphabetically"],
323 "apple": ["first alphabetically"],
324 });
325 assert_eq!(
327 extract_validation_message(Some(&body)),
328 "invalid request parameters: first alphabetically"
329 );
330 }
331
332 #[test]
333 fn extract_falls_back() {
334 let body = json!({});
335 assert_eq!(
336 extract_validation_message(Some(&body)),
337 "invalid request parameters"
338 );
339 assert_eq!(
340 extract_validation_message(None),
341 "invalid request parameters"
342 );
343 }
344
345 #[test]
346 fn extract_prefers_envelope_over_field() {
347 let body = json!({
348 "detail": "envelope wins",
349 "apple": ["field loses"],
350 });
351 assert_eq!(
352 extract_validation_message(Some(&body)),
353 "invalid request parameters: envelope wins"
354 );
355 }
356
357 #[test]
358 fn extract_string_field_error() {
359 let body = json!({"piid": "must be present"});
360 assert_eq!(
361 extract_validation_message(Some(&body)),
362 "invalid request parameters: must be present"
363 );
364 }
365
366 #[test]
367 fn backoff_doubles_and_caps() {
368 let base = Duration::from_millis(250);
369 assert_eq!(backoff_for(base, 0), Duration::from_millis(250));
370 assert_eq!(backoff_for(base, 1), Duration::from_millis(500));
371 assert_eq!(backoff_for(base, 2), Duration::from_secs(1));
372 assert_eq!(backoff_for(base, 3), Duration::from_secs(2));
373 assert_eq!(backoff_for(base, 4), Duration::from_secs(4));
374 assert_eq!(backoff_for(base, 5), Duration::from_secs(8));
375 assert_eq!(backoff_for(base, 6), MAX_BACKOFF);
376 assert_eq!(backoff_for(base, 50), MAX_BACKOFF);
377 }
378}