axum_util/
tls_acceptor.rs

1use std::{
2    net::SocketAddr,
3    pin::Pin,
4    sync::Arc,
5    task::{Context, Poll},
6    time::Duration,
7};
8
9use anyhow::Result;
10use futures::Stream;
11use hyper::server::{
12    accept::Accept,
13    conn::{AddrIncoming, AddrStream},
14};
15use log::{error, warn};
16use rustls::{server::Acceptor, ServerConfig};
17use tokio::sync::{mpsc, watch};
18use tokio_rustls::{server::TlsStream, LazyConfigAcceptor};
19use tokio_stream::{wrappers::ReceiverStream, StreamExt};
20
21pub struct TlsIncoming {
22    incoming: StreamWrapper,
23    tls_config: watch::Receiver<Option<Arc<ServerConfig>>>,
24}
25
26struct StreamWrapper(AddrIncoming);
27
28impl Stream for StreamWrapper {
29    type Item = Result<AddrStream, std::io::Error>;
30
31    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
32        Pin::new(&mut self.0).poll_accept(cx)
33    }
34}
35
36impl TlsIncoming {
37    pub fn new(
38        listen: SocketAddr,
39        nodelay: bool,
40        keepalive: Option<Duration>,
41        tls_config: watch::Receiver<Option<Arc<ServerConfig>>>,
42    ) -> Result<Self> {
43        let mut incoming = AddrIncoming::bind(&listen)?;
44        incoming.set_nodelay(nodelay);
45        incoming.set_keepalive(keepalive);
46
47        Ok(Self {
48            incoming: StreamWrapper(incoming),
49            tls_config,
50        })
51    }
52
53    pub fn start(mut self) -> impl Stream<Item = Result<TlsStream<AddrStream>, std::io::Error>> {
54        let (sender, receiver) = mpsc::channel::<Result<TlsStream<AddrStream>, std::io::Error>>(10);
55        tokio::spawn(async move {
56            loop {
57                let client = match self.incoming.next().await {
58                    Some(Ok(x)) => x,
59                    Some(Err(e)) => {
60                        error!("error during accepting TCP client: {e}");
61                        continue;
62                    }
63                    None => break,
64                };
65                let Some(server_config) = self.tls_config.borrow().clone() else {
66                    warn!("inbound TLS connection dropped (no certificates loaded, but were configured)");
67                    continue
68                };
69
70                let lazy = LazyConfigAcceptor::new(Acceptor::default(), client);
71                let sender = sender.clone();
72                tokio::spawn(async move {
73                    let accepted = match lazy.await {
74                        Ok(x) => x,
75                        Err(e) => {
76                            error!("error during TLS init: {e}");
77                            return;
78                        }
79                    };
80                    let tls_stream = accepted.into_stream(server_config).await;
81                    if sender.send(tls_stream).await.is_err() {
82                        error!("TLS acceptor hung");
83                    }
84                });
85            }
86        });
87        ReceiverStream::new(receiver)
88    }
89}