use anyhow::Result;
use rskim_core::Mode;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::fs;
use std::path::{Path, PathBuf};
use std::time::SystemTime;
use crate::cascade::TruncationOptions;
#[derive(Debug, Serialize, Deserialize)]
struct CacheEntry {
path: String,
mtime_secs: u64,
mode: String,
content: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
original_tokens: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
transformed_tokens: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
effective_mode: Option<String>,
}
#[derive(Debug)]
pub(crate) struct CacheHit {
pub(crate) content: String,
pub(crate) original_tokens: Option<usize>,
pub(crate) transformed_tokens: Option<usize>,
}
pub(crate) struct CacheWriteParams<'a> {
pub(crate) path: &'a Path,
pub(crate) mode: Mode,
pub(crate) content: &'a str,
pub(crate) original_tokens: Option<usize>,
pub(crate) transformed_tokens: Option<usize>,
pub(crate) trunc: TruncationOptions,
pub(crate) effective_mode: Option<Mode>,
}
pub(crate) fn get_cache_dir() -> Result<PathBuf> {
let cache_dir = dirs::cache_dir()
.ok_or_else(|| anyhow::anyhow!("Failed to determine cache directory"))?
.join("skim");
#[cfg(unix)]
{
use std::fs::DirBuilder;
use std::os::unix::fs::DirBuilderExt;
let mut builder = DirBuilder::new();
builder.mode(0o700); builder.recursive(true);
builder.create(&cache_dir)?;
}
#[cfg(not(unix))]
{
fs::create_dir_all(&cache_dir)?;
}
Ok(cache_dir)
}
fn cache_key(
path: &Path,
mtime: SystemTime,
mode: Mode,
trunc: &TruncationOptions,
) -> Result<String> {
let canonical_path = path.canonicalize()?;
let mtime_secs = mtime.duration_since(SystemTime::UNIX_EPOCH)?.as_secs();
fn fmt_opt(opt: Option<usize>) -> String {
match opt {
Some(n) => n.to_string(),
None => "none".to_string(),
}
}
let hash_input = format!(
"{}|{}|{:?}|{}|{}|{}",
canonical_path.display(),
mtime_secs,
mode,
fmt_opt(trunc.max_lines),
fmt_opt(trunc.last_lines),
fmt_opt(trunc.token_budget),
);
let mut hasher = Sha256::new();
hasher.update(hash_input.as_bytes());
Ok(format!("{:x}", hasher.finalize()))
}
pub(crate) fn read_cache(path: &Path, mode: Mode, trunc: &TruncationOptions) -> Option<CacheHit> {
let metadata = fs::metadata(path).ok()?;
let mtime = metadata.modified().ok()?;
let key = cache_key(path, mtime, mode, trunc).ok()?;
let cache_file = get_cache_dir().ok()?.join(format!("{key}.json"));
let cache_content = fs::read_to_string(&cache_file).ok()?;
let entry: CacheEntry = serde_json::from_str(&cache_content).ok()?;
let mtime_secs = mtime.duration_since(SystemTime::UNIX_EPOCH).ok()?.as_secs();
let mode_str = format!("{mode:?}");
if entry.mtime_secs == mtime_secs && entry.mode == mode_str {
Some(CacheHit {
content: entry.content,
original_tokens: entry.original_tokens,
transformed_tokens: entry.transformed_tokens,
})
} else {
let _ = fs::remove_file(&cache_file);
None
}
}
pub(crate) fn write_cache(params: &CacheWriteParams<'_>) -> Result<()> {
let metadata = fs::metadata(params.path)?;
let mtime = metadata.modified()?;
let key = cache_key(params.path, mtime, params.mode, ¶ms.trunc)?;
let cache_file = get_cache_dir()?.join(format!("{key}.json"));
let mtime_secs = mtime.duration_since(SystemTime::UNIX_EPOCH)?.as_secs();
let mode = params.mode;
let entry = CacheEntry {
path: params.path.display().to_string(),
mtime_secs,
mode: format!("{mode:?}"),
content: params.content.to_string(),
original_tokens: params.original_tokens,
transformed_tokens: params.transformed_tokens,
effective_mode: params.effective_mode.map(|m| format!("{m:?}")),
};
let json = serde_json::to_string(&entry)?;
fs::write(&cache_file, json)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
fs::set_permissions(&cache_file, fs::Permissions::from_mode(0o600))?;
}
Ok(())
}
pub(crate) fn clear_cache() -> Result<()> {
let cache_dir = get_cache_dir()?;
if cache_dir.exists() {
for entry in fs::read_dir(&cache_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_file() && path.extension().is_some_and(|ext| ext == "json") {
let _ = fs::remove_file(&path);
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_cache_key_generation() {
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "test content").unwrap();
let path = temp_file.path();
let metadata = fs::metadata(path).unwrap();
let mtime = metadata.modified().unwrap();
let default_trunc = TruncationOptions::default();
let key1 = cache_key(path, mtime, Mode::Structure, &default_trunc).unwrap();
let key2 = cache_key(path, mtime, Mode::Structure, &default_trunc).unwrap();
assert_eq!(key1, key2);
let key3 = cache_key(path, mtime, Mode::Signatures, &default_trunc).unwrap();
assert_ne!(key1, key3);
let trunc_max = TruncationOptions {
max_lines: Some(50),
..Default::default()
};
let key4 = cache_key(path, mtime, Mode::Structure, &trunc_max).unwrap();
assert_ne!(key1, key4);
let key5 = cache_key(path, mtime, Mode::Structure, &trunc_max).unwrap();
assert_eq!(key4, key5);
let trunc_budget = TruncationOptions {
token_budget: Some(500),
..Default::default()
};
let key6 = cache_key(path, mtime, Mode::Structure, &trunc_budget).unwrap();
assert_ne!(key1, key6);
let key7 = cache_key(path, mtime, Mode::Structure, &trunc_budget).unwrap();
assert_eq!(key6, key7);
let trunc_both = TruncationOptions {
max_lines: Some(50),
token_budget: Some(500),
..Default::default()
};
let key8 = cache_key(path, mtime, Mode::Structure, &trunc_both).unwrap();
assert_ne!(key4, key8);
assert_ne!(key6, key8);
let trunc_last = TruncationOptions {
last_lines: Some(10),
..Default::default()
};
let key9 = cache_key(path, mtime, Mode::Structure, &trunc_last).unwrap();
assert_ne!(key1, key9);
let key10 = cache_key(path, mtime, Mode::Structure, &trunc_last).unwrap();
assert_eq!(key9, key10);
}
#[test]
fn test_cache_read_write() {
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "test content").unwrap();
let path = temp_file.path().to_path_buf();
let default_trunc = TruncationOptions::default();
assert!(read_cache(&path, Mode::Structure, &default_trunc).is_none());
let content = "transformed output";
write_cache(&CacheWriteParams {
path: &path,
mode: Mode::Structure,
content,
original_tokens: Some(100),
transformed_tokens: Some(50),
trunc: default_trunc,
effective_mode: None,
})
.unwrap();
let hit = read_cache(&path, Mode::Structure, &default_trunc).unwrap();
assert_eq!(hit.content, content);
assert_eq!(hit.original_tokens, Some(100));
assert_eq!(hit.transformed_tokens, Some(50));
assert!(read_cache(&path, Mode::Signatures, &default_trunc).is_none());
let trunc_max = TruncationOptions {
max_lines: Some(50),
..Default::default()
};
assert!(read_cache(&path, Mode::Structure, &trunc_max).is_none());
let trunc_last = TruncationOptions {
last_lines: Some(10),
..Default::default()
};
assert!(read_cache(&path, Mode::Structure, &trunc_last).is_none());
let trunc_budget = TruncationOptions {
token_budget: Some(500),
..Default::default()
};
assert!(read_cache(&path, Mode::Structure, &trunc_budget).is_none());
}
#[test]
fn test_cache_read_write_with_token_budget() {
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "test content for token budget").unwrap();
let path = temp_file.path().to_path_buf();
let trunc = TruncationOptions {
token_budget: Some(500),
..Default::default()
};
assert!(read_cache(&path, Mode::Structure, &trunc).is_none());
write_cache(&CacheWriteParams {
path: &path,
mode: Mode::Structure,
content: "budget-transformed output",
original_tokens: Some(200),
transformed_tokens: Some(80),
trunc,
effective_mode: None,
})
.unwrap();
let hit = read_cache(&path, Mode::Structure, &trunc).unwrap();
assert_eq!(hit.content, "budget-transformed output");
assert_eq!(hit.original_tokens, Some(200));
assert_eq!(hit.transformed_tokens, Some(80));
let default_trunc = TruncationOptions::default();
assert!(read_cache(&path, Mode::Structure, &default_trunc).is_none());
let trunc_1000 = TruncationOptions {
token_budget: Some(1000),
..Default::default()
};
assert!(read_cache(&path, Mode::Structure, &trunc_1000).is_none());
assert!(read_cache(&path, Mode::Signatures, &trunc).is_none());
}
#[test]
fn test_cache_stores_effective_mode() {
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "effective mode test content").unwrap();
let path = temp_file.path().to_path_buf();
let trunc = TruncationOptions {
token_budget: Some(100),
..Default::default()
};
write_cache(&CacheWriteParams {
path: &path,
mode: Mode::Structure,
content: "escalated output",
original_tokens: Some(150),
transformed_tokens: Some(60),
trunc,
effective_mode: Some(Mode::Signatures),
})
.unwrap();
let hit = read_cache(&path, Mode::Structure, &trunc).unwrap();
assert_eq!(hit.content, "escalated output");
assert_eq!(hit.original_tokens, Some(150));
assert_eq!(hit.transformed_tokens, Some(60));
let metadata = fs::metadata(&path).unwrap();
let mtime = metadata.modified().unwrap();
let key = cache_key(&path, mtime, Mode::Structure, &trunc).unwrap();
let cache_file = get_cache_dir().unwrap().join(format!("{key}.json"));
let raw_json = fs::read_to_string(&cache_file).unwrap();
let raw: serde_json::Value = serde_json::from_str(&raw_json).unwrap();
assert_eq!(
raw["effective_mode"].as_str(),
Some("Signatures"),
"effective_mode should be serialized in cache entry JSON"
);
}
#[test]
fn test_cache_invalidation_on_mtime_change() {
use std::fs::File;
use std::io::Write as IoWrite;
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path().to_path_buf();
let default_trunc = TruncationOptions::default();
{
let mut file = File::create(&path).unwrap();
file.write_all(b"original content").unwrap();
file.flush().unwrap();
}
write_cache(&CacheWriteParams {
path: &path,
mode: Mode::Structure,
content: "cached v1",
original_tokens: None,
transformed_tokens: None,
trunc: default_trunc,
effective_mode: None,
})
.unwrap();
let hit = read_cache(&path, Mode::Structure, &default_trunc).unwrap();
assert_eq!(hit.content, "cached v1");
std::thread::sleep(std::time::Duration::from_secs(1));
{
let mut file = File::create(&path).unwrap();
file.write_all(b"modified content").unwrap();
file.flush().unwrap();
}
assert!(read_cache(&path, Mode::Structure, &default_trunc).is_none());
}
}