axum_util/
tls_acceptor.rs1use std::{
2 net::SocketAddr,
3 pin::Pin,
4 sync::Arc,
5 task::{Context, Poll},
6 time::Duration,
7};
8
9use anyhow::Result;
10use futures::Stream;
11use hyper::server::{
12 accept::Accept,
13 conn::{AddrIncoming, AddrStream},
14};
15use log::{error, warn};
16use rustls::{server::Acceptor, ServerConfig};
17use tokio::sync::{mpsc, watch};
18use tokio_rustls::{server::TlsStream, LazyConfigAcceptor};
19use tokio_stream::{wrappers::ReceiverStream, StreamExt};
20
21pub struct TlsIncoming {
22 incoming: StreamWrapper,
23 tls_config: watch::Receiver<Option<Arc<ServerConfig>>>,
24}
25
26struct StreamWrapper(AddrIncoming);
27
28impl Stream for StreamWrapper {
29 type Item = Result<AddrStream, std::io::Error>;
30
31 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
32 Pin::new(&mut self.0).poll_accept(cx)
33 }
34}
35
36impl TlsIncoming {
37 pub fn new(
38 listen: SocketAddr,
39 nodelay: bool,
40 keepalive: Option<Duration>,
41 tls_config: watch::Receiver<Option<Arc<ServerConfig>>>,
42 ) -> Result<Self> {
43 let mut incoming = AddrIncoming::bind(&listen)?;
44 incoming.set_nodelay(nodelay);
45 incoming.set_keepalive(keepalive);
46
47 Ok(Self {
48 incoming: StreamWrapper(incoming),
49 tls_config,
50 })
51 }
52
53 pub fn start(mut self) -> impl Stream<Item = Result<TlsStream<AddrStream>, std::io::Error>> {
54 let (sender, receiver) = mpsc::channel::<Result<TlsStream<AddrStream>, std::io::Error>>(10);
55 tokio::spawn(async move {
56 loop {
57 let client = match self.incoming.next().await {
58 Some(Ok(x)) => x,
59 Some(Err(e)) => {
60 error!("error during accepting TCP client: {e}");
61 continue;
62 }
63 None => break,
64 };
65 let Some(server_config) = self.tls_config.borrow().clone() else {
66 warn!("inbound TLS connection dropped (no certificates loaded, but were configured)");
67 continue
68 };
69
70 let lazy = LazyConfigAcceptor::new(Acceptor::default(), client);
71 let sender = sender.clone();
72 tokio::spawn(async move {
73 let accepted = match lazy.await {
74 Ok(x) => x,
75 Err(e) => {
76 error!("error during TLS init: {e}");
77 return;
78 }
79 };
80 let tls_stream = accepted.into_stream(server_config).await;
81 if sender.send(tls_stream).await.is_err() {
82 error!("TLS acceptor hung");
83 }
84 });
85 }
86 });
87 ReceiverStream::new(receiver)
88 }
89}