use crate::{Result, WireError};
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::RootCertStore;
use rustls::{ClientConfig, DigitallySignedStruct, SignatureScheme};
use rustls_pemfile::Item;
use std::fmt::Debug;
use std::fs;
use std::sync::Arc;
#[derive(Clone)]
pub struct TlsConfig {
ca_cert_path: Option<String>,
verify_hostname: bool,
danger_accept_invalid_certs: bool,
danger_accept_invalid_hostnames: bool,
client_config: Arc<ClientConfig>,
}
impl TlsConfig {
pub fn builder() -> TlsConfigBuilder {
TlsConfigBuilder::default()
}
pub fn client_config(&self) -> Arc<ClientConfig> {
self.client_config.clone()
}
pub const fn verify_hostname(&self) -> bool {
self.verify_hostname
}
pub const fn danger_accept_invalid_certs(&self) -> bool {
self.danger_accept_invalid_certs
}
pub const fn danger_accept_invalid_hostnames(&self) -> bool {
self.danger_accept_invalid_hostnames
}
}
impl std::fmt::Debug for TlsConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TlsConfig")
.field("ca_cert_path", &self.ca_cert_path)
.field("verify_hostname", &self.verify_hostname)
.field(
"danger_accept_invalid_certs",
&self.danger_accept_invalid_certs,
)
.field(
"danger_accept_invalid_hostnames",
&self.danger_accept_invalid_hostnames,
)
.field("client_config", &"<ClientConfig>")
.finish()
}
}
#[must_use = "call .build() to construct the final value"]
pub struct TlsConfigBuilder {
ca_cert_path: Option<String>,
verify_hostname: bool,
danger_accept_invalid_certs: bool,
danger_accept_invalid_hostnames: bool,
}
impl Default for TlsConfigBuilder {
fn default() -> Self {
Self {
ca_cert_path: None,
verify_hostname: true,
danger_accept_invalid_certs: false,
danger_accept_invalid_hostnames: false,
}
}
}
impl TlsConfigBuilder {
pub fn ca_cert_path(mut self, path: impl Into<String>) -> Self {
self.ca_cert_path = Some(path.into());
self
}
pub const fn verify_hostname(mut self, verify: bool) -> Self {
self.verify_hostname = verify;
self
}
pub const fn danger_accept_invalid_certs(mut self, accept: bool) -> Self {
self.danger_accept_invalid_certs = accept;
self
}
pub const fn danger_accept_invalid_hostnames(mut self, accept: bool) -> Self {
self.danger_accept_invalid_hostnames = accept;
self
}
pub fn build(self) -> Result<TlsConfig> {
validate_tls_security(self.danger_accept_invalid_certs)?;
let client_config = if self.danger_accept_invalid_certs {
let verifier = Arc::new(NoVerifier);
Arc::new(
ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(verifier)
.with_no_client_auth(),
)
} else {
let root_store = if let Some(ca_path) = &self.ca_cert_path {
self.load_custom_ca(ca_path)?
} else {
let result = rustls_native_certs::load_native_certs();
let mut store = RootCertStore::empty();
for cert in result.certs {
let _ = store.add_parsable_certificates(std::iter::once(cert));
}
if !result.errors.is_empty() && store.is_empty() {
return Err(WireError::Config(
"Failed to load any system root certificates".to_string(),
));
}
store
};
Arc::new(
ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
)
};
Ok(TlsConfig {
ca_cert_path: self.ca_cert_path,
verify_hostname: self.verify_hostname,
danger_accept_invalid_certs: self.danger_accept_invalid_certs,
danger_accept_invalid_hostnames: self.danger_accept_invalid_hostnames,
client_config,
})
}
fn load_custom_ca(&self, ca_path: &str) -> Result<RootCertStore> {
let ca_cert_data = fs::read(ca_path).map_err(|e| {
WireError::Config(format!(
"Failed to read CA certificate file '{}': {}",
ca_path, e
))
})?;
let mut reader = std::io::Cursor::new(&ca_cert_data);
let mut root_store = RootCertStore::empty();
let mut found_certs = 0;
loop {
match rustls_pemfile::read_one(&mut reader) {
Ok(Some(Item::X509Certificate(cert))) => {
let _ = root_store.add_parsable_certificates(std::iter::once(cert));
found_certs += 1;
}
Ok(Some(_)) => {
}
Ok(None) => {
break;
}
Err(_) => {
return Err(WireError::Config(format!(
"Failed to parse CA certificate from '{}'",
ca_path
)));
}
}
}
if found_certs == 0 {
return Err(WireError::Config(format!(
"No valid certificates found in '{}'",
ca_path
)));
}
Ok(root_store)
}
}
fn validate_tls_security(danger_accept_invalid_certs: bool) -> Result<()> {
if danger_accept_invalid_certs {
#[cfg(not(debug_assertions))]
return Err(WireError::Config(
"TLS certificate validation bypass not permitted in release builds".into(),
));
#[cfg(debug_assertions)]
{
tracing::warn!("TLS certificate validation is DISABLED (development only)");
tracing::warn!("This mode is only for development with self-signed certificates");
}
}
Ok(())
}
pub fn parse_server_name(hostname: &str) -> Result<String> {
let hostname = hostname.trim_end_matches('.');
if hostname.is_empty() || hostname.len() > 253 {
return Err(WireError::Config(format!(
"Invalid hostname for TLS: '{}'",
hostname
)));
}
if !hostname
.chars()
.all(|c| c.is_alphanumeric() || c == '-' || c == '.')
{
return Err(WireError::Config(format!(
"Invalid hostname for TLS: '{}'",
hostname
)));
}
Ok(hostname.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
fn install_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
#[test]
fn test_tls_config_builder_defaults() {
let tls = TlsConfigBuilder::default();
assert!(!tls.danger_accept_invalid_certs);
assert!(!tls.danger_accept_invalid_hostnames);
assert!(tls.verify_hostname);
assert!(tls.ca_cert_path.is_none());
}
#[test]
fn test_tls_config_builder_with_hostname_verification() {
install_crypto_provider();
let tls = TlsConfig::builder()
.verify_hostname(true)
.build()
.expect("Failed to build TLS config");
assert!(tls.verify_hostname());
assert!(!tls.danger_accept_invalid_certs());
}
#[test]
fn test_tls_config_builder_with_custom_ca() {
}
#[test]
fn test_parse_server_name_valid() {
let _name =
parse_server_name("localhost").expect("localhost should be a valid server name");
let _name =
parse_server_name("example.com").expect("example.com should be a valid server name");
let _name = parse_server_name("db.internal.example.com")
.expect("subdomain should be a valid server name");
}
#[test]
fn test_parse_server_name_trailing_dot() {
let _name = parse_server_name("example.com.")
.expect("trailing dot should be accepted as valid server name");
}
#[test]
fn test_parse_server_name_with_port() {
let _result = parse_server_name("example.com:5432");
}
#[test]
fn test_tls_config_debug() {
install_crypto_provider();
let tls = TlsConfig::builder()
.verify_hostname(true)
.build()
.expect("Failed to build TLS config");
let debug_str = format!("{:?}", tls);
assert!(debug_str.contains("TlsConfig"));
assert!(debug_str.contains("verify_hostname"));
}
#[test]
#[cfg(not(debug_assertions))]
fn test_danger_mode_returns_error_in_release_build() {
let result = TlsConfig::builder()
.danger_accept_invalid_certs(true)
.build();
assert!(
result.is_err(),
"danger mode must be rejected in release builds"
);
let err = result.unwrap_err();
assert!(
err.to_string().contains("not permitted in release builds"),
"error message must explain the restriction",
);
}
#[test]
fn test_danger_mode_allowed_in_debug_build() {
install_crypto_provider();
let config = TlsConfig::builder()
.danger_accept_invalid_certs(true)
.build()
.expect("danger mode should be allowed in debug builds");
assert!(config.danger_accept_invalid_certs());
}
#[test]
fn test_normal_tls_config_works() {
install_crypto_provider();
let config = TlsConfig::builder()
.verify_hostname(true)
.build()
.expect("normal TLS config should build successfully");
assert!(!config.danger_accept_invalid_certs());
}
}
#[derive(Debug)]
struct NoVerifier;
impl ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> std::result::Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
]
}
}