use axum::extract::ConnectInfo;
use axum::http::Request;
use axum::{Extension, Router};
use futures_util::FutureExt;
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
use rustls_acme::is_tls_alpn_challenge;
use rustls_acme::rustls::ServerConfig;
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::sync::watch::{Receiver, Sender};
use tokio_rustls::StartHandshake;
use tokio_rustls::server::TlsStream;
use tower::{Layer, Service};
use tracing::{Instrument, Span};
pub async fn handle_stream(
span: &Span,
tower_service: Router,
addr: SocketAddr,
stream: TokioIo<TlsStream<TcpStream>>,
signal_tx: &Sender<()>,
close_rx: &Receiver<()>,
terminate_rx: Option<&mut Receiver<bool>>,
) {
let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
let mut tower_service = Extension(ConnectInfo(addr)).layer(tower_service.clone());
tower_service.call(request)
});
let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
builder.http2().enable_connect_protocol();
let signal_tx = signal_tx.clone();
let close_rx = close_rx.clone();
let span = span.clone();
let mut conn = pin!(builder.serve_connection_with_upgrades(stream, hyper_service));
let mut signal_closed = pin!(signal_tx.closed().fuse());
if let Some(terminate_rx) = terminate_rx {
loop {
tokio::select! {
result = conn.as_mut() => {
match result {
Ok(()) => span.in_scope(|| tracing::info!("closed")),
Err(err) => span.in_scope(|| tracing::warn!(%err, "connection")),
}
break;
}
() = &mut signal_closed => {
span.in_scope(|| tracing::warn!("shutdown"));
conn.as_mut().graceful_shutdown();
}
_ = terminate_rx.changed() => {
span.in_scope(|| tracing::warn!("shutdown"));
conn.as_mut().graceful_shutdown();
}
}
}
} else {
loop {
tokio::select! {
result = conn.as_mut() => {
match result {
Ok(()) => span.in_scope(|| tracing::info!("closed")),
Err(err) => span.in_scope(|| tracing::warn!(%err, "connection")),
}
break;
}
() = &mut signal_closed => {
span.in_scope(|| tracing::warn!("shutdown"));
conn.as_mut().graceful_shutdown();
}
}
}
}
drop(close_rx);
}
#[allow(clippy::too_many_arguments)]
pub async fn handle_handshake(
tls_span: &Span,
acme_span: &Span,
start_handshake: StartHandshake<TcpStream>,
tower_service: Router,
addr: SocketAddr,
challenge_rustls_config: Arc<ServerConfig>,
default_rustls_config: Arc<ServerConfig>,
signal_tx: &Sender<()>,
close_rx: &Receiver<()>,
terminate_rx: Option<&mut Receiver<bool>>,
) -> anyhow::Result<()> {
if is_tls_alpn_challenge(&start_handshake.client_hello()) {
async {
tracing::info!(evt = %"validation", "TLS-ALPN-01");
let mut tls = start_handshake.into_stream(challenge_rustls_config).await?;
tls.shutdown().await?;
anyhow::Ok(())
}
.instrument(acme_span.clone())
.await?;
} else {
tls_span.in_scope(|| tracing::info!("connecting"));
let as_stream = start_handshake.into_stream(default_rustls_config).await?;
let stream = TokioIo::new(as_stream);
handle_stream(
tls_span,
tower_service,
addr,
stream,
signal_tx,
close_rx,
terminate_rx,
)
.await;
}
Ok(())
}