flexible_hyper_server_tls/
accept.rs

1use std::net::SocketAddr;
2
3use hyper::body::{Body, Incoming};
4use hyper::server::conn::http1;
5use hyper::service::HttpService;
6use hyper_util::rt::TokioIo;
7use thiserror::Error;
8use tokio::net::{TcpListener, TcpStream};
9use tokio_rustls::TlsAcceptor;
10
11use crate::stream::HttpOrHttpsStream;
12
13/// Accept either an HTTP or HTTPS connection using Hyper
14pub struct HttpOrHttpsAcceptor {
15    listener: TcpListener,
16    tls: Option<TlsAcceptor>,
17}
18
19impl HttpOrHttpsAcceptor {
20    /// Creates a new [`HttpOrHttpsAcceptor`] configured to only serve HTTP
21    pub const fn new(listener: TcpListener) -> Self {
22        Self {
23            listener,
24            tls: None,
25        }
26    }
27
28    /// Configures this [`HttpOrHttpsAcceptor`] to serve HTTPS using the provided [`TlsAcceptor`]
29    ///
30    /// If you need to create a [`TlsAcceptor`], see the helper functions in [`rustls_helpers`](crate::rustls_helpers)
31    #[must_use]
32    pub fn with_tls(mut self, tls: TlsAcceptor) -> Self {
33        self.tls = Some(tls);
34        self
35    }
36
37    /// Accepts a singular connection.
38    /// Returns a the peer address of the connected client and a future that MUST be spawned to serve the connection.
39    ///
40    /// # Errors
41    /// The function will return an error if the TCP connection fails, the returned future will return an error if the TLS handshake or Hyper service fails.
42    pub async fn accept<S>(
43        &self,
44        service: S,
45    ) -> Result<
46        (
47            SocketAddr,
48            impl Future<Output = Result<(), AcceptorError>> + use<S>,
49        ),
50        AcceptorError,
51    >
52    where
53        S: HttpService<Incoming> + 'static,
54        <S::ResBody as Body>::Error: std::error::Error + Send + Sync,
55    {
56        match self.listener.accept().await {
57            Ok((stream, peer_addr)) => {
58                // The TlsAcceptor is a wrapper around an Arc, so this is relatively cheap
59                let cloned_tls = self.tls.clone();
60
61                let conn_fut = handle_conn(stream, cloned_tls, service);
62                Ok((peer_addr, conn_fut))
63            }
64            Err(e) => Err(AcceptorError::TcpConnect(e)),
65        }
66    }
67}
68
69async fn handle_conn<S>(
70    stream: TcpStream,
71    tls: Option<TlsAcceptor>,
72    handler: S,
73) -> Result<(), AcceptorError>
74where
75    S: HttpService<Incoming>,
76    S::ResBody: 'static,
77    <S::ResBody as Body>::Error: std::error::Error + Send + Sync,
78{
79    let client = match tls {
80        None => HttpOrHttpsStream::Http(stream),
81        Some(tls) => {
82            let tls_stream = tls
83                .accept(stream)
84                .await
85                .map_err(AcceptorError::TlsHandshake)?;
86            HttpOrHttpsStream::Https(tls_stream)
87        }
88    };
89
90    // Use `with_upgrades` to allow usage of websockets in client code
91    http1::Builder::new()
92        .serve_connection(TokioIo::new(client), handler)
93        .with_upgrades()
94        .await
95        .map_err(AcceptorError::Hyper)
96}
97
98/// Error when accepting connections
99#[derive(Error, Debug)]
100pub enum AcceptorError {
101    /// Failed to connect to client over TCP
102    #[error("TCP connection to client failed")]
103    TcpConnect(#[source] std::io::Error),
104    /// Failed to make TLS handshake with client
105    #[error("TLS handshake with client failed")]
106    TlsHandshake(#[source] std::io::Error),
107    /// Hyper failed to serve connection
108    #[error("Failed to serve HTTP connection")]
109    Hyper(#[source] hyper::Error),
110}