use std::{fs::File, io::BufReader, sync::Arc};
use tokio_rustls::{
TlsAcceptor,
rustls::{
RootCertStore, ServerConfig,
pki_types::{CertificateDer, PrivateKeyDer},
server::WebPkiClientVerifier,
},
};
use crate::{WaeError, WaeResult};
pub mod alpn {
pub const HTTP_1_1: &[u8] = b"http/1.1";
pub const HTTP_2: &[u8] = b"h2";
}
pub fn create_tls_acceptor(cert_path: &str, key_path: &str) -> WaeResult<TlsAcceptor> {
create_tls_acceptor_with_http2(cert_path, key_path, false)
}
pub fn create_tls_acceptor_with_http2(cert_path: &str, key_path: &str, enable_http2: bool) -> WaeResult<TlsAcceptor> {
let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?;
let alpn_protocols =
if enable_http2 { vec![alpn::HTTP_2.to_vec(), alpn::HTTP_1_1.to_vec()] } else { vec![alpn::HTTP_1_1.to_vec()] };
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| WaeError::internal(format!("Failed to create TLS config: {}", e)))?;
let mut config = Arc::new(config);
Arc::get_mut(&mut config).expect("Config should be unique").alpn_protocols = alpn_protocols;
Ok(TlsAcceptor::from(config))
}
pub fn create_tls_acceptor_with_client_auth(
cert_path: &str,
key_path: &str,
ca_path: &str,
enable_http2: bool,
) -> WaeResult<TlsAcceptor> {
let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?;
let ca_certs = load_certs(ca_path)?;
let mut root_cert_store = RootCertStore::empty();
for cert in ca_certs {
root_cert_store.add(cert).map_err(|e| WaeError::internal(format!("Failed to add CA cert: {}", e)))?;
}
let client_verifier = WebPkiClientVerifier::builder(Arc::new(root_cert_store))
.build()
.map_err(|e| WaeError::internal(format!("Failed to create client verifier: {}", e)))?;
let alpn_protocols =
if enable_http2 { vec![alpn::HTTP_2.to_vec(), alpn::HTTP_1_1.to_vec()] } else { vec![alpn::HTTP_1_1.to_vec()] };
let config = ServerConfig::builder()
.with_client_cert_verifier(client_verifier)
.with_single_cert(certs, key)
.map_err(|e| WaeError::internal(format!("Failed to create TLS config: {}", e)))?;
let mut config = Arc::new(config);
Arc::get_mut(&mut config).expect("Config should be unique").alpn_protocols = alpn_protocols;
Ok(TlsAcceptor::from(config))
}
fn load_certs(path: &str) -> WaeResult<Vec<CertificateDer<'static>>> {
let file = File::open(path).map_err(|e| WaeError::internal(format!("Failed to open cert file {}: {}", path, e)))?;
let mut reader = BufReader::new(file);
rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| WaeError::internal(format!("Failed to parse cert file {}: {}", path, e)))
}
fn load_private_key(path: &str) -> WaeResult<PrivateKeyDer<'static>> {
let file = File::open(path).map_err(|e| WaeError::internal(format!("Failed to open key file {}: {}", path, e)))?;
let mut reader = BufReader::new(file);
let keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::private_key(&mut reader)
.map_err(|e| WaeError::internal(format!("Failed to parse key file {}: {}", path, e)))?
.into_iter()
.collect();
keys.into_iter().next().ok_or_else(|| WaeError::internal(format!("No private key found in {}", path)))
}
pub struct TlsConfigBuilder {
cert_path: Option<String>,
key_path: Option<String>,
ca_path: Option<String>,
enable_http2: bool,
}
impl TlsConfigBuilder {
pub fn new() -> Self {
Self { cert_path: None, key_path: None, ca_path: None, enable_http2: true }
}
pub fn cert_path(mut self, path: impl Into<String>) -> Self {
self.cert_path = Some(path.into());
self
}
pub fn key_path(mut self, path: impl Into<String>) -> Self {
self.key_path = Some(path.into());
self
}
pub fn ca_path(mut self, path: impl Into<String>) -> Self {
self.ca_path = Some(path.into());
self
}
pub fn enable_http2(mut self, enable: bool) -> Self {
self.enable_http2 = enable;
self
}
pub fn build(self) -> WaeResult<TlsAcceptor> {
let cert_path = self.cert_path.ok_or_else(|| WaeError::internal("Certificate path is required"))?;
let key_path = self.key_path.ok_or_else(|| WaeError::internal("Key path is required"))?;
match self.ca_path {
Some(ca_path) => create_tls_acceptor_with_client_auth(&cert_path, &key_path, &ca_path, self.enable_http2),
None => create_tls_acceptor_with_http2(&cert_path, &key_path, self.enable_http2),
}
}
}
impl Default for TlsConfigBuilder {
fn default() -> Self {
Self::new()
}
}