Skip to main content

actix_http/h2/
service.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    mem, net,
5    pin::Pin,
6    rc::Rc,
7    task::{Context, Poll},
8};
9
10use actix_codec::{AsyncRead, AsyncWrite};
11use actix_rt::net::TcpStream;
12use actix_service::{
13    fn_factory, fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _,
14};
15use actix_utils::future::ready;
16use futures_core::{future::LocalBoxFuture, ready};
17use tracing::{error, trace};
18
19use super::{dispatcher::Dispatcher, handshake_with_timeout, HandshakeWithTimeout};
20use crate::{
21    body::{BoxBody, MessageBody},
22    config::ServiceConfig,
23    error::DispatchError,
24    service::HttpFlow,
25    ConnectCallback, OnConnectData, Request, Response,
26};
27
28#[inline]
29fn desired_nodelay(tcp_nodelay: Option<bool>) -> Option<bool> {
30    tcp_nodelay
31}
32
33#[inline]
34fn set_nodelay(stream: &TcpStream, nodelay: bool) {
35    let _ = stream.set_nodelay(nodelay);
36}
37
38/// `ServiceFactory` implementation for HTTP/2 transport
39pub struct H2Service<T, S, B> {
40    srv: S,
41    cfg: ServiceConfig,
42    on_connect_ext: Option<Rc<ConnectCallback<T>>>,
43    _phantom: PhantomData<(T, B)>,
44}
45
46impl<T, S, B> H2Service<T, S, B>
47where
48    S: ServiceFactory<Request, Config = ()>,
49    S::Error: Into<Response<BoxBody>> + 'static,
50    S::Response: Into<Response<B>> + 'static,
51    <S::Service as Service<Request>>::Future: 'static,
52
53    B: MessageBody + 'static,
54{
55    /// Create new `H2Service` instance with config.
56    pub(crate) fn with_config<F: IntoServiceFactory<S, Request>>(
57        cfg: ServiceConfig,
58        service: F,
59    ) -> Self {
60        H2Service {
61            cfg,
62            on_connect_ext: None,
63            srv: service.into_factory(),
64            _phantom: PhantomData,
65        }
66    }
67
68    /// Set on connect callback.
69    pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
70        self.on_connect_ext = f;
71        self
72    }
73}
74
75impl<S, B> H2Service<TcpStream, S, B>
76where
77    S: ServiceFactory<Request, Config = ()>,
78    S::Future: 'static,
79    S::Error: Into<Response<BoxBody>> + 'static,
80    S::Response: Into<Response<B>> + 'static,
81    <S::Service as Service<Request>>::Future: 'static,
82
83    B: MessageBody + 'static,
84{
85    /// Create plain TCP based service
86    pub fn tcp(
87        self,
88    ) -> impl ServiceFactory<
89        TcpStream,
90        Config = (),
91        Response = (),
92        Error = DispatchError,
93        InitError = S::InitError,
94    > {
95        let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
96
97        fn_factory(move || {
98            ready(Ok::<_, S::InitError>(fn_service(move |io: TcpStream| {
99                if let Some(nodelay) = tcp_nodelay {
100                    set_nodelay(&io, nodelay);
101                }
102                let peer_addr = io.peer_addr().ok();
103                ready(Ok::<_, DispatchError>((io, peer_addr)))
104            })))
105        })
106        .and_then(self)
107    }
108}
109
110#[cfg(feature = "openssl")]
111mod openssl {
112    use actix_service::ServiceFactoryExt as _;
113    use actix_tls::accept::{
114        openssl::{
115            reexports::{Error as SslError, SslAcceptor},
116            Acceptor, TlsStream,
117        },
118        TlsError,
119    };
120
121    use super::*;
122
123    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
124    where
125        S: ServiceFactory<Request, Config = ()>,
126        S::Future: 'static,
127        S::Error: Into<Response<BoxBody>> + 'static,
128        S::Response: Into<Response<B>> + 'static,
129        <S::Service as Service<Request>>::Future: 'static,
130
131        B: MessageBody + 'static,
132    {
133        /// Create OpenSSL based service.
134        pub fn openssl(
135            self,
136            acceptor: SslAcceptor,
137        ) -> impl ServiceFactory<
138            TcpStream,
139            Config = (),
140            Response = (),
141            Error = TlsError<SslError, DispatchError>,
142            InitError = S::InitError,
143        > {
144            let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
145
146            Acceptor::new(acceptor)
147                .map_init_err(|_| {
148                    unreachable!("TLS acceptor service factory does not error on init")
149                })
150                .map_err(TlsError::into_service_error)
151                .map(move |io: TlsStream<TcpStream>| {
152                    if let Some(nodelay) = tcp_nodelay {
153                        set_nodelay(io.get_ref(), nodelay);
154                    }
155                    let peer_addr = io.get_ref().peer_addr().ok();
156                    (io, peer_addr)
157                })
158                .and_then(self.map_err(TlsError::Service))
159        }
160    }
161}
162
163#[cfg(feature = "rustls-0_20")]
164mod rustls_0_20 {
165    use std::io;
166
167    use actix_service::ServiceFactoryExt as _;
168    use actix_tls::accept::{
169        rustls::{reexports::ServerConfig, Acceptor, TlsStream},
170        TlsError,
171    };
172
173    use super::*;
174
175    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
176    where
177        S: ServiceFactory<Request, Config = ()>,
178        S::Future: 'static,
179        S::Error: Into<Response<BoxBody>> + 'static,
180        S::Response: Into<Response<B>> + 'static,
181        <S::Service as Service<Request>>::Future: 'static,
182
183        B: MessageBody + 'static,
184    {
185        /// Create Rustls v0.20 based service.
186        pub fn rustls(
187            self,
188            mut config: ServerConfig,
189        ) -> impl ServiceFactory<
190            TcpStream,
191            Config = (),
192            Response = (),
193            Error = TlsError<io::Error, DispatchError>,
194            InitError = S::InitError,
195        > {
196            let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
197            let mut protos = vec![b"h2".to_vec()];
198            protos.extend_from_slice(&config.alpn_protocols);
199            config.alpn_protocols = protos;
200
201            Acceptor::new(config)
202                .map_init_err(|_| {
203                    unreachable!("TLS acceptor service factory does not error on init")
204                })
205                .map_err(TlsError::into_service_error)
206                .map(move |io: TlsStream<TcpStream>| {
207                    if let Some(nodelay) = tcp_nodelay {
208                        set_nodelay(io.get_ref().0, nodelay);
209                    }
210                    let peer_addr = io.get_ref().0.peer_addr().ok();
211                    (io, peer_addr)
212                })
213                .and_then(self.map_err(TlsError::Service))
214        }
215    }
216}
217
218#[cfg(feature = "rustls-0_21")]
219mod rustls_0_21 {
220    use std::io;
221
222    use actix_service::ServiceFactoryExt as _;
223    use actix_tls::accept::{
224        rustls_0_21::{reexports::ServerConfig, Acceptor, TlsStream},
225        TlsError,
226    };
227
228    use super::*;
229
230    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
231    where
232        S: ServiceFactory<Request, Config = ()>,
233        S::Future: 'static,
234        S::Error: Into<Response<BoxBody>> + 'static,
235        S::Response: Into<Response<B>> + 'static,
236        <S::Service as Service<Request>>::Future: 'static,
237
238        B: MessageBody + 'static,
239    {
240        /// Create Rustls v0.21 based service.
241        pub fn rustls_021(
242            self,
243            mut config: ServerConfig,
244        ) -> impl ServiceFactory<
245            TcpStream,
246            Config = (),
247            Response = (),
248            Error = TlsError<io::Error, DispatchError>,
249            InitError = S::InitError,
250        > {
251            let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
252            let mut protos = vec![b"h2".to_vec()];
253            protos.extend_from_slice(&config.alpn_protocols);
254            config.alpn_protocols = protos;
255
256            Acceptor::new(config)
257                .map_init_err(|_| {
258                    unreachable!("TLS acceptor service factory does not error on init")
259                })
260                .map_err(TlsError::into_service_error)
261                .map(move |io: TlsStream<TcpStream>| {
262                    if let Some(nodelay) = tcp_nodelay {
263                        set_nodelay(io.get_ref().0, nodelay);
264                    }
265                    let peer_addr = io.get_ref().0.peer_addr().ok();
266                    (io, peer_addr)
267                })
268                .and_then(self.map_err(TlsError::Service))
269        }
270    }
271}
272
273#[cfg(feature = "rustls-0_22")]
274mod rustls_0_22 {
275    use std::io;
276
277    use actix_service::ServiceFactoryExt as _;
278    use actix_tls::accept::{
279        rustls_0_22::{reexports::ServerConfig, Acceptor, TlsStream},
280        TlsError,
281    };
282
283    use super::*;
284
285    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
286    where
287        S: ServiceFactory<Request, Config = ()>,
288        S::Future: 'static,
289        S::Error: Into<Response<BoxBody>> + 'static,
290        S::Response: Into<Response<B>> + 'static,
291        <S::Service as Service<Request>>::Future: 'static,
292
293        B: MessageBody + 'static,
294    {
295        /// Create Rustls v0.22 based service.
296        pub fn rustls_0_22(
297            self,
298            mut config: ServerConfig,
299        ) -> impl ServiceFactory<
300            TcpStream,
301            Config = (),
302            Response = (),
303            Error = TlsError<io::Error, DispatchError>,
304            InitError = S::InitError,
305        > {
306            let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
307            let mut protos = vec![b"h2".to_vec()];
308            protos.extend_from_slice(&config.alpn_protocols);
309            config.alpn_protocols = protos;
310
311            Acceptor::new(config)
312                .map_init_err(|_| {
313                    unreachable!("TLS acceptor service factory does not error on init")
314                })
315                .map_err(TlsError::into_service_error)
316                .map(move |io: TlsStream<TcpStream>| {
317                    if let Some(nodelay) = tcp_nodelay {
318                        set_nodelay(io.get_ref().0, nodelay);
319                    }
320                    let peer_addr = io.get_ref().0.peer_addr().ok();
321                    (io, peer_addr)
322                })
323                .and_then(self.map_err(TlsError::Service))
324        }
325    }
326}
327
328#[cfg(feature = "rustls-0_23")]
329mod rustls_0_23 {
330    use std::io;
331
332    use actix_service::ServiceFactoryExt as _;
333    use actix_tls::accept::{
334        rustls_0_23::{reexports::ServerConfig, Acceptor, TlsStream},
335        TlsError,
336    };
337
338    use super::*;
339
340    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
341    where
342        S: ServiceFactory<Request, Config = ()>,
343        S::Future: 'static,
344        S::Error: Into<Response<BoxBody>> + 'static,
345        S::Response: Into<Response<B>> + 'static,
346        <S::Service as Service<Request>>::Future: 'static,
347
348        B: MessageBody + 'static,
349    {
350        /// Create Rustls v0.23 based service.
351        pub fn rustls_0_23(
352            self,
353            mut config: ServerConfig,
354        ) -> impl ServiceFactory<
355            TcpStream,
356            Config = (),
357            Response = (),
358            Error = TlsError<io::Error, DispatchError>,
359            InitError = S::InitError,
360        > {
361            let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
362            let mut protos = vec![b"h2".to_vec()];
363            protos.extend_from_slice(&config.alpn_protocols);
364            config.alpn_protocols = protos;
365
366            Acceptor::new(config)
367                .map_init_err(|_| {
368                    unreachable!("TLS acceptor service factory does not error on init")
369                })
370                .map_err(TlsError::into_service_error)
371                .map(move |io: TlsStream<TcpStream>| {
372                    if let Some(nodelay) = tcp_nodelay {
373                        set_nodelay(io.get_ref().0, nodelay);
374                    }
375                    let peer_addr = io.get_ref().0.peer_addr().ok();
376                    (io, peer_addr)
377                })
378                .and_then(self.map_err(TlsError::Service))
379        }
380    }
381}
382
383impl<T, S, B> ServiceFactory<(T, Option<net::SocketAddr>)> for H2Service<T, S, B>
384where
385    T: AsyncRead + AsyncWrite + Unpin + 'static,
386
387    S: ServiceFactory<Request, Config = ()>,
388    S::Future: 'static,
389    S::Error: Into<Response<BoxBody>> + 'static,
390    S::Response: Into<Response<B>> + 'static,
391    <S::Service as Service<Request>>::Future: 'static,
392
393    B: MessageBody + 'static,
394{
395    type Response = ();
396    type Error = DispatchError;
397    type Config = ();
398    type Service = H2ServiceHandler<T, S::Service, B>;
399    type InitError = S::InitError;
400    type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
401
402    fn new_service(&self, _: ()) -> Self::Future {
403        let service = self.srv.new_service(());
404        let cfg = self.cfg.clone();
405        let on_connect_ext = self.on_connect_ext.clone();
406
407        Box::pin(async move {
408            let service = service.await?;
409            Ok(H2ServiceHandler::new(cfg, on_connect_ext, service))
410        })
411    }
412}
413
414/// `Service` implementation for HTTP/2 transport
415pub struct H2ServiceHandler<T, S, B>
416where
417    S: Service<Request>,
418{
419    flow: Rc<HttpFlow<S, (), ()>>,
420    cfg: ServiceConfig,
421    on_connect_ext: Option<Rc<ConnectCallback<T>>>,
422    _phantom: PhantomData<B>,
423}
424
425impl<T, S, B> H2ServiceHandler<T, S, B>
426where
427    S: Service<Request>,
428    S::Error: Into<Response<BoxBody>> + 'static,
429    S::Future: 'static,
430    S::Response: Into<Response<B>> + 'static,
431    B: MessageBody + 'static,
432{
433    fn new(
434        cfg: ServiceConfig,
435        on_connect_ext: Option<Rc<ConnectCallback<T>>>,
436        service: S,
437    ) -> H2ServiceHandler<T, S, B> {
438        H2ServiceHandler {
439            flow: HttpFlow::new(service, (), None),
440            cfg,
441            on_connect_ext,
442            _phantom: PhantomData,
443        }
444    }
445}
446
447impl<T, S, B> Service<(T, Option<net::SocketAddr>)> for H2ServiceHandler<T, S, B>
448where
449    T: AsyncRead + AsyncWrite + Unpin,
450    S: Service<Request>,
451    S::Error: Into<Response<BoxBody>> + 'static,
452    S::Future: 'static,
453    S::Response: Into<Response<B>> + 'static,
454    B: MessageBody + 'static,
455{
456    type Response = ();
457    type Error = DispatchError;
458    type Future = H2ServiceHandlerResponse<T, S, B>;
459
460    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
461        self.flow.service.poll_ready(cx).map_err(|err| {
462            let err = err.into();
463            error!("Service readiness error: {:?}", err);
464            DispatchError::Service(err)
465        })
466    }
467
468    fn call(&self, (io, addr): (T, Option<net::SocketAddr>)) -> Self::Future {
469        let on_connect_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
470
471        H2ServiceHandlerResponse {
472            state: State::Handshake(
473                Some(Rc::clone(&self.flow)),
474                Some(self.cfg.clone()),
475                addr,
476                on_connect_data,
477                handshake_with_timeout(io, &self.cfg),
478            ),
479        }
480    }
481}
482
483enum State<T, S: Service<Request>, B: MessageBody>
484where
485    T: AsyncRead + AsyncWrite + Unpin,
486    S::Future: 'static,
487{
488    Handshake(
489        Option<Rc<HttpFlow<S, (), ()>>>,
490        Option<ServiceConfig>,
491        Option<net::SocketAddr>,
492        OnConnectData,
493        HandshakeWithTimeout<T>,
494    ),
495    Established(Dispatcher<T, S, B, (), ()>),
496}
497
498pub struct H2ServiceHandlerResponse<T, S, B>
499where
500    T: AsyncRead + AsyncWrite + Unpin,
501    S: Service<Request>,
502    S::Error: Into<Response<BoxBody>> + 'static,
503    S::Future: 'static,
504    S::Response: Into<Response<B>> + 'static,
505    B: MessageBody + 'static,
506{
507    state: State<T, S, B>,
508}
509
510impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
511where
512    T: AsyncRead + AsyncWrite + Unpin,
513    S: Service<Request>,
514    S::Error: Into<Response<BoxBody>> + 'static,
515    S::Future: 'static,
516    S::Response: Into<Response<B>> + 'static,
517    B: MessageBody,
518{
519    type Output = Result<(), DispatchError>;
520
521    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
522        match self.state {
523            State::Handshake(
524                ref mut srv,
525                ref mut config,
526                ref peer_addr,
527                ref mut conn_data,
528                ref mut handshake,
529            ) => match ready!(Pin::new(handshake).poll(cx)) {
530                Ok((conn, timer)) => {
531                    let on_connect_data = mem::take(conn_data);
532
533                    self.state = State::Established(Dispatcher::new(
534                        conn,
535                        srv.take().unwrap(),
536                        config.take().unwrap(),
537                        *peer_addr,
538                        on_connect_data,
539                        timer,
540                    ));
541
542                    self.poll(cx)
543                }
544
545                Err(err) => {
546                    trace!("H2 handshake error: {}", err);
547                    Poll::Ready(Err(err))
548                }
549            },
550
551            State::Established(ref mut disp) => Pin::new(disp).poll(cx),
552        }
553    }
554}