use crate::grouper::SemanticGroup;
use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::path::PathBuf;
#[derive(Debug, Serialize, Deserialize)]
struct CacheEntry {
diff_hash: u64,
groups: Vec<CachedGroup>,
#[serde(default)]
head_commit: Option<String>,
#[serde(default)]
file_hashes: HashMap<String, u64>,
}
#[derive(Debug, Serialize, Deserialize)]
struct CachedGroup {
label: String,
description: String,
changes: Vec<CachedChange>,
}
#[derive(Debug, Serialize, Deserialize)]
struct CachedChange {
file: String,
hunks: Vec<usize>,
}
pub fn get_head_commit() -> Option<String> {
let output = std::process::Command::new("git")
.args(["rev-parse", "HEAD"])
.output()
.ok()?;
if !output.status.success() {
return None;
}
Some(String::from_utf8(output.stdout).ok()?.trim().to_string())
}
pub fn diff_hash(raw_diff: &str) -> u64 {
let mut hasher = DefaultHasher::new();
raw_diff.hash(&mut hasher);
hasher.finish()
}
pub fn load(hash: u64) -> Option<Vec<SemanticGroup>> {
let path = cache_path()?;
let metadata = std::fs::metadata(&path).ok()?;
if metadata.len() > 1_048_576 {
tracing::warn!("Cache file too large ({} bytes), ignoring", metadata.len());
return None;
}
let content = std::fs::read_to_string(&path).ok()?;
let entry: CacheEntry = serde_json::from_str(&content).ok()?;
if entry.groups.len() > 50 {
tracing::warn!(
"Cache has too many groups ({}), ignoring",
entry.groups.len()
);
return None;
}
if entry.diff_hash != hash {
tracing::debug!("Cache miss: hash mismatch");
return None;
}
tracing::info!("Cache hit: reusing {} groups", entry.groups.len());
Some(
entry
.groups
.into_iter()
.map(|g| SemanticGroup::new(
g.label,
g.description,
g.changes
.into_iter()
.map(|c| crate::grouper::GroupedChange {
file: c.file,
hunks: c.hunks,
})
.collect(),
))
.collect(),
)
}
pub fn save_with_state(
hash: u64,
groups: &[SemanticGroup],
head_commit: Option<&str>,
file_hashes: &HashMap<String, u64>,
) {
let Some(path) = cache_path() else { return };
let entry = CacheEntry {
diff_hash: hash,
groups: groups
.iter()
.map(|g| CachedGroup {
label: g.label.clone(),
description: g.description.clone(),
changes: g
.changes()
.iter()
.map(|c| CachedChange {
file: c.file.clone(),
hunks: c.hunks.clone(),
})
.collect(),
})
.collect(),
head_commit: head_commit.map(|s| s.to_string()),
file_hashes: file_hashes.clone(),
};
match serde_json::to_string(&entry) {
Ok(json) => {
if let Err(e) = std::fs::write(&path, json) {
tracing::warn!("Failed to write cache: {}", e);
} else {
tracing::debug!("Saved cache to {}", path.display());
}
}
Err(e) => tracing::warn!("Failed to serialize cache: {}", e),
}
}
pub fn load_incremental(
current_head: &str,
) -> Option<(Vec<SemanticGroup>, HashMap<String, u64>)> {
let path = cache_path()?;
let metadata = std::fs::metadata(&path).ok()?;
if metadata.len() > 1_048_576 {
return None;
}
let content = std::fs::read_to_string(&path).ok()?;
let entry: CacheEntry = serde_json::from_str(&content).ok()?;
if entry.groups.len() > 50 {
return None;
}
let cached_head = entry.head_commit.as_deref()?;
if cached_head != current_head {
return None;
}
if entry.file_hashes.is_empty() {
return None;
}
tracing::info!(
"Incremental cache hit: {} groups, {} file hashes",
entry.groups.len(),
entry.file_hashes.len()
);
let groups = entry
.groups
.into_iter()
.map(|g| {
SemanticGroup::new(
g.label,
g.description,
g.changes
.into_iter()
.map(|c| crate::grouper::GroupedChange {
file: c.file,
hunks: c.hunks,
})
.collect(),
)
})
.collect();
Some((groups, entry.file_hashes))
}
fn cache_path() -> Option<PathBuf> {
let output = std::process::Command::new("git")
.args(["rev-parse", "--git-dir"])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let git_dir = String::from_utf8(output.stdout).ok()?.trim().to_string();
let git_path = PathBuf::from(&git_dir);
let cwd = std::env::current_dir().ok()?;
let canonical_git = std::fs::canonicalize(&git_path).unwrap_or(git_path.clone());
let canonical_cwd = std::fs::canonicalize(&cwd).unwrap_or(cwd);
if !canonical_git.starts_with(&canonical_cwd) {
tracing::warn!(
"git-dir {} is outside repo root {}, refusing to use cache",
canonical_git.display(),
canonical_cwd.display()
);
return None;
}
Some(PathBuf::from(git_dir).join("semantic-diff-cache.json"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diff_hash_deterministic() {
let a = diff_hash("hello world");
let b = diff_hash("hello world");
assert_eq!(a, b);
}
#[test]
fn test_diff_hash_changes() {
let a = diff_hash("hello");
let b = diff_hash("world");
assert_ne!(a, b);
}
#[test]
fn test_cache_path_validates_git_dir_within_cwd() {
let path = cache_path();
if let Some(p) = &path {
assert!(
p.to_string_lossy().contains("semantic-diff-cache.json"),
"cache path should contain cache filename, got: {}",
p.display()
);
}
}
#[test]
fn test_load_rejects_oversized_cache() {
let temp_dir = tempfile::tempdir().unwrap();
let cache_file = temp_dir.path().join("oversized-cache.json");
let large_content = "x".repeat(1_048_577);
std::fs::write(&cache_file, large_content).unwrap();
let metadata = std::fs::metadata(&cache_file).unwrap();
assert!(
metadata.len() > 1_048_576,
"Test file should be larger than 1MB"
);
}
#[test]
fn test_cache_entry_with_valid_groups_deserializes() {
let json = r#"{
"diff_hash": 12345,
"groups": [
{"label": "Auth", "description": "Auth changes", "changes": [{"file": "src/auth.rs", "hunks": [0]}]}
]
}"#;
let entry: CacheEntry = serde_json::from_str(json).unwrap();
assert_eq!(entry.groups.len(), 1);
assert_eq!(entry.groups[0].label, "Auth");
}
#[test]
fn test_cache_entry_group_count_validation() {
let mut groups = Vec::new();
for i in 0..60 {
groups.push(CachedGroup {
label: format!("Group {}", i),
description: "desc".to_string(),
changes: vec![],
});
}
let entry = CacheEntry {
diff_hash: 99999,
groups,
head_commit: None,
file_hashes: HashMap::new(),
};
assert!(entry.groups.len() > 50);
}
}