1use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use bytes::Bytes;
11use http_body_util::{BodyExt, Full};
12use hyper_rustls::HttpsConnector;
13use hyper_util::{
14 client::legacy::{Client, connect::HttpConnector},
15 rt::TokioExecutor,
16};
17use tower::Layer;
18use tower::util::BoxCloneService;
19use tower_service::Service;
20
21use crate::{
22 Error, Request, Response, Result,
23 config::{ClientConfig, ClientConfigBuilder},
24 connector::https_connector,
25};
26
27#[cfg(feature = "streaming")]
29use futures_util::TryStreamExt;
30#[cfg(feature = "streaming")]
31use http_body_util::BodyStream;
32#[cfg(feature = "streaming")]
33use pincer_core::StreamingBody;
34
35#[cfg(feature = "middleware-basic-auth")]
37use crate::middleware::BasicAuthLayer;
38#[cfg(feature = "middleware-bearer-auth")]
39use crate::middleware::BearerAuthLayer;
40#[cfg(feature = "middleware-decompression")]
41use crate::middleware::DecompressionLayer;
42#[cfg(feature = "middleware-follow-redirect")]
43use crate::middleware::FollowRedirectLayer;
44#[cfg(feature = "middleware-logging")]
45use crate::middleware::LoggingLayer;
46#[cfg(feature = "middleware-metrics")]
47use crate::middleware::MetricsLayer;
48#[cfg(feature = "middleware-rate-limit")]
49use crate::middleware::RateLimitLayer;
50#[cfg(feature = "middleware-retry")]
51use crate::middleware::RetryPolicy;
52#[cfg(feature = "middleware-circuit-breaker")]
53use crate::middleware::{CircuitBreakerConfig, CircuitBreakerLayer};
54#[cfg(feature = "middleware-concurrency")]
55use tower::limit::ConcurrencyLimitLayer;
56#[cfg(feature = "middleware-retry")]
57use tower::retry::RetryLayer;
58
59pub type BoxedService = BoxCloneService<Request<Bytes>, Response<Bytes>, Error>;
68
69pub type ServiceFuture = Pin<Box<dyn Future<Output = Result<Response<Bytes>>> + Send + 'static>>;
71
72#[derive(Clone)]
77struct SyncService {
78 inner: Arc<Mutex<BoxedService>>,
79}
80
81impl SyncService {
82 fn new(service: BoxedService) -> Self {
83 Self {
84 inner: Arc::new(Mutex::new(service)),
85 }
86 }
87
88 fn call(&self, request: Request<Bytes>) -> ServiceFuture {
89 let mut service = self
91 .inner
92 .lock()
93 .unwrap_or_else(std::sync::PoisonError::into_inner)
94 .clone();
95
96 Box::pin(async move { service.call(request).await })
97 }
98}
99
100#[derive(Clone)]
106struct RawHyperClient {
107 inner: Client<HttpsConnector<HttpConnector>, Full<Bytes>>,
108 config: ClientConfig,
109}
110
111impl RawHyperClient {
112 fn new(config: ClientConfig) -> Self {
113 let connector = https_connector();
114
115 let inner = Client::builder(TokioExecutor::new())
116 .pool_idle_timeout(config.pool_idle_timeout)
117 .pool_max_idle_per_host(config.pool_idle_per_host)
118 .build(connector);
119
120 Self { inner, config }
121 }
122
123 fn build_hyper_request(request: Request<Bytes>) -> Result<http::Request<Full<Bytes>>> {
125 let (method, url, headers, body, extensions) = request.into_parts();
126
127 let mut builder = http::Request::builder()
128 .method(http::Method::from(method))
129 .uri(url.as_str());
130
131 for (name, value) in &headers {
132 builder = builder.header(name.as_str(), value.as_str());
133 }
134
135 let body = body.map_or_else(Full::default, Full::new);
136 let mut http_request = builder
137 .body(body)
138 .map_err(|e| Error::invalid_request(e.to_string()))?;
139
140 *http_request.extensions_mut() = extensions;
142
143 Ok(http_request)
144 }
145
146 fn extract_headers(headers: &http::HeaderMap) -> HashMap<String, String> {
148 headers
149 .iter()
150 .filter_map(|(name, value)| {
151 value
152 .to_str()
153 .ok()
154 .map(|v| (name.to_string(), v.to_string()))
155 })
156 .collect()
157 }
158
159 async fn execute(&self, request: Request<Bytes>) -> Result<Response<Bytes>> {
160 let hyper_request = Self::build_hyper_request(request)?;
161
162 let response = tokio::time::timeout(self.config.timeout, self.inner.request(hyper_request))
163 .await
164 .map_err(|_| Error::Timeout)?
165 .map_err(Self::map_hyper_error)?;
166
167 let status = response.status().as_u16();
168 let response_headers = Self::extract_headers(response.headers());
169
170 let body = response
171 .into_body()
172 .collect()
173 .await
174 .map_err(|e| Error::connection(e.to_string()))?
175 .to_bytes();
176
177 Ok(Response::new(status, response_headers, body))
178 }
179
180 #[allow(clippy::needless_pass_by_value)]
181 fn map_hyper_error(err: hyper_util::client::legacy::Error) -> Error {
182 let msg = err.to_string();
183
184 if err.is_connect() {
185 return Error::connection(msg);
186 }
187
188 if msg.contains("ssl") || msg.contains("tls") || msg.contains("certificate") {
189 return Error::tls(msg);
190 }
191
192 Error::connection(msg)
193 }
194
195 #[cfg(feature = "streaming")]
197 async fn execute_streaming(
198 &self,
199 request: Request<Bytes>,
200 ) -> Result<pincer_core::StreamingResponse> {
201 let hyper_request = Self::build_hyper_request(request)?;
202
203 let response = tokio::time::timeout(self.config.timeout, self.inner.request(hyper_request))
204 .await
205 .map_err(|_| Error::Timeout)?
206 .map_err(Self::map_hyper_error)?;
207
208 let status = response.status().as_u16();
209 let response_headers = Self::extract_headers(response.headers());
210
211 let body_stream = BodyStream::new(response.into_body());
212 let streaming_body: StreamingBody = Box::pin(
213 body_stream
214 .map_ok(|frame| frame.into_data().unwrap_or_default())
215 .map_err(|e| Error::connection(e.to_string())),
216 );
217
218 Ok(pincer_core::StreamingResponse::new(
219 status,
220 response_headers,
221 streaming_body,
222 ))
223 }
224}
225
226impl Service<Request<Bytes>> for RawHyperClient {
227 type Response = Response<Bytes>;
228 type Error = Error;
229 type Future = Pin<Box<dyn Future<Output = Result<Self::Response>> + Send + 'static>>;
230
231 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<()>> {
232 Poll::Ready(Ok(()))
233 }
234
235 fn call(&mut self, request: Request<Bytes>) -> Self::Future {
236 let client = self.clone();
237 Box::pin(async move { client.execute(request).await })
238 }
239}
240
241#[derive(Clone)]
263pub struct HyperClient {
264 service: SyncService,
265 config: ClientConfig,
266}
267
268impl std::fmt::Debug for HyperClient {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 f.debug_struct("HyperClient")
271 .field("config", &self.config)
272 .finish_non_exhaustive()
273 }
274}
275
276impl HyperClient {
277 #[must_use]
279 pub fn new() -> Self {
280 Self::with_config(ClientConfig::default())
281 }
282
283 #[must_use]
285 pub fn with_config(config: ClientConfig) -> Self {
286 let raw = RawHyperClient::new(config.clone());
287 Self {
288 service: SyncService::new(BoxCloneService::new(raw)),
289 config,
290 }
291 }
292
293 fn with_config_raw(config: ClientConfig) -> RawHyperClient {
295 RawHyperClient::new(config)
296 }
297
298 fn with_service(service: BoxedService, config: ClientConfig) -> Self {
300 Self {
301 service: SyncService::new(service),
302 config,
303 }
304 }
305
306 #[must_use]
308 pub fn builder() -> HyperClientBuilder {
309 HyperClientBuilder::default()
310 }
311
312 #[must_use]
314 pub const fn config(&self) -> &ClientConfig {
315 &self.config
316 }
317}
318
319impl Default for HyperClient {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325impl pincer_core::HttpClient for HyperClient {
326 async fn execute(&self, request: Request<Bytes>) -> Result<Response<Bytes>> {
327 self.service.call(request).await
328 }
329}
330
331#[cfg(feature = "streaming")]
336impl pincer_core::HttpClientStreaming for HyperClient {
337 async fn execute_streaming(
338 &self,
339 request: Request<Bytes>,
340 ) -> Result<pincer_core::StreamingResponse> {
341 let raw_client = RawHyperClient::new(self.config.clone());
343 raw_client.execute_streaming(request).await
344 }
345}
346
347impl Service<Request<Bytes>> for HyperClient {
352 type Response = Response<Bytes>;
353 type Error = Error;
354 type Future = ServiceFuture;
355
356 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<()>> {
357 Poll::Ready(Ok(()))
359 }
360
361 fn call(&mut self, request: Request<Bytes>) -> Self::Future {
362 self.service.call(request)
363 }
364}
365
366#[derive(Default)]
389pub struct HyperClientBuilder {
390 config: ClientConfigBuilder,
391 layers: Vec<Arc<dyn Fn(BoxedService) -> BoxedService + Send + Sync>>,
392 use_defaults: bool,
393}
394
395impl std::fmt::Debug for HyperClientBuilder {
396 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397 f.debug_struct("HyperClientBuilder")
398 .field("config", &self.config)
399 .field("layers_count", &self.layers.len())
400 .field("use_defaults", &self.use_defaults)
401 .finish()
402 }
403}
404
405impl HyperClientBuilder {
406 #[must_use]
412 pub fn timeout(mut self, timeout: Duration) -> Self {
413 self.config = self.config.timeout(timeout);
414 self
415 }
416
417 #[must_use]
419 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
420 self.config = self.config.connect_timeout(timeout);
421 self
422 }
423
424 #[must_use]
426 pub fn pool_idle_per_host(mut self, count: usize) -> Self {
427 self.config = self.config.pool_idle_per_host(count);
428 self
429 }
430
431 #[must_use]
433 pub fn pool_idle_timeout(mut self, timeout: Duration) -> Self {
434 self.config = self.config.pool_idle_timeout(timeout);
435 self
436 }
437
438 #[must_use]
458 pub fn layer<L>(mut self, layer: L) -> Self
459 where
460 L: Layer<BoxedService> + Send + Sync + 'static,
461 L::Service: Service<Request<Bytes>, Response = Response<Bytes>, Error = Error>
462 + Clone
463 + Send
464 + 'static,
465 <L::Service as Service<Request<Bytes>>>::Future: Send,
466 {
467 self.layers.push(Arc::new(move |service| {
468 BoxCloneService::new(layer.layer(service))
469 }));
470 self
471 }
472
473 #[must_use]
477 pub fn with<L>(self, layer: L) -> Self
478 where
479 L: Layer<BoxedService> + Send + Sync + 'static,
480 L::Service: Service<Request<Bytes>, Response = Response<Bytes>, Error = Error>
481 + Clone
482 + Send
483 + 'static,
484 <L::Service as Service<Request<Bytes>>>::Future: Send,
485 {
486 self.layer(layer)
487 }
488
489 #[must_use]
501 pub fn with_defaults(mut self) -> Self {
502 self.use_defaults = true;
503 self
504 }
505
506 #[must_use]
508 pub fn without_defaults(mut self) -> Self {
509 self.use_defaults = false;
510 self
511 }
512
513 #[cfg(feature = "middleware-retry")]
529 #[must_use]
530 pub fn with_retry(self, max_retries: u32) -> Self {
531 self.layer(RetryLayer::new(RetryPolicy::new(max_retries)))
532 }
533
534 #[cfg(feature = "middleware-bearer-auth")]
544 #[must_use]
545 pub fn with_bearer_auth(self, token: impl Into<String>) -> Self {
546 self.layer(BearerAuthLayer::new(token))
547 }
548
549 #[cfg(feature = "middleware-basic-auth")]
559 #[must_use]
560 pub fn with_basic_auth(self, username: impl AsRef<str>, password: impl AsRef<str>) -> Self {
561 self.layer(BasicAuthLayer::new(username, password))
562 }
563
564 #[cfg(feature = "middleware-logging")]
574 #[must_use]
575 pub fn with_logging(self) -> Self {
576 self.layer(LoggingLayer::new())
577 }
578
579 #[cfg(feature = "middleware-logging")]
581 #[must_use]
582 pub fn with_debug_logging(self) -> Self {
583 self.layer(LoggingLayer::debug())
584 }
585
586 #[cfg(feature = "middleware-concurrency")]
596 #[must_use]
597 pub fn with_concurrency_limit(self, max: usize) -> Self {
598 self.layer(ConcurrencyLimitLayer::new(max))
599 }
600
601 #[cfg(feature = "middleware-rate-limit")]
611 #[must_use]
612 pub fn with_rate_limit_per_second(self, count: u32) -> Self {
613 self.layer(RateLimitLayer::per_second(count))
614 }
615
616 #[cfg(feature = "middleware-rate-limit")]
626 #[must_use]
627 pub fn with_rate_limit_per_minute(self, count: u32) -> Self {
628 self.layer(RateLimitLayer::per_minute(count))
629 }
630
631 #[cfg(feature = "middleware-circuit-breaker")]
643 #[must_use]
644 pub fn with_circuit_breaker(self) -> Self {
645 self.layer(CircuitBreakerLayer::new(CircuitBreakerConfig::default()))
646 }
647
648 #[cfg(feature = "middleware-circuit-breaker")]
665 #[must_use]
666 pub fn with_circuit_breaker_config(
667 self,
668 config: crate::middleware::CircuitBreakerConfig,
669 ) -> Self {
670 self.layer(CircuitBreakerLayer::new(config))
671 }
672
673 #[cfg(feature = "middleware-metrics")]
688 #[must_use]
689 pub fn with_metrics(self) -> Self {
690 self.layer(MetricsLayer::new())
691 }
692
693 #[cfg(feature = "middleware-follow-redirect")]
706 #[must_use]
707 pub fn with_follow_redirects(self) -> Self {
708 self.layer(FollowRedirectLayer::new())
709 }
710
711 #[cfg(feature = "middleware-follow-redirect")]
721 #[must_use]
722 pub fn with_follow_redirects_max(self, max_redirects: usize) -> Self {
723 self.layer(FollowRedirectLayer::with_max_redirects(max_redirects))
724 }
725
726 #[cfg(feature = "middleware-decompression")]
740 #[must_use]
741 pub fn with_decompression(self) -> Self {
742 self.layer(DecompressionLayer::new())
743 }
744
745 #[must_use]
751 pub fn build(self) -> HyperClient {
752 let config = self.config.build();
753 let base_client = HyperClient::with_config_raw(config.clone());
754
755 let mut service: BoxedService = BoxCloneService::new(base_client);
757
758 if self.use_defaults {
760 #[cfg(feature = "middleware-logging")]
761 {
762 service = BoxCloneService::new(LoggingLayer::new().layer(service));
763 }
764 }
765
766 for layer_fn in self.layers {
768 service = layer_fn(service);
769 }
770
771 HyperClient::with_service(service, config)
772 }
773}
774
775#[cfg(test)]
776mod tests {
777 use super::*;
778
779 #[test]
780 fn client_default() {
781 let client = HyperClient::new();
782 assert_eq!(client.config().timeout, std::time::Duration::from_secs(30));
783 }
784
785 #[test]
786 fn client_builder() {
787 let client = HyperClient::builder()
788 .timeout(std::time::Duration::from_secs(60))
789 .pool_idle_per_host(16)
790 .build();
791
792 assert_eq!(client.config().timeout, std::time::Duration::from_secs(60));
793 assert_eq!(client.config().pool_idle_per_host, 16);
794 }
795
796 #[test]
797 fn client_is_clone() {
798 let client = HyperClient::new();
799 let _cloned = client.clone();
800 }
801
802 #[test]
803 fn client_is_debug() {
804 let client = HyperClient::new();
805 let debug = format!("{client:?}");
806 assert!(debug.contains("HyperClient"));
807 }
808}