1use std::time::Duration;
2
3use backoff::{backoff::Backoff, ExponentialBackoff};
4use reqwest::{Client, RequestBuilder, Response, StatusCode};
5use tokio::time::timeout;
6use tracing::{debug, instrument};
7
8use crate::{
9 api::OdosApiErrorResponse,
10 error::{OdosError, Result},
11 error_code::OdosErrorCode,
12};
13
14#[derive(Debug, Clone)]
44pub struct RetryConfig {
45 pub max_retries: u32,
47
48 pub initial_backoff_ms: u64,
50
51 pub retry_server_errors: bool,
53
54 pub retry_predicate: Option<fn(&OdosError) -> bool>,
59}
60
61impl Default for RetryConfig {
62 fn default() -> Self {
63 Self {
64 max_retries: 3,
65 initial_backoff_ms: 100,
66 retry_server_errors: true,
67 retry_predicate: None,
68 }
69 }
70}
71
72impl RetryConfig {
73 pub fn no_retries() -> Self {
78 Self {
79 max_retries: 0,
80 ..Default::default()
81 }
82 }
83
84 pub fn conservative() -> Self {
90 Self {
91 max_retries: 2,
92 retry_server_errors: false,
93 ..Default::default()
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
102pub struct ClientConfig {
103 pub timeout: Duration,
105 pub connect_timeout: Duration,
107 pub retry_config: RetryConfig,
109 pub max_connections: usize,
111 pub pool_idle_timeout: Duration,
113}
114
115impl Default for ClientConfig {
116 fn default() -> Self {
117 Self {
118 timeout: Duration::from_secs(30),
119 connect_timeout: Duration::from_secs(10),
120 retry_config: RetryConfig::default(),
121 max_connections: 20,
122 pool_idle_timeout: Duration::from_secs(90),
123 }
124 }
125}
126
127impl ClientConfig {
128 pub fn no_retries() -> Self {
132 Self {
133 retry_config: RetryConfig::no_retries(),
134 ..Default::default()
135 }
136 }
137
138 pub fn conservative() -> Self {
142 Self {
143 retry_config: RetryConfig::conservative(),
144 ..Default::default()
145 }
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct OdosHttpClient {
152 client: Client,
153 config: ClientConfig,
154}
155
156impl OdosHttpClient {
157 pub fn new() -> Result<Self> {
159 Self::with_config(ClientConfig::default())
160 }
161
162 pub fn with_config(config: ClientConfig) -> Result<Self> {
164 let client = Client::builder()
165 .timeout(config.timeout)
166 .connect_timeout(config.connect_timeout)
167 .pool_max_idle_per_host(config.max_connections)
168 .pool_idle_timeout(config.pool_idle_timeout)
169 .build()
170 .map_err(OdosError::Http)?;
171
172 Ok(Self { client, config })
173 }
174
175 #[instrument(skip(self, request_builder_fn), level = "debug")]
177 pub async fn execute_with_retry<F>(&self, request_builder_fn: F) -> Result<Response>
178 where
179 F: Fn() -> RequestBuilder + Clone,
180 {
181 let initial_backoff_duration =
182 Duration::from_millis(self.config.retry_config.initial_backoff_ms);
183 let mut backoff = ExponentialBackoff {
184 initial_interval: initial_backoff_duration,
185 max_interval: Duration::from_secs(30), max_elapsed_time: Some(self.config.timeout),
187 ..Default::default()
188 };
189
190 let mut attempt = 0;
191
192 loop {
193 attempt += 1;
194
195 let request = match request_builder_fn().build() {
196 Ok(req) => req,
197 Err(e) => return Err(OdosError::Http(e)),
198 };
199
200 let last_error = match timeout(self.config.timeout, self.client.execute(request)).await
201 {
202 Ok(Ok(response)) if response.status().is_success() => {
203 return Ok(response);
204 }
205 Ok(Ok(response)) => {
206 let status = response.status();
207
208 if status == StatusCode::TOO_MANY_REQUESTS {
209 let retry_after = extract_retry_after(&response);
210
211 let (message, _code, _trace_id) = parse_error_response(response).await;
213
214 let error =
215 OdosError::rate_limit_error_with_retry_after(message, retry_after);
216
217 if !self.should_retry(&error, attempt) {
219 return Err(error);
220 }
221
222 if let Some(delay) = retry_after {
223 if !delay.is_zero() {
225 debug!(attempt, retry_after_secs = delay.as_secs());
226 tokio::time::sleep(delay).await;
227 continue;
228 }
229 }
230 error
231 } else {
232 let (message, code, trace_id) = parse_error_response(response).await;
234
235 let error = OdosError::api_error_with_code(status, message, code, trace_id);
236
237 if !self.should_retry(&error, attempt) {
238 return Err(error);
239 }
240
241 error
242 }
243 }
244 Ok(Err(e)) => {
245 let error = OdosError::Http(e);
246
247 if !self.should_retry(&error, attempt) {
248 return Err(error);
249 }
250 debug!(attempt, error = %error);
251 error
252 }
253 Err(_) => {
254 let error = OdosError::timeout_error("Request timed out");
255 debug!(attempt, timeout = ?self.config.timeout);
256 error
257 }
258 };
259
260 if attempt >= self.config.retry_config.max_retries {
262 return Err(last_error);
263 }
264
265 if let Some(delay) = backoff.next_backoff() {
266 tokio::time::sleep(delay).await;
267 } else {
268 return Err(last_error);
269 }
270 }
271 }
272
273 pub fn inner(&self) -> &Client {
275 &self.client
276 }
277
278 pub fn config(&self) -> &ClientConfig {
280 &self.config
281 }
282
283 fn should_retry(&self, error: &OdosError, attempts: u32) -> bool {
301 let retry_config = &self.config.retry_config;
302
303 if attempts >= retry_config.max_retries {
305 return false;
306 }
307
308 if let Some(predicate) = retry_config.retry_predicate {
310 return predicate(error);
311 }
312
313 match error {
315 OdosError::RateLimit { .. } => false,
317
318 OdosError::Api { status, .. } if status.is_client_error() => false,
320
321 OdosError::Api { status, .. } if status.is_server_error() => {
323 retry_config.retry_server_errors
324 }
325
326 OdosError::Http(err) => err.is_timeout() || err.is_connect() || err.is_request(),
328
329 OdosError::Timeout(_) => true,
331
332 _ => false,
334 }
335 }
336}
337
338fn extract_retry_after(response: &Response) -> Option<Duration> {
340 response
341 .headers()
342 .get("retry-after")
343 .and_then(|v| v.to_str().ok())
344 .and_then(|s| s.parse::<u64>().ok())
345 .map(Duration::from_secs)
346}
347
348async fn parse_error_response(
354 response: Response,
355) -> (
356 String,
357 Option<OdosErrorCode>,
358 Option<crate::error_code::TraceId>,
359) {
360 let body_text = match response.text().await {
362 Ok(text) => text,
363 Err(e) => return (format!("Failed to read response body: {}", e), None, None),
364 };
365
366 match serde_json::from_str::<OdosApiErrorResponse>(&body_text) {
368 Ok(error_response) => {
369 let error_code = OdosErrorCode::from(error_response.error_code);
371 (
372 error_response.detail,
373 Some(error_code),
374 Some(error_response.trace_id),
375 )
376 }
377 Err(_) => {
378 (body_text, None, None)
380 }
381 }
382}
383
384impl Default for OdosHttpClient {
385 fn default() -> Self {
397 Self::new().expect("Failed to create default HTTP client")
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404 use std::sync::{Arc, Mutex};
405 use std::time::Duration;
406 use wiremock::{
407 matchers::{method, path},
408 Mock, MockServer, Request, ResponseTemplate,
409 };
410
411 fn create_retry_mock(
413 first_status: u16,
414 first_body: String,
415 success_after: usize,
416 ) -> impl Fn(&Request) -> ResponseTemplate {
417 let attempt_count = Arc::new(Mutex::new(0));
418 move |_req: &Request| {
419 let mut count = attempt_count.lock().unwrap();
420 *count += 1;
421
422 if *count < success_after {
423 ResponseTemplate::new(first_status).set_body_string(&first_body)
424 } else {
425 ResponseTemplate::new(200).set_body_string("Success")
426 }
427 }
428 }
429
430 fn create_test_client(max_retries: u32, timeout_ms: u64) -> OdosHttpClient {
432 let config = ClientConfig {
433 timeout: Duration::from_millis(timeout_ms),
434 retry_config: RetryConfig {
435 max_retries,
436 initial_backoff_ms: 10,
437 ..Default::default()
438 },
439 ..Default::default()
440 };
441 OdosHttpClient::with_config(config).unwrap()
442 }
443
444 #[test]
445 fn test_client_config_default() {
446 let config = ClientConfig::default();
447 assert_eq!(config.timeout, Duration::from_secs(30));
448 assert_eq!(config.retry_config.max_retries, 3);
449 assert_eq!(config.max_connections, 20);
450 }
451
452 #[tokio::test]
453 async fn test_client_creation() {
454 let client = OdosHttpClient::new();
455 assert!(client.is_ok());
456 }
457
458 #[tokio::test]
459 async fn test_client_with_custom_config() {
460 let config = ClientConfig {
461 timeout: Duration::from_secs(60),
462 retry_config: RetryConfig {
463 max_retries: 5,
464 ..Default::default()
465 },
466 ..Default::default()
467 };
468 let client = OdosHttpClient::with_config(config.clone());
469 assert!(client.is_ok());
470
471 let client = client.unwrap();
472 assert_eq!(client.config().timeout, Duration::from_secs(60));
473 assert_eq!(client.config().retry_config.max_retries, 5);
474 }
475
476 #[tokio::test]
477 async fn test_rate_limit_with_retry_after() {
478 let mock_server = MockServer::start().await;
479
480 Mock::given(method("GET"))
482 .and(path("/test"))
483 .respond_with(
484 ResponseTemplate::new(429)
485 .set_body_string("Rate limit exceeded")
486 .insert_header("retry-after", "1"),
487 )
488 .expect(1) .mount(&mock_server)
490 .await;
491
492 let client = create_test_client(3, 30000);
493 let response = client
494 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
495 .await;
496
497 assert!(
499 response.is_err(),
500 "Rate limit should return error immediately"
501 );
502
503 if let Err(OdosError::RateLimit {
504 message,
505 retry_after,
506 }) = response
507 {
508 assert!(message.contains("Rate limit"));
509 assert_eq!(retry_after, Some(Duration::from_secs(1)));
510 } else {
511 panic!("Expected RateLimit error, got: {response:?}");
512 }
513 }
514
515 #[tokio::test]
516 async fn test_rate_limit_without_retry_after() {
517 let mock_server = MockServer::start().await;
518
519 Mock::given(method("GET"))
521 .and(path("/test"))
522 .respond_with(ResponseTemplate::new(429).set_body_string("Rate limit exceeded"))
523 .expect(1) .mount(&mock_server)
525 .await;
526
527 let client = create_test_client(3, 30000);
528 let response = client
529 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
530 .await;
531
532 assert!(
534 response.is_err(),
535 "Rate limit should return error immediately"
536 );
537
538 if let Err(OdosError::RateLimit {
539 message,
540 retry_after,
541 }) = response
542 {
543 assert!(message.contains("Rate limit"));
544 assert_eq!(retry_after, None);
545 } else {
546 panic!("Expected RateLimit error, got: {response:?}");
547 }
548 }
549
550 #[tokio::test]
551 async fn test_non_retryable_error() {
552 let mock_server = MockServer::start().await;
553
554 Mock::given(method("GET"))
556 .and(path("/test"))
557 .respond_with(ResponseTemplate::new(400).set_body_string("Bad request"))
558 .expect(1)
559 .mount(&mock_server)
560 .await;
561
562 let client = OdosHttpClient::with_config(ClientConfig::default()).unwrap();
563
564 let response = client
565 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
566 .await;
567
568 assert!(response.is_err());
570 if let Err(e) = response {
571 assert!(!e.is_retryable());
572 }
573 }
574
575 #[tokio::test]
576 async fn test_retry_exhaustion_returns_last_error() {
577 let mock_server = MockServer::start().await;
578
579 Mock::given(method("GET"))
581 .and(path("/test"))
582 .respond_with(ResponseTemplate::new(503).set_body_string("Service unavailable"))
583 .mount(&mock_server)
584 .await;
585
586 let client = create_test_client(2, 30000);
587
588 let response = client
589 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
590 .await;
591
592 assert!(response.is_err());
594 if let Err(e) = response {
595 assert!(
596 matches!(e, OdosError::Api { status, .. } if status == StatusCode::SERVICE_UNAVAILABLE)
597 );
598 }
599 }
600
601 #[tokio::test]
602 async fn test_timeout_error() {
603 let mock_server = MockServer::start().await;
604
605 Mock::given(method("GET"))
607 .and(path("/test"))
608 .respond_with(
609 ResponseTemplate::new(200)
610 .set_body_string("Success")
611 .set_delay(Duration::from_secs(5)),
612 )
613 .mount(&mock_server)
614 .await;
615
616 let client = create_test_client(2, 100);
617
618 let response = client
619 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
620 .await;
621
622 assert!(response.is_err());
624 if let Err(e) = response {
625 let is_timeout = matches!(e, OdosError::Timeout(_))
627 || matches!(e, OdosError::Http(ref err) if err.is_timeout());
628 assert!(is_timeout, "Expected timeout error, got: {e:?}");
629 }
630 }
631
632 #[tokio::test]
633 async fn test_invalid_request_builder_fails_immediately() {
634 let client = OdosHttpClient::default();
635
636 let bad_builder = || {
639 let mut builder = client.inner().get("http://localhost");
640 builder = builder.header("x".repeat(100000), "value");
642 builder
643 };
644
645 let result = client.execute_with_retry(bad_builder).await;
646
647 assert!(result.is_err());
649 if let Err(e) = result {
650 assert!(matches!(e, OdosError::Http(_)));
651 }
652 }
653
654 #[tokio::test]
655 async fn test_retryable_500_error() {
656 let mock_server = MockServer::start().await;
657
658 Mock::given(method("GET"))
659 .and(path("/test"))
660 .respond_with(create_retry_mock(
661 500,
662 "Internal server error".to_string(),
663 2,
664 ))
665 .mount(&mock_server)
666 .await;
667
668 let client = create_test_client(3, 30000);
669 let response = client
670 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
671 .await;
672
673 assert!(response.is_ok(), "500 error should be retried and succeed");
674 }
675
676 #[tokio::test]
677 async fn test_retryable_502_bad_gateway() {
678 let mock_server = MockServer::start().await;
679
680 Mock::given(method("GET"))
681 .and(path("/test"))
682 .respond_with(create_retry_mock(502, "Bad gateway".to_string(), 2))
683 .mount(&mock_server)
684 .await;
685
686 let client = create_test_client(3, 30000);
687 let response = client
688 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
689 .await;
690
691 assert!(response.is_ok(), "502 error should be retried and succeed");
692 }
693
694 #[tokio::test]
695 async fn test_retryable_503_service_unavailable() {
696 let mock_server = MockServer::start().await;
697
698 Mock::given(method("GET"))
699 .and(path("/test"))
700 .respond_with(create_retry_mock(503, "Service unavailable".to_string(), 3))
701 .mount(&mock_server)
702 .await;
703
704 let client = create_test_client(3, 30000);
705 let response = client
706 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
707 .await;
708
709 assert!(response.is_ok(), "503 error should be retried and succeed");
710 }
711
712 #[tokio::test]
713 async fn test_retryable_504_gateway_timeout() {
714 let mock_server = MockServer::start().await;
715
716 Mock::given(method("GET"))
717 .and(path("/test"))
718 .respond_with(create_retry_mock(504, "Gateway timeout".to_string(), 2))
719 .mount(&mock_server)
720 .await;
721
722 let client = create_test_client(3, 30000);
723 let response = client
724 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
725 .await;
726
727 assert!(response.is_ok(), "504 error should be retried and succeed");
728 }
729
730 #[tokio::test]
731 async fn test_network_error_retryable() {
732 let client = create_test_client(2, 100);
734
735 let response = client
736 .execute_with_retry(|| client.inner().get("http://localhost:1"))
737 .await;
738
739 assert!(response.is_err());
741 if let Err(e) = response {
742 assert!(matches!(e, OdosError::Http(_)));
743 }
744 }
745
746 #[test]
747 fn test_accessor_methods() {
748 let config = ClientConfig {
749 timeout: Duration::from_secs(45),
750 retry_config: RetryConfig {
751 max_retries: 5,
752 ..Default::default()
753 },
754 ..Default::default()
755 };
756 let client = OdosHttpClient::with_config(config.clone()).unwrap();
757
758 assert_eq!(client.config().timeout, Duration::from_secs(45));
760 assert_eq!(client.config().retry_config.max_retries, 5);
761
762 let _inner: &reqwest::Client = client.inner();
764 }
765
766 #[test]
767 fn test_default_client() {
768 let client = OdosHttpClient::default();
769
770 assert_eq!(client.config().timeout, Duration::from_secs(30));
772 assert_eq!(client.config().retry_config.max_retries, 3);
773 }
774
775 #[test]
776 fn test_extract_retry_after_valid_numeric() {
777 let response = reqwest::Response::from(
778 http::Response::builder()
779 .status(429)
780 .header("retry-after", "30")
781 .body("")
782 .unwrap(),
783 );
784
785 let retry_after = extract_retry_after(&response);
786 assert_eq!(retry_after, Some(Duration::from_secs(30)));
787 }
788
789 #[test]
790 fn test_extract_retry_after_missing_header() {
791 let response =
792 reqwest::Response::from(http::Response::builder().status(429).body("").unwrap());
793
794 let retry_after = extract_retry_after(&response);
795 assert_eq!(retry_after, None);
796 }
797
798 #[test]
799 fn test_extract_retry_after_malformed_value() {
800 let response = reqwest::Response::from(
801 http::Response::builder()
802 .status(429)
803 .header("retry-after", "not-a-number")
804 .body("")
805 .unwrap(),
806 );
807
808 let retry_after = extract_retry_after(&response);
809 assert_eq!(retry_after, None);
810 }
811
812 #[test]
813 fn test_extract_retry_after_zero_value() {
814 let response = reqwest::Response::from(
815 http::Response::builder()
816 .status(429)
817 .header("retry-after", "0")
818 .body("")
819 .unwrap(),
820 );
821
822 let retry_after = extract_retry_after(&response);
823 assert_eq!(retry_after, Some(Duration::from_secs(0)));
824 }
825
826 #[tokio::test]
827 async fn test_rate_limit_with_retry_after_zero() {
828 let mock_server = MockServer::start().await;
829
830 Mock::given(method("GET"))
832 .and(path("/test"))
833 .respond_with(
834 ResponseTemplate::new(429)
835 .set_body_string("Rate limit exceeded")
836 .insert_header("retry-after", "0"),
837 )
838 .expect(1) .mount(&mock_server)
840 .await;
841
842 let client = create_test_client(3, 30000);
843 let response = client
844 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
845 .await;
846
847 assert!(
849 response.is_err(),
850 "Rate limit should return error immediately"
851 );
852
853 if let Err(OdosError::RateLimit {
854 message,
855 retry_after,
856 }) = response
857 {
858 assert!(message.contains("Rate limit"));
859 assert_eq!(retry_after, Some(Duration::from_secs(0)));
860 } else {
861 panic!("Expected RateLimit error, got: {response:?}");
862 }
863 }
864
865 #[test]
866 fn test_extract_retry_after_large_value() {
867 let response = reqwest::Response::from(
868 http::Response::builder()
869 .status(429)
870 .header("retry-after", "3600")
871 .body("")
872 .unwrap(),
873 );
874
875 let retry_after = extract_retry_after(&response);
876 assert_eq!(retry_after, Some(Duration::from_secs(3600)));
877 }
878
879 #[test]
880 fn test_extract_retry_after_invalid_utf8() {
881 let response = reqwest::Response::from(
882 http::Response::builder()
883 .status(429)
884 .header("retry-after", vec![0xff, 0xfe])
885 .body("")
886 .unwrap(),
887 );
888
889 let retry_after = extract_retry_after(&response);
890 assert_eq!(retry_after, None);
891 }
892
893 #[tokio::test]
894 async fn test_max_retries_zero() {
895 let mock_server = MockServer::start().await;
896
897 Mock::given(method("GET"))
899 .and(path("/test"))
900 .respond_with(ResponseTemplate::new(500).set_body_string("Server error"))
901 .expect(1) .mount(&mock_server)
903 .await;
904
905 let client = create_test_client(0, 30000); let response = client
907 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
908 .await;
909
910 assert!(response.is_err());
912 if let Err(e) = response {
913 assert!(
914 matches!(e, OdosError::Api { status, .. } if status == StatusCode::INTERNAL_SERVER_ERROR)
915 );
916 }
917 }
918
919 #[tokio::test]
920 async fn test_parse_structured_error_response() {
921 use crate::error_code::OdosErrorCode;
922
923 let error_json = r#"{
925 "detail": "Error getting quote, please try again",
926 "traceId": "10becdc8-a021-4491-8201-a17b657204e0",
927 "errorCode": 2999
928 }"#;
929
930 let http_response = http::Response::builder()
931 .status(500)
932 .body(error_json)
933 .unwrap();
934 let response = reqwest::Response::from(http_response);
935
936 let (message, code, trace_id) = parse_error_response(response).await;
937
938 assert_eq!(message, "Error getting quote, please try again");
939 assert!(code.is_some());
940 assert_eq!(code.unwrap(), OdosErrorCode::AlgoInternal);
941 assert!(trace_id.is_some());
942 assert_eq!(
943 trace_id.unwrap().to_string(),
944 "10becdc8-a021-4491-8201-a17b657204e0"
945 );
946 }
947
948 #[tokio::test]
949 async fn test_parse_unstructured_error_response() {
950 let http_response = http::Response::builder()
952 .status(500)
953 .body("Internal server error")
954 .unwrap();
955 let response = reqwest::Response::from(http_response);
956
957 let (message, code, trace_id) = parse_error_response(response).await;
958
959 assert_eq!(message, "Internal server error");
960 assert!(code.is_none());
961 assert!(trace_id.is_none());
962 }
963
964 #[tokio::test]
965 async fn test_api_error_with_structured_response() {
966 let mock_server = MockServer::start().await;
967
968 let error_json = r#"{
969 "detail": "Invalid chain ID",
970 "traceId": "a0b1c2d3-e4f5-6789-0abc-def123456789",
971 "errorCode": 4001
972 }"#;
973
974 Mock::given(method("GET"))
975 .and(path("/test"))
976 .respond_with(ResponseTemplate::new(400).set_body_string(error_json))
977 .expect(1)
978 .mount(&mock_server)
979 .await;
980
981 let client = create_test_client(0, 30000);
982 let response = client
983 .execute_with_retry(|| client.inner().get(format!("{}/test", mock_server.uri())))
984 .await;
985
986 assert!(response.is_err());
987 if let Err(e) = response {
988 assert!(matches!(e, OdosError::Api { .. }));
990
991 let error_code = e.error_code();
993 assert!(error_code.is_some());
994 assert!(error_code.unwrap().is_invalid_chain_id());
995
996 let trace_id = e.trace_id();
998 assert!(trace_id.is_some());
999 } else {
1000 panic!("Expected error, got success");
1001 }
1002 }
1003
1004 #[tokio::test]
1005 async fn test_client_config_failure() {
1006 let config = ClientConfig {
1009 max_connections: usize::MAX,
1010 ..Default::default()
1011 };
1012
1013 let result = OdosHttpClient::with_config(config);
1015
1016 match result {
1019 Ok(_) => {
1020 }
1022 Err(e) => {
1023 assert!(matches!(e, OdosError::Http(_)));
1025 }
1026 }
1027 }
1028}