awc/client/
connector.rs

1use std::{
2    fmt,
3    future::Future,
4    net::IpAddr,
5    pin::Pin,
6    rc::Rc,
7    task::{Context, Poll},
8    time::Duration,
9};
10
11use actix_http::Protocol;
12use actix_rt::{
13    net::{ActixStream, TcpStream},
14    time::{sleep, Sleep},
15};
16use actix_service::Service;
17use actix_tls::connect::{
18    ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection,
19    Connector as TcpConnector, Resolver,
20};
21use futures_core::{future::LocalBoxFuture, ready};
22use http::Uri;
23use pin_project_lite::pin_project;
24
25use super::{
26    config::ConnectorConfig,
27    connection::{Connection, ConnectionIo},
28    error::ConnectError,
29    pool::ConnectionPool,
30    Connect,
31};
32
33enum OurTlsConnector {
34    #[allow(dead_code)] // only dead when no TLS feature is enabled
35    None,
36
37    // #[cfg(feature = "openssl")]
38    // Openssl(actix_tls::connect::openssl::reexports::SslConnector),
39
40    // /// Provided because building the OpenSSL context on newer versions can be very slow.
41    // /// This prevents unnecessary calls to `.build()` while constructing the client connector.
42    // #[cfg(feature = "openssl")]
43    // #[allow(dead_code)] // false positive; used in build_ssl
44    // OpensslBuilder(actix_tls::connect::openssl::reexports::SslConnectorBuilder),
45
46    #[cfg(feature = "rustls")]
47    Rustls(std::sync::Arc<actix_tls::connect::rustls::reexports::ClientConfig>),
48}
49
50/// Manages HTTP client network connectivity.
51///
52/// The `Connector` type uses a builder-like combinator pattern for service
53/// construction that finishes by calling the `.finish()` method.
54///
55/// ```ignore
56/// use std::time::Duration;
57/// use actix_http::client::Connector;
58///
59/// let connector = Connector::new()
60///      .timeout(Duration::from_secs(5))
61///      .finish();
62/// ```
63pub struct Connector<T> {
64    connector: T,
65    config: ConnectorConfig,
66
67    #[allow(dead_code)] // only dead when no TLS feature is enabled
68    tls: OurTlsConnector,
69}
70
71impl Connector<()> {
72    #[allow(clippy::new_ret_no_self, clippy::let_unit_value)]
73    pub fn new() -> Connector<
74        impl Service<
75                ConnectInfo<Uri>,
76                Response = TcpConnection<Uri, TcpStream>,
77                Error = actix_tls::connect::ConnectError,
78            > + Clone,
79    > {
80        Connector {
81            connector: TcpConnector::new(resolver::resolver()).service(),
82            config: ConnectorConfig::default(),
83            tls: Self::build_ssl(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
84        }
85    }
86
87    /// Provides an empty TLS connector when no TLS feature is enabled.
88    #[cfg(not(any(feature = "openssl", feature = "rustls")))]
89    fn build_ssl(_: Vec<Vec<u8>>) -> OurTlsConnector {
90        OurTlsConnector::None
91    }
92
93    /// Build TLS connector with rustls, based on supplied ALPN protocols
94    ///
95    /// Note that if both `openssl` and `rustls` features are enabled, rustls will be used.
96    #[cfg(feature = "rustls")]
97    fn build_ssl(protocols: Vec<Vec<u8>>) -> OurTlsConnector {
98        use actix_tls::connect::rustls::{reexports::ClientConfig, webpki_roots_cert_store};
99
100        let mut config = ClientConfig::builder()
101            .with_safe_defaults()
102            .with_root_certificates(webpki_roots_cert_store())
103            .with_no_client_auth();
104
105        config.alpn_protocols = protocols;
106
107        OurTlsConnector::Rustls(std::sync::Arc::new(config))
108    }
109
110    // /// Build TLS connector with openssl, based on supplied ALPN protocols
111    // #[cfg(all(feature = "openssl", not(feature = "rustls")))]
112    // fn build_ssl(protocols: Vec<Vec<u8>>) -> OurTlsConnector {
113    //     use actix_tls::connect::openssl::reexports::{SslConnector, SslMethod};
114    //     use bytes::{BufMut, BytesMut};
115
116    //     let mut alpn = BytesMut::with_capacity(20);
117    //     for proto in &protocols {
118    //         alpn.put_u8(proto.len() as u8);
119    //         alpn.put(proto.as_slice());
120    //     }
121
122    //     let mut ssl = SslConnector::builder(SslMethod::tls()).unwrap();
123    //     if let Err(err) = ssl.set_alpn_protos(&alpn) {
124    //         log::error!("Can not set ALPN protocol: {:?}", err);
125    //     }
126
127    //     OurTlsConnector::OpensslBuilder(ssl)
128    // }
129}
130
131impl<S> Connector<S> {
132    /// Use custom connector.
133    pub fn connector<S1, Io1>(self, connector: S1) -> Connector<S1>
134    where
135        Io1: ActixStream + fmt::Debug + 'static,
136        S1: Service<
137                ConnectInfo<Uri>,
138                Response = TcpConnection<Uri, Io1>,
139                Error = TcpConnectError,
140            > + Clone,
141    {
142        Connector {
143            connector,
144            config: self.config,
145            tls: self.tls,
146        }
147    }
148}
149
150impl<S, IO> Connector<S>
151where
152    // Note:
153    // Input Io type is bound to ActixStream trait but internally in client module they
154    // are bound to ConnectionIo trait alias. And latter is the trait exposed to public
155    // in the form of Box<dyn ConnectionIo> type.
156    //
157    // This remap is to hide ActixStream's trait methods. They are not meant to be called
158    // from user code.
159    IO: ActixStream + fmt::Debug + 'static,
160    S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, IO>, Error = TcpConnectError>
161        + Clone
162        + 'static,
163{
164    /// Tcp connection timeout, i.e. max time to connect to remote host including dns name
165    /// resolution. Set to 5 second by default.
166    pub fn timeout(mut self, timeout: Duration) -> Self {
167        self.config.timeout = timeout;
168        self
169    }
170
171    /// Tls handshake timeout, i.e. max time to do tls handshake with remote host after tcp
172    /// connection established. Set to 5 second by default.
173    pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
174        self.config.handshake_timeout = timeout;
175        self
176    }
177
178    // /// Use custom OpenSSL `SslConnector` instance.
179    // #[cfg(feature = "openssl")]
180    // pub fn openssl(
181    //     mut self,
182    //     connector: actix_tls::connect::openssl::reexports::SslConnector,
183    // ) -> Self {
184    //     self.tls = OurTlsConnector::Openssl(connector);
185    //     self
186    // }
187
188    // /// See docs for [`Connector::openssl`].
189    // #[doc(hidden)]
190    // #[cfg(feature = "openssl")]
191    // #[deprecated(since = "3.0.0", note = "Renamed to `Connector::openssl`.")]
192    // pub fn ssl(
193    //     mut self,
194    //     connector: actix_tls::connect::openssl::reexports::SslConnector,
195    // ) -> Self {
196    //     self.tls = OurTlsConnector::Openssl(connector);
197    //     self
198    // }
199
200    /// Use custom Rustls `ClientConfig` instance.
201    #[cfg(feature = "rustls")]
202    pub fn rustls(
203        mut self,
204        connector: std::sync::Arc<actix_tls::connect::rustls::reexports::ClientConfig>,
205    ) -> Self {
206        self.tls = OurTlsConnector::Rustls(connector);
207        self
208    }
209
210    /// Sets maximum supported HTTP major version.
211    ///
212    /// Supported versions are HTTP/1.1 and HTTP/2.
213    pub fn max_http_version(mut self, val: http::Version) -> Self {
214        let versions = match val {
215            http::Version::HTTP_11 => vec![b"http/1.1".to_vec()],
216            http::Version::HTTP_2 => vec![b"h2".to_vec(), b"http/1.1".to_vec()],
217            _ => {
218                unimplemented!("actix-http client only supports versions http/1.1 & http/2")
219            }
220        };
221        self.tls = Connector::build_ssl(versions);
222        self
223    }
224
225    /// Sets the initial window size (in octets) for HTTP/2 stream-level flow control for
226    /// received data.
227    ///
228    /// The default value is 65,535 and is good for APIs, but not for big objects.
229    pub fn initial_window_size(mut self, size: u32) -> Self {
230        self.config.stream_window_size = size;
231        self
232    }
233
234    /// Sets the initial window size (in octets) for HTTP/2 connection-level flow control for
235    /// received data.
236    ///
237    /// The default value is 65,535 and is good for APIs, but not for big objects.
238    pub fn initial_connection_window_size(mut self, size: u32) -> Self {
239        self.config.conn_window_size = size;
240        self
241    }
242
243    /// Set total number of simultaneous connections per type of scheme.
244    ///
245    /// If limit is 0, the connector has no limit.
246    ///
247    /// The default limit size is 100.
248    pub fn limit(mut self, limit: usize) -> Self {
249        if limit == 0 {
250            self.config.limit = u32::MAX as usize;
251        } else {
252            self.config.limit = limit;
253        }
254
255        self
256    }
257
258    /// Set keep-alive period for opened connection.
259    ///
260    /// Keep-alive period is the period between connection usage. If
261    /// the delay between repeated usages of the same connection
262    /// exceeds this period, the connection is closed.
263    /// Default keep-alive period is 15 seconds.
264    pub fn conn_keep_alive(mut self, dur: Duration) -> Self {
265        self.config.conn_keep_alive = dur;
266        self
267    }
268
269    /// Set max lifetime period for connection.
270    ///
271    /// Connection lifetime is max lifetime of any opened connection
272    /// until it is closed regardless of keep-alive period.
273    /// Default lifetime period is 75 seconds.
274    pub fn conn_lifetime(mut self, dur: Duration) -> Self {
275        self.config.conn_lifetime = dur;
276        self
277    }
278
279    /// Set server connection disconnect timeout in milliseconds.
280    ///
281    /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
282    /// within this time, the socket get dropped. This timeout affects only secure connections.
283    ///
284    /// To disable timeout set value to 0.
285    ///
286    /// By default disconnect timeout is set to 3000 milliseconds.
287    pub fn disconnect_timeout(mut self, dur: Duration) -> Self {
288        self.config.disconnect_timeout = Some(dur);
289        self
290    }
291
292    /// Set local IP Address the connector would use for establishing connection.
293    pub fn local_address(mut self, addr: IpAddr) -> Self {
294        self.config.local_address = Some(addr);
295        self
296    }
297
298    /// Finish configuration process and create connector service.
299    ///
300    /// The `Connector` builder always concludes by calling `finish()` last in its combinator chain.
301    pub fn finish(self) -> ConnectorService<S, IO> {
302        let local_address = self.config.local_address;
303        let timeout = self.config.timeout;
304
305        let tcp_service_inner =
306            TcpConnectorInnerService::new(self.connector, timeout, local_address);
307
308        #[allow(clippy::redundant_clone)]
309        let tcp_service = TcpConnectorService {
310            service: tcp_service_inner.clone(),
311        };
312
313        let tls = match self.tls {
314            // #[cfg(feature = "openssl")]
315            // OurTlsConnector::OpensslBuilder(builder) => {
316            //     OurTlsConnector::Openssl(builder.build())
317            // }
318            tls => tls,
319        };
320
321        let tls_service = match tls {
322            OurTlsConnector::None => {
323                #[cfg(not(feature = "dangerous-h2c"))]
324                {
325                    None
326                }
327
328                #[cfg(feature = "dangerous-h2c")]
329                {
330                    use std::io;
331
332                    use actix_tls::connect::Connection;
333                    use actix_utils::future::{ready, Ready};
334
335                    impl IntoConnectionIo for TcpConnection<Uri, Box<dyn ConnectionIo>> {
336                        fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
337                            let io = self.into_parts().0;
338                            (io, Protocol::Http2)
339                        }
340                    }
341
342                    /// With the `dangerous-h2c` feature enabled, this connector uses a no-op TLS
343                    /// connection service that passes through plain TCP as a TLS connection.
344                    ///
345                    /// The protocol version of this fake TLS connection is set to be HTTP/2.
346                    #[derive(Clone)]
347                    struct NoOpTlsConnectorService;
348
349                    impl<R, IO> Service<Connection<R, IO>> for NoOpTlsConnectorService
350                    where
351                        IO: ActixStream + 'static,
352                    {
353                        type Response = Connection<R, Box<dyn ConnectionIo>>;
354                        type Error = io::Error;
355                        type Future = Ready<Result<Self::Response, Self::Error>>;
356
357                        actix_service::always_ready!();
358
359                        fn call(&self, connection: Connection<R, IO>) -> Self::Future {
360                            let (io, connection) = connection.replace_io(());
361                            let (_, connection) = connection.replace_io(Box::new(io) as _);
362
363                            ready(Ok(connection))
364                        }
365                    }
366
367                    let handshake_timeout = self.config.handshake_timeout;
368
369                    let tls_service = TlsConnectorService {
370                        tcp_service: tcp_service_inner,
371                        tls_service: NoOpTlsConnectorService,
372                        timeout: handshake_timeout,
373                    };
374
375                    Some(actix_service::boxed::rc_service(tls_service))
376                }
377            }
378
379            // #[cfg(feature = "openssl")]
380            // OurTlsConnector::Openssl(tls) => {
381            //     const H2: &[u8] = b"h2";
382
383            //     use actix_tls::connect::openssl::{reexports::AsyncSslStream, TlsConnector};
384
385            //     impl<IO: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncSslStream<IO>> {
386            //         fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
387            //             let sock = self.into_parts().0;
388            //             let h2 = sock
389            //                 .ssl()
390            //                 .selected_alpn_protocol()
391            //                 .map_or(false, |protos| protos.windows(2).any(|w| w == H2));
392            //             if h2 {
393            //                 (Box::new(sock), Protocol::Http2)
394            //             } else {
395            //                 (Box::new(sock), Protocol::Http1)
396            //             }
397            //         }
398            //     }
399
400            //     let handshake_timeout = self.config.handshake_timeout;
401
402            //     let tls_service = TlsConnectorService {
403            //         tcp_service: tcp_service_inner,
404            //         tls_service: TlsConnector::service(tls),
405            //         timeout: handshake_timeout,
406            //     };
407
408            //     Some(actix_service::boxed::rc_service(tls_service))
409            // }
410
411            // #[cfg(feature = "openssl")]
412            // OurTlsConnector::OpensslBuilder(_) => {
413            //     unreachable!("OpenSSL builder is built before this match.");
414            // }
415
416            #[cfg(feature = "rustls")]
417            OurTlsConnector::Rustls(tls) => {
418                const H2: &[u8] = b"h2";
419
420                use actix_tls::connect::rustls::{reexports::AsyncTlsStream, TlsConnector};
421
422                impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, AsyncTlsStream<Io>> {
423                    fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
424                        let sock = self.into_parts().0;
425                        let h2 = sock
426                            .get_ref()
427                            .1
428                            .alpn_protocol()
429                            .map_or(false, |protos| protos.windows(2).any(|w| w == H2));
430                        if h2 {
431                            (Box::new(sock), Protocol::Http2)
432                        } else {
433                            (Box::new(sock), Protocol::Http1)
434                        }
435                    }
436                }
437
438                let handshake_timeout = self.config.handshake_timeout;
439
440                let tls_service = TlsConnectorService {
441                    tcp_service: tcp_service_inner,
442                    tls_service: TlsConnector::service(tls),
443                    timeout: handshake_timeout,
444                };
445
446                Some(actix_service::boxed::rc_service(tls_service))
447            }
448        };
449
450        let tcp_config = self.config.no_disconnect_timeout();
451
452        let tcp_pool = ConnectionPool::new(tcp_service, tcp_config);
453
454        let tls_config = self.config;
455        let tls_pool =
456            tls_service.map(move |tls_service| ConnectionPool::new(tls_service, tls_config));
457
458        ConnectorServicePriv { tcp_pool, tls_pool }
459    }
460}
461
462/// tcp service for map `TcpConnection<Uri, Io>` type to `(Io, Protocol)`
463#[derive(Clone)]
464pub struct TcpConnectorService<S: Clone> {
465    service: S,
466}
467
468impl<S, Io> Service<Connect> for TcpConnectorService<S>
469where
470    S: Service<Connect, Response = TcpConnection<Uri, Io>, Error = ConnectError>
471        + Clone
472        + 'static,
473{
474    type Response = (Io, Protocol);
475    type Error = ConnectError;
476    type Future = TcpConnectorFuture<S::Future>;
477
478    actix_service::forward_ready!(service);
479
480    fn call(&self, req: Connect) -> Self::Future {
481        TcpConnectorFuture {
482            fut: self.service.call(req),
483        }
484    }
485}
486
487pin_project! {
488    #[project = TcpConnectorFutureProj]
489    pub struct TcpConnectorFuture<Fut> {
490        #[pin]
491        fut: Fut,
492    }
493}
494
495impl<Fut, Io> Future for TcpConnectorFuture<Fut>
496where
497    Fut: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>,
498{
499    type Output = Result<(Io, Protocol), ConnectError>;
500
501    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
502        self.project()
503            .fut
504            .poll(cx)
505            .map_ok(|res| (res.into_parts().0, Protocol::Http1))
506    }
507}
508
509/// service for establish tcp connection and do client tls handshake.
510/// operation is canceled when timeout limit reached.
511struct TlsConnectorService<Tcp, Tls> {
512    /// TCP connection is canceled on `TcpConnectorInnerService`'s timeout setting.
513    tcp_service: Tcp,
514
515    /// TLS connection is canceled on `TlsConnectorService`'s timeout setting.
516    tls_service: Tls,
517
518    timeout: Duration,
519}
520
521impl<Tcp, Tls, IO> Service<Connect> for TlsConnectorService<Tcp, Tls>
522where
523    Tcp: Service<Connect, Response = TcpConnection<Uri, IO>, Error = ConnectError>
524        + Clone
525        + 'static,
526    Tls: Service<TcpConnection<Uri, IO>, Error = std::io::Error> + Clone + 'static,
527    Tls::Response: IntoConnectionIo,
528    IO: ConnectionIo,
529{
530    type Response = (Box<dyn ConnectionIo>, Protocol);
531    type Error = ConnectError;
532    type Future = TlsConnectorFuture<Tls, Tcp::Future, Tls::Future>;
533
534    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
535        ready!(self.tcp_service.poll_ready(cx))?;
536        ready!(self.tls_service.poll_ready(cx))?;
537        Poll::Ready(Ok(()))
538    }
539
540    fn call(&self, req: Connect) -> Self::Future {
541        let fut = self.tcp_service.call(req);
542        let tls_service = self.tls_service.clone();
543        let timeout = self.timeout;
544
545        TlsConnectorFuture::TcpConnect {
546            fut,
547            tls_service: Some(tls_service),
548            timeout,
549        }
550    }
551}
552
553pin_project! {
554    #[project = TlsConnectorProj]
555    #[allow(clippy::large_enum_variant)]
556    enum TlsConnectorFuture<S, Fut1, Fut2> {
557        TcpConnect {
558            #[pin]
559            fut: Fut1,
560            tls_service: Option<S>,
561            timeout: Duration,
562        },
563        TlsConnect {
564            #[pin]
565            fut: Fut2,
566            #[pin]
567            timeout: Sleep,
568        },
569    }
570
571}
572/// helper trait for generic over different TlsStream types between tls crates.
573trait IntoConnectionIo {
574    fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol);
575}
576
577impl<S, Io, Fut1, Fut2, Res> Future for TlsConnectorFuture<S, Fut1, Fut2>
578where
579    S: Service<TcpConnection<Uri, Io>, Response = Res, Error = std::io::Error, Future = Fut2>,
580    S::Response: IntoConnectionIo,
581    Fut1: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>,
582    Fut2: Future<Output = Result<S::Response, S::Error>>,
583    Io: ConnectionIo,
584{
585    type Output = Result<(Box<dyn ConnectionIo>, Protocol), ConnectError>;
586
587    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
588        match self.as_mut().project() {
589            TlsConnectorProj::TcpConnect {
590                fut,
591                tls_service,
592                timeout,
593            } => {
594                let res = ready!(fut.poll(cx))?;
595                let fut = tls_service
596                    .take()
597                    .expect("TlsConnectorFuture polled after complete")
598                    .call(res);
599                let timeout = sleep(*timeout);
600                self.set(TlsConnectorFuture::TlsConnect { fut, timeout });
601                self.poll(cx)
602            }
603            TlsConnectorProj::TlsConnect { fut, timeout } => match fut.poll(cx)? {
604                Poll::Ready(res) => Poll::Ready(Ok(res.into_connection_io())),
605                Poll::Pending => timeout.poll(cx).map(|_| Err(ConnectError::Timeout)),
606            },
607        }
608    }
609}
610
611/// service for establish tcp connection.
612/// operation is canceled when timeout limit reached.
613#[derive(Clone)]
614pub struct TcpConnectorInnerService<S: Clone> {
615    service: S,
616    timeout: Duration,
617    local_address: Option<std::net::IpAddr>,
618}
619
620impl<S: Clone> TcpConnectorInnerService<S> {
621    fn new(service: S, timeout: Duration, local_address: Option<std::net::IpAddr>) -> Self {
622        Self {
623            service,
624            timeout,
625            local_address,
626        }
627    }
628}
629
630impl<S, Io> Service<Connect> for TcpConnectorInnerService<S>
631where
632    S: Service<ConnectInfo<Uri>, Response = TcpConnection<Uri, Io>, Error = TcpConnectError>
633        + Clone
634        + 'static,
635{
636    type Response = S::Response;
637    type Error = ConnectError;
638    type Future = TcpConnectorInnerFuture<S::Future>;
639
640    actix_service::forward_ready!(service);
641
642    fn call(&self, req: Connect) -> Self::Future {
643        let mut req = ConnectInfo::new(req.uri).set_addr(req.addr);
644
645        if let Some(local_addr) = self.local_address {
646            req = req.set_local_addr(local_addr);
647        }
648
649        TcpConnectorInnerFuture {
650            fut: self.service.call(req),
651            timeout: sleep(self.timeout),
652        }
653    }
654}
655
656pin_project! {
657    #[project = TcpConnectorInnerFutureProj]
658    pub struct TcpConnectorInnerFuture<Fut> {
659        #[pin]
660        fut: Fut,
661        #[pin]
662        timeout: Sleep,
663    }
664}
665
666impl<Fut, Io> Future for TcpConnectorInnerFuture<Fut>
667where
668    Fut: Future<Output = Result<TcpConnection<Uri, Io>, TcpConnectError>>,
669{
670    type Output = Result<TcpConnection<Uri, Io>, ConnectError>;
671
672    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
673        let this = self.project();
674        match this.fut.poll(cx) {
675            Poll::Ready(res) => Poll::Ready(res.map_err(ConnectError::from)),
676            Poll::Pending => this.timeout.poll(cx).map(|_| Err(ConnectError::Timeout)),
677        }
678    }
679}
680
681/// Connector service for pooled Plain/Tls Tcp connections.
682pub type ConnectorService<Svc, IO> = ConnectorServicePriv<
683    TcpConnectorService<TcpConnectorInnerService<Svc>>,
684    Rc<
685        dyn Service<
686            Connect,
687            Response = (Box<dyn ConnectionIo>, Protocol),
688            Error = ConnectError,
689            Future = LocalBoxFuture<
690                'static,
691                Result<(Box<dyn ConnectionIo>, Protocol), ConnectError>,
692            >,
693        >,
694    >,
695    IO,
696    Box<dyn ConnectionIo>,
697>;
698
699pub struct ConnectorServicePriv<S1, S2, Io1, Io2>
700where
701    S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError>,
702    S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError>,
703    Io1: ConnectionIo,
704    Io2: ConnectionIo,
705{
706    tcp_pool: ConnectionPool<S1, Io1>,
707    tls_pool: Option<ConnectionPool<S2, Io2>>,
708}
709
710impl<S1, S2, Io1, Io2> Service<Connect> for ConnectorServicePriv<S1, S2, Io1, Io2>
711where
712    S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + Clone + 'static,
713    S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + Clone + 'static,
714    Io1: ConnectionIo,
715    Io2: ConnectionIo,
716{
717    type Response = Connection<Io1, Io2>;
718    type Error = ConnectError;
719    type Future = ConnectorServiceFuture<S1, S2, Io1, Io2>;
720
721    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
722        ready!(self.tcp_pool.poll_ready(cx))?;
723        if let Some(ref tls_pool) = self.tls_pool {
724            ready!(tls_pool.poll_ready(cx))?;
725        }
726        Poll::Ready(Ok(()))
727    }
728
729    fn call(&self, req: Connect) -> Self::Future {
730        match req.uri.scheme_str() {
731            Some("https") | Some("wss") => match self.tls_pool {
732                None => ConnectorServiceFuture::SslIsNotSupported,
733                Some(ref pool) => ConnectorServiceFuture::Tls {
734                    fut: pool.call(req),
735                },
736            },
737            _ => ConnectorServiceFuture::Tcp {
738                fut: self.tcp_pool.call(req),
739            },
740        }
741    }
742}
743
744pin_project! {
745    #[project = ConnectorServiceFutureProj]
746    pub enum ConnectorServiceFuture<S1, S2, Io1, Io2>
747    where
748        S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError>,
749        S1: Clone,
750        S1: 'static,
751        S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError>,
752        S2: Clone,
753        S2: 'static,
754        Io1: ConnectionIo,
755        Io2: ConnectionIo,
756    {
757        Tcp {
758            #[pin]
759            fut: <ConnectionPool<S1, Io1> as Service<Connect>>::Future
760        },
761        Tls {
762            #[pin]
763            fut:  <ConnectionPool<S2, Io2> as Service<Connect>>::Future
764        },
765        SslIsNotSupported
766    }
767}
768
769impl<S1, S2, Io1, Io2> Future for ConnectorServiceFuture<S1, S2, Io1, Io2>
770where
771    S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + Clone + 'static,
772    S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + Clone + 'static,
773    Io1: ConnectionIo,
774    Io2: ConnectionIo,
775{
776    type Output = Result<Connection<Io1, Io2>, ConnectError>;
777
778    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
779        match self.project() {
780            ConnectorServiceFutureProj::Tcp { fut } => fut.poll(cx).map_ok(Connection::Tcp),
781            ConnectorServiceFutureProj::Tls { fut } => fut.poll(cx).map_ok(Connection::Tls),
782            ConnectorServiceFutureProj::SslIsNotSupported => {
783                Poll::Ready(Err(ConnectError::SslIsNotSupported))
784            }
785        }
786    }
787}
788
789#[cfg(not(feature = "trust-dns"))]
790mod resolver {
791    use super::*;
792
793    pub(super) fn resolver() -> Resolver {
794        Resolver::default()
795    }
796}
797
798#[cfg(feature = "trust-dns")]
799mod resolver {
800    use std::{cell::RefCell, net::SocketAddr};
801
802    use actix_tls::connect::Resolve;
803    use futures_core::future::LocalBoxFuture;
804    use trust_dns_resolver::{
805        config::{ResolverConfig, ResolverOpts},
806        system_conf::read_system_conf,
807        TokioAsyncResolver,
808    };
809
810    use super::*;
811
812    pub(super) fn resolver() -> Resolver {
813        // new type for impl Resolve trait for TokioAsyncResolver.
814        struct TrustDnsResolver(TokioAsyncResolver);
815
816        impl Resolve for TrustDnsResolver {
817            fn lookup<'a>(
818                &'a self,
819                host: &'a str,
820                port: u16,
821            ) -> LocalBoxFuture<'a, Result<Vec<SocketAddr>, Box<dyn std::error::Error>>>
822            {
823                Box::pin(async move {
824                    let res = self
825                        .0
826                        .lookup_ip(host)
827                        .await?
828                        .iter()
829                        .map(|ip| SocketAddr::new(ip, port))
830                        .collect();
831                    Ok(res)
832                })
833            }
834        }
835
836        // resolver struct is cached in thread local so new clients can reuse the existing instance
837        thread_local! {
838            static TRUST_DNS_RESOLVER: RefCell<Option<Resolver>> = RefCell::new(None);
839        }
840
841        // get from thread local or construct a new trust-dns resolver.
842        TRUST_DNS_RESOLVER.with(|local| {
843            let resolver = local.borrow().as_ref().map(Clone::clone);
844
845            match resolver {
846                Some(resolver) => resolver,
847
848                None => {
849                    let (cfg, opts) = match read_system_conf() {
850                        Ok((cfg, opts)) => (cfg, opts),
851                        Err(e) => {
852                            log::error!("TRust-DNS can not load system config: {}", e);
853                            (ResolverConfig::default(), ResolverOpts::default())
854                        }
855                    };
856
857                    let resolver = TokioAsyncResolver::tokio(cfg, opts).unwrap();
858
859                    // box trust dns resolver and put it in thread local.
860                    let resolver = Resolver::custom(TrustDnsResolver(resolver));
861                    *local.borrow_mut() = Some(resolver.clone());
862
863                    resolver
864                }
865            }
866        })
867    }
868}
869
870#[cfg(feature = "dangerous-h2c")]
871#[cfg(test)]
872mod tests {
873    use std::convert::Infallible;
874
875    use actix_http::{HttpService, Request, Response, Version};
876    use actix_http_test::test_server;
877    use actix_service::ServiceFactoryExt as _;
878
879    use super::*;
880    use crate::Client;
881
882    #[actix_rt::test]
883    async fn h2c_connector() {
884        let mut srv = test_server(|| {
885            HttpService::build()
886                .h2(|_req: Request| async { Ok::<_, Infallible>(Response::ok()) })
887                .tcp()
888                .map_err(|_| ())
889        })
890        .await;
891
892        let connector = Connector {
893            connector: TcpConnector::new(resolver::resolver()).service(),
894            config: ConnectorConfig::default(),
895            tls: OurTlsConnector::None,
896        };
897
898        let client = Client::builder().connector(connector).finish();
899
900        let request = client.get(srv.surl("/")).send();
901        let response = request.await.unwrap();
902        assert!(response.status().is_success());
903        assert_eq!(response.version(), Version::HTTP_2);
904
905        srv.stop().await;
906    }
907}