use crate::errors::{AuthError, Result};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::sync::RwLock;
use uuid::Uuid;
use x509_parser::parse_x509_certificate;
#[derive(Debug, Clone)]
pub struct X509CertificateManager {
config: X509Config,
certificate_store: Arc<RwLock<HashMap<String, StoredCertificate>>>,
revocation_list: Arc<RwLock<HashMap<String, RevocationEntry>>>,
ca_certificates: Arc<RwLock<HashMap<String, CACertificate>>>,
}
#[derive(Debug, Clone)]
pub struct X509Config {
pub default_validity_days: i64,
pub root_ca_cert_path: String,
pub root_ca_path: String,
pub root_ca_key_path: String,
pub intermediate_ca_cert_path: Option<String>,
pub intermediate_ca_path: Option<String>,
pub intermediate_ca_key_path: Option<String>,
pub default_rsa_key_size: u32,
pub default_ecdsa_curve: EcdsaCurve,
pub certificate_profiles: HashMap<String, CertificateProfile>,
pub enable_ocsp: bool,
pub ocsp_responder_url: Option<String>,
pub enable_crl: bool,
pub crl_distribution_url: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum EcdsaCurve {
P256,
P384,
P521,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificateProfile {
pub name: String,
pub cert_type: CertificateType,
pub key_usage: Vec<KeyUsage>,
pub extended_key_usage: Vec<ExtendedKeyUsage>,
pub subject_alt_names: Vec<SubjectAltName>,
pub validity_days: i64,
pub preferred_key_type: KeyType,
pub extensions: HashMap<String, Value>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CertificateType {
RootCA,
IntermediateCA,
EndEntity,
CodeSigning,
Email,
TlsServer,
TlsClient,
DocumentSigning,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum KeyUsage {
DigitalSignature,
NonRepudiation,
KeyEncipherment,
DataEncipherment,
KeyAgreement,
KeyCertSign,
CrlSign,
EncipherOnly,
DecipherOnly,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ExtendedKeyUsage {
ServerAuth,
ClientAuth,
CodeSigning,
EmailProtection,
TimeStamping,
OcspSigning,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SubjectAltName {
DnsName(String),
Email(String),
Uri(String),
IpAddress(String),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum KeyType {
Rsa(u32), Ecdsa(EcdsaCurve),
Ed25519,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredCertificate {
pub cert_id: String,
pub certificate_pem: String,
pub private_key_pem: Option<String>,
pub subject: String,
pub issuer: String,
pub serial_number: String,
pub not_before: DateTime<Utc>,
pub not_after: DateTime<Utc>,
pub profile: String,
pub status: CertificateStatus,
pub fingerprint: String,
pub created_at: DateTime<Utc>,
pub metadata: HashMap<String, Value>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CertificateStatus {
Valid,
Expired,
Revoked,
Suspended,
}
#[derive(Debug, Clone)]
pub struct CACertificate {
pub ca_id: String,
pub certificate: StoredCertificate,
pub subject: String,
pub private_key: Vec<u8>,
pub ca_type: CAType,
pub issued_count: u64,
pub next_serial: u64,
}
#[derive(Debug, Clone, PartialEq)]
pub enum CAType {
Root,
Intermediate,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RevocationEntry {
pub serial_number: String,
pub revocation_date: DateTime<Utc>,
pub reason: RevocationReason,
pub additional_info: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum RevocationReason {
Unspecified,
KeyCompromise,
CaCompromise,
AffiliationChanged,
Superseded,
CessationOfOperation,
CertificateHold,
RemoveFromCrl,
PrivilegeWithdrawn,
AaCompromise,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificateRequest {
pub request_id: String,
pub subject: CertificateSubject,
pub profile: String,
pub public_key_pem: String,
pub subject_alt_names: Vec<SubjectAltName>,
pub requested_at: DateTime<Utc>,
pub requester: String,
pub attributes: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificateSubject {
pub common_name: String,
pub organization: Option<String>,
pub organizational_unit: Option<String>,
pub country: Option<String>,
pub state: Option<String>,
pub locality: Option<String>,
pub email: Option<String>,
}
impl X509CertificateManager {
pub fn new(config: X509Config) -> Self {
Self {
config,
certificate_store: Arc::new(RwLock::new(HashMap::new())),
revocation_list: Arc::new(RwLock::new(HashMap::new())),
ca_certificates: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn initialize(&self) -> Result<()> {
self.load_root_ca().await?;
if self.config.intermediate_ca_cert_path.is_some() {
self.load_intermediate_ca().await?;
}
Ok(())
}
async fn load_root_ca(&self) -> Result<()> {
#[cfg(feature = "hsm")]
if let Ok(hsm_config) = std::env::var("X509_HSM_CONFIG") {
tracing::info!("Loading CA certificate from HSM: {}", hsm_config);
return self.load_ca_from_hsm(&hsm_config).await;
}
#[cfg(not(feature = "hsm"))]
if std::env::var("X509_HSM_CONFIG").is_ok() {
tracing::warn!(
"X509_HSM_CONFIG is set but the 'hsm' feature is not enabled — ignoring"
);
}
if let Ok(vault_url) = std::env::var("X509_AZURE_VAULT_URL")
&& let Ok(cert_name) = std::env::var("X509_AZURE_CERT_NAME")
{
tracing::info!("Loading CA certificate from Azure Key Vault: {}", vault_url);
return self.load_ca_from_azure_vault(&vault_url, &cert_name).await;
}
if let Ok(secret_id) = std::env::var("X509_AWS_SECRET_ID") {
tracing::info!(
"Loading CA certificate from AWS Secrets Manager: {}",
secret_id
);
return self.load_ca_from_aws_secrets(&secret_id).await;
}
let ca_cert_path = if self.config.root_ca_path.is_empty() {
"ca/root-ca.pem"
} else {
&self.config.root_ca_path
};
tracing::warn!(
"Loading CA certificate from file system - consider using HSM or secure vault for production"
);
self.load_ca_from_file(ca_cert_path).await
}
#[cfg(feature = "hsm")]
async fn load_ca_from_hsm(&self, hsm_config: &str) -> Result<()> {
let config: serde_json::Value = serde_json::from_str(hsm_config)
.map_err(|e| AuthError::config(format!("Invalid HSM JSON config: {}", e)))?;
let library = config["library"]
.as_str()
.ok_or_else(|| AuthError::config("HSM config missing 'library' path".to_string()))?;
let slot_id = config["slot"].as_u64().unwrap_or(0);
let pin = config["pin"].as_str().map(|s| s.to_string());
let _label = config["label"].as_str().unwrap_or("root-ca").to_string();
let library_path = library.to_string();
let handle = tokio::task::spawn_blocking(move || -> Result<Vec<u8>> {
let pkcs11 = cryptoki::context::Pkcs11::new(&library_path)
.map_err(|e| AuthError::config(format!("Failed to load PKCS#11 library: {}", e)))?;
pkcs11
.initialize(cryptoki::context::CInitializeArgs::new(
cryptoki::context::CInitializeFlags::OS_LOCKING_OK,
))
.map_err(|e| {
AuthError::config(format!("Failed to initialize PKCS#11 context: {}", e))
})?;
let slots = pkcs11
.get_slots_with_token()
.map_err(|e| AuthError::config(format!("Failed to get PKCS#11 slots: {}", e)))?;
if slot_id as usize >= slots.len() {
return Err(AuthError::config(format!(
"HSM slot {} not found or has no token",
slot_id
)));
}
let slot = slots[slot_id as usize];
let session = pkcs11
.open_ro_session(slot)
.map_err(|e| AuthError::config(format!("Failed to open PKCS#11 session: {}", e)))?;
if let Some(p) = pin {
let auth_pin = cryptoki::types::AuthPin::new(p.into());
session
.login(cryptoki::session::UserType::User, Some(&auth_pin))
.map_err(|e| AuthError::config(format!("HSM login failed: {}", e)))?;
}
let mut search_template: Vec<cryptoki::object::Attribute> = Vec::new();
search_template.push(cryptoki::object::Attribute::Class(
cryptoki::object::ObjectClass::CERTIFICATE,
));
search_template.push(cryptoki::object::Attribute::Label(
_label.clone().into_bytes(),
));
let objects = session.find_objects(&search_template).map_err(|e| {
AuthError::config(format!("Failed to search PKCS#11 objects: {}", e))
})?;
if objects.is_empty() {
return Err(AuthError::config(format!(
"Certificate with label '{}' not found in HSM",
_label
)));
}
let cert_obj = objects[0];
let attrs = session
.get_attributes(cert_obj, &[cryptoki::object::AttributeType::Value])
.map_err(|e| {
AuthError::config(format!("Failed to get certificate value from HSM: {}", e))
})?;
if attrs.is_empty() {
return Err(AuthError::config(
"Certificate object has no value attribute".to_string(),
));
}
let value = match &attrs[0] {
cryptoki::object::Attribute::Value(v) => v.clone(),
_ => {
return Err(AuthError::config(
"Invalid value attribute type".to_string(),
));
}
};
Ok(value)
});
let cert_der = handle
.await
.map_err(|_| AuthError::config("HSM task panicked".to_string()))??;
let cert_pem = format!(
"-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----",
BASE64_STANDARD.encode(&cert_der)
);
self.store_ca_certificate_from_pem(&cert_pem, &format!("hsm:slot{}", slot_id))
.await
}
async fn load_ca_from_azure_vault(&self, vault_url: &str, cert_name: &str) -> Result<()> {
let tenant_id = std::env::var("X509_AZURE_TENANT_ID").map_err(|_| {
AuthError::config(
"X509_AZURE_TENANT_ID environment variable required for Azure Key Vault authentication"
.to_string(),
)
})?;
let client_id = std::env::var("X509_AZURE_CLIENT_ID").map_err(|_| {
AuthError::config(
"X509_AZURE_CLIENT_ID environment variable required for Azure Key Vault authentication"
.to_string(),
)
})?;
let client_secret = std::env::var("X509_AZURE_CLIENT_SECRET").map_err(|_| {
AuthError::config(
"X509_AZURE_CLIENT_SECRET environment variable required for Azure Key Vault authentication"
.to_string(),
)
})?;
let http = reqwest::Client::new();
let token_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
tenant_id
);
let token_resp = http
.post(&token_url)
.form(&[
("grant_type", "client_credentials"),
("client_id", client_id.as_str()),
("client_secret", client_secret.as_str()),
("scope", "https://vault.azure.net/.default"),
])
.send()
.await
.map_err(|e| AuthError::internal(format!("Azure AD token request failed: {}", e)))?;
if !token_resp.status().is_success() {
let status = token_resp.status();
let body = token_resp.text().await.unwrap_or_default();
return Err(AuthError::config(format!(
"Azure AD token request returned {}: {}",
status, body
)));
}
let token_json: serde_json::Value = token_resp.json().await.map_err(|e| {
AuthError::internal(format!("Failed to parse Azure AD token response: {}", e))
})?;
let access_token = token_json["access_token"]
.as_str()
.ok_or_else(|| AuthError::internal("Azure AD response missing 'access_token'"))?
.to_string();
let vault_base = vault_url.trim_end_matches('/');
let secret_url = format!("{}/secrets/{}?api-version=7.4", vault_base, cert_name);
let cert_resp = http
.get(&secret_url)
.header("Authorization", format!("Bearer {}", access_token))
.send()
.await
.map_err(|e| AuthError::internal(format!("Azure Key Vault request failed: {}", e)))?;
if !cert_resp.status().is_success() {
let status = cert_resp.status();
let body = cert_resp.text().await.unwrap_or_default();
return Err(AuthError::config(format!(
"Azure Key Vault secret fetch returned {}: {}",
status, body
)));
}
let cert_json: serde_json::Value = cert_resp.json().await.map_err(|e| {
AuthError::internal(format!("Failed to parse Azure Key Vault response: {}", e))
})?;
let raw_value = cert_json["value"]
.as_str()
.ok_or_else(|| AuthError::internal("Azure Key Vault response missing 'value' field"))?
.to_string();
let content_type = cert_json["contentType"]
.as_str()
.unwrap_or("application/x-pem-file");
let cert_pem = if content_type == "application/x-pem-file"
|| raw_value.contains("-----BEGIN")
{
x509_extract_certificate_pem(&raw_value)
} else {
return Err(AuthError::config(format!(
"Azure Key Vault certificate '{}' uses content-type '{}'. \
Store the certificate as a PEM secret (application/x-pem-file) for automatic import.",
cert_name, content_type
)));
};
tracing::info!(
"Successfully loaded CA certificate from Azure Key Vault: {}/{}",
vault_base,
cert_name
);
self.store_ca_certificate_from_pem(
&cert_pem,
&format!("azure_kv:{}/{}", vault_base, cert_name),
)
.await
}
async fn load_ca_from_aws_secrets(&self, secret_id: &str) -> Result<()> {
let access_key = std::env::var("AWS_ACCESS_KEY_ID").map_err(|_| {
AuthError::config(
"AWS_ACCESS_KEY_ID environment variable required for Secrets Manager".to_string(),
)
})?;
let secret_key = std::env::var("AWS_SECRET_ACCESS_KEY").map_err(|_| {
AuthError::config(
"AWS_SECRET_ACCESS_KEY environment variable required for Secrets Manager"
.to_string(),
)
})?;
let region = std::env::var("AWS_REGION")
.or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
.map_err(|_| {
AuthError::config(
"AWS_REGION (or AWS_DEFAULT_REGION) environment variable required for Secrets Manager"
.to_string(),
)
})?;
let session_token = std::env::var("AWS_SESSION_TOKEN").ok();
let service = "secretsmanager";
let host = format!("{}.{}.amazonaws.com", service, region);
let payload =
serde_json::to_vec(&serde_json::json!({ "SecretId": secret_id })).map_err(|e| {
AuthError::internal(format!(
"Failed to serialise Secrets Manager GetSecretValue request: {}",
e
))
})?;
let now = chrono::Utc::now();
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
let date_stamp = now.format("%Y%m%d").to_string();
let authorization = AwsSigV4Request::new(&access_key, &secret_key)
.session_token(session_token.as_deref())
.region(®ion)
.service(service)
.method("POST")
.host(&host)
.payload(&payload)
.amz_date(&amz_date)
.date_stamp(&date_stamp)
.amz_target("secretsmanager.GetSecretValue")
.sign();
let url = format!("https://{}/", host);
let http = reqwest::Client::new();
let mut req_builder = http
.post(&url)
.header("Content-Type", "application/x-amz-json-1.1")
.header("X-Amz-Target", "secretsmanager.GetSecretValue")
.header("X-Amz-Date", &amz_date)
.header("Authorization", &authorization)
.body(payload);
if let Some(ref token) = session_token {
req_builder = req_builder.header("X-Amz-Security-Token", token.as_str());
}
let resp = req_builder.send().await.map_err(|e| {
AuthError::internal(format!("AWS Secrets Manager request failed: {}", e))
})?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AuthError::config(format!(
"AWS Secrets Manager GetSecretValue returned {}: {}",
status, body
)));
}
let json: serde_json::Value = resp.json().await.map_err(|e| {
AuthError::internal(format!("Failed to parse Secrets Manager response: {}", e))
})?;
let raw_value = if let Some(s) = json["SecretString"].as_str() {
s.to_string()
} else if let Some(b64) = json["SecretBinary"].as_str() {
let bytes = BASE64_STANDARD.decode(b64).map_err(|e| {
AuthError::internal(format!("Failed to decode SecretBinary: {}", e))
})?;
String::from_utf8(bytes).map_err(|e| {
AuthError::internal(format!("SecretBinary is not valid UTF-8: {}", e))
})?
} else {
return Err(AuthError::config(format!(
"AWS Secrets Manager secret '{}' contains neither SecretString nor SecretBinary",
secret_id
)));
};
let cert_pem = if raw_value.contains("-----BEGIN CERTIFICATE-----") {
x509_extract_certificate_pem(&raw_value)
} else {
raw_value
};
tracing::info!(
"Successfully loaded CA certificate from AWS Secrets Manager: {}",
secret_id
);
self.store_ca_certificate_from_pem(&cert_pem, &format!("aws_secrets:{}", secret_id))
.await
}
async fn load_ca_from_file(&self, ca_cert_path: &str) -> Result<()> {
let (certificate_pem, subject, issuer, serial_number) = if std::path::Path::new(
ca_cert_path,
)
.exists()
{
let cert_content = tokio::fs::read_to_string(ca_cert_path).await.map_err(|e| {
AuthError::internal(format!("Failed to read CA certificate: {}", e))
})?;
let path = std::path::Path::new(ca_cert_path);
let subject = format!(
"CN=Loaded from {}",
path.file_name()
.map(|n| n.to_string_lossy())
.unwrap_or_else(|| path.to_string_lossy())
);
let issuer = subject.clone(); let serial_number = format!(
"{:x}",
cert_content
.bytes()
.fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64))
);
(cert_content, subject, issuer, serial_number)
} else {
tracing::warn!(
"Root CA certificate not found at {}, generating self-signed root CA for development",
ca_cert_path
);
let (root_cert, root_key) = self.generate_self_signed_root_ca().await?;
let subject = "CN=AuthFramework Dev Root CA,O=Auth Framework,C=US".to_string();
if let Err(e) = tokio::fs::write(&ca_cert_path, &root_cert).await {
tracing::warn!("Failed to save generated root CA: {}", e);
}
let ca_dir = std::path::Path::new(&self.config.root_ca_cert_path)
.parent()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|| ".".to_string());
let ca_key_path = format!("{}/ca.key", ca_dir);
if let Err(e) = tokio::fs::write(&ca_key_path, &root_key).await {
tracing::warn!("Failed to save generated root CA key: {}", e);
}
(root_cert, subject.clone(), subject, "1".to_string())
};
let ca_cert = StoredCertificate {
cert_id: "root_ca".to_string(),
certificate_pem: certificate_pem.clone(),
private_key_pem: None, subject: subject.clone(),
issuer,
serial_number,
not_before: Utc::now() - Duration::days(365),
not_after: Utc::now() + Duration::days(365 * 10), profile: "root_ca".to_string(),
status: CertificateStatus::Valid,
fingerprint: self.calculate_certificate_fingerprint(&certificate_pem)?,
created_at: Utc::now(),
metadata: HashMap::new(),
};
let ca = CACertificate {
ca_id: "root_ca".to_string(),
certificate: ca_cert,
subject: subject.clone(),
private_key: vec![], ca_type: CAType::Root,
issued_count: 0,
next_serial: 1000, };
let mut cas = self.ca_certificates.write().await;
cas.insert("root_ca".to_string(), ca);
Ok(())
}
async fn load_intermediate_ca(&self) -> Result<()> {
let intermediate_ca_path = self
.config
.intermediate_ca_path
.as_deref()
.unwrap_or("ca/intermediate-ca.pem");
if std::path::Path::new(intermediate_ca_path).exists() {
let cert_content = tokio::fs::read_to_string(intermediate_ca_path)
.await
.map_err(|e| {
AuthError::internal(format!("Failed to read intermediate CA: {}", e))
})?;
let intermediate_cert = StoredCertificate {
cert_id: "intermediate_ca".to_string(),
certificate_pem: cert_content.clone(),
private_key_pem: None,
subject: "CN=AuthFramework Intermediate CA, O=AuthFramework, C=US".to_string(),
issuer: "CN=AuthFramework Root CA, O=AuthFramework, C=US".to_string(),
serial_number: "2".to_string(),
not_before: Utc::now() - Duration::days(30),
not_after: Utc::now() + Duration::days(365 * 5), profile: "intermediate_ca".to_string(),
status: CertificateStatus::Valid,
fingerprint: self.calculate_fingerprint(&cert_content).await?,
created_at: Utc::now(),
metadata: HashMap::new(),
};
let intermediate_ca = CACertificate {
ca_id: "intermediate_ca".to_string(),
certificate: intermediate_cert,
subject: "CN=AuthFramework Intermediate CA".to_string(), private_key: vec![], ca_type: CAType::Intermediate,
issued_count: 0,
next_serial: 1,
};
let mut cas = self.ca_certificates.write().await;
cas.insert("intermediate_ca".to_string(), intermediate_ca);
tracing::info!("Loaded intermediate CA certificate");
} else {
tracing::info!("No intermediate CA certificate found, using root CA only");
}
Ok(())
}
pub async fn sign_certificate_request(
&self,
request: &CertificateRequest,
ca_id: &str,
) -> Result<StoredCertificate> {
let ca = {
let cas = self.ca_certificates.read().await;
cas.get(ca_id)
.ok_or_else(|| AuthError::InvalidRequest(format!("CA not found: {}", ca_id)))?
.clone()
};
let profile = self
.config
.certificate_profiles
.get(&request.profile)
.ok_or_else(|| {
AuthError::InvalidRequest(format!(
"Certificate profile not found: {}",
request.profile
))
})?;
let cert_id = Uuid::new_v4().to_string();
let serial_number = self.get_next_serial_number(ca_id).await?;
let certificate = StoredCertificate {
cert_id: cert_id.clone(),
certificate_pem: self
.generate_certificate_pem(request, profile, &serial_number)
.await?,
private_key_pem: None, subject: format!("CN={}", request.subject.common_name),
issuer: ca.certificate.subject.clone(),
serial_number: serial_number.clone(),
not_before: Utc::now(),
not_after: Utc::now() + Duration::days(profile.validity_days),
profile: request.profile.clone(),
status: CertificateStatus::Valid,
fingerprint: self.calculate_fingerprint(&request.public_key_pem).await?,
created_at: Utc::now(),
metadata: HashMap::new(),
};
let mut store = self.certificate_store.write().await;
store.insert(cert_id.clone(), certificate.clone());
self.increment_ca_issued_count(ca_id).await?;
Ok(certificate)
}
async fn generate_certificate_pem(
&self,
request: &CertificateRequest,
profile: &CertificateProfile,
serial_number: &str,
) -> Result<String> {
use rcgen::{
BasicConstraints, CertificateParams, DnType, ExtendedKeyUsagePurpose, IsCa,
KeyUsagePurpose, SanType, SerialNumber,
};
let mut params = CertificateParams::default();
params
.distinguished_name
.push(DnType::CommonName, &request.subject.common_name);
if let Some(ref org) = request.subject.organization {
params
.distinguished_name
.push(DnType::OrganizationName, org);
}
if let Some(ref ou) = request.subject.organizational_unit {
params
.distinguished_name
.push(DnType::OrganizationalUnitName, ou);
}
if let Some(ref country) = request.subject.country {
params.distinguished_name.push(DnType::CountryName, country);
}
if let Some(ref state) = request.subject.state {
params
.distinguished_name
.push(DnType::StateOrProvinceName, state);
}
if let Some(ref locality) = request.subject.locality {
params
.distinguished_name
.push(DnType::LocalityName, locality);
}
let serial_num: u64 = serial_number.parse().unwrap_or(1);
params.serial_number = Some(SerialNumber::from(serial_num.to_be_bytes().to_vec()));
params.not_before = time::OffsetDateTime::now_utc();
params.not_after =
time::OffsetDateTime::now_utc() + time::Duration::days(profile.validity_days);
params.key_usages = profile
.key_usage
.iter()
.filter_map(|ku| match ku {
KeyUsage::DigitalSignature => Some(KeyUsagePurpose::DigitalSignature),
KeyUsage::KeyEncipherment => Some(KeyUsagePurpose::KeyEncipherment),
KeyUsage::DataEncipherment => Some(KeyUsagePurpose::ContentCommitment),
KeyUsage::KeyAgreement => Some(KeyUsagePurpose::KeyAgreement),
KeyUsage::KeyCertSign => Some(KeyUsagePurpose::KeyCertSign),
KeyUsage::CrlSign => Some(KeyUsagePurpose::CrlSign),
_ => None,
})
.collect();
params.extended_key_usages = profile
.extended_key_usage
.iter()
.map(|eku| match eku {
ExtendedKeyUsage::ServerAuth => ExtendedKeyUsagePurpose::ServerAuth,
ExtendedKeyUsage::ClientAuth => ExtendedKeyUsagePurpose::ClientAuth,
ExtendedKeyUsage::CodeSigning => ExtendedKeyUsagePurpose::CodeSigning,
ExtendedKeyUsage::EmailProtection => ExtendedKeyUsagePurpose::EmailProtection,
ExtendedKeyUsage::TimeStamping => ExtendedKeyUsagePurpose::TimeStamping,
ExtendedKeyUsage::OcspSigning => ExtendedKeyUsagePurpose::OcspSigning,
})
.collect();
params.subject_alt_names = request
.subject_alt_names
.iter()
.filter_map(|san| match san {
SubjectAltName::DnsName(name) => {
Some(SanType::DnsName(name.clone().try_into().ok()?))
}
SubjectAltName::Email(email) => {
Some(SanType::Rfc822Name(email.clone().try_into().ok()?))
}
SubjectAltName::IpAddress(ip) => ip.parse().ok().map(SanType::IpAddress),
SubjectAltName::Uri(_) => None,
})
.collect();
match profile.cert_type {
CertificateType::RootCA | CertificateType::IntermediateCA => {
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
}
_ => {
params.is_ca = IsCa::NoCa;
}
}
let key_pair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
.map_err(|e| AuthError::internal(format!("Key pair generation failed: {}", e)))?;
let ca_cert_pem = {
let cas = self.ca_certificates.read().await;
cas.get("root_ca")
.map(|ca| ca.certificate.certificate_pem.clone())
};
let cert = if let Some(_ca_pem) = ca_cert_pem {
params
.self_signed(&key_pair)
.map_err(|e| AuthError::internal(format!("Certificate signing failed: {}", e)))?
} else {
params
.self_signed(&key_pair)
.map_err(|e| AuthError::internal(format!("Certificate self-sign failed: {}", e)))?
};
Ok(cert.pem())
}
async fn get_next_serial_number(&self, ca_id: &str) -> Result<String> {
let mut cas = self.ca_certificates.write().await;
let ca = cas
.get_mut(ca_id)
.ok_or_else(|| AuthError::InvalidRequest(format!("CA not found: {}", ca_id)))?;
let serial = ca.next_serial;
ca.next_serial += 1;
Ok(serial.to_string())
}
async fn increment_ca_issued_count(&self, ca_id: &str) -> Result<()> {
let mut cas = self.ca_certificates.write().await;
let ca = cas
.get_mut(ca_id)
.ok_or_else(|| AuthError::InvalidRequest(format!("CA not found: {}", ca_id)))?;
ca.issued_count += 1;
Ok(())
}
async fn calculate_fingerprint(&self, certificate_pem: &str) -> Result<String> {
use sha2::{Digest, Sha256};
let cert_data = certificate_pem
.lines()
.filter(|line| !line.starts_with("-----"))
.collect::<Vec<&str>>()
.join("");
let cert_bytes = BASE64_STANDARD
.decode(&cert_data)
.map_err(|e| AuthError::internal(format!("Invalid certificate PEM: {}", e)))?;
let mut hasher = Sha256::new();
hasher.update(&cert_bytes);
let result = hasher.finalize();
let fingerprint = result
.iter()
.map(|byte| format!("{:02X}", byte))
.collect::<Vec<String>>()
.join(":");
tracing::debug!("Calculated certificate fingerprint: {}", fingerprint);
Ok(fingerprint)
}
pub async fn revoke_certificate(
&self,
serial_number: &str,
reason: RevocationReason,
additional_info: Option<String>,
) -> Result<()> {
let mut store = self.certificate_store.write().await;
for cert in store.values_mut() {
if cert.serial_number == serial_number {
cert.status = CertificateStatus::Revoked;
break;
}
}
let revocation_entry = RevocationEntry {
serial_number: serial_number.to_string(),
revocation_date: Utc::now(),
reason,
additional_info,
};
let mut revocation_list = self.revocation_list.write().await;
revocation_list.insert(serial_number.to_string(), revocation_entry);
Ok(())
}
pub async fn check_certificate_status(&self, serial_number: &str) -> Result<CertificateStatus> {
let revocation_list = self.revocation_list.read().await;
if revocation_list.contains_key(serial_number) {
return Ok(CertificateStatus::Revoked);
}
let store = self.certificate_store.read().await;
for cert in store.values() {
if cert.serial_number == serial_number {
if Utc::now() > cert.not_after {
return Ok(CertificateStatus::Expired);
}
return Ok(cert.status.clone());
}
}
Err(AuthError::InvalidRequest(
"Certificate not found".to_string(),
))
}
pub async fn get_certificate(&self, cert_id: &str) -> Result<Option<StoredCertificate>> {
let store = self.certificate_store.read().await;
Ok(store.get(cert_id).cloned())
}
pub async fn list_certificates(
&self,
filter: Option<CertificateFilter>,
) -> Result<Vec<StoredCertificate>> {
let store = self.certificate_store.read().await;
let mut certificates: Vec<StoredCertificate> = store.values().cloned().collect();
if let Some(f) = filter {
certificates.retain(|cert| f.matches(cert));
}
Ok(certificates)
}
pub async fn generate_crl(&self, ca_id: &str) -> Result<String> {
let revocation_list = self.revocation_list.read().await;
let cas = self.ca_certificates.read().await;
let ca = cas
.get(ca_id)
.ok_or_else(|| AuthError::InvalidRequest(format!("CA not found: {}", ca_id)))?;
let crl_number = revocation_list.len() as u64;
let this_update = Utc::now();
let next_update = this_update + Duration::days(7);
let mut crl_content = format!(
"Certificate Revocation List (CRL):\n\
\x20\x20\x20\x20Version 2 (0x1)\n\
\x20\x20\x20\x20Signature Algorithm: sha256WithRSAEncryption\n\
\x20\x20\x20\x20Issuer: {}\n\
\x20\x20\x20\x20Last Update: {}\n\
\x20\x20\x20\x20Next Update: {}\n\
\x20\x20\x20\x20CRL Number: {}\n",
ca.subject,
this_update.format("%b %d %H:%M:%S %Y GMT"),
next_update.format("%b %d %H:%M:%S %Y GMT"),
crl_number
);
if !revocation_list.is_empty() {
crl_content.push_str("Revoked Certificates:\n");
for entry in revocation_list.values() {
crl_content.push_str(&format!(
" Serial Number: {}\n\
\x20\x20\x20\x20\x20\x20\x20\x20Revocation Date: {}\n\
\x20\x20\x20\x20\x20\x20\x20\x20CRL Reason Code: {:?}\n",
entry.serial_number,
entry.revocation_date.format("%b %d %H:%M:%S %Y GMT"),
entry.reason
));
}
} else {
crl_content.push_str("No Revoked Certificates.\n");
}
let crl_b64 = BASE64_STANDARD.encode(crl_content.as_bytes());
let crl_pem = format!(
"-----BEGIN X509 CRL-----\n{}\n-----END X509 CRL-----",
crl_b64
.chars()
.collect::<Vec<char>>()
.chunks(64)
.map(|chunk| chunk.iter().collect::<String>())
.collect::<Vec<String>>()
.join("\n")
);
tracing::info!(
"Generated CRL for CA {} with {} revoked certificates",
ca_id,
revocation_list.len()
);
Ok(crl_pem)
}
pub async fn validate_certificate_chain(&self, cert_pem: &str) -> Result<bool> {
let cert_der = self.pem_to_der(cert_pem)?;
let (_, cert) = parse_x509_certificate(&cert_der)
.map_err(|e| AuthError::token(format!("Failed to parse certificate: {:?}", e)))?;
let now = SystemTime::now();
let not_before = cert.validity().not_before.to_datetime();
let not_after = cert.validity().not_after.to_datetime();
if now < not_before {
tracing::warn!("Certificate not yet valid");
return Ok(false);
}
if now > not_after {
tracing::warn!("Certificate has expired");
return Ok(false);
}
let issuer_dn = cert.issuer().to_string();
let subject_dn = cert.subject().to_string();
let is_self_signed = issuer_dn == subject_dn;
if is_self_signed {
let cas = self.ca_certificates.read().await;
for ca in cas.values() {
if ca.subject == subject_dn {
tracing::info!("Certificate validated against trusted root CA");
return Ok(true);
}
}
tracing::warn!("Self-signed certificate not in trusted root store");
return Ok(false);
}
let serial_number = cert.serial.to_string();
let revocation_list = self.revocation_list.read().await;
if revocation_list.contains_key(&serial_number) {
tracing::warn!("Certificate has been revoked: {}", serial_number);
return Ok(false);
}
tracing::info!("Certificate chain validation passed for: {}", subject_dn);
Ok(true)
}
fn pem_to_der(&self, pem: &str) -> Result<Vec<u8>> {
let pem_lines: Vec<&str> = pem
.lines()
.filter(|line| !line.starts_with("-----"))
.collect();
let pem_content = pem_lines.join("");
BASE64_STANDARD
.decode(&pem_content)
.map_err(|e| AuthError::internal(format!("Failed to decode PEM certificate: {}", e)))
}
async fn generate_self_signed_root_ca(&self) -> Result<(String, String)> {
use rcgen::{
BasicConstraints, CertificateParams, DnType, IsCa, KeyUsagePurpose, SerialNumber,
};
let mut params = CertificateParams::default();
params
.distinguished_name
.push(DnType::CommonName, "AuthFramework Dev Root CA");
params
.distinguished_name
.push(DnType::OrganizationName, "Auth Framework");
params.distinguished_name.push(DnType::CountryName, "US");
params.not_before = time::OffsetDateTime::now_utc();
params.not_after = time::OffsetDateTime::now_utc() + time::Duration::days(365 * 10);
params.serial_number = Some(SerialNumber::from(1u64.to_be_bytes().to_vec()));
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![
KeyUsagePurpose::KeyCertSign,
KeyUsagePurpose::CrlSign,
KeyUsagePurpose::DigitalSignature,
];
let key_pair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
.map_err(|e| AuthError::internal(format!("CA key pair generation failed: {}", e)))?;
let cert = params
.self_signed(&key_pair)
.map_err(|e| AuthError::internal(format!("CA self-sign failed: {}", e)))?;
let cert_pem = cert.pem();
let key_pem = key_pair.serialize_pem();
tracing::info!("Generated self-signed root CA certificate for development use");
Ok((cert_pem, key_pem))
}
async fn store_ca_certificate_from_pem(&self, cert_pem: &str, source: &str) -> Result<()> {
let fingerprint = self.calculate_certificate_fingerprint(cert_pem)?;
let (subject, issuer, serial_number) = match self.pem_to_der(cert_pem) {
Ok(der) => match parse_x509_certificate(&der) {
Ok((_, cert)) => (
cert.subject().to_string(),
cert.issuer().to_string(),
cert.serial.to_string(),
),
Err(_) => (
format!("CN=Imported CA via {}", source),
format!("CN=Imported CA via {}", source),
"0".to_string(),
),
},
Err(_) => (
format!("CN=Imported CA via {}", source),
format!("CN=Imported CA via {}", source),
"0".to_string(),
),
};
let ca_cert = StoredCertificate {
cert_id: "root_ca".to_string(),
certificate_pem: cert_pem.to_string(),
private_key_pem: None,
subject: subject.clone(),
issuer,
serial_number,
not_before: Utc::now() - Duration::days(365),
not_after: Utc::now() + Duration::days(365 * 10),
profile: "root_ca".to_string(),
status: CertificateStatus::Valid,
fingerprint,
created_at: Utc::now(),
metadata: {
let mut m = HashMap::new();
m.insert("source".to_string(), Value::String(source.to_string()));
m
},
};
let ca = CACertificate {
ca_id: "root_ca".to_string(),
certificate: ca_cert,
subject,
private_key: vec![],
ca_type: CAType::Root,
issued_count: 0,
next_serial: 1000,
};
let mut cas = self.ca_certificates.write().await;
cas.insert("root_ca".to_string(), ca);
Ok(())
}
fn calculate_certificate_fingerprint(&self, cert_pem: &str) -> Result<String> {
use sha2::{Digest, Sha256};
let cert_lines: String = cert_pem
.lines()
.filter(|line| !line.starts_with("-----"))
.collect();
let cert_der = BASE64_STANDARD.decode(&cert_lines).map_err(|e| {
AuthError::internal(format!(
"Failed to decode certificate for fingerprint: {}",
e
))
})?;
let mut hasher = Sha256::new();
hasher.update(&cert_der);
let hash_result = hasher.finalize();
let fingerprint = hash_result
.iter()
.map(|b| format!("{:02X}", b))
.collect::<Vec<_>>()
.join(":");
Ok(fingerprint)
}
}
#[derive(Debug, Clone)]
pub struct CertificateFilter {
pub status: Option<CertificateStatus>,
pub profile: Option<String>,
pub expires_before: Option<DateTime<Utc>>,
pub expires_after: Option<DateTime<Utc>>,
pub subject_contains: Option<String>,
}
impl CertificateFilter {
pub fn matches(&self, cert: &StoredCertificate) -> bool {
if let Some(ref status) = self.status
&& &cert.status != status
{
return false;
}
if let Some(ref profile) = self.profile
&& &cert.profile != profile
{
return false;
}
if let Some(expires_before) = self.expires_before
&& cert.not_after > expires_before
{
return false;
}
if let Some(expires_after) = self.expires_after
&& cert.not_after < expires_after
{
return false;
}
if let Some(ref subject_contains) = self.subject_contains
&& !cert.subject.contains(subject_contains)
{
return false;
}
true
}
}
impl Default for X509Config {
fn default() -> Self {
let mut certificate_profiles = HashMap::new();
certificate_profiles.insert(
"tls_server".to_string(),
CertificateProfile {
name: "TLS Server".to_string(),
cert_type: CertificateType::TlsServer,
key_usage: vec![KeyUsage::DigitalSignature, KeyUsage::KeyEncipherment],
extended_key_usage: vec![ExtendedKeyUsage::ServerAuth],
subject_alt_names: vec![],
validity_days: 365,
preferred_key_type: KeyType::Rsa(2048),
extensions: HashMap::new(),
},
);
certificate_profiles.insert(
"tls_client".to_string(),
CertificateProfile {
name: "TLS Client".to_string(),
cert_type: CertificateType::TlsClient,
key_usage: vec![KeyUsage::DigitalSignature, KeyUsage::KeyAgreement],
extended_key_usage: vec![ExtendedKeyUsage::ClientAuth],
subject_alt_names: vec![],
validity_days: 365,
preferred_key_type: KeyType::Rsa(2048),
extensions: HashMap::new(),
},
);
Self {
default_validity_days: 365,
root_ca_cert_path: "ca/root-ca.crt".to_string(),
root_ca_path: "ca/root-ca.crt".to_string(),
root_ca_key_path: "ca/root-ca.key".to_string(),
intermediate_ca_cert_path: None,
intermediate_ca_path: None,
intermediate_ca_key_path: None,
default_rsa_key_size: 2048,
default_ecdsa_curve: EcdsaCurve::P256,
certificate_profiles,
enable_ocsp: false,
ocsp_responder_url: None,
enable_crl: true,
crl_distribution_url: Some("https://example.com/crl".to_string()),
}
}
}
fn x509_extract_certificate_pem(pem: &str) -> String {
let mut in_cert = false;
let mut lines: Vec<&str> = Vec::new();
let mut collected = false;
for line in pem.lines() {
if line.starts_with("-----BEGIN CERTIFICATE-----") {
if collected {
break; }
in_cert = true;
collected = true;
lines.push(line);
} else if line.starts_with("-----END CERTIFICATE-----") {
lines.push(line);
in_cert = false;
} else if in_cert {
lines.push(line);
}
}
if collected {
lines.join("\n") + "\n"
} else {
pem.to_string()
}
}
struct AwsSigV4Request<'a> {
access_key: &'a str,
secret_key: &'a str,
session_token: Option<&'a str>,
region: &'a str,
service: &'a str,
method: &'a str,
host: &'a str,
path: &'a str,
query: &'a str,
payload: &'a [u8],
amz_date: &'a str,
date_stamp: &'a str,
amz_target: &'a str,
}
impl<'a> AwsSigV4Request<'a> {
fn new(access_key: &'a str, secret_key: &'a str) -> Self {
Self {
access_key,
secret_key,
session_token: None,
region: "us-east-1",
service: "",
method: "POST",
host: "",
path: "/",
query: "",
payload: b"",
amz_date: "",
date_stamp: "",
amz_target: "",
}
}
fn session_token(mut self, token: Option<&'a str>) -> Self {
self.session_token = token;
self
}
fn region(mut self, region: &'a str) -> Self {
self.region = region;
self
}
fn service(mut self, service: &'a str) -> Self {
self.service = service;
self
}
fn method(mut self, method: &'a str) -> Self {
self.method = method;
self
}
fn host(mut self, host: &'a str) -> Self {
self.host = host;
self
}
fn payload(mut self, payload: &'a [u8]) -> Self {
self.payload = payload;
self
}
fn amz_date(mut self, amz_date: &'a str) -> Self {
self.amz_date = amz_date;
self
}
fn date_stamp(mut self, date_stamp: &'a str) -> Self {
self.date_stamp = date_stamp;
self
}
fn amz_target(mut self, amz_target: &'a str) -> Self {
self.amz_target = amz_target;
self
}
fn sign(&self) -> String {
use hmac::{Mac, SimpleHmac};
use sha2::{Digest, Sha256};
fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
let mut mac =
<SimpleHmac<Sha256>>::new_from_slice(key).expect("HMAC accepts any key size");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
fn sha256hex(data: &[u8]) -> String {
let mut h = Sha256::new();
h.update(data);
hex::encode(h.finalize())
}
let mut headers: Vec<(String, String)> = vec![
("content-type".into(), "application/x-amz-json-1.1".into()),
("host".into(), self.host.into()),
("x-amz-date".into(), self.amz_date.into()),
("x-amz-target".into(), self.amz_target.into()),
];
if let Some(tok) = self.session_token {
headers.push(("x-amz-security-token".into(), tok.into()));
}
headers.sort_by(|a, b| a.0.cmp(&b.0));
let canonical_headers: String = headers
.iter()
.map(|(k, v)| format!("{}:{}\n", k, v.trim()))
.collect();
let signed_headers: String = headers
.iter()
.map(|(k, _)| k.as_str())
.collect::<Vec<_>>()
.join(";");
let canonical_request = format!(
"{method}\n{path}\n{query}\n{canonical_headers}\n{signed_headers}\n{payload_hash}",
method = self.method,
path = self.path,
query = self.query,
canonical_headers = canonical_headers,
signed_headers = signed_headers,
payload_hash = sha256hex(self.payload),
);
let credential_scope =
format!("{}/{}/{}/aws4_request", self.date_stamp, self.region, self.service);
let string_to_sign = format!(
"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{canonical_hash}",
amz_date = self.amz_date,
credential_scope = credential_scope,
canonical_hash = sha256hex(canonical_request.as_bytes()),
);
let signing_key = hmac_sha256(
&hmac_sha256(
&hmac_sha256(
&hmac_sha256(
format!("AWS4{}", self.secret_key).as_bytes(),
self.date_stamp.as_bytes(),
),
self.region.as_bytes(),
),
self.service.as_bytes(),
),
b"aws4_request",
);
let signature = hex::encode(hmac_sha256(&signing_key, string_to_sign.as_bytes()));
format!(
"AWS4-HMAC-SHA256 Credential={access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
access_key = self.access_key,
credential_scope = credential_scope,
signed_headers = signed_headers,
signature = signature,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_x509_manager_creation() {
let config = X509Config::default();
let manager = X509CertificateManager::new(config);
assert!(!manager.config.certificate_profiles.is_empty());
assert_eq!(manager.config.default_validity_days, 365);
}
#[tokio::test]
async fn test_certificate_profile() {
let config = X509Config::default();
assert!(config.certificate_profiles.contains_key("tls_server"));
assert!(config.certificate_profiles.contains_key("tls_client"));
let tls_server_profile = &config.certificate_profiles["tls_server"];
assert_eq!(tls_server_profile.cert_type, CertificateType::TlsServer);
assert!(
tls_server_profile
.extended_key_usage
.contains(&ExtendedKeyUsage::ServerAuth)
);
}
#[tokio::test]
async fn test_certificate_filter() {
let filter = CertificateFilter {
status: Some(CertificateStatus::Valid),
profile: None,
expires_before: None,
expires_after: None,
subject_contains: Some("example.com".to_string()),
};
let cert = StoredCertificate {
cert_id: "test".to_string(),
certificate_pem: "".to_string(),
private_key_pem: None,
subject: "CN=example.com".to_string(),
issuer: "CN=Test CA".to_string(),
serial_number: "123".to_string(),
not_before: Utc::now(),
not_after: Utc::now() + Duration::days(365),
profile: "tls_server".to_string(),
status: CertificateStatus::Valid,
fingerprint: "test_fp".to_string(),
created_at: Utc::now(),
metadata: HashMap::new(),
};
assert!(filter.matches(&cert));
}
#[test]
fn test_x509_extract_certificate_pem_single_cert() {
let pem = "-----BEGIN CERTIFICATE-----\nMIIBxx==\n-----END CERTIFICATE-----\n";
let extracted = x509_extract_certificate_pem(pem);
assert!(extracted.contains("-----BEGIN CERTIFICATE-----"));
assert!(extracted.contains("-----END CERTIFICATE-----"));
assert!(extracted.contains("MIIBxx=="));
}
#[test]
fn test_x509_extract_certificate_pem_strips_key() {
let bundle = concat!(
"-----BEGIN RSA PRIVATE KEY-----\nMIIEowIBAAK==\n-----END RSA PRIVATE KEY-----\n",
"-----BEGIN CERTIFICATE-----\nMIICert==\n-----END CERTIFICATE-----\n",
);
let extracted = x509_extract_certificate_pem(bundle);
assert!(
!extracted.contains("PRIVATE KEY"),
"Private key must be stripped"
);
assert!(extracted.contains("-----BEGIN CERTIFICATE-----"));
assert!(extracted.contains("MIICert=="));
}
#[test]
fn test_x509_extract_certificate_pem_keeps_first_only() {
let bundle = concat!(
"-----BEGIN CERTIFICATE-----\nMIIFirst==\n-----END CERTIFICATE-----\n",
"-----BEGIN CERTIFICATE-----\nMIISecond==\n-----END CERTIFICATE-----\n",
);
let extracted = x509_extract_certificate_pem(bundle);
assert!(
extracted.contains("MIIFirst=="),
"First cert should be kept"
);
assert!(
!extracted.contains("MIISecond=="),
"Second cert must be discarded"
);
}
#[test]
fn test_aws_sigv4_authorization_format() {
let auth = AwsSigV4Request::new(
"AKIAIOSFODNN7EXAMPLE",
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
)
.region("us-east-1")
.service("secretsmanager")
.method("POST")
.host("secretsmanager.us-east-1.amazonaws.com")
.payload(b"{\"SecretId\":\"my-secret\"}")
.amz_date("20230101T000000Z")
.date_stamp("20230101")
.amz_target("secretsmanager.GetSecretValue")
.sign();
assert!(auth.starts_with("AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20230101/"));
assert!(auth.contains("SignedHeaders="));
assert!(auth.contains("Signature="));
let sig_part = auth.split("Signature=").nth(1).unwrap_or("");
assert_eq!(sig_part.len(), 64, "SigV4 signature must be 64 hex chars");
}
#[tokio::test]
async fn test_azure_vault_missing_tenant_id() {
if std::env::var("X509_AZURE_TENANT_ID").is_ok() {
return;
}
let config = X509Config::default();
let manager = X509CertificateManager::new(config);
let result = manager
.load_ca_from_azure_vault("https://test.vault.azure.net", "my-ca")
.await;
assert!(result.is_err(), "Should fail when tenant_id is not set");
let msg = format!("{}", result.unwrap_err());
assert!(
msg.contains("X509_AZURE_TENANT_ID"),
"Error should name the missing variable: {msg}"
);
}
#[tokio::test]
async fn test_aws_secrets_missing_access_key() {
if std::env::var("AWS_ACCESS_KEY_ID").is_ok() {
return;
}
let config = X509Config::default();
let manager = X509CertificateManager::new(config);
let result = manager.load_ca_from_aws_secrets("my-ca-cert").await;
assert!(
result.is_err(),
"Should fail when AWS_ACCESS_KEY_ID is not set"
);
let msg = format!("{}", result.unwrap_err());
assert!(
msg.contains("AWS_ACCESS_KEY_ID"),
"Error should name the missing variable: {msg}"
);
}
#[tokio::test]
#[cfg(feature = "hsm")]
async fn test_hsm_invalid_json_config() {
let config = X509Config::default();
let manager = X509CertificateManager::new(config);
let result = manager.load_ca_from_hsm("not-valid-json").await;
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(
msg.contains("JSON") || msg.contains("json") || msg.contains("X509_HSM_CONFIG"),
"Error should mention JSON parsing: {msg}"
);
}
#[tokio::test]
#[cfg(feature = "hsm")]
async fn test_hsm_missing_library_field() {
let config = X509Config::default();
let manager = X509CertificateManager::new(config);
let result = manager
.load_ca_from_hsm(r#"{"slot": 0, "pin": "1234", "label": "ca-cert"}"#)
.await;
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(
msg.contains("library"),
"Error should mention the missing 'library' field: {msg}"
);
}
#[tokio::test]
#[cfg(feature = "hsm")]
async fn test_hsm_nonexistent_library_path() {
let config = X509Config::default();
let manager = X509CertificateManager::new(config);
let result = manager
.load_ca_from_hsm(
r#"{"library": "/nonexistent/pkcs11/libpkcs11.so", "slot": 0, "pin": "", "label": "ca-cert"}"#,
)
.await;
assert!(
result.is_err(),
"Expected error loading non-existent PKCS#11 library"
);
}
}