use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use crate::merkle_spec_hash::{hex_encode, hex_to_array};
pub use crate::merkle_spec_hash::compute_spec_hash;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct EntryV2 {
mtime_ns: u64,
size: u64,
hash: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct OnDisk {
version: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
spec_hash: Option<String>,
entries: HashMap<String, EntryV2>,
}
const SCHEMA_VERSION: u32 = 2;
const MERKLE_SHARDS: usize = 64;
fn shard_index(path: &Path) -> usize {
let mut h = DefaultHasher::new();
path.hash(&mut h);
(h.finish() as usize) % MERKLE_SHARDS
}
#[derive(Debug, Clone, Copy)]
struct CacheEntry {
mtime_ns: u64,
size: u64,
hash: [u8; 32],
}
#[derive(Debug)]
pub struct MerkleIndex {
shards: Vec<RwLock<HashMap<PathBuf, CacheEntry>>>,
}
impl MerkleIndex {
pub fn empty() -> Self {
Self {
shards: (0..MERKLE_SHARDS)
.map(|_| RwLock::new(HashMap::new()))
.collect(),
}
}
pub fn load(path: &Path) -> Self {
sweep_stale_tmp_files(path);
Self::load_with_spec_inner(path, None)
}
pub fn load_with_spec(path: &Path, expected_spec_hash: &[u8; 32]) -> Self {
sweep_stale_tmp_files(path);
Self::load_with_spec_inner(path, Some(expected_spec_hash))
}
fn load_with_spec_inner(path: &Path, expected_spec_hash: Option<&[u8; 32]>) -> Self {
let bytes = match std::fs::read(path) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Self::empty(),
Err(e) => {
tracing::warn!(
cache = %path.display(),
error = %e,
"merkle index file read failed; treating as cold start"
);
return Self::empty();
}
};
let on_disk: OnDisk = match serde_json::from_slice(&bytes) {
Ok(d) => d,
Err(error) => {
tracing::warn!(
cache = %path.display(),
%error,
"merkle index parse failed; treating as cold start"
);
return Self::empty();
}
};
if on_disk.version != SCHEMA_VERSION {
tracing::warn!(
cache = %path.display(),
version = on_disk.version,
expected = SCHEMA_VERSION,
"merkle index schema mismatch; treating as cold start"
);
return Self::empty();
}
if let Some(expected) = expected_spec_hash {
let stored_match = on_disk
.spec_hash
.as_deref()
.and_then(hex_to_array)
.is_some_and(|stored| &stored == expected);
if !stored_match {
tracing::info!(
cache = %path.display(),
"detector spec changed since last scan; cache invalidated"
);
return Self::empty();
}
}
let entries: HashMap<PathBuf, CacheEntry> = on_disk
.entries
.into_iter()
.filter_map(|(p, e)| {
hex_to_array(&e.hash).map(|hash| {
(
PathBuf::from(p),
CacheEntry {
mtime_ns: e.mtime_ns,
size: e.size,
hash,
},
)
})
})
.collect();
tracing::info!(
cache = %path.display(),
count = entries.len(),
"merkle index loaded"
);
let idx = Self::empty();
for (p, e) in entries {
let i = shard_index(&p);
idx.shards[i].write().insert(p, e);
}
idx
}
pub fn save(&self, path: &Path) -> std::io::Result<()> {
self.save_inner(path, None)
}
pub fn save_with_spec(&self, path: &Path, spec_hash: &[u8; 32]) -> std::io::Result<()> {
self.save_inner(path, Some(spec_hash))
}
fn save_inner(&self, path: &Path, spec_hash: Option<&[u8; 32]>) -> std::io::Result<()> {
let mut merged = HashMap::<PathBuf, CacheEntry>::new();
let on_disk_now = match spec_hash {
Some(hash) => Self::load_with_spec(path, hash),
None => Self::load(path),
};
for shard in &on_disk_now.shards {
merged.extend(shard.read().iter().map(|(p, e)| (p.clone(), *e)));
}
for shard in &self.shards {
merged.extend(shard.read().iter().map(|(p, e)| (p.clone(), *e)));
}
let entries: HashMap<String, EntryV2> = merged
.iter()
.map(|(p, e)| {
(
p.display().to_string(),
EntryV2 {
mtime_ns: e.mtime_ns,
size: e.size,
hash: hex_encode(&e.hash),
},
)
})
.collect();
let on_disk = OnDisk {
version: SCHEMA_VERSION,
spec_hash: spec_hash.map(hex_encode),
entries,
};
let serialized = serde_json::to_vec_pretty(&on_disk)
.map_err(|e| std::io::Error::other(format!("merkle index encode: {e}")))?;
let parent = path.parent().unwrap_or_else(|| std::path::Path::new("."));
std::fs::create_dir_all(parent)?;
let mut tmp = tempfile::NamedTempFile::new_in(parent)?;
std::io::Write::write_all(&mut tmp, &serialized)?;
tmp.as_file().sync_all()?;
tmp.persist(path).map_err(|e| e.error)?;
Ok(())
}
pub fn hash_content(content: &[u8]) -> [u8; 32] {
*blake3::hash(content).as_bytes()
}
pub fn unchanged(&self, path: &Path, content_hash: &[u8; 32]) -> bool {
let i = shard_index(path);
self.shards[i]
.read()
.get(path)
.is_some_and(|prev| &prev.hash == content_hash)
}
pub fn metadata_unchanged(&self, path: &Path, mtime_ns: u64, size: u64) -> bool {
let i = shard_index(path);
self.shards[i]
.read()
.get(path)
.is_some_and(|prev| prev.mtime_ns == mtime_ns && prev.size == size)
}
pub fn lookup(&self, path: &Path) -> Option<(u64, u64, [u8; 32])> {
let i = shard_index(path);
self.shards[i]
.read()
.get(path)
.map(|e| (e.mtime_ns, e.size, e.hash))
}
pub fn record(&self, path: PathBuf, content_hash: [u8; 32]) {
self.record_with_metadata(path, 0, 0, content_hash);
}
pub fn record_with_metadata(
&self,
path: PathBuf,
mtime_ns: u64,
size: u64,
content_hash: [u8; 32],
) {
let i = shard_index(&path);
self.shards[i].write().insert(
path,
CacheEntry {
mtime_ns,
size,
hash: content_hash,
},
);
}
pub fn forget(&self, path: &Path) {
let i = shard_index(path);
self.shards[i].write().remove(path);
}
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.read().len()).sum()
}
pub fn is_empty(&self) -> bool {
self.shards.iter().all(|s| s.read().is_empty())
}
}
impl Default for MerkleIndex {
fn default() -> Self {
Self::empty()
}
}
pub fn default_cache_path() -> Option<PathBuf> {
dirs::cache_dir().map(|d| d.join("keyhog").join("merkle.idx"))
}
const STALE_TMP_CUTOFF_SECS: u64 = 60 * 60;
fn sweep_stale_tmp_files(cache_path: &Path) {
let Some(parent) = cache_path.parent() else {
return;
};
let Ok(entries) = std::fs::read_dir(parent) else {
return;
};
let stem = cache_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("merkle");
let now = std::time::SystemTime::now();
let mut swept = 0usize;
for entry in entries.flatten() {
let name = entry.file_name();
let Some(name_str) = name.to_str() else {
continue;
};
let is_tmp_sibling =
name_str.starts_with(&format!("{stem}.tmp")) || name_str.starts_with(".tmp");
if !is_tmp_sibling {
continue;
}
let path = entry.path();
let Ok(meta) = path.metadata() else { continue };
let Ok(modified) = meta.modified() else {
continue;
};
let age = match now.duration_since(modified) {
Ok(d) => d,
Err(_) => continue, };
if age.as_secs() < STALE_TMP_CUTOFF_SECS {
continue;
}
if std::fs::remove_file(&path).is_ok() {
swept += 1;
}
}
if swept > 0 {
tracing::debug!(
count = swept,
dir = %parent.display(),
"swept stale cache tmp files left by an interrupted save"
);
}
}