use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls_pemfile::{certs, private_key};
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Debug, thiserror::Error)]
pub enum TlsError {
#[error("TLS certificate file not found: {0}")]
CertNotFound(String),
#[error("TLS private key file not found: {0}")]
KeyNotFound(String),
#[error("Failed to read certificate: {0}")]
CertReadError(String),
#[error("Failed to read private key: {0}")]
KeyReadError(String),
#[error("No certificates found in certificate file")]
NoCertificates,
#[error("No private key found in key file")]
NoPrivateKey,
#[error("TLS configuration error: {0}")]
ConfigError(String),
}
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub cert_path: PathBuf,
pub key_path: PathBuf,
pub ca_path: Option<PathBuf>,
}
impl TlsConfig {
pub fn new(cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
Self {
cert_path: cert_path.into(),
key_path: key_path.into(),
ca_path: None,
}
}
pub fn with_ca(mut self, ca_path: impl Into<PathBuf>) -> Self {
self.ca_path = Some(ca_path.into());
self
}
}
fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
let file = File::open(path)
.map_err(|e| TlsError::CertReadError(format!("{}: {}", path.display(), e)))?;
let mut reader = BufReader::new(file);
let certs_result: Vec<CertificateDer<'static>> =
certs(&mut reader).filter_map(|c| c.ok()).collect();
if certs_result.is_empty() {
return Err(TlsError::NoCertificates);
}
Ok(certs_result)
}
fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
let file = File::open(path)
.map_err(|e| TlsError::KeyReadError(format!("{}: {}", path.display(), e)))?;
let mut reader = BufReader::new(file);
private_key(&mut reader)
.map_err(|e| TlsError::KeyReadError(e.to_string()))?
.ok_or(TlsError::NoPrivateKey)
}
pub fn build_server_tls_config(
config: &TlsConfig,
) -> Result<Arc<tokio_rustls::rustls::ServerConfig>, TlsError> {
if !config.cert_path.exists() {
return Err(TlsError::CertNotFound(config.cert_path.display().to_string()));
}
if !config.key_path.exists() {
return Err(TlsError::KeyNotFound(config.key_path.display().to_string()));
}
let certs_vec = load_certs(&config.cert_path)?;
let key = load_private_key(&config.key_path)?;
let provider = rustls::crypto::ring::default_provider();
let _ = provider.clone().install_default();
let server_config = if let Some(ca_path) = &config.ca_path {
if !ca_path.exists() {
return Err(TlsError::CertNotFound(format!("CA certificate: {}", ca_path.display())));
}
let ca_certs = load_certs(ca_path)?;
let mut root_store = rustls::RootCertStore::empty();
for cert in ca_certs {
root_store
.add(cert)
.map_err(|e| TlsError::ConfigError(format!("Failed to add CA cert: {}", e)))?;
}
let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
.build()
.map_err(|e| {
TlsError::ConfigError(format!("Failed to create client verifier: {}", e))
})?;
rustls::ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()
.map_err(|e| TlsError::ConfigError(e.to_string()))?
.with_client_cert_verifier(client_verifier)
.with_single_cert(certs_vec, key)
.map_err(|e| TlsError::ConfigError(e.to_string()))?
} else {
rustls::ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()
.map_err(|e| TlsError::ConfigError(e.to_string()))?
.with_no_client_auth()
.with_single_cert(certs_vec, key)
.map_err(|e| TlsError::ConfigError(e.to_string()))?
};
Ok(Arc::new(server_config))
}
pub fn build_client_tls_config(
config: &TlsConfig,
) -> Result<Arc<tokio_rustls::rustls::ClientConfig>, TlsError> {
if !config.cert_path.exists() {
return Err(TlsError::CertNotFound(config.cert_path.display().to_string()));
}
if !config.key_path.exists() {
return Err(TlsError::KeyNotFound(config.key_path.display().to_string()));
}
let certs_vec = load_certs(&config.cert_path)?;
let key = load_private_key(&config.key_path)?;
let provider = rustls::crypto::ring::default_provider();
let mut root_store = rustls::RootCertStore::empty();
if let Some(ca_path) = &config.ca_path {
if !ca_path.exists() {
return Err(TlsError::CertNotFound(format!("CA certificate: {}", ca_path.display())));
}
let ca_certs = load_certs(ca_path)?;
for cert in ca_certs {
root_store
.add(cert)
.map_err(|e| TlsError::ConfigError(format!("Failed to add CA cert: {}", e)))?;
}
} else {
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
}
let client_config = rustls::ClientConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()
.map_err(|e| TlsError::ConfigError(e.to_string()))?
.with_root_certificates(root_store)
.with_client_auth_cert(certs_vec, key)
.map_err(|e| TlsError::ConfigError(e.to_string()))?;
Ok(Arc::new(client_config))
}
#[cfg(test)]
mod tests {
use super::*;
fn write_test_cert_and_key(dir: &tempfile::TempDir) -> (PathBuf, PathBuf) {
let cert_path = dir.path().join("cert.pem");
let key_path = dir.path().join("key.pem");
let subject_alt_names = vec!["localhost".to_string()];
let cert_params =
rcgen::CertificateParams::new(subject_alt_names).expect("Failed to create cert params");
let key_pair = rcgen::KeyPair::generate().expect("Failed to generate key pair");
let cert = cert_params.self_signed(&key_pair).expect("Failed to self-sign cert");
let cert_pem = cert.pem();
let key_pem = key_pair.serialize_pem();
std::fs::write(&cert_path, cert_pem).unwrap();
std::fs::write(&key_path, key_pem).unwrap();
(cert_path, key_path)
}
#[test]
fn test_tls_config_new() {
let config = TlsConfig::new("/tmp/cert.pem", "/tmp/key.pem");
assert_eq!(config.cert_path, PathBuf::from("/tmp/cert.pem"));
assert_eq!(config.key_path, PathBuf::from("/tmp/key.pem"));
assert!(config.ca_path.is_none());
}
#[test]
fn test_tls_config_with_ca() {
let config = TlsConfig::new("/tmp/cert.pem", "/tmp/key.pem").with_ca("/tmp/ca.pem");
assert_eq!(config.ca_path, Some(PathBuf::from("/tmp/ca.pem")));
}
#[test]
fn test_tls_error_display() {
let err = TlsError::CertNotFound("/path/to/cert.pem".to_string());
assert!(err.to_string().contains("/path/to/cert.pem"));
let err = TlsError::NoCertificates;
assert!(err.to_string().contains("No certificates"));
let err = TlsError::NoPrivateKey;
assert!(err.to_string().contains("No private key"));
let err = TlsError::ConfigError("bad config".to_string());
assert!(err.to_string().contains("bad config"));
}
#[test]
fn test_build_server_tls_config_cert_not_found() {
let config = TlsConfig::new("/nonexistent/cert.pem", "/nonexistent/key.pem");
let result = build_server_tls_config(&config);
assert!(matches!(result, Err(TlsError::CertNotFound(_))));
}
#[test]
fn test_build_server_tls_config_key_not_found() {
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("cert.pem");
std::fs::write(&cert_path, "placeholder").unwrap();
let config = TlsConfig::new(&cert_path, "/nonexistent/key.pem");
let result = build_server_tls_config(&config);
assert!(matches!(result, Err(TlsError::KeyNotFound(_))));
}
#[test]
fn test_build_server_tls_config_empty_cert() {
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("cert.pem");
let key_path = dir.path().join("key.pem");
std::fs::write(&cert_path, "").unwrap();
std::fs::write(&key_path, "").unwrap();
let config = TlsConfig::new(&cert_path, &key_path);
let result = build_server_tls_config(&config);
assert!(matches!(result, Err(TlsError::NoCertificates)));
}
#[test]
fn test_build_server_tls_config_valid() {
let dir = tempfile::tempdir().unwrap();
let (cert_path, key_path) = write_test_cert_and_key(&dir);
let config = TlsConfig::new(&cert_path, &key_path);
let result = build_server_tls_config(&config);
assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
}
#[test]
fn test_build_server_tls_config_with_client_auth() {
let dir = tempfile::tempdir().unwrap();
let (cert_path, key_path) = write_test_cert_and_key(&dir);
let ca_path = dir.path().join("ca.pem");
std::fs::copy(&cert_path, &ca_path).unwrap();
let config = TlsConfig::new(&cert_path, &key_path).with_ca(&ca_path);
let result = build_server_tls_config(&config);
assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
}
#[test]
fn test_build_server_tls_config_ca_not_found() {
let dir = tempfile::tempdir().unwrap();
let (cert_path, key_path) = write_test_cert_and_key(&dir);
let config = TlsConfig::new(&cert_path, &key_path).with_ca("/nonexistent/ca.pem");
let result = build_server_tls_config(&config);
assert!(matches!(result, Err(TlsError::CertNotFound(_))));
}
#[test]
fn test_build_client_tls_config_cert_not_found() {
let config = TlsConfig::new("/nonexistent/cert.pem", "/nonexistent/key.pem");
let result = build_client_tls_config(&config);
assert!(matches!(result, Err(TlsError::CertNotFound(_))));
}
#[test]
fn test_build_client_tls_config_valid_with_ca() {
let dir = tempfile::tempdir().unwrap();
let (cert_path, key_path) = write_test_cert_and_key(&dir);
let ca_path = dir.path().join("ca.pem");
std::fs::copy(&cert_path, &ca_path).unwrap();
let config = TlsConfig::new(&cert_path, &key_path).with_ca(&ca_path);
let result = build_client_tls_config(&config);
assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
}
#[test]
fn test_build_client_tls_config_valid_default_roots() {
let dir = tempfile::tempdir().unwrap();
let (cert_path, key_path) = write_test_cert_and_key(&dir);
let config = TlsConfig::new(&cert_path, &key_path);
let result = build_client_tls_config(&config);
assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
}
#[test]
fn test_build_client_tls_config_ca_not_found() {
let dir = tempfile::tempdir().unwrap();
let (cert_path, key_path) = write_test_cert_and_key(&dir);
let config = TlsConfig::new(&cert_path, &key_path).with_ca("/nonexistent/ca.pem");
let result = build_client_tls_config(&config);
assert!(matches!(result, Err(TlsError::CertNotFound(_))));
}
}