1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::{net, rc};
6
7use requiem_codec::{AsyncRead, AsyncWrite};
8use requiem_rt::net::TcpStream;
9use requiem_service::{
10 fn_factory, fn_service, pipeline_factory, IntoServiceFactory, Service,
11 ServiceFactory,
12};
13use bytes::Bytes;
14use futures_core::ready;
15use futures_util::future::ok;
16use h2::server::{self, Handshake};
17use log::error;
18
19use crate::body::MessageBody;
20use crate::cloneable::CloneableService;
21use crate::config::ServiceConfig;
22use crate::error::{DispatchError, Error};
23use crate::helpers::DataFactory;
24use crate::request::Request;
25use crate::response::Response;
26
27use super::dispatcher::Dispatcher;
28
29pub struct H2Service<T, S, B> {
31 srv: S,
32 cfg: ServiceConfig,
33 on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
34 _t: PhantomData<(T, B)>,
35}
36
37impl<T, S, B> H2Service<T, S, B>
38where
39 S: ServiceFactory<Config = (), Request = Request>,
40 S::Error: Into<Error> + 'static,
41 S::Response: Into<Response<B>> + 'static,
42 <S::Service as Service>::Future: 'static,
43 B: MessageBody + 'static,
44{
45 pub(crate) fn with_config<F: IntoServiceFactory<S>>(
47 cfg: ServiceConfig,
48 service: F,
49 ) -> Self {
50 H2Service {
51 cfg,
52 on_connect: None,
53 srv: service.into_factory(),
54 _t: PhantomData,
55 }
56 }
57
58 pub(crate) fn on_connect(
60 mut self,
61 f: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
62 ) -> Self {
63 self.on_connect = f;
64 self
65 }
66}
67
68impl<S, B> H2Service<TcpStream, S, B>
69where
70 S: ServiceFactory<Config = (), Request = Request>,
71 S::Error: Into<Error> + 'static,
72 S::Response: Into<Response<B>> + 'static,
73 <S::Service as Service>::Future: 'static,
74 B: MessageBody + 'static,
75{
76 pub fn tcp(
78 self,
79 ) -> impl ServiceFactory<
80 Config = (),
81 Request = TcpStream,
82 Response = (),
83 Error = DispatchError,
84 InitError = S::InitError,
85 > {
86 pipeline_factory(fn_factory(|| {
87 async {
88 Ok::<_, S::InitError>(fn_service(|io: TcpStream| {
89 let peer_addr = io.peer_addr().ok();
90 ok::<_, DispatchError>((io, peer_addr))
91 }))
92 }
93 }))
94 .and_then(self)
95 }
96}
97
98#[cfg(feature = "openssl")]
99mod openssl {
100 use requiem_service::{fn_factory, fn_service};
101 use requiem_tls::openssl::{Acceptor, SslAcceptor, SslStream};
102 use requiem_tls::{openssl::HandshakeError, SslError};
103
104 use super::*;
105
106 impl<S, B> H2Service<SslStream<TcpStream>, S, B>
107 where
108 S: ServiceFactory<Config = (), Request = Request>,
109 S::Error: Into<Error> + 'static,
110 S::Response: Into<Response<B>> + 'static,
111 <S::Service as Service>::Future: 'static,
112 B: MessageBody + 'static,
113 {
114 pub fn openssl(
116 self,
117 acceptor: SslAcceptor,
118 ) -> impl ServiceFactory<
119 Config = (),
120 Request = TcpStream,
121 Response = (),
122 Error = SslError<HandshakeError<TcpStream>, DispatchError>,
123 InitError = S::InitError,
124 > {
125 pipeline_factory(
126 Acceptor::new(acceptor)
127 .map_err(SslError::Ssl)
128 .map_init_err(|_| panic!()),
129 )
130 .and_then(fn_factory(|| {
131 ok::<_, S::InitError>(fn_service(|io: SslStream<TcpStream>| {
132 let peer_addr = io.get_ref().peer_addr().ok();
133 ok((io, peer_addr))
134 }))
135 }))
136 .and_then(self.map_err(SslError::Service))
137 }
138 }
139}
140
141#[cfg(feature = "rustls")]
142mod rustls {
143 use super::*;
144 use requiem_tls::rustls::{Acceptor, ServerConfig, TlsStream};
145 use requiem_tls::SslError;
146 use std::io;
147
148 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
149 where
150 S: ServiceFactory<Config = (), Request = Request>,
151 S::Error: Into<Error> + 'static,
152 S::Response: Into<Response<B>> + 'static,
153 <S::Service as Service>::Future: 'static,
154 B: MessageBody + 'static,
155 {
156 pub fn rustls(
158 self,
159 mut config: ServerConfig,
160 ) -> impl ServiceFactory<
161 Config = (),
162 Request = TcpStream,
163 Response = (),
164 Error = SslError<io::Error, DispatchError>,
165 InitError = S::InitError,
166 > {
167 let protos = vec!["h2".to_string().into()];
168 config.set_protocols(&protos);
169
170 pipeline_factory(
171 Acceptor::new(config)
172 .map_err(SslError::Ssl)
173 .map_init_err(|_| panic!()),
174 )
175 .and_then(fn_factory(|| {
176 ok::<_, S::InitError>(fn_service(|io: TlsStream<TcpStream>| {
177 let peer_addr = io.get_ref().0.peer_addr().ok();
178 ok((io, peer_addr))
179 }))
180 }))
181 .and_then(self.map_err(SslError::Service))
182 }
183 }
184}
185
186impl<T, S, B> ServiceFactory for H2Service<T, S, B>
187where
188 T: AsyncRead + AsyncWrite + Unpin,
189 S: ServiceFactory<Config = (), Request = Request>,
190 S::Error: Into<Error> + 'static,
191 S::Response: Into<Response<B>> + 'static,
192 <S::Service as Service>::Future: 'static,
193 B: MessageBody + 'static,
194{
195 type Config = ();
196 type Request = (T, Option<net::SocketAddr>);
197 type Response = ();
198 type Error = DispatchError;
199 type InitError = S::InitError;
200 type Service = H2ServiceHandler<T, S::Service, B>;
201 type Future = H2ServiceResponse<T, S, B>;
202
203 fn new_service(&self, _: ()) -> Self::Future {
204 H2ServiceResponse {
205 fut: self.srv.new_service(()),
206 cfg: Some(self.cfg.clone()),
207 on_connect: self.on_connect.clone(),
208 _t: PhantomData,
209 }
210 }
211}
212
213#[doc(hidden)]
214#[pin_project::pin_project]
215pub struct H2ServiceResponse<T, S: ServiceFactory, B> {
216 #[pin]
217 fut: S::Future,
218 cfg: Option<ServiceConfig>,
219 on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
220 _t: PhantomData<(T, B)>,
221}
222
223impl<T, S, B> Future for H2ServiceResponse<T, S, B>
224where
225 T: AsyncRead + AsyncWrite + Unpin,
226 S: ServiceFactory<Config = (), Request = Request>,
227 S::Error: Into<Error> + 'static,
228 S::Response: Into<Response<B>> + 'static,
229 <S::Service as Service>::Future: 'static,
230 B: MessageBody + 'static,
231{
232 type Output = Result<H2ServiceHandler<T, S::Service, B>, S::InitError>;
233
234 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
235 let this = self.as_mut().project();
236
237 Poll::Ready(ready!(this.fut.poll(cx)).map(|service| {
238 let this = self.as_mut().project();
239 H2ServiceHandler::new(
240 this.cfg.take().unwrap(),
241 this.on_connect.clone(),
242 service,
243 )
244 }))
245 }
246}
247
248pub struct H2ServiceHandler<T, S: Service, B> {
250 srv: CloneableService<S>,
251 cfg: ServiceConfig,
252 on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
253 _t: PhantomData<(T, B)>,
254}
255
256impl<T, S, B> H2ServiceHandler<T, S, B>
257where
258 S: Service<Request = Request>,
259 S::Error: Into<Error> + 'static,
260 S::Future: 'static,
261 S::Response: Into<Response<B>> + 'static,
262 B: MessageBody + 'static,
263{
264 fn new(
265 cfg: ServiceConfig,
266 on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
267 srv: S,
268 ) -> H2ServiceHandler<T, S, B> {
269 H2ServiceHandler {
270 cfg,
271 on_connect,
272 srv: CloneableService::new(srv),
273 _t: PhantomData,
274 }
275 }
276}
277
278impl<T, S, B> Service for H2ServiceHandler<T, S, B>
279where
280 T: AsyncRead + AsyncWrite + Unpin,
281 S: Service<Request = Request>,
282 S::Error: Into<Error> + 'static,
283 S::Future: 'static,
284 S::Response: Into<Response<B>> + 'static,
285 B: MessageBody + 'static,
286{
287 type Request = (T, Option<net::SocketAddr>);
288 type Response = ();
289 type Error = DispatchError;
290 type Future = H2ServiceHandlerResponse<T, S, B>;
291
292 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
293 self.srv.poll_ready(cx).map_err(|e| {
294 let e = e.into();
295 error!("Service readiness error: {:?}", e);
296 DispatchError::Service(e)
297 })
298 }
299
300 fn call(&mut self, (io, addr): Self::Request) -> Self::Future {
301 let on_connect = if let Some(ref on_connect) = self.on_connect {
302 Some(on_connect(&io))
303 } else {
304 None
305 };
306
307 H2ServiceHandlerResponse {
308 state: State::Handshake(
309 Some(self.srv.clone()),
310 Some(self.cfg.clone()),
311 addr,
312 on_connect,
313 server::handshake(io),
314 ),
315 }
316 }
317}
318
319enum State<T, S: Service<Request = Request>, B: MessageBody>
320where
321 T: AsyncRead + AsyncWrite + Unpin,
322 S::Future: 'static,
323{
324 Incoming(Dispatcher<T, S, B>),
325 Handshake(
326 Option<CloneableService<S>>,
327 Option<ServiceConfig>,
328 Option<net::SocketAddr>,
329 Option<Box<dyn DataFactory>>,
330 Handshake<T, Bytes>,
331 ),
332}
333
334pub struct H2ServiceHandlerResponse<T, S, B>
335where
336 T: AsyncRead + AsyncWrite + Unpin,
337 S: Service<Request = Request>,
338 S::Error: Into<Error> + 'static,
339 S::Future: 'static,
340 S::Response: Into<Response<B>> + 'static,
341 B: MessageBody + 'static,
342{
343 state: State<T, S, B>,
344}
345
346impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
347where
348 T: AsyncRead + AsyncWrite + Unpin,
349 S: Service<Request = Request>,
350 S::Error: Into<Error> + 'static,
351 S::Future: 'static,
352 S::Response: Into<Response<B>> + 'static,
353 B: MessageBody,
354{
355 type Output = Result<(), DispatchError>;
356
357 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
358 match self.state {
359 State::Incoming(ref mut disp) => Pin::new(disp).poll(cx),
360 State::Handshake(
361 ref mut srv,
362 ref mut config,
363 ref peer_addr,
364 ref mut on_connect,
365 ref mut handshake,
366 ) => match Pin::new(handshake).poll(cx) {
367 Poll::Ready(Ok(conn)) => {
368 self.state = State::Incoming(Dispatcher::new(
369 srv.take().unwrap(),
370 conn,
371 on_connect.take(),
372 config.take().unwrap(),
373 None,
374 *peer_addr,
375 ));
376 self.poll(cx)
377 }
378 Poll::Ready(Err(err)) => {
379 trace!("H2 handshake error: {}", err);
380 Poll::Ready(Err(err.into()))
381 }
382 Poll::Pending => Poll::Pending,
383 },
384 }
385 }
386}