use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls::version::{TLS12, TLS13};
use rustls::{ClientConfig, RootCertStore, ServerConfig};
use std::sync::Arc;
use thiserror::Error;
use tracing::info;
use webpki_roots::TLS_SERVER_ROOTS;
#[derive(Debug, Error)]
pub enum TlsConfigError {
#[error("Invalid certificate: {0}")]
InvalidCertificate(String),
#[error("Invalid private key: {0}")]
InvalidPrivateKey(String),
#[error("TLS configuration error: {0}")]
ConfigError(String),
#[error("No supported cipher suites")]
NoCipherSuites,
#[error("Invalid server name: {0}")]
InvalidServerName(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TlsVersion {
Tls12And13,
Tls13Only,
}
impl Default for TlsVersion {
fn default() -> Self {
Self::Tls12And13 }
}
pub struct ClientTlsConfig {
config: Arc<ServerConfig>,
}
impl ClientTlsConfig {
pub fn new(
cert_chain: Vec<CertificateDer<'static>>,
private_key: PrivateKeyDer<'static>,
) -> Result<Self, TlsConfigError> {
Self::new_with_options(cert_chain, private_key, TlsVersion::default())
}
pub fn new_with_options(
cert_chain: Vec<CertificateDer<'static>>,
private_key: PrivateKeyDer<'static>,
tls_version: TlsVersion,
) -> Result<Self, TlsConfigError> {
let versions = match tls_version {
TlsVersion::Tls12And13 => vec![&TLS12, &TLS13],
TlsVersion::Tls13Only => vec![&TLS13],
};
let mut config = ServerConfig::builder_with_protocol_versions(&versions)
.with_no_client_auth()
.with_single_cert(cert_chain, private_key)
.map_err(|e| TlsConfigError::ConfigError(e.to_string()))?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
info!(
tls_version = ?tls_version,
alpn = ?config.alpn_protocols,
"Client-facing TLS config created"
);
Ok(Self {
config: Arc::new(config),
})
}
pub fn server_config(&self) -> Arc<ServerConfig> {
Arc::clone(&self.config)
}
}
pub struct UpstreamTlsConfig {
config: Arc<ClientConfig>,
}
impl UpstreamTlsConfig {
pub fn new() -> Result<Self, TlsConfigError> {
Self::new_with_options(TlsVersion::default())
}
pub fn new_with_options(tls_version: TlsVersion) -> Result<Self, TlsConfigError> {
let mut root_store = RootCertStore::empty();
root_store.extend(TLS_SERVER_ROOTS.iter().cloned());
let versions = match tls_version {
TlsVersion::Tls12And13 => vec![&TLS12, &TLS13],
TlsVersion::Tls13Only => vec![&TLS13],
};
let config = ClientConfig::builder_with_protocol_versions(&versions)
.with_root_certificates(root_store)
.with_no_client_auth();
info!(
tls_version = ?tls_version,
roots_count = TLS_SERVER_ROOTS.len(),
"Upstream TLS config created"
);
Ok(Self {
config: Arc::new(config),
})
}
pub fn client_config(&self) -> Arc<ClientConfig> {
Arc::clone(&self.config)
}
}
impl Default for UpstreamTlsConfig {
fn default() -> Self {
Self::new().expect("Failed to create default UpstreamTlsConfig")
}
}
#[derive(Debug, Clone)]
pub struct TlsConfigBuilder {
tls_version: TlsVersion,
alpn_protocols: Vec<Vec<u8>>,
enable_sni: bool,
verify_hostname: bool,
}
impl TlsConfigBuilder {
pub fn new() -> Self {
Self {
tls_version: TlsVersion::Tls12And13,
alpn_protocols: vec![b"h2".to_vec(), b"http/1.1".to_vec()],
enable_sni: true,
verify_hostname: true,
}
}
pub fn tls_version(mut self, version: TlsVersion) -> Self {
self.tls_version = version;
self
}
pub fn alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
self.alpn_protocols = protocols;
self
}
pub fn enable_sni(mut self, enable: bool) -> Self {
self.enable_sni = enable;
self
}
pub fn verify_hostname(mut self, verify: bool) -> Self {
self.verify_hostname = verify;
self
}
pub fn build_upstream(&self) -> Result<UpstreamTlsConfig, TlsConfigError> {
UpstreamTlsConfig::new_with_options(self.tls_version)
}
pub fn build_client_facing(
&self,
cert_chain: Vec<CertificateDer<'static>>,
private_key: PrivateKeyDer<'static>,
) -> Result<ClientTlsConfig, TlsConfigError> {
ClientTlsConfig::new_with_options(cert_chain, private_key, self.tls_version)
}
}
impl Default for TlsConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TlsHardeningConfig {
pub reject_old_tls: bool,
pub min_version: TlsVersion,
pub require_alpn: bool,
pub enforce_hostname_verification: bool,
}
impl Default for TlsHardeningConfig {
fn default() -> Self {
Self {
reject_old_tls: true, min_version: TlsVersion::Tls12And13,
require_alpn: false, enforce_hostname_verification: true,
}
}
}
impl TlsHardeningConfig {
pub fn strict() -> Self {
Self {
reject_old_tls: true,
min_version: TlsVersion::Tls13Only,
require_alpn: true,
enforce_hostname_verification: true,
}
}
pub fn compatible() -> Self {
Self::default()
}
}
pub struct SniUtils;
impl SniUtils {
pub fn parse_server_name(hostname: &str) -> Result<ServerName<'static>, TlsConfigError> {
ServerName::try_from(hostname.to_owned())
.map_err(|e| TlsConfigError::InvalidServerName(format!("{}", e)))
}
pub fn validate_hostname(hostname: &str) -> bool {
!hostname.is_empty()
&& !hostname.starts_with('.')
&& !hostname.ends_with('.')
&& hostname.len() <= 253
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tls_version_default() {
assert_eq!(TlsVersion::default(), TlsVersion::Tls12And13);
}
#[test]
fn test_hardening_config_defaults() {
let config = TlsHardeningConfig::default();
assert!(config.reject_old_tls);
assert_eq!(config.min_version, TlsVersion::Tls12And13);
assert!(config.enforce_hostname_verification);
}
#[test]
fn test_hardening_config_strict() {
let config = TlsHardeningConfig::strict();
assert!(config.reject_old_tls);
assert_eq!(config.min_version, TlsVersion::Tls13Only);
assert!(config.require_alpn);
assert!(config.enforce_hostname_verification);
}
#[test]
fn test_sni_validate_hostname() {
assert!(SniUtils::validate_hostname("example.com"));
assert!(SniUtils::validate_hostname("sub.example.com"));
assert!(SniUtils::validate_hostname("a.b.c.example.com"));
assert!(!SniUtils::validate_hostname(""));
assert!(!SniUtils::validate_hostname(".example.com"));
assert!(!SniUtils::validate_hostname("example.com."));
}
#[test]
fn test_sni_parse_server_name() {
let result = SniUtils::parse_server_name("example.com");
assert!(result.is_ok());
let result = SniUtils::parse_server_name("192.168.1.1");
assert!(
result.is_ok(),
"rustls 0.22+ accepts IP addresses in ServerName"
);
let result = SniUtils::parse_server_name("");
assert!(result.is_err(), "Empty hostname should fail");
}
#[test]
fn test_tls_config_builder_defaults() {
let builder = TlsConfigBuilder::new();
assert_eq!(builder.tls_version, TlsVersion::Tls12And13);
assert_eq!(builder.alpn_protocols.len(), 2);
assert!(builder.enable_sni);
assert!(builder.verify_hostname);
}
#[test]
fn test_tls_config_builder_customization() {
let builder = TlsConfigBuilder::new()
.tls_version(TlsVersion::Tls13Only)
.enable_sni(false)
.verify_hostname(false);
assert_eq!(builder.tls_version, TlsVersion::Tls13Only);
assert!(!builder.enable_sni);
assert!(!builder.verify_hostname);
}
#[test]
fn test_upstream_tls_config_creation() {
let config = UpstreamTlsConfig::new();
assert!(config.is_ok());
let config = UpstreamTlsConfig::new_with_options(TlsVersion::Tls13Only);
assert!(config.is_ok());
}
}