use mockforge_core::config::HttpTlsConfig;
use mockforge_core::Result;
use std::sync::Arc;
use std::sync::Once;
use tokio_rustls::TlsAcceptor;
use tracing::info;
static CRYPTO_INIT: Once = Once::new();
pub fn init_crypto_provider() {
CRYPTO_INIT.call_once(|| {
let _ = rustls::crypto::ring::default_provider().install_default();
});
}
fn tls_config_builder(
tls13_only: bool,
) -> rustls::ConfigBuilder<rustls::server::ServerConfig, rustls::WantsVerifier> {
if tls13_only {
rustls::server::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
} else {
rustls::server::ServerConfig::builder()
}
}
fn is_tls13_only(min_version: &str) -> bool {
match min_version {
"1.3" => {
info!("Enforcing TLS 1.3 only (min_version=1.3)");
true
}
"1.2" | "" => false,
other => {
tracing::warn!("Unsupported TLS min_version '{}', using defaults (TLS 1.2+)", other);
false
}
}
}
pub fn load_tls_acceptor(config: &HttpTlsConfig) -> Result<TlsAcceptor> {
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::fs::File;
use std::io::BufReader;
init_crypto_provider();
info!(
"Loading TLS certificate from {} and key from {}",
config.cert_file, config.key_file
);
let cert_file = File::open(&config.cert_file).map_err(|e| {
mockforge_core::Error::internal(format!(
"Failed to open certificate file {}: {}",
config.cert_file, e
))
})?;
let mut cert_reader = BufReader::new(cert_file);
let server_certs: Vec<rustls::pki_types::CertificateDer<'static>> = certs(&mut cert_reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
mockforge_core::Error::internal(format!(
"Failed to parse certificate file {}: {}",
config.cert_file, e
))
})?;
if server_certs.is_empty() {
return Err(mockforge_core::Error::internal(format!(
"No certificates found in {}",
config.cert_file
)));
}
let key_file = File::open(&config.key_file).map_err(|e| {
mockforge_core::Error::internal(format!(
"Failed to open private key file {}: {}",
config.key_file, e
))
})?;
let mut key_reader = BufReader::new(key_file);
let pkcs8_keys: Vec<rustls::pki_types::PrivatePkcs8KeyDer<'static>> =
pkcs8_private_keys(&mut key_reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
mockforge_core::Error::internal(format!(
"Failed to parse private key file {}: {}",
config.key_file, e
))
})?;
let mut keys: Vec<rustls::pki_types::PrivateKeyDer<'static>> =
pkcs8_keys.into_iter().map(rustls::pki_types::PrivateKeyDer::Pkcs8).collect();
if keys.is_empty() {
return Err(mockforge_core::Error::internal(format!(
"No private keys found in {}",
config.key_file
)));
}
let mtls_mode = if !config.mtls_mode.is_empty() && config.mtls_mode != "off" {
config.mtls_mode.as_str()
} else if config.require_client_cert {
"required"
} else {
"off"
};
let tls13_only = is_tls13_only(&config.min_version);
let server_config = match mtls_mode {
"required" => {
if let Some(ref ca_file_path) = config.ca_file {
let ca_file = File::open(ca_file_path).map_err(|e| {
mockforge_core::Error::internal(format!(
"Failed to open CA certificate file {}: {}",
ca_file_path, e
))
})?;
let mut ca_reader = BufReader::new(ca_file);
let ca_certs: Vec<rustls::pki_types::CertificateDer<'static>> =
certs(&mut ca_reader).collect::<std::result::Result<Vec<_>, _>>().map_err(
|e| {
mockforge_core::Error::internal(format!(
"Failed to parse CA certificate file {}: {}",
ca_file_path, e
))
},
)?;
let mut root_store = rustls::RootCertStore::empty();
for cert in &ca_certs {
root_store.add(cert.clone()).map_err(|e| {
mockforge_core::Error::internal(format!(
"Failed to add CA certificate to root store: {}",
e
))
})?;
}
let client_verifier =
rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
.build()
.map_err(|e| {
mockforge_core::Error::internal(format!(
"Failed to build client verifier: {}",
e
))
})?;
let key = keys.remove(0);
tls_config_builder(tls13_only)
.with_client_cert_verifier(client_verifier)
.with_single_cert(server_certs, key)
.map_err(|e| {
mockforge_core::Error::internal(format!(
"TLS config error (mTLS required): {}",
e
))
})?
} else {
return Err(mockforge_core::Error::internal(
"mTLS mode 'required' requires --tls-ca (CA certificate file)",
));
}
}
"optional" => {
if let Some(ref ca_file_path) = config.ca_file {
let ca_file = File::open(ca_file_path).map_err(|e| {
mockforge_core::Error::internal(format!(
"Failed to open CA certificate file {}: {}",
ca_file_path, e
))
})?;
let mut ca_reader = BufReader::new(ca_file);
let ca_certs: Vec<rustls::pki_types::CertificateDer<'static>> =
certs(&mut ca_reader).collect::<std::result::Result<Vec<_>, _>>().map_err(
|e| {
mockforge_core::Error::internal(format!(
"Failed to parse CA certificate file {}: {}",
ca_file_path, e
))
},
)?;
let mut root_store = rustls::RootCertStore::empty();
for cert in &ca_certs {
root_store.add(cert.clone()).map_err(|e| {
mockforge_core::Error::internal(format!(
"Failed to add CA certificate to root store: {}",
e
))
})?;
}
let client_verifier =
rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
.build()
.map_err(|e| {
mockforge_core::Error::internal(format!(
"Failed to build client verifier: {}",
e
))
})?;
let key = keys.remove(0);
tls_config_builder(tls13_only)
.with_client_cert_verifier(client_verifier)
.with_single_cert(server_certs, key)
.map_err(|e| {
mockforge_core::Error::internal(format!(
"TLS config error (mTLS optional): {}",
e
))
})?
} else {
info!("mTLS optional mode specified but no CA file provided, using standard TLS");
let key = keys.remove(0);
tls_config_builder(tls13_only)
.with_no_client_auth()
.with_single_cert(server_certs, key)
.map_err(|e| {
mockforge_core::Error::internal(format!("TLS config error: {}", e))
})?
}
}
_ => {
let key = keys.remove(0);
tls_config_builder(tls13_only)
.with_no_client_auth()
.with_single_cert(server_certs, key)
.map_err(|e| mockforge_core::Error::internal(format!("TLS config error: {}", e)))?
}
};
if !config.cipher_suites.is_empty() {
info!(
"Custom cipher suites specified: {:?}. Note: rustls enforces safe cipher suites; \
for fine-grained control, configure the rustls CryptoProvider.",
config.cipher_suites
);
}
info!("TLS acceptor configured successfully");
Ok(TlsAcceptor::from(Arc::new(server_config)))
}
pub fn load_tls_server_config(
config: &HttpTlsConfig,
) -> std::result::Result<Arc<rustls::server::ServerConfig>, Box<dyn std::error::Error + Send + Sync>>
{
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
init_crypto_provider();
info!(
"Loading TLS certificate from {} and key from {}",
config.cert_file, config.key_file
);
let cert_file = File::open(&config.cert_file)
.map_err(|e| format!("Failed to open certificate file {}: {}", config.cert_file, e))?;
let mut cert_reader = BufReader::new(cert_file);
let server_certs: Vec<rustls::pki_types::CertificateDer<'static>> = certs(&mut cert_reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| format!("Failed to parse certificate file {}: {}", config.cert_file, e))?;
if server_certs.is_empty() {
return Err(format!("No certificates found in {}", config.cert_file).into());
}
let key_file = File::open(&config.key_file)
.map_err(|e| format!("Failed to open private key file {}: {}", config.key_file, e))?;
let mut key_reader = BufReader::new(key_file);
let pkcs8_keys: Vec<rustls::pki_types::PrivatePkcs8KeyDer<'static>> =
pkcs8_private_keys(&mut key_reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| format!("Failed to parse private key file {}: {}", config.key_file, e))?;
let mut keys: Vec<rustls::pki_types::PrivateKeyDer<'static>> =
pkcs8_keys.into_iter().map(rustls::pki_types::PrivateKeyDer::Pkcs8).collect();
if keys.is_empty() {
return Err(format!("No private keys found in {}", config.key_file).into());
}
let tls13_only = is_tls13_only(&config.min_version);
let mtls_mode = if !config.mtls_mode.is_empty() && config.mtls_mode != "off" {
config.mtls_mode.as_str()
} else if config.require_client_cert {
"required"
} else {
"off"
};
let server_config = match mtls_mode {
"required" => {
if let Some(ref ca_file_path) = config.ca_file {
let ca_file = File::open(ca_file_path).map_err(|e| {
format!("Failed to open CA certificate file {}: {}", ca_file_path, e)
})?;
let mut ca_reader = BufReader::new(ca_file);
let ca_certs: Vec<rustls::pki_types::CertificateDer<'static>> =
certs(&mut ca_reader).collect::<std::result::Result<Vec<_>, _>>().map_err(
|e| format!("Failed to parse CA certificate file {}: {}", ca_file_path, e),
)?;
let mut root_store = rustls::RootCertStore::empty();
for cert in &ca_certs {
root_store.add(cert.clone()).map_err(|e| {
format!("Failed to add CA certificate to root store: {}", e)
})?;
}
let client_verifier =
rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
.build()
.map_err(|e| format!("Failed to build client verifier: {}", e))?;
let key = keys.remove(0);
tls_config_builder(tls13_only)
.with_client_cert_verifier(client_verifier)
.with_single_cert(server_certs, key)
.map_err(|e| format!("TLS config error (mTLS required): {}", e))?
} else {
return Err("mTLS mode 'required' requires CA certificate file".to_string().into());
}
}
"optional" => {
if let Some(ref ca_file_path) = config.ca_file {
let ca_file = File::open(ca_file_path).map_err(|e| {
format!("Failed to open CA certificate file {}: {}", ca_file_path, e)
})?;
let mut ca_reader = BufReader::new(ca_file);
let ca_certs: Vec<rustls::pki_types::CertificateDer<'static>> =
certs(&mut ca_reader).collect::<std::result::Result<Vec<_>, _>>().map_err(
|e| format!("Failed to parse CA certificate file {}: {}", ca_file_path, e),
)?;
let mut root_store = rustls::RootCertStore::empty();
for cert in &ca_certs {
root_store.add(cert.clone()).map_err(|e| {
format!("Failed to add CA certificate to root store: {}", e)
})?;
}
let client_verifier =
rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
.build()
.map_err(|e| format!("Failed to build client verifier: {}", e))?;
let key = keys.remove(0);
tls_config_builder(tls13_only)
.with_client_cert_verifier(client_verifier)
.with_single_cert(server_certs, key)
.map_err(|e| format!("TLS config error (mTLS optional): {}", e))?
} else {
let key = keys.remove(0);
tls_config_builder(tls13_only)
.with_no_client_auth()
.with_single_cert(server_certs, key)
.map_err(|e| format!("TLS config error: {}", e))?
}
}
_ => {
let key = keys.remove(0);
tls_config_builder(tls13_only)
.with_no_client_auth()
.with_single_cert(server_certs, key)
.map_err(|e| format!("TLS config error: {}", e))?
}
};
Ok(Arc::new(server_config))
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_cert() -> (NamedTempFile, NamedTempFile) {
let cert = NamedTempFile::new().unwrap();
let key = NamedTempFile::new().unwrap();
writeln!(cert.as_file(), "-----BEGIN CERTIFICATE-----").unwrap();
writeln!(cert.as_file(), "TEST").unwrap();
writeln!(cert.as_file(), "-----END CERTIFICATE-----").unwrap();
writeln!(key.as_file(), "-----BEGIN PRIVATE KEY-----").unwrap();
writeln!(key.as_file(), "TEST").unwrap();
writeln!(key.as_file(), "-----END PRIVATE KEY-----").unwrap();
(cert, key)
}
#[test]
fn test_tls_config_validation() {
init_crypto_provider();
let (cert, key) = create_test_cert();
let config = HttpTlsConfig {
enabled: true,
cert_file: cert.path().to_string_lossy().to_string(),
key_file: key.path().to_string_lossy().to_string(),
ca_file: None,
min_version: "1.2".to_string(),
cipher_suites: Vec::new(),
require_client_cert: false,
mtls_mode: "off".to_string(),
};
let result = load_tls_acceptor(&config);
assert!(result.is_err()); }
#[test]
fn test_mtls_requires_ca() {
init_crypto_provider();
let (cert, key) = create_test_cert();
let config = HttpTlsConfig {
enabled: true,
cert_file: cert.path().to_string_lossy().to_string(),
key_file: key.path().to_string_lossy().to_string(),
ca_file: None,
min_version: "1.2".to_string(),
cipher_suites: Vec::new(),
require_client_cert: true, mtls_mode: "required".to_string(),
};
let result = load_tls_acceptor(&config);
assert!(result.is_err());
let err_msg = format!("{}", result.err().unwrap());
assert!(
err_msg.contains("CA") || err_msg.contains("--tls-ca"),
"Expected error message about CA certificate, got: {}",
err_msg
);
}
}