1use std::{
2 future::Future,
3 marker::PhantomData,
4 mem, net,
5 pin::Pin,
6 rc::Rc,
7 task::{Context, Poll},
8};
9
10use actix_codec::{AsyncRead, AsyncWrite};
11use actix_rt::net::TcpStream;
12use actix_service::{
13 fn_factory, fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _,
14};
15use actix_utils::future::ready;
16use futures_core::{future::LocalBoxFuture, ready};
17use tracing::{error, trace};
18
19use super::{dispatcher::Dispatcher, handshake_with_timeout, HandshakeWithTimeout};
20use crate::{
21 body::{BoxBody, MessageBody},
22 config::ServiceConfig,
23 error::DispatchError,
24 service::HttpFlow,
25 ConnectCallback, OnConnectData, Request, Response,
26};
27
28#[inline]
29fn desired_nodelay(tcp_nodelay: Option<bool>) -> Option<bool> {
30 tcp_nodelay
31}
32
33#[inline]
34fn set_nodelay(stream: &TcpStream, nodelay: bool) {
35 let _ = stream.set_nodelay(nodelay);
36}
37
38pub struct H2Service<T, S, B> {
40 srv: S,
41 cfg: ServiceConfig,
42 on_connect_ext: Option<Rc<ConnectCallback<T>>>,
43 _phantom: PhantomData<(T, B)>,
44}
45
46impl<T, S, B> H2Service<T, S, B>
47where
48 S: ServiceFactory<Request, Config = ()>,
49 S::Error: Into<Response<BoxBody>> + 'static,
50 S::Response: Into<Response<B>> + 'static,
51 <S::Service as Service<Request>>::Future: 'static,
52
53 B: MessageBody + 'static,
54{
55 pub(crate) fn with_config<F: IntoServiceFactory<S, Request>>(
57 cfg: ServiceConfig,
58 service: F,
59 ) -> Self {
60 H2Service {
61 cfg,
62 on_connect_ext: None,
63 srv: service.into_factory(),
64 _phantom: PhantomData,
65 }
66 }
67
68 pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
70 self.on_connect_ext = f;
71 self
72 }
73}
74
75impl<S, B> H2Service<TcpStream, S, B>
76where
77 S: ServiceFactory<Request, Config = ()>,
78 S::Future: 'static,
79 S::Error: Into<Response<BoxBody>> + 'static,
80 S::Response: Into<Response<B>> + 'static,
81 <S::Service as Service<Request>>::Future: 'static,
82
83 B: MessageBody + 'static,
84{
85 pub fn tcp(
87 self,
88 ) -> impl ServiceFactory<
89 TcpStream,
90 Config = (),
91 Response = (),
92 Error = DispatchError,
93 InitError = S::InitError,
94 > {
95 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
96
97 fn_factory(move || {
98 ready(Ok::<_, S::InitError>(fn_service(move |io: TcpStream| {
99 if let Some(nodelay) = tcp_nodelay {
100 set_nodelay(&io, nodelay);
101 }
102 let peer_addr = io.peer_addr().ok();
103 ready(Ok::<_, DispatchError>((io, peer_addr)))
104 })))
105 })
106 .and_then(self)
107 }
108}
109
110#[cfg(feature = "openssl")]
111mod openssl {
112 use actix_service::ServiceFactoryExt as _;
113 use actix_tls::accept::{
114 openssl::{
115 reexports::{Error as SslError, SslAcceptor},
116 Acceptor, TlsStream,
117 },
118 TlsError,
119 };
120
121 use super::*;
122
123 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
124 where
125 S: ServiceFactory<Request, Config = ()>,
126 S::Future: 'static,
127 S::Error: Into<Response<BoxBody>> + 'static,
128 S::Response: Into<Response<B>> + 'static,
129 <S::Service as Service<Request>>::Future: 'static,
130
131 B: MessageBody + 'static,
132 {
133 pub fn openssl(
135 self,
136 acceptor: SslAcceptor,
137 ) -> impl ServiceFactory<
138 TcpStream,
139 Config = (),
140 Response = (),
141 Error = TlsError<SslError, DispatchError>,
142 InitError = S::InitError,
143 > {
144 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
145
146 Acceptor::new(acceptor)
147 .map_init_err(|_| {
148 unreachable!("TLS acceptor service factory does not error on init")
149 })
150 .map_err(TlsError::into_service_error)
151 .map(move |io: TlsStream<TcpStream>| {
152 if let Some(nodelay) = tcp_nodelay {
153 set_nodelay(io.get_ref(), nodelay);
154 }
155 let peer_addr = io.get_ref().peer_addr().ok();
156 (io, peer_addr)
157 })
158 .and_then(self.map_err(TlsError::Service))
159 }
160 }
161}
162
163#[cfg(feature = "rustls-0_20")]
164mod rustls_0_20 {
165 use std::io;
166
167 use actix_service::ServiceFactoryExt as _;
168 use actix_tls::accept::{
169 rustls::{reexports::ServerConfig, Acceptor, TlsStream},
170 TlsError,
171 };
172
173 use super::*;
174
175 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
176 where
177 S: ServiceFactory<Request, Config = ()>,
178 S::Future: 'static,
179 S::Error: Into<Response<BoxBody>> + 'static,
180 S::Response: Into<Response<B>> + 'static,
181 <S::Service as Service<Request>>::Future: 'static,
182
183 B: MessageBody + 'static,
184 {
185 pub fn rustls(
187 self,
188 mut config: ServerConfig,
189 ) -> impl ServiceFactory<
190 TcpStream,
191 Config = (),
192 Response = (),
193 Error = TlsError<io::Error, DispatchError>,
194 InitError = S::InitError,
195 > {
196 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
197 let mut protos = vec![b"h2".to_vec()];
198 protos.extend_from_slice(&config.alpn_protocols);
199 config.alpn_protocols = protos;
200
201 Acceptor::new(config)
202 .map_init_err(|_| {
203 unreachable!("TLS acceptor service factory does not error on init")
204 })
205 .map_err(TlsError::into_service_error)
206 .map(move |io: TlsStream<TcpStream>| {
207 if let Some(nodelay) = tcp_nodelay {
208 set_nodelay(io.get_ref().0, nodelay);
209 }
210 let peer_addr = io.get_ref().0.peer_addr().ok();
211 (io, peer_addr)
212 })
213 .and_then(self.map_err(TlsError::Service))
214 }
215 }
216}
217
218#[cfg(feature = "rustls-0_21")]
219mod rustls_0_21 {
220 use std::io;
221
222 use actix_service::ServiceFactoryExt as _;
223 use actix_tls::accept::{
224 rustls_0_21::{reexports::ServerConfig, Acceptor, TlsStream},
225 TlsError,
226 };
227
228 use super::*;
229
230 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
231 where
232 S: ServiceFactory<Request, Config = ()>,
233 S::Future: 'static,
234 S::Error: Into<Response<BoxBody>> + 'static,
235 S::Response: Into<Response<B>> + 'static,
236 <S::Service as Service<Request>>::Future: 'static,
237
238 B: MessageBody + 'static,
239 {
240 pub fn rustls_021(
242 self,
243 mut config: ServerConfig,
244 ) -> impl ServiceFactory<
245 TcpStream,
246 Config = (),
247 Response = (),
248 Error = TlsError<io::Error, DispatchError>,
249 InitError = S::InitError,
250 > {
251 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
252 let mut protos = vec![b"h2".to_vec()];
253 protos.extend_from_slice(&config.alpn_protocols);
254 config.alpn_protocols = protos;
255
256 Acceptor::new(config)
257 .map_init_err(|_| {
258 unreachable!("TLS acceptor service factory does not error on init")
259 })
260 .map_err(TlsError::into_service_error)
261 .map(move |io: TlsStream<TcpStream>| {
262 if let Some(nodelay) = tcp_nodelay {
263 set_nodelay(io.get_ref().0, nodelay);
264 }
265 let peer_addr = io.get_ref().0.peer_addr().ok();
266 (io, peer_addr)
267 })
268 .and_then(self.map_err(TlsError::Service))
269 }
270 }
271}
272
273#[cfg(feature = "rustls-0_22")]
274mod rustls_0_22 {
275 use std::io;
276
277 use actix_service::ServiceFactoryExt as _;
278 use actix_tls::accept::{
279 rustls_0_22::{reexports::ServerConfig, Acceptor, TlsStream},
280 TlsError,
281 };
282
283 use super::*;
284
285 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
286 where
287 S: ServiceFactory<Request, Config = ()>,
288 S::Future: 'static,
289 S::Error: Into<Response<BoxBody>> + 'static,
290 S::Response: Into<Response<B>> + 'static,
291 <S::Service as Service<Request>>::Future: 'static,
292
293 B: MessageBody + 'static,
294 {
295 pub fn rustls_0_22(
297 self,
298 mut config: ServerConfig,
299 ) -> impl ServiceFactory<
300 TcpStream,
301 Config = (),
302 Response = (),
303 Error = TlsError<io::Error, DispatchError>,
304 InitError = S::InitError,
305 > {
306 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
307 let mut protos = vec![b"h2".to_vec()];
308 protos.extend_from_slice(&config.alpn_protocols);
309 config.alpn_protocols = protos;
310
311 Acceptor::new(config)
312 .map_init_err(|_| {
313 unreachable!("TLS acceptor service factory does not error on init")
314 })
315 .map_err(TlsError::into_service_error)
316 .map(move |io: TlsStream<TcpStream>| {
317 if let Some(nodelay) = tcp_nodelay {
318 set_nodelay(io.get_ref().0, nodelay);
319 }
320 let peer_addr = io.get_ref().0.peer_addr().ok();
321 (io, peer_addr)
322 })
323 .and_then(self.map_err(TlsError::Service))
324 }
325 }
326}
327
328#[cfg(feature = "rustls-0_23")]
329mod rustls_0_23 {
330 use std::io;
331
332 use actix_service::ServiceFactoryExt as _;
333 use actix_tls::accept::{
334 rustls_0_23::{reexports::ServerConfig, Acceptor, TlsStream},
335 TlsError,
336 };
337
338 use super::*;
339
340 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
341 where
342 S: ServiceFactory<Request, Config = ()>,
343 S::Future: 'static,
344 S::Error: Into<Response<BoxBody>> + 'static,
345 S::Response: Into<Response<B>> + 'static,
346 <S::Service as Service<Request>>::Future: 'static,
347
348 B: MessageBody + 'static,
349 {
350 pub fn rustls_0_23(
352 self,
353 mut config: ServerConfig,
354 ) -> impl ServiceFactory<
355 TcpStream,
356 Config = (),
357 Response = (),
358 Error = TlsError<io::Error, DispatchError>,
359 InitError = S::InitError,
360 > {
361 let tcp_nodelay = desired_nodelay(self.cfg.tcp_nodelay());
362 let mut protos = vec![b"h2".to_vec()];
363 protos.extend_from_slice(&config.alpn_protocols);
364 config.alpn_protocols = protos;
365
366 Acceptor::new(config)
367 .map_init_err(|_| {
368 unreachable!("TLS acceptor service factory does not error on init")
369 })
370 .map_err(TlsError::into_service_error)
371 .map(move |io: TlsStream<TcpStream>| {
372 if let Some(nodelay) = tcp_nodelay {
373 set_nodelay(io.get_ref().0, nodelay);
374 }
375 let peer_addr = io.get_ref().0.peer_addr().ok();
376 (io, peer_addr)
377 })
378 .and_then(self.map_err(TlsError::Service))
379 }
380 }
381}
382
383impl<T, S, B> ServiceFactory<(T, Option<net::SocketAddr>)> for H2Service<T, S, B>
384where
385 T: AsyncRead + AsyncWrite + Unpin + 'static,
386
387 S: ServiceFactory<Request, Config = ()>,
388 S::Future: 'static,
389 S::Error: Into<Response<BoxBody>> + 'static,
390 S::Response: Into<Response<B>> + 'static,
391 <S::Service as Service<Request>>::Future: 'static,
392
393 B: MessageBody + 'static,
394{
395 type Response = ();
396 type Error = DispatchError;
397 type Config = ();
398 type Service = H2ServiceHandler<T, S::Service, B>;
399 type InitError = S::InitError;
400 type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
401
402 fn new_service(&self, _: ()) -> Self::Future {
403 let service = self.srv.new_service(());
404 let cfg = self.cfg.clone();
405 let on_connect_ext = self.on_connect_ext.clone();
406
407 Box::pin(async move {
408 let service = service.await?;
409 Ok(H2ServiceHandler::new(cfg, on_connect_ext, service))
410 })
411 }
412}
413
414pub struct H2ServiceHandler<T, S, B>
416where
417 S: Service<Request>,
418{
419 flow: Rc<HttpFlow<S, (), ()>>,
420 cfg: ServiceConfig,
421 on_connect_ext: Option<Rc<ConnectCallback<T>>>,
422 _phantom: PhantomData<B>,
423}
424
425impl<T, S, B> H2ServiceHandler<T, S, B>
426where
427 S: Service<Request>,
428 S::Error: Into<Response<BoxBody>> + 'static,
429 S::Future: 'static,
430 S::Response: Into<Response<B>> + 'static,
431 B: MessageBody + 'static,
432{
433 fn new(
434 cfg: ServiceConfig,
435 on_connect_ext: Option<Rc<ConnectCallback<T>>>,
436 service: S,
437 ) -> H2ServiceHandler<T, S, B> {
438 H2ServiceHandler {
439 flow: HttpFlow::new(service, (), None),
440 cfg,
441 on_connect_ext,
442 _phantom: PhantomData,
443 }
444 }
445}
446
447impl<T, S, B> Service<(T, Option<net::SocketAddr>)> for H2ServiceHandler<T, S, B>
448where
449 T: AsyncRead + AsyncWrite + Unpin,
450 S: Service<Request>,
451 S::Error: Into<Response<BoxBody>> + 'static,
452 S::Future: 'static,
453 S::Response: Into<Response<B>> + 'static,
454 B: MessageBody + 'static,
455{
456 type Response = ();
457 type Error = DispatchError;
458 type Future = H2ServiceHandlerResponse<T, S, B>;
459
460 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
461 self.flow.service.poll_ready(cx).map_err(|err| {
462 let err = err.into();
463 error!("Service readiness error: {:?}", err);
464 DispatchError::Service(err)
465 })
466 }
467
468 fn call(&self, (io, addr): (T, Option<net::SocketAddr>)) -> Self::Future {
469 let on_connect_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
470
471 H2ServiceHandlerResponse {
472 state: State::Handshake(
473 Some(Rc::clone(&self.flow)),
474 Some(self.cfg.clone()),
475 addr,
476 on_connect_data,
477 handshake_with_timeout(io, &self.cfg),
478 ),
479 }
480 }
481}
482
483enum State<T, S: Service<Request>, B: MessageBody>
484where
485 T: AsyncRead + AsyncWrite + Unpin,
486 S::Future: 'static,
487{
488 Handshake(
489 Option<Rc<HttpFlow<S, (), ()>>>,
490 Option<ServiceConfig>,
491 Option<net::SocketAddr>,
492 OnConnectData,
493 HandshakeWithTimeout<T>,
494 ),
495 Established(Dispatcher<T, S, B, (), ()>),
496}
497
498pub struct H2ServiceHandlerResponse<T, S, B>
499where
500 T: AsyncRead + AsyncWrite + Unpin,
501 S: Service<Request>,
502 S::Error: Into<Response<BoxBody>> + 'static,
503 S::Future: 'static,
504 S::Response: Into<Response<B>> + 'static,
505 B: MessageBody + 'static,
506{
507 state: State<T, S, B>,
508}
509
510impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
511where
512 T: AsyncRead + AsyncWrite + Unpin,
513 S: Service<Request>,
514 S::Error: Into<Response<BoxBody>> + 'static,
515 S::Future: 'static,
516 S::Response: Into<Response<B>> + 'static,
517 B: MessageBody,
518{
519 type Output = Result<(), DispatchError>;
520
521 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
522 match self.state {
523 State::Handshake(
524 ref mut srv,
525 ref mut config,
526 ref peer_addr,
527 ref mut conn_data,
528 ref mut handshake,
529 ) => match ready!(Pin::new(handshake).poll(cx)) {
530 Ok((conn, timer)) => {
531 let on_connect_data = mem::take(conn_data);
532
533 self.state = State::Established(Dispatcher::new(
534 conn,
535 srv.take().unwrap(),
536 config.take().unwrap(),
537 *peer_addr,
538 on_connect_data,
539 timer,
540 ));
541
542 self.poll(cx)
543 }
544
545 Err(err) => {
546 trace!("H2 handshake error: {}", err);
547 Poll::Ready(Err(err))
548 }
549 },
550
551 State::Established(ref mut disp) => Pin::new(disp).poll(cx),
552 }
553 }
554}