Skip to main content

actix_http/
service.rs

1use std::{
2    fmt,
3    future::Future,
4    marker::PhantomData,
5    net,
6    pin::Pin,
7    rc::Rc,
8    task::{Context, Poll},
9};
10
11use actix_codec::{AsyncRead, AsyncWrite, Framed};
12use actix_rt::net::TcpStream;
13use actix_service::{
14    fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _,
15};
16use futures_core::{future::LocalBoxFuture, ready};
17use pin_project_lite::pin_project;
18use tracing::error;
19
20use crate::{
21    body::{BoxBody, MessageBody},
22    builder::HttpServiceBuilder,
23    error::DispatchError,
24    h1, ConnectCallback, OnConnectData, Protocol, Request, Response, ServiceConfig,
25};
26
27#[inline]
28fn desired_nodelay(tcp_nodelay: Option<bool>) -> Option<bool> {
29    tcp_nodelay
30}
31
32#[inline]
33fn set_nodelay(stream: &TcpStream, nodelay: bool) {
34    let _ = stream.set_nodelay(nodelay);
35}
36
37/// A [`ServiceFactory`] for HTTP/1.1 and HTTP/2 connections.
38///
39/// Use [`build`](Self::build) to begin constructing service. Also see [`HttpServiceBuilder`].
40///
41/// # Automatic HTTP Version Selection
42/// There are two ways to select the HTTP version of an incoming connection:
43/// - One is to rely on the ALPN information that is provided when using TLS (HTTPS); both versions
44///   are supported automatically when using either of the `.rustls()` or `.openssl()` finalizing
45///   methods.
46/// - The other is to read the first few bytes of the TCP stream. This is the only viable approach
47///   for supporting H2C, which allows the HTTP/2 protocol to work over plaintext connections. Use
48///   the `.tcp_auto_h2c()` finalizing method to enable this behavior.
49///
50/// # Examples
51/// ```
52/// # use std::convert::Infallible;
53/// use actix_http::{HttpService, Request, Response, StatusCode};
54///
55/// // this service would constructed in an actix_server::Server
56///
57/// # actix_rt::System::new().block_on(async {
58/// HttpService::build()
59///     // the builder finalizing method, other finalizers would not return an `HttpService`
60///     .finish(|_req: Request| async move {
61///         Ok::<_, Infallible>(
62///             Response::build(StatusCode::OK).body("Hello!")
63///         )
64///     })
65///     // the service finalizing method method
66///     // you can use `.tcp_auto_h2c()`, `.rustls()`, or `.openssl()` instead of `.tcp()`
67///     .tcp();
68/// # })
69/// ```
70pub struct HttpService<T, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler> {
71    srv: S,
72    cfg: ServiceConfig,
73    expect: X,
74    upgrade: Option<U>,
75    on_connect_ext: Option<Rc<ConnectCallback<T>>>,
76    _phantom: PhantomData<B>,
77}
78
79impl<T, S, B> HttpService<T, S, B>
80where
81    S: ServiceFactory<Request, Config = ()>,
82    S::Error: Into<Response<BoxBody>> + 'static,
83    S::InitError: fmt::Debug,
84    S::Response: Into<Response<B>> + 'static,
85    <S::Service as Service<Request>>::Future: 'static,
86    B: MessageBody + 'static,
87{
88    /// Constructs builder for `HttpService` instance.
89    pub fn build() -> HttpServiceBuilder<T, S> {
90        HttpServiceBuilder::default()
91    }
92}
93
94impl<T, S, B> HttpService<T, S, B>
95where
96    S: ServiceFactory<Request, Config = ()>,
97    S::Error: Into<Response<BoxBody>> + 'static,
98    S::InitError: fmt::Debug,
99    S::Response: Into<Response<B>> + 'static,
100    <S::Service as Service<Request>>::Future: 'static,
101    B: MessageBody + 'static,
102{
103    /// Constructs new `HttpService` instance from service with default config.
104    pub fn new<F: IntoServiceFactory<S, Request>>(service: F) -> Self {
105        HttpService {
106            cfg: ServiceConfig::default(),
107            srv: service.into_factory(),
108            expect: h1::ExpectHandler,
109            upgrade: None,
110            on_connect_ext: None,
111            _phantom: PhantomData,
112        }
113    }
114
115    /// Constructs new `HttpService` instance from config and service.
116    pub(crate) fn with_config<F: IntoServiceFactory<S, Request>>(
117        cfg: ServiceConfig,
118        service: F,
119    ) -> Self {
120        HttpService {
121            cfg,
122            srv: service.into_factory(),
123            expect: h1::ExpectHandler,
124            upgrade: None,
125            on_connect_ext: None,
126            _phantom: PhantomData,
127        }
128    }
129}
130
131impl<T, S, B, X, U> HttpService<T, S, B, X, U>
132where
133    S: ServiceFactory<Request, Config = ()>,
134    S::Error: Into<Response<BoxBody>> + 'static,
135    S::InitError: fmt::Debug,
136    S::Response: Into<Response<B>> + 'static,
137    <S::Service as Service<Request>>::Future: 'static,
138    B: MessageBody,
139{
140    /// Sets service for `Expect: 100-Continue` handling.
141    ///
142    /// An expect service is called with requests that contain an `Expect` header. A successful
143    /// response type is also a request which will be forwarded to the main service.
144    pub fn expect<X1>(self, expect: X1) -> HttpService<T, S, B, X1, U>
145    where
146        X1: ServiceFactory<Request, Config = (), Response = Request>,
147        X1::Error: Into<Response<BoxBody>>,
148        X1::InitError: fmt::Debug,
149    {
150        HttpService {
151            expect,
152            cfg: self.cfg,
153            srv: self.srv,
154            upgrade: self.upgrade,
155            on_connect_ext: self.on_connect_ext,
156            _phantom: PhantomData,
157        }
158    }
159
160    /// Sets service for custom `Connection: Upgrade` handling.
161    ///
162    /// If service is provided then normal requests handling get halted and this service get called
163    /// with original request and framed object.
164    pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<T, S, B, X, U1>
165    where
166        U1: ServiceFactory<(Request, Framed<T, h1::Codec>), Config = (), Response = ()>,
167        U1::Error: fmt::Display,
168        U1::InitError: fmt::Debug,
169    {
170        HttpService {
171            upgrade,
172            cfg: self.cfg,
173            srv: self.srv,
174            expect: self.expect,
175            on_connect_ext: self.on_connect_ext,
176            _phantom: PhantomData,
177        }
178    }
179
180    /// Set connect callback with mutable access to request data container.
181    pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
182        self.on_connect_ext = f;
183        self
184    }
185}
186
187impl<S, B, X, U> HttpService<TcpStream, S, B, X, U>
188where
189    S: ServiceFactory<Request, Config = ()>,
190    S::Future: 'static,
191    S::Error: Into<Response<BoxBody>> + 'static,
192    S::InitError: fmt::Debug,
193    S::Response: Into<Response<B>> + 'static,
194    <S::Service as Service<Request>>::Future: 'static,
195
196    B: MessageBody + 'static,
197
198    X: ServiceFactory<Request, Config = (), Response = Request>,
199    X::Future: 'static,
200    X::Error: Into<Response<BoxBody>>,
201    X::InitError: fmt::Debug,
202
203    U: ServiceFactory<(Request, Framed<TcpStream, h1::Codec>), Config = (), Response = ()>,
204    U::Future: 'static,
205    U::Error: fmt::Display + Into<Response<BoxBody>>,
206    U::InitError: fmt::Debug,
207{
208    /// Creates TCP stream service from HTTP service.
209    ///
210    /// The resulting service only supports HTTP/1.x.
211    pub fn tcp(
212        self,
213    ) -> impl ServiceFactory<TcpStream, Config = (), Response = (), Error = DispatchError, InitError = ()>
214    {
215        let tcp_nodelay = self.cfg.tcp_nodelay();
216
217        fn_service(move |io: TcpStream| async move {
218            if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
219                set_nodelay(&io, nodelay);
220            }
221
222            let peer_addr = io.peer_addr().ok();
223            Ok((io, Protocol::Http1, peer_addr))
224        })
225        .and_then(self)
226    }
227
228    /// Creates TCP stream service from HTTP service that automatically selects HTTP/1.x or HTTP/2
229    /// on plaintext connections.
230    #[cfg(feature = "http2")]
231    pub fn tcp_auto_h2c(
232        self,
233    ) -> impl ServiceFactory<TcpStream, Config = (), Response = (), Error = DispatchError, InitError = ()>
234    {
235        let tcp_nodelay = self.cfg.tcp_nodelay();
236
237        fn_service(move |io: TcpStream| async move {
238            // subset of HTTP/2 preface defined by RFC 9113 ยง3.4
239            // this subset was chosen to maximize likelihood that peeking only once will allow us to
240            // reliably determine version or else it should fallback to h1 and fail quickly if data
241            // on the wire is junk
242            const H2_PREFACE: &[u8] = b"PRI * HTTP/2";
243
244            let mut buf = [0; 12];
245
246            io.peek(&mut buf).await?;
247
248            let proto = if buf == H2_PREFACE {
249                Protocol::Http2
250            } else {
251                Protocol::Http1
252            };
253
254            if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
255                set_nodelay(&io, nodelay);
256            }
257
258            let peer_addr = io.peer_addr().ok();
259            Ok((io, proto, peer_addr))
260        })
261        .and_then(self)
262    }
263}
264
265/// Configuration options used when accepting TLS connection.
266#[cfg(feature = "__tls")]
267#[derive(Debug, Default)]
268pub struct TlsAcceptorConfig {
269    pub(crate) handshake_timeout: Option<std::time::Duration>,
270}
271
272#[cfg(feature = "__tls")]
273impl TlsAcceptorConfig {
274    /// Set TLS handshake timeout duration.
275    pub fn handshake_timeout(self, dur: std::time::Duration) -> Self {
276        Self {
277            handshake_timeout: Some(dur),
278            // ..self
279        }
280    }
281}
282
283#[cfg(feature = "openssl")]
284mod openssl {
285    use actix_tls::accept::{
286        openssl::{
287            reexports::{Error as SslError, SslAcceptor},
288            Acceptor, TlsStream,
289        },
290        TlsError,
291    };
292
293    use super::*;
294
295    impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
296    where
297        S: ServiceFactory<Request, Config = ()>,
298        S::Future: 'static,
299        S::Error: Into<Response<BoxBody>> + 'static,
300        S::InitError: fmt::Debug,
301        S::Response: Into<Response<B>> + 'static,
302        <S::Service as Service<Request>>::Future: 'static,
303
304        B: MessageBody + 'static,
305
306        X: ServiceFactory<Request, Config = (), Response = Request>,
307        X::Future: 'static,
308        X::Error: Into<Response<BoxBody>>,
309        X::InitError: fmt::Debug,
310
311        U: ServiceFactory<
312            (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
313            Config = (),
314            Response = (),
315        >,
316        U::Future: 'static,
317        U::Error: fmt::Display + Into<Response<BoxBody>>,
318        U::InitError: fmt::Debug,
319    {
320        /// Create OpenSSL based service.
321        pub fn openssl(
322            self,
323            acceptor: SslAcceptor,
324        ) -> impl ServiceFactory<
325            TcpStream,
326            Config = (),
327            Response = (),
328            Error = TlsError<SslError, DispatchError>,
329            InitError = (),
330        > {
331            self.openssl_with_config(acceptor, TlsAcceptorConfig::default())
332        }
333
334        /// Create OpenSSL based service with custom TLS acceptor configuration.
335        pub fn openssl_with_config(
336            self,
337            acceptor: SslAcceptor,
338            tls_acceptor_config: TlsAcceptorConfig,
339        ) -> impl ServiceFactory<
340            TcpStream,
341            Config = (),
342            Response = (),
343            Error = TlsError<SslError, DispatchError>,
344            InitError = (),
345        > {
346            let tcp_nodelay = self.cfg.tcp_nodelay();
347            let mut acceptor = Acceptor::new(acceptor);
348
349            if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
350                acceptor.set_handshake_timeout(handshake_timeout);
351            }
352
353            acceptor
354                .map_init_err(|_| {
355                    unreachable!("TLS acceptor service factory does not error on init")
356                })
357                .map_err(TlsError::into_service_error)
358                .map(move |io: TlsStream<TcpStream>| {
359                    let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() {
360                        if protos.windows(2).any(|window| window == b"h2") {
361                            Protocol::Http2
362                        } else {
363                            Protocol::Http1
364                        }
365                    } else {
366                        Protocol::Http1
367                    };
368
369                    if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
370                        set_nodelay(io.get_ref(), nodelay);
371                    }
372
373                    let peer_addr = io.get_ref().peer_addr().ok();
374                    (io, proto, peer_addr)
375                })
376                .and_then(self.map_err(TlsError::Service))
377        }
378    }
379}
380
381#[cfg(feature = "rustls-0_20")]
382mod rustls_0_20 {
383    use std::io;
384
385    use actix_tls::accept::{
386        rustls_0_20::{reexports::ServerConfig, Acceptor, TlsStream},
387        TlsError,
388    };
389
390    use super::*;
391
392    impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
393    where
394        S: ServiceFactory<Request, Config = ()>,
395        S::Future: 'static,
396        S::Error: Into<Response<BoxBody>> + 'static,
397        S::InitError: fmt::Debug,
398        S::Response: Into<Response<B>> + 'static,
399        <S::Service as Service<Request>>::Future: 'static,
400
401        B: MessageBody + 'static,
402
403        X: ServiceFactory<Request, Config = (), Response = Request>,
404        X::Future: 'static,
405        X::Error: Into<Response<BoxBody>>,
406        X::InitError: fmt::Debug,
407
408        U: ServiceFactory<
409            (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
410            Config = (),
411            Response = (),
412        >,
413        U::Future: 'static,
414        U::Error: fmt::Display + Into<Response<BoxBody>>,
415        U::InitError: fmt::Debug,
416    {
417        /// Create Rustls v0.20 based service.
418        pub fn rustls(
419            self,
420            config: ServerConfig,
421        ) -> impl ServiceFactory<
422            TcpStream,
423            Config = (),
424            Response = (),
425            Error = TlsError<io::Error, DispatchError>,
426            InitError = (),
427        > {
428            self.rustls_with_config(config, TlsAcceptorConfig::default())
429        }
430
431        /// Create Rustls v0.20 based service with custom TLS acceptor configuration.
432        pub fn rustls_with_config(
433            self,
434            mut config: ServerConfig,
435            tls_acceptor_config: TlsAcceptorConfig,
436        ) -> impl ServiceFactory<
437            TcpStream,
438            Config = (),
439            Response = (),
440            Error = TlsError<io::Error, DispatchError>,
441            InitError = (),
442        > {
443            let tcp_nodelay = self.cfg.tcp_nodelay();
444            let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
445            protos.extend_from_slice(&config.alpn_protocols);
446            config.alpn_protocols = protos;
447
448            let mut acceptor = Acceptor::new(config);
449
450            if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
451                acceptor.set_handshake_timeout(handshake_timeout);
452            }
453
454            acceptor
455                .map_init_err(|_| {
456                    unreachable!("TLS acceptor service factory does not error on init")
457                })
458                .map_err(TlsError::into_service_error)
459                .and_then(move |io: TlsStream<TcpStream>| async move {
460                    let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
461                        if protos.windows(2).any(|window| window == b"h2") {
462                            Protocol::Http2
463                        } else {
464                            Protocol::Http1
465                        }
466                    } else {
467                        Protocol::Http1
468                    };
469
470                    if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
471                        set_nodelay(io.get_ref().0, nodelay);
472                    }
473
474                    let peer_addr = io.get_ref().0.peer_addr().ok();
475                    Ok((io, proto, peer_addr))
476                })
477                .and_then(self.map_err(TlsError::Service))
478        }
479    }
480}
481
482#[cfg(feature = "rustls-0_21")]
483mod rustls_0_21 {
484    use std::io;
485
486    use actix_tls::accept::{
487        rustls_0_21::{reexports::ServerConfig, Acceptor, TlsStream},
488        TlsError,
489    };
490
491    use super::*;
492
493    impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
494    where
495        S: ServiceFactory<Request, Config = ()>,
496        S::Future: 'static,
497        S::Error: Into<Response<BoxBody>> + 'static,
498        S::InitError: fmt::Debug,
499        S::Response: Into<Response<B>> + 'static,
500        <S::Service as Service<Request>>::Future: 'static,
501
502        B: MessageBody + 'static,
503
504        X: ServiceFactory<Request, Config = (), Response = Request>,
505        X::Future: 'static,
506        X::Error: Into<Response<BoxBody>>,
507        X::InitError: fmt::Debug,
508
509        U: ServiceFactory<
510            (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
511            Config = (),
512            Response = (),
513        >,
514        U::Future: 'static,
515        U::Error: fmt::Display + Into<Response<BoxBody>>,
516        U::InitError: fmt::Debug,
517    {
518        /// Create Rustls v0.21 based service.
519        pub fn rustls_021(
520            self,
521            config: ServerConfig,
522        ) -> impl ServiceFactory<
523            TcpStream,
524            Config = (),
525            Response = (),
526            Error = TlsError<io::Error, DispatchError>,
527            InitError = (),
528        > {
529            self.rustls_021_with_config(config, TlsAcceptorConfig::default())
530        }
531
532        /// Create Rustls v0.21 based service with custom TLS acceptor configuration.
533        pub fn rustls_021_with_config(
534            self,
535            mut config: ServerConfig,
536            tls_acceptor_config: TlsAcceptorConfig,
537        ) -> impl ServiceFactory<
538            TcpStream,
539            Config = (),
540            Response = (),
541            Error = TlsError<io::Error, DispatchError>,
542            InitError = (),
543        > {
544            let tcp_nodelay = self.cfg.tcp_nodelay();
545            let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
546            protos.extend_from_slice(&config.alpn_protocols);
547            config.alpn_protocols = protos;
548
549            let mut acceptor = Acceptor::new(config);
550
551            if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
552                acceptor.set_handshake_timeout(handshake_timeout);
553            }
554
555            acceptor
556                .map_init_err(|_| {
557                    unreachable!("TLS acceptor service factory does not error on init")
558                })
559                .map_err(TlsError::into_service_error)
560                .and_then(move |io: TlsStream<TcpStream>| async move {
561                    let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
562                        if protos.windows(2).any(|window| window == b"h2") {
563                            Protocol::Http2
564                        } else {
565                            Protocol::Http1
566                        }
567                    } else {
568                        Protocol::Http1
569                    };
570
571                    if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
572                        set_nodelay(io.get_ref().0, nodelay);
573                    }
574
575                    let peer_addr = io.get_ref().0.peer_addr().ok();
576                    Ok((io, proto, peer_addr))
577                })
578                .and_then(self.map_err(TlsError::Service))
579        }
580    }
581}
582
583#[cfg(feature = "rustls-0_22")]
584mod rustls_0_22 {
585    use std::io;
586
587    use actix_tls::accept::{
588        rustls_0_22::{reexports::ServerConfig, Acceptor, TlsStream},
589        TlsError,
590    };
591
592    use super::*;
593
594    impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
595    where
596        S: ServiceFactory<Request, Config = ()>,
597        S::Future: 'static,
598        S::Error: Into<Response<BoxBody>> + 'static,
599        S::InitError: fmt::Debug,
600        S::Response: Into<Response<B>> + 'static,
601        <S::Service as Service<Request>>::Future: 'static,
602
603        B: MessageBody + 'static,
604
605        X: ServiceFactory<Request, Config = (), Response = Request>,
606        X::Future: 'static,
607        X::Error: Into<Response<BoxBody>>,
608        X::InitError: fmt::Debug,
609
610        U: ServiceFactory<
611            (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
612            Config = (),
613            Response = (),
614        >,
615        U::Future: 'static,
616        U::Error: fmt::Display + Into<Response<BoxBody>>,
617        U::InitError: fmt::Debug,
618    {
619        /// Create Rustls v0.22 based service.
620        pub fn rustls_0_22(
621            self,
622            config: ServerConfig,
623        ) -> impl ServiceFactory<
624            TcpStream,
625            Config = (),
626            Response = (),
627            Error = TlsError<io::Error, DispatchError>,
628            InitError = (),
629        > {
630            self.rustls_0_22_with_config(config, TlsAcceptorConfig::default())
631        }
632
633        /// Create Rustls v0.22 based service with custom TLS acceptor configuration.
634        pub fn rustls_0_22_with_config(
635            self,
636            mut config: ServerConfig,
637            tls_acceptor_config: TlsAcceptorConfig,
638        ) -> impl ServiceFactory<
639            TcpStream,
640            Config = (),
641            Response = (),
642            Error = TlsError<io::Error, DispatchError>,
643            InitError = (),
644        > {
645            let tcp_nodelay = self.cfg.tcp_nodelay();
646            let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
647            protos.extend_from_slice(&config.alpn_protocols);
648            config.alpn_protocols = protos;
649
650            let mut acceptor = Acceptor::new(config);
651
652            if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
653                acceptor.set_handshake_timeout(handshake_timeout);
654            }
655
656            acceptor
657                .map_init_err(|_| {
658                    unreachable!("TLS acceptor service factory does not error on init")
659                })
660                .map_err(TlsError::into_service_error)
661                .and_then(move |io: TlsStream<TcpStream>| async move {
662                    let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
663                        if protos.windows(2).any(|window| window == b"h2") {
664                            Protocol::Http2
665                        } else {
666                            Protocol::Http1
667                        }
668                    } else {
669                        Protocol::Http1
670                    };
671
672                    if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
673                        set_nodelay(io.get_ref().0, nodelay);
674                    }
675
676                    let peer_addr = io.get_ref().0.peer_addr().ok();
677                    Ok((io, proto, peer_addr))
678                })
679                .and_then(self.map_err(TlsError::Service))
680        }
681    }
682}
683
684#[cfg(feature = "rustls-0_23")]
685mod rustls_0_23 {
686    use std::io;
687
688    use actix_tls::accept::{
689        rustls_0_23::{reexports::ServerConfig, Acceptor, TlsStream},
690        TlsError,
691    };
692
693    use super::*;
694
695    impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
696    where
697        S: ServiceFactory<Request, Config = ()>,
698        S::Future: 'static,
699        S::Error: Into<Response<BoxBody>> + 'static,
700        S::InitError: fmt::Debug,
701        S::Response: Into<Response<B>> + 'static,
702        <S::Service as Service<Request>>::Future: 'static,
703
704        B: MessageBody + 'static,
705
706        X: ServiceFactory<Request, Config = (), Response = Request>,
707        X::Future: 'static,
708        X::Error: Into<Response<BoxBody>>,
709        X::InitError: fmt::Debug,
710
711        U: ServiceFactory<
712            (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
713            Config = (),
714            Response = (),
715        >,
716        U::Future: 'static,
717        U::Error: fmt::Display + Into<Response<BoxBody>>,
718        U::InitError: fmt::Debug,
719    {
720        /// Create Rustls v0.23 based service.
721        pub fn rustls_0_23(
722            self,
723            config: ServerConfig,
724        ) -> impl ServiceFactory<
725            TcpStream,
726            Config = (),
727            Response = (),
728            Error = TlsError<io::Error, DispatchError>,
729            InitError = (),
730        > {
731            self.rustls_0_23_with_config(config, TlsAcceptorConfig::default())
732        }
733
734        /// Create Rustls v0.23 based service with custom TLS acceptor configuration.
735        pub fn rustls_0_23_with_config(
736            self,
737            mut config: ServerConfig,
738            tls_acceptor_config: TlsAcceptorConfig,
739        ) -> impl ServiceFactory<
740            TcpStream,
741            Config = (),
742            Response = (),
743            Error = TlsError<io::Error, DispatchError>,
744            InitError = (),
745        > {
746            let tcp_nodelay = self.cfg.tcp_nodelay();
747            let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
748            protos.extend_from_slice(&config.alpn_protocols);
749            config.alpn_protocols = protos;
750
751            let mut acceptor = Acceptor::new(config);
752
753            if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
754                acceptor.set_handshake_timeout(handshake_timeout);
755            }
756
757            acceptor
758                .map_init_err(|_| {
759                    unreachable!("TLS acceptor service factory does not error on init")
760                })
761                .map_err(TlsError::into_service_error)
762                .and_then(move |io: TlsStream<TcpStream>| async move {
763                    let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
764                        if protos.windows(2).any(|window| window == b"h2") {
765                            Protocol::Http2
766                        } else {
767                            Protocol::Http1
768                        }
769                    } else {
770                        Protocol::Http1
771                    };
772
773                    if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
774                        set_nodelay(io.get_ref().0, nodelay);
775                    }
776
777                    let peer_addr = io.get_ref().0.peer_addr().ok();
778                    Ok((io, proto, peer_addr))
779                })
780                .and_then(self.map_err(TlsError::Service))
781        }
782    }
783}
784
785impl<T, S, B, X, U> ServiceFactory<(T, Protocol, Option<net::SocketAddr>)>
786    for HttpService<T, S, B, X, U>
787where
788    T: AsyncRead + AsyncWrite + Unpin + 'static,
789
790    S: ServiceFactory<Request, Config = ()>,
791    S::Future: 'static,
792    S::Error: Into<Response<BoxBody>> + 'static,
793    S::InitError: fmt::Debug,
794    S::Response: Into<Response<B>> + 'static,
795    <S::Service as Service<Request>>::Future: 'static,
796
797    B: MessageBody + 'static,
798
799    X: ServiceFactory<Request, Config = (), Response = Request>,
800    X::Future: 'static,
801    X::Error: Into<Response<BoxBody>>,
802    X::InitError: fmt::Debug,
803
804    U: ServiceFactory<(Request, Framed<T, h1::Codec>), Config = (), Response = ()>,
805    U::Future: 'static,
806    U::Error: fmt::Display + Into<Response<BoxBody>>,
807    U::InitError: fmt::Debug,
808{
809    type Response = ();
810    type Error = DispatchError;
811    type Config = ();
812    type Service = HttpServiceHandler<T, S::Service, B, X::Service, U::Service>;
813    type InitError = ();
814    type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
815
816    fn new_service(&self, _: ()) -> Self::Future {
817        let service = self.srv.new_service(());
818        let expect = self.expect.new_service(());
819        let upgrade = self.upgrade.as_ref().map(|s| s.new_service(()));
820        let on_connect_ext = self.on_connect_ext.clone();
821        let cfg = self.cfg.clone();
822
823        Box::pin(async move {
824            let expect = expect.await.map_err(|err| {
825                tracing::error!("Initialization of HTTP expect service error: {err:?}");
826            })?;
827
828            let upgrade = match upgrade {
829                Some(upgrade) => {
830                    let upgrade = upgrade.await.map_err(|err| {
831                        tracing::error!("Initialization of HTTP upgrade service error: {err:?}");
832                    })?;
833                    Some(upgrade)
834                }
835                None => None,
836            };
837
838            let service = service.await.map_err(|err| {
839                tracing::error!("Initialization of HTTP service error: {err:?}");
840            })?;
841
842            Ok(HttpServiceHandler::new(
843                cfg,
844                service,
845                expect,
846                upgrade,
847                on_connect_ext,
848            ))
849        })
850    }
851}
852
853/// `Service` implementation for HTTP/1 and HTTP/2 transport
854pub struct HttpServiceHandler<T, S, B, X, U>
855where
856    S: Service<Request>,
857    X: Service<Request>,
858    U: Service<(Request, Framed<T, h1::Codec>)>,
859{
860    pub(super) flow: Rc<HttpFlow<S, X, U>>,
861    pub(super) cfg: ServiceConfig,
862    pub(super) on_connect_ext: Option<Rc<ConnectCallback<T>>>,
863    _phantom: PhantomData<B>,
864}
865
866impl<T, S, B, X, U> HttpServiceHandler<T, S, B, X, U>
867where
868    S: Service<Request>,
869    S::Error: Into<Response<BoxBody>>,
870    X: Service<Request>,
871    X::Error: Into<Response<BoxBody>>,
872    U: Service<(Request, Framed<T, h1::Codec>)>,
873    U::Error: Into<Response<BoxBody>>,
874{
875    pub(super) fn new(
876        cfg: ServiceConfig,
877        service: S,
878        expect: X,
879        upgrade: Option<U>,
880        on_connect_ext: Option<Rc<ConnectCallback<T>>>,
881    ) -> HttpServiceHandler<T, S, B, X, U> {
882        HttpServiceHandler {
883            cfg,
884            on_connect_ext,
885            flow: HttpFlow::new(service, expect, upgrade),
886            _phantom: PhantomData,
887        }
888    }
889
890    pub(super) fn _poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Response<BoxBody>>> {
891        ready!(self.flow.expect.poll_ready(cx).map_err(Into::into))?;
892
893        ready!(self.flow.service.poll_ready(cx).map_err(Into::into))?;
894
895        if let Some(ref upg) = self.flow.upgrade {
896            ready!(upg.poll_ready(cx).map_err(Into::into))?;
897        };
898
899        Poll::Ready(Ok(()))
900    }
901}
902
903/// A collection of services that describe an HTTP request flow.
904pub(super) struct HttpFlow<S, X, U> {
905    pub(super) service: S,
906    pub(super) expect: X,
907    pub(super) upgrade: Option<U>,
908}
909
910impl<S, X, U> HttpFlow<S, X, U> {
911    pub(super) fn new(service: S, expect: X, upgrade: Option<U>) -> Rc<Self> {
912        Rc::new(Self {
913            service,
914            expect,
915            upgrade,
916        })
917    }
918}
919
920impl<T, S, B, X, U> Service<(T, Protocol, Option<net::SocketAddr>)>
921    for HttpServiceHandler<T, S, B, X, U>
922where
923    T: AsyncRead + AsyncWrite + Unpin,
924
925    S: Service<Request>,
926    S::Error: Into<Response<BoxBody>> + 'static,
927    S::Future: 'static,
928    S::Response: Into<Response<B>> + 'static,
929
930    B: MessageBody + 'static,
931
932    X: Service<Request, Response = Request>,
933    X::Error: Into<Response<BoxBody>>,
934
935    U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
936    U::Error: fmt::Display + Into<Response<BoxBody>>,
937{
938    type Response = ();
939    type Error = DispatchError;
940    type Future = HttpServiceHandlerResponse<T, S, B, X, U>;
941
942    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
943        self._poll_ready(cx).map_err(|err| {
944            error!("HTTP service readiness error: {:?}", err);
945            DispatchError::Service(err)
946        })
947    }
948
949    fn call(&self, (io, proto, peer_addr): (T, Protocol, Option<net::SocketAddr>)) -> Self::Future {
950        let conn_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
951
952        match proto {
953            #[cfg(feature = "http2")]
954            Protocol::Http2 => HttpServiceHandlerResponse {
955                state: State::H2Handshake {
956                    handshake: Some((
957                        crate::h2::handshake_with_timeout(io, &self.cfg),
958                        self.cfg.clone(),
959                        Rc::clone(&self.flow),
960                        conn_data,
961                        peer_addr,
962                    )),
963                },
964            },
965
966            #[cfg(not(feature = "http2"))]
967            Protocol::Http2 => {
968                panic!("HTTP/2 support is disabled (enable with the `http2` feature flag)")
969            }
970
971            Protocol::Http1 => HttpServiceHandlerResponse {
972                state: State::H1 {
973                    dispatcher: h1::Dispatcher::new(
974                        io,
975                        Rc::clone(&self.flow),
976                        self.cfg.clone(),
977                        peer_addr,
978                        conn_data,
979                    ),
980                },
981            },
982
983            proto => unimplemented!("Unsupported HTTP version: {:?}.", proto),
984        }
985    }
986}
987
988#[cfg(not(feature = "http2"))]
989pin_project! {
990    #[project = StateProj]
991    enum State<T, S, B, X, U>
992    where
993        T: AsyncRead,
994        T: AsyncWrite,
995        T: Unpin,
996
997        S: Service<Request>,
998        S::Future: 'static,
999        S::Error: Into<Response<BoxBody>>,
1000
1001        B: MessageBody,
1002
1003        X: Service<Request, Response = Request>,
1004        X::Error: Into<Response<BoxBody>>,
1005
1006        U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
1007        U::Error: fmt::Display,
1008    {
1009        H1 { #[pin] dispatcher: h1::Dispatcher<T, S, B, X, U> },
1010    }
1011}
1012
1013#[cfg(feature = "http2")]
1014pin_project! {
1015    #[project = StateProj]
1016    enum State<T, S, B, X, U>
1017    where
1018        T: AsyncRead,
1019        T: AsyncWrite,
1020        T: Unpin,
1021
1022        S: Service<Request>,
1023        S::Future: 'static,
1024        S::Error: Into<Response<BoxBody>>,
1025
1026        B: MessageBody,
1027
1028        X: Service<Request, Response = Request>,
1029        X::Error: Into<Response<BoxBody>>,
1030
1031        U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
1032        U::Error: fmt::Display,
1033    {
1034        H1 { #[pin] dispatcher: h1::Dispatcher<T, S, B, X, U> },
1035
1036        H2 { #[pin] dispatcher: crate::h2::Dispatcher<T, S, B, X, U> },
1037
1038        H2Handshake {
1039            handshake: Option<(
1040                crate::h2::HandshakeWithTimeout<T>,
1041                ServiceConfig,
1042                Rc<HttpFlow<S, X, U>>,
1043                OnConnectData,
1044                Option<net::SocketAddr>,
1045            )>,
1046        },
1047    }
1048}
1049
1050pin_project! {
1051    pub struct HttpServiceHandlerResponse<T, S, B, X, U>
1052    where
1053        T: AsyncRead,
1054        T: AsyncWrite,
1055        T: Unpin,
1056
1057        S: Service<Request>,
1058        S::Error: Into<Response<BoxBody>>,
1059        S::Error: 'static,
1060        S::Future: 'static,
1061        S::Response: Into<Response<B>>,
1062        S::Response: 'static,
1063
1064        B: MessageBody,
1065
1066        X: Service<Request, Response = Request>,
1067        X::Error: Into<Response<BoxBody>>,
1068
1069        U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
1070        U::Error: fmt::Display,
1071    {
1072        #[pin]
1073        state: State<T, S, B, X, U>,
1074    }
1075}
1076
1077impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
1078where
1079    T: AsyncRead + AsyncWrite + Unpin,
1080
1081    S: Service<Request>,
1082    S::Error: Into<Response<BoxBody>> + 'static,
1083    S::Future: 'static,
1084    S::Response: Into<Response<B>> + 'static,
1085
1086    B: MessageBody + 'static,
1087
1088    X: Service<Request, Response = Request>,
1089    X::Error: Into<Response<BoxBody>>,
1090
1091    U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
1092    U::Error: fmt::Display,
1093{
1094    type Output = Result<(), DispatchError>;
1095
1096    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1097        match self.as_mut().project().state.project() {
1098            StateProj::H1 { dispatcher } => dispatcher.poll(cx),
1099
1100            #[cfg(feature = "http2")]
1101            StateProj::H2 { dispatcher } => dispatcher.poll(cx),
1102
1103            #[cfg(feature = "http2")]
1104            StateProj::H2Handshake { handshake: data } => {
1105                match ready!(Pin::new(&mut data.as_mut().unwrap().0).poll(cx)) {
1106                    Ok((conn, timer)) => {
1107                        let (_, config, flow, conn_data, peer_addr) = data.take().unwrap();
1108
1109                        self.as_mut().project().state.set(State::H2 {
1110                            dispatcher: crate::h2::Dispatcher::new(
1111                                conn, flow, config, peer_addr, conn_data, timer,
1112                            ),
1113                        });
1114                        self.poll(cx)
1115                    }
1116                    Err(err) => {
1117                        tracing::trace!("H2 handshake error: {}", err);
1118                        Poll::Ready(Err(err))
1119                    }
1120                }
1121            }
1122        }
1123    }
1124}