1use crate::config::{
2 ClientAuthConfig, HttpClientConfig, RedirectConfig, RetryConfig, TlsConfig, TlsRootConfig,
3 TlsVersion, TransportSecurity,
4};
5use crate::error::HttpError;
6use crate::layers::{OtelLayer, RetryLayer, SecureRedirectPolicy, UserAgentLayer};
7use crate::response::ResponseBody;
8use crate::tls;
9use bytes::Bytes;
10use http::Response;
11use http_body_util::{BodyExt, Full};
12use hyper_rustls::HttpsConnector;
13use hyper_util::client::legacy::Client;
14use hyper_util::client::legacy::connect::HttpConnector;
15use hyper_util::rt::{TokioExecutor, TokioTimer};
16use std::time::Duration;
17use tower::buffer::Buffer;
18use tower::limit::ConcurrencyLimitLayer;
19use tower::load_shed::LoadShedLayer;
20use tower::timeout::TimeoutLayer;
21use tower::util::BoxCloneService;
22use tower::{ServiceBuilder, ServiceExt};
23use tower_http::decompression::DecompressionLayer;
24use tower_http::follow_redirect::FollowRedirectLayer;
25
26type InnerService =
28 BoxCloneService<http::Request<Full<Bytes>>, http::Response<ResponseBody>, HttpError>;
29
30pub struct HttpClientBuilder {
32 config: HttpClientConfig,
33 auth_layer: Option<Box<dyn FnOnce(InnerService) -> InnerService + Send>>,
34 metrics_layer: Option<Box<dyn FnOnce(InnerService) -> InnerService + Send>>,
35}
36
37impl HttpClientBuilder {
38 #[must_use]
40 pub fn new() -> Self {
41 Self {
42 config: HttpClientConfig::default(),
43 auth_layer: None,
44 metrics_layer: None,
45 }
46 }
47
48 #[must_use]
50 pub fn with_config(config: HttpClientConfig) -> Self {
51 Self {
52 config,
53 auth_layer: None,
54 metrics_layer: None,
55 }
56 }
57
58 #[must_use]
63 pub fn timeout(mut self, timeout: Duration) -> Self {
64 self.config.request_timeout = timeout;
65 self
66 }
67
68 #[must_use]
74 pub fn total_timeout(mut self, timeout: Duration) -> Self {
75 self.config.total_timeout = Some(timeout);
76 self
77 }
78
79 #[must_use]
81 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
82 self.config.user_agent = user_agent.into();
83 self
84 }
85
86 #[must_use]
88 pub fn retry(mut self, retry: Option<RetryConfig>) -> Self {
89 self.config.retry = retry;
90 self
91 }
92
93 #[must_use]
95 pub fn max_body_size(mut self, size: usize) -> Self {
96 self.config.max_body_size = size;
97 self
98 }
99
100 #[must_use]
104 pub fn transport(mut self, transport: TransportSecurity) -> Self {
105 self.config.transport = transport;
106 self
107 }
108
109 #[must_use]
115 pub fn deny_insecure_http(mut self) -> Self {
116 tracing::debug!(
117 target: "toolkit_http::security",
118 "deny_insecure_http() called - enforcing TLS for all connections"
119 );
120 self.config.transport = TransportSecurity::TlsOnly;
121 self
122 }
123
124 #[must_use]
137 pub fn tls(mut self, tls: TlsConfig) -> Self {
138 self.config.tls = tls;
139 self
140 }
141
142 #[must_use]
150 pub fn tls_min_version(mut self, min_version: TlsVersion) -> Self {
151 self.config.tls.min_version = min_version;
152 self
153 }
154
155 #[must_use]
167 pub fn client_auth(mut self, client_auth: ClientAuthConfig) -> Self {
168 self.config.tls.client_auth = Some(client_auth);
169 self
170 }
171
172 #[must_use]
177 pub fn with_otel(mut self) -> Self {
178 self.config.otel = true;
179 self
180 }
181
182 #[must_use]
190 pub fn with_auth_layer(
191 mut self,
192 wrap: impl FnOnce(InnerService) -> InnerService + Send + 'static,
193 ) -> Self {
194 self.auth_layer = Some(Box::new(wrap));
195 self
196 }
197
198 #[must_use]
210 pub fn with_metrics_layer(
211 mut self,
212 wrap: impl FnOnce(InnerService) -> InnerService + Send + 'static,
213 ) -> Self {
214 self.metrics_layer = Some(Box::new(wrap));
215 self
216 }
217
218 #[must_use]
227 pub fn buffer_capacity(mut self, capacity: usize) -> Self {
228 self.config.buffer_capacity = capacity.max(1);
230 self
231 }
232
233 #[must_use]
238 pub fn max_redirects(mut self, max_redirects: usize) -> Self {
239 self.config.redirect.max_redirects = max_redirects;
240 self
241 }
242
243 #[must_use]
248 pub fn no_redirects(mut self) -> Self {
249 self.config.redirect = RedirectConfig::disabled();
250 self
251 }
252
253 #[must_use]
268 pub fn redirect(mut self, config: RedirectConfig) -> Self {
269 self.config.redirect = config;
270 self
271 }
272
273 #[must_use]
280 pub fn pool_idle_timeout(mut self, timeout: Option<Duration>) -> Self {
281 self.config.pool_idle_timeout = timeout;
282 self
283 }
284
285 #[must_use]
293 pub fn pool_max_idle_per_host(mut self, max: usize) -> Self {
294 self.config.pool_max_idle_per_host = max;
295 self
296 }
297
298 pub fn build(self) -> Result<crate::HttpClient, HttpError> {
307 #[cfg(feature = "fips")]
312 if self.config.transport == TransportSecurity::AllowInsecureHttp {
313 tracing::warn!(
314 target: "toolkit_http::security",
315 "rejecting AllowInsecureHttp under --features fips: returning HttpError::InsecureTransport"
316 );
317 return Err(HttpError::InsecureTransport);
318 }
319
320 let timeout = self.config.request_timeout;
321 let total_timeout = self.config.total_timeout;
322
323 let https = build_https_connector(
325 self.config.tls_roots,
326 self.config.transport,
327 &self.config.tls,
328 )?;
329
330 let mut client_builder = Client::builder(TokioExecutor::new());
332
333 client_builder
336 .pool_timer(TokioTimer::new())
337 .pool_max_idle_per_host(self.config.pool_max_idle_per_host)
338 .http2_only(false); if let Some(idle_timeout) = self.config.pool_idle_timeout {
342 client_builder.pool_idle_timeout(idle_timeout);
343 }
344
345 let hyper_client = client_builder.build::<_, Full<Bytes>>(https);
346
347 let ua_layer = UserAgentLayer::try_new(&self.config.user_agent)?;
349
350 let redirect_policy = SecureRedirectPolicy::new(self.config.redirect.clone());
385
386 let service = ServiceBuilder::new()
388 .layer(TimeoutLayer::new(timeout))
389 .layer(ua_layer)
390 .layer(DecompressionLayer::new())
391 .layer(FollowRedirectLayer::with_policy(redirect_policy))
392 .service(hyper_client);
393
394 let service = service.map_response(map_decompression_response);
400
401 let service = service.map_err(move |e: tower::BoxError| map_tower_error(e, timeout));
403
404 let mut boxed_service = service.boxed_clone();
406
407 if let Some(wrap) = self.auth_layer {
410 boxed_service = wrap(boxed_service);
411 }
412
413 if let Some(ref retry_config) = self.config.retry {
425 let retry_layer = RetryLayer::with_total_timeout(retry_config.clone(), total_timeout);
426 let retry_service = ServiceBuilder::new()
427 .layer(retry_layer)
428 .service(boxed_service);
429 boxed_service = retry_service.boxed_clone();
430 }
431
432 if let Some(wrap) = self.metrics_layer {
435 boxed_service = wrap(boxed_service);
436 }
437
438 if let Some(rate_limit) = self.config.rate_limit
442 && rate_limit.max_concurrent_requests < usize::MAX
443 {
444 let limited_service = ServiceBuilder::new()
445 .layer(LoadShedLayer::new())
446 .layer(ConcurrencyLimitLayer::new(
447 rate_limit.max_concurrent_requests,
448 ))
449 .service(boxed_service);
450 let limited_service = limited_service.map_err(map_load_shed_error);
452 boxed_service = limited_service.boxed_clone();
453 }
454
455 if self.config.otel {
459 let otel_service = ServiceBuilder::new()
460 .layer(OtelLayer::new())
461 .service(boxed_service);
462 boxed_service = otel_service.boxed_clone();
463 }
464
465 let buffer_capacity = self.config.buffer_capacity.max(1);
469 let buffered_service: crate::client::BufferedService =
470 Buffer::new(boxed_service, buffer_capacity);
471
472 Ok(crate::HttpClient {
473 service: buffered_service,
474 max_body_size: self.config.max_body_size,
475 transport_security: self.config.transport,
476 })
477 }
478}
479
480#[cfg(test)]
481impl HttpClientBuilder {
482 fn build_with_inner_service(self, inner: InnerService) -> crate::HttpClient {
490 let mut boxed_service = inner;
491
492 if let Some(ref retry_config) = self.config.retry {
493 let retry_layer =
494 RetryLayer::with_total_timeout(retry_config.clone(), self.config.total_timeout);
495 let retry_service = ServiceBuilder::new()
496 .layer(retry_layer)
497 .service(boxed_service);
498 boxed_service = retry_service.boxed_clone();
499 }
500
501 if let Some(rate_limit) = self.config.rate_limit
502 && rate_limit.max_concurrent_requests < usize::MAX
503 {
504 let limited_service = ServiceBuilder::new()
505 .layer(LoadShedLayer::new())
506 .layer(ConcurrencyLimitLayer::new(
507 rate_limit.max_concurrent_requests,
508 ))
509 .service(boxed_service);
510 let limited_service = limited_service.map_err(map_load_shed_error);
511 boxed_service = limited_service.boxed_clone();
512 }
513
514 let buffer_capacity = self.config.buffer_capacity.max(1);
515 let buffered_service: crate::client::BufferedService =
516 Buffer::new(boxed_service, buffer_capacity);
517
518 crate::HttpClient {
519 service: buffered_service,
520 max_body_size: self.config.max_body_size,
521 transport_security: self.config.transport,
522 }
523 }
524}
525
526impl Default for HttpClientBuilder {
527 fn default() -> Self {
528 Self::new()
529 }
530}
531
532fn map_tower_error(err: tower::BoxError, timeout: Duration) -> HttpError {
538 if err.is::<tower::timeout::error::Elapsed>() {
539 return HttpError::Timeout(timeout);
540 }
541
542 match err.downcast::<HttpError>() {
544 Ok(http_err) => *http_err,
545 Err(other) => HttpError::Transport(other),
546 }
547}
548
549fn map_load_shed_error(err: tower::BoxError) -> HttpError {
551 if err.is::<tower::load_shed::error::Overloaded>() {
552 HttpError::Overloaded
553 } else {
554 match err.downcast::<HttpError>() {
556 Ok(http_err) => *http_err,
557 Err(err) => HttpError::Transport(err),
558 }
559 }
560}
561
562fn map_decompression_response<B>(response: Response<B>) -> Response<ResponseBody>
567where
568 B: hyper::body::Body<Data = Bytes> + Send + Sync + 'static,
569 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
570{
571 let (parts, body) = response.into_parts();
572 let boxed_body: ResponseBody = body.map_err(Into::into).boxed();
576 Response::from_parts(parts, boxed_body)
577}
578
579fn build_https_connector(
593 tls_roots: TlsRootConfig,
594 transport: TransportSecurity,
595 tls: &TlsConfig,
596) -> Result<HttpsConnector<HttpConnector>, HttpError> {
597 let allow_http = transport == TransportSecurity::AllowInsecureHttp;
598
599 let client_config = match tls_roots {
607 TlsRootConfig::WebPki => tls::webpki_roots_client_config(tls),
608 TlsRootConfig::Native => tls::native_roots_client_config(tls),
609 }
610 .map_err(|e| HttpError::Tls(Box::new(e)))?;
611
612 let builder = hyper_rustls::HttpsConnectorBuilder::new().with_tls_config(client_config);
613 let connector = if allow_http {
614 builder.https_or_http().enable_all_versions().build()
615 } else {
616 builder.https_only().enable_all_versions().build()
617 };
618 Ok(connector)
619}
620
621#[cfg(test)]
622#[cfg_attr(coverage_nightly, coverage(off))]
623mod tests {
624 use super::*;
625 use crate::config::DEFAULT_USER_AGENT;
626
627 #[test]
628 fn test_builder_default() {
629 let builder = HttpClientBuilder::new();
630 assert_eq!(builder.config.request_timeout, Duration::from_secs(30));
631 assert_eq!(builder.config.user_agent, DEFAULT_USER_AGENT);
632 assert!(builder.config.retry.is_some());
633 assert_eq!(builder.config.buffer_capacity, 1024);
634 }
635
636 #[test]
637 fn test_builder_with_config() {
638 let config = HttpClientConfig::minimal();
639 let builder = HttpClientBuilder::with_config(config);
640 assert_eq!(builder.config.request_timeout, Duration::from_secs(10));
641 }
642
643 #[test]
644 fn test_builder_timeout() {
645 let builder = HttpClientBuilder::new().timeout(Duration::from_mins(1));
646 assert_eq!(builder.config.request_timeout, Duration::from_mins(1));
647 }
648
649 #[test]
650 fn test_builder_user_agent() {
651 let builder = HttpClientBuilder::new().user_agent("custom/1.0");
652 assert_eq!(builder.config.user_agent, "custom/1.0");
653 }
654
655 #[test]
656 fn test_builder_retry() {
657 let builder = HttpClientBuilder::new().retry(None);
658 assert!(builder.config.retry.is_none());
659 }
660
661 #[test]
662 fn test_builder_max_body_size() {
663 let builder = HttpClientBuilder::new().max_body_size(1024);
664 assert_eq!(builder.config.max_body_size, 1024);
665 }
666
667 #[test]
668 fn test_builder_transport_security() {
669 let builder = HttpClientBuilder::new().transport(TransportSecurity::TlsOnly);
670 assert_eq!(builder.config.transport, TransportSecurity::TlsOnly);
671
672 let builder = HttpClientBuilder::new().deny_insecure_http();
673 assert_eq!(builder.config.transport, TransportSecurity::TlsOnly);
674
675 let builder = HttpClientBuilder::new();
676 #[cfg(not(feature = "fips"))]
677 assert_eq!(
678 builder.config.transport,
679 TransportSecurity::AllowInsecureHttp
680 );
681 #[cfg(feature = "fips")]
682 assert_eq!(builder.config.transport, TransportSecurity::TlsOnly);
683 }
684
685 #[test]
686 fn test_builder_otel() {
687 let builder = HttpClientBuilder::new().with_otel();
688 assert!(builder.config.otel);
689 }
690
691 #[test]
692 fn test_builder_buffer_capacity() {
693 let builder = HttpClientBuilder::new().buffer_capacity(512);
694 assert_eq!(builder.config.buffer_capacity, 512);
695 }
696
697 #[test]
701 fn test_builder_buffer_capacity_zero_clamped() {
702 let builder = HttpClientBuilder::new().buffer_capacity(0);
703 assert_eq!(
704 builder.config.buffer_capacity, 1,
705 "buffer_capacity=0 should be clamped to 1"
706 );
707 }
708
709 #[tokio::test]
711 async fn test_builder_buffer_capacity_zero_in_config_clamped() {
712 let config = HttpClientConfig {
713 buffer_capacity: 0, ..Default::default()
715 };
716 let result = HttpClientBuilder::with_config(config).build();
717 assert!(
719 result.is_ok(),
720 "build() should succeed with capacity clamped to 1"
721 );
722 }
723
724 #[tokio::test]
725 async fn test_builder_build_with_otel() {
726 let client = HttpClientBuilder::new().with_otel().build();
727 assert!(client.is_ok());
728 }
729
730 #[tokio::test]
731 async fn test_builder_with_auth_layer() {
732 let client = HttpClientBuilder::new()
733 .with_auth_layer(|svc| svc) .build();
735 assert!(client.is_ok());
736 }
737
738 #[tokio::test]
739 async fn test_builder_with_metrics_layer() {
740 let client = HttpClientBuilder::new()
741 .with_metrics_layer(|svc| svc) .build();
743 assert!(client.is_ok());
744 }
745
746 #[tokio::test]
747 async fn test_builder_with_metrics_layer_second_call_replaces_first() {
748 use std::sync::Arc;
749 use std::sync::atomic::{AtomicUsize, Ordering};
750
751 let call_count = Arc::new(AtomicUsize::new(0));
752 let call_count2 = call_count.clone();
753
754 let client = HttpClientBuilder::new()
756 .with_metrics_layer(|_svc| {
757 panic!("first metrics layer should have been replaced");
759 })
760 .with_metrics_layer(move |svc| {
761 call_count2.fetch_add(1, Ordering::SeqCst);
762 svc
763 })
764 .build();
765
766 assert!(client.is_ok());
767 assert_eq!(
768 call_count.load(Ordering::SeqCst),
769 1,
770 "second metrics layer must be applied exactly once"
771 );
772 }
773
774 #[tokio::test]
775 async fn test_builder_build() {
776 let client = HttpClientBuilder::new().build();
777 assert!(client.is_ok());
778 }
779
780 #[tokio::test]
781 async fn test_builder_build_with_deny_insecure_http() {
782 let client = HttpClientBuilder::new().deny_insecure_http().build();
783 assert!(client.is_ok());
784 }
785
786 #[tokio::test]
787 async fn test_builder_build_with_sse_config() {
788 use crate::config::HttpClientConfig;
789 let config = HttpClientConfig::sse();
790 let client = HttpClientBuilder::with_config(config).build();
791 assert!(client.is_ok(), "SSE config should build successfully");
792 }
793
794 #[tokio::test]
795 async fn test_builder_build_invalid_user_agent() {
796 let client = HttpClientBuilder::new()
797 .user_agent("invalid\x00agent")
798 .build();
799 assert!(client.is_err());
800 }
801
802 #[tokio::test]
803 async fn test_builder_default_uses_webpki_roots() {
804 let builder = HttpClientBuilder::new();
805 assert_eq!(builder.config.tls_roots, TlsRootConfig::WebPki);
806 let client = builder.build();
808 assert!(client.is_ok());
809 }
810
811 #[tokio::test]
812 async fn test_builder_native_roots() {
813 let config = HttpClientConfig {
814 tls_roots: TlsRootConfig::Native,
815 ..Default::default()
816 };
817 let result = HttpClientBuilder::with_config(config).build();
818
819 match &result {
823 Ok(_) => {
824 }
826 Err(HttpError::Tls(err)) => {
827 let msg = err.to_string();
829 assert!(
830 msg.contains("native root") || msg.contains("certificate"),
831 "TLS error should mention certificates: {msg}"
832 );
833 }
834 Err(other) => {
835 panic!("Unexpected error type: {other:?}");
836 }
837 }
838 }
839
840 #[tokio::test]
841 async fn test_builder_webpki_roots_https_only() {
842 let config = HttpClientConfig {
843 tls_roots: TlsRootConfig::WebPki,
844 transport: TransportSecurity::TlsOnly,
845 ..Default::default()
846 };
847 let client = HttpClientBuilder::with_config(config).build();
848 assert!(client.is_ok());
849 }
850
851 #[tokio::test]
860 async fn test_http2_enabled_for_all_configurations() {
861 let client = HttpClientBuilder::new().build();
863 assert!(
864 client.is_ok(),
865 "WebPki + default transport should build with HTTP/2 enabled"
866 );
867
868 let client = HttpClientBuilder::new()
870 .transport(TransportSecurity::TlsOnly)
871 .build();
872 assert!(
873 client.is_ok(),
874 "WebPki + TlsOnly should build with HTTP/2 enabled"
875 );
876
877 #[cfg(not(feature = "fips"))]
879 {
880 let config = HttpClientConfig {
881 tls_roots: TlsRootConfig::Native,
882 transport: TransportSecurity::AllowInsecureHttp,
883 ..Default::default()
884 };
885 let client = HttpClientBuilder::with_config(config).build();
886 assert!(
887 client.is_ok(),
888 "Native + AllowInsecureHttp should build with HTTP/2 enabled"
889 );
890 }
891
892 let config = HttpClientConfig {
894 tls_roots: TlsRootConfig::Native,
895 transport: TransportSecurity::TlsOnly,
896 ..Default::default()
897 };
898 let client = HttpClientBuilder::with_config(config).build();
899 assert!(
900 client.is_ok(),
901 "Native + TlsOnly should build with HTTP/2 enabled"
902 );
903 }
904
905 #[tokio::test]
910 async fn test_load_shedding_returns_overloaded_error() {
911 use bytes::Bytes;
912 use http::{Request, Response};
913 use http_body_util::Full;
914 use std::future::Future;
915 use std::pin::Pin;
916 use std::sync::Arc;
917 use std::sync::atomic::{AtomicUsize, Ordering};
918 use std::task::{Context, Poll};
919 use tower::Service;
920 use tower::ServiceExt;
921
922 #[derive(Clone)]
924 struct SlotHoldingService {
925 active: Arc<AtomicUsize>,
926 }
927
928 impl Service<Request<Full<Bytes>>> for SlotHoldingService {
929 type Response = Response<Full<Bytes>>;
930 type Error = HttpError;
931 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
932
933 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
934 Poll::Ready(Ok(()))
935 }
936
937 fn call(&mut self, _: Request<Full<Bytes>>) -> Self::Future {
938 self.active.fetch_add(1, Ordering::SeqCst);
939 Box::pin(std::future::pending())
941 }
942 }
943
944 let active = Arc::new(AtomicUsize::new(0));
945
946 let service = tower::ServiceBuilder::new()
948 .layer(LoadShedLayer::new())
949 .layer(ConcurrencyLimitLayer::new(1))
950 .service(SlotHoldingService {
951 active: active.clone(),
952 });
953
954 let service = service.map_err(map_load_shed_error);
955
956 let req1 = Request::builder()
958 .uri("http://test")
959 .body(Full::new(Bytes::new()))
960 .unwrap();
961 let mut svc1 = service.clone();
962
963 let svc1_ready = svc1.ready().await.unwrap();
964 let _pending_fut = svc1_ready.call(req1);
965
966 tokio::time::sleep(Duration::from_millis(10)).await;
968 assert_eq!(
969 active.load(Ordering::SeqCst),
970 1,
971 "First request should be active"
972 );
973
974 let req2 = Request::builder()
976 .uri("http://test")
977 .body(Full::new(Bytes::new()))
978 .unwrap();
979
980 let mut svc2 = service.clone();
981
982 let result = tokio::time::timeout(Duration::from_millis(100), async {
984 match svc2.ready().await {
986 Ok(ready_svc) => ready_svc.call(req2).await,
987 Err(e) => Err(e),
988 }
989 })
990 .await;
991
992 assert!(result.is_ok(), "Request should not hang");
994 let err = result.unwrap().unwrap_err();
995 assert!(
996 matches!(err, HttpError::Overloaded),
997 "Expected Overloaded error, got: {err:?}"
998 );
999 }
1000
1001 #[test]
1007 fn test_map_tower_error_preserves_overloaded() {
1008 let http_err = HttpError::Overloaded;
1009 let boxed: tower::BoxError = Box::new(http_err);
1010 let result = map_tower_error(boxed, Duration::from_secs(30));
1011
1012 assert!(
1013 matches!(result, HttpError::Overloaded),
1014 "Should preserve HttpError::Overloaded, got: {result:?}"
1015 );
1016 }
1017
1018 #[test]
1020 fn test_map_tower_error_preserves_service_closed() {
1021 let http_err = HttpError::ServiceClosed;
1022 let boxed: tower::BoxError = Box::new(http_err);
1023 let result = map_tower_error(boxed, Duration::from_secs(30));
1024
1025 assert!(
1026 matches!(result, HttpError::ServiceClosed),
1027 "Should preserve HttpError::ServiceClosed, got: {result:?}"
1028 );
1029 }
1030
1031 #[test]
1033 fn test_map_tower_error_preserves_timeout_attempt() {
1034 let original_duration = Duration::from_secs(5);
1035 let http_err = HttpError::Timeout(original_duration);
1036 let boxed: tower::BoxError = Box::new(http_err);
1037 let result = map_tower_error(boxed, Duration::from_secs(30));
1039
1040 match result {
1041 HttpError::Timeout(d) => {
1042 assert_eq!(
1043 d, original_duration,
1044 "Should preserve original timeout duration"
1045 );
1046 }
1047 other => panic!("Should preserve HttpError::Timeout, got: {other:?}"),
1048 }
1049 }
1050
1051 #[test]
1053 fn test_map_tower_error_wraps_unknown_as_transport() {
1054 let other_err: tower::BoxError = Box::new(std::io::Error::new(
1055 std::io::ErrorKind::ConnectionRefused,
1056 "connection refused",
1057 ));
1058 let result = map_tower_error(other_err, Duration::from_secs(30));
1059
1060 assert!(
1061 matches!(result, HttpError::Transport(_)),
1062 "Should wrap unknown errors as Transport, got: {result:?}"
1063 );
1064 }
1065
1066 #[tokio::test]
1085 async fn test_cancellation_propagates_through_full_stack() {
1086 use crate::response::ResponseBody;
1087 use std::future::Future;
1088 use std::pin::Pin;
1089 use std::sync::Arc;
1090 use std::sync::atomic::{AtomicBool, Ordering};
1091 use std::task::{Context, Poll};
1092 use tower::Service;
1093
1094 #[derive(Clone)]
1095 struct PendingService {
1096 completed: Arc<AtomicBool>,
1097 drop_notifier: Arc<tokio::sync::Notify>,
1098 started_notifier: Arc<tokio::sync::Notify>,
1099 }
1100
1101 struct FutureGuard {
1102 completed: Arc<AtomicBool>,
1103 drop_notifier: Arc<tokio::sync::Notify>,
1104 }
1105
1106 impl Drop for FutureGuard {
1107 fn drop(&mut self) {
1108 if !self.completed.load(Ordering::SeqCst) {
1109 self.drop_notifier.notify_one();
1110 }
1111 }
1112 }
1113
1114 impl Service<http::Request<Full<Bytes>>> for PendingService {
1115 type Response = http::Response<ResponseBody>;
1116 type Error = HttpError;
1117 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1118
1119 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1120 Poll::Ready(Ok(()))
1121 }
1122
1123 fn call(&mut self, _: http::Request<Full<Bytes>>) -> Self::Future {
1124 let completed = self.completed.clone();
1125 let drop_notifier = self.drop_notifier.clone();
1126 let started_notifier = self.started_notifier.clone();
1127 Box::pin(async move {
1128 let _guard = FutureGuard {
1129 completed: completed.clone(),
1130 drop_notifier,
1131 };
1132 started_notifier.notify_one();
1134 std::future::pending::<()>().await;
1136 completed.store(true, Ordering::SeqCst);
1137 unreachable!()
1138 })
1139 }
1140 }
1141
1142 let inner_completed = Arc::new(AtomicBool::new(false));
1143 let drop_notifier = Arc::new(tokio::sync::Notify::new());
1144 let started_notifier = Arc::new(tokio::sync::Notify::new());
1145
1146 let inner = PendingService {
1147 completed: inner_completed.clone(),
1148 drop_notifier: drop_notifier.clone(),
1149 started_notifier: started_notifier.clone(),
1150 };
1151
1152 let client = HttpClientBuilder::new()
1155 .timeout(Duration::from_secs(30))
1156 .retry(None)
1157 .build_with_inner_service(inner.boxed_clone());
1158
1159 let send_handle = tokio::spawn({
1165 let client = client.clone();
1166 async move { client.get("https://fake/slow").send().await }
1167 });
1168
1169 started_notifier.notified().await;
1171
1172 send_handle.abort();
1174
1175 tokio::time::timeout(Duration::from_secs(5), drop_notifier.notified())
1177 .await
1178 .expect(
1179 "Inner service future should have been dropped within 5s - \
1180 the full toolkit-http stack must propagate cancellation",
1181 );
1182
1183 assert!(
1184 !inner_completed.load(Ordering::SeqCst),
1185 "Inner service future should NOT have completed"
1186 );
1187 }
1188}