Skip to main content

actix_http/h2/
service.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    mem, net,
5    pin::Pin,
6    rc::Rc,
7    task::{Context, Poll},
8};
9
10use actix_codec::{AsyncRead, AsyncWrite};
11use actix_rt::net::TcpStream;
12use actix_service::{
13    fn_factory, fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _,
14};
15use actix_utils::future::ready;
16use futures_core::{future::LocalBoxFuture, ready};
17use tracing::{error, trace};
18
19use super::{dispatcher::Dispatcher, handshake_with_timeout, HandshakeWithTimeout};
20use crate::{
21    body::{BoxBody, MessageBody},
22    config::ServiceConfig,
23    error::DispatchError,
24    service::HttpFlow,
25    ConnectCallback, OnConnectData, Request, Response,
26};
27
28#[inline]
29fn desired_nodelay(tcp_nodelay: Option<bool>) -> Option<bool> {
30    tcp_nodelay
31}
32
33#[inline]
34fn set_nodelay(stream: &TcpStream, nodelay: bool) {
35    let _ = stream.set_nodelay(nodelay);
36}
37
38/// `ServiceFactory` implementation for HTTP/2 transport
39pub struct H2Service<T, S, B> {
40    srv: S,
41    cfg: ServiceConfig,
42    on_connect_ext: Option<Rc<ConnectCallback<T>>>,
43    _phantom: PhantomData<(T, B)>,
44}
45
46impl<T, S, B> H2Service<T, S, B>
47where
48    S: ServiceFactory<Request, Config = ()>,
49    S::Error: Into<Response<BoxBody>> + 'static,
50    S::Response: Into<Response<B>> + 'static,
51    <S::Service as Service<Request>>::Future: 'static,
52
53    B: MessageBody + 'static,
54{
55    /// Create new `H2Service` instance with config.
56    pub(crate) fn with_config<F: IntoServiceFactory<S, Request>>(
57        cfg: ServiceConfig,
58        service: F,
59    ) -> Self {
60        H2Service {
61            cfg,
62            on_connect_ext: None,
63            srv: service.into_factory(),
64            _phantom: PhantomData,
65        }
66    }
67
68    /// Set on connect callback.
69    pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
70        self.on_connect_ext = f;
71        self
72    }
73}
74
75impl<S, B> H2Service<TcpStream, S, B>
76where
77    S: ServiceFactory<Request, Config = ()>,
78    S::Future: 'static,
79    S::Error: Into<Response<BoxBody>> + 'static,
80    S::Response: Into<Response<B>> + 'static,
81    <S::Service as Service<Request>>::Future: 'static,
82
83    B: MessageBody + 'static,
84{
85    /// Create plain TCP based service
86    pub fn tcp(
87        self,
88    ) -> impl ServiceFactory<
89        TcpStream,
90        Config = (),
91        Response = (),
92        Error = DispatchError,
93        InitError = S::InitError,
94    > {
95        let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
96
97        fn_factory(move || {
98            ready(Ok::<_, S::InitError>(fn_service(move |io: TcpStream| {
99                if let Some(nodelay) = tcp_nodelay {
100                    set_nodelay(&io, nodelay);
101                }
102                let peer_addr = io.peer_addr().ok();
103                ready(Ok::<_, DispatchError>((io, peer_addr)))
104            })))
105        })
106        .and_then(self)
107    }
108}
109
110#[cfg(feature = "openssl")]
111mod openssl {
112    use actix_tls::accept::{
113        openssl::{
114            reexports::{Error as SslError, SslAcceptor},
115            Acceptor, TlsStream,
116        },
117        TlsError,
118    };
119
120    use super::*;
121
122    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
123    where
124        S: ServiceFactory<Request, Config = ()>,
125        S::Future: 'static,
126        S::Error: Into<Response<BoxBody>> + 'static,
127        S::Response: Into<Response<B>> + 'static,
128        <S::Service as Service<Request>>::Future: 'static,
129
130        B: MessageBody + 'static,
131    {
132        /// Create OpenSSL based service.
133        pub fn openssl(
134            self,
135            acceptor: SslAcceptor,
136        ) -> impl ServiceFactory<
137            TcpStream,
138            Config = (),
139            Response = (),
140            Error = TlsError<SslError, DispatchError>,
141            InitError = S::InitError,
142        > {
143            let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
144
145            Acceptor::new(acceptor)
146                .map_init_err(|_| {
147                    unreachable!("TLS acceptor service factory does not error on init")
148                })
149                .map_err(TlsError::into_service_error)
150                .map(move |io: TlsStream<TcpStream>| {
151                    if let Some(nodelay) = tcp_nodelay {
152                        set_nodelay(io.get_ref(), nodelay);
153                    }
154                    let peer_addr = io.get_ref().peer_addr().ok();
155                    (io, peer_addr)
156                })
157                .and_then(self.map_err(TlsError::Service))
158        }
159    }
160}
161
162#[cfg(feature = "rustls-0_20")]
163mod rustls_0_20 {
164    use std::io;
165
166    use actix_tls::accept::{
167        rustls::{reexports::ServerConfig, Acceptor, TlsStream},
168        TlsError,
169    };
170
171    use super::*;
172
173    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
174    where
175        S: ServiceFactory<Request, Config = ()>,
176        S::Future: 'static,
177        S::Error: Into<Response<BoxBody>> + 'static,
178        S::Response: Into<Response<B>> + 'static,
179        <S::Service as Service<Request>>::Future: 'static,
180
181        B: MessageBody + 'static,
182    {
183        /// Create Rustls v0.20 based service.
184        pub fn rustls(
185            self,
186            mut config: ServerConfig,
187        ) -> impl ServiceFactory<
188            TcpStream,
189            Config = (),
190            Response = (),
191            Error = TlsError<io::Error, DispatchError>,
192            InitError = S::InitError,
193        > {
194            let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
195            let mut protos = vec![b"h2".to_vec()];
196            protos.extend_from_slice(&config.alpn_protocols);
197            config.alpn_protocols = protos;
198
199            Acceptor::new(config)
200                .map_init_err(|_| {
201                    unreachable!("TLS acceptor service factory does not error on init")
202                })
203                .map_err(TlsError::into_service_error)
204                .map(move |io: TlsStream<TcpStream>| {
205                    if let Some(nodelay) = tcp_nodelay {
206                        set_nodelay(io.get_ref().0, nodelay);
207                    }
208                    let peer_addr = io.get_ref().0.peer_addr().ok();
209                    (io, peer_addr)
210                })
211                .and_then(self.map_err(TlsError::Service))
212        }
213    }
214}
215
216#[cfg(feature = "rustls-0_21")]
217mod rustls_0_21 {
218    use std::io;
219
220    use actix_tls::accept::{
221        rustls_0_21::{reexports::ServerConfig, Acceptor, TlsStream},
222        TlsError,
223    };
224
225    use super::*;
226
227    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
228    where
229        S: ServiceFactory<Request, Config = ()>,
230        S::Future: 'static,
231        S::Error: Into<Response<BoxBody>> + 'static,
232        S::Response: Into<Response<B>> + 'static,
233        <S::Service as Service<Request>>::Future: 'static,
234
235        B: MessageBody + 'static,
236    {
237        /// Create Rustls v0.21 based service.
238        pub fn rustls_021(
239            self,
240            mut config: ServerConfig,
241        ) -> impl ServiceFactory<
242            TcpStream,
243            Config = (),
244            Response = (),
245            Error = TlsError<io::Error, DispatchError>,
246            InitError = S::InitError,
247        > {
248            let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
249            let mut protos = vec![b"h2".to_vec()];
250            protos.extend_from_slice(&config.alpn_protocols);
251            config.alpn_protocols = protos;
252
253            Acceptor::new(config)
254                .map_init_err(|_| {
255                    unreachable!("TLS acceptor service factory does not error on init")
256                })
257                .map_err(TlsError::into_service_error)
258                .map(move |io: TlsStream<TcpStream>| {
259                    if let Some(nodelay) = tcp_nodelay {
260                        set_nodelay(io.get_ref().0, nodelay);
261                    }
262                    let peer_addr = io.get_ref().0.peer_addr().ok();
263                    (io, peer_addr)
264                })
265                .and_then(self.map_err(TlsError::Service))
266        }
267    }
268}
269
270#[cfg(feature = "rustls-0_22")]
271mod rustls_0_22 {
272    use std::io;
273
274    use actix_tls::accept::{
275        rustls_0_22::{reexports::ServerConfig, Acceptor, TlsStream},
276        TlsError,
277    };
278
279    use super::*;
280
281    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
282    where
283        S: ServiceFactory<Request, Config = ()>,
284        S::Future: 'static,
285        S::Error: Into<Response<BoxBody>> + 'static,
286        S::Response: Into<Response<B>> + 'static,
287        <S::Service as Service<Request>>::Future: 'static,
288
289        B: MessageBody + 'static,
290    {
291        /// Create Rustls v0.22 based service.
292        pub fn rustls_0_22(
293            self,
294            mut config: ServerConfig,
295        ) -> impl ServiceFactory<
296            TcpStream,
297            Config = (),
298            Response = (),
299            Error = TlsError<io::Error, DispatchError>,
300            InitError = S::InitError,
301        > {
302            let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
303            let mut protos = vec![b"h2".to_vec()];
304            protos.extend_from_slice(&config.alpn_protocols);
305            config.alpn_protocols = protos;
306
307            Acceptor::new(config)
308                .map_init_err(|_| {
309                    unreachable!("TLS acceptor service factory does not error on init")
310                })
311                .map_err(TlsError::into_service_error)
312                .map(move |io: TlsStream<TcpStream>| {
313                    if let Some(nodelay) = tcp_nodelay {
314                        set_nodelay(io.get_ref().0, nodelay);
315                    }
316                    let peer_addr = io.get_ref().0.peer_addr().ok();
317                    (io, peer_addr)
318                })
319                .and_then(self.map_err(TlsError::Service))
320        }
321    }
322}
323
324#[cfg(feature = "rustls-0_23")]
325mod rustls_0_23 {
326    use std::io;
327
328    use actix_tls::accept::{
329        rustls_0_23::{reexports::ServerConfig, Acceptor, TlsStream},
330        TlsError,
331    };
332
333    use super::*;
334
335    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
336    where
337        S: ServiceFactory<Request, Config = ()>,
338        S::Future: 'static,
339        S::Error: Into<Response<BoxBody>> + 'static,
340        S::Response: Into<Response<B>> + 'static,
341        <S::Service as Service<Request>>::Future: 'static,
342
343        B: MessageBody + 'static,
344    {
345        /// Create Rustls v0.23 based service.
346        pub fn rustls_0_23(
347            self,
348            mut config: ServerConfig,
349        ) -> impl ServiceFactory<
350            TcpStream,
351            Config = (),
352            Response = (),
353            Error = TlsError<io::Error, DispatchError>,
354            InitError = S::InitError,
355        > {
356            let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
357            let mut protos = vec![b"h2".to_vec()];
358            protos.extend_from_slice(&config.alpn_protocols);
359            config.alpn_protocols = protos;
360
361            Acceptor::new(config)
362                .map_init_err(|_| {
363                    unreachable!("TLS acceptor service factory does not error on init")
364                })
365                .map_err(TlsError::into_service_error)
366                .map(move |io: TlsStream<TcpStream>| {
367                    if let Some(nodelay) = tcp_nodelay {
368                        set_nodelay(io.get_ref().0, nodelay);
369                    }
370                    let peer_addr = io.get_ref().0.peer_addr().ok();
371                    (io, peer_addr)
372                })
373                .and_then(self.map_err(TlsError::Service))
374        }
375    }
376}
377
378impl<T, S, B> ServiceFactory<(T, Option<net::SocketAddr>)> for H2Service<T, S, B>
379where
380    T: AsyncRead + AsyncWrite + Unpin + 'static,
381
382    S: ServiceFactory<Request, Config = ()>,
383    S::Future: 'static,
384    S::Error: Into<Response<BoxBody>> + 'static,
385    S::Response: Into<Response<B>> + 'static,
386    <S::Service as Service<Request>>::Future: 'static,
387
388    B: MessageBody + 'static,
389{
390    type Response = ();
391    type Error = DispatchError;
392    type Config = ();
393    type Service = H2ServiceHandler<T, S::Service, B>;
394    type InitError = S::InitError;
395    type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
396
397    fn new_service(&self, _: ()) -> Self::Future {
398        let service = self.srv.new_service(());
399        let cfg = self.cfg.clone();
400        let on_connect_ext = self.on_connect_ext.clone();
401
402        Box::pin(async move {
403            let service = service.await?;
404            Ok(H2ServiceHandler::new(cfg, on_connect_ext, service))
405        })
406    }
407}
408
409/// `Service` implementation for HTTP/2 transport
410pub struct H2ServiceHandler<T, S, B>
411where
412    S: Service<Request>,
413{
414    flow: Rc<HttpFlow<S, (), ()>>,
415    cfg: ServiceConfig,
416    on_connect_ext: Option<Rc<ConnectCallback<T>>>,
417    _phantom: PhantomData<B>,
418}
419
420impl<T, S, B> H2ServiceHandler<T, S, B>
421where
422    S: Service<Request>,
423    S::Error: Into<Response<BoxBody>> + 'static,
424    S::Future: 'static,
425    S::Response: Into<Response<B>> + 'static,
426    B: MessageBody + 'static,
427{
428    fn new(
429        cfg: ServiceConfig,
430        on_connect_ext: Option<Rc<ConnectCallback<T>>>,
431        service: S,
432    ) -> H2ServiceHandler<T, S, B> {
433        H2ServiceHandler {
434            flow: HttpFlow::new(service, (), None),
435            cfg,
436            on_connect_ext,
437            _phantom: PhantomData,
438        }
439    }
440}
441
442impl<T, S, B> Service<(T, Option<net::SocketAddr>)> for H2ServiceHandler<T, S, B>
443where
444    T: AsyncRead + AsyncWrite + Unpin,
445    S: Service<Request>,
446    S::Error: Into<Response<BoxBody>> + 'static,
447    S::Future: 'static,
448    S::Response: Into<Response<B>> + 'static,
449    B: MessageBody + 'static,
450{
451    type Response = ();
452    type Error = DispatchError;
453    type Future = H2ServiceHandlerResponse<T, S, B>;
454
455    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
456        self.flow.service.poll_ready(cx).map_err(|err| {
457            let err = err.into();
458            error!("Service readiness error: {:?}", err);
459            DispatchError::Service(err)
460        })
461    }
462
463    fn call(&self, (io, addr): (T, Option<net::SocketAddr>)) -> Self::Future {
464        let on_connect_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
465
466        H2ServiceHandlerResponse {
467            state: State::Handshake(
468                Some(Rc::clone(&self.flow)),
469                Some(self.cfg.clone()),
470                addr,
471                on_connect_data,
472                handshake_with_timeout(io, &self.cfg),
473            ),
474        }
475    }
476}
477
478enum State<T, S: Service<Request>, B: MessageBody>
479where
480    T: AsyncRead + AsyncWrite + Unpin,
481    S::Future: 'static,
482{
483    Handshake(
484        Option<Rc<HttpFlow<S, (), ()>>>,
485        Option<ServiceConfig>,
486        Option<net::SocketAddr>,
487        OnConnectData,
488        HandshakeWithTimeout<T>,
489    ),
490    Established(Dispatcher<T, S, B, (), ()>),
491}
492
493pub struct H2ServiceHandlerResponse<T, S, B>
494where
495    T: AsyncRead + AsyncWrite + Unpin,
496    S: Service<Request>,
497    S::Error: Into<Response<BoxBody>> + 'static,
498    S::Future: 'static,
499    S::Response: Into<Response<B>> + 'static,
500    B: MessageBody + 'static,
501{
502    state: State<T, S, B>,
503}
504
505impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
506where
507    T: AsyncRead + AsyncWrite + Unpin,
508    S: Service<Request>,
509    S::Error: Into<Response<BoxBody>> + 'static,
510    S::Future: 'static,
511    S::Response: Into<Response<B>> + 'static,
512    B: MessageBody,
513{
514    type Output = Result<(), DispatchError>;
515
516    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
517        match self.state {
518            State::Handshake(
519                ref mut srv,
520                ref mut config,
521                ref peer_addr,
522                ref mut conn_data,
523                ref mut handshake,
524            ) => match ready!(Pin::new(handshake).poll(cx)) {
525                Ok((conn, timer)) => {
526                    let on_connect_data = mem::take(conn_data);
527
528                    self.state = State::Established(Dispatcher::new(
529                        conn,
530                        srv.take().unwrap(),
531                        config.take().unwrap(),
532                        *peer_addr,
533                        on_connect_data,
534                        timer,
535                    ));
536
537                    self.poll(cx)
538                }
539
540                Err(err) => {
541                    trace!("H2 handshake error: {}", err);
542                    Poll::Ready(Err(err))
543                }
544            },
545
546            State::Established(ref mut disp) => Pin::new(disp).poll(cx),
547        }
548    }
549}