use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;
use crate::config::TlsConfig;
use crate::error::Result;
pub struct TlsListener {
tcp: TcpListener,
acceptor: TlsAcceptor,
}
impl TlsListener {
pub fn new(tcp: TcpListener, server_config: Arc<ServerConfig>) -> Self {
Self {
tcp,
acceptor: TlsAcceptor::from(server_config),
}
}
}
impl axum::serve::Listener for TlsListener {
type Io = TlsStream<TcpStream>;
type Addr = SocketAddr;
fn accept(&mut self) -> impl std::future::Future<Output = (Self::Io, Self::Addr)> + Send {
let acceptor = self.acceptor.clone();
let tcp = &mut self.tcp;
async move {
loop {
let (stream, addr) = match TcpListener::accept(tcp).await {
Ok((stream, addr)) => (stream, addr),
Err(e) => {
tracing::error!("TCP accept error: {}", e);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
continue;
}
};
match acceptor.accept(stream).await {
Ok(tls_stream) => return (tls_stream, addr),
Err(e) => {
tracing::warn!("TLS handshake failed from {}: {}", addr, e);
continue;
}
}
}
}
}
fn local_addr(&self) -> io::Result<Self::Addr> {
self.tcp.local_addr()
}
}
pub fn load_server_config(tls_config: &TlsConfig) -> Result<Arc<ServerConfig>> {
use rustls_pemfile::{certs, private_key};
use std::fs::File;
use std::io::BufReader;
use tokio_rustls::rustls;
let cert_file = File::open(&tls_config.cert_path).map_err(|e| {
crate::error::Error::Internal(format!(
"Failed to open TLS cert file '{}': {}",
tls_config.cert_path.display(),
e
))
})?;
let mut cert_reader = BufReader::new(cert_file);
let cert_chain: Vec<rustls::pki_types::CertificateDer<'static>> = certs(&mut cert_reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
crate::error::Error::Internal(format!("Failed to parse TLS certificates: {}", e))
})?;
if cert_chain.is_empty() {
return Err(crate::error::Error::Internal(
"TLS cert file contains no certificates".to_string(),
));
}
let key_file = File::open(&tls_config.key_path).map_err(|e| {
crate::error::Error::Internal(format!(
"Failed to open TLS key file '{}': {}",
tls_config.key_path.display(),
e
))
})?;
let mut key_reader = BufReader::new(key_file);
let key = private_key(&mut key_reader)
.map_err(|e| {
crate::error::Error::Internal(format!("Failed to parse TLS private key: {}", e))
})?
.ok_or_else(|| {
crate::error::Error::Internal("TLS key file contains no private key".to_string())
})?;
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, key)
.map_err(|e| {
crate::error::Error::Internal(format!("Failed to build TLS server config: {}", e))
})?;
Ok(Arc::new(config))
}