use std::sync::Arc;
use tokio::sync::RwLock;
use tonic::transport::{Certificate, Identity, ServerTlsConfig};
use x509_parser::prelude::*;
use crate::config::TlsConfig;
use crate::error::{Result, ServerError};
#[derive(Debug)]
pub struct TlsLoader {
config: TlsConfig,
current: Arc<RwLock<Option<LoadedTls>>>,
}
impl std::fmt::Debug for LoadedTls {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LoadedTls")
.field("identity", &"<redacted>")
.field("client_ca", &self.client_ca.is_some())
.field("require_client_cert", &self.require_client_cert)
.finish()
}
}
struct LoadedTls {
identity: Identity,
client_ca: Option<Certificate>,
require_client_cert: bool,
}
impl TlsLoader {
pub fn new(config: TlsConfig) -> Result<Self> {
let loader = Self {
config,
current: Arc::new(RwLock::new(None)),
};
loader.load_sync()?;
Ok(loader)
}
fn load_sync(&self) -> Result<LoadedTls> {
let cert = std::fs::read(&self.config.cert_path).map_err(|e| {
ServerError::Config(format!(
"failed to read certificate file {}: {e}",
self.config.cert_path.display()
))
})?;
let key = std::fs::read(&self.config.key_path).map_err(|e| {
ServerError::Config(format!(
"failed to read key file {}: {e}",
self.config.key_path.display()
))
})?;
Self::validate_certificate(&cert, &self.config.cert_path.display().to_string())?;
let identity = Identity::from_pem(&cert, &key);
let client_ca = if let Some(ref ca_path) = self.config.ca_cert_path {
let ca_cert = std::fs::read(ca_path).map_err(|e| {
ServerError::Config(format!(
"failed to read CA certificate file {}: {e}",
ca_path.display()
))
})?;
Self::validate_certificate(&ca_cert, &ca_path.display().to_string())?;
Some(Certificate::from_pem(ca_cert))
} else {
None
};
Ok(LoadedTls {
identity,
client_ca,
require_client_cert: self.config.require_client_cert,
})
}
fn validate_certificate(pem_data: &[u8], cert_name: &str) -> Result<()> {
let pem = match Pem::iter_from_buffer(pem_data).next() {
Some(Ok(pem)) => pem,
Some(Err(e)) => {
return Err(ServerError::Config(format!(
"invalid PEM format in {cert_name}: {e}"
)));
},
None => {
return Err(ServerError::Config(format!(
"no PEM data found in {cert_name}"
)));
},
};
let (_, cert) = X509Certificate::from_der(&pem.contents).map_err(|e| {
ServerError::Config(format!("invalid X.509 certificate in {cert_name}: {e}"))
})?;
let now_secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| i64::try_from(d.as_secs()).unwrap_or(i64::MAX))
.unwrap_or(0);
let now_asn1 = ASN1Time::from_timestamp(now_secs)
.map_err(|e| ServerError::Internal(format!("failed to get current time: {e}")))?;
let validity = cert.validity();
if now_asn1 > validity.not_after {
let expiry = validity.not_after.to_rfc2822().unwrap_or_default();
return Err(ServerError::Config(format!(
"certificate {cert_name} has expired (expiry: {expiry})"
)));
}
if now_asn1 < validity.not_before {
let not_before = validity.not_before.to_rfc2822().unwrap_or_default();
return Err(ServerError::Config(format!(
"certificate {cert_name} is not yet valid (not before: {not_before})"
)));
}
let subject = cert.subject().to_string();
let expiry = validity.not_after.to_rfc2822().unwrap_or_default();
tracing::debug!(
cert = %cert_name,
subject = %subject,
expires = %expiry,
"Certificate validated successfully"
);
Ok(())
}
pub fn load(&self) -> Result<ServerTlsConfig> {
let loaded = self.load_sync()?;
let mut tls_config = ServerTlsConfig::new().identity(loaded.identity.clone());
if let Some(ref ca) = loaded.client_ca {
tls_config = tls_config.client_ca_root(ca.clone());
if !loaded.require_client_cert {
tls_config = tls_config.client_auth_optional(true);
}
}
if let Ok(mut current) = self.current.try_write() {
*current = Some(loaded);
}
Ok(tls_config)
}
pub async fn reload(&self) -> Result<()> {
tracing::info!(
cert_path = %self.config.cert_path.display(),
key_path = %self.config.key_path.display(),
"Reloading TLS certificates"
);
let loaded = self.load_sync()?;
*self.current.write().await = Some(loaded);
tracing::info!("TLS certificates reloaded successfully");
Ok(())
}
#[allow(clippy::significant_drop_tightening)]
pub async fn current(&self) -> Result<ServerTlsConfig> {
let current = self.current.read().await;
let loaded = current.as_ref().ok_or_else(|| {
ServerError::Internal("TLS not loaded - call load() first".to_string())
})?;
let mut tls_config = ServerTlsConfig::new().identity(loaded.identity.clone());
if let Some(ref ca) = loaded.client_ca {
tls_config = tls_config.client_ca_root(ca.clone());
if !loaded.require_client_cert {
tls_config = tls_config.client_auth_optional(true);
}
}
Ok(tls_config)
}
#[must_use]
pub const fn config(&self) -> &TlsConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn fixtures_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("fixtures")
}
fn valid_cert_paths() -> (PathBuf, PathBuf) {
let dir = fixtures_dir();
(dir.join("valid_cert.pem"), dir.join("valid_key.pem"))
}
fn expired_cert_paths() -> (PathBuf, PathBuf) {
let dir = fixtures_dir();
(dir.join("expired_cert.pem"), dir.join("expired_key.pem"))
}
fn not_yet_valid_cert_paths() -> (PathBuf, PathBuf) {
let dir = fixtures_dir();
(
dir.join("not_yet_valid_cert.pem"),
dir.join("not_yet_valid_key.pem"),
)
}
fn ca_cert_path() -> PathBuf {
fixtures_dir().join("ca_cert.pem")
}
fn invalid_cert_path() -> PathBuf {
fixtures_dir().join("invalid_cert.pem")
}
fn empty_cert_path() -> PathBuf {
fixtures_dir().join("empty_cert.pem")
}
#[test]
fn test_load_valid_certs() {
let (cert_path, key_path) = valid_cert_paths();
let config = TlsConfig::new(cert_path, key_path);
let loader = TlsLoader::new(config).expect("should create loader");
let tls_config = loader.load();
assert!(tls_config.is_ok());
}
#[test]
fn test_load_missing_cert() {
let (_, key_path) = valid_cert_paths();
let config = TlsConfig::new(PathBuf::from("nonexistent.pem"), key_path);
let result = TlsLoader::new(config);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("certificate"));
}
#[test]
fn test_load_missing_key() {
let (cert_path, _) = valid_cert_paths();
let config = TlsConfig::new(cert_path, PathBuf::from("nonexistent.pem"));
let result = TlsLoader::new(config);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("key"));
}
#[test]
fn test_load_expired_cert() {
let (cert_path, key_path) = expired_cert_paths();
let config = TlsConfig::new(cert_path, key_path);
let result = TlsLoader::new(config);
drop(result);
}
#[test]
fn test_load_not_yet_valid_cert() {
let (cert_path, key_path) = not_yet_valid_cert_paths();
let config = TlsConfig::new(cert_path, key_path);
let result = TlsLoader::new(config);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not yet valid") || err.contains("not before"));
}
#[test]
fn test_load_invalid_cert_format() {
let (_, key_path) = valid_cert_paths();
let invalid_cert = invalid_cert_path();
let config = TlsConfig::new(invalid_cert, key_path);
let result = TlsLoader::new(config);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("PEM") || err.contains("invalid"));
}
#[test]
fn test_load_empty_cert() {
let (_, key_path) = valid_cert_paths();
let empty_cert = empty_cert_path();
let config = TlsConfig::new(empty_cert, key_path);
let result = TlsLoader::new(config);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("no PEM data") || err.contains("PEM"));
}
#[tokio::test]
async fn test_reload_success() {
let (cert_path, key_path) = valid_cert_paths();
let config = TlsConfig::new(cert_path, key_path);
let loader = TlsLoader::new(config).expect("should create loader");
loader.load().expect("should load");
let result = loader.reload().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_reload_invalid_cert_rejected() {
use std::io::Write;
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let cert_path = temp_dir.path().join("test_cert.pem");
let key_path = temp_dir.path().join("test_key.pem");
let (valid_cert, valid_key) = valid_cert_paths();
std::fs::copy(&valid_cert, &cert_path).unwrap();
std::fs::copy(&valid_key, &key_path).unwrap();
let config = TlsConfig::new(cert_path.clone(), key_path);
let loader = TlsLoader::new(config).expect("should create loader");
loader.load().expect("should load");
let mut file = std::fs::File::create(&cert_path).unwrap();
file.write_all(b"invalid certificate data").unwrap();
drop(file);
let result = loader.reload().await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("PEM") || err.contains("invalid"));
}
#[test]
fn test_validate_certificate_invalid_pem() {
let result = TlsLoader::validate_certificate(b"not valid pem", "test.pem");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("invalid") || err.contains("PEM") || err.contains("no PEM"));
}
#[test]
fn test_validate_certificate_empty() {
let result = TlsLoader::validate_certificate(b"", "test.pem");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("no PEM data"));
}
#[test]
fn test_validate_certificate_valid() {
let (cert_path, _) = valid_cert_paths();
let cert_data = std::fs::read(&cert_path).expect("should read valid cert");
let result = TlsLoader::validate_certificate(&cert_data, "valid_cert.pem");
assert!(result.is_ok());
}
#[test]
fn test_validate_certificate_not_yet_valid_fixture() {
let (cert_path, _) = not_yet_valid_cert_paths();
let cert_data = std::fs::read(&cert_path).expect("should read cert");
let result = TlsLoader::validate_certificate(&cert_data, "not_yet_valid.pem");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not yet valid"));
}
#[test]
fn test_validate_certificate_invalid_x509() {
let invalid_pem = b"-----BEGIN CERTIFICATE-----\nTm90IHZhbGlkIGJhc2U2NCBjb250ZW50IGZvciBYNTA5\n-----END CERTIFICATE-----";
let result = TlsLoader::validate_certificate(invalid_pem, "invalid_x509.pem");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("X.509") || err.contains("invalid"));
}
#[tokio::test]
async fn test_current_after_load() {
let (cert_path, key_path) = valid_cert_paths();
let config = TlsConfig::new(cert_path, key_path);
let loader = TlsLoader::new(config).expect("should create loader");
loader.load().expect("should load");
let result = loader.current().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_current_before_load() {
let (cert_path, key_path) = valid_cert_paths();
let config = TlsConfig::new(cert_path, key_path);
let loader = TlsLoader {
config,
current: Arc::new(RwLock::new(None)),
};
let result = loader.current().await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not loaded"));
}
#[tokio::test]
async fn test_mtls_ca_loading() {
let (cert_path, key_path) = valid_cert_paths();
let ca_path = ca_cert_path();
let config = TlsConfig::new(cert_path, key_path).with_ca_cert(ca_path);
let loader = TlsLoader::new(config).expect("should create loader");
let tls_config = loader.load();
assert!(tls_config.is_ok());
}
#[tokio::test]
async fn test_mtls_optional_client_cert() {
let (cert_path, key_path) = valid_cert_paths();
let ca_path = ca_cert_path();
let config = TlsConfig::new(cert_path, key_path)
.with_ca_cert(ca_path)
.with_required_client_cert(false);
let loader = TlsLoader::new(config).expect("should create loader");
let tls_config = loader.load().expect("should load");
drop(tls_config);
}
#[tokio::test]
async fn test_mtls_required_client_cert() {
let (cert_path, key_path) = valid_cert_paths();
let ca_path = ca_cert_path();
let config = TlsConfig::new(cert_path, key_path)
.with_ca_cert(ca_path)
.with_required_client_cert(true);
let loader = TlsLoader::new(config).expect("should create loader");
let tls_config = loader.load().expect("should load");
drop(tls_config);
}
#[tokio::test]
async fn test_mtls_missing_ca_cert() {
let (cert_path, key_path) = valid_cert_paths();
let missing_ca = PathBuf::from("nonexistent_ca.pem");
let config = TlsConfig::new(cert_path, key_path).with_ca_cert(missing_ca);
let result = TlsLoader::new(config);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("CA certificate"));
}
#[test]
fn test_config_accessor() {
let (cert_path, key_path) = valid_cert_paths();
let config = TlsConfig::new(cert_path.clone(), key_path);
let loader = TlsLoader::new(config).expect("should create loader");
assert_eq!(loader.config().cert_path, cert_path);
}
#[tokio::test]
async fn test_concurrent_reload() {
let (cert_path, key_path) = valid_cert_paths();
let config = TlsConfig::new(cert_path, key_path);
let loader = Arc::new(TlsLoader::new(config).expect("should create loader"));
loader.load().expect("should load");
let mut handles = vec![];
for _ in 0..10 {
let loader_clone = Arc::clone(&loader);
handles.push(tokio::spawn(async move { loader_clone.reload().await }));
}
for handle in handles {
let result = handle.await.expect("task should complete");
assert!(result.is_ok());
}
}
#[test]
fn test_debug_loaded_tls() {
let (cert_path, key_path) = valid_cert_paths();
let config = TlsConfig::new(cert_path, key_path);
let loader = TlsLoader::new(config).expect("should create loader");
let debug_str = format!("{loader:?}");
assert!(debug_str.contains("TlsLoader"));
}
}