use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
struct OnDisk {
version: u32,
entries: HashMap<String, String>,
}
const SCHEMA_VERSION: u32 = 1;
const MERKLE_SHARDS: usize = 64;
fn shard_index(path: &Path) -> usize {
let mut h = DefaultHasher::new();
path.hash(&mut h);
(h.finish() as usize) % MERKLE_SHARDS
}
#[derive(Debug)]
pub struct MerkleIndex {
shards: Vec<Mutex<HashMap<PathBuf, [u8; 32]>>>,
}
impl MerkleIndex {
pub fn empty() -> Self {
Self {
shards: (0..MERKLE_SHARDS)
.map(|_| Mutex::new(HashMap::new()))
.collect(),
}
}
pub fn load(path: &Path) -> Self {
let bytes = match std::fs::read(path) {
Ok(b) => b,
Err(_) => return Self::empty(),
};
let on_disk: OnDisk = match serde_json::from_slice(&bytes) {
Ok(d) => d,
Err(e) => {
tracing::warn!(
cache = %path.display(),
error = %e,
"merkle index parse failed; treating as cold start"
);
return Self::empty();
}
};
if on_disk.version != SCHEMA_VERSION {
tracing::warn!(
cache = %path.display(),
version = on_disk.version,
expected = SCHEMA_VERSION,
"merkle index schema mismatch; treating as cold start"
);
return Self::empty();
}
let entries: HashMap<PathBuf, [u8; 32]> = on_disk
.entries
.into_iter()
.filter_map(|(p, h)| hex_to_array(&h).map(|a| (PathBuf::from(p), a)))
.collect();
tracing::info!(
cache = %path.display(),
count = entries.len(),
"merkle index loaded"
);
let idx = Self::empty();
for (p, h) in entries {
idx.record(p, h);
}
idx
}
pub fn save(&self, path: &Path) -> std::io::Result<()> {
let mut merged = HashMap::<PathBuf, [u8; 32]>::new();
for shard in &self.shards {
merged.extend(shard.lock().iter().map(|(p, h)| (p.clone(), *h)));
}
let entries: HashMap<String, String> = merged
.iter()
.map(|(p, h)| (p.display().to_string(), hex_encode(h)))
.collect();
let on_disk = OnDisk {
version: SCHEMA_VERSION,
entries,
};
let serialized = serde_json::to_vec_pretty(&on_disk)
.map_err(|e| std::io::Error::other(format!("merkle index encode: {e}")))?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let tmp = path.with_extension(format!("tmp.{}", std::process::id()));
std::fs::write(&tmp, &serialized)?;
std::fs::rename(&tmp, path)?;
Ok(())
}
pub fn hash_content(content: &[u8]) -> [u8; 32] {
*blake3::hash(content).as_bytes()
}
pub fn unchanged(&self, path: &Path, content_hash: &[u8; 32]) -> bool {
let i = shard_index(path);
self.shards[i]
.lock()
.get(path)
.is_some_and(|prev| prev == content_hash)
}
pub fn record(&self, path: PathBuf, content_hash: [u8; 32]) {
let i = shard_index(&path);
self.shards[i].lock().insert(path, content_hash);
}
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.lock().len()).sum()
}
pub fn is_empty(&self) -> bool {
self.shards.iter().all(|s| s.lock().is_empty())
}
}
impl Default for MerkleIndex {
fn default() -> Self {
Self::empty()
}
}
pub fn default_cache_path() -> Option<PathBuf> {
dirs::cache_dir().map(|d| d.join("keyhog").join("merkle.idx"))
}
fn hex_encode(bytes: &[u8; 32]) -> String {
let mut out = String::with_capacity(64);
for b in bytes {
out.push_str(&format!("{:02x}", b));
}
out
}
fn hex_to_array(hex: &str) -> Option<[u8; 32]> {
if hex.len() != 64 {
return None;
}
let mut out = [0u8; 32];
for i in 0..32 {
out[i] = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16).ok()?;
}
Some(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn record_and_unchanged_roundtrip() {
let idx = MerkleIndex::empty();
let p = PathBuf::from("/tmp/example.env");
let h = MerkleIndex::hash_content(b"DB_PASS=secret123");
idx.record(p.clone(), h);
assert!(idx.unchanged(&p, &h));
let h2 = MerkleIndex::hash_content(b"DB_PASS=changed");
assert!(!idx.unchanged(&p, &h2));
}
#[test]
fn unknown_path_is_changed() {
let idx = MerkleIndex::empty();
let h = MerkleIndex::hash_content(b"x");
assert!(!idx.unchanged(&PathBuf::from("/never/seen"), &h));
}
#[test]
fn save_and_load_preserves_entries() {
let dir = tempfile::tempdir().unwrap();
let cache_path = dir.path().join("merkle.idx");
let idx = MerkleIndex::empty();
let p = PathBuf::from("/tmp/secrets.env");
let h = MerkleIndex::hash_content(b"hello world");
idx.record(p.clone(), h);
idx.save(&cache_path).expect("save");
let loaded = MerkleIndex::load(&cache_path);
assert_eq!(loaded.len(), 1);
assert!(loaded.unchanged(&p, &h));
}
#[test]
fn corrupted_cache_treated_as_cold_start() {
let dir = tempfile::tempdir().unwrap();
let cache_path = dir.path().join("merkle.idx");
std::fs::write(&cache_path, b"this is not json").unwrap();
let loaded = MerkleIndex::load(&cache_path);
assert!(loaded.is_empty());
}
#[test]
fn missing_cache_returns_empty() {
let loaded = MerkleIndex::load(Path::new("/definitely/does/not/exist.idx"));
assert!(loaded.is_empty());
}
#[test]
fn schema_version_mismatch_treated_as_cold_start() {
let dir = tempfile::tempdir().unwrap();
let cache_path = dir.path().join("merkle.idx");
let bad = serde_json::json!({
"version": 99,
"entries": { "/foo": "00".repeat(32) }
});
std::fs::write(&cache_path, serde_json::to_vec(&bad).unwrap()).unwrap();
let loaded = MerkleIndex::load(&cache_path);
assert!(loaded.is_empty());
}
}