1use core::future::Future;
42
43use api_bones::{ApiError, RateLimitInfo};
44use reqwest::{RequestBuilder, Response};
45
46pub trait RequestBuilderExt: Sized {
52 #[must_use]
64 fn with_request_id(self, id: impl AsRef<str>) -> Self;
65
66 #[must_use]
78 fn with_idempotency_key(self, key: impl AsRef<str>) -> Self;
79
80 #[must_use]
92 fn with_bearer_token(self, token: impl AsRef<str>) -> Self;
93}
94
95impl RequestBuilderExt for RequestBuilder {
96 fn with_request_id(self, id: impl AsRef<str>) -> Self {
97 self.header("x-request-id", id.as_ref())
98 }
99
100 fn with_idempotency_key(self, key: impl AsRef<str>) -> Self {
101 self.header("idempotency-key", key.as_ref())
102 }
103
104 fn with_bearer_token(self, token: impl AsRef<str>) -> Self {
105 self.header("authorization", format!("Bearer {}", token.as_ref()))
106 }
107}
108
109pub trait ResponseExt {
115 #[must_use]
134 fn rate_limit_info(&self) -> Option<RateLimitInfo>;
135
136 #[must_use]
153 fn next_page_url(&self) -> Option<String>;
154
155 fn problem_json_or_json<T: serde::de::DeserializeOwned>(
188 self,
189 ) -> impl Future<Output = Result<T, ApiError>> + Send;
190}
191
192impl ResponseExt for Response {
193 fn rate_limit_info(&self) -> Option<RateLimitInfo> {
194 let parse = |name: &str| -> Option<u64> {
195 self.headers()
196 .get(name)
197 .and_then(|v| v.to_str().ok())
198 .and_then(|s| s.parse().ok())
199 };
200 let limit = parse("x-ratelimit-limit")?;
201 let remaining = parse("x-ratelimit-remaining")?;
202 let reset = parse("x-ratelimit-reset")?;
203 let retry_after = parse("retry-after");
204 Some(RateLimitInfo {
205 limit,
206 remaining,
207 reset,
208 retry_after,
209 })
210 }
211
212 fn next_page_url(&self) -> Option<String> {
213 for link_val in self.headers().get_all("link") {
214 let Ok(s) = link_val.to_str() else {
215 continue;
216 };
217 for entry in s.split(',') {
218 let entry = entry.trim();
219 if let Some(url) = parse_link_next(entry) {
220 return Some(url);
221 }
222 }
223 }
224 None
225 }
226
227 async fn problem_json_or_json<T: serde::de::DeserializeOwned>(self) -> Result<T, ApiError> {
228 let status = self.status();
229 if status.is_client_error() || status.is_server_error() {
230 let is_problem = self
231 .headers()
232 .get("content-type")
233 .and_then(|v| v.to_str().ok())
234 .is_some_and(|ct| ct.contains("application/problem+json"));
235
236 if is_problem {
237 let body: serde_json::Value = self
238 .json()
239 .await
240 .map_err(|e| ApiError::bad_request(e.to_string()))?;
241 let detail = body
242 .get("detail")
243 .and_then(|v| v.as_str())
244 .unwrap_or("unknown error")
245 .to_owned();
246 let code_status = body
247 .get("status")
248 .and_then(serde_json::Value::as_u64)
249 .and_then(|s| u16::try_from(s).ok())
250 .unwrap_or(status.as_u16());
251 return Err(map_status_to_api_error(code_status, detail));
252 }
253
254 return Err(ApiError::new(
255 api_bones::ErrorCode::InternalServerError,
256 format!("HTTP {}", status.as_u16()),
257 ));
258 }
259
260 self.json::<T>()
261 .await
262 .map_err(|e| ApiError::bad_request(e.to_string()))
263 }
264}
265
266fn parse_link_next(entry: &str) -> Option<String> {
271 let mut parts = entry.split(';');
272 let url_part = parts.next()?.trim();
273 let url = url_part
274 .strip_prefix('<')
275 .and_then(|s| s.strip_suffix('>'))?;
276
277 let is_next = parts.any(|p| {
278 let p = p.trim();
279 p == "rel=\"next\"" || p == "rel=next"
280 });
281
282 if is_next { Some(url.to_owned()) } else { None }
283}
284
285fn map_status_to_api_error(status: u16, detail: String) -> ApiError {
286 use api_bones::ErrorCode;
287 let code = match status {
288 400 => ErrorCode::BadRequest,
289 401 => ErrorCode::Unauthorized,
290 403 => ErrorCode::Forbidden,
291 404 => ErrorCode::ResourceNotFound,
292 409 => ErrorCode::Conflict,
293 422 => ErrorCode::UnprocessableEntity,
294 429 => ErrorCode::RateLimited,
295 500 => ErrorCode::InternalServerError,
296 502 => ErrorCode::BadGateway,
297 503 => ErrorCode::ServiceUnavailable,
298 504 => ErrorCode::GatewayTimeout,
299 _ if status >= 500 => ErrorCode::InternalServerError,
300 _ => ErrorCode::BadRequest,
301 };
302 ApiError::new(code, detail)
303}
304
305pub async fn from_response(resp: reqwest::Response) -> api_bones::ApiError {
317 use api_bones::{ApiError, ErrorCode, ProblemJson};
318
319 let http_status = resp.status().as_u16();
320
321 let is_problem_json = resp
322 .headers()
323 .get(reqwest::header::CONTENT_TYPE)
324 .and_then(|v| v.to_str().ok())
325 .is_some_and(|ct| ct.contains("application/problem+json"));
326
327 if is_problem_json {
328 match resp.json::<ProblemJson>().await {
329 Ok(p) => {
330 let code = ErrorCode::from_type_uri(&p.r#type)
331 .unwrap_or_else(|| map_status_to_api_error(http_status, String::new()).code);
332 let mut err = ApiError::new(code, p.detail);
333 err.title = p.title;
334 err.status = p.status;
335 if let Some(inst) = p.instance {
336 #[cfg(feature = "uuid")]
337 if let Some(hex) = inst.strip_prefix("urn:uuid:")
338 && let Ok(id) = hex.parse::<uuid::Uuid>()
339 {
340 err.request_id = Some(id);
341 }
342 #[cfg(not(feature = "uuid"))]
343 let _ = inst;
344 }
345 err.extensions = p.extensions;
346 err
347 }
348 Err(_) => ApiError::new(
349 map_status_to_api_error(http_status, String::new()).code,
350 "failed to parse problem+json response",
351 ),
352 }
353 } else {
354 let detail = resp
355 .text()
356 .await
357 .unwrap_or_else(|_| "upstream error".to_owned());
358 map_status_to_api_error(http_status, detail)
359 }
360}
361
362#[cfg(test)]
367#[allow(clippy::significant_drop_tightening)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn map_status_401() {
373 let err = map_status_to_api_error(401, "unauth".into());
374 assert_eq!(err.status, 401);
375 }
376
377 #[test]
378 fn map_status_403() {
379 let err = map_status_to_api_error(403, "forbidden".into());
380 assert_eq!(err.status, 403);
381 }
382
383 #[test]
384 fn map_status_409() {
385 let err = map_status_to_api_error(409, "conflict".into());
386 assert_eq!(err.status, 409);
387 }
388
389 #[test]
390 fn map_status_422() {
391 let err = map_status_to_api_error(422, "unprocessable".into());
392 assert_eq!(err.status, 422);
393 }
394
395 #[test]
396 fn map_status_429() {
397 let err = map_status_to_api_error(429, "rate limited".into());
398 assert_eq!(err.status, 429);
399 }
400
401 #[test]
402 fn map_status_500() {
403 let err = map_status_to_api_error(500, "ise".into());
404 assert_eq!(err.status, 500);
405 }
406
407 #[test]
408 fn map_status_502() {
409 let err = map_status_to_api_error(502, "bad gateway".into());
410 assert_eq!(err.status, 502);
411 }
412
413 #[test]
414 fn map_status_503() {
415 let err = map_status_to_api_error(503, "unavailable".into());
416 assert_eq!(err.status, 503);
417 }
418
419 #[test]
420 fn map_status_504() {
421 let err = map_status_to_api_error(504, "timeout".into());
422 assert_eq!(err.status, 504);
423 }
424
425 #[tokio::test]
426 async fn request_builder_with_request_id() {
427 let mut server = mockito::Server::new_async().await;
428 let mock = server
429 .mock("GET", "/")
430 .match_header("x-request-id", "req-abc")
431 .with_status(200)
432 .with_body("{}")
433 .create_async()
434 .await;
435
436 let client = reqwest::Client::new();
437 let resp = client
438 .get(server.url())
439 .with_request_id("req-abc")
440 .send()
441 .await
442 .unwrap();
443 assert_eq!(resp.status().as_u16(), 200);
444 mock.assert_async().await;
445 }
446
447 #[tokio::test]
448 async fn request_builder_with_idempotency_key() {
449 let mut server = mockito::Server::new_async().await;
450 let mock = server
451 .mock("POST", "/")
452 .match_header("idempotency-key", "idem-123")
453 .with_status(201)
454 .with_body("{}")
455 .create_async()
456 .await;
457
458 let client = reqwest::Client::new();
459 let resp = client
460 .post(server.url())
461 .with_idempotency_key("idem-123")
462 .send()
463 .await
464 .unwrap();
465 assert_eq!(resp.status().as_u16(), 201);
466 mock.assert_async().await;
467 }
468
469 #[tokio::test]
470 async fn request_builder_with_bearer_token() {
471 let mut server = mockito::Server::new_async().await;
472 let mock = server
473 .mock("GET", "/")
474 .match_header("authorization", "Bearer my.token")
475 .with_status(200)
476 .with_body("{}")
477 .create_async()
478 .await;
479
480 let client = reqwest::Client::new();
481 let resp = client
482 .get(server.url())
483 .with_bearer_token("my.token")
484 .send()
485 .await
486 .unwrap();
487 assert_eq!(resp.status().as_u16(), 200);
488 mock.assert_async().await;
489 }
490
491 #[tokio::test]
492 async fn rate_limit_info_present() {
493 let mut server = mockito::Server::new_async().await;
494 let _mock = server
495 .mock("GET", "/")
496 .with_status(200)
497 .with_header("x-ratelimit-limit", "100")
498 .with_header("x-ratelimit-remaining", "42")
499 .with_header("x-ratelimit-reset", "1700000000")
500 .with_header("retry-after", "5")
501 .with_body("{}")
502 .create_async()
503 .await;
504
505 let resp = reqwest::get(server.url()).await.unwrap();
506 let rl = resp.rate_limit_info().unwrap();
507 assert_eq!(rl.limit, 100);
508 assert_eq!(rl.remaining, 42);
509 assert_eq!(rl.reset, 1_700_000_000);
510 assert_eq!(rl.retry_after, Some(5));
511 }
512
513 #[tokio::test]
514 async fn rate_limit_info_missing_headers_returns_none() {
515 let mut server = mockito::Server::new_async().await;
516 let _mock = server
517 .mock("GET", "/")
518 .with_status(200)
519 .with_body("{}")
520 .create_async()
521 .await;
522
523 let resp = reqwest::get(server.url()).await.unwrap();
524 assert!(resp.rate_limit_info().is_none());
525 }
526
527 #[tokio::test]
528 async fn rate_limit_info_without_retry_after() {
529 let mut server = mockito::Server::new_async().await;
530 let _mock = server
531 .mock("GET", "/")
532 .with_status(200)
533 .with_header("x-ratelimit-limit", "50")
534 .with_header("x-ratelimit-remaining", "10")
535 .with_header("x-ratelimit-reset", "9999")
536 .with_body("{}")
537 .create_async()
538 .await;
539
540 let resp = reqwest::get(server.url()).await.unwrap();
541 let rl = resp.rate_limit_info().unwrap();
542 assert_eq!(rl.retry_after, None);
543 }
544
545 #[tokio::test]
546 async fn next_page_url_present() {
547 let mut server = mockito::Server::new_async().await;
548 let _mock = server
549 .mock("GET", "/")
550 .with_status(200)
551 .with_header(
552 "link",
553 r#"<https://api.example.com/items?after=xyz>; rel="next""#,
554 )
555 .with_body("[]")
556 .create_async()
557 .await;
558
559 let resp = reqwest::get(server.url()).await.unwrap();
560 assert_eq!(
561 resp.next_page_url(),
562 Some("https://api.example.com/items?after=xyz".to_owned())
563 );
564 }
565
566 #[tokio::test]
567 async fn next_page_url_absent() {
568 let mut server = mockito::Server::new_async().await;
569 let _mock = server
570 .mock("GET", "/")
571 .with_status(200)
572 .with_body("[]")
573 .create_async()
574 .await;
575
576 let resp = reqwest::get(server.url()).await.unwrap();
577 assert!(resp.next_page_url().is_none());
578 }
579
580 #[tokio::test]
581 async fn problem_json_or_json_success() {
582 let mut server = mockito::Server::new_async().await;
583 let _mock = server
584 .mock("GET", "/")
585 .with_status(200)
586 .with_header("content-type", "application/json")
587 .with_body(r#"{"value": 42}"#)
588 .create_async()
589 .await;
590
591 let resp = reqwest::get(server.url()).await.unwrap();
592 let body: serde_json::Value = resp.problem_json_or_json().await.unwrap();
593 assert_eq!(body["value"], 42);
594 }
595
596 #[tokio::test]
597 async fn problem_json_or_json_problem_response() {
598 let mut server = mockito::Server::new_async().await;
599 let problem_body =
600 r#"{"type":"about:blank","title":"Not Found","status":404,"detail":"item missing"}"#;
601 let _mock = server
602 .mock("GET", "/")
603 .with_status(404)
604 .with_header("content-type", "application/problem+json")
605 .with_body(problem_body)
606 .create_async()
607 .await;
608
609 let resp = reqwest::get(server.url()).await.unwrap();
610 let err: api_bones::ApiError = resp
611 .problem_json_or_json::<serde_json::Value>()
612 .await
613 .unwrap_err();
614 assert_eq!(err.status, 404);
615 }
616
617 #[tokio::test]
618 async fn problem_json_or_json_non_problem_error_response() {
619 let mut server = mockito::Server::new_async().await;
620 let _mock = server
621 .mock("GET", "/")
622 .with_status(500)
623 .with_header("content-type", "text/plain")
624 .with_body("Internal Server Error")
625 .create_async()
626 .await;
627
628 let resp = reqwest::get(server.url()).await.unwrap();
629 let err: api_bones::ApiError = resp
630 .problem_json_or_json::<serde_json::Value>()
631 .await
632 .unwrap_err();
633 assert_eq!(err.status, 500);
634 }
635
636 #[test]
637 fn map_status_418_defaults_to_bad_request() {
638 let err = map_status_to_api_error(418, "teapot".into());
639 assert_eq!(err.status, 400);
640 }
641
642 #[tokio::test]
643 async fn next_page_url_non_utf8_link_header_is_skipped() {
644 use tokio::io::AsyncWriteExt;
645 use tokio::net::TcpListener;
646
647 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
648 let addr = listener.local_addr().unwrap();
649
650 tokio::spawn(async move {
651 let (mut stream, _) = listener.accept().await.unwrap();
652 let mut buf = [0u8; 4096];
653 let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buf).await;
654 let response: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nLink: \xff\r\n\r\n[]";
655 let _ = stream.write_all(response).await;
656 });
657
658 let url = format!("http://{addr}/");
659 if let Ok(resp) = reqwest::get(&url).await {
660 assert!(resp.next_page_url().is_none());
661 }
662 }
663
664 #[tokio::test]
665 async fn next_page_url_with_only_prev_link() {
666 let mut server = mockito::Server::new_async().await;
667 let _mock = server
668 .mock("GET", "/")
669 .with_status(200)
670 .with_header(
671 "link",
672 r#"<https://api.example.com/items?before=abc>; rel="prev""#,
673 )
674 .with_body("[]")
675 .create_async()
676 .await;
677
678 let resp = reqwest::get(server.url()).await.unwrap();
679 assert!(resp.next_page_url().is_none());
680 }
681
682 #[test]
683 fn parse_link_next_empty_entry_returns_none() {
684 assert!(parse_link_next("").is_none());
685 }
686
687 #[test]
688 fn parse_link_next_malformed_url_no_closing_angle_returns_none() {
689 let entry = "<https://example.com; rel=\"next\"";
690 assert!(parse_link_next(entry).is_none());
691 }
692
693 #[tokio::test]
694 async fn problem_json_or_json_problem_response_invalid_json_body() {
695 let mut server = mockito::Server::new_async().await;
696 let _mock = server
697 .mock("GET", "/")
698 .with_status(404)
699 .with_header("content-type", "application/problem+json")
700 .with_body("not json at all")
701 .create_async()
702 .await;
703
704 let resp = reqwest::get(server.url()).await.unwrap();
705 let err: api_bones::ApiError = resp
706 .problem_json_or_json::<serde_json::Value>()
707 .await
708 .unwrap_err();
709 assert_eq!(err.status, 400);
710 }
711
712 #[tokio::test]
713 async fn problem_json_or_json_success_invalid_json_body() {
714 let mut server = mockito::Server::new_async().await;
715 let _mock = server
716 .mock("GET", "/")
717 .with_status(200)
718 .with_header("content-type", "application/json")
719 .with_body("not json")
720 .create_async()
721 .await;
722
723 let resp = reqwest::get(server.url()).await.unwrap();
724 let err: api_bones::ApiError = resp
725 .problem_json_or_json::<serde_json::Value>()
726 .await
727 .unwrap_err();
728 assert_eq!(err.status, 400);
729 }
730
731 #[test]
732 fn parse_link_next_basic() {
733 let entry = r#"<https://api.example.com/items?after=abc>; rel="next""#;
734 assert_eq!(
735 parse_link_next(entry),
736 Some("https://api.example.com/items?after=abc".to_owned())
737 );
738 }
739
740 #[test]
741 fn parse_link_next_no_match() {
742 let entry = r#"<https://api.example.com/items?before=abc>; rel="prev""#;
743 assert!(parse_link_next(entry).is_none());
744 }
745
746 #[test]
747 fn parse_link_next_unquoted_rel() {
748 let entry = "<https://example.com/next>; rel=next";
749 assert_eq!(
750 parse_link_next(entry),
751 Some("https://example.com/next".to_owned())
752 );
753 }
754
755 #[test]
756 fn map_status_400() {
757 let err = map_status_to_api_error(400, "bad".into());
758 assert_eq!(err.status, 400);
759 }
760
761 #[test]
762 fn map_status_404() {
763 let err = map_status_to_api_error(404, "not found".into());
764 assert_eq!(err.status, 404);
765 }
766
767 #[test]
768 fn map_status_unknown_5xx() {
769 let err = map_status_to_api_error(599, "oops".into());
770 assert_eq!(err.status, 500);
771 }
772
773 #[tokio::test]
778 async fn from_response_parses_problem_json() {
779 let mut server = mockito::Server::new_async().await;
780 let body = r#"{"type":"urn:api-bones:error:resource-not-found","title":"Not Found","status":404,"detail":"gone","extensions":{}}"#;
781 let _mock = server
782 .mock("GET", "/")
783 .with_status(404)
784 .with_header("content-type", "application/problem+json")
785 .with_body(body)
786 .create_async()
787 .await;
788
789 let resp = reqwest::get(server.url()).await.unwrap();
790 let err = from_response(resp).await;
791 assert_eq!(err.status, 404);
792 assert_eq!(err.detail, "gone");
793 }
794
795 #[tokio::test]
796 async fn from_response_plain_text_fallback() {
797 let mut server = mockito::Server::new_async().await;
798 let _mock = server
799 .mock("GET", "/")
800 .with_status(503)
801 .with_header("content-type", "text/plain")
802 .with_body("service down")
803 .create_async()
804 .await;
805
806 let resp = reqwest::get(server.url()).await.unwrap();
807 let err = from_response(resp).await;
808 assert_eq!(err.status, 503);
809 assert_eq!(err.detail, "service down");
810 }
811
812 #[tokio::test]
813 async fn from_response_unknown_type_uri_falls_back_to_status_code() {
814 let body = r#"{"type":"urn:unknown:error:whatever","title":"Oops","status":422,"detail":"bad input","extensions":{}}"#;
816 let mut server = mockito::Server::new_async().await;
817 let _mock = server
818 .mock("GET", "/")
819 .with_status(422)
820 .with_header("content-type", "application/problem+json")
821 .with_body(body)
822 .create_async()
823 .await;
824
825 let resp = reqwest::get(server.url()).await.unwrap();
826 let err = from_response(resp).await;
827 assert_eq!(err.status, 422);
828 assert_eq!(err.detail, "bad input");
829 }
830
831 #[tokio::test]
832 async fn from_response_problem_json_parse_error_fallback() {
833 let mut server = mockito::Server::new_async().await;
834 let _mock = server
835 .mock("GET", "/")
836 .with_status(400)
837 .with_header("content-type", "application/problem+json")
838 .with_body("not valid json")
839 .create_async()
840 .await;
841
842 let resp = reqwest::get(server.url()).await.unwrap();
843 let err = from_response(resp).await;
844 assert_eq!(err.status, 400);
845 assert_eq!(err.detail, "failed to parse problem+json response");
846 }
847
848 #[cfg(feature = "uuid")]
849 #[tokio::test]
850 async fn from_response_extracts_uuid_instance() {
851 let id = uuid::Uuid::nil();
852 let body = format!(
853 r#"{{"type":"urn:api-bones:error:bad-request","title":"Bad Request","status":400,"detail":"bad","instance":"urn:uuid:{id}","extensions":{{}}}}"#
854 );
855 let mut server = mockito::Server::new_async().await;
856 let _mock = server
857 .mock("GET", "/")
858 .with_status(400)
859 .with_header("content-type", "application/problem+json")
860 .with_body(body)
861 .create_async()
862 .await;
863
864 let resp = reqwest::get(server.url()).await.unwrap();
865 let err = from_response(resp).await;
866 assert_eq!(err.request_id, Some(id));
867 }
868
869 #[tokio::test]
870 async fn from_response_text_read_error_falls_back_to_upstream_error() {
871 use std::io::Write as _;
872 use std::net::TcpListener as StdTcpListener;
873
874 let std_listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
875 std_listener.set_nonblocking(false).unwrap();
876 let addr = std_listener.local_addr().unwrap();
877
878 std::thread::spawn(move || {
879 use std::io::Read as _;
880 let (mut stream, _) = std_listener.accept().unwrap();
881 let mut buf = [0u8; 4096];
882 let _ = stream.read(&mut buf);
883 stream
884 .write_all(b"HTTP/1.1 500 Internal Server Error\r\ncontent-length: 100\r\n\r\n")
885 .unwrap();
886 });
887
888 let resp = reqwest::get(format!("http://{addr}/")).await.unwrap();
889 let err = from_response(resp).await;
890 assert_eq!(err.status, 500);
891 assert_eq!(err.detail, "upstream error");
892 }
893}