use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sha2::{Digest, Sha256};
use tandem_types::{TenantContext, TenantSource};
use tokio::fs::{self, OpenOptions};
use tokio::io::AsyncWriteExt;
use uuid::Uuid;
use crate::{now_ms, AppState};
const AUDIT_SCHEMA_VERSION: u32 = 2;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuditDurability {
BestEffort,
DurableRequired,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProtectedAuditEnvelope {
pub event_id: String,
pub durability: AuditDurability,
pub event_type: String,
#[serde(default)]
pub tenant_context: TenantContext,
#[serde(skip_serializing_if = "Option::is_none")]
pub actor: Option<String>,
pub payload: Value,
pub created_at_ms: u64,
#[serde(default)]
pub seq: u64,
#[serde(default)]
pub prev_hash: Option<String>,
#[serde(default)]
pub record_hash: String,
}
#[derive(Serialize)]
struct AuditEnvelopeForHashing<'a> {
event_id: &'a str,
durability_str: &'a str,
event_type: &'a str,
tenant_org_id: &'a str,
tenant_workspace_id: &'a str,
tenant_deployment_id: &'a Option<String>,
tenant_actor_id: &'a Option<String>,
tenant_source: &'a TenantSource,
actor: &'a Option<String>,
payload: &'a Value,
created_at_ms: u64,
seq: u64,
prev_hash: &'a Option<String>,
}
fn durability_str(d: &AuditDurability) -> &'static str {
match d {
AuditDurability::BestEffort => "best_effort",
AuditDurability::DurableRequired => "durable_required",
}
}
pub(crate) fn compute_audit_envelope_hash(envelope: &ProtectedAuditEnvelope) -> String {
let for_hashing = AuditEnvelopeForHashing {
event_id: &envelope.event_id,
durability_str: durability_str(&envelope.durability),
event_type: &envelope.event_type,
tenant_org_id: &envelope.tenant_context.org_id,
tenant_workspace_id: &envelope.tenant_context.workspace_id,
tenant_deployment_id: &envelope.tenant_context.deployment_id,
tenant_actor_id: &envelope.tenant_context.actor_id,
tenant_source: &envelope.tenant_context.source,
actor: &envelope.actor,
payload: &envelope.payload,
created_at_ms: envelope.created_at_ms,
seq: envelope.seq,
prev_hash: &envelope.prev_hash,
};
let json = serde_json::to_string(&for_hashing)
.expect("audit envelope hash serialization is infallible");
format!("{:x}", Sha256::digest(json.as_bytes()))
}
async fn protected_audit_lock_for(path: &std::path::Path) -> Arc<tokio::sync::Mutex<()>> {
static LOCKS: OnceLock<tokio::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>> =
OnceLock::new();
let map = LOCKS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()));
let mut guard = map.lock().await;
guard
.entry(path.to_string_lossy().to_string())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
}
async fn read_last_protected_audit_record(
path: &std::path::Path,
) -> Option<ProtectedAuditEnvelope> {
let content = fs::read_to_string(path).await.ok()?;
content
.lines()
.filter_map(|line| serde_json::from_str::<ProtectedAuditEnvelope>(line.trim()).ok())
.max_by_key(|e| e.seq)
}
pub fn protected_audit_event_matches_tenant(
event: &ProtectedAuditEnvelope,
tenant_context: &TenantContext,
) -> bool {
tenant_context.is_local_implicit()
|| (event.tenant_context.org_id == tenant_context.org_id
&& event.tenant_context.workspace_id == tenant_context.workspace_id
&& event.tenant_context.deployment_id == tenant_context.deployment_id)
}
pub async fn load_protected_audit_events_for_tenant(
state: &AppState,
tenant_context: &TenantContext,
) -> Vec<ProtectedAuditEnvelope> {
let content = match fs::read_to_string(&state.protected_audit_path).await {
Ok(content) => content,
Err(_) => return Vec::new(),
};
let mut rows = content
.lines()
.filter_map(|line| {
let trimmed = line.trim();
if trimmed.is_empty() {
return None;
}
serde_json::from_str::<ProtectedAuditEnvelope>(trimmed).ok()
})
.filter(|event| protected_audit_event_matches_tenant(event, tenant_context))
.collect::<Vec<_>>();
rows.sort_by(|a, b| {
a.created_at_ms
.cmp(&b.created_at_ms)
.then(a.event_id.cmp(&b.event_id))
});
rows
}
pub async fn append_protected_audit_event(
state: &AppState,
event_type: impl Into<String>,
tenant_context: &TenantContext,
actor: Option<String>,
payload: Value,
) -> anyhow::Result<()> {
let path = state.protected_audit_path.clone();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await?;
}
let audit_lock = protected_audit_lock_for(&path).await;
let _guard = audit_lock.lock().await;
let last = read_last_protected_audit_record(&path).await;
let next_seq = last.as_ref().map(|e| e.seq).unwrap_or(0).saturating_add(1);
let prev_hash = last
.as_ref()
.filter(|e| !e.record_hash.is_empty())
.map(|e| e.record_hash.clone());
let mut row = ProtectedAuditEnvelope {
event_id: Uuid::new_v4().to_string(),
durability: AuditDurability::DurableRequired,
event_type: event_type.into(),
tenant_context: tenant_context.clone(),
actor,
payload,
created_at_ms: now_ms(),
seq: next_seq,
prev_hash,
record_hash: String::new(),
};
row.record_hash = compute_audit_envelope_hash(&row);
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.await?;
file.write_all(serde_json::to_string(&row)?.as_bytes())
.await?;
file.write_all(b"\n").await?;
file.flush().await?;
Ok(())
}
#[derive(Debug, Clone, PartialEq)]
pub enum AuditChainViolationKind {
RecordHashMismatch { expected: String },
ChainBreak { expected_prev: String },
SeqGap { expected_seq: u64 },
SeqReplay { seen_seq: u64 },
}
#[derive(Debug, Clone, PartialEq)]
pub struct AuditChainViolation {
pub seq: u64,
pub kind: AuditChainViolationKind,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AuditLedgerVerificationResult {
pub valid: bool,
pub record_count: u64,
pub hashed_record_count: u64,
pub root_hash: Option<String>,
pub schema_version: u32,
pub violation: Option<AuditChainViolation>,
}
pub async fn verify_protected_audit_ledger(
path: &std::path::Path,
) -> AuditLedgerVerificationResult {
let content = match fs::read_to_string(path).await {
Ok(c) => c,
Err(_) => {
return AuditLedgerVerificationResult {
valid: false,
record_count: 0,
hashed_record_count: 0,
root_hash: None,
schema_version: 0,
violation: None,
}
}
};
let mut records: Vec<ProtectedAuditEnvelope> = content
.lines()
.filter_map(|line| serde_json::from_str(line.trim()).ok())
.collect();
records.sort_by_key(|e| e.seq);
let record_count = records.len() as u64;
let schema_version = records
.iter()
.find(|e| e.seq > 0)
.map(|_| AUDIT_SCHEMA_VERSION)
.unwrap_or(1);
let seq_records: Vec<_> = records.iter().filter(|e| e.seq > 0).collect();
if !seq_records.is_empty() {
let mut expected = 1u64;
for record in &seq_records {
if record.seq < expected {
return AuditLedgerVerificationResult {
valid: false,
record_count,
hashed_record_count: 0,
root_hash: None,
schema_version,
violation: Some(AuditChainViolation {
seq: record.seq,
kind: AuditChainViolationKind::SeqReplay {
seen_seq: record.seq,
},
}),
};
}
if record.seq > expected {
return AuditLedgerVerificationResult {
valid: false,
record_count,
hashed_record_count: 0,
root_hash: None,
schema_version,
violation: Some(AuditChainViolation {
seq: expected,
kind: AuditChainViolationKind::SeqGap {
expected_seq: expected,
},
}),
};
}
expected = expected.saturating_add(1);
}
}
let hashed: Vec<_> = records
.iter()
.filter(|e| !e.record_hash.is_empty())
.collect();
let hashed_record_count = hashed.len() as u64;
let mut prev_hash: Option<String> = None;
for record in &hashed {
let expected_hash = compute_audit_envelope_hash(record);
if expected_hash != record.record_hash {
return AuditLedgerVerificationResult {
valid: false,
record_count,
hashed_record_count,
root_hash: None,
schema_version,
violation: Some(AuditChainViolation {
seq: record.seq,
kind: AuditChainViolationKind::RecordHashMismatch {
expected: expected_hash,
},
}),
};
}
if let Some(ref expected) = prev_hash {
if record.prev_hash.as_deref() != Some(expected.as_str()) {
return AuditLedgerVerificationResult {
valid: false,
record_count,
hashed_record_count,
root_hash: None,
schema_version,
violation: Some(AuditChainViolation {
seq: record.seq,
kind: AuditChainViolationKind::ChainBreak {
expected_prev: expected.clone(),
},
}),
};
}
}
prev_hash = Some(record.record_hash.clone());
}
AuditLedgerVerificationResult {
valid: true,
record_count,
hashed_record_count,
root_hash: prev_hash,
schema_version,
violation: None,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditLedgerManifest {
pub ledger_path: String,
pub schema_version: u32,
pub record_count: u64,
pub last_seq: u64,
pub root_hash: Option<String>,
pub generated_at_ms: u64,
}
pub async fn generate_audit_ledger_manifest(
path: &std::path::Path,
) -> anyhow::Result<AuditLedgerManifest> {
let result = verify_protected_audit_ledger(path).await;
let last_seq = read_last_protected_audit_record(path)
.await
.map(|e| e.seq)
.unwrap_or(0);
Ok(AuditLedgerManifest {
ledger_path: path.to_string_lossy().into_owned(),
schema_version: result.schema_version,
record_count: result.record_count,
last_seq,
root_hash: result.root_hash,
generated_at_ms: now_ms(),
})
}