requiem_http/h2/
service.rs

1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::{net, rc};
6
7use requiem_codec::{AsyncRead, AsyncWrite};
8use requiem_rt::net::TcpStream;
9use requiem_service::{
10    fn_factory, fn_service, pipeline_factory, IntoServiceFactory, Service,
11    ServiceFactory,
12};
13use bytes::Bytes;
14use futures_core::ready;
15use futures_util::future::ok;
16use h2::server::{self, Handshake};
17use log::error;
18
19use crate::body::MessageBody;
20use crate::cloneable::CloneableService;
21use crate::config::ServiceConfig;
22use crate::error::{DispatchError, Error};
23use crate::helpers::DataFactory;
24use crate::request::Request;
25use crate::response::Response;
26
27use super::dispatcher::Dispatcher;
28
29/// `ServiceFactory` implementation for HTTP2 transport
30pub struct H2Service<T, S, B> {
31    srv: S,
32    cfg: ServiceConfig,
33    on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
34    _t: PhantomData<(T, B)>,
35}
36
37impl<T, S, B> H2Service<T, S, B>
38where
39    S: ServiceFactory<Config = (), Request = Request>,
40    S::Error: Into<Error> + 'static,
41    S::Response: Into<Response<B>> + 'static,
42    <S::Service as Service>::Future: 'static,
43    B: MessageBody + 'static,
44{
45    /// Create new `HttpService` instance with config.
46    pub(crate) fn with_config<F: IntoServiceFactory<S>>(
47        cfg: ServiceConfig,
48        service: F,
49    ) -> Self {
50        H2Service {
51            cfg,
52            on_connect: None,
53            srv: service.into_factory(),
54            _t: PhantomData,
55        }
56    }
57
58    /// Set on connect callback.
59    pub(crate) fn on_connect(
60        mut self,
61        f: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
62    ) -> Self {
63        self.on_connect = f;
64        self
65    }
66}
67
68impl<S, B> H2Service<TcpStream, S, B>
69where
70    S: ServiceFactory<Config = (), Request = Request>,
71    S::Error: Into<Error> + 'static,
72    S::Response: Into<Response<B>> + 'static,
73    <S::Service as Service>::Future: 'static,
74    B: MessageBody + 'static,
75{
76    /// Create simple tcp based service
77    pub fn tcp(
78        self,
79    ) -> impl ServiceFactory<
80        Config = (),
81        Request = TcpStream,
82        Response = (),
83        Error = DispatchError,
84        InitError = S::InitError,
85    > {
86        pipeline_factory(fn_factory(|| {
87            async {
88                Ok::<_, S::InitError>(fn_service(|io: TcpStream| {
89                    let peer_addr = io.peer_addr().ok();
90                    ok::<_, DispatchError>((io, peer_addr))
91                }))
92            }
93        }))
94        .and_then(self)
95    }
96}
97
98#[cfg(feature = "openssl")]
99mod openssl {
100    use requiem_service::{fn_factory, fn_service};
101    use requiem_tls::openssl::{Acceptor, SslAcceptor, SslStream};
102    use requiem_tls::{openssl::HandshakeError, SslError};
103
104    use super::*;
105
106    impl<S, B> H2Service<SslStream<TcpStream>, S, B>
107    where
108        S: ServiceFactory<Config = (), Request = Request>,
109        S::Error: Into<Error> + 'static,
110        S::Response: Into<Response<B>> + 'static,
111        <S::Service as Service>::Future: 'static,
112        B: MessageBody + 'static,
113    {
114        /// Create ssl based service
115        pub fn openssl(
116            self,
117            acceptor: SslAcceptor,
118        ) -> impl ServiceFactory<
119            Config = (),
120            Request = TcpStream,
121            Response = (),
122            Error = SslError<HandshakeError<TcpStream>, DispatchError>,
123            InitError = S::InitError,
124        > {
125            pipeline_factory(
126                Acceptor::new(acceptor)
127                    .map_err(SslError::Ssl)
128                    .map_init_err(|_| panic!()),
129            )
130            .and_then(fn_factory(|| {
131                ok::<_, S::InitError>(fn_service(|io: SslStream<TcpStream>| {
132                    let peer_addr = io.get_ref().peer_addr().ok();
133                    ok((io, peer_addr))
134                }))
135            }))
136            .and_then(self.map_err(SslError::Service))
137        }
138    }
139}
140
141#[cfg(feature = "rustls")]
142mod rustls {
143    use super::*;
144    use requiem_tls::rustls::{Acceptor, ServerConfig, TlsStream};
145    use requiem_tls::SslError;
146    use std::io;
147
148    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
149    where
150        S: ServiceFactory<Config = (), Request = Request>,
151        S::Error: Into<Error> + 'static,
152        S::Response: Into<Response<B>> + 'static,
153        <S::Service as Service>::Future: 'static,
154        B: MessageBody + 'static,
155    {
156        /// Create openssl based service
157        pub fn rustls(
158            self,
159            mut config: ServerConfig,
160        ) -> impl ServiceFactory<
161            Config = (),
162            Request = TcpStream,
163            Response = (),
164            Error = SslError<io::Error, DispatchError>,
165            InitError = S::InitError,
166        > {
167            let protos = vec!["h2".to_string().into()];
168            config.set_protocols(&protos);
169
170            pipeline_factory(
171                Acceptor::new(config)
172                    .map_err(SslError::Ssl)
173                    .map_init_err(|_| panic!()),
174            )
175            .and_then(fn_factory(|| {
176                ok::<_, S::InitError>(fn_service(|io: TlsStream<TcpStream>| {
177                    let peer_addr = io.get_ref().0.peer_addr().ok();
178                    ok((io, peer_addr))
179                }))
180            }))
181            .and_then(self.map_err(SslError::Service))
182        }
183    }
184}
185
186impl<T, S, B> ServiceFactory for H2Service<T, S, B>
187where
188    T: AsyncRead + AsyncWrite + Unpin,
189    S: ServiceFactory<Config = (), Request = Request>,
190    S::Error: Into<Error> + 'static,
191    S::Response: Into<Response<B>> + 'static,
192    <S::Service as Service>::Future: 'static,
193    B: MessageBody + 'static,
194{
195    type Config = ();
196    type Request = (T, Option<net::SocketAddr>);
197    type Response = ();
198    type Error = DispatchError;
199    type InitError = S::InitError;
200    type Service = H2ServiceHandler<T, S::Service, B>;
201    type Future = H2ServiceResponse<T, S, B>;
202
203    fn new_service(&self, _: ()) -> Self::Future {
204        H2ServiceResponse {
205            fut: self.srv.new_service(()),
206            cfg: Some(self.cfg.clone()),
207            on_connect: self.on_connect.clone(),
208            _t: PhantomData,
209        }
210    }
211}
212
213#[doc(hidden)]
214#[pin_project::pin_project]
215pub struct H2ServiceResponse<T, S: ServiceFactory, B> {
216    #[pin]
217    fut: S::Future,
218    cfg: Option<ServiceConfig>,
219    on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
220    _t: PhantomData<(T, B)>,
221}
222
223impl<T, S, B> Future for H2ServiceResponse<T, S, B>
224where
225    T: AsyncRead + AsyncWrite + Unpin,
226    S: ServiceFactory<Config = (), Request = Request>,
227    S::Error: Into<Error> + 'static,
228    S::Response: Into<Response<B>> + 'static,
229    <S::Service as Service>::Future: 'static,
230    B: MessageBody + 'static,
231{
232    type Output = Result<H2ServiceHandler<T, S::Service, B>, S::InitError>;
233
234    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
235        let this = self.as_mut().project();
236
237        Poll::Ready(ready!(this.fut.poll(cx)).map(|service| {
238            let this = self.as_mut().project();
239            H2ServiceHandler::new(
240                this.cfg.take().unwrap(),
241                this.on_connect.clone(),
242                service,
243            )
244        }))
245    }
246}
247
248/// `Service` implementation for http/2 transport
249pub struct H2ServiceHandler<T, S: Service, B> {
250    srv: CloneableService<S>,
251    cfg: ServiceConfig,
252    on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
253    _t: PhantomData<(T, B)>,
254}
255
256impl<T, S, B> H2ServiceHandler<T, S, B>
257where
258    S: Service<Request = Request>,
259    S::Error: Into<Error> + 'static,
260    S::Future: 'static,
261    S::Response: Into<Response<B>> + 'static,
262    B: MessageBody + 'static,
263{
264    fn new(
265        cfg: ServiceConfig,
266        on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
267        srv: S,
268    ) -> H2ServiceHandler<T, S, B> {
269        H2ServiceHandler {
270            cfg,
271            on_connect,
272            srv: CloneableService::new(srv),
273            _t: PhantomData,
274        }
275    }
276}
277
278impl<T, S, B> Service for H2ServiceHandler<T, S, B>
279where
280    T: AsyncRead + AsyncWrite + Unpin,
281    S: Service<Request = Request>,
282    S::Error: Into<Error> + 'static,
283    S::Future: 'static,
284    S::Response: Into<Response<B>> + 'static,
285    B: MessageBody + 'static,
286{
287    type Request = (T, Option<net::SocketAddr>);
288    type Response = ();
289    type Error = DispatchError;
290    type Future = H2ServiceHandlerResponse<T, S, B>;
291
292    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
293        self.srv.poll_ready(cx).map_err(|e| {
294            let e = e.into();
295            error!("Service readiness error: {:?}", e);
296            DispatchError::Service(e)
297        })
298    }
299
300    fn call(&mut self, (io, addr): Self::Request) -> Self::Future {
301        let on_connect = if let Some(ref on_connect) = self.on_connect {
302            Some(on_connect(&io))
303        } else {
304            None
305        };
306
307        H2ServiceHandlerResponse {
308            state: State::Handshake(
309                Some(self.srv.clone()),
310                Some(self.cfg.clone()),
311                addr,
312                on_connect,
313                server::handshake(io),
314            ),
315        }
316    }
317}
318
319enum State<T, S: Service<Request = Request>, B: MessageBody>
320where
321    T: AsyncRead + AsyncWrite + Unpin,
322    S::Future: 'static,
323{
324    Incoming(Dispatcher<T, S, B>),
325    Handshake(
326        Option<CloneableService<S>>,
327        Option<ServiceConfig>,
328        Option<net::SocketAddr>,
329        Option<Box<dyn DataFactory>>,
330        Handshake<T, Bytes>,
331    ),
332}
333
334pub struct H2ServiceHandlerResponse<T, S, B>
335where
336    T: AsyncRead + AsyncWrite + Unpin,
337    S: Service<Request = Request>,
338    S::Error: Into<Error> + 'static,
339    S::Future: 'static,
340    S::Response: Into<Response<B>> + 'static,
341    B: MessageBody + 'static,
342{
343    state: State<T, S, B>,
344}
345
346impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
347where
348    T: AsyncRead + AsyncWrite + Unpin,
349    S: Service<Request = Request>,
350    S::Error: Into<Error> + 'static,
351    S::Future: 'static,
352    S::Response: Into<Response<B>> + 'static,
353    B: MessageBody,
354{
355    type Output = Result<(), DispatchError>;
356
357    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
358        match self.state {
359            State::Incoming(ref mut disp) => Pin::new(disp).poll(cx),
360            State::Handshake(
361                ref mut srv,
362                ref mut config,
363                ref peer_addr,
364                ref mut on_connect,
365                ref mut handshake,
366            ) => match Pin::new(handshake).poll(cx) {
367                Poll::Ready(Ok(conn)) => {
368                    self.state = State::Incoming(Dispatcher::new(
369                        srv.take().unwrap(),
370                        conn,
371                        on_connect.take(),
372                        config.take().unwrap(),
373                        None,
374                        *peer_addr,
375                    ));
376                    self.poll(cx)
377                }
378                Poll::Ready(Err(err)) => {
379                    trace!("H2 handshake error: {}", err);
380                    Poll::Ready(Err(err.into()))
381                }
382                Poll::Pending => Poll::Pending,
383            },
384        }
385    }
386}