use std::io::Read;
use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
use sha2::{Digest, Sha256};
use crate::codec::NeuCodecEncoder;
use crate::npy;
pub struct RefCodeCache {
dir: PathBuf,
}
impl RefCodeCache {
pub fn new() -> Result<Self> {
let base = dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from(".neutts_cache"));
Self::with_dir(base.join("neutts").join("ref_codes"))
}
pub fn with_dir(dir: impl Into<PathBuf>) -> Result<Self> {
let dir = dir.into();
std::fs::create_dir_all(&dir)
.with_context(|| format!("Cannot create cache directory: {}", dir.display()))?;
Ok(Self { dir })
}
pub fn dir(&self) -> &Path {
&self.dir
}
pub fn cache_path_for(&self, wav_path: &Path) -> Result<PathBuf> {
let hash = sha256_file(wav_path)?;
Ok(self.dir.join(format!("{hash}.npy")))
}
pub fn is_cached(&self, wav_path: &Path) -> Result<bool> {
let path = self.cache_path_for(wav_path)?;
Ok(path.exists())
}
pub fn try_load(&self, wav_path: &Path) -> Result<Option<(Vec<i32>, CacheOutcome)>> {
let hash = sha256_file(wav_path)
.with_context(|| format!("Failed to hash: {}", wav_path.display()))?;
let cache_file = self.dir.join(format!("{hash}.npy"));
if cache_file.exists() {
let codes = npy::load_npy_i32(&cache_file)
.with_context(|| format!("Failed to load cached codes: {}", cache_file.display()))?;
Ok(Some((codes, CacheOutcome::Hit { path: cache_file, hash })))
} else {
Ok(None)
}
}
pub fn store(&self, wav_path: &Path, codes: &[i32]) -> Result<CacheOutcome> {
let hash = sha256_file(wav_path)
.with_context(|| format!("Failed to hash: {}", wav_path.display()))?;
let cache_file = self.dir.join(format!("{hash}.npy"));
npy::write_npy_i32(&cache_file, codes)
.with_context(|| format!("Failed to write cache: {}", cache_file.display()))?;
Ok(CacheOutcome::Miss { path: cache_file, hash })
}
pub fn get_or_encode(
&self,
wav_path: &Path,
encoder: &NeuCodecEncoder,
) -> Result<(Vec<i32>, CacheOutcome)> {
if let Some(hit) = self.try_load(wav_path)? {
return Ok(hit);
}
let codes = encoder.encode_wav(wav_path)
.with_context(|| format!("Failed to encode: {}", wav_path.display()))?;
let outcome = self.store(wav_path, &codes)?;
Ok((codes, outcome))
}
pub fn evict(&self, wav_path: &Path) -> Result<bool> {
let path = self.cache_path_for(wav_path)?;
if path.exists() {
std::fs::remove_file(&path)
.with_context(|| format!("Failed to evict cache entry: {}", path.display()))?;
Ok(true)
} else {
Ok(false)
}
}
pub fn clear(&self) -> Result<usize> {
let mut count = 0;
for entry in std::fs::read_dir(&self.dir)
.with_context(|| format!("Cannot read cache dir: {}", self.dir.display()))?
{
let entry = entry.context("Failed to read dir entry")?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("npy") {
std::fs::remove_file(&path)
.with_context(|| format!("Failed to remove: {}", path.display()))?;
count += 1;
}
}
Ok(count)
}
}
#[derive(Debug, Clone)]
pub enum CacheOutcome {
Hit {
path: PathBuf,
hash: String,
},
Miss {
path: PathBuf,
hash: String,
},
}
impl CacheOutcome {
pub fn is_hit(&self) -> bool {
matches!(self, Self::Hit { .. })
}
pub fn path(&self) -> &Path {
match self {
Self::Hit { path, .. } => path,
Self::Miss { path, .. } => path,
}
}
pub fn hash(&self) -> &str {
match self {
Self::Hit { hash, .. } => hash,
Self::Miss { hash, .. } => hash,
}
}
}
impl std::fmt::Display for CacheOutcome {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Hit { hash, path } =>
write!(f, "cache hit (sha256: {}…) ← {}", &hash[..16], path.display()),
Self::Miss { hash, path } =>
write!(f, "cache miss (sha256: {}…) → {}", &hash[..16], path.display()),
}
}
}
pub fn sha256_file(path: &Path) -> Result<String> {
let mut file = std::fs::File::open(path)
.with_context(|| format!("Cannot open file for hashing: {}", path.display()))?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 65_536];
loop {
let n = file.read(&mut buf)
.with_context(|| format!("IO error while hashing: {}", path.display()))?;
if n == 0 { break; }
hasher.update(&buf[..n]);
}
Ok(format!("{:x}", hasher.finalize()))
}
#[cfg(test)]
mod tests {
use super::*;
fn tmp_dir() -> PathBuf {
let d = std::env::temp_dir().join(format!(
"neutts_cache_test_{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.subsec_nanos()
));
std::fs::create_dir_all(&d).unwrap();
d
}
#[test]
fn test_sha256_file_deterministic() {
let dir = tmp_dir();
let path = dir.join("test.bin");
std::fs::write(&path, b"hello neutts").unwrap();
let h1 = sha256_file(&path).unwrap();
let h2 = sha256_file(&path).unwrap();
assert_eq!(h1, h2);
assert_eq!(h1.len(), 64); }
#[test]
fn test_sha256_changes_with_content() {
let dir = tmp_dir();
let p1 = dir.join("a.bin");
let p2 = dir.join("b.bin");
std::fs::write(&p1, b"file a").unwrap();
std::fs::write(&p2, b"file b").unwrap();
assert_ne!(sha256_file(&p1).unwrap(), sha256_file(&p2).unwrap());
}
#[test]
fn test_cache_path_is_hash_based() {
let dir = tmp_dir();
let cache = RefCodeCache::with_dir(&dir).unwrap();
let wav = dir.join("ref.wav");
std::fs::write(&wav, b"fake wav content").unwrap();
let path = cache.cache_path_for(&wav).unwrap();
let hash = sha256_file(&wav).unwrap();
assert_eq!(path, dir.join(format!("{hash}.npy")));
}
#[test]
fn test_is_cached_returns_false_before_write() {
let dir = tmp_dir();
let cache = RefCodeCache::with_dir(&dir).unwrap();
let wav = dir.join("ref.wav");
std::fs::write(&wav, b"fake wav").unwrap();
assert!(!cache.is_cached(&wav).unwrap());
}
#[test]
fn test_try_load_miss_then_store_then_hit() {
let dir = tmp_dir();
let cache = RefCodeCache::with_dir(&dir).unwrap();
let wav = dir.join("ref.wav");
std::fs::write(&wav, b"fake wav content 123").unwrap();
assert!(cache.try_load(&wav).unwrap().is_none());
let codes: Vec<i32> = vec![1, 2, 3, 42, 1023];
let outcome = cache.store(&wav, &codes).unwrap();
assert!(!outcome.is_hit());
let (loaded, outcome2) = cache.try_load(&wav).unwrap().unwrap();
assert!(outcome2.is_hit());
assert_eq!(loaded, codes);
assert_eq!(outcome.path(), outcome2.path());
}
#[test]
fn test_evict_removes_entry() {
let dir = tmp_dir();
let cache = RefCodeCache::with_dir(&dir).unwrap();
let wav = dir.join("ref.wav");
std::fs::write(&wav, b"fake wav").unwrap();
let hash = sha256_file(&wav).unwrap();
let npy = dir.join(format!("{hash}.npy"));
std::fs::write(&npy, b"placeholder").unwrap();
assert!(cache.is_cached(&wav).unwrap());
let removed = cache.evict(&wav).unwrap();
assert!(removed);
assert!(!cache.is_cached(&wav).unwrap());
}
#[test]
fn test_evict_nonexistent_returns_false() {
let dir = tmp_dir();
let cache = RefCodeCache::with_dir(&dir).unwrap();
let wav = dir.join("ref.wav");
std::fs::write(&wav, b"fake wav").unwrap();
assert!(!cache.evict(&wav).unwrap());
}
#[test]
fn test_clear_removes_all_npy() {
let dir = tmp_dir();
let cache = RefCodeCache::with_dir(&dir).unwrap();
std::fs::write(dir.join("aaa.npy"), b"x").unwrap();
std::fs::write(dir.join("bbb.npy"), b"y").unwrap();
std::fs::write(dir.join("keep.txt"), b"z").unwrap();
let removed = cache.clear().unwrap();
assert_eq!(removed, 2);
assert!(!dir.join("aaa.npy").exists());
assert!(!dir.join("bbb.npy").exists());
assert!(dir.join("keep.txt").exists()); }
#[test]
fn test_cache_outcome_display() {
let hit = CacheOutcome::Hit {
path: PathBuf::from("/cache/abc.npy"),
hash: "abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890".to_string(),
};
let s = format!("{hit}");
assert!(s.contains("cache hit"));
assert!(s.contains("abcdef12345678"));
}
}