use std::fs;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use parking_lot::RwLock;
use rcgen::{CertificateParams, DistinguishedName, DnType, Issuer, KeyPair, SanType};
use rustls::RootCertStore;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio::sync::watch;
use tracing::{debug, error, info, warn};
use x509_parser::prelude::*;
use crate::error::{NetError, NetResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CertificateFormat {
Pem,
Der,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PrivateKeyType {
Rsa,
Ecdsa,
Ed25519,
Pkcs8,
}
#[derive(Debug, Clone)]
pub struct CertificateInfo {
pub common_name: Option<String>,
pub subject_alt_names: Vec<String>,
pub issuer: Option<String>,
pub serial_number: String,
pub not_before: SystemTime,
pub not_after: SystemTime,
pub is_ca: bool,
pub key_usage: Vec<String>,
pub extended_key_usage: Vec<String>,
pub fingerprint_sha256: String,
}
impl CertificateInfo {
pub fn is_valid(&self) -> bool {
let now = SystemTime::now();
now >= self.not_before && now <= self.not_after
}
pub fn time_to_expiry(&self) -> Option<Duration> {
SystemTime::now()
.duration_since(self.not_after)
.ok()
.map(|_| Duration::ZERO)
.or_else(|| self.not_after.duration_since(SystemTime::now()).ok())
}
pub fn expires_within(&self, duration: Duration) -> bool {
self.time_to_expiry()
.is_some_and(|remaining| remaining <= duration)
}
}
#[derive(Debug, Clone)]
pub struct CertificateLoader {
validate_on_load: bool,
}
impl Default for CertificateLoader {
fn default() -> Self {
Self::new()
}
}
impl CertificateLoader {
pub fn new() -> Self {
Self {
validate_on_load: true,
}
}
pub fn without_validation() -> Self {
Self {
validate_on_load: false,
}
}
pub fn load_pem_file<P: AsRef<Path>>(
&self,
path: P,
) -> NetResult<Vec<CertificateDer<'static>>> {
let path = path.as_ref();
debug!(path = %path.display(), "Loading PEM certificates from file");
let file = fs::File::open(path)
.map_err(|e| NetError::InvalidCertificate(format!("Failed to open PEM file: {e}")))?;
let mut reader = BufReader::new(file);
let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
.filter_map(|result| result.ok())
.collect();
if certs.is_empty() {
return Err(NetError::InvalidCertificate(
"No certificates found in PEM file".to_string(),
));
}
if self.validate_on_load {
for cert in &certs {
self.validate_certificate_der(cert)?;
}
}
info!(count = certs.len(), "Loaded certificates from PEM file");
Ok(certs)
}
pub fn load_der_file<P: AsRef<Path>>(&self, path: P) -> NetResult<CertificateDer<'static>> {
let path = path.as_ref();
debug!(path = %path.display(), "Loading DER certificate from file");
let der_data = fs::read(path)
.map_err(|e| NetError::InvalidCertificate(format!("Failed to read DER file: {e}")))?;
let cert = CertificateDer::from(der_data);
if self.validate_on_load {
self.validate_certificate_der(&cert)?;
}
info!("Loaded DER certificate from file");
Ok(cert)
}
pub fn load_pem_bytes(&self, pem_data: &[u8]) -> NetResult<Vec<CertificateDer<'static>>> {
let mut reader = BufReader::new(pem_data);
let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
.filter_map(|result| result.ok())
.collect();
if certs.is_empty() {
return Err(NetError::InvalidCertificate(
"No certificates found in PEM data".to_string(),
));
}
if self.validate_on_load {
for cert in &certs {
self.validate_certificate_der(cert)?;
}
}
Ok(certs)
}
pub fn load_der_bytes(&self, der_data: &[u8]) -> NetResult<CertificateDer<'static>> {
let cert = CertificateDer::from(der_data.to_vec());
if self.validate_on_load {
self.validate_certificate_der(&cert)?;
}
Ok(cert)
}
fn validate_certificate_der(&self, cert: &CertificateDer<'_>) -> NetResult<()> {
let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
})?;
let now = ASN1Time::now();
if parsed.validity().not_before > now {
return Err(NetError::InvalidCertificate(
"Certificate is not yet valid".to_string(),
));
}
if parsed.validity().not_after < now {
return Err(NetError::InvalidCertificate(
"Certificate has expired".to_string(),
));
}
Ok(())
}
pub fn get_certificate_info(&self, cert: &CertificateDer<'_>) -> NetResult<CertificateInfo> {
let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
})?;
let common_name = parsed
.subject()
.iter_common_name()
.next()
.and_then(|cn| cn.as_str().ok())
.map(String::from);
let issuer = parsed
.issuer()
.iter_common_name()
.next()
.and_then(|cn| cn.as_str().ok())
.map(String::from);
let mut subject_alt_names = Vec::new();
if let Ok(Some(san)) = parsed.subject_alternative_name() {
for name in san.value.general_names.iter() {
match name {
GeneralName::DNSName(dns) => subject_alt_names.push(dns.to_string()),
GeneralName::IPAddress(ip) => {
if ip.len() == 4 {
subject_alt_names
.push(format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]));
} else if ip.len() == 16 {
let mut parts = Vec::with_capacity(8);
for i in 0..8 {
let val = u16::from_be_bytes([ip[i * 2], ip[i * 2 + 1]]);
parts.push(format!("{val:x}"));
}
subject_alt_names.push(parts.join(":"));
}
}
GeneralName::RFC822Name(email) => subject_alt_names.push(email.to_string()),
GeneralName::URI(uri) => subject_alt_names.push(uri.to_string()),
_ => {}
}
}
}
let serial_number = format!("{:x}", parsed.serial);
let not_before = asn1_time_to_system_time(&parsed.validity().not_before);
let not_after = asn1_time_to_system_time(&parsed.validity().not_after);
let is_ca = parsed.is_ca();
let mut key_usage = Vec::new();
if let Ok(Some(ku)) = parsed.key_usage() {
let flags = ku.value;
if flags.digital_signature() {
key_usage.push("digitalSignature".to_string());
}
if flags.non_repudiation() {
key_usage.push("nonRepudiation".to_string());
}
if flags.key_encipherment() {
key_usage.push("keyEncipherment".to_string());
}
if flags.data_encipherment() {
key_usage.push("dataEncipherment".to_string());
}
if flags.key_agreement() {
key_usage.push("keyAgreement".to_string());
}
if flags.key_cert_sign() {
key_usage.push("keyCertSign".to_string());
}
if flags.crl_sign() {
key_usage.push("cRLSign".to_string());
}
}
let mut extended_key_usage = Vec::new();
if let Ok(Some(eku)) = parsed.extended_key_usage() {
for oid in eku.value.other.iter() {
extended_key_usage.push(oid.to_string());
}
if eku.value.any {
extended_key_usage.push("anyExtendedKeyUsage".to_string());
}
if eku.value.server_auth {
extended_key_usage.push("serverAuth".to_string());
}
if eku.value.client_auth {
extended_key_usage.push("clientAuth".to_string());
}
if eku.value.code_signing {
extended_key_usage.push("codeSigning".to_string());
}
if eku.value.email_protection {
extended_key_usage.push("emailProtection".to_string());
}
if eku.value.time_stamping {
extended_key_usage.push("timeStamping".to_string());
}
if eku.value.ocsp_signing {
extended_key_usage.push("ocspSigning".to_string());
}
}
use std::fmt::Write;
let fingerprint_sha256 = cert
.as_ref()
.iter()
.take(32) .fold(String::new(), |mut s, b| {
let _ = write!(&mut s, "{b:02x}");
s
});
Ok(CertificateInfo {
common_name,
subject_alt_names,
issuer,
serial_number,
not_before,
not_after,
is_ca,
key_usage,
extended_key_usage,
fingerprint_sha256,
})
}
}
fn asn1_time_to_system_time(time: &ASN1Time) -> SystemTime {
let timestamp = time.timestamp();
if timestamp >= 0 {
SystemTime::UNIX_EPOCH + Duration::from_secs(timestamp as u64)
} else {
SystemTime::UNIX_EPOCH
}
}
#[derive(Debug, Clone)]
pub struct PrivateKeyLoader;
impl Default for PrivateKeyLoader {
fn default() -> Self {
Self::new()
}
}
impl PrivateKeyLoader {
pub fn new() -> Self {
Self
}
pub fn load_pem_file<P: AsRef<Path>>(&self, path: P) -> NetResult<PrivateKeyDer<'static>> {
let path = path.as_ref();
debug!(path = %path.display(), "Loading private key from PEM file");
let file = fs::File::open(path)
.map_err(|e| NetError::InvalidCertificate(format!("Failed to open key file: {e}")))?;
let mut reader = BufReader::new(file);
self.load_from_reader(&mut reader)
}
pub fn load_pem_bytes(&self, pem_data: &[u8]) -> NetResult<PrivateKeyDer<'static>> {
let mut reader = BufReader::new(pem_data);
self.load_from_reader(&mut reader)
}
pub fn load_der_file<P: AsRef<Path>>(
&self,
path: P,
key_type: PrivateKeyType,
) -> NetResult<PrivateKeyDer<'static>> {
let path = path.as_ref();
debug!(path = %path.display(), "Loading private key from DER file");
let der_data = fs::read(path)
.map_err(|e| NetError::InvalidCertificate(format!("Failed to read key file: {e}")))?;
self.load_der_bytes(&der_data, key_type)
}
pub fn load_der_bytes(
&self,
der_data: &[u8],
key_type: PrivateKeyType,
) -> NetResult<PrivateKeyDer<'static>> {
let key = match key_type {
PrivateKeyType::Rsa => PrivateKeyDer::Pkcs1(der_data.to_vec().into()),
PrivateKeyType::Ecdsa | PrivateKeyType::Ed25519 => {
PrivateKeyDer::Sec1(der_data.to_vec().into())
}
PrivateKeyType::Pkcs8 => {
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(der_data.to_vec()))
}
};
Ok(key)
}
pub fn load_encrypted_pem_file<P: AsRef<Path>>(
&self,
path: P,
password: &str,
) -> NetResult<PrivateKeyDer<'static>> {
let path = path.as_ref();
debug!(path = %path.display(), "Loading potentially encrypted private key");
let pem_data = fs::read(path)
.map_err(|e| NetError::InvalidCertificate(format!("Failed to read key file: {e}")))?;
let pem_str = std::str::from_utf8(&pem_data).map_err(|e| {
NetError::InvalidCertificate(format!("Key file is not valid UTF-8: {e}"))
})?;
let enc_format = detect_encrypted_pem(pem_str);
match enc_format {
EncryptedPemFormat::NotEncrypted => {
debug!("Key is not encrypted, loading directly");
self.load_pem_bytes(&pem_data)
}
EncryptedPemFormat::Pkcs8Encrypted => {
let effective_password = resolve_password(password)?;
decrypt_pkcs8_encrypted_pem(pem_str, &effective_password)
}
EncryptedPemFormat::LegacyEncrypted => {
let effective_password = resolve_password(password)?;
decrypt_legacy_encrypted_pem(pem_str, &effective_password)
}
}
}
pub fn load_encrypted_pem_file_env<P: AsRef<Path>>(
&self,
path: P,
) -> NetResult<PrivateKeyDer<'static>> {
self.load_encrypted_pem_file(path, "")
}
fn load_from_reader<R: std::io::BufRead>(
&self,
reader: &mut R,
) -> NetResult<PrivateKeyDer<'static>> {
let mut original_data: Vec<u8> = Vec::new();
reader
.read_to_end(&mut original_data)
.map_err(|e| NetError::InvalidCertificate(format!("Failed to read key data: {e}")))?;
let mut cursor = std::io::Cursor::new(&original_data);
if let Some(Ok(key)) = rustls_pemfile::pkcs8_private_keys(&mut cursor).next() {
info!("Loaded PKCS#8 private key");
return Ok(PrivateKeyDer::Pkcs8(key));
}
let mut cursor = std::io::Cursor::new(&original_data);
if let Some(Ok(key)) = rustls_pemfile::rsa_private_keys(&mut cursor).next() {
info!("Loaded RSA private key");
return Ok(PrivateKeyDer::Pkcs1(key));
}
let mut cursor = std::io::Cursor::new(&original_data);
if let Some(Ok(key)) = rustls_pemfile::ec_private_keys(&mut cursor).next() {
info!("Loaded EC private key");
return Ok(PrivateKeyDer::Sec1(key));
}
Err(NetError::InvalidCertificate(
"No valid private key found in PEM data (tried PKCS#8, RSA, EC formats)".to_string(),
))
}
}
pub use crate::tls_crypto::{
EncryptedPemFormat, detect_encrypted_pem, parse_dek_info, pbkdf2_hmac_sha1, pbkdf2_hmac_sha256,
};
use crate::tls_crypto::{
decrypt_legacy_encrypted_pem, decrypt_pkcs8_encrypted_pem, resolve_password,
};
#[derive(Debug, Clone)]
pub struct SelfSignedGenerator {
subject_alt_names: Vec<String>,
common_name: String,
organization: Option<String>,
validity_days: u32,
is_ca: bool,
}
impl SelfSignedGenerator {
pub fn new(common_name: impl Into<String>) -> Self {
Self {
common_name: common_name.into(),
subject_alt_names: vec!["localhost".to_string()],
organization: None,
validity_days: 365,
is_ca: false,
}
}
pub fn with_san(mut self, san: impl Into<String>) -> Self {
self.subject_alt_names.push(san.into());
self
}
pub fn with_sans<I, S>(mut self, sans: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.subject_alt_names
.extend(sans.into_iter().map(|s| s.into()));
self
}
pub fn with_organization(mut self, org: impl Into<String>) -> Self {
self.organization = Some(org.into());
self
}
pub fn with_validity_days(mut self, days: u32) -> Self {
self.validity_days = days;
self
}
pub fn as_ca(mut self) -> Self {
self.is_ca = true;
self
}
pub fn generate(&self) -> NetResult<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
let mut params = CertificateParams::default();
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, &self.common_name);
if let Some(ref org) = self.organization {
dn.push(DnType::OrganizationName, org);
}
params.distinguished_name = dn;
params.subject_alt_names = self
.subject_alt_names
.iter()
.map(|name| {
if let Ok(ip) = name.parse::<std::net::IpAddr>() {
SanType::IpAddress(ip)
} else {
SanType::DnsName(name.clone().try_into().unwrap_or_else(|_| {
"localhost"
.to_string()
.try_into()
.expect("localhost is valid DNS name")
}))
}
})
.collect();
params.not_before = rcgen::date_time_ymd(
chrono::Utc::now().year(),
chrono::Utc::now().month() as u8,
chrono::Utc::now().day() as u8,
);
let future = chrono::Utc::now() + chrono::Duration::days(self.validity_days as i64);
params.not_after =
rcgen::date_time_ymd(future.year(), future.month() as u8, future.day() as u8);
if self.is_ca {
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
}
let key_pair = KeyPair::generate().map_err(|e| {
NetError::InvalidCertificate(format!("Failed to generate key pair: {e}"))
})?;
let cert = params.self_signed(&key_pair).map_err(|e| {
NetError::InvalidCertificate(format!("Failed to generate certificate: {e}"))
})?;
let cert_der = CertificateDer::from(cert.der().to_vec());
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
info!(
common_name = %self.common_name,
is_ca = self.is_ca,
validity_days = self.validity_days,
"Generated self-signed certificate"
);
Ok((cert_der, key_der))
}
pub fn generate_signed_by_keypair(
&self,
ca_key_pair: &KeyPair,
ca_common_name: &str,
) -> NetResult<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
let mut params = CertificateParams::default();
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, &self.common_name);
if let Some(ref org) = self.organization {
dn.push(DnType::OrganizationName, org);
}
params.distinguished_name = dn;
params.subject_alt_names = self
.subject_alt_names
.iter()
.map(|name| {
if let Ok(ip) = name.parse::<std::net::IpAddr>() {
SanType::IpAddress(ip)
} else {
SanType::DnsName(name.clone().try_into().unwrap_or_else(|_| {
"localhost"
.to_string()
.try_into()
.expect("localhost is valid DNS name")
}))
}
})
.collect();
params.not_before = rcgen::date_time_ymd(
chrono::Utc::now().year(),
chrono::Utc::now().month() as u8,
chrono::Utc::now().day() as u8,
);
let future = chrono::Utc::now() + chrono::Duration::days(self.validity_days as i64);
params.not_after =
rcgen::date_time_ymd(future.year(), future.month() as u8, future.day() as u8);
let key_pair = KeyPair::generate().map_err(|e| {
NetError::InvalidCertificate(format!("Failed to generate key pair: {e}"))
})?;
let mut ca_params = CertificateParams::default();
ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
let mut issuer_dn = DistinguishedName::new();
issuer_dn.push(DnType::CommonName, ca_common_name);
ca_params.distinguished_name = issuer_dn;
let issuer = Issuer::from_params(&ca_params, ca_key_pair);
let signed_cert = params.signed_by(&key_pair, &issuer).map_err(|e| {
NetError::InvalidCertificate(format!("Failed to sign certificate: {e}"))
})?;
let cert_der = CertificateDer::from(signed_cert.der().to_vec());
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
info!(
common_name = %self.common_name,
"Generated CA-signed certificate"
);
Ok((cert_der, key_der))
}
}
use chrono::Datelike;
#[derive(Debug)]
pub struct CertificateStore {
roots: Arc<RwLock<RootCertStore>>,
cert_chain: Arc<RwLock<Vec<CertificateDer<'static>>>>,
cert_info: Arc<RwLock<Vec<CertificateInfo>>>,
}
impl Default for CertificateStore {
fn default() -> Self {
Self::new()
}
}
impl Clone for CertificateStore {
fn clone(&self) -> Self {
Self {
roots: Arc::new(RwLock::new((*self.roots.read()).clone())),
cert_chain: Arc::new(RwLock::new(self.cert_chain.read().clone())),
cert_info: Arc::new(RwLock::new(self.cert_info.read().clone())),
}
}
}
impl CertificateStore {
pub fn new() -> Self {
Self {
roots: Arc::new(RwLock::new(RootCertStore::empty())),
cert_chain: Arc::new(RwLock::new(Vec::new())),
cert_info: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn add_system_roots(&mut self) -> NetResult<usize> {
let mut roots = self.roots.write();
let count_before = roots.len();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let added = roots.len() - count_before;
info!(count = added, "Added system root certificates");
Ok(added)
}
pub fn add_certificate(&mut self, cert: CertificateDer<'static>) -> NetResult<()> {
let loader = CertificateLoader::new();
let info = loader.get_certificate_info(&cert)?;
if !info.is_ca {
warn!(common_name = ?info.common_name, "Adding non-CA certificate to root store");
}
{
let mut roots = self.roots.write();
roots.add(cert.clone()).map_err(|e| {
NetError::InvalidCertificate(format!("Failed to add certificate: {e}"))
})?;
}
{
let mut chain = self.cert_chain.write();
chain.push(cert);
}
{
let mut infos = self.cert_info.write();
infos.push(info);
}
Ok(())
}
pub fn add_certificates_from_file<P: AsRef<Path>>(&mut self, path: P) -> NetResult<usize> {
let loader = CertificateLoader::new();
let certs = loader.load_pem_file(path)?;
let count = certs.len();
for cert in certs {
self.add_certificate(cert)?;
}
Ok(count)
}
pub fn get_root_store(&self) -> RootCertStore {
self.roots.read().clone()
}
pub fn get_cert_chain(&self) -> Vec<CertificateDer<'static>> {
self.cert_chain.read().clone()
}
pub fn len(&self) -> usize {
self.roots.read().len()
}
pub fn is_empty(&self) -> bool {
self.roots.read().is_empty()
}
pub fn get_certificate_infos(&self) -> Vec<CertificateInfo> {
self.cert_info.read().clone()
}
pub fn check_expiring(&self, within: Duration) -> Vec<CertificateInfo> {
self.cert_info
.read()
.iter()
.filter(|info| info.expires_within(within))
.cloned()
.collect()
}
}
#[derive(Debug, Clone)]
enum PrivateKeyData {
Pkcs8(Vec<u8>),
Pkcs1(Vec<u8>),
Sec1(Vec<u8>),
}
impl PrivateKeyData {
fn from_key(key: &PrivateKeyDer<'_>) -> Self {
match key {
PrivateKeyDer::Pkcs8(k) => Self::Pkcs8(k.secret_pkcs8_der().to_vec()),
PrivateKeyDer::Pkcs1(k) => Self::Pkcs1(k.secret_pkcs1_der().to_vec()),
PrivateKeyDer::Sec1(k) => Self::Sec1(k.secret_sec1_der().to_vec()),
_ => Self::Pkcs8(Vec::new()), }
}
fn to_key(&self) -> PrivateKeyDer<'static> {
match self {
Self::Pkcs8(data) => PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(data.clone())),
Self::Pkcs1(data) => PrivateKeyDer::Pkcs1(data.clone().into()),
Self::Sec1(data) => PrivateKeyDer::Sec1(data.clone().into()),
}
}
}
pub struct HotReloadableCertificates {
cert_chain: Arc<RwLock<Vec<CertificateDer<'static>>>>,
private_key_data: Arc<RwLock<Option<PrivateKeyData>>>,
update_tx: watch::Sender<u64>,
version: Arc<RwLock<u64>>,
cert_path: Arc<RwLock<Option<std::path::PathBuf>>>,
key_path: Arc<RwLock<Option<std::path::PathBuf>>>,
}
impl std::fmt::Debug for HotReloadableCertificates {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HotReloadableCertificates")
.field("version", &*self.version.read())
.field("cert_count", &self.cert_chain.read().len())
.field("has_key", &self.private_key_data.read().is_some())
.finish()
}
}
impl Default for HotReloadableCertificates {
fn default() -> Self {
Self::new()
}
}
impl HotReloadableCertificates {
pub fn new() -> Self {
let (update_tx, _) = watch::channel(0u64);
Self {
cert_chain: Arc::new(RwLock::new(Vec::new())),
private_key_data: Arc::new(RwLock::new(None)),
update_tx,
version: Arc::new(RwLock::new(0)),
cert_path: Arc::new(RwLock::new(None)),
key_path: Arc::new(RwLock::new(None)),
}
}
pub fn load_from_files<P: AsRef<Path>>(&self, cert_path: P, key_path: P) -> NetResult<()> {
let cert_path = cert_path.as_ref();
let key_path = key_path.as_ref();
let loader = CertificateLoader::new();
let key_loader = PrivateKeyLoader::new();
let certs = loader.load_pem_file(cert_path)?;
let key = key_loader.load_pem_file(key_path)?;
{
let mut chain = self.cert_chain.write();
*chain = certs;
}
{
let mut pk = self.private_key_data.write();
*pk = Some(PrivateKeyData::from_key(&key));
}
{
let mut cp = self.cert_path.write();
*cp = Some(cert_path.to_path_buf());
}
{
let mut kp = self.key_path.write();
*kp = Some(key_path.to_path_buf());
}
self.increment_version();
info!(
cert_path = %cert_path.display(),
key_path = %key_path.display(),
"Loaded certificates from files"
);
Ok(())
}
pub fn reload(&self) -> NetResult<()> {
let cert_path = self.cert_path.read().clone();
let key_path = self.key_path.read().clone();
match (cert_path, key_path) {
(Some(cp), Some(kp)) => {
self.load_from_files(&cp, &kp)?;
info!("Reloaded certificates");
Ok(())
}
_ => Err(NetError::InvalidCertificate(
"No certificate paths configured for reload".to_string(),
)),
}
}
pub fn set_certificates(
&self,
certs: Vec<CertificateDer<'static>>,
key: PrivateKeyDer<'static>,
) {
{
let mut chain = self.cert_chain.write();
*chain = certs;
}
{
let mut pk = self.private_key_data.write();
*pk = Some(PrivateKeyData::from_key(&key));
}
self.increment_version();
}
pub fn get_cert_chain(&self) -> Vec<CertificateDer<'static>> {
self.cert_chain.read().clone()
}
pub fn get_private_key(&self) -> Option<PrivateKeyDer<'static>> {
self.private_key_data.read().as_ref().map(|k| k.to_key())
}
pub fn get_version(&self) -> u64 {
*self.version.read()
}
pub fn subscribe(&self) -> watch::Receiver<u64> {
self.update_tx.subscribe()
}
fn increment_version(&self) {
let mut version = self.version.write();
*version += 1;
let _ = self.update_tx.send(*version);
}
pub fn start_file_watcher(
self: Arc<Self>,
check_interval: Duration,
) -> NetResult<tokio::task::JoinHandle<()>> {
let cert_path = self.cert_path.read().clone();
let key_path = self.key_path.read().clone();
let (cert_path, key_path) = match (cert_path, key_path) {
(Some(cp), Some(kp)) => (cp, kp),
_ => {
return Err(NetError::InvalidCertificate(
"No certificate paths configured for file watching".to_string(),
));
}
};
let handle = tokio::spawn(async move {
let mut last_cert_modified = get_file_modified(&cert_path);
let mut last_key_modified = get_file_modified(&key_path);
loop {
tokio::time::sleep(check_interval).await;
let cert_modified = get_file_modified(&cert_path);
let key_modified = get_file_modified(&key_path);
let cert_changed = cert_modified != last_cert_modified;
let key_changed = key_modified != last_key_modified;
if cert_changed || key_changed {
info!(
cert_changed = cert_changed,
key_changed = key_changed,
"Detected certificate file change, reloading"
);
match self.reload() {
Ok(()) => {
last_cert_modified = cert_modified;
last_key_modified = key_modified;
}
Err(e) => {
error!(error = %e, "Failed to reload certificates");
}
}
}
}
});
Ok(handle)
}
}
fn get_file_modified<P: AsRef<Path>>(path: P) -> Option<SystemTime> {
fs::metadata(path.as_ref())
.ok()
.and_then(|m| m.modified().ok())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tls_crypto::{
AES_SBOX, aes_cbc_decrypt, aes_key_expansion, base64_decode_pure, evp_bytes_to_key_md5,
gf_mul, hex_decode, md5, parse_asn1_length, remove_pkcs7_padding, sha1, sha256,
};
use std::env::temp_dir;
#[test]
fn test_self_signed_generator() {
let generator = SelfSignedGenerator::new("test.example.com")
.with_san("localhost")
.with_san("127.0.0.1")
.with_organization("Test Org")
.with_validity_days(30);
let result = generator.generate();
assert!(result.is_ok());
let (cert, key) = result.expect("Should generate certificate");
assert!(!cert.as_ref().is_empty());
let loader = CertificateLoader::new();
let info = loader
.get_certificate_info(&cert)
.expect("Should parse certificate");
assert_eq!(info.common_name.as_deref(), Some("test.example.com"));
assert!(info.is_valid());
}
#[test]
fn test_ca_certificate_generation() {
let ca_generator = SelfSignedGenerator::new("Test CA")
.as_ca()
.with_validity_days(365);
let (ca_cert, ca_key) = ca_generator.generate().expect("Should generate CA");
let loader = CertificateLoader::new();
let ca_info = loader
.get_certificate_info(&ca_cert)
.expect("Should parse CA certificate");
assert!(ca_info.is_ca);
assert_eq!(ca_info.common_name.as_deref(), Some("Test CA"));
}
#[test]
fn test_certificate_store() {
let mut store = CertificateStore::new();
let generator = SelfSignedGenerator::new("test").as_ca();
let (cert, _) = generator.generate().expect("Should generate certificate");
assert!(store.is_empty());
store.add_certificate(cert).expect("Should add certificate");
assert!(!store.is_empty());
assert_eq!(store.len(), 1);
}
#[test]
fn test_certificate_store_system_roots() {
let mut store = CertificateStore::new();
let added = store.add_system_roots().expect("Should add system roots");
assert!(added > 0);
assert!(!store.is_empty());
}
#[test]
fn test_certificate_info_validity() {
let generator = SelfSignedGenerator::new("test").with_validity_days(30);
let (cert, _) = generator.generate().expect("Should generate certificate");
let loader = CertificateLoader::new();
let info = loader.get_certificate_info(&cert).expect("Should get info");
assert!(info.is_valid());
assert!(!info.expires_within(Duration::from_secs(0)));
assert!(info.expires_within(Duration::from_secs(31 * 24 * 60 * 60)));
}
#[test]
fn test_hot_reloadable_certificates() {
let hot_certs = HotReloadableCertificates::new();
let generator = SelfSignedGenerator::new("test");
let (cert, key) = generator.generate().expect("Should generate certificate");
assert_eq!(hot_certs.get_version(), 0);
hot_certs.set_certificates(vec![cert], key);
assert_eq!(hot_certs.get_version(), 1);
assert!(!hot_certs.get_cert_chain().is_empty());
assert!(hot_certs.get_private_key().is_some());
}
#[test]
fn test_pem_certificate_loading() {
let generator = SelfSignedGenerator::new("test");
let (cert, _) = generator.generate().expect("Should generate certificate");
let pem_content = format!(
"-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----\n",
base64_encode(cert.as_ref())
);
let temp_path = temp_dir().join("test_cert.pem");
fs::write(&temp_path, &pem_content).expect("Should write temp file");
let loader = CertificateLoader::new();
let result = loader.load_pem_file(&temp_path);
let _ = fs::remove_file(&temp_path);
assert!(result.is_ok());
}
#[test]
fn test_der_certificate_loading() {
let generator = SelfSignedGenerator::new("test");
let (cert, _) = generator.generate().expect("Should generate certificate");
let temp_path = temp_dir().join("test_cert.der");
fs::write(&temp_path, cert.as_ref()).expect("Should write temp file");
let loader = CertificateLoader::new();
let result = loader.load_der_file(&temp_path);
let _ = fs::remove_file(&temp_path);
assert!(result.is_ok());
}
fn base64_encode(data: &[u8]) -> String {
const ALPHABET: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
let mut i = 0;
while i < data.len() {
let b1 = data[i];
let b2 = data.get(i + 1).copied().unwrap_or(0);
let b3 = data.get(i + 2).copied().unwrap_or(0);
result.push(ALPHABET[(b1 >> 2) as usize] as char);
result.push(ALPHABET[(((b1 & 0x03) << 4) | (b2 >> 4)) as usize] as char);
if i + 1 < data.len() {
result.push(ALPHABET[(((b2 & 0x0f) << 2) | (b3 >> 6)) as usize] as char);
} else {
result.push('=');
}
if i + 2 < data.len() {
result.push(ALPHABET[(b3 & 0x3f) as usize] as char);
} else {
result.push('=');
}
i += 3;
}
let mut formatted = String::new();
for (idx, ch) in result.chars().enumerate() {
if idx > 0 && idx % 64 == 0 {
formatted.push('\n');
}
formatted.push(ch);
}
formatted
}
#[test]
fn test_detect_encrypted_pem_unencrypted() {
let pem = "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----\n";
assert_eq!(detect_encrypted_pem(pem), EncryptedPemFormat::NotEncrypted);
}
#[test]
fn test_detect_encrypted_pem_pkcs8() {
let pem =
"-----BEGIN ENCRYPTED PRIVATE KEY-----\nMIIE...\n-----END ENCRYPTED PRIVATE KEY-----\n";
assert_eq!(
detect_encrypted_pem(pem),
EncryptedPemFormat::Pkcs8Encrypted
);
}
#[test]
fn test_detect_encrypted_pem_legacy() {
let pem = "-----BEGIN RSA PRIVATE KEY-----\nProc-Type: 4,ENCRYPTED\nDEK-Info: AES-256-CBC,AABB\n\nMIIE...\n-----END RSA PRIVATE KEY-----\n";
assert_eq!(
detect_encrypted_pem(pem),
EncryptedPemFormat::LegacyEncrypted
);
}
#[test]
fn test_load_unencrypted_passthrough() {
let generator = SelfSignedGenerator::new("test-passthrough");
let (_cert, key) = generator.generate().expect("Should generate certificate");
let key_der = match &key {
PrivateKeyDer::Pkcs8(k) => k.secret_pkcs8_der().to_vec(),
_ => panic!("Expected PKCS#8 key"),
};
let pem_content = format!(
"-----BEGIN PRIVATE KEY-----\n{}\n-----END PRIVATE KEY-----\n",
base64_encode(&key_der)
);
let temp_path = temp_dir().join("test_unencrypted_passthrough.pem");
fs::write(&temp_path, &pem_content).expect("Should write temp file");
let loader = PrivateKeyLoader::new();
let result = loader.load_encrypted_pem_file(&temp_path, "any_password");
let _ = fs::remove_file(&temp_path);
assert!(
result.is_ok(),
"Unencrypted key should load with any password: {:?}",
result.err()
);
}
#[test]
fn test_encrypted_key_no_password() {
let pem = "-----BEGIN ENCRYPTED PRIVATE KEY-----\nMIIFHDBOBgkqhkiG...\n-----END ENCRYPTED PRIVATE KEY-----\n";
let temp_path = temp_dir().join("test_no_password.pem");
fs::write(&temp_path, pem).expect("Should write temp file");
unsafe { std::env::remove_var("AMATERS_KEY_PASSWORD") };
let loader = PrivateKeyLoader::new();
let result = loader.load_encrypted_pem_file(&temp_path, "");
let _ = fs::remove_file(&temp_path);
assert!(result.is_err());
let err_msg = format!("{}", result.expect_err("Should be an error"));
assert!(
err_msg.contains("no password"),
"Error should mention no password: {err_msg}"
);
}
#[test]
fn test_encrypted_key_empty_password_triggers_env_check() {
unsafe { std::env::remove_var("AMATERS_KEY_PASSWORD") };
let result = resolve_password("");
assert!(result.is_err());
unsafe { std::env::set_var("AMATERS_KEY_PASSWORD", "test_env_pw") };
let result = resolve_password("");
assert!(result.is_ok());
assert_eq!(result.expect("Should succeed"), "test_env_pw");
unsafe { std::env::remove_var("AMATERS_KEY_PASSWORD") };
}
#[test]
fn test_encrypted_key_env_variable() {
unsafe { std::env::set_var("AMATERS_KEY_PASSWORD", "env_password_123") };
let result = resolve_password("");
assert!(result.is_ok());
assert_eq!(result.expect("Should resolve from env"), "env_password_123");
let result = resolve_password("direct_pw");
assert!(result.is_ok());
assert_eq!(result.expect("Should use direct pw"), "direct_pw");
unsafe { std::env::remove_var("AMATERS_KEY_PASSWORD") };
}
#[test]
fn test_parse_dek_info_header_aes256() {
let pem = "-----BEGIN RSA PRIVATE KEY-----\nProc-Type: 4,ENCRYPTED\nDEK-Info: AES-256-CBC,AABBCCDD11223344AABBCCDD11223344\n\nbase64data\n-----END RSA PRIVATE KEY-----\n";
let result = parse_dek_info(pem);
assert!(result.is_ok(), "Should parse DEK-Info: {:?}", result.err());
let (algo, iv) = result.expect("Should succeed");
assert_eq!(algo, "AES-256-CBC");
assert_eq!(iv.len(), 16);
assert_eq!(iv[0], 0xAA);
assert_eq!(iv[1], 0xBB);
}
#[test]
fn test_parse_dek_info_header_aes128() {
let pem = "-----BEGIN RSA PRIVATE KEY-----\nProc-Type: 4,ENCRYPTED\nDEK-Info: AES-128-CBC,00112233445566778899AABBCCDDEEFF\n\nbase64data\n-----END RSA PRIVATE KEY-----\n";
let result = parse_dek_info(pem);
assert!(result.is_ok());
let (algo, iv) = result.expect("Should succeed");
assert_eq!(algo, "AES-128-CBC");
assert_eq!(iv.len(), 16);
}
#[test]
fn test_parse_dek_info_missing() {
let pem = "-----BEGIN RSA PRIVATE KEY-----\nProc-Type: 4,ENCRYPTED\n\nbase64data\n-----END RSA PRIVATE KEY-----\n";
let result = parse_dek_info(pem);
assert!(result.is_err());
}
#[test]
fn test_parse_encrypted_pkcs8_format() {
let pem =
"-----BEGIN ENCRYPTED PRIVATE KEY-----\ndata\n-----END ENCRYPTED PRIVATE KEY-----\n";
assert_eq!(
detect_encrypted_pem(pem),
EncryptedPemFormat::Pkcs8Encrypted
);
}
#[test]
fn test_legacy_encrypted_format_detection() {
let pem = "-----BEGIN RSA PRIVATE KEY-----\nProc-Type: 4,ENCRYPTED\nDEK-Info: AES-256-CBC,0011223344556677\n\nSomeBase64Data\n-----END RSA PRIVATE KEY-----\n";
assert_eq!(
detect_encrypted_pem(pem),
EncryptedPemFormat::LegacyEncrypted
);
}
#[test]
fn test_key_derivation_pbkdf2_sha256() {
let password = b"password";
let salt = b"salt";
let iterations = 1;
let key_len = 32;
let derived = pbkdf2_hmac_sha256(password, salt, iterations, key_len);
assert_eq!(derived.len(), key_len);
let expected: [u8; 32] = [
0x12, 0x0f, 0xb6, 0xcf, 0xfc, 0xf8, 0xb3, 0x2c, 0x43, 0xe7, 0x22, 0x52, 0x56, 0xc4,
0xf8, 0x37, 0xa8, 0x65, 0x48, 0xc9, 0x2c, 0xcc, 0x35, 0x48, 0x08, 0x05, 0x98, 0x7c,
0xb7, 0x0b, 0xe1, 0x7b,
];
assert_eq!(derived, expected, "PBKDF2-HMAC-SHA256 test vector mismatch");
}
#[test]
fn test_key_derivation_deterministic() {
let password = b"my_secret";
let salt = b"random_salt_12345678";
let iterations = 100;
let key1 = pbkdf2_hmac_sha256(password, salt, iterations, 32);
let key2 = pbkdf2_hmac_sha256(password, salt, iterations, 32);
assert_eq!(key1, key2, "Same inputs should produce same derived key");
let key3 = pbkdf2_hmac_sha256(b"different", salt, iterations, 32);
assert_ne!(
key1, key3,
"Different passwords should produce different keys"
);
}
#[test]
fn test_sha256_known_vectors() {
let empty_hash = sha256(b"");
assert_eq!(empty_hash[0], 0xe3);
assert_eq!(empty_hash[1], 0xb0);
assert_eq!(empty_hash[31], 0x55);
let abc_hash = sha256(b"abc");
assert_eq!(abc_hash[0], 0xba);
assert_eq!(abc_hash[1], 0x78);
assert_eq!(abc_hash[31], 0xad);
}
#[test]
fn test_sha1_known_vectors() {
let empty_hash = sha1(b"");
assert_eq!(empty_hash[0], 0xda);
assert_eq!(empty_hash[1], 0x39);
assert_eq!(empty_hash[19], 0x09);
let abc_hash = sha1(b"abc");
assert_eq!(abc_hash[0], 0xa9);
assert_eq!(abc_hash[1], 0x99);
assert_eq!(abc_hash[19], 0x9d);
}
#[test]
fn test_md5_known_vectors() {
let empty_hash = md5(b"");
assert_eq!(empty_hash[0], 0xd4);
assert_eq!(empty_hash[1], 0x1d);
assert_eq!(empty_hash[15], 0x7e);
let abc_hash = md5(b"abc");
assert_eq!(abc_hash[0], 0x90);
assert_eq!(abc_hash[1], 0x01);
assert_eq!(abc_hash[15], 0x72);
}
#[test]
fn test_aes_cbc_roundtrip() {
let key = [0x00u8; 32]; let iv = [0x00u8; 16];
let mut plaintext = vec![0x41u8; 16]; plaintext.extend_from_slice(&[0x10u8; 16]);
let round_keys = aes_key_expansion(&key).expect("Key expansion should work");
let mut ciphertext = Vec::new();
let mut prev_block = iv;
for chunk in plaintext.chunks_exact(16) {
let mut block = [0u8; 16];
for i in 0..16 {
block[i] = chunk[i] ^ prev_block[i];
}
let encrypted = aes_encrypt_block_for_test(&block, &round_keys);
ciphertext.extend_from_slice(&encrypted);
prev_block = encrypted;
}
let decrypted = aes_cbc_decrypt(&ciphertext, &key, &iv).expect("Decryption should work");
let unpadded = remove_pkcs7_padding(&decrypted).expect("Padding removal should work");
assert_eq!(unpadded, &[0x41u8; 16]);
}
fn aes_encrypt_block_for_test(block: &[u8; 16], round_keys: &[[u8; 4]]) -> [u8; 16] {
let nr = round_keys.len() / 4 - 1;
let mut state = [[0u8; 4]; 4];
for c in 0..4 {
for r in 0..4 {
state[r][c] = block[c * 4 + r];
}
}
for c in 0..4 {
for r in 0..4 {
state[r][c] ^= round_keys[c][r];
}
}
for round in 1..nr {
for row in state.iter_mut() {
for val in row.iter_mut() {
*val = AES_SBOX[*val as usize];
}
}
shift_rows_for_test(&mut state);
mix_columns_for_test(&mut state);
for c in 0..4 {
for r in 0..4 {
state[r][c] ^= round_keys[round * 4 + c][r];
}
}
}
for row in state.iter_mut() {
for val in row.iter_mut() {
*val = AES_SBOX[*val as usize];
}
}
shift_rows_for_test(&mut state);
for c in 0..4 {
for r in 0..4 {
state[r][c] ^= round_keys[nr * 4 + c][r];
}
}
let mut output = [0u8; 16];
for c in 0..4 {
for r in 0..4 {
output[c * 4 + r] = state[r][c];
}
}
output
}
fn shift_rows_for_test(state: &mut [[u8; 4]; 4]) {
let tmp = state[1][0];
state[1][0] = state[1][1];
state[1][1] = state[1][2];
state[1][2] = state[1][3];
state[1][3] = tmp;
let (t0, t1) = (state[2][0], state[2][1]);
state[2][0] = state[2][2];
state[2][1] = state[2][3];
state[2][2] = t0;
state[2][3] = t1;
let tmp = state[3][3];
state[3][3] = state[3][2];
state[3][2] = state[3][1];
state[3][1] = state[3][0];
state[3][0] = tmp;
}
#[allow(clippy::needless_range_loop)]
fn mix_columns_for_test(state: &mut [[u8; 4]; 4]) {
for c in 0..4 {
let s0 = state[0][c];
let s1 = state[1][c];
let s2 = state[2][c];
let s3 = state[3][c];
state[0][c] = gf_mul(s0, 2) ^ gf_mul(s1, 3) ^ s2 ^ s3;
state[1][c] = s0 ^ gf_mul(s1, 2) ^ gf_mul(s2, 3) ^ s3;
state[2][c] = s0 ^ s1 ^ gf_mul(s2, 2) ^ gf_mul(s3, 3);
state[3][c] = gf_mul(s0, 3) ^ s1 ^ s2 ^ gf_mul(s3, 2);
}
}
#[test]
fn test_hex_decode_valid() {
let result = hex_decode("AABBCCDD");
assert!(result.is_ok());
assert_eq!(result.expect("hex"), vec![0xAA, 0xBB, 0xCC, 0xDD]);
}
#[test]
fn test_hex_decode_invalid() {
assert!(hex_decode("GG").is_err());
assert!(hex_decode("A").is_err()); }
#[test]
fn test_base64_decode_roundtrip() {
let original = b"Hello, World!";
let encoded = base64_encode(original);
let decoded = base64_decode_pure(&encoded).expect("Should decode");
assert_eq!(decoded, original);
}
#[test]
fn test_pkcs7_padding_removal() {
let mut data = vec![0x41; 12];
data.extend_from_slice(&[0x04, 0x04, 0x04, 0x04]);
let result = remove_pkcs7_padding(&data);
assert!(result.is_ok());
assert_eq!(result.expect("unpadded").len(), 12);
let bad_data = vec![0x41; 16];
let result = remove_pkcs7_padding(&bad_data);
assert!(result.is_err());
}
#[test]
fn test_encrypted_key_wrong_password() {
let fake_encrypted_data: Vec<u8> = [0xDE, 0xAD, 0xBE, 0xEF]
.iter()
.copied()
.cycle()
.take(64)
.collect();
let encoded = base64_encode(&fake_encrypted_data);
let pem = format!(
"-----BEGIN ENCRYPTED PRIVATE KEY-----\n{encoded}\n-----END ENCRYPTED PRIVATE KEY-----\n"
);
let temp_path = temp_dir().join("test_wrong_password.pem");
fs::write(&temp_path, &pem).expect("Should write temp file");
let loader = PrivateKeyLoader::new();
let result = loader.load_encrypted_pem_file(&temp_path, "wrong_password");
let _ = fs::remove_file(&temp_path);
assert!(result.is_err());
}
#[test]
fn test_evp_bytes_to_key_deterministic() {
let password = b"test_password";
let salt = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
let key1 = evp_bytes_to_key_md5(password, &salt, 32);
let key2 = evp_bytes_to_key_md5(password, &salt, 32);
assert_eq!(key1, key2, "Same inputs should produce same key");
assert_eq!(key1.len(), 32);
let key3 = evp_bytes_to_key_md5(b"different", &salt, 32);
assert_ne!(key1, key3);
}
#[test]
fn test_load_encrypted_pem_roundtrip() {
let generator = SelfSignedGenerator::new("roundtrip-test");
let (_cert, key) = generator.generate().expect("Should generate certificate");
let key_der = match &key {
PrivateKeyDer::Pkcs8(k) => k.secret_pkcs8_der().to_vec(),
_ => panic!("Expected PKCS#8 key"),
};
let password = b"test_roundtrip_pw";
let iv = [
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
0x0F, 0x10,
];
let aes_key = evp_bytes_to_key_md5(password, &iv[..8], 32);
let pad_len = 16 - (key_der.len() % 16);
let mut padded = key_der.clone();
for _ in 0..pad_len {
padded.push(pad_len as u8);
}
let round_keys = aes_key_expansion(&aes_key).expect("Key expansion should work");
let mut ciphertext = Vec::new();
let mut prev_block = iv;
for chunk in padded.chunks_exact(16) {
let mut block = [0u8; 16];
for i in 0..16 {
block[i] = chunk[i] ^ prev_block[i];
}
let encrypted = aes_encrypt_block_for_test(&block, &round_keys);
ciphertext.extend_from_slice(&encrypted);
prev_block = encrypted;
}
let iv_hex: String = iv.iter().map(|b| format!("{b:02X}")).collect();
let b64_body = base64_encode(&ciphertext);
let pem = format!(
"-----BEGIN RSA PRIVATE KEY-----\nProc-Type: 4,ENCRYPTED\nDEK-Info: AES-256-CBC,{iv_hex}\n\n{b64_body}\n-----END RSA PRIVATE KEY-----\n"
);
let temp_path = temp_dir().join("test_encrypted_roundtrip.pem");
fs::write(&temp_path, &pem).expect("Should write temp file");
let loader = PrivateKeyLoader::new();
let result = loader.load_encrypted_pem_file(
&temp_path,
std::str::from_utf8(password).expect("valid utf8"),
);
let _ = fs::remove_file(&temp_path);
assert!(
result.is_ok(),
"Encrypted PEM roundtrip should succeed: {:?}",
result.err()
);
let loaded_key = result.expect("Should succeed");
match &loaded_key {
PrivateKeyDer::Pkcs8(k) => {
assert_eq!(
k.secret_pkcs8_der(),
key_der.as_slice(),
"Decrypted key should match original"
);
}
PrivateKeyDer::Pkcs1(k) => {
assert!(
!k.secret_pkcs1_der().is_empty(),
"Decrypted key should not be empty"
);
}
_ => panic!("Unexpected key type"),
}
}
#[test]
fn test_asn1_length_parsing() {
let data = [0x05]; let (len, consumed) = parse_asn1_length(&data).expect("Should parse");
assert_eq!(len, 5);
assert_eq!(consumed, 1);
let data = [0x81, 0x80]; let (len, consumed) = parse_asn1_length(&data).expect("Should parse");
assert_eq!(len, 128);
assert_eq!(consumed, 2);
let data = [0x82, 0x01, 0x00]; let (len, consumed) = parse_asn1_length(&data).expect("Should parse");
assert_eq!(len, 256);
assert_eq!(consumed, 3);
}
#[test]
fn test_pbkdf2_hmac_sha1_basic() {
let derived = pbkdf2_hmac_sha1(b"password", b"salt", 1, 20);
assert_eq!(derived.len(), 20);
let expected: [u8; 20] = [
0x0c, 0x60, 0xc8, 0x0f, 0x96, 0x1f, 0x0e, 0x71, 0xf3, 0xa9, 0xb5, 0x24, 0xaf, 0x60,
0x12, 0x06, 0x2f, 0xe0, 0x37, 0xa6,
];
assert_eq!(derived, expected, "PBKDF2-HMAC-SHA1 test vector mismatch");
}
}