requiem_http/client/
connector.rs

1use std::fmt;
2use std::marker::PhantomData;
3use std::time::Duration;
4
5use requiem_codec::{AsyncRead, AsyncWrite};
6use requiem_connect::{
7    default_connector, Connect as TcpConnect, Connection as TcpConnection,
8};
9use requiem_rt::net::TcpStream;
10use requiem_service::{apply_fn, Service};
11use requiem_utils::timeout::{TimeoutError, TimeoutService};
12use http::Uri;
13
14use super::connection::Connection;
15use super::error::ConnectError;
16use super::pool::{ConnectionPool, Protocol};
17use super::Connect;
18
19#[cfg(feature = "openssl")]
20use requiem_connect::ssl::openssl::SslConnector as OpensslConnector;
21
22#[cfg(feature = "rustls")]
23use requiem_connect::ssl::rustls::ClientConfig;
24#[cfg(feature = "rustls")]
25use std::sync::Arc;
26
27#[cfg(any(feature = "openssl", feature = "rustls"))]
28enum SslConnector {
29    #[cfg(feature = "openssl")]
30    Openssl(OpensslConnector),
31    #[cfg(feature = "rustls")]
32    Rustls(Arc<ClientConfig>),
33}
34#[cfg(not(any(feature = "openssl", feature = "rustls")))]
35type SslConnector = ();
36
37/// Manages http client network connectivity
38/// The `Connector` type uses a builder-like combinator pattern for service
39/// construction that finishes by calling the `.finish()` method.
40///
41/// ```rust,ignore
42/// use std::time::Duration;
43/// use requiem_http::client::Connector;
44///
45/// let connector = Connector::new()
46///      .timeout(Duration::from_secs(5))
47///      .finish();
48/// ```
49pub struct Connector<T, U> {
50    connector: T,
51    timeout: Duration,
52    conn_lifetime: Duration,
53    conn_keep_alive: Duration,
54    disconnect_timeout: Duration,
55    limit: usize,
56    #[allow(dead_code)]
57    ssl: SslConnector,
58    _t: PhantomData<U>,
59}
60
61trait Io: AsyncRead + AsyncWrite + Unpin {}
62impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
63
64impl Connector<(), ()> {
65    #[allow(clippy::new_ret_no_self, clippy::let_unit_value)]
66    pub fn new() -> Connector<
67        impl Service<
68                Request = TcpConnect<Uri>,
69                Response = TcpConnection<Uri, TcpStream>,
70                Error = requiem_connect::ConnectError,
71            > + Clone,
72        TcpStream,
73    > {
74        let ssl = {
75            #[cfg(feature = "openssl")]
76            {
77                use requiem_connect::ssl::openssl::SslMethod;
78
79                let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap();
80                let _ = ssl
81                    .set_alpn_protos(b"\x02h2\x08http/1.1")
82                    .map_err(|e| error!("Can not set alpn protocol: {:?}", e));
83                SslConnector::Openssl(ssl.build())
84            }
85            #[cfg(all(not(feature = "openssl"), feature = "rustls"))]
86            {
87                let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
88                let mut config = ClientConfig::new();
89                config.set_protocols(&protos);
90                config
91                    .root_store
92                    .add_server_trust_anchors(&requiem_tls::rustls::TLS_SERVER_ROOTS);
93                SslConnector::Rustls(Arc::new(config))
94            }
95            #[cfg(not(any(feature = "openssl", feature = "rustls")))]
96            {}
97        };
98
99        Connector {
100            ssl,
101            connector: default_connector(),
102            timeout: Duration::from_secs(1),
103            conn_lifetime: Duration::from_secs(75),
104            conn_keep_alive: Duration::from_secs(15),
105            disconnect_timeout: Duration::from_millis(3000),
106            limit: 100,
107            _t: PhantomData,
108        }
109    }
110}
111
112impl<T, U> Connector<T, U> {
113    /// Use custom connector.
114    pub fn connector<T1, U1>(self, connector: T1) -> Connector<T1, U1>
115    where
116        U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
117        T1: Service<
118                Request = TcpConnect<Uri>,
119                Response = TcpConnection<Uri, U1>,
120                Error = requiem_connect::ConnectError,
121            > + Clone,
122    {
123        Connector {
124            connector,
125            timeout: self.timeout,
126            conn_lifetime: self.conn_lifetime,
127            conn_keep_alive: self.conn_keep_alive,
128            disconnect_timeout: self.disconnect_timeout,
129            limit: self.limit,
130            ssl: self.ssl,
131            _t: PhantomData,
132        }
133    }
134}
135
136impl<T, U> Connector<T, U>
137where
138    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
139    T: Service<
140            Request = TcpConnect<Uri>,
141            Response = TcpConnection<Uri, U>,
142            Error = requiem_connect::ConnectError,
143        > + Clone
144        + 'static,
145{
146    /// Connection timeout, i.e. max time to connect to remote host including dns name resolution.
147    /// Set to 1 second by default.
148    pub fn timeout(mut self, timeout: Duration) -> Self {
149        self.timeout = timeout;
150        self
151    }
152
153    #[cfg(feature = "openssl")]
154    /// Use custom `SslConnector` instance.
155    pub fn ssl(mut self, connector: OpensslConnector) -> Self {
156        self.ssl = SslConnector::Openssl(connector);
157        self
158    }
159
160    #[cfg(feature = "rustls")]
161    pub fn rustls(mut self, connector: Arc<ClientConfig>) -> Self {
162        self.ssl = SslConnector::Rustls(connector);
163        self
164    }
165
166    /// Set total number of simultaneous connections per type of scheme.
167    ///
168    /// If limit is 0, the connector has no limit.
169    /// The default limit size is 100.
170    pub fn limit(mut self, limit: usize) -> Self {
171        self.limit = limit;
172        self
173    }
174
175    /// Set keep-alive period for opened connection.
176    ///
177    /// Keep-alive period is the period between connection usage. If
178    /// the delay between repeated usages of the same connection
179    /// exceeds this period, the connection is closed.
180    /// Default keep-alive period is 15 seconds.
181    pub fn conn_keep_alive(mut self, dur: Duration) -> Self {
182        self.conn_keep_alive = dur;
183        self
184    }
185
186    /// Set max lifetime period for connection.
187    ///
188    /// Connection lifetime is max lifetime of any opened connection
189    /// until it is closed regardless of keep-alive period.
190    /// Default lifetime period is 75 seconds.
191    pub fn conn_lifetime(mut self, dur: Duration) -> Self {
192        self.conn_lifetime = dur;
193        self
194    }
195
196    /// Set server connection disconnect timeout in milliseconds.
197    ///
198    /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
199    /// within this time, the socket get dropped. This timeout affects only secure connections.
200    ///
201    /// To disable timeout set value to 0.
202    ///
203    /// By default disconnect timeout is set to 3000 milliseconds.
204    pub fn disconnect_timeout(mut self, dur: Duration) -> Self {
205        self.disconnect_timeout = dur;
206        self
207    }
208
209    /// Finish configuration process and create connector service.
210    /// The Connector builder always concludes by calling `finish()` last in
211    /// its combinator chain.
212    pub fn finish(
213        self,
214    ) -> impl Service<Request = Connect, Response = impl Connection, Error = ConnectError>
215           + Clone {
216        #[cfg(not(any(feature = "openssl", feature = "rustls")))]
217        {
218            let connector = TimeoutService::new(
219                self.timeout,
220                apply_fn(self.connector, |msg: Connect, srv| {
221                    srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
222                })
223                .map_err(ConnectError::from)
224                .map(|stream| (stream.into_parts().0, Protocol::Http1)),
225            )
226            .map_err(|e| match e {
227                TimeoutError::Service(e) => e,
228                TimeoutError::Timeout => ConnectError::Timeout,
229            });
230
231            connect_impl::InnerConnector {
232                tcp_pool: ConnectionPool::new(
233                    connector,
234                    self.conn_lifetime,
235                    self.conn_keep_alive,
236                    None,
237                    self.limit,
238                ),
239            }
240        }
241        #[cfg(any(feature = "openssl", feature = "rustls"))]
242        {
243            const H2: &[u8] = b"h2";
244            #[cfg(feature = "openssl")]
245            use requiem_connect::ssl::openssl::OpensslConnector;
246            #[cfg(feature = "rustls")]
247            use requiem_connect::ssl::rustls::{RustlsConnector, Session};
248            use requiem_service::{boxed::service, pipeline};
249
250            let ssl_service = TimeoutService::new(
251                self.timeout,
252                pipeline(
253                    apply_fn(self.connector.clone(), |msg: Connect, srv| {
254                        srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
255                    })
256                    .map_err(ConnectError::from),
257                )
258                .and_then(match self.ssl {
259                    #[cfg(feature = "openssl")]
260                    SslConnector::Openssl(ssl) => service(
261                        OpensslConnector::service(ssl)
262                            .map(|stream| {
263                                let sock = stream.into_parts().0;
264                                let h2 = sock
265                                    .ssl()
266                                    .selected_alpn_protocol()
267                                    .map(|protos| protos.windows(2).any(|w| w == H2))
268                                    .unwrap_or(false);
269                                if h2 {
270                                    (Box::new(sock) as Box<dyn Io>, Protocol::Http2)
271                                } else {
272                                    (Box::new(sock) as Box<dyn Io>, Protocol::Http1)
273                                }
274                            })
275                            .map_err(ConnectError::from),
276                    ),
277                    #[cfg(feature = "rustls")]
278                    SslConnector::Rustls(ssl) => service(
279                        RustlsConnector::service(ssl)
280                            .map_err(ConnectError::from)
281                            .map(|stream| {
282                                let sock = stream.into_parts().0;
283                                let h2 = sock
284                                    .get_ref()
285                                    .1
286                                    .get_alpn_protocol()
287                                    .map(|protos| protos.windows(2).any(|w| w == H2))
288                                    .unwrap_or(false);
289                                if h2 {
290                                    (Box::new(sock) as Box<dyn Io>, Protocol::Http2)
291                                } else {
292                                    (Box::new(sock) as Box<dyn Io>, Protocol::Http1)
293                                }
294                            }),
295                    ),
296                }),
297            )
298            .map_err(|e| match e {
299                TimeoutError::Service(e) => e,
300                TimeoutError::Timeout => ConnectError::Timeout,
301            });
302
303            let tcp_service = TimeoutService::new(
304                self.timeout,
305                apply_fn(self.connector, |msg: Connect, srv| {
306                    srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
307                })
308                .map_err(ConnectError::from)
309                .map(|stream| (stream.into_parts().0, Protocol::Http1)),
310            )
311            .map_err(|e| match e {
312                TimeoutError::Service(e) => e,
313                TimeoutError::Timeout => ConnectError::Timeout,
314            });
315
316            connect_impl::InnerConnector {
317                tcp_pool: ConnectionPool::new(
318                    tcp_service,
319                    self.conn_lifetime,
320                    self.conn_keep_alive,
321                    None,
322                    self.limit,
323                ),
324                ssl_pool: ConnectionPool::new(
325                    ssl_service,
326                    self.conn_lifetime,
327                    self.conn_keep_alive,
328                    Some(self.disconnect_timeout),
329                    self.limit,
330                ),
331            }
332        }
333    }
334}
335
336#[cfg(not(any(feature = "openssl", feature = "rustls")))]
337mod connect_impl {
338    use std::task::{Context, Poll};
339
340    use futures_util::future::{err, Either, Ready};
341
342    use super::*;
343    use crate::client::connection::IoConnection;
344
345    pub(crate) struct InnerConnector<T, Io>
346    where
347        Io: AsyncRead + AsyncWrite + Unpin + 'static,
348        T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
349            + 'static,
350    {
351        pub(crate) tcp_pool: ConnectionPool<T, Io>,
352    }
353
354    impl<T, Io> Clone for InnerConnector<T, Io>
355    where
356        Io: AsyncRead + AsyncWrite + Unpin + 'static,
357        T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
358            + 'static,
359    {
360        fn clone(&self) -> Self {
361            InnerConnector {
362                tcp_pool: self.tcp_pool.clone(),
363            }
364        }
365    }
366
367    impl<T, Io> Service for InnerConnector<T, Io>
368    where
369        Io: AsyncRead + AsyncWrite + Unpin + 'static,
370        T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
371            + 'static,
372    {
373        type Request = Connect;
374        type Response = IoConnection<Io>;
375        type Error = ConnectError;
376        type Future = Either<
377            <ConnectionPool<T, Io> as Service>::Future,
378            Ready<Result<IoConnection<Io>, ConnectError>>,
379        >;
380
381        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
382            self.tcp_pool.poll_ready(cx)
383        }
384
385        fn call(&mut self, req: Connect) -> Self::Future {
386            match req.uri.scheme_str() {
387                Some("https") | Some("wss") => {
388                    Either::Right(err(ConnectError::SslIsNotSupported))
389                }
390                _ => Either::Left(self.tcp_pool.call(req)),
391            }
392        }
393    }
394}
395
396#[cfg(any(feature = "openssl", feature = "rustls"))]
397mod connect_impl {
398    use std::future::Future;
399    use std::marker::PhantomData;
400    use std::pin::Pin;
401    use std::task::{Context, Poll};
402
403    use futures_core::ready;
404    use futures_util::future::Either;
405
406    use super::*;
407    use crate::client::connection::EitherConnection;
408
409    pub(crate) struct InnerConnector<T1, T2, Io1, Io2>
410    where
411        Io1: AsyncRead + AsyncWrite + Unpin + 'static,
412        Io2: AsyncRead + AsyncWrite + Unpin + 'static,
413        T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>,
414        T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>,
415    {
416        pub(crate) tcp_pool: ConnectionPool<T1, Io1>,
417        pub(crate) ssl_pool: ConnectionPool<T2, Io2>,
418    }
419
420    impl<T1, T2, Io1, Io2> Clone for InnerConnector<T1, T2, Io1, Io2>
421    where
422        Io1: AsyncRead + AsyncWrite + Unpin + 'static,
423        Io2: AsyncRead + AsyncWrite + Unpin + 'static,
424        T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
425            + 'static,
426        T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
427            + 'static,
428    {
429        fn clone(&self) -> Self {
430            InnerConnector {
431                tcp_pool: self.tcp_pool.clone(),
432                ssl_pool: self.ssl_pool.clone(),
433            }
434        }
435    }
436
437    impl<T1, T2, Io1, Io2> Service for InnerConnector<T1, T2, Io1, Io2>
438    where
439        Io1: AsyncRead + AsyncWrite + Unpin + 'static,
440        Io2: AsyncRead + AsyncWrite + Unpin + 'static,
441        T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
442            + 'static,
443        T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
444            + 'static,
445    {
446        type Request = Connect;
447        type Response = EitherConnection<Io1, Io2>;
448        type Error = ConnectError;
449        type Future = Either<
450            InnerConnectorResponseA<T1, Io1, Io2>,
451            InnerConnectorResponseB<T2, Io1, Io2>,
452        >;
453
454        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
455            self.tcp_pool.poll_ready(cx)
456        }
457
458        fn call(&mut self, req: Connect) -> Self::Future {
459            match req.uri.scheme_str() {
460                Some("https") | Some("wss") => Either::Right(InnerConnectorResponseB {
461                    fut: self.ssl_pool.call(req),
462                    _t: PhantomData,
463                }),
464                _ => Either::Left(InnerConnectorResponseA {
465                    fut: self.tcp_pool.call(req),
466                    _t: PhantomData,
467                }),
468            }
469        }
470    }
471
472    #[pin_project::pin_project]
473    pub(crate) struct InnerConnectorResponseA<T, Io1, Io2>
474    where
475        Io1: AsyncRead + AsyncWrite + Unpin + 'static,
476        T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
477            + 'static,
478    {
479        #[pin]
480        fut: <ConnectionPool<T, Io1> as Service>::Future,
481        _t: PhantomData<Io2>,
482    }
483
484    impl<T, Io1, Io2> Future for InnerConnectorResponseA<T, Io1, Io2>
485    where
486        T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
487            + 'static,
488        Io1: AsyncRead + AsyncWrite + Unpin + 'static,
489        Io2: AsyncRead + AsyncWrite + Unpin + 'static,
490    {
491        type Output = Result<EitherConnection<Io1, Io2>, ConnectError>;
492
493        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
494            Poll::Ready(
495                ready!(Pin::new(&mut self.get_mut().fut).poll(cx))
496                    .map(EitherConnection::A),
497            )
498        }
499    }
500
501    #[pin_project::pin_project]
502    pub(crate) struct InnerConnectorResponseB<T, Io1, Io2>
503    where
504        Io2: AsyncRead + AsyncWrite + Unpin + 'static,
505        T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
506            + 'static,
507    {
508        #[pin]
509        fut: <ConnectionPool<T, Io2> as Service>::Future,
510        _t: PhantomData<Io1>,
511    }
512
513    impl<T, Io1, Io2> Future for InnerConnectorResponseB<T, Io1, Io2>
514    where
515        T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
516            + 'static,
517        Io1: AsyncRead + AsyncWrite + Unpin + 'static,
518        Io2: AsyncRead + AsyncWrite + Unpin + 'static,
519    {
520        type Output = Result<EitherConnection<Io1, Io2>, ConnectError>;
521
522        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
523            Poll::Ready(
524                ready!(Pin::new(&mut self.get_mut().fut).poll(cx))
525                    .map(EitherConnection::B),
526            )
527        }
528    }
529}