use std::collections::HashMap;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, RwLock};
use sha2::{Digest, Sha256};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuditMode {
Off,
#[default]
Default,
}
impl AuditMode {
pub fn enabled(&self) -> bool {
matches!(self, Self::Default)
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct AuditEntry {
pub id: u64,
pub ts_ms: u64,
pub kid: String,
pub ns: String,
pub event_type: String,
pub payload: serde_json::Value,
pub prev_hash: String,
pub entry_hash: String,
}
struct KeyChain {
file: Option<File>,
next_id: u64,
head_hash: String,
count: u64,
}
impl KeyChain {
fn in_memory() -> Self {
Self {
file: None,
next_id: 0,
head_hash: String::new(),
count: 0,
}
}
fn open(path: &Path) -> std::io::Result<Self> {
let mut file = OpenOptions::new()
.read(true)
.create(true)
.append(true)
.open(path)?;
let mut scan = OpenOptions::new().read(true).open(path)?;
let mut next_id = 0u64;
let mut head_hash = String::new();
let mut count = 0u64;
loop {
let mut len_buf = [0u8; 4];
match scan.read_exact(&mut len_buf) {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e),
}
let len = u32::from_le_bytes(len_buf) as usize;
let mut payload = vec![0u8; len];
if scan.read_exact(&mut payload).is_err() {
break;
}
if let Ok(entry) = serde_json::from_slice::<AuditEntry>(&payload) {
next_id = entry.id + 1;
head_hash = entry.entry_hash.clone();
count += 1;
}
}
let _ = file.flush();
Ok(Self {
file: Some(file),
next_id,
head_hash,
count,
})
}
}
pub struct AuditLog {
dir: Option<PathBuf>,
chains: RwLock<HashMap<String, Arc<Mutex<KeyChain>>>>,
mode: AuditMode,
}
impl AuditLog {
pub fn new(data_dir: Option<&str>, mode: AuditMode) -> Self {
let dir = data_dir.map(|d| {
let p = PathBuf::from(d).join("_audit");
let _ = fs::create_dir_all(&p);
p
});
let s = Self {
dir,
chains: RwLock::new(HashMap::new()),
mode,
};
s.scan_existing();
s
}
pub fn mode(&self) -> AuditMode {
self.mode
}
fn scan_existing(&self) {
let Some(dir) = self.dir.clone() else {
return;
};
let Ok(entries) = fs::read_dir(&dir) else {
return;
};
let mut map = self.chains.write().unwrap();
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map(|e| e == "log").unwrap_or(false) {
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
match KeyChain::open(&path) {
Ok(chain) => {
eprintln!(
"[audit_log] loaded kid={} entries={} head={}",
stem,
chain.count,
short(&chain.head_hash)
);
map.insert(stem.to_string(), Arc::new(Mutex::new(chain)));
}
Err(e) => eprintln!("[audit_log] error loading {}: {}", stem, e),
}
}
}
}
}
fn chain_for(&self, kid: &str) -> Arc<Mutex<KeyChain>> {
{
let map = self.chains.read().unwrap();
if let Some(c) = map.get(kid) {
return c.clone();
}
}
let mut map = self.chains.write().unwrap();
if let Some(c) = map.get(kid) {
return c.clone();
}
let chain = match self.dir.as_ref().map(|d| d.join(format!("{}.log", kid))) {
Some(path) => KeyChain::open(&path).unwrap_or_else(|e| {
eprintln!("[audit_log] cannot open {}: {}", path.display(), e);
KeyChain::in_memory()
}),
None => KeyChain::in_memory(),
};
let arc = Arc::new(Mutex::new(chain));
map.insert(kid.to_string(), arc.clone());
arc
}
pub fn record(&self, kid: &str, ns: &str, event_type: &str, payload: serde_json::Value) -> u64 {
let chain_arc = self.chain_for(kid);
let mut chain = chain_arc.lock().unwrap();
let id = chain.next_id;
let ts_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let prev_hash = chain.head_hash.clone();
let canonical_payload = serde_json::to_vec(&payload).unwrap_or_default();
let entry_hash = compute_entry_hash(
&prev_hash,
id,
ts_ms,
kid,
ns,
event_type,
&canonical_payload,
);
let entry = AuditEntry {
id,
ts_ms,
kid: kid.to_string(),
ns: ns.to_string(),
event_type: event_type.to_string(),
payload,
prev_hash,
entry_hash: entry_hash.clone(),
};
let serialized = serde_json::to_vec(&entry).unwrap_or_default();
if let Some(ref mut file) = chain.file {
let len = (serialized.len() as u32).to_le_bytes();
let _ = file.write_all(&len);
let _ = file.write_all(&serialized);
}
chain.next_id = id + 1;
chain.head_hash = entry_hash;
chain.count += 1;
id
}
pub fn heads(&self) -> Vec<ChainHead> {
let map = self.chains.read().unwrap();
let mut out: Vec<ChainHead> = map
.iter()
.map(|(kid, c)| {
let chain = c.lock().unwrap();
ChainHead {
kid: kid.clone(),
head_hash: chain.head_hash.clone(),
count: chain.count,
}
})
.collect();
out.sort_by(|a, b| a.kid.cmp(&b.kid));
out
}
}
#[derive(Clone, Debug, serde::Serialize)]
pub struct ChainHead {
pub kid: String,
pub head_hash: String,
pub count: u64,
}
pub fn compute_entry_hash(
prev_hash: &str,
id: u64,
ts_ms: u64,
kid: &str,
ns: &str,
event_type: &str,
canonical_payload: &[u8],
) -> String {
let mut h = Sha256::new();
let prev_bytes = if prev_hash.is_empty() {
[0u8; 32].to_vec()
} else {
hex::decode(prev_hash).unwrap_or_else(|_| vec![0u8; 32])
};
h.update(&prev_bytes);
h.update(id.to_le_bytes());
h.update(ts_ms.to_le_bytes());
h.update(kid.as_bytes());
h.update([0u8]); h.update(ns.as_bytes());
h.update([0u8]);
h.update(event_type.as_bytes());
h.update([0u8]);
h.update(canonical_payload);
hex::encode(h.finalize())
}
pub fn hash_query(query: &str) -> String {
let mut h = Sha256::new();
h.update(query.as_bytes());
hex::encode(h.finalize())
}
fn short(h: &str) -> String {
if h.len() <= 12 {
h.to_string()
} else {
format!("{}…", &h[..12])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chain_records_and_links() {
let log = AuditLog::new(None, AuditMode::Default);
let id0 = log.record(
"test-key",
"ns1",
"intent.add",
serde_json::json!({"intent_id": "foo"}),
);
let id1 = log.record(
"test-key",
"ns1",
"intent.add",
serde_json::json!({"intent_id": "bar"}),
);
assert_eq!(id0, 0);
assert_eq!(id1, 1);
let heads = log.heads();
assert_eq!(heads.len(), 1);
assert_eq!(heads[0].count, 2);
assert!(!heads[0].head_hash.is_empty());
}
#[test]
fn separate_chains_per_key() {
let log = AuditLog::new(None, AuditMode::Default);
log.record("key-a", "ns", "x", serde_json::json!({}));
log.record("key-b", "ns", "x", serde_json::json!({}));
log.record("key-a", "ns", "x", serde_json::json!({}));
let heads = log.heads();
assert_eq!(heads.len(), 2);
let a = heads.iter().find(|h| h.kid == "key-a").unwrap();
let b = heads.iter().find(|h| h.kid == "key-b").unwrap();
assert_eq!(a.count, 2);
assert_eq!(b.count, 1);
assert_ne!(a.head_hash, b.head_hash);
}
#[test]
fn modes_simple_two_state() {
assert!(!AuditMode::Off.enabled());
assert!(AuditMode::Default.enabled());
}
#[test]
fn entry_hashes_change_with_payload() {
let h1 = compute_entry_hash("", 0, 1000, "k", "n", "x", b"{\"a\":1}");
let h2 = compute_entry_hash("", 0, 1000, "k", "n", "x", b"{\"a\":2}");
assert_ne!(h1, h2);
}
}