flexible_hyper_server_tls/
accept.rs1use std::net::SocketAddr;
2
3use hyper::body::{Body, Incoming};
4use hyper::server::conn::http1;
5use hyper::service::HttpService;
6use hyper_util::rt::TokioIo;
7use thiserror::Error;
8use tokio::net::{TcpListener, TcpStream};
9use tokio_rustls::TlsAcceptor;
10
11use crate::stream::HttpOrHttpsStream;
12
13pub struct HttpOrHttpsAcceptor {
15 listener: TcpListener,
16 tls: Option<TlsAcceptor>,
17}
18
19impl HttpOrHttpsAcceptor {
20 pub const fn new(listener: TcpListener) -> Self {
22 Self {
23 listener,
24 tls: None,
25 }
26 }
27
28 #[must_use]
32 pub fn with_tls(mut self, tls: TlsAcceptor) -> Self {
33 self.tls = Some(tls);
34 self
35 }
36
37 pub async fn accept<S>(
43 &self,
44 service: S,
45 ) -> Result<
46 (
47 SocketAddr,
48 impl Future<Output = Result<(), AcceptorError>> + use<S>,
49 ),
50 AcceptorError,
51 >
52 where
53 S: HttpService<Incoming> + 'static,
54 <S::ResBody as Body>::Error: std::error::Error + Send + Sync,
55 {
56 match self.listener.accept().await {
57 Ok((stream, peer_addr)) => {
58 let cloned_tls = self.tls.clone();
60
61 let conn_fut = handle_conn(stream, cloned_tls, service);
62 Ok((peer_addr, conn_fut))
63 }
64 Err(e) => Err(AcceptorError::TcpConnect(e)),
65 }
66 }
67}
68
69async fn handle_conn<S>(
70 stream: TcpStream,
71 tls: Option<TlsAcceptor>,
72 handler: S,
73) -> Result<(), AcceptorError>
74where
75 S: HttpService<Incoming>,
76 S::ResBody: 'static,
77 <S::ResBody as Body>::Error: std::error::Error + Send + Sync,
78{
79 let client = match tls {
80 None => HttpOrHttpsStream::Http(stream),
81 Some(tls) => {
82 let tls_stream = tls
83 .accept(stream)
84 .await
85 .map_err(AcceptorError::TlsHandshake)?;
86 HttpOrHttpsStream::Https(tls_stream)
87 }
88 };
89
90 http1::Builder::new()
92 .serve_connection(TokioIo::new(client), handler)
93 .with_upgrades()
94 .await
95 .map_err(AcceptorError::Hyper)
96}
97
98#[derive(Error, Debug)]
100pub enum AcceptorError {
101 #[error("TCP connection to client failed")]
103 TcpConnect(#[source] std::io::Error),
104 #[error("TLS handshake with client failed")]
106 TlsHandshake(#[source] std::io::Error),
107 #[error("Failed to serve HTTP connection")]
109 Hyper(#[source] hyper::Error),
110}