use std::fmt;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use parking_lot::RwLock;
#[derive(Debug, Clone)]
pub enum TlsError {
CertificateNotFound(String),
KeyNotFound(String),
InvalidCertificate(String),
InvalidKey(String),
CertificateValidation(String),
HandshakeFailed(String),
CertificateExpired,
CertificateNotYetValid,
SniMismatch(String),
ClientAuthFailed(String),
Configuration(String),
Io(String),
}
impl fmt::Display for TlsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::CertificateNotFound(path) => write!(f, "Certificate not found: {}", path),
Self::KeyNotFound(path) => write!(f, "Private key not found: {}", path),
Self::InvalidCertificate(msg) => write!(f, "Invalid certificate: {}", msg),
Self::InvalidKey(msg) => write!(f, "Invalid private key: {}", msg),
Self::CertificateValidation(msg) => write!(f, "Certificate validation failed: {}", msg),
Self::HandshakeFailed(msg) => write!(f, "TLS handshake failed: {}", msg),
Self::CertificateExpired => write!(f, "Certificate has expired"),
Self::CertificateNotYetValid => write!(f, "Certificate is not yet valid"),
Self::SniMismatch(hostname) => write!(f, "SNI hostname mismatch: {}", hostname),
Self::ClientAuthFailed(msg) => write!(f, "Client authentication failed: {}", msg),
Self::Configuration(msg) => write!(f, "TLS configuration error: {}", msg),
Self::Io(msg) => write!(f, "I/O error: {}", msg),
}
}
}
impl std::error::Error for TlsError {}
pub type TlsResult<T> = Result<T, TlsError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TlsVersion {
Tls12,
Tls13,
}
impl TlsVersion {
pub fn as_str(&self) -> &'static str {
match self {
Self::Tls12 => "TLS 1.2",
Self::Tls13 => "TLS 1.3",
}
}
}
impl Default for TlsVersion {
fn default() -> Self {
Self::Tls13
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ClientAuth {
#[default]
None,
Optional,
Required,
}
impl ClientAuth {
pub fn verifies_client(&self) -> bool {
!matches!(self, Self::None)
}
pub fn requires_client(&self) -> bool {
matches!(self, Self::Required)
}
}
#[derive(Debug, Clone)]
pub struct CertificateInfo {
pub subject_cn: Option<String>,
pub issuer_cn: Option<String>,
pub san_dns: Vec<String>,
pub san_ips: Vec<String>,
pub not_before: Option<SystemTime>,
pub not_after: Option<SystemTime>,
pub serial_number: String,
pub is_self_signed: bool,
pub is_ca: bool,
pub key_usage: Vec<String>,
pub extended_key_usage: Vec<String>,
}
impl CertificateInfo {
pub fn empty() -> Self {
Self {
subject_cn: None,
issuer_cn: None,
san_dns: Vec::new(),
san_ips: Vec::new(),
not_before: None,
not_after: None,
serial_number: String::new(),
is_self_signed: false,
is_ca: false,
key_usage: Vec::new(),
extended_key_usage: Vec::new(),
}
}
pub fn is_valid(&self) -> bool {
let now = SystemTime::now();
if let Some(not_before) = self.not_before {
if now < not_before {
return false;
}
}
if let Some(not_after) = self.not_after {
if now > not_after {
return false;
}
}
true
}
pub fn remaining_validity(&self) -> Option<Duration> {
let now = SystemTime::now();
self.not_after
.and_then(|not_after| not_after.duration_since(now).ok())
}
pub fn expires_within(&self, duration: Duration) -> bool {
self.remaining_validity()
.map(|remaining| remaining < duration)
.unwrap_or(true)
}
pub fn matches_hostname(&self, hostname: &str) -> bool {
if let Some(ref cn) = self.subject_cn {
if cn == hostname || Self::wildcard_match(cn, hostname) {
return true;
}
}
for san in &self.san_dns {
if san == hostname || Self::wildcard_match(san, hostname) {
return true;
}
}
for san_ip in &self.san_ips {
if san_ip == hostname {
return true;
}
}
false
}
fn wildcard_match(pattern: &str, hostname: &str) -> bool {
if let Some(suffix) = pattern.strip_prefix("*.") {
if let Some(rest) = hostname.strip_suffix(suffix) {
return rest.ends_with('.') && !rest[..rest.len() - 1].contains('.');
}
}
false
}
}
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub cert_chain: Option<Vec<u8>>,
pub cert_path: Option<PathBuf>,
pub private_key: Option<Vec<u8>>,
pub key_path: Option<PathBuf>,
pub ca_certs: Option<Vec<u8>>,
pub ca_path: Option<PathBuf>,
pub min_version: TlsVersion,
pub max_version: TlsVersion,
pub client_auth: ClientAuth,
pub alpn_protocols: Vec<String>,
pub sni_hostnames: Vec<String>,
pub session_ticket_lifetime: Duration,
pub ocsp_stapling: bool,
pub rotation_check_interval: Duration,
pub is_server: bool,
}
impl Default for TlsConfig {
fn default() -> Self {
Self {
cert_chain: None,
cert_path: None,
private_key: None,
key_path: None,
ca_certs: None,
ca_path: None,
min_version: TlsVersion::Tls12,
max_version: TlsVersion::Tls13,
client_auth: ClientAuth::None,
alpn_protocols: vec!["h2".to_string(), "http/1.1".to_string()],
sni_hostnames: Vec::new(),
session_ticket_lifetime: Duration::from_secs(24 * 60 * 60), ocsp_stapling: false,
rotation_check_interval: Duration::from_secs(3600), is_server: false,
}
}
}
impl TlsConfig {
pub fn server() -> TlsConfigBuilder {
TlsConfigBuilder::new(true)
}
pub fn client() -> TlsConfigBuilder {
TlsConfigBuilder::new(false)
}
pub fn is_server(&self) -> bool {
self.is_server
}
pub fn is_client(&self) -> bool {
!self.is_server
}
pub fn is_mtls(&self) -> bool {
self.client_auth.requires_client()
}
}
pub struct TlsConfigBuilder {
config: TlsConfig,
}
impl TlsConfigBuilder {
fn new(is_server: bool) -> Self {
Self {
config: TlsConfig {
is_server,
..Default::default()
},
}
}
pub fn with_cert(mut self, cert: Vec<u8>) -> Self {
self.config.cert_chain = Some(cert);
self
}
pub fn with_cert_file(mut self, path: impl Into<PathBuf>) -> Self {
self.config.cert_path = Some(path.into());
self
}
pub fn with_key(mut self, key: Vec<u8>) -> Self {
self.config.private_key = Some(key);
self
}
pub fn with_key_file(mut self, path: impl Into<PathBuf>) -> Self {
self.config.key_path = Some(path.into());
self
}
pub fn with_ca_cert(mut self, ca: Vec<u8>) -> Self {
self.config.ca_certs = Some(ca);
self
}
pub fn with_ca_file(mut self, path: impl Into<PathBuf>) -> Self {
self.config.ca_path = Some(path.into());
self
}
pub fn with_min_version(mut self, version: TlsVersion) -> Self {
self.config.min_version = version;
self
}
pub fn with_max_version(mut self, version: TlsVersion) -> Self {
self.config.max_version = version;
self
}
pub fn with_client_auth(mut self, auth: ClientAuth) -> Self {
self.config.client_auth = auth;
self
}
pub fn with_mtls(mut self) -> Self {
self.config.client_auth = ClientAuth::Required;
self
}
pub fn with_alpn(mut self, protocols: Vec<String>) -> Self {
self.config.alpn_protocols = protocols;
self
}
pub fn with_sni_hostname(mut self, hostname: impl Into<String>) -> Self {
self.config.sni_hostnames.push(hostname.into());
self
}
pub fn with_session_ticket_lifetime(mut self, lifetime: Duration) -> Self {
self.config.session_ticket_lifetime = lifetime;
self
}
pub fn with_ocsp_stapling(mut self, enable: bool) -> Self {
self.config.ocsp_stapling = enable;
self
}
pub fn with_rotation_check_interval(mut self, interval: Duration) -> Self {
self.config.rotation_check_interval = interval;
self
}
pub fn build(self) -> TlsResult<TlsConfig> {
if self.config.is_server {
if self.config.cert_chain.is_none() && self.config.cert_path.is_none() {
return Err(TlsError::Configuration(
"Server TLS requires a certificate".to_string(),
));
}
if self.config.private_key.is_none() && self.config.key_path.is_none() {
return Err(TlsError::Configuration(
"Server TLS requires a private key".to_string(),
));
}
}
Ok(self.config)
}
}
#[derive(Clone)]
pub struct CertificateEntry {
pub cert_chain: Vec<Vec<u8>>,
pub private_key: Vec<u8>,
pub info: CertificateInfo,
pub loaded_at: Instant,
pub cert_path: Option<PathBuf>,
pub key_path: Option<PathBuf>,
}
impl fmt::Debug for CertificateEntry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CertificateEntry")
.field("info", &self.info)
.field("loaded_at", &self.loaded_at)
.field("cert_path", &self.cert_path)
.field("key_path", &self.key_path)
.finish()
}
}
pub struct CertificateStore {
current: RwLock<Option<CertificateEntry>>,
previous: RwLock<Option<CertificateEntry>>,
check_interval: Duration,
last_check: RwLock<Instant>,
on_rotation: RwLock<Option<Arc<dyn Fn(&CertificateEntry) + Send + Sync>>>,
expiry_warning_threshold: Duration,
}
impl CertificateStore {
pub fn new() -> Self {
Self {
current: RwLock::new(None),
previous: RwLock::new(None),
check_interval: Duration::from_secs(3600),
last_check: RwLock::new(Instant::now()),
on_rotation: RwLock::new(None),
expiry_warning_threshold: Duration::from_secs(7 * 24 * 60 * 60), }
}
pub fn with_check_interval(mut self, interval: Duration) -> Self {
self.check_interval = interval;
self
}
pub fn with_expiry_warning(mut self, threshold: Duration) -> Self {
self.expiry_warning_threshold = threshold;
self
}
pub fn on_rotation<F>(self, callback: F) -> Self
where
F: Fn(&CertificateEntry) + Send + Sync + 'static,
{
*self.on_rotation.write() = Some(Arc::new(callback));
self
}
pub fn load_pem(&self, cert_pem: &[u8], key_pem: &[u8]) -> TlsResult<()> {
let entry = self.parse_pem(cert_pem, key_pem, None, None)?;
self.set_certificate(entry);
Ok(())
}
pub fn load_files(&self, cert_path: &PathBuf, key_path: &PathBuf) -> TlsResult<()> {
let cert_pem = std::fs::read(cert_path).map_err(|e| {
TlsError::CertificateNotFound(format!("{}: {}", cert_path.display(), e))
})?;
let key_pem = std::fs::read(key_path)
.map_err(|e| TlsError::KeyNotFound(format!("{}: {}", key_path.display(), e)))?;
let entry = self.parse_pem(
&cert_pem,
&key_pem,
Some(cert_path.clone()),
Some(key_path.clone()),
)?;
self.set_certificate(entry);
Ok(())
}
fn parse_pem(
&self,
_cert_pem: &[u8],
_key_pem: &[u8],
cert_path: Option<PathBuf>,
key_path: Option<PathBuf>,
) -> TlsResult<CertificateEntry> {
Ok(CertificateEntry {
cert_chain: Vec::new(),
private_key: Vec::new(),
info: CertificateInfo::empty(),
loaded_at: Instant::now(),
cert_path,
key_path,
})
}
fn set_certificate(&self, entry: CertificateEntry) {
let prev = {
let mut current = self.current.write();
let prev = current.take();
*current = Some(entry.clone());
prev
};
if let Some(prev_entry) = prev {
*self.previous.write() = Some(prev_entry);
}
if let Some(callback) = self.on_rotation.read().as_ref() {
callback(&entry);
}
}
pub fn current(&self) -> Option<CertificateEntry> {
self.current.read().clone()
}
pub fn previous(&self) -> Option<CertificateEntry> {
self.previous.read().clone()
}
pub fn check_reload(&self) -> TlsResult<bool> {
let now = Instant::now();
let should_check = {
let last = *self.last_check.read();
now.duration_since(last) >= self.check_interval
};
if !should_check {
return Ok(false);
}
*self.last_check.write() = now;
let paths = {
let current = self.current.read();
current
.as_ref()
.and_then(|entry| match (&entry.cert_path, &entry.key_path) {
(Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
_ => None,
})
};
if let Some((cert_path, key_path)) = paths {
self.load_files(&cert_path, &key_path)?;
return Ok(true);
}
Ok(false)
}
pub fn is_expiring_soon(&self) -> bool {
self.current
.read()
.as_ref()
.map(|e| e.info.expires_within(self.expiry_warning_threshold))
.unwrap_or(false)
}
pub fn remaining_validity(&self) -> Option<Duration> {
self.current
.read()
.as_ref()
.and_then(|e| e.info.remaining_validity())
}
}
impl Default for CertificateStore {
fn default() -> Self {
Self::new()
}
}
pub struct SniResolver {
default_store: Arc<CertificateStore>,
hostname_stores: RwLock<std::collections::HashMap<String, Arc<CertificateStore>>>,
}
impl SniResolver {
pub fn new(default_store: Arc<CertificateStore>) -> Self {
Self {
default_store,
hostname_stores: RwLock::new(std::collections::HashMap::new()),
}
}
pub fn add_hostname(&self, hostname: impl Into<String>, store: Arc<CertificateStore>) {
let mut stores = self.hostname_stores.write();
stores.insert(hostname.into(), store);
}
pub fn remove_hostname(&self, hostname: &str) -> bool {
let mut stores = self.hostname_stores.write();
stores.remove(hostname).is_some()
}
pub fn resolve(&self, hostname: &str) -> Arc<CertificateStore> {
let stores = self.hostname_stores.read();
if let Some(store) = stores.get(hostname) {
return Arc::clone(store);
}
if let Some(dot_pos) = hostname.find('.') {
let wildcard = format!("*.{}", &hostname[dot_pos + 1..]);
if let Some(store) = stores.get(&wildcard) {
return Arc::clone(store);
}
}
Arc::clone(&self.default_store)
}
pub fn hostnames(&self) -> Vec<String> {
let stores = self.hostname_stores.read();
stores.keys().cloned().collect()
}
pub fn hostname_count(&self) -> usize {
self.hostname_stores.read().len()
}
}
pub struct TlsAcceptor {
config: TlsConfig,
cert_store: Arc<CertificateStore>,
sni_resolver: Option<Arc<SniResolver>>,
stats: TlsAcceptorStats,
}
#[derive(Debug, Default)]
pub struct TlsAcceptorStats {
pub handshakes_attempted: std::sync::atomic::AtomicU64,
pub handshakes_succeeded: std::sync::atomic::AtomicU64,
pub handshakes_failed: std::sync::atomic::AtomicU64,
pub client_auth_failures: std::sync::atomic::AtomicU64,
pub sni_mismatches: std::sync::atomic::AtomicU64,
}
impl TlsAcceptor {
pub fn new(config: TlsConfig) -> TlsResult<Self> {
if !config.is_server {
return Err(TlsError::Configuration(
"TlsAcceptor requires a server configuration".to_string(),
));
}
let cert_store =
Arc::new(CertificateStore::new().with_check_interval(config.rotation_check_interval));
if let (Some(cert_path), Some(key_path)) = (&config.cert_path, &config.key_path) {
cert_store.load_files(cert_path, key_path)?;
} else if let (Some(cert_pem), Some(key_pem)) = (&config.cert_chain, &config.private_key) {
cert_store.load_pem(cert_pem, key_pem)?;
}
Ok(Self {
config,
cert_store,
sni_resolver: None,
stats: TlsAcceptorStats::default(),
})
}
pub fn with_sni_resolver(mut self, resolver: Arc<SniResolver>) -> Self {
self.sni_resolver = Some(resolver);
self
}
pub fn config(&self) -> &TlsConfig {
&self.config
}
pub fn cert_store(&self) -> &Arc<CertificateStore> {
&self.cert_store
}
pub fn stats(&self) -> TlsAcceptorStatsSnapshot {
TlsAcceptorStatsSnapshot {
handshakes_attempted: self
.stats
.handshakes_attempted
.load(std::sync::atomic::Ordering::Relaxed),
handshakes_succeeded: self
.stats
.handshakes_succeeded
.load(std::sync::atomic::Ordering::Relaxed),
handshakes_failed: self
.stats
.handshakes_failed
.load(std::sync::atomic::Ordering::Relaxed),
client_auth_failures: self
.stats
.client_auth_failures
.load(std::sync::atomic::Ordering::Relaxed),
sni_mismatches: self
.stats
.sni_mismatches
.load(std::sync::atomic::Ordering::Relaxed),
}
}
pub fn check_rotation(&self) -> TlsResult<bool> {
self.cert_store.check_reload()
}
pub fn is_cert_expiring_soon(&self) -> bool {
self.cert_store.is_expiring_soon()
}
}
#[derive(Debug, Clone)]
pub struct TlsAcceptorStatsSnapshot {
pub handshakes_attempted: u64,
pub handshakes_succeeded: u64,
pub handshakes_failed: u64,
pub client_auth_failures: u64,
pub sni_mismatches: u64,
}
impl TlsAcceptorStatsSnapshot {
pub fn success_rate(&self) -> f64 {
if self.handshakes_attempted == 0 {
1.0
} else {
self.handshakes_succeeded as f64 / self.handshakes_attempted as f64
}
}
}
pub struct TlsConnector {
config: TlsConfig,
root_certs: Vec<Vec<u8>>,
client_cert: Option<CertificateEntry>,
allow_invalid_certs: bool,
allow_invalid_hostnames: bool,
}
impl TlsConnector {
pub fn new(config: TlsConfig) -> TlsResult<Self> {
if config.is_server {
return Err(TlsError::Configuration(
"TlsConnector requires a client configuration".to_string(),
));
}
Ok(Self {
config,
root_certs: Vec::new(),
client_cert: None,
allow_invalid_certs: false,
allow_invalid_hostnames: false,
})
}
pub fn with_native_roots() -> TlsResult<Self> {
let config = TlsConfig::client().build()?;
Ok(Self {
config,
root_certs: Vec::new(), client_cert: None,
allow_invalid_certs: false,
allow_invalid_hostnames: false,
})
}
pub fn with_client_cert(mut self, cert: CertificateEntry) -> Self {
self.client_cert = Some(cert);
self
}
pub fn danger_accept_invalid_certs(mut self, allow: bool) -> Self {
self.allow_invalid_certs = allow;
self
}
pub fn danger_accept_invalid_hostnames(mut self, allow: bool) -> Self {
self.allow_invalid_hostnames = allow;
self
}
pub fn config(&self) -> &TlsConfig {
&self.config
}
pub fn has_client_cert(&self) -> bool {
self.client_cert.is_some()
}
}
#[derive(Debug, Clone)]
pub struct TlsSessionInfo {
pub version: TlsVersion,
pub cipher_suite: String,
pub sni_hostname: Option<String>,
pub alpn_protocol: Option<String>,
pub server_cert: Option<CertificateInfo>,
pub client_cert: Option<CertificateInfo>,
pub resumed: bool,
pub handshake_duration: Duration,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tls_version() {
assert_eq!(TlsVersion::Tls12.as_str(), "TLS 1.2");
assert_eq!(TlsVersion::Tls13.as_str(), "TLS 1.3");
assert_eq!(TlsVersion::default(), TlsVersion::Tls13);
}
#[test]
fn test_client_auth() {
assert!(!ClientAuth::None.verifies_client());
assert!(ClientAuth::Optional.verifies_client());
assert!(ClientAuth::Required.verifies_client());
assert!(!ClientAuth::None.requires_client());
assert!(!ClientAuth::Optional.requires_client());
assert!(ClientAuth::Required.requires_client());
}
#[test]
fn test_tls_config_builder_server() {
let config = TlsConfig::server()
.with_cert(vec![1, 2, 3])
.with_key(vec![4, 5, 6])
.with_client_auth(ClientAuth::Required)
.with_alpn(vec!["h2".to_string()])
.with_sni_hostname("example.com")
.build()
.unwrap();
assert!(config.is_server());
assert!(config.is_mtls());
assert_eq!(config.client_auth, ClientAuth::Required);
assert_eq!(config.alpn_protocols, vec!["h2"]);
assert!(config.sni_hostnames.contains(&"example.com".to_string()));
}
#[test]
fn test_tls_config_builder_client() {
let config = TlsConfig::client()
.with_min_version(TlsVersion::Tls13)
.build()
.unwrap();
assert!(config.is_client());
assert!(!config.is_server());
assert_eq!(config.min_version, TlsVersion::Tls13);
}
#[test]
fn test_tls_config_server_requires_cert() {
let result = TlsConfig::server().build();
assert!(result.is_err());
}
#[test]
fn test_certificate_info_validity() {
let mut info = CertificateInfo::empty();
assert!(info.is_valid());
info.not_before = Some(SystemTime::now() - Duration::from_secs(3600));
info.not_after = Some(SystemTime::now() + Duration::from_secs(3600));
assert!(info.is_valid());
info.not_after = Some(SystemTime::now() - Duration::from_secs(1));
assert!(!info.is_valid());
}
#[test]
fn test_certificate_info_hostname_match() {
let mut info = CertificateInfo::empty();
info.subject_cn = Some("example.com".to_string());
info.san_dns = vec!["www.example.com".to_string(), "*.test.com".to_string()];
info.san_ips = vec!["192.168.1.1".to_string()];
assert!(info.matches_hostname("example.com"));
assert!(info.matches_hostname("www.example.com"));
assert!(info.matches_hostname("api.test.com"));
assert!(!info.matches_hostname("sub.api.test.com"));
assert!(info.matches_hostname("192.168.1.1"));
assert!(!info.matches_hostname("other.com"));
}
#[test]
fn test_certificate_store() {
let store = CertificateStore::new();
assert!(store.current().is_none());
assert!(store.previous().is_none());
assert!(!store.is_expiring_soon());
}
#[test]
fn test_sni_resolver() {
let default_store = Arc::new(CertificateStore::new());
let resolver = SniResolver::new(default_store);
let example_store = Arc::new(CertificateStore::new());
resolver.add_hostname("example.com", example_store);
assert_eq!(resolver.hostname_count(), 1);
assert!(resolver.hostnames().contains(&"example.com".to_string()));
let _ = resolver.resolve("example.com");
let _ = resolver.resolve("other.com");
assert!(resolver.remove_hostname("example.com"));
assert_eq!(resolver.hostname_count(), 0);
}
#[test]
fn test_tls_acceptor_stats() {
let config = TlsConfig::server()
.with_cert(vec![1])
.with_key(vec![2])
.build()
.unwrap();
let acceptor = TlsAcceptor::new(config).unwrap();
let stats = acceptor.stats();
assert_eq!(stats.handshakes_attempted, 0);
assert_eq!(stats.success_rate(), 1.0);
}
#[test]
fn test_tls_connector() {
let connector = TlsConnector::with_native_roots().unwrap();
assert!(!connector.has_client_cert());
assert!(connector.config().is_client());
}
#[test]
fn test_tls_error_display() {
let err = TlsError::CertificateNotFound("/path/to/cert".to_string());
assert!(err.to_string().contains("/path/to/cert"));
let err = TlsError::HandshakeFailed("protocol error".to_string());
assert!(err.to_string().contains("handshake"));
}
#[test]
fn test_mtls_config() {
let config = TlsConfig::server()
.with_cert(vec![1])
.with_key(vec![2])
.with_mtls()
.with_ca_cert(vec![3])
.build()
.unwrap();
assert!(config.is_mtls());
assert_eq!(config.client_auth, ClientAuth::Required);
assert!(config.ca_certs.is_some());
}
}