hyperdriver/client/
builder.rs

1#[cfg(feature = "tls")]
2use std::sync::Arc;
3use std::time::Duration;
4
5use http::HeaderValue;
6use http_body::Body;
7#[cfg(feature = "tls")]
8use rustls::ClientConfig;
9use tower::layer::util::{Identity, Stack};
10use tower::ServiceBuilder;
11use tower_http::follow_redirect::policy;
12use tower_http::follow_redirect::FollowRedirectLayer;
13use tower_http::set_header::SetRequestHeaderLayer;
14
15use super::conn::protocol::auto;
16use super::conn::transport::tcp::TcpTransportConfig;
17use super::conn::transport::TransportExt;
18use super::conn::Connection;
19use super::conn::Protocol;
20use super::conn::Transport;
21use super::pool::{PoolableConnection, PoolableStream, UriKey};
22use super::ConnectionPoolLayer;
23use crate::service::RequestExecutor;
24use crate::service::{Http1ChecksLayer, Http2ChecksLayer, SetHostHeaderLayer};
25use crate::BoxError;
26
27use crate::client::conn::connection::ConnectionError;
28#[cfg(feature = "tls")]
29use crate::client::default_tls_config;
30use crate::client::{conn::protocol::auto::HttpConnectionBuilder, Client};
31use crate::info::HasConnectionInfo;
32use crate::service::IncomingResponseLayer;
33use crate::service::OptionLayerExt;
34use crate::service::SharedService;
35use crate::service::TimeoutLayer;
36
37pub trait BuildProtocol<IO, B>
38where
39    IO: HasConnectionInfo,
40{
41    type Target: Protocol<IO, B>;
42    fn build(self) -> Self::Target;
43}
44
45impl<P, IO, B> BuildProtocol<IO, B> for P
46where
47    P: Protocol<IO, B>,
48    IO: HasConnectionInfo,
49{
50    type Target = P;
51    fn build(self) -> Self::Target {
52        self
53    }
54}
55
56pub trait BuildTransport {
57    type Target: Transport;
58    fn build(self) -> Self::Target;
59}
60
61impl<T> BuildTransport for T
62where
63    T: Transport,
64{
65    type Target = T;
66    fn build(self) -> Self::Target {
67        self
68    }
69}
70
71/// A builder for a client.
72#[derive(Debug)]
73pub struct Builder<T, P, RP = policy::Standard, S = Identity, BIn = crate::Body, BOut = crate::Body>
74{
75    transport: T,
76    protocol: P,
77    builder: ServiceBuilder<S>,
78    user_agent: Option<String>,
79    redirect: Option<RP>,
80    timeout: Option<Duration>,
81    #[cfg(feature = "tls")]
82    tls: Option<ClientConfig>,
83    pool: Option<crate::client::pool::Config>,
84    body: std::marker::PhantomData<fn(BIn) -> BOut>,
85}
86
87impl Builder<(), (), policy::Standard> {
88    /// Create a new, empty builder
89    pub fn new() -> Self {
90        Self {
91            transport: (),
92            protocol: (),
93            builder: ServiceBuilder::new(),
94            user_agent: None,
95            redirect: None,
96            timeout: None,
97            #[cfg(feature = "tls")]
98            tls: None,
99            pool: None,
100            body: std::marker::PhantomData,
101        }
102    }
103}
104
105impl Default
106    for Builder<
107        TcpTransportConfig,
108        HttpConnectionBuilder<crate::Body>,
109        policy::Standard,
110        Identity,
111        crate::Body,
112        crate::Body,
113    >
114{
115    fn default() -> Self {
116        Self {
117            transport: Default::default(),
118            protocol: Default::default(),
119            builder: ServiceBuilder::new(),
120            user_agent: None,
121            redirect: Some(policy::Standard::default()),
122            timeout: Some(Duration::from_secs(30)),
123            #[cfg(feature = "tls")]
124            tls: Some(default_tls_config()),
125            pool: Some(Default::default()),
126            body: std::marker::PhantomData,
127        }
128    }
129}
130
131impl<T, P, RP, S, BIn, BOut> Builder<T, P, RP, S, BIn, BOut> {
132    /// Use the provided TCP configuration.
133    pub fn with_tcp(
134        self,
135        config: TcpTransportConfig,
136    ) -> Builder<TcpTransportConfig, P, RP, S, BIn, BOut> {
137        Builder {
138            transport: config,
139            protocol: self.protocol,
140            builder: self.builder,
141            user_agent: self.user_agent,
142            redirect: self.redirect,
143            timeout: self.timeout,
144            #[cfg(feature = "tls")]
145            tls: self.tls,
146            pool: self.pool,
147            body: self.body,
148        }
149    }
150
151    /// Provide a custom transport
152    pub fn with_transport<T2>(self, transport: T2) -> Builder<T2, P, RP, S, BIn, BOut> {
153        Builder {
154            transport,
155            protocol: self.protocol,
156            builder: self.builder,
157            user_agent: self.user_agent,
158            redirect: self.redirect,
159            timeout: self.timeout,
160            #[cfg(feature = "tls")]
161            tls: self.tls,
162            pool: self.pool,
163            body: self.body,
164        }
165    }
166
167    /// Get a mutable reference to the transport configuration
168    pub fn transport(&mut self) -> &mut T {
169        &mut self.transport
170    }
171}
172
173#[cfg(feature = "tls")]
174impl<T, P, RP, S, BIn, BOut> Builder<T, P, RP, S, BIn, BOut> {
175    /// Disable TLS
176    pub fn without_tls(mut self) -> Self {
177        self.tls = None;
178        self
179    }
180
181    /// Use the provided TLS configuration.
182    pub fn with_tls(mut self, config: ClientConfig) -> Self {
183        self.tls = Some(config);
184        self
185    }
186
187    /// Use the default TLS configuration with native root certificates.
188    pub fn with_default_tls(mut self) -> Self {
189        self.tls = Some(default_tls_config());
190        self
191    }
192
193    /// TLS configuration.
194    pub fn tls(&mut self) -> &mut Option<ClientConfig> {
195        &mut self.tls
196    }
197}
198
199#[cfg(not(feature = "tls"))]
200impl<T, P, RP, S, BIn, BOut> Builder<T, P, RP, S, BIn, BOut> {
201    /// Disable TLS
202    pub fn without_tls(self) -> Self {
203        self
204    }
205}
206
207impl<T, P, RP, S, BIn, BOut> Builder<T, P, RP, S, BIn, BOut> {
208    /// Connection pool configuration.
209    pub fn pool(&mut self) -> Option<&mut crate::client::pool::Config> {
210        self.pool.as_mut()
211    }
212
213    /// Use the provided connection pool configuration.
214    pub fn with_pool(mut self, pool: crate::client::pool::Config) -> Self {
215        self.pool = Some(pool);
216        self
217    }
218
219    /// Configure the default pool settings
220    pub fn with_default_pool(mut self) -> Self {
221        self.pool = Some(Default::default());
222        self
223    }
224
225    /// Disable connection pooling.
226    pub fn without_pool(mut self) -> Self {
227        self.pool = None;
228        self
229    }
230}
231
232impl<T, P, RP, S, BIn, BOut> Builder<T, P, RP, S, BIn, BOut> {
233    /// Use the auto-HTTP Protocol
234    pub fn with_auto_http(self) -> Builder<T, auto::HttpConnectionBuilder<BIn>, RP, S, BIn, BOut> {
235        Builder {
236            transport: self.transport,
237            protocol: auto::HttpConnectionBuilder::default(),
238            builder: self.builder,
239            user_agent: self.user_agent,
240            redirect: self.redirect,
241            timeout: self.timeout,
242            #[cfg(feature = "tls")]
243            tls: self.tls,
244            pool: self.pool,
245            body: self.body,
246        }
247    }
248
249    /// Use the provided HTTP connection configuration.
250    pub fn with_protocol<P2>(self, protocol: P2) -> Builder<T, P2, RP, S, BIn, BOut> {
251        Builder {
252            transport: self.transport,
253            protocol,
254            builder: self.builder,
255            user_agent: self.user_agent,
256            redirect: self.redirect,
257            timeout: self.timeout,
258            #[cfg(feature = "tls")]
259            tls: self.tls,
260            pool: self.pool,
261            body: self.body,
262        }
263    }
264
265    /// HTTP connection configuration.
266    pub fn protocol(&mut self) -> &mut P {
267        &mut self.protocol
268    }
269}
270
271impl<T, P, RP, S, BIn, BOut> Builder<T, P, RP, S, BIn, BOut> {
272    /// Set the User-Agent header.
273    pub fn with_user_agent(mut self, user_agent: String) -> Self {
274        self.user_agent = Some(user_agent);
275        self
276    }
277
278    /// Get the user agent currently configured
279    pub fn user_agent(&self) -> Option<&str> {
280        self.user_agent.as_deref()
281    }
282}
283
284impl<T, P, RP, S, BIn, BOut> Builder<T, P, RP, S, BIn, BOut> {
285    /// Set the redirect policy. See [`policy`] for more information.
286    pub fn with_redirect_policy<RP2>(self, policy: RP2) -> Builder<T, P, RP2, S, BIn, BOut> {
287        Builder {
288            transport: self.transport,
289            protocol: self.protocol,
290            builder: self.builder,
291            user_agent: self.user_agent,
292            redirect: Some(policy),
293            timeout: self.timeout,
294            #[cfg(feature = "tls")]
295            tls: self.tls,
296            pool: self.pool,
297            body: self.body,
298        }
299    }
300
301    /// Disable redirects.
302    pub fn without_redirects(self) -> Builder<T, P, policy::Standard, S, BIn, BOut> {
303        Builder {
304            transport: self.transport,
305            protocol: self.protocol,
306            user_agent: self.user_agent,
307            builder: self.builder,
308            redirect: None,
309            timeout: self.timeout,
310            #[cfg(feature = "tls")]
311            tls: self.tls,
312            pool: self.pool,
313            body: self.body,
314        }
315    }
316
317    /// Set the standard redirect policy. See [`policy::Standard`] for more information.
318    pub fn with_standard_redirect_policy(self) -> Builder<T, P, policy::Standard, S, BIn, BOut> {
319        Builder {
320            transport: self.transport,
321            protocol: self.protocol,
322            builder: self.builder,
323            user_agent: self.user_agent,
324            redirect: Some(policy::Standard::default()),
325            timeout: self.timeout,
326            #[cfg(feature = "tls")]
327            tls: self.tls,
328            pool: self.pool,
329            body: self.body,
330        }
331    }
332
333    /// Configured redirect policy.
334    pub fn redirect_policy(&mut self) -> Option<&mut RP> {
335        self.redirect.as_mut()
336    }
337}
338
339impl<T, P, RP, S, BIn, BOut> Builder<T, P, RP, S, BIn, BOut> {
340    /// Set the timeout for requests.
341    pub fn with_timeout(mut self, timeout: Duration) -> Self {
342        self.timeout = Some(timeout);
343        self
344    }
345
346    /// Get the timeout for requests.
347    pub fn timeout(&self) -> Option<Duration> {
348        self.timeout
349    }
350
351    /// Disable request timeouts.
352    pub fn without_timeout(mut self) -> Self {
353        self.timeout = None;
354        self
355    }
356
357    /// Set the timeout for requests with an Option.
358    pub fn with_optional_timeout(mut self, timeout: Option<Duration>) -> Self {
359        self.timeout = timeout;
360        self
361    }
362}
363
364impl<T, P, RP, S, BIn, BOut> Builder<T, P, RP, S, BIn, BOut> {
365    /// Add a layer to the service under construction
366    pub fn with_body<B2In, B2Out>(self) -> Builder<T, P, RP, S, B2In, B2Out> {
367        Builder {
368            transport: self.transport,
369            protocol: self.protocol,
370            builder: self.builder,
371            user_agent: self.user_agent,
372            redirect: self.redirect,
373            timeout: self.timeout,
374            #[cfg(feature = "tls")]
375            tls: self.tls,
376            pool: self.pool,
377            body: std::marker::PhantomData,
378        }
379    }
380}
381
382impl<T, P, RP, S, BIn, BOut> Builder<T, P, RP, S, BIn, BOut> {
383    /// Add a layer to the service under construction
384    pub fn layer<L>(self, layer: L) -> Builder<T, P, RP, Stack<L, S>, BIn, BOut> {
385        Builder {
386            transport: self.transport,
387            protocol: self.protocol,
388            builder: self.builder.layer(layer),
389            user_agent: self.user_agent,
390            redirect: self.redirect,
391            timeout: self.timeout,
392            #[cfg(feature = "tls")]
393            tls: self.tls,
394            pool: self.pool,
395            body: self.body,
396        }
397    }
398}
399
400impl<T, P, RP, S, BIn, BOut> Builder<T, P, RP, S, BIn, BOut>
401where
402    T: BuildTransport,
403    <T as BuildTransport>::Target: Transport + Clone + Send + Sync + 'static,
404    <<T as BuildTransport>::Target as Transport>::IO:
405        PoolableStream + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
406    <<<T as BuildTransport>::Target as Transport>::IO as HasConnectionInfo>::Addr:
407        Unpin + Clone + Send,
408    P: BuildProtocol<
409        super::conn::stream::Stream<<<T as BuildTransport>::Target as Transport>::IO>,
410        BIn,
411    >,
412    <P as BuildProtocol<
413        super::conn::stream::Stream<<<T as BuildTransport>::Target as Transport>::IO>,
414        BIn,
415    >>::Target: Protocol<
416            super::conn::stream::Stream<<<T as BuildTransport>::Target as Transport>::IO>,
417            BIn,
418            Error = ConnectionError,
419        > + Clone
420        + Send
421        + Sync
422        + 'static,
423    <<P as BuildProtocol<
424        super::conn::stream::Stream<<<T as BuildTransport>::Target as Transport>::IO>,
425        BIn,
426    >>::Target as Protocol<
427        super::conn::stream::Stream<<<T as BuildTransport>::Target as Transport>::IO>,
428        BIn,
429    >>::Connection: Connection<BIn, ResBody = hyper::body::Incoming> + PoolableConnection<BIn>,
430
431    RP: policy::Policy<BIn, super::Error> + Clone + Send + Sync + 'static,
432    S: tower::Layer<SharedService<http::Request<BIn>, http::Response<BOut>, super::Error>>,
433    S::Service: tower::Service<http::Request<BIn>, Response = http::Response<BOut>, Error = super::Error>
434        + Clone
435        + Send
436        + Sync
437        + 'static,
438    <S::Service as tower::Service<http::Request<BIn>>>::Future: Send + 'static,
439    BIn: Default + Body + Unpin + Send + 'static,
440    <BIn as Body>::Data: Send,
441    <BIn as Body>::Error: Into<BoxError>,
442    BOut: From<hyper::body::Incoming> + Body + Unpin + Send + 'static,
443{
444    /// Build a client service with the configured layers
445    pub fn build_service(
446        self,
447    ) -> SharedService<http::Request<BIn>, http::Response<BOut>, super::Error> {
448        let user_agent = if let Some(ua) = self.user_agent {
449            HeaderValue::from_str(&ua).expect("user-agent should be a valid http header")
450        } else {
451            HeaderValue::from_static(concat!(
452                env!("CARGO_PKG_NAME"),
453                "/",
454                env!("CARGO_PKG_VERSION")
455            ))
456        };
457
458        #[cfg(feature = "tls")]
459        let transport = self
460            .transport
461            .build()
462            .with_optional_tls(self.tls.map(Arc::new));
463        #[cfg(not(feature = "tls"))]
464        let transport = self.transport.build().without_tls();
465
466        let service = self
467            .builder
468            .layer(SharedService::layer())
469            .optional(
470                self.timeout
471                    .map(|d| TimeoutLayer::new(|| super::Error::RequestTimeout, d)),
472            )
473            .optional(self.redirect.map(FollowRedirectLayer::with_policy))
474            .layer(SetRequestHeaderLayer::if_not_present(
475                http::header::USER_AGENT,
476                user_agent,
477            ))
478            .layer(IncomingResponseLayer::new())
479            .layer(
480                ConnectionPoolLayer::<_, _, _, UriKey>::new(transport, self.protocol.build())
481                    .with_optional_pool(self.pool.clone()),
482            )
483            .layer(SetHostHeaderLayer::new())
484            .layer(Http2ChecksLayer::new())
485            .layer(Http1ChecksLayer::new())
486            .service(RequestExecutor::new());
487
488        SharedService::new(service)
489    }
490}
491
492impl<T, P, RP, S> Builder<T, P, RP, S, crate::Body, crate::Body>
493where
494    T: BuildTransport,
495    <T as BuildTransport>::Target: Transport + Clone + Send + Sync + 'static,
496    <<T as BuildTransport>::Target as Transport>::IO:
497        PoolableStream + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
498    <<<T as BuildTransport>::Target as Transport>::IO as HasConnectionInfo>::Addr:
499        Unpin + Clone + Send,
500    P: BuildProtocol<
501        super::conn::stream::Stream<<<T as BuildTransport>::Target as Transport>::IO>,
502        crate::Body,
503    >,
504    <P as BuildProtocol<
505        super::conn::stream::Stream<<<T as BuildTransport>::Target as Transport>::IO>,
506        crate::Body,
507    >>::Target: Protocol<
508            super::conn::stream::Stream<<<T as BuildTransport>::Target as Transport>::IO>,
509            crate::Body,
510            Error = ConnectionError,
511        > + Clone
512        + Send
513        + Sync
514        + 'static,
515    <<P as BuildProtocol<
516        super::conn::stream::Stream<<<T as BuildTransport>::Target as Transport>::IO>,
517        crate::Body,
518    >>::Target as Protocol<
519        super::conn::stream::Stream<<<T as BuildTransport>::Target as Transport>::IO>,
520        crate::Body,
521    >>::Connection:
522        Connection<crate::Body, ResBody = hyper::body::Incoming> + PoolableConnection<crate::Body>,
523
524    RP: policy::Policy<crate::Body, super::Error> + Clone + Send + Sync + 'static,
525    S: tower::Layer<
526        SharedService<http::Request<crate::Body>, http::Response<crate::Body>, super::Error>,
527    >,
528    S::Service: tower::Service<
529            http::Request<crate::Body>,
530            Response = http::Response<crate::Body>,
531            Error = super::Error,
532        > + Clone
533        + Send
534        + Sync
535        + 'static,
536    <S::Service as tower::Service<http::Request<crate::Body>>>::Future: Send + 'static,
537{
538    /// Build the client.
539    pub fn build(self) -> Client {
540        Client::new_from_service(self.build_service())
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::Builder;
547
548    #[test]
549    fn build_default_compiles() {
550        #[cfg(feature = "tls")]
551        {
552            crate::fixtures::tls_install_default();
553        }
554
555        let _ = Builder::default().build();
556    }
557}