1use std::{
2 fmt,
3 future::Future,
4 marker::PhantomData,
5 net,
6 pin::Pin,
7 rc::Rc,
8 task::{Context, Poll},
9};
10
11use actix_codec::{AsyncRead, AsyncWrite, Framed};
12use actix_rt::net::TcpStream;
13use actix_service::{
14 fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _,
15};
16use futures_core::{future::LocalBoxFuture, ready};
17use pin_project_lite::pin_project;
18use tracing::error;
19
20use crate::{
21 body::{BoxBody, MessageBody},
22 builder::HttpServiceBuilder,
23 error::DispatchError,
24 h1, ConnectCallback, OnConnectData, Protocol, Request, Response, ServiceConfig,
25};
26
27#[inline]
28fn desired_nodelay(tcp_nodelay: Option<bool>) -> Option<bool> {
29 tcp_nodelay
30}
31
32#[inline]
33fn set_nodelay(stream: &TcpStream, nodelay: bool) {
34 let _ = stream.set_nodelay(nodelay);
35}
36
37pub struct HttpService<T, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler> {
71 srv: S,
72 cfg: ServiceConfig,
73 expect: X,
74 upgrade: Option<U>,
75 on_connect_ext: Option<Rc<ConnectCallback<T>>>,
76 _phantom: PhantomData<B>,
77}
78
79impl<T, S, B> HttpService<T, S, B>
80where
81 S: ServiceFactory<Request, Config = ()>,
82 S::Error: Into<Response<BoxBody>> + 'static,
83 S::InitError: fmt::Debug,
84 S::Response: Into<Response<B>> + 'static,
85 <S::Service as Service<Request>>::Future: 'static,
86 B: MessageBody + 'static,
87{
88 pub fn build() -> HttpServiceBuilder<T, S> {
90 HttpServiceBuilder::default()
91 }
92}
93
94impl<T, S, B> HttpService<T, S, B>
95where
96 S: ServiceFactory<Request, Config = ()>,
97 S::Error: Into<Response<BoxBody>> + 'static,
98 S::InitError: fmt::Debug,
99 S::Response: Into<Response<B>> + 'static,
100 <S::Service as Service<Request>>::Future: 'static,
101 B: MessageBody + 'static,
102{
103 pub fn new<F: IntoServiceFactory<S, Request>>(service: F) -> Self {
105 HttpService {
106 cfg: ServiceConfig::default(),
107 srv: service.into_factory(),
108 expect: h1::ExpectHandler,
109 upgrade: None,
110 on_connect_ext: None,
111 _phantom: PhantomData,
112 }
113 }
114
115 pub(crate) fn with_config<F: IntoServiceFactory<S, Request>>(
117 cfg: ServiceConfig,
118 service: F,
119 ) -> Self {
120 HttpService {
121 cfg,
122 srv: service.into_factory(),
123 expect: h1::ExpectHandler,
124 upgrade: None,
125 on_connect_ext: None,
126 _phantom: PhantomData,
127 }
128 }
129}
130
131impl<T, S, B, X, U> HttpService<T, S, B, X, U>
132where
133 S: ServiceFactory<Request, Config = ()>,
134 S::Error: Into<Response<BoxBody>> + 'static,
135 S::InitError: fmt::Debug,
136 S::Response: Into<Response<B>> + 'static,
137 <S::Service as Service<Request>>::Future: 'static,
138 B: MessageBody,
139{
140 pub fn expect<X1>(self, expect: X1) -> HttpService<T, S, B, X1, U>
145 where
146 X1: ServiceFactory<Request, Config = (), Response = Request>,
147 X1::Error: Into<Response<BoxBody>>,
148 X1::InitError: fmt::Debug,
149 {
150 HttpService {
151 expect,
152 cfg: self.cfg,
153 srv: self.srv,
154 upgrade: self.upgrade,
155 on_connect_ext: self.on_connect_ext,
156 _phantom: PhantomData,
157 }
158 }
159
160 pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<T, S, B, X, U1>
165 where
166 U1: ServiceFactory<(Request, Framed<T, h1::Codec>), Config = (), Response = ()>,
167 U1::Error: fmt::Display,
168 U1::InitError: fmt::Debug,
169 {
170 HttpService {
171 upgrade,
172 cfg: self.cfg,
173 srv: self.srv,
174 expect: self.expect,
175 on_connect_ext: self.on_connect_ext,
176 _phantom: PhantomData,
177 }
178 }
179
180 pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
182 self.on_connect_ext = f;
183 self
184 }
185}
186
187impl<S, B, X, U> HttpService<TcpStream, S, B, X, U>
188where
189 S: ServiceFactory<Request, Config = ()>,
190 S::Future: 'static,
191 S::Error: Into<Response<BoxBody>> + 'static,
192 S::InitError: fmt::Debug,
193 S::Response: Into<Response<B>> + 'static,
194 <S::Service as Service<Request>>::Future: 'static,
195
196 B: MessageBody + 'static,
197
198 X: ServiceFactory<Request, Config = (), Response = Request>,
199 X::Future: 'static,
200 X::Error: Into<Response<BoxBody>>,
201 X::InitError: fmt::Debug,
202
203 U: ServiceFactory<(Request, Framed<TcpStream, h1::Codec>), Config = (), Response = ()>,
204 U::Future: 'static,
205 U::Error: fmt::Display + Into<Response<BoxBody>>,
206 U::InitError: fmt::Debug,
207{
208 pub fn tcp(
212 self,
213 ) -> impl ServiceFactory<TcpStream, Config = (), Response = (), Error = DispatchError, InitError = ()>
214 {
215 let tcp_nodelay = self.cfg.tcp_nodelay();
216
217 fn_service(move |io: TcpStream| async move {
218 if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
219 set_nodelay(&io, nodelay);
220 }
221
222 let peer_addr = io.peer_addr().ok();
223 Ok((io, Protocol::Http1, peer_addr))
224 })
225 .and_then(self)
226 }
227
228 #[cfg(feature = "http2")]
231 pub fn tcp_auto_h2c(
232 self,
233 ) -> impl ServiceFactory<TcpStream, Config = (), Response = (), Error = DispatchError, InitError = ()>
234 {
235 let tcp_nodelay = self.cfg.tcp_nodelay();
236
237 fn_service(move |io: TcpStream| async move {
238 const H2_PREFACE: &[u8] = b"PRI * HTTP/2";
243
244 let mut buf = [0; 12];
245
246 io.peek(&mut buf).await?;
247
248 let proto = if buf == H2_PREFACE {
249 Protocol::Http2
250 } else {
251 Protocol::Http1
252 };
253
254 if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
255 set_nodelay(&io, nodelay);
256 }
257
258 let peer_addr = io.peer_addr().ok();
259 Ok((io, proto, peer_addr))
260 })
261 .and_then(self)
262 }
263}
264
265#[cfg(feature = "__tls")]
267#[derive(Debug, Default)]
268pub struct TlsAcceptorConfig {
269 pub(crate) handshake_timeout: Option<std::time::Duration>,
270}
271
272#[cfg(feature = "__tls")]
273impl TlsAcceptorConfig {
274 pub fn handshake_timeout(self, dur: std::time::Duration) -> Self {
276 Self {
277 handshake_timeout: Some(dur),
278 }
280 }
281}
282
283#[cfg(feature = "openssl")]
284mod openssl {
285 use actix_service::ServiceFactoryExt as _;
286 use actix_tls::accept::{
287 openssl::{
288 reexports::{Error as SslError, SslAcceptor},
289 Acceptor, TlsStream,
290 },
291 TlsError,
292 };
293
294 use super::*;
295
296 impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
297 where
298 S: ServiceFactory<Request, Config = ()>,
299 S::Future: 'static,
300 S::Error: Into<Response<BoxBody>> + 'static,
301 S::InitError: fmt::Debug,
302 S::Response: Into<Response<B>> + 'static,
303 <S::Service as Service<Request>>::Future: 'static,
304
305 B: MessageBody + 'static,
306
307 X: ServiceFactory<Request, Config = (), Response = Request>,
308 X::Future: 'static,
309 X::Error: Into<Response<BoxBody>>,
310 X::InitError: fmt::Debug,
311
312 U: ServiceFactory<
313 (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
314 Config = (),
315 Response = (),
316 >,
317 U::Future: 'static,
318 U::Error: fmt::Display + Into<Response<BoxBody>>,
319 U::InitError: fmt::Debug,
320 {
321 pub fn openssl(
323 self,
324 acceptor: SslAcceptor,
325 ) -> impl ServiceFactory<
326 TcpStream,
327 Config = (),
328 Response = (),
329 Error = TlsError<SslError, DispatchError>,
330 InitError = (),
331 > {
332 self.openssl_with_config(acceptor, TlsAcceptorConfig::default())
333 }
334
335 pub fn openssl_with_config(
337 self,
338 acceptor: SslAcceptor,
339 tls_acceptor_config: TlsAcceptorConfig,
340 ) -> impl ServiceFactory<
341 TcpStream,
342 Config = (),
343 Response = (),
344 Error = TlsError<SslError, DispatchError>,
345 InitError = (),
346 > {
347 let tcp_nodelay = self.cfg.tcp_nodelay();
348 let mut acceptor = Acceptor::new(acceptor);
349
350 if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
351 acceptor.set_handshake_timeout(handshake_timeout);
352 }
353
354 acceptor
355 .map_init_err(|_| {
356 unreachable!("TLS acceptor service factory does not error on init")
357 })
358 .map_err(TlsError::into_service_error)
359 .map(move |io: TlsStream<TcpStream>| {
360 let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() {
361 if protos.windows(2).any(|window| window == b"h2") {
362 Protocol::Http2
363 } else {
364 Protocol::Http1
365 }
366 } else {
367 Protocol::Http1
368 };
369
370 if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
371 set_nodelay(io.get_ref(), nodelay);
372 }
373
374 let peer_addr = io.get_ref().peer_addr().ok();
375 (io, proto, peer_addr)
376 })
377 .and_then(self.map_err(TlsError::Service))
378 }
379 }
380}
381
382#[cfg(feature = "rustls-0_20")]
383mod rustls_0_20 {
384 use std::io;
385
386 use actix_service::ServiceFactoryExt as _;
387 use actix_tls::accept::{
388 rustls_0_20::{reexports::ServerConfig, Acceptor, TlsStream},
389 TlsError,
390 };
391
392 use super::*;
393
394 impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
395 where
396 S: ServiceFactory<Request, Config = ()>,
397 S::Future: 'static,
398 S::Error: Into<Response<BoxBody>> + 'static,
399 S::InitError: fmt::Debug,
400 S::Response: Into<Response<B>> + 'static,
401 <S::Service as Service<Request>>::Future: 'static,
402
403 B: MessageBody + 'static,
404
405 X: ServiceFactory<Request, Config = (), Response = Request>,
406 X::Future: 'static,
407 X::Error: Into<Response<BoxBody>>,
408 X::InitError: fmt::Debug,
409
410 U: ServiceFactory<
411 (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
412 Config = (),
413 Response = (),
414 >,
415 U::Future: 'static,
416 U::Error: fmt::Display + Into<Response<BoxBody>>,
417 U::InitError: fmt::Debug,
418 {
419 pub fn rustls(
421 self,
422 config: ServerConfig,
423 ) -> impl ServiceFactory<
424 TcpStream,
425 Config = (),
426 Response = (),
427 Error = TlsError<io::Error, DispatchError>,
428 InitError = (),
429 > {
430 self.rustls_with_config(config, TlsAcceptorConfig::default())
431 }
432
433 pub fn rustls_with_config(
435 self,
436 mut config: ServerConfig,
437 tls_acceptor_config: TlsAcceptorConfig,
438 ) -> impl ServiceFactory<
439 TcpStream,
440 Config = (),
441 Response = (),
442 Error = TlsError<io::Error, DispatchError>,
443 InitError = (),
444 > {
445 let tcp_nodelay = self.cfg.tcp_nodelay();
446 let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
447 protos.extend_from_slice(&config.alpn_protocols);
448 config.alpn_protocols = protos;
449
450 let mut acceptor = Acceptor::new(config);
451
452 if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
453 acceptor.set_handshake_timeout(handshake_timeout);
454 }
455
456 acceptor
457 .map_init_err(|_| {
458 unreachable!("TLS acceptor service factory does not error on init")
459 })
460 .map_err(TlsError::into_service_error)
461 .and_then(move |io: TlsStream<TcpStream>| async move {
462 let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
463 if protos.windows(2).any(|window| window == b"h2") {
464 Protocol::Http2
465 } else {
466 Protocol::Http1
467 }
468 } else {
469 Protocol::Http1
470 };
471
472 if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
473 set_nodelay(io.get_ref().0, nodelay);
474 }
475
476 let peer_addr = io.get_ref().0.peer_addr().ok();
477 Ok((io, proto, peer_addr))
478 })
479 .and_then(self.map_err(TlsError::Service))
480 }
481 }
482}
483
484#[cfg(feature = "rustls-0_21")]
485mod rustls_0_21 {
486 use std::io;
487
488 use actix_service::ServiceFactoryExt as _;
489 use actix_tls::accept::{
490 rustls_0_21::{reexports::ServerConfig, Acceptor, TlsStream},
491 TlsError,
492 };
493
494 use super::*;
495
496 impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
497 where
498 S: ServiceFactory<Request, Config = ()>,
499 S::Future: 'static,
500 S::Error: Into<Response<BoxBody>> + 'static,
501 S::InitError: fmt::Debug,
502 S::Response: Into<Response<B>> + 'static,
503 <S::Service as Service<Request>>::Future: 'static,
504
505 B: MessageBody + 'static,
506
507 X: ServiceFactory<Request, Config = (), Response = Request>,
508 X::Future: 'static,
509 X::Error: Into<Response<BoxBody>>,
510 X::InitError: fmt::Debug,
511
512 U: ServiceFactory<
513 (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
514 Config = (),
515 Response = (),
516 >,
517 U::Future: 'static,
518 U::Error: fmt::Display + Into<Response<BoxBody>>,
519 U::InitError: fmt::Debug,
520 {
521 pub fn rustls_021(
523 self,
524 config: ServerConfig,
525 ) -> impl ServiceFactory<
526 TcpStream,
527 Config = (),
528 Response = (),
529 Error = TlsError<io::Error, DispatchError>,
530 InitError = (),
531 > {
532 self.rustls_021_with_config(config, TlsAcceptorConfig::default())
533 }
534
535 pub fn rustls_021_with_config(
537 self,
538 mut config: ServerConfig,
539 tls_acceptor_config: TlsAcceptorConfig,
540 ) -> impl ServiceFactory<
541 TcpStream,
542 Config = (),
543 Response = (),
544 Error = TlsError<io::Error, DispatchError>,
545 InitError = (),
546 > {
547 let tcp_nodelay = self.cfg.tcp_nodelay();
548 let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
549 protos.extend_from_slice(&config.alpn_protocols);
550 config.alpn_protocols = protos;
551
552 let mut acceptor = Acceptor::new(config);
553
554 if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
555 acceptor.set_handshake_timeout(handshake_timeout);
556 }
557
558 acceptor
559 .map_init_err(|_| {
560 unreachable!("TLS acceptor service factory does not error on init")
561 })
562 .map_err(TlsError::into_service_error)
563 .and_then(move |io: TlsStream<TcpStream>| async move {
564 let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
565 if protos.windows(2).any(|window| window == b"h2") {
566 Protocol::Http2
567 } else {
568 Protocol::Http1
569 }
570 } else {
571 Protocol::Http1
572 };
573
574 if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
575 set_nodelay(io.get_ref().0, nodelay);
576 }
577
578 let peer_addr = io.get_ref().0.peer_addr().ok();
579 Ok((io, proto, peer_addr))
580 })
581 .and_then(self.map_err(TlsError::Service))
582 }
583 }
584}
585
586#[cfg(feature = "rustls-0_22")]
587mod rustls_0_22 {
588 use std::io;
589
590 use actix_service::ServiceFactoryExt as _;
591 use actix_tls::accept::{
592 rustls_0_22::{reexports::ServerConfig, Acceptor, TlsStream},
593 TlsError,
594 };
595
596 use super::*;
597
598 impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
599 where
600 S: ServiceFactory<Request, Config = ()>,
601 S::Future: 'static,
602 S::Error: Into<Response<BoxBody>> + 'static,
603 S::InitError: fmt::Debug,
604 S::Response: Into<Response<B>> + 'static,
605 <S::Service as Service<Request>>::Future: 'static,
606
607 B: MessageBody + 'static,
608
609 X: ServiceFactory<Request, Config = (), Response = Request>,
610 X::Future: 'static,
611 X::Error: Into<Response<BoxBody>>,
612 X::InitError: fmt::Debug,
613
614 U: ServiceFactory<
615 (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
616 Config = (),
617 Response = (),
618 >,
619 U::Future: 'static,
620 U::Error: fmt::Display + Into<Response<BoxBody>>,
621 U::InitError: fmt::Debug,
622 {
623 pub fn rustls_0_22(
625 self,
626 config: ServerConfig,
627 ) -> impl ServiceFactory<
628 TcpStream,
629 Config = (),
630 Response = (),
631 Error = TlsError<io::Error, DispatchError>,
632 InitError = (),
633 > {
634 self.rustls_0_22_with_config(config, TlsAcceptorConfig::default())
635 }
636
637 pub fn rustls_0_22_with_config(
639 self,
640 mut config: ServerConfig,
641 tls_acceptor_config: TlsAcceptorConfig,
642 ) -> impl ServiceFactory<
643 TcpStream,
644 Config = (),
645 Response = (),
646 Error = TlsError<io::Error, DispatchError>,
647 InitError = (),
648 > {
649 let tcp_nodelay = self.cfg.tcp_nodelay();
650 let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
651 protos.extend_from_slice(&config.alpn_protocols);
652 config.alpn_protocols = protos;
653
654 let mut acceptor = Acceptor::new(config);
655
656 if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
657 acceptor.set_handshake_timeout(handshake_timeout);
658 }
659
660 acceptor
661 .map_init_err(|_| {
662 unreachable!("TLS acceptor service factory does not error on init")
663 })
664 .map_err(TlsError::into_service_error)
665 .and_then(move |io: TlsStream<TcpStream>| async move {
666 let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
667 if protos.windows(2).any(|window| window == b"h2") {
668 Protocol::Http2
669 } else {
670 Protocol::Http1
671 }
672 } else {
673 Protocol::Http1
674 };
675
676 if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
677 set_nodelay(io.get_ref().0, nodelay);
678 }
679
680 let peer_addr = io.get_ref().0.peer_addr().ok();
681 Ok((io, proto, peer_addr))
682 })
683 .and_then(self.map_err(TlsError::Service))
684 }
685 }
686}
687
688#[cfg(feature = "rustls-0_23")]
689mod rustls_0_23 {
690 use std::io;
691
692 use actix_service::ServiceFactoryExt as _;
693 use actix_tls::accept::{
694 rustls_0_23::{reexports::ServerConfig, Acceptor, TlsStream},
695 TlsError,
696 };
697
698 use super::*;
699
700 impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
701 where
702 S: ServiceFactory<Request, Config = ()>,
703 S::Future: 'static,
704 S::Error: Into<Response<BoxBody>> + 'static,
705 S::InitError: fmt::Debug,
706 S::Response: Into<Response<B>> + 'static,
707 <S::Service as Service<Request>>::Future: 'static,
708
709 B: MessageBody + 'static,
710
711 X: ServiceFactory<Request, Config = (), Response = Request>,
712 X::Future: 'static,
713 X::Error: Into<Response<BoxBody>>,
714 X::InitError: fmt::Debug,
715
716 U: ServiceFactory<
717 (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
718 Config = (),
719 Response = (),
720 >,
721 U::Future: 'static,
722 U::Error: fmt::Display + Into<Response<BoxBody>>,
723 U::InitError: fmt::Debug,
724 {
725 pub fn rustls_0_23(
727 self,
728 config: ServerConfig,
729 ) -> impl ServiceFactory<
730 TcpStream,
731 Config = (),
732 Response = (),
733 Error = TlsError<io::Error, DispatchError>,
734 InitError = (),
735 > {
736 self.rustls_0_23_with_config(config, TlsAcceptorConfig::default())
737 }
738
739 pub fn rustls_0_23_with_config(
741 self,
742 mut config: ServerConfig,
743 tls_acceptor_config: TlsAcceptorConfig,
744 ) -> impl ServiceFactory<
745 TcpStream,
746 Config = (),
747 Response = (),
748 Error = TlsError<io::Error, DispatchError>,
749 InitError = (),
750 > {
751 let tcp_nodelay = self.cfg.tcp_nodelay();
752 let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
753 protos.extend_from_slice(&config.alpn_protocols);
754 config.alpn_protocols = protos;
755
756 let mut acceptor = Acceptor::new(config);
757
758 if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
759 acceptor.set_handshake_timeout(handshake_timeout);
760 }
761
762 acceptor
763 .map_init_err(|_| {
764 unreachable!("TLS acceptor service factory does not error on init")
765 })
766 .map_err(TlsError::into_service_error)
767 .and_then(move |io: TlsStream<TcpStream>| async move {
768 let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
769 if protos.windows(2).any(|window| window == b"h2") {
770 Protocol::Http2
771 } else {
772 Protocol::Http1
773 }
774 } else {
775 Protocol::Http1
776 };
777
778 if let Some(nodelay) = desired_nodelay(tcp_nodelay) {
779 set_nodelay(io.get_ref().0, nodelay);
780 }
781
782 let peer_addr = io.get_ref().0.peer_addr().ok();
783 Ok((io, proto, peer_addr))
784 })
785 .and_then(self.map_err(TlsError::Service))
786 }
787 }
788}
789
790impl<T, S, B, X, U> ServiceFactory<(T, Protocol, Option<net::SocketAddr>)>
791 for HttpService<T, S, B, X, U>
792where
793 T: AsyncRead + AsyncWrite + Unpin + 'static,
794
795 S: ServiceFactory<Request, Config = ()>,
796 S::Future: 'static,
797 S::Error: Into<Response<BoxBody>> + 'static,
798 S::InitError: fmt::Debug,
799 S::Response: Into<Response<B>> + 'static,
800 <S::Service as Service<Request>>::Future: 'static,
801
802 B: MessageBody + 'static,
803
804 X: ServiceFactory<Request, Config = (), Response = Request>,
805 X::Future: 'static,
806 X::Error: Into<Response<BoxBody>>,
807 X::InitError: fmt::Debug,
808
809 U: ServiceFactory<(Request, Framed<T, h1::Codec>), Config = (), Response = ()>,
810 U::Future: 'static,
811 U::Error: fmt::Display + Into<Response<BoxBody>>,
812 U::InitError: fmt::Debug,
813{
814 type Response = ();
815 type Error = DispatchError;
816 type Config = ();
817 type Service = HttpServiceHandler<T, S::Service, B, X::Service, U::Service>;
818 type InitError = ();
819 type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
820
821 fn new_service(&self, _: ()) -> Self::Future {
822 let service = self.srv.new_service(());
823 let expect = self.expect.new_service(());
824 let upgrade = self.upgrade.as_ref().map(|s| s.new_service(()));
825 let on_connect_ext = self.on_connect_ext.clone();
826 let cfg = self.cfg.clone();
827
828 Box::pin(async move {
829 let expect = expect.await.map_err(|err| {
830 tracing::error!("Initialization of HTTP expect service error: {err:?}");
831 })?;
832
833 let upgrade = match upgrade {
834 Some(upgrade) => {
835 let upgrade = upgrade.await.map_err(|err| {
836 tracing::error!("Initialization of HTTP upgrade service error: {err:?}");
837 })?;
838 Some(upgrade)
839 }
840 None => None,
841 };
842
843 let service = service.await.map_err(|err| {
844 tracing::error!("Initialization of HTTP service error: {err:?}");
845 })?;
846
847 Ok(HttpServiceHandler::new(
848 cfg,
849 service,
850 expect,
851 upgrade,
852 on_connect_ext,
853 ))
854 })
855 }
856}
857
858pub struct HttpServiceHandler<T, S, B, X, U>
860where
861 S: Service<Request>,
862 X: Service<Request>,
863 U: Service<(Request, Framed<T, h1::Codec>)>,
864{
865 pub(super) flow: Rc<HttpFlow<S, X, U>>,
866 pub(super) cfg: ServiceConfig,
867 pub(super) on_connect_ext: Option<Rc<ConnectCallback<T>>>,
868 _phantom: PhantomData<B>,
869}
870
871impl<T, S, B, X, U> HttpServiceHandler<T, S, B, X, U>
872where
873 S: Service<Request>,
874 S::Error: Into<Response<BoxBody>>,
875 X: Service<Request>,
876 X::Error: Into<Response<BoxBody>>,
877 U: Service<(Request, Framed<T, h1::Codec>)>,
878 U::Error: Into<Response<BoxBody>>,
879{
880 pub(super) fn new(
881 cfg: ServiceConfig,
882 service: S,
883 expect: X,
884 upgrade: Option<U>,
885 on_connect_ext: Option<Rc<ConnectCallback<T>>>,
886 ) -> HttpServiceHandler<T, S, B, X, U> {
887 HttpServiceHandler {
888 cfg,
889 on_connect_ext,
890 flow: HttpFlow::new(service, expect, upgrade),
891 _phantom: PhantomData,
892 }
893 }
894
895 pub(super) fn _poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Response<BoxBody>>> {
896 ready!(self.flow.expect.poll_ready(cx).map_err(Into::into))?;
897
898 ready!(self.flow.service.poll_ready(cx).map_err(Into::into))?;
899
900 if let Some(ref upg) = self.flow.upgrade {
901 ready!(upg.poll_ready(cx).map_err(Into::into))?;
902 };
903
904 Poll::Ready(Ok(()))
905 }
906}
907
908pub(super) struct HttpFlow<S, X, U> {
910 pub(super) service: S,
911 pub(super) expect: X,
912 pub(super) upgrade: Option<U>,
913}
914
915impl<S, X, U> HttpFlow<S, X, U> {
916 pub(super) fn new(service: S, expect: X, upgrade: Option<U>) -> Rc<Self> {
917 Rc::new(Self {
918 service,
919 expect,
920 upgrade,
921 })
922 }
923}
924
925impl<T, S, B, X, U> Service<(T, Protocol, Option<net::SocketAddr>)>
926 for HttpServiceHandler<T, S, B, X, U>
927where
928 T: AsyncRead + AsyncWrite + Unpin,
929
930 S: Service<Request>,
931 S::Error: Into<Response<BoxBody>> + 'static,
932 S::Future: 'static,
933 S::Response: Into<Response<B>> + 'static,
934
935 B: MessageBody + 'static,
936
937 X: Service<Request, Response = Request>,
938 X::Error: Into<Response<BoxBody>>,
939
940 U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
941 U::Error: fmt::Display + Into<Response<BoxBody>>,
942{
943 type Response = ();
944 type Error = DispatchError;
945 type Future = HttpServiceHandlerResponse<T, S, B, X, U>;
946
947 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
948 self._poll_ready(cx).map_err(|err| {
949 error!("HTTP service readiness error: {:?}", err);
950 DispatchError::Service(err)
951 })
952 }
953
954 fn call(&self, (io, proto, peer_addr): (T, Protocol, Option<net::SocketAddr>)) -> Self::Future {
955 let conn_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
956
957 match proto {
958 #[cfg(feature = "http2")]
959 Protocol::Http2 => HttpServiceHandlerResponse {
960 state: State::H2Handshake {
961 handshake: Some((
962 crate::h2::handshake_with_timeout(io, &self.cfg),
963 self.cfg.clone(),
964 Rc::clone(&self.flow),
965 conn_data,
966 peer_addr,
967 )),
968 },
969 },
970
971 #[cfg(not(feature = "http2"))]
972 Protocol::Http2 => {
973 panic!("HTTP/2 support is disabled (enable with the `http2` feature flag)")
974 }
975
976 Protocol::Http1 => HttpServiceHandlerResponse {
977 state: State::H1 {
978 dispatcher: h1::Dispatcher::new(
979 io,
980 Rc::clone(&self.flow),
981 self.cfg.clone(),
982 peer_addr,
983 conn_data,
984 ),
985 },
986 },
987
988 proto => unimplemented!("Unsupported HTTP version: {:?}.", proto),
989 }
990 }
991}
992
993#[cfg(not(feature = "http2"))]
994pin_project! {
995 #[project = StateProj]
996 enum State<T, S, B, X, U>
997 where
998 T: AsyncRead,
999 T: AsyncWrite,
1000 T: Unpin,
1001
1002 S: Service<Request>,
1003 S::Future: 'static,
1004 S::Error: Into<Response<BoxBody>>,
1005
1006 B: MessageBody,
1007
1008 X: Service<Request, Response = Request>,
1009 X::Error: Into<Response<BoxBody>>,
1010
1011 U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
1012 U::Error: fmt::Display,
1013 {
1014 H1 { #[pin] dispatcher: h1::Dispatcher<T, S, B, X, U> },
1015 }
1016}
1017
1018#[cfg(feature = "http2")]
1019pin_project! {
1020 #[project = StateProj]
1021 enum State<T, S, B, X, U>
1022 where
1023 T: AsyncRead,
1024 T: AsyncWrite,
1025 T: Unpin,
1026
1027 S: Service<Request>,
1028 S::Future: 'static,
1029 S::Error: Into<Response<BoxBody>>,
1030
1031 B: MessageBody,
1032
1033 X: Service<Request, Response = Request>,
1034 X::Error: Into<Response<BoxBody>>,
1035
1036 U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
1037 U::Error: fmt::Display,
1038 {
1039 H1 { #[pin] dispatcher: h1::Dispatcher<T, S, B, X, U> },
1040
1041 H2 { #[pin] dispatcher: crate::h2::Dispatcher<T, S, B, X, U> },
1042
1043 H2Handshake {
1044 handshake: Option<(
1045 crate::h2::HandshakeWithTimeout<T>,
1046 ServiceConfig,
1047 Rc<HttpFlow<S, X, U>>,
1048 OnConnectData,
1049 Option<net::SocketAddr>,
1050 )>,
1051 },
1052 }
1053}
1054
1055pin_project! {
1056 pub struct HttpServiceHandlerResponse<T, S, B, X, U>
1057 where
1058 T: AsyncRead,
1059 T: AsyncWrite,
1060 T: Unpin,
1061
1062 S: Service<Request>,
1063 S::Error: Into<Response<BoxBody>>,
1064 S::Error: 'static,
1065 S::Future: 'static,
1066 S::Response: Into<Response<B>>,
1067 S::Response: 'static,
1068
1069 B: MessageBody,
1070
1071 X: Service<Request, Response = Request>,
1072 X::Error: Into<Response<BoxBody>>,
1073
1074 U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
1075 U::Error: fmt::Display,
1076 {
1077 #[pin]
1078 state: State<T, S, B, X, U>,
1079 }
1080}
1081
1082impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
1083where
1084 T: AsyncRead + AsyncWrite + Unpin,
1085
1086 S: Service<Request>,
1087 S::Error: Into<Response<BoxBody>> + 'static,
1088 S::Future: 'static,
1089 S::Response: Into<Response<B>> + 'static,
1090
1091 B: MessageBody + 'static,
1092
1093 X: Service<Request, Response = Request>,
1094 X::Error: Into<Response<BoxBody>>,
1095
1096 U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
1097 U::Error: fmt::Display,
1098{
1099 type Output = Result<(), DispatchError>;
1100
1101 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1102 match self.as_mut().project().state.project() {
1103 StateProj::H1 { dispatcher } => dispatcher.poll(cx),
1104
1105 #[cfg(feature = "http2")]
1106 StateProj::H2 { dispatcher } => dispatcher.poll(cx),
1107
1108 #[cfg(feature = "http2")]
1109 StateProj::H2Handshake { handshake: data } => {
1110 match ready!(Pin::new(&mut data.as_mut().unwrap().0).poll(cx)) {
1111 Ok((conn, timer)) => {
1112 let (_, config, flow, conn_data, peer_addr) = data.take().unwrap();
1113
1114 self.as_mut().project().state.set(State::H2 {
1115 dispatcher: crate::h2::Dispatcher::new(
1116 conn, flow, config, peer_addr, conn_data, timer,
1117 ),
1118 });
1119 self.poll(cx)
1120 }
1121 Err(err) => {
1122 tracing::trace!("H2 handshake error: {}", err);
1123 Poll::Ready(Err(err))
1124 }
1125 }
1126 }
1127 }
1128 }
1129}