#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
use rustls::{Certificate, PrivateKey, ServerConfig};
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
use rustls_pemfile::{certs, pkcs8_private_keys};
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
use std::fs::File;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
use std::io::{self, BufReader, ErrorKind};
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
use std::path::Path;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
use std::sync::Arc;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
use tokio_rustls::TlsAcceptor;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
use tracing::{info, warn};
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub cert_file: String,
pub key_file: String,
pub alpn_protocols: Vec<Vec<u8>>,
pub client_cert_verification: bool,
pub min_tls_version: TlsVersion,
pub max_tls_version: TlsVersion,
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
#[derive(Debug, Clone)]
pub enum TlsVersion {
TlsV12,
TlsV13,
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
impl Default for TlsConfig {
fn default() -> Self {
Self {
cert_file: "cert.pem".to_string(),
key_file: "key.pem".to_string(),
alpn_protocols: vec![
b"h2".to_vec(), b"http/1.1".to_vec(), ],
client_cert_verification: false,
min_tls_version: TlsVersion::TlsV12,
max_tls_version: TlsVersion::TlsV13,
}
}
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
impl TlsConfig {
pub fn new(cert_file: impl Into<String>, key_file: impl Into<String>) -> Self {
Self {
cert_file: cert_file.into(),
key_file: key_file.into(),
..Default::default()
}
}
pub fn with_alpn_protocols(mut self, protocols: Vec<&str>) -> Self {
self.alpn_protocols = protocols
.into_iter()
.map(|p| p.as_bytes().to_vec())
.collect();
self
}
pub fn enable_client_cert_verification(mut self) -> Self {
self.client_cert_verification = true;
self
}
pub fn tls_versions(mut self, min: TlsVersion, max: TlsVersion) -> Self {
self.min_tls_version = min;
self.max_tls_version = max;
self
}
pub fn build(&self) -> Result<TlsAcceptor, TlsError> {
let certs = load_certs(&self.cert_file)?;
let key = load_private_key(&self.key_file)?;
let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| TlsError::Config(format!("Failed to configure TLS: {}", e)))?;
if !self.alpn_protocols.is_empty() {
config.alpn_protocols = self.alpn_protocols.clone();
}
info!("TLS configuration loaded successfully");
info!(
"ALPN protocols: {:?}",
self.alpn_protocols
.iter()
.map(|p| String::from_utf8_lossy(p))
.collect::<Vec<_>>()
);
Ok(TlsAcceptor::from(Arc::new(config)))
}
#[cfg(feature = "self-signed")]
#[cfg_attr(docsrs, doc(cfg(feature = "self-signed")))]
pub fn generate_self_signed(domain: &str) -> Result<(String, String), TlsError> {
use rcgen::{Certificate, CertificateParams, DistinguishedName};
let mut params = CertificateParams::new(vec![domain.to_string()]);
params.distinguished_name = DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, domain);
let cert = Certificate::from_params(params).map_err(|e| {
TlsError::CertGeneration(format!("Failed to generate certificate: {}", e))
})?;
let cert_pem = cert.serialize_pem().map_err(|e| {
TlsError::CertGeneration(format!("Failed to serialize certificate: {}", e))
})?;
let key_pem = cert.serialize_private_key_pem();
std::fs::write("self_signed_cert.pem", &cert_pem).map_err(TlsError::Io)?;
std::fs::write("self_signed_key.pem", &key_pem).map_err(TlsError::Io)?;
warn!(
"Generated self-signed certificate for '{}' - DO NOT USE IN PRODUCTION",
domain
);
Ok((cert_pem, key_pem))
}
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
#[derive(Debug, thiserror::Error)]
pub enum TlsError {
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("Configuration error: {0}")]
Config(String),
#[error("Certificate parsing error: {0}")]
CertParsing(String),
#[error("Key parsing error: {0}")]
KeyParsing(String),
#[cfg(feature = "self-signed")]
#[error("Certificate generation error: {0}")]
CertGeneration(String),
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
fn load_certs<P: AsRef<Path>>(path: P) -> Result<Vec<Certificate>, TlsError> {
let file = File::open(&path).map_err(|_| {
TlsError::Io(io::Error::new(
ErrorKind::NotFound,
format!("Certificate file not found: {}", path.as_ref().display()),
))
})?;
let mut reader = BufReader::new(file);
let certs = certs(&mut reader)
.map_err(|e| TlsError::CertParsing(format!("Failed to parse certificates: {}", e)))?;
if certs.is_empty() {
return Err(TlsError::CertParsing(
"No certificates found in file".into(),
));
}
Ok(certs.into_iter().map(Certificate).collect())
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
fn load_private_key<P: AsRef<Path>>(path: P) -> Result<PrivateKey, TlsError> {
let file = File::open(&path).map_err(|_| {
TlsError::Io(io::Error::new(
ErrorKind::NotFound,
format!("Private key file not found: {}", path.as_ref().display()),
))
})?;
let mut reader = BufReader::new(file);
let mut keys = pkcs8_private_keys(&mut reader)
.map_err(|e| TlsError::KeyParsing(format!("Failed to parse private key: {}", e)))?;
if keys.is_empty() {
return Err(TlsError::KeyParsing("No private keys found in file".into()));
}
if keys.len() > 1 {
warn!("Multiple private keys found, using the first one");
}
Ok(PrivateKey(keys.remove(0)))
}
#[cfg(not(feature = "tls"))]
#[derive(Debug, Clone)]
pub struct TlsConfig;
#[cfg(not(feature = "tls"))]
#[derive(Debug, Clone)]
pub enum TlsVersion {
TlsV12,
TlsV13,
}
#[cfg(not(feature = "tls"))]
#[derive(Debug, thiserror::Error)]
pub enum TlsError {
#[error("TLS feature not enabled")]
NotEnabled,
}
#[cfg(not(feature = "tls"))]
impl TlsConfig {
pub fn new(_cert_file: impl Into<String>, _key_file: impl Into<String>) -> Self {
panic!("TLS feature not enabled. Add 'tls' feature to your Cargo.toml");
}
}