use std::sync::{Arc, RwLock};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio_stream::StreamExt;
use tracing::{debug, info, warn};
use crate::security::certificate::{CertificateBundle, MeshCertificate};
use crate::security::error::SecurityError;
use super::automerge_store::AutomergeStore;
use super::typed_collection::TypedCollection;
const CERTIFICATES_COLLECTION: &str = "certificates";
const REVOCATIONS_COLLECTION: &str = "revocations";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificateEntry {
pub cert_data: String,
pub node_id: String,
pub tier: String,
pub subject: String,
pub issuer: String,
pub permissions: u8,
pub expires_at_ms: u64,
pub written_at_ms: u64,
}
impl CertificateEntry {
pub fn from_certificate(cert: &MeshCertificate) -> Self {
Self {
cert_data: base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
cert.encode(),
),
node_id: cert.node_id.clone(),
tier: cert.tier.as_str().to_string(),
subject: hex::encode(cert.subject_public_key),
issuer: hex::encode(cert.issuer_public_key),
permissions: cert.permissions,
expires_at_ms: cert.expires_at_ms,
written_at_ms: now_ms(),
}
}
pub fn to_certificate(&self) -> Result<MeshCertificate, SecurityError> {
let bytes =
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &self.cert_data)
.map_err(|e| SecurityError::SerializationError(format!("invalid base64: {e}")))?;
MeshCertificate::decode(&bytes)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RevocationEntry {
pub subject: String,
pub reason: String,
pub timestamp_ms: u64,
}
pub struct CertificateStore {
certificates: TypedCollection<CertificateEntry>,
revocations: TypedCollection<RevocationEntry>,
bundle: Arc<RwLock<CertificateBundle>>,
store: Arc<AutomergeStore>,
}
impl CertificateStore {
pub fn new(store: Arc<AutomergeStore>, trusted_authorities: &[[u8; 32]]) -> Self {
let mut bundle = CertificateBundle::new();
for auth in trusted_authorities {
bundle.add_authority(*auth);
}
let certificates = TypedCollection::new(store.clone(), CERTIFICATES_COLLECTION);
let revocations = TypedCollection::new(store.clone(), REVOCATIONS_COLLECTION);
Self {
certificates,
revocations,
bundle: Arc::new(RwLock::new(bundle)),
store,
}
}
pub fn bundle(&self) -> Arc<RwLock<CertificateBundle>> {
self.bundle.clone()
}
pub fn publish_certificate(&self, cert: &MeshCertificate) -> Result<()> {
let subject_hex = hex::encode(cert.subject_public_key);
let entry = CertificateEntry::from_certificate(cert);
self.certificates.upsert(&subject_hex, &entry)?;
let mut bundle = self.bundle.write().unwrap_or_else(|e| e.into_inner());
if let Err(e) = bundle.add_certificate(cert.clone()) {
debug!(subject = %subject_hex, error = %e, "adding cert unchecked (issuer may not be loaded yet)");
bundle.add_certificate_unchecked(cert.clone());
}
info!(
subject = %subject_hex,
node_id = %cert.node_id,
tier = %cert.tier,
"published certificate to CRDT"
);
Ok(())
}
pub fn publish_revocation(&self, subject_public_key: &[u8; 32], reason: &str) -> Result<()> {
let subject_hex = hex::encode(subject_public_key);
let entry = RevocationEntry {
subject: subject_hex.clone(),
reason: reason.to_string(),
timestamp_ms: now_ms(),
};
self.revocations.upsert(&subject_hex, &entry)?;
info!(subject = %subject_hex, reason, "published revocation to CRDT");
Ok(())
}
pub fn load_all(&self) -> Result<usize> {
let entries = self.certificates.scan()?;
let revocations = self.load_revocation_set()?;
let mut count = 0;
let mut bundle = self.bundle.write().unwrap_or_else(|e| e.into_inner());
for (_id, entry) in &entries {
if revocations.contains(&entry.subject) {
debug!(subject = %entry.subject, "skipping revoked certificate");
continue;
}
match entry.to_certificate() {
Ok(cert) => {
if let Err(e) = bundle.add_certificate(cert) {
warn!(
subject = %entry.subject,
error = %e,
"skipping certificate during load (validation failed)"
);
continue;
}
count += 1;
}
Err(e) => {
warn!(subject = %entry.subject, error = %e, "failed to decode certificate entry");
}
}
}
info!(
loaded = count,
total_entries = entries.len(),
"loaded certificates from CRDT"
);
Ok(count)
}
pub async fn watch_and_reload(&self) {
let mut cert_stream = self.certificates.subscribe_all();
let mut revoke_stream = self.revocations.subscribe_all();
loop {
tokio::select! {
Some(subject_hex) = cert_stream.next() => {
self.handle_cert_change(&subject_hex);
}
Some(subject_hex) = revoke_stream.next() => {
self.handle_revocation_change(&subject_hex);
}
else => break,
}
}
}
fn handle_cert_change(&self, subject_hex: &str) {
if let Ok(Some(_)) = self.revocations.get(subject_hex) {
debug!(subject = %subject_hex, "ignoring cert change for revoked subject");
return;
}
match self.certificates.get(subject_hex) {
Ok(Some(entry)) => match entry.to_certificate() {
Ok(cert) => {
let mut bundle = self.bundle.write().unwrap_or_else(|e| e.into_inner());
if let Err(e) = bundle.add_certificate(cert) {
warn!(
subject = %subject_hex,
error = %e,
"rejecting certificate on change (validation failed)"
);
return;
}
debug!(subject = %subject_hex, "hot-reloaded certificate from CRDT");
}
Err(e) => {
warn!(subject = %subject_hex, error = %e, "failed to decode cert on change");
}
},
Ok(None) => {
debug!(subject = %subject_hex, "cert change notification but entry missing");
}
Err(e) => {
warn!(subject = %subject_hex, error = %e, "failed to read cert on change");
}
}
}
fn handle_revocation_change(&self, subject_hex: &str) {
match self.revocations.get(subject_hex) {
Ok(Some(entry)) => {
if let Ok(subject_bytes) = hex::decode(&entry.subject) {
if subject_bytes.len() == 32 {
let mut key = [0u8; 32];
key.copy_from_slice(&subject_bytes);
if let Err(e) = self.rebuild_bundle() {
warn!(error = %e, "failed to rebuild bundle after revocation");
} else {
info!(subject = %subject_hex, reason = %entry.reason, "applied revocation");
}
}
}
}
Ok(None) => {}
Err(e) => {
warn!(subject = %subject_hex, error = %e, "failed to read revocation");
}
}
}
fn rebuild_bundle(&self) -> Result<()> {
let entries = self.certificates.scan()?;
let revocations = self.load_revocation_set()?;
let mut bundle = self.bundle.write().unwrap_or_else(|e| e.into_inner());
let auth_count = bundle.authority_count();
let mut count = 0;
for (_id, entry) in &entries {
if revocations.contains(&entry.subject) {
continue;
}
match entry.to_certificate() {
Ok(cert) => {
if let Err(e) = bundle.add_certificate(cert) {
warn!(subject = %entry.subject, error = %e, "skipping invalid cert in rebuild");
continue;
}
count += 1;
}
Err(e) => {
warn!(subject = %entry.subject, error = %e, "skipping bad cert in rebuild");
}
}
}
debug!(
certs = count,
authorities = auth_count,
"rebuilt certificate bundle"
);
Ok(())
}
fn load_revocation_set(&self) -> Result<std::collections::HashSet<String>> {
let entries = self.revocations.scan()?;
Ok(entries.into_iter().map(|(_, e)| e.subject).collect())
}
pub fn store(&self) -> &Arc<AutomergeStore> {
&self.store
}
}
fn now_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::certificate::{permissions, MeshTier};
use crate::security::keypair::DeviceKeypair;
fn make_store() -> Arc<AutomergeStore> {
Arc::new(AutomergeStore::in_memory())
}
fn make_cert(
authority: &DeviceKeypair,
member: &DeviceKeypair,
node_id: &str,
tier: MeshTier,
) -> MeshCertificate {
let now = now_ms();
MeshCertificate::new(
member.public_key_bytes(),
"DEADBEEF".to_string(),
node_id.to_string(),
tier,
permissions::STANDARD,
now,
now + 3600000,
authority.public_key_bytes(),
)
.signed(authority)
}
#[test]
fn test_publish_and_load_certificate() {
let store = make_store();
let authority = DeviceKeypair::generate();
let member = DeviceKeypair::generate();
let cert_store = CertificateStore::new(store, &[authority.public_key_bytes()]);
let cert = make_cert(&authority, &member, "tac-1", MeshTier::Tactical);
cert_store.publish_certificate(&cert).unwrap();
let bundle = cert_store.bundle();
let b = bundle.read().unwrap();
assert!(b.validate_node_id("tac-1", now_ms()));
assert_eq!(b.len(), 1);
}
#[test]
fn test_load_all_from_crdt() {
let store = make_store();
let authority = DeviceKeypair::generate();
let cert_store = CertificateStore::new(store.clone(), &[authority.public_key_bytes()]);
for i in 0..3 {
let member = DeviceKeypair::generate();
let cert = make_cert(
&authority,
&member,
&format!("node-{i}"),
MeshTier::Tactical,
);
cert_store.publish_certificate(&cert).unwrap();
}
let cert_store2 = CertificateStore::new(store, &[authority.public_key_bytes()]);
let loaded = cert_store2.load_all().unwrap();
assert_eq!(loaded, 3);
let bundle = cert_store2.bundle();
let b = bundle.read().unwrap();
assert_eq!(b.len(), 3);
assert!(b.validate_node_id("node-0", now_ms()));
assert!(b.validate_node_id("node-1", now_ms()));
assert!(b.validate_node_id("node-2", now_ms()));
}
#[test]
fn test_revocation_skips_on_load() {
let store = make_store();
let authority = DeviceKeypair::generate();
let member = DeviceKeypair::generate();
let cert_store = CertificateStore::new(store.clone(), &[authority.public_key_bytes()]);
let cert = make_cert(&authority, &member, "revoked-node", MeshTier::Tactical);
cert_store.publish_certificate(&cert).unwrap();
cert_store
.publish_revocation(&member.public_key_bytes(), "compromised")
.unwrap();
let cert_store2 = CertificateStore::new(store, &[authority.public_key_bytes()]);
let loaded = cert_store2.load_all().unwrap();
assert_eq!(loaded, 0);
}
#[test]
fn test_certificate_entry_roundtrip() {
let authority = DeviceKeypair::generate();
let member = DeviceKeypair::generate();
let cert = make_cert(&authority, &member, "hub-1", MeshTier::Regional);
let entry = CertificateEntry::from_certificate(&cert);
let decoded = entry.to_certificate().unwrap();
assert_eq!(decoded.subject_public_key, cert.subject_public_key);
assert_eq!(decoded.mesh_id, cert.mesh_id);
assert_eq!(decoded.node_id, "hub-1");
assert_eq!(decoded.tier, MeshTier::Regional);
assert!(decoded.verify().is_ok());
}
#[test]
fn test_publish_root_certificate() {
let store = make_store();
let authority = DeviceKeypair::generate();
let cert_store = CertificateStore::new(store, &[authority.public_key_bytes()]);
let root = MeshCertificate::new_root(
&authority,
"DEADBEEF".to_string(),
"enterprise-0".to_string(),
MeshTier::Enterprise,
now_ms(),
0,
);
cert_store.publish_certificate(&root).unwrap();
let bundle = cert_store.bundle();
let b = bundle.read().unwrap();
assert!(b.validate_node_id("enterprise-0", now_ms()));
}
#[test]
fn test_rebuild_after_revocation() {
let store = make_store();
let authority = DeviceKeypair::generate();
let member1 = DeviceKeypair::generate();
let member2 = DeviceKeypair::generate();
let cert_store = CertificateStore::new(store, &[authority.public_key_bytes()]);
let cert1 = make_cert(&authority, &member1, "node-a", MeshTier::Tactical);
let cert2 = make_cert(&authority, &member2, "node-b", MeshTier::Tactical);
cert_store.publish_certificate(&cert1).unwrap();
cert_store.publish_certificate(&cert2).unwrap();
cert_store
.publish_revocation(&member1.public_key_bytes(), "lost device")
.unwrap();
cert_store.rebuild_bundle().unwrap();
let bundle = cert_store.bundle();
let b = bundle.read().unwrap();
assert!(b.validate_node_id("node-b", now_ms()));
}
#[tokio::test]
async fn test_enrollment_to_crdt_flow() {
use crate::security::enrollment::{
EnrollmentRequest, EnrollmentService, StaticEnrollmentService,
};
let store = make_store();
let authority = DeviceKeypair::generate();
let member = DeviceKeypair::generate();
let mut svc = StaticEnrollmentService::new(
authority.clone(),
"test-mesh".to_string(),
3600000, );
svc.add_token(
b"bootstrap-123".to_vec(),
MeshTier::Tactical,
permissions::STANDARD,
);
let request = EnrollmentRequest::new(
&member,
"test-mesh".to_string(),
"tac-node-1".to_string(),
MeshTier::Tactical,
b"bootstrap-123".to_vec(),
now_ms(),
);
let response = svc.process_request(&request).await.unwrap();
assert_eq!(
response.status,
crate::security::enrollment::EnrollmentStatus::Approved
);
let cert = response
.certificate
.expect("approved response should have certificate");
assert!(cert.verify().is_ok());
assert_eq!(cert.node_id, "tac-node-1");
assert_eq!(cert.subject_public_key, member.public_key_bytes());
let cert_store = CertificateStore::new(store.clone(), &[authority.public_key_bytes()]);
cert_store.publish_certificate(&cert).unwrap();
{
let bundle = cert_store.bundle();
let b = bundle.read().unwrap();
assert!(b.validate_node_id("tac-node-1", now_ms()));
assert_eq!(b.get_node_tier("tac-node-1"), Some(MeshTier::Tactical));
}
let cert_store2 = CertificateStore::new(store.clone(), &[authority.public_key_bytes()]);
let loaded = cert_store2.load_all().unwrap();
assert_eq!(loaded, 1);
{
let bundle2 = cert_store2.bundle();
let b = bundle2.read().unwrap();
assert!(b.validate_node_id("tac-node-1", now_ms()));
}
cert_store
.publish_revocation(&member.public_key_bytes(), "compromised device")
.unwrap();
cert_store.rebuild_bundle().unwrap();
let cert_store3 = CertificateStore::new(store, &[authority.public_key_bytes()]);
let loaded = cert_store3.load_all().unwrap();
assert_eq!(loaded, 0, "revoked certificate should not be loaded");
}
#[tokio::test]
async fn test_enrollment_bad_token_denied() {
use crate::security::enrollment::{
EnrollmentRequest, EnrollmentService, EnrollmentStatus, StaticEnrollmentService,
};
let authority = DeviceKeypair::generate();
let member = DeviceKeypair::generate();
let svc = StaticEnrollmentService::new(authority, "mesh-1".to_string(), 3600000);
let request = EnrollmentRequest::new(
&member,
"mesh-1".to_string(),
"rogue-node".to_string(),
MeshTier::Tactical,
b"invalid-token".to_vec(),
now_ms(),
);
let response = svc.process_request(&request).await.unwrap();
assert!(matches!(response.status, EnrollmentStatus::Denied { .. }));
assert!(response.certificate.is_none());
}
#[test]
fn test_bundle_validate_peer_bytes() {
let store = make_store();
let authority = DeviceKeypair::generate();
let member = DeviceKeypair::generate();
let cert_store = CertificateStore::new(store, &[authority.public_key_bytes()]);
let cert = make_cert(&authority, &member, "node-x", MeshTier::Tactical);
cert_store.publish_certificate(&cert).unwrap();
let bundle = cert_store.bundle();
let b = bundle.read().unwrap();
assert!(b.validate_peer(&member.public_key_bytes(), now_ms()));
let unknown = DeviceKeypair::generate();
assert!(!b.validate_peer(&unknown.public_key_bytes(), now_ms()));
}
}