1use crate::config::{
2 HttpClientConfig, RedirectConfig, RetryConfig, TlsRootConfig, TransportSecurity,
3};
4use crate::error::HttpError;
5use crate::layers::{OtelLayer, RetryLayer, SecureRedirectPolicy, UserAgentLayer};
6use crate::response::ResponseBody;
7use crate::tls;
8use bytes::Bytes;
9use http::Response;
10use http_body_util::{BodyExt, Full};
11use hyper_rustls::HttpsConnector;
12use hyper_util::client::legacy::Client;
13use hyper_util::client::legacy::connect::HttpConnector;
14use hyper_util::rt::{TokioExecutor, TokioTimer};
15use std::time::Duration;
16use tower::buffer::Buffer;
17use tower::limit::ConcurrencyLimitLayer;
18use tower::load_shed::LoadShedLayer;
19use tower::timeout::TimeoutLayer;
20use tower::util::BoxCloneService;
21use tower::{ServiceBuilder, ServiceExt};
22use tower_http::decompression::DecompressionLayer;
23use tower_http::follow_redirect::FollowRedirectLayer;
24
25type InnerService =
27 BoxCloneService<http::Request<Full<Bytes>>, http::Response<ResponseBody>, HttpError>;
28
29pub struct HttpClientBuilder {
31 config: HttpClientConfig,
32 auth_layer: Option<Box<dyn FnOnce(InnerService) -> InnerService + Send>>,
33 metrics_layer: Option<Box<dyn FnOnce(InnerService) -> InnerService + Send>>,
34}
35
36impl HttpClientBuilder {
37 #[must_use]
39 pub fn new() -> Self {
40 Self {
41 config: HttpClientConfig::default(),
42 auth_layer: None,
43 metrics_layer: None,
44 }
45 }
46
47 #[must_use]
49 pub fn with_config(config: HttpClientConfig) -> Self {
50 Self {
51 config,
52 auth_layer: None,
53 metrics_layer: None,
54 }
55 }
56
57 #[must_use]
62 pub fn timeout(mut self, timeout: Duration) -> Self {
63 self.config.request_timeout = timeout;
64 self
65 }
66
67 #[must_use]
73 pub fn total_timeout(mut self, timeout: Duration) -> Self {
74 self.config.total_timeout = Some(timeout);
75 self
76 }
77
78 #[must_use]
80 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
81 self.config.user_agent = user_agent.into();
82 self
83 }
84
85 #[must_use]
87 pub fn retry(mut self, retry: Option<RetryConfig>) -> Self {
88 self.config.retry = retry;
89 self
90 }
91
92 #[must_use]
94 pub fn max_body_size(mut self, size: usize) -> Self {
95 self.config.max_body_size = size;
96 self
97 }
98
99 #[must_use]
103 pub fn transport(mut self, transport: TransportSecurity) -> Self {
104 self.config.transport = transport;
105 self
106 }
107
108 #[must_use]
114 pub fn deny_insecure_http(mut self) -> Self {
115 tracing::debug!(
116 target: "modkit_http::security",
117 "deny_insecure_http() called - enforcing TLS for all connections"
118 );
119 self.config.transport = TransportSecurity::TlsOnly;
120 self
121 }
122
123 #[must_use]
128 pub fn with_otel(mut self) -> Self {
129 self.config.otel = true;
130 self
131 }
132
133 #[must_use]
141 pub fn with_auth_layer(
142 mut self,
143 wrap: impl FnOnce(InnerService) -> InnerService + Send + 'static,
144 ) -> Self {
145 self.auth_layer = Some(Box::new(wrap));
146 self
147 }
148
149 #[must_use]
161 pub fn with_metrics_layer(
162 mut self,
163 wrap: impl FnOnce(InnerService) -> InnerService + Send + 'static,
164 ) -> Self {
165 self.metrics_layer = Some(Box::new(wrap));
166 self
167 }
168
169 #[must_use]
178 pub fn buffer_capacity(mut self, capacity: usize) -> Self {
179 self.config.buffer_capacity = capacity.max(1);
181 self
182 }
183
184 #[must_use]
189 pub fn max_redirects(mut self, max_redirects: usize) -> Self {
190 self.config.redirect.max_redirects = max_redirects;
191 self
192 }
193
194 #[must_use]
199 pub fn no_redirects(mut self) -> Self {
200 self.config.redirect = RedirectConfig::disabled();
201 self
202 }
203
204 #[must_use]
219 pub fn redirect(mut self, config: RedirectConfig) -> Self {
220 self.config.redirect = config;
221 self
222 }
223
224 #[must_use]
231 pub fn pool_idle_timeout(mut self, timeout: Option<Duration>) -> Self {
232 self.config.pool_idle_timeout = timeout;
233 self
234 }
235
236 #[must_use]
244 pub fn pool_max_idle_per_host(mut self, max: usize) -> Self {
245 self.config.pool_max_idle_per_host = max;
246 self
247 }
248
249 pub fn build(self) -> Result<crate::HttpClient, HttpError> {
254 let timeout = self.config.request_timeout;
255 let total_timeout = self.config.total_timeout;
256
257 let https = build_https_connector(self.config.tls_roots, self.config.transport)?;
259
260 let mut client_builder = Client::builder(TokioExecutor::new());
262
263 client_builder
266 .pool_timer(TokioTimer::new())
267 .pool_max_idle_per_host(self.config.pool_max_idle_per_host)
268 .http2_only(false); if let Some(idle_timeout) = self.config.pool_idle_timeout {
272 client_builder.pool_idle_timeout(idle_timeout);
273 }
274
275 let hyper_client = client_builder.build::<_, Full<Bytes>>(https);
276
277 let ua_layer = UserAgentLayer::try_new(&self.config.user_agent)?;
279
280 let redirect_policy = SecureRedirectPolicy::new(self.config.redirect.clone());
315
316 let service = ServiceBuilder::new()
318 .layer(TimeoutLayer::new(timeout))
319 .layer(ua_layer)
320 .layer(DecompressionLayer::new())
321 .layer(FollowRedirectLayer::with_policy(redirect_policy))
322 .service(hyper_client);
323
324 let service = service.map_response(map_decompression_response);
330
331 let service = service.map_err(move |e: tower::BoxError| map_tower_error(e, timeout));
333
334 let mut boxed_service = service.boxed_clone();
336
337 if let Some(wrap) = self.auth_layer {
340 boxed_service = wrap(boxed_service);
341 }
342
343 if let Some(ref retry_config) = self.config.retry {
355 let retry_layer = RetryLayer::with_total_timeout(retry_config.clone(), total_timeout);
356 let retry_service = ServiceBuilder::new()
357 .layer(retry_layer)
358 .service(boxed_service);
359 boxed_service = retry_service.boxed_clone();
360 }
361
362 if let Some(wrap) = self.metrics_layer {
365 boxed_service = wrap(boxed_service);
366 }
367
368 if let Some(rate_limit) = self.config.rate_limit
372 && rate_limit.max_concurrent_requests < usize::MAX
373 {
374 let limited_service = ServiceBuilder::new()
375 .layer(LoadShedLayer::new())
376 .layer(ConcurrencyLimitLayer::new(
377 rate_limit.max_concurrent_requests,
378 ))
379 .service(boxed_service);
380 let limited_service = limited_service.map_err(map_load_shed_error);
382 boxed_service = limited_service.boxed_clone();
383 }
384
385 if self.config.otel {
389 let otel_service = ServiceBuilder::new()
390 .layer(OtelLayer::new())
391 .service(boxed_service);
392 boxed_service = otel_service.boxed_clone();
393 }
394
395 let buffer_capacity = self.config.buffer_capacity.max(1);
399 let buffered_service: crate::client::BufferedService =
400 Buffer::new(boxed_service, buffer_capacity);
401
402 Ok(crate::HttpClient {
403 service: buffered_service,
404 max_body_size: self.config.max_body_size,
405 transport_security: self.config.transport,
406 })
407 }
408}
409
410#[cfg(test)]
411impl HttpClientBuilder {
412 fn build_with_inner_service(self, inner: InnerService) -> crate::HttpClient {
420 let mut boxed_service = inner;
421
422 if let Some(ref retry_config) = self.config.retry {
423 let retry_layer =
424 RetryLayer::with_total_timeout(retry_config.clone(), self.config.total_timeout);
425 let retry_service = ServiceBuilder::new()
426 .layer(retry_layer)
427 .service(boxed_service);
428 boxed_service = retry_service.boxed_clone();
429 }
430
431 if let Some(rate_limit) = self.config.rate_limit
432 && rate_limit.max_concurrent_requests < usize::MAX
433 {
434 let limited_service = ServiceBuilder::new()
435 .layer(LoadShedLayer::new())
436 .layer(ConcurrencyLimitLayer::new(
437 rate_limit.max_concurrent_requests,
438 ))
439 .service(boxed_service);
440 let limited_service = limited_service.map_err(map_load_shed_error);
441 boxed_service = limited_service.boxed_clone();
442 }
443
444 let buffer_capacity = self.config.buffer_capacity.max(1);
445 let buffered_service: crate::client::BufferedService =
446 Buffer::new(boxed_service, buffer_capacity);
447
448 crate::HttpClient {
449 service: buffered_service,
450 max_body_size: self.config.max_body_size,
451 transport_security: self.config.transport,
452 }
453 }
454}
455
456impl Default for HttpClientBuilder {
457 fn default() -> Self {
458 Self::new()
459 }
460}
461
462fn map_tower_error(err: tower::BoxError, timeout: Duration) -> HttpError {
468 if err.is::<tower::timeout::error::Elapsed>() {
469 return HttpError::Timeout(timeout);
470 }
471
472 match err.downcast::<HttpError>() {
474 Ok(http_err) => *http_err,
475 Err(other) => HttpError::Transport(other),
476 }
477}
478
479fn map_load_shed_error(err: tower::BoxError) -> HttpError {
481 if err.is::<tower::load_shed::error::Overloaded>() {
482 HttpError::Overloaded
483 } else {
484 match err.downcast::<HttpError>() {
486 Ok(http_err) => *http_err,
487 Err(err) => HttpError::Transport(err),
488 }
489 }
490}
491
492fn map_decompression_response<B>(response: Response<B>) -> Response<ResponseBody>
497where
498 B: hyper::body::Body<Data = Bytes> + Send + Sync + 'static,
499 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
500{
501 let (parts, body) = response.into_parts();
502 let boxed_body: ResponseBody = body.map_err(Into::into).boxed();
506 Response::from_parts(parts, boxed_body)
507}
508
509fn build_https_connector(
523 tls_roots: TlsRootConfig,
524 transport: TransportSecurity,
525) -> Result<HttpsConnector<HttpConnector>, HttpError> {
526 let allow_http = transport == TransportSecurity::AllowInsecureHttp;
527
528 match tls_roots {
529 TlsRootConfig::WebPki => {
530 let provider = tls::get_crypto_provider();
531 let builder = hyper_rustls::HttpsConnectorBuilder::new()
532 .with_provider_and_webpki_roots(provider)
533 .map_err(|e| HttpError::Tls(Box::new(e)))?;
536 let connector = if allow_http {
537 builder.https_or_http().enable_all_versions().build()
538 } else {
539 builder.https_only().enable_all_versions().build()
540 };
541 Ok(connector)
542 }
543 TlsRootConfig::Native => {
544 let client_config = tls::native_roots_client_config()
545 .map_err(|e| HttpError::Tls(e.into()))?;
547 let builder = hyper_rustls::HttpsConnectorBuilder::new().with_tls_config(client_config);
548 let connector = if allow_http {
549 builder.https_or_http().enable_all_versions().build()
550 } else {
551 builder.https_only().enable_all_versions().build()
552 };
553 Ok(connector)
554 }
555 }
556}
557
558#[cfg(test)]
559#[cfg_attr(coverage_nightly, coverage(off))]
560mod tests {
561 use super::*;
562 use crate::config::DEFAULT_USER_AGENT;
563
564 #[test]
565 fn test_builder_default() {
566 let builder = HttpClientBuilder::new();
567 assert_eq!(builder.config.request_timeout, Duration::from_secs(30));
568 assert_eq!(builder.config.user_agent, DEFAULT_USER_AGENT);
569 assert!(builder.config.retry.is_some());
570 assert_eq!(builder.config.buffer_capacity, 1024);
571 }
572
573 #[test]
574 fn test_builder_with_config() {
575 let config = HttpClientConfig::minimal();
576 let builder = HttpClientBuilder::with_config(config);
577 assert_eq!(builder.config.request_timeout, Duration::from_secs(10));
578 }
579
580 #[test]
581 fn test_builder_timeout() {
582 let builder = HttpClientBuilder::new().timeout(Duration::from_mins(1));
583 assert_eq!(builder.config.request_timeout, Duration::from_mins(1));
584 }
585
586 #[test]
587 fn test_builder_user_agent() {
588 let builder = HttpClientBuilder::new().user_agent("custom/1.0");
589 assert_eq!(builder.config.user_agent, "custom/1.0");
590 }
591
592 #[test]
593 fn test_builder_retry() {
594 let builder = HttpClientBuilder::new().retry(None);
595 assert!(builder.config.retry.is_none());
596 }
597
598 #[test]
599 fn test_builder_max_body_size() {
600 let builder = HttpClientBuilder::new().max_body_size(1024);
601 assert_eq!(builder.config.max_body_size, 1024);
602 }
603
604 #[test]
605 fn test_builder_transport_security() {
606 let builder = HttpClientBuilder::new().transport(TransportSecurity::TlsOnly);
607 assert_eq!(builder.config.transport, TransportSecurity::TlsOnly);
608
609 let builder = HttpClientBuilder::new().deny_insecure_http();
610 assert_eq!(builder.config.transport, TransportSecurity::TlsOnly);
611
612 let builder = HttpClientBuilder::new();
613 assert_eq!(
614 builder.config.transport,
615 TransportSecurity::AllowInsecureHttp
616 );
617 }
618
619 #[test]
620 fn test_builder_otel() {
621 let builder = HttpClientBuilder::new().with_otel();
622 assert!(builder.config.otel);
623 }
624
625 #[test]
626 fn test_builder_buffer_capacity() {
627 let builder = HttpClientBuilder::new().buffer_capacity(512);
628 assert_eq!(builder.config.buffer_capacity, 512);
629 }
630
631 #[test]
635 fn test_builder_buffer_capacity_zero_clamped() {
636 let builder = HttpClientBuilder::new().buffer_capacity(0);
637 assert_eq!(
638 builder.config.buffer_capacity, 1,
639 "buffer_capacity=0 should be clamped to 1"
640 );
641 }
642
643 #[tokio::test]
645 async fn test_builder_buffer_capacity_zero_in_config_clamped() {
646 let config = HttpClientConfig {
647 buffer_capacity: 0, ..Default::default()
649 };
650 let result = HttpClientBuilder::with_config(config).build();
651 assert!(
653 result.is_ok(),
654 "build() should succeed with capacity clamped to 1"
655 );
656 }
657
658 #[tokio::test]
659 async fn test_builder_build_with_otel() {
660 let client = HttpClientBuilder::new().with_otel().build();
661 assert!(client.is_ok());
662 }
663
664 #[tokio::test]
665 async fn test_builder_with_auth_layer() {
666 let client = HttpClientBuilder::new()
667 .with_auth_layer(|svc| svc) .build();
669 assert!(client.is_ok());
670 }
671
672 #[tokio::test]
673 async fn test_builder_with_metrics_layer() {
674 let client = HttpClientBuilder::new()
675 .with_metrics_layer(|svc| svc) .build();
677 assert!(client.is_ok());
678 }
679
680 #[tokio::test]
681 async fn test_builder_with_metrics_layer_second_call_replaces_first() {
682 use std::sync::Arc;
683 use std::sync::atomic::{AtomicUsize, Ordering};
684
685 let call_count = Arc::new(AtomicUsize::new(0));
686 let call_count2 = call_count.clone();
687
688 let client = HttpClientBuilder::new()
690 .with_metrics_layer(|_svc| {
691 panic!("first metrics layer should have been replaced");
693 })
694 .with_metrics_layer(move |svc| {
695 call_count2.fetch_add(1, Ordering::SeqCst);
696 svc
697 })
698 .build();
699
700 assert!(client.is_ok());
701 assert_eq!(
702 call_count.load(Ordering::SeqCst),
703 1,
704 "second metrics layer must be applied exactly once"
705 );
706 }
707
708 #[tokio::test]
709 async fn test_builder_build() {
710 let client = HttpClientBuilder::new().build();
711 assert!(client.is_ok());
712 }
713
714 #[tokio::test]
715 async fn test_builder_build_with_deny_insecure_http() {
716 let client = HttpClientBuilder::new().deny_insecure_http().build();
717 assert!(client.is_ok());
718 }
719
720 #[tokio::test]
721 async fn test_builder_build_with_sse_config() {
722 use crate::config::HttpClientConfig;
723 let config = HttpClientConfig::sse();
724 let client = HttpClientBuilder::with_config(config).build();
725 assert!(client.is_ok(), "SSE config should build successfully");
726 }
727
728 #[tokio::test]
729 async fn test_builder_build_invalid_user_agent() {
730 let client = HttpClientBuilder::new()
731 .user_agent("invalid\x00agent")
732 .build();
733 assert!(client.is_err());
734 }
735
736 #[tokio::test]
737 async fn test_builder_default_uses_webpki_roots() {
738 let builder = HttpClientBuilder::new();
739 assert_eq!(builder.config.tls_roots, TlsRootConfig::WebPki);
740 let client = builder.build();
742 assert!(client.is_ok());
743 }
744
745 #[tokio::test]
746 async fn test_builder_native_roots() {
747 let config = HttpClientConfig {
748 tls_roots: TlsRootConfig::Native,
749 ..Default::default()
750 };
751 let result = HttpClientBuilder::with_config(config).build();
752
753 match &result {
757 Ok(_) => {
758 }
760 Err(HttpError::Tls(err)) => {
761 let msg = err.to_string();
763 assert!(
764 msg.contains("native root") || msg.contains("certificate"),
765 "TLS error should mention certificates: {msg}"
766 );
767 }
768 Err(other) => {
769 panic!("Unexpected error type: {other:?}");
770 }
771 }
772 }
773
774 #[tokio::test]
775 async fn test_builder_webpki_roots_https_only() {
776 let config = HttpClientConfig {
777 tls_roots: TlsRootConfig::WebPki,
778 transport: TransportSecurity::TlsOnly,
779 ..Default::default()
780 };
781 let client = HttpClientBuilder::with_config(config).build();
782 assert!(client.is_ok());
783 }
784
785 #[tokio::test]
791 async fn test_http2_enabled_for_all_configurations() {
792 let client = HttpClientBuilder::new().build();
794 assert!(
795 client.is_ok(),
796 "WebPki + AllowInsecureHttp should build with HTTP/2 enabled"
797 );
798
799 let client = HttpClientBuilder::new()
801 .transport(TransportSecurity::TlsOnly)
802 .build();
803 assert!(
804 client.is_ok(),
805 "WebPki + TlsOnly should build with HTTP/2 enabled"
806 );
807
808 let config = HttpClientConfig {
810 tls_roots: TlsRootConfig::Native,
811 transport: TransportSecurity::AllowInsecureHttp,
812 ..Default::default()
813 };
814 let client = HttpClientBuilder::with_config(config).build();
815 assert!(
816 client.is_ok(),
817 "Native + AllowInsecureHttp should build with HTTP/2 enabled"
818 );
819
820 let config = HttpClientConfig {
822 tls_roots: TlsRootConfig::Native,
823 transport: TransportSecurity::TlsOnly,
824 ..Default::default()
825 };
826 let client = HttpClientBuilder::with_config(config).build();
827 assert!(
828 client.is_ok(),
829 "Native + TlsOnly should build with HTTP/2 enabled"
830 );
831 }
832
833 #[tokio::test]
838 async fn test_load_shedding_returns_overloaded_error() {
839 use bytes::Bytes;
840 use http::{Request, Response};
841 use http_body_util::Full;
842 use std::future::Future;
843 use std::pin::Pin;
844 use std::sync::Arc;
845 use std::sync::atomic::{AtomicUsize, Ordering};
846 use std::task::{Context, Poll};
847 use tower::Service;
848 use tower::ServiceExt;
849
850 #[derive(Clone)]
852 struct SlotHoldingService {
853 active: Arc<AtomicUsize>,
854 }
855
856 impl Service<Request<Full<Bytes>>> for SlotHoldingService {
857 type Response = Response<Full<Bytes>>;
858 type Error = HttpError;
859 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
860
861 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
862 Poll::Ready(Ok(()))
863 }
864
865 fn call(&mut self, _: Request<Full<Bytes>>) -> Self::Future {
866 self.active.fetch_add(1, Ordering::SeqCst);
867 Box::pin(std::future::pending())
869 }
870 }
871
872 let active = Arc::new(AtomicUsize::new(0));
873
874 let service = tower::ServiceBuilder::new()
876 .layer(LoadShedLayer::new())
877 .layer(ConcurrencyLimitLayer::new(1))
878 .service(SlotHoldingService {
879 active: active.clone(),
880 });
881
882 let service = service.map_err(map_load_shed_error);
883
884 let req1 = Request::builder()
886 .uri("http://test")
887 .body(Full::new(Bytes::new()))
888 .unwrap();
889 let mut svc1 = service.clone();
890
891 let svc1_ready = svc1.ready().await.unwrap();
892 let _pending_fut = svc1_ready.call(req1);
893
894 tokio::time::sleep(Duration::from_millis(10)).await;
896 assert_eq!(
897 active.load(Ordering::SeqCst),
898 1,
899 "First request should be active"
900 );
901
902 let req2 = Request::builder()
904 .uri("http://test")
905 .body(Full::new(Bytes::new()))
906 .unwrap();
907
908 let mut svc2 = service.clone();
909
910 let result = tokio::time::timeout(Duration::from_millis(100), async {
912 match svc2.ready().await {
914 Ok(ready_svc) => ready_svc.call(req2).await,
915 Err(e) => Err(e),
916 }
917 })
918 .await;
919
920 assert!(result.is_ok(), "Request should not hang");
922 let err = result.unwrap().unwrap_err();
923 assert!(
924 matches!(err, HttpError::Overloaded),
925 "Expected Overloaded error, got: {err:?}"
926 );
927 }
928
929 #[test]
935 fn test_map_tower_error_preserves_overloaded() {
936 let http_err = HttpError::Overloaded;
937 let boxed: tower::BoxError = Box::new(http_err);
938 let result = map_tower_error(boxed, Duration::from_secs(30));
939
940 assert!(
941 matches!(result, HttpError::Overloaded),
942 "Should preserve HttpError::Overloaded, got: {result:?}"
943 );
944 }
945
946 #[test]
948 fn test_map_tower_error_preserves_service_closed() {
949 let http_err = HttpError::ServiceClosed;
950 let boxed: tower::BoxError = Box::new(http_err);
951 let result = map_tower_error(boxed, Duration::from_secs(30));
952
953 assert!(
954 matches!(result, HttpError::ServiceClosed),
955 "Should preserve HttpError::ServiceClosed, got: {result:?}"
956 );
957 }
958
959 #[test]
961 fn test_map_tower_error_preserves_timeout_attempt() {
962 let original_duration = Duration::from_secs(5);
963 let http_err = HttpError::Timeout(original_duration);
964 let boxed: tower::BoxError = Box::new(http_err);
965 let result = map_tower_error(boxed, Duration::from_secs(30));
967
968 match result {
969 HttpError::Timeout(d) => {
970 assert_eq!(
971 d, original_duration,
972 "Should preserve original timeout duration"
973 );
974 }
975 other => panic!("Should preserve HttpError::Timeout, got: {other:?}"),
976 }
977 }
978
979 #[test]
981 fn test_map_tower_error_wraps_unknown_as_transport() {
982 let other_err: tower::BoxError = Box::new(std::io::Error::new(
983 std::io::ErrorKind::ConnectionRefused,
984 "connection refused",
985 ));
986 let result = map_tower_error(other_err, Duration::from_secs(30));
987
988 assert!(
989 matches!(result, HttpError::Transport(_)),
990 "Should wrap unknown errors as Transport, got: {result:?}"
991 );
992 }
993
994 #[tokio::test]
1013 async fn test_cancellation_propagates_through_full_stack() {
1014 use crate::response::ResponseBody;
1015 use std::future::Future;
1016 use std::pin::Pin;
1017 use std::sync::Arc;
1018 use std::sync::atomic::{AtomicBool, Ordering};
1019 use std::task::{Context, Poll};
1020 use tower::Service;
1021
1022 #[derive(Clone)]
1023 struct PendingService {
1024 completed: Arc<AtomicBool>,
1025 drop_notifier: Arc<tokio::sync::Notify>,
1026 started_notifier: Arc<tokio::sync::Notify>,
1027 }
1028
1029 struct FutureGuard {
1030 completed: Arc<AtomicBool>,
1031 drop_notifier: Arc<tokio::sync::Notify>,
1032 }
1033
1034 impl Drop for FutureGuard {
1035 fn drop(&mut self) {
1036 if !self.completed.load(Ordering::SeqCst) {
1037 self.drop_notifier.notify_one();
1038 }
1039 }
1040 }
1041
1042 impl Service<http::Request<Full<Bytes>>> for PendingService {
1043 type Response = http::Response<ResponseBody>;
1044 type Error = HttpError;
1045 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1046
1047 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1048 Poll::Ready(Ok(()))
1049 }
1050
1051 fn call(&mut self, _: http::Request<Full<Bytes>>) -> Self::Future {
1052 let completed = self.completed.clone();
1053 let drop_notifier = self.drop_notifier.clone();
1054 let started_notifier = self.started_notifier.clone();
1055 Box::pin(async move {
1056 let _guard = FutureGuard {
1057 completed: completed.clone(),
1058 drop_notifier,
1059 };
1060 started_notifier.notify_one();
1062 std::future::pending::<()>().await;
1064 completed.store(true, Ordering::SeqCst);
1065 unreachable!()
1066 })
1067 }
1068 }
1069
1070 let inner_completed = Arc::new(AtomicBool::new(false));
1071 let drop_notifier = Arc::new(tokio::sync::Notify::new());
1072 let started_notifier = Arc::new(tokio::sync::Notify::new());
1073
1074 let inner = PendingService {
1075 completed: inner_completed.clone(),
1076 drop_notifier: drop_notifier.clone(),
1077 started_notifier: started_notifier.clone(),
1078 };
1079
1080 let client = HttpClientBuilder::new()
1083 .timeout(Duration::from_secs(30))
1084 .retry(None)
1085 .build_with_inner_service(inner.boxed_clone());
1086
1087 let send_handle = tokio::spawn({
1089 let client = client.clone();
1090 async move { client.get("http://fake/slow").send().await }
1091 });
1092
1093 started_notifier.notified().await;
1095
1096 send_handle.abort();
1098
1099 tokio::time::timeout(Duration::from_secs(5), drop_notifier.notified())
1101 .await
1102 .expect(
1103 "Inner service future should have been dropped within 5s - \
1104 the full modkit-http stack must propagate cancellation",
1105 );
1106
1107 assert!(
1108 !inner_completed.load(Ordering::SeqCst),
1109 "Inner service future should NOT have completed"
1110 );
1111 }
1112}