1use crate::config::{
2 HttpClientConfig, RedirectConfig, RetryConfig, TlsRootConfig, TransportSecurity,
3};
4use crate::error::HttpError;
5use crate::layers::{OtelLayer, RetryLayer, SecureRedirectPolicy, UserAgentLayer};
6use crate::response::ResponseBody;
7use crate::tls;
8use bytes::Bytes;
9use http::Response;
10use http_body_util::{BodyExt, Full};
11use hyper_rustls::HttpsConnector;
12use hyper_util::client::legacy::Client;
13use hyper_util::client::legacy::connect::HttpConnector;
14use hyper_util::rt::{TokioExecutor, TokioTimer};
15use std::time::Duration;
16use tower::buffer::Buffer;
17use tower::limit::ConcurrencyLimitLayer;
18use tower::load_shed::LoadShedLayer;
19use tower::timeout::TimeoutLayer;
20use tower::util::BoxCloneService;
21use tower::{ServiceBuilder, ServiceExt};
22use tower_http::decompression::DecompressionLayer;
23use tower_http::follow_redirect::FollowRedirectLayer;
24
25type InnerService =
27 BoxCloneService<http::Request<Full<Bytes>>, http::Response<ResponseBody>, HttpError>;
28
29pub struct HttpClientBuilder {
31 config: HttpClientConfig,
32 auth_layer: Option<Box<dyn FnOnce(InnerService) -> InnerService + Send>>,
33}
34
35impl HttpClientBuilder {
36 #[must_use]
38 pub fn new() -> Self {
39 Self {
40 config: HttpClientConfig::default(),
41 auth_layer: None,
42 }
43 }
44
45 #[must_use]
47 pub fn with_config(config: HttpClientConfig) -> Self {
48 Self {
49 config,
50 auth_layer: None,
51 }
52 }
53
54 #[must_use]
59 pub fn timeout(mut self, timeout: Duration) -> Self {
60 self.config.request_timeout = timeout;
61 self
62 }
63
64 #[must_use]
70 pub fn total_timeout(mut self, timeout: Duration) -> Self {
71 self.config.total_timeout = Some(timeout);
72 self
73 }
74
75 #[must_use]
77 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
78 self.config.user_agent = user_agent.into();
79 self
80 }
81
82 #[must_use]
84 pub fn retry(mut self, retry: Option<RetryConfig>) -> Self {
85 self.config.retry = retry;
86 self
87 }
88
89 #[must_use]
91 pub fn max_body_size(mut self, size: usize) -> Self {
92 self.config.max_body_size = size;
93 self
94 }
95
96 #[must_use]
100 pub fn transport(mut self, transport: TransportSecurity) -> Self {
101 self.config.transport = transport;
102 self
103 }
104
105 #[must_use]
111 pub fn deny_insecure_http(mut self) -> Self {
112 tracing::debug!(
113 target: "modkit_http::security",
114 "deny_insecure_http() called - enforcing TLS for all connections"
115 );
116 self.config.transport = TransportSecurity::TlsOnly;
117 self
118 }
119
120 #[must_use]
125 pub fn with_otel(mut self) -> Self {
126 self.config.otel = true;
127 self
128 }
129
130 #[must_use]
138 pub fn with_auth_layer(
139 mut self,
140 wrap: impl FnOnce(InnerService) -> InnerService + Send + 'static,
141 ) -> Self {
142 self.auth_layer = Some(Box::new(wrap));
143 self
144 }
145
146 #[must_use]
155 pub fn buffer_capacity(mut self, capacity: usize) -> Self {
156 self.config.buffer_capacity = capacity.max(1);
158 self
159 }
160
161 #[must_use]
166 pub fn max_redirects(mut self, max_redirects: usize) -> Self {
167 self.config.redirect.max_redirects = max_redirects;
168 self
169 }
170
171 #[must_use]
176 pub fn no_redirects(mut self) -> Self {
177 self.config.redirect = RedirectConfig::disabled();
178 self
179 }
180
181 #[must_use]
196 pub fn redirect(mut self, config: RedirectConfig) -> Self {
197 self.config.redirect = config;
198 self
199 }
200
201 #[must_use]
208 pub fn pool_idle_timeout(mut self, timeout: Option<Duration>) -> Self {
209 self.config.pool_idle_timeout = timeout;
210 self
211 }
212
213 #[must_use]
221 pub fn pool_max_idle_per_host(mut self, max: usize) -> Self {
222 self.config.pool_max_idle_per_host = max;
223 self
224 }
225
226 pub fn build(self) -> Result<crate::HttpClient, HttpError> {
231 let timeout = self.config.request_timeout;
232 let total_timeout = self.config.total_timeout;
233
234 let https = build_https_connector(self.config.tls_roots, self.config.transport)?;
236
237 let mut client_builder = Client::builder(TokioExecutor::new());
239
240 client_builder
243 .pool_timer(TokioTimer::new())
244 .pool_max_idle_per_host(self.config.pool_max_idle_per_host)
245 .http2_only(false); if let Some(idle_timeout) = self.config.pool_idle_timeout {
249 client_builder.pool_idle_timeout(idle_timeout);
250 }
251
252 let hyper_client = client_builder.build::<_, Full<Bytes>>(https);
253
254 let ua_layer = UserAgentLayer::try_new(&self.config.user_agent)?;
256
257 let redirect_policy = SecureRedirectPolicy::new(self.config.redirect.clone());
289
290 let service = ServiceBuilder::new()
292 .layer(TimeoutLayer::new(timeout))
293 .layer(ua_layer)
294 .layer(DecompressionLayer::new())
295 .layer(FollowRedirectLayer::with_policy(redirect_policy))
296 .service(hyper_client);
297
298 let service = service.map_response(map_decompression_response);
304
305 let service = service.map_err(move |e: tower::BoxError| map_tower_error(e, timeout));
307
308 let mut boxed_service = service.boxed_clone();
310
311 if let Some(wrap) = self.auth_layer {
314 boxed_service = wrap(boxed_service);
315 }
316
317 if let Some(ref retry_config) = self.config.retry {
329 let retry_layer = RetryLayer::with_total_timeout(retry_config.clone(), total_timeout);
330 let retry_service = ServiceBuilder::new()
331 .layer(retry_layer)
332 .service(boxed_service);
333 boxed_service = retry_service.boxed_clone();
334 }
335
336 if let Some(rate_limit) = self.config.rate_limit
340 && rate_limit.max_concurrent_requests < usize::MAX
341 {
342 let limited_service = ServiceBuilder::new()
343 .layer(LoadShedLayer::new())
344 .layer(ConcurrencyLimitLayer::new(
345 rate_limit.max_concurrent_requests,
346 ))
347 .service(boxed_service);
348 let limited_service = limited_service.map_err(map_load_shed_error);
350 boxed_service = limited_service.boxed_clone();
351 }
352
353 if self.config.otel {
357 let otel_service = ServiceBuilder::new()
358 .layer(OtelLayer::new())
359 .service(boxed_service);
360 boxed_service = otel_service.boxed_clone();
361 }
362
363 let buffer_capacity = self.config.buffer_capacity.max(1);
367 let buffered_service: crate::client::BufferedService =
368 Buffer::new(boxed_service, buffer_capacity);
369
370 Ok(crate::HttpClient {
371 service: buffered_service,
372 max_body_size: self.config.max_body_size,
373 transport_security: self.config.transport,
374 })
375 }
376}
377
378impl Default for HttpClientBuilder {
379 fn default() -> Self {
380 Self::new()
381 }
382}
383
384fn map_tower_error(err: tower::BoxError, timeout: Duration) -> HttpError {
390 if err.is::<tower::timeout::error::Elapsed>() {
391 return HttpError::Timeout(timeout);
392 }
393
394 match err.downcast::<HttpError>() {
396 Ok(http_err) => *http_err,
397 Err(other) => HttpError::Transport(other),
398 }
399}
400
401fn map_load_shed_error(err: tower::BoxError) -> HttpError {
403 if err.is::<tower::load_shed::error::Overloaded>() {
404 HttpError::Overloaded
405 } else {
406 match err.downcast::<HttpError>() {
408 Ok(http_err) => *http_err,
409 Err(err) => HttpError::Transport(err),
410 }
411 }
412}
413
414fn map_decompression_response<B>(response: Response<B>) -> Response<ResponseBody>
419where
420 B: hyper::body::Body<Data = Bytes> + Send + Sync + 'static,
421 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
422{
423 let (parts, body) = response.into_parts();
424 let boxed_body: ResponseBody = body.map_err(Into::into).boxed();
428 Response::from_parts(parts, boxed_body)
429}
430
431fn build_https_connector(
445 tls_roots: TlsRootConfig,
446 transport: TransportSecurity,
447) -> Result<HttpsConnector<HttpConnector>, HttpError> {
448 let allow_http = transport == TransportSecurity::AllowInsecureHttp;
449
450 match tls_roots {
451 TlsRootConfig::WebPki => {
452 let provider = tls::get_crypto_provider();
453 let builder = hyper_rustls::HttpsConnectorBuilder::new()
454 .with_provider_and_webpki_roots(provider)
455 .map_err(|e| HttpError::Tls(Box::new(e)))?;
458 let connector = if allow_http {
459 builder.https_or_http().enable_all_versions().build()
460 } else {
461 builder.https_only().enable_all_versions().build()
462 };
463 Ok(connector)
464 }
465 TlsRootConfig::Native => {
466 let client_config = tls::native_roots_client_config()
467 .map_err(|e| HttpError::Tls(e.into()))?;
469 let builder = hyper_rustls::HttpsConnectorBuilder::new().with_tls_config(client_config);
470 let connector = if allow_http {
471 builder.https_or_http().enable_all_versions().build()
472 } else {
473 builder.https_only().enable_all_versions().build()
474 };
475 Ok(connector)
476 }
477 }
478}
479
480#[cfg(test)]
481#[cfg_attr(coverage_nightly, coverage(off))]
482mod tests {
483 use super::*;
484 use crate::config::DEFAULT_USER_AGENT;
485
486 #[test]
487 fn test_builder_default() {
488 let builder = HttpClientBuilder::new();
489 assert_eq!(builder.config.request_timeout, Duration::from_secs(30));
490 assert_eq!(builder.config.user_agent, DEFAULT_USER_AGENT);
491 assert!(builder.config.retry.is_some());
492 assert_eq!(builder.config.buffer_capacity, 1024);
493 }
494
495 #[test]
496 fn test_builder_with_config() {
497 let config = HttpClientConfig::minimal();
498 let builder = HttpClientBuilder::with_config(config);
499 assert_eq!(builder.config.request_timeout, Duration::from_secs(10));
500 }
501
502 #[test]
503 fn test_builder_timeout() {
504 let builder = HttpClientBuilder::new().timeout(Duration::from_secs(60));
505 assert_eq!(builder.config.request_timeout, Duration::from_secs(60));
506 }
507
508 #[test]
509 fn test_builder_user_agent() {
510 let builder = HttpClientBuilder::new().user_agent("custom/1.0");
511 assert_eq!(builder.config.user_agent, "custom/1.0");
512 }
513
514 #[test]
515 fn test_builder_retry() {
516 let builder = HttpClientBuilder::new().retry(None);
517 assert!(builder.config.retry.is_none());
518 }
519
520 #[test]
521 fn test_builder_max_body_size() {
522 let builder = HttpClientBuilder::new().max_body_size(1024);
523 assert_eq!(builder.config.max_body_size, 1024);
524 }
525
526 #[test]
527 fn test_builder_transport_security() {
528 let builder = HttpClientBuilder::new().transport(TransportSecurity::TlsOnly);
529 assert_eq!(builder.config.transport, TransportSecurity::TlsOnly);
530
531 let builder = HttpClientBuilder::new().deny_insecure_http();
532 assert_eq!(builder.config.transport, TransportSecurity::TlsOnly);
533
534 let builder = HttpClientBuilder::new();
535 assert_eq!(
536 builder.config.transport,
537 TransportSecurity::AllowInsecureHttp
538 );
539 }
540
541 #[test]
542 fn test_builder_otel() {
543 let builder = HttpClientBuilder::new().with_otel();
544 assert!(builder.config.otel);
545 }
546
547 #[test]
548 fn test_builder_buffer_capacity() {
549 let builder = HttpClientBuilder::new().buffer_capacity(512);
550 assert_eq!(builder.config.buffer_capacity, 512);
551 }
552
553 #[test]
557 fn test_builder_buffer_capacity_zero_clamped() {
558 let builder = HttpClientBuilder::new().buffer_capacity(0);
559 assert_eq!(
560 builder.config.buffer_capacity, 1,
561 "buffer_capacity=0 should be clamped to 1"
562 );
563 }
564
565 #[tokio::test]
567 async fn test_builder_buffer_capacity_zero_in_config_clamped() {
568 let config = HttpClientConfig {
569 buffer_capacity: 0, ..Default::default()
571 };
572 let result = HttpClientBuilder::with_config(config).build();
573 assert!(
575 result.is_ok(),
576 "build() should succeed with capacity clamped to 1"
577 );
578 }
579
580 #[tokio::test]
581 async fn test_builder_build_with_otel() {
582 let client = HttpClientBuilder::new().with_otel().build();
583 assert!(client.is_ok());
584 }
585
586 #[tokio::test]
587 async fn test_builder_with_auth_layer() {
588 let client = HttpClientBuilder::new()
589 .with_auth_layer(|svc| svc) .build();
591 assert!(client.is_ok());
592 }
593
594 #[tokio::test]
595 async fn test_builder_build() {
596 let client = HttpClientBuilder::new().build();
597 assert!(client.is_ok());
598 }
599
600 #[tokio::test]
601 async fn test_builder_build_with_deny_insecure_http() {
602 let client = HttpClientBuilder::new().deny_insecure_http().build();
603 assert!(client.is_ok());
604 }
605
606 #[tokio::test]
607 async fn test_builder_build_with_sse_config() {
608 use crate::config::HttpClientConfig;
609 let config = HttpClientConfig::sse();
610 let client = HttpClientBuilder::with_config(config).build();
611 assert!(client.is_ok(), "SSE config should build successfully");
612 }
613
614 #[tokio::test]
615 async fn test_builder_build_invalid_user_agent() {
616 let client = HttpClientBuilder::new()
617 .user_agent("invalid\x00agent")
618 .build();
619 assert!(client.is_err());
620 }
621
622 #[tokio::test]
623 async fn test_builder_default_uses_webpki_roots() {
624 let builder = HttpClientBuilder::new();
625 assert_eq!(builder.config.tls_roots, TlsRootConfig::WebPki);
626 let client = builder.build();
628 assert!(client.is_ok());
629 }
630
631 #[tokio::test]
632 async fn test_builder_native_roots() {
633 let config = HttpClientConfig {
634 tls_roots: TlsRootConfig::Native,
635 ..Default::default()
636 };
637 let result = HttpClientBuilder::with_config(config).build();
638
639 match &result {
643 Ok(_) => {
644 }
646 Err(HttpError::Tls(err)) => {
647 let msg = err.to_string();
649 assert!(
650 msg.contains("native root") || msg.contains("certificate"),
651 "TLS error should mention certificates: {msg}"
652 );
653 }
654 Err(other) => {
655 panic!("Unexpected error type: {other:?}");
656 }
657 }
658 }
659
660 #[tokio::test]
661 async fn test_builder_webpki_roots_https_only() {
662 let config = HttpClientConfig {
663 tls_roots: TlsRootConfig::WebPki,
664 transport: TransportSecurity::TlsOnly,
665 ..Default::default()
666 };
667 let client = HttpClientBuilder::with_config(config).build();
668 assert!(client.is_ok());
669 }
670
671 #[tokio::test]
677 async fn test_http2_enabled_for_all_configurations() {
678 let client = HttpClientBuilder::new().build();
680 assert!(
681 client.is_ok(),
682 "WebPki + AllowInsecureHttp should build with HTTP/2 enabled"
683 );
684
685 let client = HttpClientBuilder::new()
687 .transport(TransportSecurity::TlsOnly)
688 .build();
689 assert!(
690 client.is_ok(),
691 "WebPki + TlsOnly should build with HTTP/2 enabled"
692 );
693
694 let config = HttpClientConfig {
696 tls_roots: TlsRootConfig::Native,
697 transport: TransportSecurity::AllowInsecureHttp,
698 ..Default::default()
699 };
700 let client = HttpClientBuilder::with_config(config).build();
701 assert!(
702 client.is_ok(),
703 "Native + AllowInsecureHttp should build with HTTP/2 enabled"
704 );
705
706 let config = HttpClientConfig {
708 tls_roots: TlsRootConfig::Native,
709 transport: TransportSecurity::TlsOnly,
710 ..Default::default()
711 };
712 let client = HttpClientBuilder::with_config(config).build();
713 assert!(
714 client.is_ok(),
715 "Native + TlsOnly should build with HTTP/2 enabled"
716 );
717 }
718
719 #[tokio::test]
724 async fn test_load_shedding_returns_overloaded_error() {
725 use bytes::Bytes;
726 use http::{Request, Response};
727 use http_body_util::Full;
728 use std::future::Future;
729 use std::pin::Pin;
730 use std::sync::Arc;
731 use std::sync::atomic::{AtomicUsize, Ordering};
732 use std::task::{Context, Poll};
733 use tower::Service;
734 use tower::ServiceExt;
735
736 #[derive(Clone)]
738 struct SlotHoldingService {
739 active: Arc<AtomicUsize>,
740 }
741
742 impl Service<Request<Full<Bytes>>> for SlotHoldingService {
743 type Response = Response<Full<Bytes>>;
744 type Error = HttpError;
745 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
746
747 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
748 Poll::Ready(Ok(()))
749 }
750
751 fn call(&mut self, _: Request<Full<Bytes>>) -> Self::Future {
752 self.active.fetch_add(1, Ordering::SeqCst);
753 Box::pin(std::future::pending())
755 }
756 }
757
758 let active = Arc::new(AtomicUsize::new(0));
759
760 let service = tower::ServiceBuilder::new()
762 .layer(LoadShedLayer::new())
763 .layer(ConcurrencyLimitLayer::new(1))
764 .service(SlotHoldingService {
765 active: active.clone(),
766 });
767
768 let service = service.map_err(map_load_shed_error);
769
770 let req1 = Request::builder()
772 .uri("http://test")
773 .body(Full::new(Bytes::new()))
774 .unwrap();
775 let mut svc1 = service.clone();
776
777 let svc1_ready = svc1.ready().await.unwrap();
778 let _pending_fut = svc1_ready.call(req1);
779
780 tokio::time::sleep(Duration::from_millis(10)).await;
782 assert_eq!(
783 active.load(Ordering::SeqCst),
784 1,
785 "First request should be active"
786 );
787
788 let req2 = Request::builder()
790 .uri("http://test")
791 .body(Full::new(Bytes::new()))
792 .unwrap();
793
794 let mut svc2 = service.clone();
795
796 let result = tokio::time::timeout(Duration::from_millis(100), async {
798 match svc2.ready().await {
800 Ok(ready_svc) => ready_svc.call(req2).await,
801 Err(e) => Err(e),
802 }
803 })
804 .await;
805
806 assert!(result.is_ok(), "Request should not hang");
808 let err = result.unwrap().unwrap_err();
809 assert!(
810 matches!(err, HttpError::Overloaded),
811 "Expected Overloaded error, got: {err:?}"
812 );
813 }
814
815 #[test]
821 fn test_map_tower_error_preserves_overloaded() {
822 let http_err = HttpError::Overloaded;
823 let boxed: tower::BoxError = Box::new(http_err);
824 let result = map_tower_error(boxed, Duration::from_secs(30));
825
826 assert!(
827 matches!(result, HttpError::Overloaded),
828 "Should preserve HttpError::Overloaded, got: {result:?}"
829 );
830 }
831
832 #[test]
834 fn test_map_tower_error_preserves_service_closed() {
835 let http_err = HttpError::ServiceClosed;
836 let boxed: tower::BoxError = Box::new(http_err);
837 let result = map_tower_error(boxed, Duration::from_secs(30));
838
839 assert!(
840 matches!(result, HttpError::ServiceClosed),
841 "Should preserve HttpError::ServiceClosed, got: {result:?}"
842 );
843 }
844
845 #[test]
847 fn test_map_tower_error_preserves_timeout_attempt() {
848 let original_duration = Duration::from_secs(5);
849 let http_err = HttpError::Timeout(original_duration);
850 let boxed: tower::BoxError = Box::new(http_err);
851 let result = map_tower_error(boxed, Duration::from_secs(30));
853
854 match result {
855 HttpError::Timeout(d) => {
856 assert_eq!(
857 d, original_duration,
858 "Should preserve original timeout duration"
859 );
860 }
861 other => panic!("Should preserve HttpError::Timeout, got: {other:?}"),
862 }
863 }
864
865 #[test]
867 fn test_map_tower_error_wraps_unknown_as_transport() {
868 let other_err: tower::BoxError = Box::new(std::io::Error::new(
869 std::io::ErrorKind::ConnectionRefused,
870 "connection refused",
871 ));
872 let result = map_tower_error(other_err, Duration::from_secs(30));
873
874 assert!(
875 matches!(result, HttpError::Transport(_)),
876 "Should wrap unknown errors as Transport, got: {result:?}"
877 );
878 }
879}