1use std::{
2 future::Future,
3 marker::PhantomData,
4 mem, net,
5 pin::Pin,
6 rc::Rc,
7 task::{Context, Poll},
8};
9
10use actix_codec::{AsyncRead, AsyncWrite};
11use actix_rt::net::TcpStream;
12use actix_service::{
13 fn_factory, fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _,
14};
15use actix_utils::future::ready;
16use futures_core::{future::LocalBoxFuture, ready};
17use tracing::{error, trace};
18
19use super::{dispatcher::Dispatcher, handshake_with_timeout, HandshakeWithTimeout};
20use crate::{
21 body::{BoxBody, MessageBody},
22 config::ServiceConfig,
23 error::DispatchError,
24 service::HttpFlow,
25 ConnectCallback, OnConnectData, Request, Response,
26};
27
28#[inline]
29fn desired_nodelay(tcp_nodelay: Option<bool>) -> Option<bool> {
30 tcp_nodelay
31}
32
33#[inline]
34fn set_nodelay(stream: &TcpStream, nodelay: bool) {
35 let _ = stream.set_nodelay(nodelay);
36}
37
38pub struct H2Service<T, S, B> {
40 srv: S,
41 cfg: ServiceConfig,
42 on_connect_ext: Option<Rc<ConnectCallback<T>>>,
43 _phantom: PhantomData<(T, B)>,
44}
45
46impl<T, S, B> H2Service<T, S, B>
47where
48 S: ServiceFactory<Request, Config = ()>,
49 S::Error: Into<Response<BoxBody>> + 'static,
50 S::Response: Into<Response<B>> + 'static,
51 <S::Service as Service<Request>>::Future: 'static,
52
53 B: MessageBody + 'static,
54{
55 pub(crate) fn with_config<F: IntoServiceFactory<S, Request>>(
57 cfg: ServiceConfig,
58 service: F,
59 ) -> Self {
60 H2Service {
61 cfg,
62 on_connect_ext: None,
63 srv: service.into_factory(),
64 _phantom: PhantomData,
65 }
66 }
67
68 pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
70 self.on_connect_ext = f;
71 self
72 }
73}
74
75impl<S, B> H2Service<TcpStream, S, B>
76where
77 S: ServiceFactory<Request, Config = ()>,
78 S::Future: 'static,
79 S::Error: Into<Response<BoxBody>> + 'static,
80 S::Response: Into<Response<B>> + 'static,
81 <S::Service as Service<Request>>::Future: 'static,
82
83 B: MessageBody + 'static,
84{
85 pub fn tcp(
87 self,
88 ) -> impl ServiceFactory<
89 TcpStream,
90 Config = (),
91 Response = (),
92 Error = DispatchError,
93 InitError = S::InitError,
94 > {
95 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
96
97 fn_factory(move || {
98 ready(Ok::<_, S::InitError>(fn_service(move |io: TcpStream| {
99 if let Some(nodelay) = tcp_nodelay {
100 set_nodelay(&io, nodelay);
101 }
102 let peer_addr = io.peer_addr().ok();
103 ready(Ok::<_, DispatchError>((io, peer_addr)))
104 })))
105 })
106 .and_then(self)
107 }
108}
109
110#[cfg(feature = "openssl")]
111mod openssl {
112 use actix_tls::accept::{
113 openssl::{
114 reexports::{Error as SslError, SslAcceptor},
115 Acceptor, TlsStream,
116 },
117 TlsError,
118 };
119
120 use super::*;
121
122 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
123 where
124 S: ServiceFactory<Request, Config = ()>,
125 S::Future: 'static,
126 S::Error: Into<Response<BoxBody>> + 'static,
127 S::Response: Into<Response<B>> + 'static,
128 <S::Service as Service<Request>>::Future: 'static,
129
130 B: MessageBody + 'static,
131 {
132 pub fn openssl(
134 self,
135 acceptor: SslAcceptor,
136 ) -> impl ServiceFactory<
137 TcpStream,
138 Config = (),
139 Response = (),
140 Error = TlsError<SslError, DispatchError>,
141 InitError = S::InitError,
142 > {
143 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
144
145 Acceptor::new(acceptor)
146 .map_init_err(|_| {
147 unreachable!("TLS acceptor service factory does not error on init")
148 })
149 .map_err(TlsError::into_service_error)
150 .map(move |io: TlsStream<TcpStream>| {
151 if let Some(nodelay) = tcp_nodelay {
152 set_nodelay(io.get_ref(), nodelay);
153 }
154 let peer_addr = io.get_ref().peer_addr().ok();
155 (io, peer_addr)
156 })
157 .and_then(self.map_err(TlsError::Service))
158 }
159 }
160}
161
162#[cfg(feature = "rustls-0_20")]
163mod rustls_0_20 {
164 use std::io;
165
166 use actix_tls::accept::{
167 rustls::{reexports::ServerConfig, Acceptor, TlsStream},
168 TlsError,
169 };
170
171 use super::*;
172
173 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
174 where
175 S: ServiceFactory<Request, Config = ()>,
176 S::Future: 'static,
177 S::Error: Into<Response<BoxBody>> + 'static,
178 S::Response: Into<Response<B>> + 'static,
179 <S::Service as Service<Request>>::Future: 'static,
180
181 B: MessageBody + 'static,
182 {
183 pub fn rustls(
185 self,
186 mut config: ServerConfig,
187 ) -> impl ServiceFactory<
188 TcpStream,
189 Config = (),
190 Response = (),
191 Error = TlsError<io::Error, DispatchError>,
192 InitError = S::InitError,
193 > {
194 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
195 let mut protos = vec![b"h2".to_vec()];
196 protos.extend_from_slice(&config.alpn_protocols);
197 config.alpn_protocols = protos;
198
199 Acceptor::new(config)
200 .map_init_err(|_| {
201 unreachable!("TLS acceptor service factory does not error on init")
202 })
203 .map_err(TlsError::into_service_error)
204 .map(move |io: TlsStream<TcpStream>| {
205 if let Some(nodelay) = tcp_nodelay {
206 set_nodelay(io.get_ref().0, nodelay);
207 }
208 let peer_addr = io.get_ref().0.peer_addr().ok();
209 (io, peer_addr)
210 })
211 .and_then(self.map_err(TlsError::Service))
212 }
213 }
214}
215
216#[cfg(feature = "rustls-0_21")]
217mod rustls_0_21 {
218 use std::io;
219
220 use actix_tls::accept::{
221 rustls_0_21::{reexports::ServerConfig, Acceptor, TlsStream},
222 TlsError,
223 };
224
225 use super::*;
226
227 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
228 where
229 S: ServiceFactory<Request, Config = ()>,
230 S::Future: 'static,
231 S::Error: Into<Response<BoxBody>> + 'static,
232 S::Response: Into<Response<B>> + 'static,
233 <S::Service as Service<Request>>::Future: 'static,
234
235 B: MessageBody + 'static,
236 {
237 pub fn rustls_021(
239 self,
240 mut config: ServerConfig,
241 ) -> impl ServiceFactory<
242 TcpStream,
243 Config = (),
244 Response = (),
245 Error = TlsError<io::Error, DispatchError>,
246 InitError = S::InitError,
247 > {
248 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
249 let mut protos = vec![b"h2".to_vec()];
250 protos.extend_from_slice(&config.alpn_protocols);
251 config.alpn_protocols = protos;
252
253 Acceptor::new(config)
254 .map_init_err(|_| {
255 unreachable!("TLS acceptor service factory does not error on init")
256 })
257 .map_err(TlsError::into_service_error)
258 .map(move |io: TlsStream<TcpStream>| {
259 if let Some(nodelay) = tcp_nodelay {
260 set_nodelay(io.get_ref().0, nodelay);
261 }
262 let peer_addr = io.get_ref().0.peer_addr().ok();
263 (io, peer_addr)
264 })
265 .and_then(self.map_err(TlsError::Service))
266 }
267 }
268}
269
270#[cfg(feature = "rustls-0_22")]
271mod rustls_0_22 {
272 use std::io;
273
274 use actix_tls::accept::{
275 rustls_0_22::{reexports::ServerConfig, Acceptor, TlsStream},
276 TlsError,
277 };
278
279 use super::*;
280
281 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
282 where
283 S: ServiceFactory<Request, Config = ()>,
284 S::Future: 'static,
285 S::Error: Into<Response<BoxBody>> + 'static,
286 S::Response: Into<Response<B>> + 'static,
287 <S::Service as Service<Request>>::Future: 'static,
288
289 B: MessageBody + 'static,
290 {
291 pub fn rustls_0_22(
293 self,
294 mut config: ServerConfig,
295 ) -> impl ServiceFactory<
296 TcpStream,
297 Config = (),
298 Response = (),
299 Error = TlsError<io::Error, DispatchError>,
300 InitError = S::InitError,
301 > {
302 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
303 let mut protos = vec![b"h2".to_vec()];
304 protos.extend_from_slice(&config.alpn_protocols);
305 config.alpn_protocols = protos;
306
307 Acceptor::new(config)
308 .map_init_err(|_| {
309 unreachable!("TLS acceptor service factory does not error on init")
310 })
311 .map_err(TlsError::into_service_error)
312 .map(move |io: TlsStream<TcpStream>| {
313 if let Some(nodelay) = tcp_nodelay {
314 set_nodelay(io.get_ref().0, nodelay);
315 }
316 let peer_addr = io.get_ref().0.peer_addr().ok();
317 (io, peer_addr)
318 })
319 .and_then(self.map_err(TlsError::Service))
320 }
321 }
322}
323
324#[cfg(feature = "rustls-0_23")]
325mod rustls_0_23 {
326 use std::io;
327
328 use actix_tls::accept::{
329 rustls_0_23::{reexports::ServerConfig, Acceptor, TlsStream},
330 TlsError,
331 };
332
333 use super::*;
334
335 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
336 where
337 S: ServiceFactory<Request, Config = ()>,
338 S::Future: 'static,
339 S::Error: Into<Response<BoxBody>> + 'static,
340 S::Response: Into<Response<B>> + 'static,
341 <S::Service as Service<Request>>::Future: 'static,
342
343 B: MessageBody + 'static,
344 {
345 pub fn rustls_0_23(
347 self,
348 mut config: ServerConfig,
349 ) -> impl ServiceFactory<
350 TcpStream,
351 Config = (),
352 Response = (),
353 Error = TlsError<io::Error, DispatchError>,
354 InitError = S::InitError,
355 > {
356 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
357 let mut protos = vec![b"h2".to_vec()];
358 protos.extend_from_slice(&config.alpn_protocols);
359 config.alpn_protocols = protos;
360
361 Acceptor::new(config)
362 .map_init_err(|_| {
363 unreachable!("TLS acceptor service factory does not error on init")
364 })
365 .map_err(TlsError::into_service_error)
366 .map(move |io: TlsStream<TcpStream>| {
367 if let Some(nodelay) = tcp_nodelay {
368 set_nodelay(io.get_ref().0, nodelay);
369 }
370 let peer_addr = io.get_ref().0.peer_addr().ok();
371 (io, peer_addr)
372 })
373 .and_then(self.map_err(TlsError::Service))
374 }
375 }
376}
377
378impl<T, S, B> ServiceFactory<(T, Option<net::SocketAddr>)> for H2Service<T, S, B>
379where
380 T: AsyncRead + AsyncWrite + Unpin + 'static,
381
382 S: ServiceFactory<Request, Config = ()>,
383 S::Future: 'static,
384 S::Error: Into<Response<BoxBody>> + 'static,
385 S::Response: Into<Response<B>> + 'static,
386 <S::Service as Service<Request>>::Future: 'static,
387
388 B: MessageBody + 'static,
389{
390 type Response = ();
391 type Error = DispatchError;
392 type Config = ();
393 type Service = H2ServiceHandler<T, S::Service, B>;
394 type InitError = S::InitError;
395 type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
396
397 fn new_service(&self, _: ()) -> Self::Future {
398 let service = self.srv.new_service(());
399 let cfg = self.cfg.clone();
400 let on_connect_ext = self.on_connect_ext.clone();
401
402 Box::pin(async move {
403 let service = service.await?;
404 Ok(H2ServiceHandler::new(cfg, on_connect_ext, service))
405 })
406 }
407}
408
409pub struct H2ServiceHandler<T, S, B>
411where
412 S: Service<Request>,
413{
414 flow: Rc<HttpFlow<S, (), ()>>,
415 cfg: ServiceConfig,
416 on_connect_ext: Option<Rc<ConnectCallback<T>>>,
417 _phantom: PhantomData<B>,
418}
419
420impl<T, S, B> H2ServiceHandler<T, S, B>
421where
422 S: Service<Request>,
423 S::Error: Into<Response<BoxBody>> + 'static,
424 S::Future: 'static,
425 S::Response: Into<Response<B>> + 'static,
426 B: MessageBody + 'static,
427{
428 fn new(
429 cfg: ServiceConfig,
430 on_connect_ext: Option<Rc<ConnectCallback<T>>>,
431 service: S,
432 ) -> H2ServiceHandler<T, S, B> {
433 H2ServiceHandler {
434 flow: HttpFlow::new(service, (), None),
435 cfg,
436 on_connect_ext,
437 _phantom: PhantomData,
438 }
439 }
440}
441
442impl<T, S, B> Service<(T, Option<net::SocketAddr>)> for H2ServiceHandler<T, S, B>
443where
444 T: AsyncRead + AsyncWrite + Unpin,
445 S: Service<Request>,
446 S::Error: Into<Response<BoxBody>> + 'static,
447 S::Future: 'static,
448 S::Response: Into<Response<B>> + 'static,
449 B: MessageBody + 'static,
450{
451 type Response = ();
452 type Error = DispatchError;
453 type Future = H2ServiceHandlerResponse<T, S, B>;
454
455 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
456 self.flow.service.poll_ready(cx).map_err(|err| {
457 let err = err.into();
458 error!("Service readiness error: {:?}", err);
459 DispatchError::Service(err)
460 })
461 }
462
463 fn call(&self, (io, addr): (T, Option<net::SocketAddr>)) -> Self::Future {
464 let on_connect_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
465
466 H2ServiceHandlerResponse {
467 state: State::Handshake(
468 Some(Rc::clone(&self.flow)),
469 Some(self.cfg.clone()),
470 addr,
471 on_connect_data,
472 handshake_with_timeout(io, &self.cfg),
473 ),
474 }
475 }
476}
477
478enum State<T, S: Service<Request>, B: MessageBody>
479where
480 T: AsyncRead + AsyncWrite + Unpin,
481 S::Future: 'static,
482{
483 Handshake(
484 Option<Rc<HttpFlow<S, (), ()>>>,
485 Option<ServiceConfig>,
486 Option<net::SocketAddr>,
487 OnConnectData,
488 HandshakeWithTimeout<T>,
489 ),
490 Established(Dispatcher<T, S, B, (), ()>),
491}
492
493pub struct H2ServiceHandlerResponse<T, S, B>
494where
495 T: AsyncRead + AsyncWrite + Unpin,
496 S: Service<Request>,
497 S::Error: Into<Response<BoxBody>> + 'static,
498 S::Future: 'static,
499 S::Response: Into<Response<B>> + 'static,
500 B: MessageBody + 'static,
501{
502 state: State<T, S, B>,
503}
504
505impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
506where
507 T: AsyncRead + AsyncWrite + Unpin,
508 S: Service<Request>,
509 S::Error: Into<Response<BoxBody>> + 'static,
510 S::Future: 'static,
511 S::Response: Into<Response<B>> + 'static,
512 B: MessageBody,
513{
514 type Output = Result<(), DispatchError>;
515
516 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
517 match self.state {
518 State::Handshake(
519 ref mut srv,
520 ref mut config,
521 ref peer_addr,
522 ref mut conn_data,
523 ref mut handshake,
524 ) => match ready!(Pin::new(handshake).poll(cx)) {
525 Ok((conn, timer)) => {
526 let on_connect_data = mem::take(conn_data);
527
528 self.state = State::Established(Dispatcher::new(
529 conn,
530 srv.take().unwrap(),
531 config.take().unwrap(),
532 *peer_addr,
533 on_connect_data,
534 timer,
535 ));
536
537 self.poll(cx)
538 }
539
540 Err(err) => {
541 trace!("H2 handshake error: {}", err);
542 Poll::Ready(Err(err))
543 }
544 },
545
546 State::Established(ref mut disp) => Pin::new(disp).poll(cx),
547 }
548 }
549}