1use crate::config::{
2 HttpClientConfig, RedirectConfig, RetryConfig, TlsRootConfig, TransportSecurity,
3};
4use crate::error::HttpError;
5use crate::layers::{OtelLayer, RetryLayer, SecureRedirectPolicy, UserAgentLayer};
6use crate::response::ResponseBody;
7use crate::tls;
8use bytes::Bytes;
9use http::Response;
10use http_body_util::{BodyExt, Full};
11use hyper_rustls::HttpsConnector;
12use hyper_util::client::legacy::Client;
13use hyper_util::client::legacy::connect::HttpConnector;
14use hyper_util::rt::{TokioExecutor, TokioTimer};
15use std::time::Duration;
16use tower::buffer::Buffer;
17use tower::limit::ConcurrencyLimitLayer;
18use tower::load_shed::LoadShedLayer;
19use tower::timeout::TimeoutLayer;
20use tower::util::BoxCloneService;
21use tower::{ServiceBuilder, ServiceExt};
22use tower_http::decompression::DecompressionLayer;
23use tower_http::follow_redirect::FollowRedirectLayer;
24
25type InnerService =
27 BoxCloneService<http::Request<Full<Bytes>>, http::Response<ResponseBody>, HttpError>;
28
29pub struct HttpClientBuilder {
31 config: HttpClientConfig,
32 auth_layer: Option<Box<dyn FnOnce(InnerService) -> InnerService + Send>>,
33 metrics_layer: Option<Box<dyn FnOnce(InnerService) -> InnerService + Send>>,
34}
35
36impl HttpClientBuilder {
37 #[must_use]
39 pub fn new() -> Self {
40 Self {
41 config: HttpClientConfig::default(),
42 auth_layer: None,
43 metrics_layer: None,
44 }
45 }
46
47 #[must_use]
49 pub fn with_config(config: HttpClientConfig) -> Self {
50 Self {
51 config,
52 auth_layer: None,
53 metrics_layer: None,
54 }
55 }
56
57 #[must_use]
62 pub fn timeout(mut self, timeout: Duration) -> Self {
63 self.config.request_timeout = timeout;
64 self
65 }
66
67 #[must_use]
73 pub fn total_timeout(mut self, timeout: Duration) -> Self {
74 self.config.total_timeout = Some(timeout);
75 self
76 }
77
78 #[must_use]
80 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
81 self.config.user_agent = user_agent.into();
82 self
83 }
84
85 #[must_use]
87 pub fn retry(mut self, retry: Option<RetryConfig>) -> Self {
88 self.config.retry = retry;
89 self
90 }
91
92 #[must_use]
94 pub fn max_body_size(mut self, size: usize) -> Self {
95 self.config.max_body_size = size;
96 self
97 }
98
99 #[must_use]
103 pub fn transport(mut self, transport: TransportSecurity) -> Self {
104 self.config.transport = transport;
105 self
106 }
107
108 #[must_use]
114 pub fn deny_insecure_http(mut self) -> Self {
115 tracing::debug!(
116 target: "toolkit_http::security",
117 "deny_insecure_http() called - enforcing TLS for all connections"
118 );
119 self.config.transport = TransportSecurity::TlsOnly;
120 self
121 }
122
123 #[must_use]
128 pub fn with_otel(mut self) -> Self {
129 self.config.otel = true;
130 self
131 }
132
133 #[must_use]
141 pub fn with_auth_layer(
142 mut self,
143 wrap: impl FnOnce(InnerService) -> InnerService + Send + 'static,
144 ) -> Self {
145 self.auth_layer = Some(Box::new(wrap));
146 self
147 }
148
149 #[must_use]
161 pub fn with_metrics_layer(
162 mut self,
163 wrap: impl FnOnce(InnerService) -> InnerService + Send + 'static,
164 ) -> Self {
165 self.metrics_layer = Some(Box::new(wrap));
166 self
167 }
168
169 #[must_use]
178 pub fn buffer_capacity(mut self, capacity: usize) -> Self {
179 self.config.buffer_capacity = capacity.max(1);
181 self
182 }
183
184 #[must_use]
189 pub fn max_redirects(mut self, max_redirects: usize) -> Self {
190 self.config.redirect.max_redirects = max_redirects;
191 self
192 }
193
194 #[must_use]
199 pub fn no_redirects(mut self) -> Self {
200 self.config.redirect = RedirectConfig::disabled();
201 self
202 }
203
204 #[must_use]
219 pub fn redirect(mut self, config: RedirectConfig) -> Self {
220 self.config.redirect = config;
221 self
222 }
223
224 #[must_use]
231 pub fn pool_idle_timeout(mut self, timeout: Option<Duration>) -> Self {
232 self.config.pool_idle_timeout = timeout;
233 self
234 }
235
236 #[must_use]
244 pub fn pool_max_idle_per_host(mut self, max: usize) -> Self {
245 self.config.pool_max_idle_per_host = max;
246 self
247 }
248
249 pub fn build(self) -> Result<crate::HttpClient, HttpError> {
258 #[cfg(feature = "fips")]
263 if self.config.transport == TransportSecurity::AllowInsecureHttp {
264 tracing::warn!(
265 target: "toolkit_http::security",
266 "rejecting AllowInsecureHttp under --features fips: returning HttpError::InsecureTransport"
267 );
268 return Err(HttpError::InsecureTransport);
269 }
270
271 let timeout = self.config.request_timeout;
272 let total_timeout = self.config.total_timeout;
273
274 let https = build_https_connector(self.config.tls_roots, self.config.transport)?;
276
277 let mut client_builder = Client::builder(TokioExecutor::new());
279
280 client_builder
283 .pool_timer(TokioTimer::new())
284 .pool_max_idle_per_host(self.config.pool_max_idle_per_host)
285 .http2_only(false); if let Some(idle_timeout) = self.config.pool_idle_timeout {
289 client_builder.pool_idle_timeout(idle_timeout);
290 }
291
292 let hyper_client = client_builder.build::<_, Full<Bytes>>(https);
293
294 let ua_layer = UserAgentLayer::try_new(&self.config.user_agent)?;
296
297 let redirect_policy = SecureRedirectPolicy::new(self.config.redirect.clone());
332
333 let service = ServiceBuilder::new()
335 .layer(TimeoutLayer::new(timeout))
336 .layer(ua_layer)
337 .layer(DecompressionLayer::new())
338 .layer(FollowRedirectLayer::with_policy(redirect_policy))
339 .service(hyper_client);
340
341 let service = service.map_response(map_decompression_response);
347
348 let service = service.map_err(move |e: tower::BoxError| map_tower_error(e, timeout));
350
351 let mut boxed_service = service.boxed_clone();
353
354 if let Some(wrap) = self.auth_layer {
357 boxed_service = wrap(boxed_service);
358 }
359
360 if let Some(ref retry_config) = self.config.retry {
372 let retry_layer = RetryLayer::with_total_timeout(retry_config.clone(), total_timeout);
373 let retry_service = ServiceBuilder::new()
374 .layer(retry_layer)
375 .service(boxed_service);
376 boxed_service = retry_service.boxed_clone();
377 }
378
379 if let Some(wrap) = self.metrics_layer {
382 boxed_service = wrap(boxed_service);
383 }
384
385 if let Some(rate_limit) = self.config.rate_limit
389 && rate_limit.max_concurrent_requests < usize::MAX
390 {
391 let limited_service = ServiceBuilder::new()
392 .layer(LoadShedLayer::new())
393 .layer(ConcurrencyLimitLayer::new(
394 rate_limit.max_concurrent_requests,
395 ))
396 .service(boxed_service);
397 let limited_service = limited_service.map_err(map_load_shed_error);
399 boxed_service = limited_service.boxed_clone();
400 }
401
402 if self.config.otel {
406 let otel_service = ServiceBuilder::new()
407 .layer(OtelLayer::new())
408 .service(boxed_service);
409 boxed_service = otel_service.boxed_clone();
410 }
411
412 let buffer_capacity = self.config.buffer_capacity.max(1);
416 let buffered_service: crate::client::BufferedService =
417 Buffer::new(boxed_service, buffer_capacity);
418
419 Ok(crate::HttpClient {
420 service: buffered_service,
421 max_body_size: self.config.max_body_size,
422 transport_security: self.config.transport,
423 })
424 }
425}
426
427#[cfg(test)]
428impl HttpClientBuilder {
429 fn build_with_inner_service(self, inner: InnerService) -> crate::HttpClient {
437 let mut boxed_service = inner;
438
439 if let Some(ref retry_config) = self.config.retry {
440 let retry_layer =
441 RetryLayer::with_total_timeout(retry_config.clone(), self.config.total_timeout);
442 let retry_service = ServiceBuilder::new()
443 .layer(retry_layer)
444 .service(boxed_service);
445 boxed_service = retry_service.boxed_clone();
446 }
447
448 if let Some(rate_limit) = self.config.rate_limit
449 && rate_limit.max_concurrent_requests < usize::MAX
450 {
451 let limited_service = ServiceBuilder::new()
452 .layer(LoadShedLayer::new())
453 .layer(ConcurrencyLimitLayer::new(
454 rate_limit.max_concurrent_requests,
455 ))
456 .service(boxed_service);
457 let limited_service = limited_service.map_err(map_load_shed_error);
458 boxed_service = limited_service.boxed_clone();
459 }
460
461 let buffer_capacity = self.config.buffer_capacity.max(1);
462 let buffered_service: crate::client::BufferedService =
463 Buffer::new(boxed_service, buffer_capacity);
464
465 crate::HttpClient {
466 service: buffered_service,
467 max_body_size: self.config.max_body_size,
468 transport_security: self.config.transport,
469 }
470 }
471}
472
473impl Default for HttpClientBuilder {
474 fn default() -> Self {
475 Self::new()
476 }
477}
478
479fn map_tower_error(err: tower::BoxError, timeout: Duration) -> HttpError {
485 if err.is::<tower::timeout::error::Elapsed>() {
486 return HttpError::Timeout(timeout);
487 }
488
489 match err.downcast::<HttpError>() {
491 Ok(http_err) => *http_err,
492 Err(other) => HttpError::Transport(other),
493 }
494}
495
496fn map_load_shed_error(err: tower::BoxError) -> HttpError {
498 if err.is::<tower::load_shed::error::Overloaded>() {
499 HttpError::Overloaded
500 } else {
501 match err.downcast::<HttpError>() {
503 Ok(http_err) => *http_err,
504 Err(err) => HttpError::Transport(err),
505 }
506 }
507}
508
509fn map_decompression_response<B>(response: Response<B>) -> Response<ResponseBody>
514where
515 B: hyper::body::Body<Data = Bytes> + Send + Sync + 'static,
516 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
517{
518 let (parts, body) = response.into_parts();
519 let boxed_body: ResponseBody = body.map_err(Into::into).boxed();
523 Response::from_parts(parts, boxed_body)
524}
525
526fn build_https_connector(
540 tls_roots: TlsRootConfig,
541 transport: TransportSecurity,
542) -> Result<HttpsConnector<HttpConnector>, HttpError> {
543 let allow_http = transport == TransportSecurity::AllowInsecureHttp;
544
545 let client_config = match tls_roots {
553 TlsRootConfig::WebPki => tls::webpki_roots_client_config(),
554 TlsRootConfig::Native => tls::native_roots_client_config(),
555 }
556 .map_err(|e| HttpError::Tls(Box::new(e)))?;
557
558 let builder = hyper_rustls::HttpsConnectorBuilder::new().with_tls_config(client_config);
559 let connector = if allow_http {
560 builder.https_or_http().enable_all_versions().build()
561 } else {
562 builder.https_only().enable_all_versions().build()
563 };
564 Ok(connector)
565}
566
567#[cfg(test)]
568#[cfg_attr(coverage_nightly, coverage(off))]
569mod tests {
570 use super::*;
571 use crate::config::DEFAULT_USER_AGENT;
572
573 #[test]
574 fn test_builder_default() {
575 let builder = HttpClientBuilder::new();
576 assert_eq!(builder.config.request_timeout, Duration::from_secs(30));
577 assert_eq!(builder.config.user_agent, DEFAULT_USER_AGENT);
578 assert!(builder.config.retry.is_some());
579 assert_eq!(builder.config.buffer_capacity, 1024);
580 }
581
582 #[test]
583 fn test_builder_with_config() {
584 let config = HttpClientConfig::minimal();
585 let builder = HttpClientBuilder::with_config(config);
586 assert_eq!(builder.config.request_timeout, Duration::from_secs(10));
587 }
588
589 #[test]
590 fn test_builder_timeout() {
591 let builder = HttpClientBuilder::new().timeout(Duration::from_mins(1));
592 assert_eq!(builder.config.request_timeout, Duration::from_mins(1));
593 }
594
595 #[test]
596 fn test_builder_user_agent() {
597 let builder = HttpClientBuilder::new().user_agent("custom/1.0");
598 assert_eq!(builder.config.user_agent, "custom/1.0");
599 }
600
601 #[test]
602 fn test_builder_retry() {
603 let builder = HttpClientBuilder::new().retry(None);
604 assert!(builder.config.retry.is_none());
605 }
606
607 #[test]
608 fn test_builder_max_body_size() {
609 let builder = HttpClientBuilder::new().max_body_size(1024);
610 assert_eq!(builder.config.max_body_size, 1024);
611 }
612
613 #[test]
614 fn test_builder_transport_security() {
615 let builder = HttpClientBuilder::new().transport(TransportSecurity::TlsOnly);
616 assert_eq!(builder.config.transport, TransportSecurity::TlsOnly);
617
618 let builder = HttpClientBuilder::new().deny_insecure_http();
619 assert_eq!(builder.config.transport, TransportSecurity::TlsOnly);
620
621 let builder = HttpClientBuilder::new();
622 #[cfg(not(feature = "fips"))]
623 assert_eq!(
624 builder.config.transport,
625 TransportSecurity::AllowInsecureHttp
626 );
627 #[cfg(feature = "fips")]
628 assert_eq!(builder.config.transport, TransportSecurity::TlsOnly);
629 }
630
631 #[test]
632 fn test_builder_otel() {
633 let builder = HttpClientBuilder::new().with_otel();
634 assert!(builder.config.otel);
635 }
636
637 #[test]
638 fn test_builder_buffer_capacity() {
639 let builder = HttpClientBuilder::new().buffer_capacity(512);
640 assert_eq!(builder.config.buffer_capacity, 512);
641 }
642
643 #[test]
647 fn test_builder_buffer_capacity_zero_clamped() {
648 let builder = HttpClientBuilder::new().buffer_capacity(0);
649 assert_eq!(
650 builder.config.buffer_capacity, 1,
651 "buffer_capacity=0 should be clamped to 1"
652 );
653 }
654
655 #[tokio::test]
657 async fn test_builder_buffer_capacity_zero_in_config_clamped() {
658 let config = HttpClientConfig {
659 buffer_capacity: 0, ..Default::default()
661 };
662 let result = HttpClientBuilder::with_config(config).build();
663 assert!(
665 result.is_ok(),
666 "build() should succeed with capacity clamped to 1"
667 );
668 }
669
670 #[tokio::test]
671 async fn test_builder_build_with_otel() {
672 let client = HttpClientBuilder::new().with_otel().build();
673 assert!(client.is_ok());
674 }
675
676 #[tokio::test]
677 async fn test_builder_with_auth_layer() {
678 let client = HttpClientBuilder::new()
679 .with_auth_layer(|svc| svc) .build();
681 assert!(client.is_ok());
682 }
683
684 #[tokio::test]
685 async fn test_builder_with_metrics_layer() {
686 let client = HttpClientBuilder::new()
687 .with_metrics_layer(|svc| svc) .build();
689 assert!(client.is_ok());
690 }
691
692 #[tokio::test]
693 async fn test_builder_with_metrics_layer_second_call_replaces_first() {
694 use std::sync::Arc;
695 use std::sync::atomic::{AtomicUsize, Ordering};
696
697 let call_count = Arc::new(AtomicUsize::new(0));
698 let call_count2 = call_count.clone();
699
700 let client = HttpClientBuilder::new()
702 .with_metrics_layer(|_svc| {
703 panic!("first metrics layer should have been replaced");
705 })
706 .with_metrics_layer(move |svc| {
707 call_count2.fetch_add(1, Ordering::SeqCst);
708 svc
709 })
710 .build();
711
712 assert!(client.is_ok());
713 assert_eq!(
714 call_count.load(Ordering::SeqCst),
715 1,
716 "second metrics layer must be applied exactly once"
717 );
718 }
719
720 #[tokio::test]
721 async fn test_builder_build() {
722 let client = HttpClientBuilder::new().build();
723 assert!(client.is_ok());
724 }
725
726 #[tokio::test]
727 async fn test_builder_build_with_deny_insecure_http() {
728 let client = HttpClientBuilder::new().deny_insecure_http().build();
729 assert!(client.is_ok());
730 }
731
732 #[tokio::test]
733 async fn test_builder_build_with_sse_config() {
734 use crate::config::HttpClientConfig;
735 let config = HttpClientConfig::sse();
736 let client = HttpClientBuilder::with_config(config).build();
737 assert!(client.is_ok(), "SSE config should build successfully");
738 }
739
740 #[tokio::test]
741 async fn test_builder_build_invalid_user_agent() {
742 let client = HttpClientBuilder::new()
743 .user_agent("invalid\x00agent")
744 .build();
745 assert!(client.is_err());
746 }
747
748 #[tokio::test]
749 async fn test_builder_default_uses_webpki_roots() {
750 let builder = HttpClientBuilder::new();
751 assert_eq!(builder.config.tls_roots, TlsRootConfig::WebPki);
752 let client = builder.build();
754 assert!(client.is_ok());
755 }
756
757 #[tokio::test]
758 async fn test_builder_native_roots() {
759 let config = HttpClientConfig {
760 tls_roots: TlsRootConfig::Native,
761 ..Default::default()
762 };
763 let result = HttpClientBuilder::with_config(config).build();
764
765 match &result {
769 Ok(_) => {
770 }
772 Err(HttpError::Tls(err)) => {
773 let msg = err.to_string();
775 assert!(
776 msg.contains("native root") || msg.contains("certificate"),
777 "TLS error should mention certificates: {msg}"
778 );
779 }
780 Err(other) => {
781 panic!("Unexpected error type: {other:?}");
782 }
783 }
784 }
785
786 #[tokio::test]
787 async fn test_builder_webpki_roots_https_only() {
788 let config = HttpClientConfig {
789 tls_roots: TlsRootConfig::WebPki,
790 transport: TransportSecurity::TlsOnly,
791 ..Default::default()
792 };
793 let client = HttpClientBuilder::with_config(config).build();
794 assert!(client.is_ok());
795 }
796
797 #[tokio::test]
806 async fn test_http2_enabled_for_all_configurations() {
807 let client = HttpClientBuilder::new().build();
809 assert!(
810 client.is_ok(),
811 "WebPki + default transport should build with HTTP/2 enabled"
812 );
813
814 let client = HttpClientBuilder::new()
816 .transport(TransportSecurity::TlsOnly)
817 .build();
818 assert!(
819 client.is_ok(),
820 "WebPki + TlsOnly should build with HTTP/2 enabled"
821 );
822
823 #[cfg(not(feature = "fips"))]
825 {
826 let config = HttpClientConfig {
827 tls_roots: TlsRootConfig::Native,
828 transport: TransportSecurity::AllowInsecureHttp,
829 ..Default::default()
830 };
831 let client = HttpClientBuilder::with_config(config).build();
832 assert!(
833 client.is_ok(),
834 "Native + AllowInsecureHttp should build with HTTP/2 enabled"
835 );
836 }
837
838 let config = HttpClientConfig {
840 tls_roots: TlsRootConfig::Native,
841 transport: TransportSecurity::TlsOnly,
842 ..Default::default()
843 };
844 let client = HttpClientBuilder::with_config(config).build();
845 assert!(
846 client.is_ok(),
847 "Native + TlsOnly should build with HTTP/2 enabled"
848 );
849 }
850
851 #[tokio::test]
856 async fn test_load_shedding_returns_overloaded_error() {
857 use bytes::Bytes;
858 use http::{Request, Response};
859 use http_body_util::Full;
860 use std::future::Future;
861 use std::pin::Pin;
862 use std::sync::Arc;
863 use std::sync::atomic::{AtomicUsize, Ordering};
864 use std::task::{Context, Poll};
865 use tower::Service;
866 use tower::ServiceExt;
867
868 #[derive(Clone)]
870 struct SlotHoldingService {
871 active: Arc<AtomicUsize>,
872 }
873
874 impl Service<Request<Full<Bytes>>> for SlotHoldingService {
875 type Response = Response<Full<Bytes>>;
876 type Error = HttpError;
877 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
878
879 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
880 Poll::Ready(Ok(()))
881 }
882
883 fn call(&mut self, _: Request<Full<Bytes>>) -> Self::Future {
884 self.active.fetch_add(1, Ordering::SeqCst);
885 Box::pin(std::future::pending())
887 }
888 }
889
890 let active = Arc::new(AtomicUsize::new(0));
891
892 let service = tower::ServiceBuilder::new()
894 .layer(LoadShedLayer::new())
895 .layer(ConcurrencyLimitLayer::new(1))
896 .service(SlotHoldingService {
897 active: active.clone(),
898 });
899
900 let service = service.map_err(map_load_shed_error);
901
902 let req1 = Request::builder()
904 .uri("http://test")
905 .body(Full::new(Bytes::new()))
906 .unwrap();
907 let mut svc1 = service.clone();
908
909 let svc1_ready = svc1.ready().await.unwrap();
910 let _pending_fut = svc1_ready.call(req1);
911
912 tokio::time::sleep(Duration::from_millis(10)).await;
914 assert_eq!(
915 active.load(Ordering::SeqCst),
916 1,
917 "First request should be active"
918 );
919
920 let req2 = Request::builder()
922 .uri("http://test")
923 .body(Full::new(Bytes::new()))
924 .unwrap();
925
926 let mut svc2 = service.clone();
927
928 let result = tokio::time::timeout(Duration::from_millis(100), async {
930 match svc2.ready().await {
932 Ok(ready_svc) => ready_svc.call(req2).await,
933 Err(e) => Err(e),
934 }
935 })
936 .await;
937
938 assert!(result.is_ok(), "Request should not hang");
940 let err = result.unwrap().unwrap_err();
941 assert!(
942 matches!(err, HttpError::Overloaded),
943 "Expected Overloaded error, got: {err:?}"
944 );
945 }
946
947 #[test]
953 fn test_map_tower_error_preserves_overloaded() {
954 let http_err = HttpError::Overloaded;
955 let boxed: tower::BoxError = Box::new(http_err);
956 let result = map_tower_error(boxed, Duration::from_secs(30));
957
958 assert!(
959 matches!(result, HttpError::Overloaded),
960 "Should preserve HttpError::Overloaded, got: {result:?}"
961 );
962 }
963
964 #[test]
966 fn test_map_tower_error_preserves_service_closed() {
967 let http_err = HttpError::ServiceClosed;
968 let boxed: tower::BoxError = Box::new(http_err);
969 let result = map_tower_error(boxed, Duration::from_secs(30));
970
971 assert!(
972 matches!(result, HttpError::ServiceClosed),
973 "Should preserve HttpError::ServiceClosed, got: {result:?}"
974 );
975 }
976
977 #[test]
979 fn test_map_tower_error_preserves_timeout_attempt() {
980 let original_duration = Duration::from_secs(5);
981 let http_err = HttpError::Timeout(original_duration);
982 let boxed: tower::BoxError = Box::new(http_err);
983 let result = map_tower_error(boxed, Duration::from_secs(30));
985
986 match result {
987 HttpError::Timeout(d) => {
988 assert_eq!(
989 d, original_duration,
990 "Should preserve original timeout duration"
991 );
992 }
993 other => panic!("Should preserve HttpError::Timeout, got: {other:?}"),
994 }
995 }
996
997 #[test]
999 fn test_map_tower_error_wraps_unknown_as_transport() {
1000 let other_err: tower::BoxError = Box::new(std::io::Error::new(
1001 std::io::ErrorKind::ConnectionRefused,
1002 "connection refused",
1003 ));
1004 let result = map_tower_error(other_err, Duration::from_secs(30));
1005
1006 assert!(
1007 matches!(result, HttpError::Transport(_)),
1008 "Should wrap unknown errors as Transport, got: {result:?}"
1009 );
1010 }
1011
1012 #[tokio::test]
1031 async fn test_cancellation_propagates_through_full_stack() {
1032 use crate::response::ResponseBody;
1033 use std::future::Future;
1034 use std::pin::Pin;
1035 use std::sync::Arc;
1036 use std::sync::atomic::{AtomicBool, Ordering};
1037 use std::task::{Context, Poll};
1038 use tower::Service;
1039
1040 #[derive(Clone)]
1041 struct PendingService {
1042 completed: Arc<AtomicBool>,
1043 drop_notifier: Arc<tokio::sync::Notify>,
1044 started_notifier: Arc<tokio::sync::Notify>,
1045 }
1046
1047 struct FutureGuard {
1048 completed: Arc<AtomicBool>,
1049 drop_notifier: Arc<tokio::sync::Notify>,
1050 }
1051
1052 impl Drop for FutureGuard {
1053 fn drop(&mut self) {
1054 if !self.completed.load(Ordering::SeqCst) {
1055 self.drop_notifier.notify_one();
1056 }
1057 }
1058 }
1059
1060 impl Service<http::Request<Full<Bytes>>> for PendingService {
1061 type Response = http::Response<ResponseBody>;
1062 type Error = HttpError;
1063 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1064
1065 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1066 Poll::Ready(Ok(()))
1067 }
1068
1069 fn call(&mut self, _: http::Request<Full<Bytes>>) -> Self::Future {
1070 let completed = self.completed.clone();
1071 let drop_notifier = self.drop_notifier.clone();
1072 let started_notifier = self.started_notifier.clone();
1073 Box::pin(async move {
1074 let _guard = FutureGuard {
1075 completed: completed.clone(),
1076 drop_notifier,
1077 };
1078 started_notifier.notify_one();
1080 std::future::pending::<()>().await;
1082 completed.store(true, Ordering::SeqCst);
1083 unreachable!()
1084 })
1085 }
1086 }
1087
1088 let inner_completed = Arc::new(AtomicBool::new(false));
1089 let drop_notifier = Arc::new(tokio::sync::Notify::new());
1090 let started_notifier = Arc::new(tokio::sync::Notify::new());
1091
1092 let inner = PendingService {
1093 completed: inner_completed.clone(),
1094 drop_notifier: drop_notifier.clone(),
1095 started_notifier: started_notifier.clone(),
1096 };
1097
1098 let client = HttpClientBuilder::new()
1101 .timeout(Duration::from_secs(30))
1102 .retry(None)
1103 .build_with_inner_service(inner.boxed_clone());
1104
1105 let send_handle = tokio::spawn({
1111 let client = client.clone();
1112 async move { client.get("https://fake/slow").send().await }
1113 });
1114
1115 started_notifier.notified().await;
1117
1118 send_handle.abort();
1120
1121 tokio::time::timeout(Duration::from_secs(5), drop_notifier.notified())
1123 .await
1124 .expect(
1125 "Inner service future should have been dropped within 5s - \
1126 the full toolkit-http stack must propagate cancellation",
1127 );
1128
1129 assert!(
1130 !inner_completed.load(Ordering::SeqCst),
1131 "Inner service future should NOT have completed"
1132 );
1133 }
1134}