use std::path::PathBuf;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TlsConfig {
pub cert_path: PathBuf,
pub key_path: PathBuf,
pub ca_cert_path: PathBuf,
pub min_tls_version: TlsVersion,
pub require_client_cert: bool,
pub cert_reload_interval_secs: u64,
#[serde(default)]
pub crl_path: Option<PathBuf>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum TlsVersion {
Tls12,
Tls13,
}
impl Default for TlsConfig {
fn default() -> Self {
Self {
cert_path: PathBuf::from("/etc/nodedb/tls/node.crt"),
key_path: PathBuf::from("/etc/nodedb/tls/node.key"),
ca_cert_path: PathBuf::from("/etc/nodedb/tls/ca.crt"),
min_tls_version: TlsVersion::Tls13,
require_client_cert: true,
cert_reload_interval_secs: 3600,
crl_path: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PeerIdentity {
pub common_name: String,
pub san_dns: Vec<String>,
pub is_node: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CertValidation {
Valid(PeerIdentity),
Expired { cn: String, expired_at: String },
UntrustedCa,
NoCertificate,
Revoked { cn: String },
}
impl CertValidation {
pub fn is_valid(&self) -> bool {
matches!(self, CertValidation::Valid(_))
}
pub fn identity(&self) -> Option<&PeerIdentity> {
match self {
CertValidation::Valid(id) => Some(id),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct CertRotationState {
pub loaded_at_us: u64,
pub serial: String,
pub rotations: u64,
pub last_error: Option<String>,
}
impl CertRotationState {
pub fn new(serial: String) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64;
Self {
loaded_at_us: now,
serial,
rotations: 0,
last_error: None,
}
}
pub fn rotated(&mut self, new_serial: String) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64;
self.loaded_at_us = now;
self.serial = new_serial;
self.rotations += 1;
self.last_error = None;
}
pub fn rotation_failed(&mut self, error: String) {
self.last_error = Some(error);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tls_config_defaults() {
let cfg = TlsConfig::default();
assert_eq!(cfg.min_tls_version, TlsVersion::Tls13);
assert!(cfg.require_client_cert);
assert_eq!(cfg.cert_reload_interval_secs, 3600);
}
#[test]
fn cert_validation_states() {
let valid = CertValidation::Valid(PeerIdentity {
common_name: "node-1.nodedb.local".into(),
san_dns: vec!["node-1.nodedb.local".into()],
is_node: true,
});
assert!(valid.is_valid());
assert_eq!(valid.identity().unwrap().common_name, "node-1.nodedb.local");
let expired = CertValidation::Expired {
cn: "node-2".into(),
expired_at: "2024-01-01".into(),
};
assert!(!expired.is_valid());
assert!(expired.identity().is_none());
assert!(!CertValidation::UntrustedCa.is_valid());
assert!(!CertValidation::NoCertificate.is_valid());
}
#[test]
fn cert_rotation_lifecycle() {
let mut state = CertRotationState::new("SERIAL-001".into());
assert_eq!(state.rotations, 0);
assert!(state.last_error.is_none());
state.rotated("SERIAL-002".into());
assert_eq!(state.rotations, 1);
assert_eq!(state.serial, "SERIAL-002");
state.rotation_failed("file not found".into());
assert_eq!(state.last_error.as_deref(), Some("file not found"));
state.rotated("SERIAL-003".into());
assert_eq!(state.rotations, 2);
assert!(state.last_error.is_none());
}
#[test]
fn peer_identity_node_vs_client() {
let node = PeerIdentity {
common_name: "node-1".into(),
san_dns: vec![],
is_node: true,
};
assert!(node.is_node);
let client = PeerIdentity {
common_name: "app-client".into(),
san_dns: vec![],
is_node: false,
};
assert!(!client.is_node);
}
}