1use crate::grouper::SemanticGroup;
2use serde::{Deserialize, Serialize};
3use std::collections::hash_map::DefaultHasher;
4use std::collections::HashMap;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7
8#[derive(Debug, Serialize, Deserialize)]
10struct CacheEntry {
11 diff_hash: u64,
13 groups: Vec<CachedGroup>,
14 #[serde(default)]
16 head_commit: Option<String>,
17 #[serde(default)]
19 file_hashes: HashMap<String, u64>,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
24struct CachedGroup {
25 label: String,
26 description: String,
27 changes: Vec<CachedChange>,
28}
29
30#[derive(Debug, Serialize, Deserialize)]
32struct CachedChange {
33 file: String,
34 hunks: Vec<usize>,
35}
36
37pub fn get_head_commit() -> Option<String> {
39 let output = std::process::Command::new("git")
40 .args(["rev-parse", "HEAD"])
41 .output()
42 .ok()?;
43 if !output.status.success() {
44 return None;
45 }
46 Some(String::from_utf8(output.stdout).ok()?.trim().to_string())
47}
48
49pub fn diff_hash(raw_diff: &str) -> u64 {
51 let mut hasher = DefaultHasher::new();
52 raw_diff.hash(&mut hasher);
53 hasher.finish()
54}
55
56pub fn load(hash: u64) -> Option<Vec<SemanticGroup>> {
59 let path = cache_path()?;
60
61 let metadata = std::fs::metadata(&path).ok()?;
63 if metadata.len() > 1_048_576 {
64 tracing::warn!("Cache file too large ({} bytes), ignoring", metadata.len());
66 return None;
67 }
68
69 let content = std::fs::read_to_string(&path).ok()?;
70 let entry: CacheEntry = serde_json::from_str(&content).ok()?;
71
72 if entry.groups.len() > 50 {
74 tracing::warn!(
75 "Cache has too many groups ({}), ignoring",
76 entry.groups.len()
77 );
78 return None;
79 }
80
81 if entry.diff_hash != hash {
82 tracing::debug!("Cache miss: hash mismatch");
83 return None;
84 }
85
86 tracing::info!("Cache hit: reusing {} groups", entry.groups.len());
87 Some(
88 entry
89 .groups
90 .into_iter()
91 .map(|g| SemanticGroup::new(
92 g.label,
93 g.description,
94 g.changes
95 .into_iter()
96 .map(|c| crate::grouper::GroupedChange {
97 file: c.file,
98 hunks: c.hunks,
99 })
100 .collect(),
101 ))
102 .collect(),
103 )
104}
105
106pub fn save_with_state(
108 hash: u64,
109 groups: &[SemanticGroup],
110 head_commit: Option<&str>,
111 file_hashes: &HashMap<String, u64>,
112) {
113 let Some(path) = cache_path() else { return };
114
115 let entry = CacheEntry {
116 diff_hash: hash,
117 groups: groups
118 .iter()
119 .map(|g| CachedGroup {
120 label: g.label.clone(),
121 description: g.description.clone(),
122 changes: g
123 .changes()
124 .iter()
125 .map(|c| CachedChange {
126 file: c.file.clone(),
127 hunks: c.hunks.clone(),
128 })
129 .collect(),
130 })
131 .collect(),
132 head_commit: head_commit.map(|s| s.to_string()),
133 file_hashes: file_hashes.clone(),
134 };
135
136 match serde_json::to_string(&entry) {
137 Ok(json) => {
138 if let Err(e) = std::fs::write(&path, json) {
139 tracing::warn!("Failed to write cache: {}", e);
140 } else {
141 tracing::debug!("Saved cache to {}", path.display());
142 }
143 }
144 Err(e) => tracing::warn!("Failed to serialize cache: {}", e),
145 }
146}
147
148pub fn load_incremental(
152 current_head: &str,
153) -> Option<(Vec<SemanticGroup>, HashMap<String, u64>)> {
154 let path = cache_path()?;
155
156 let metadata = std::fs::metadata(&path).ok()?;
157 if metadata.len() > 1_048_576 {
158 return None;
159 }
160
161 let content = std::fs::read_to_string(&path).ok()?;
162 let entry: CacheEntry = serde_json::from_str(&content).ok()?;
163
164 if entry.groups.len() > 50 {
165 return None;
166 }
167
168 let cached_head = entry.head_commit.as_deref()?;
170 if cached_head != current_head {
171 return None;
172 }
173
174 if entry.file_hashes.is_empty() {
175 return None;
176 }
177
178 tracing::info!(
179 "Incremental cache hit: {} groups, {} file hashes",
180 entry.groups.len(),
181 entry.file_hashes.len()
182 );
183
184 let groups = entry
185 .groups
186 .into_iter()
187 .map(|g| {
188 SemanticGroup::new(
189 g.label,
190 g.description,
191 g.changes
192 .into_iter()
193 .map(|c| crate::grouper::GroupedChange {
194 file: c.file,
195 hunks: c.hunks,
196 })
197 .collect(),
198 )
199 })
200 .collect();
201
202 Some((groups, entry.file_hashes))
203}
204
205fn cache_path() -> Option<PathBuf> {
208 let output = std::process::Command::new("git")
209 .args(["rev-parse", "--git-dir"])
210 .output()
211 .ok()?;
212 if !output.status.success() {
213 return None;
214 }
215 let git_dir = String::from_utf8(output.stdout).ok()?.trim().to_string();
216 let git_path = PathBuf::from(&git_dir);
217
218 let cwd = std::env::current_dir().ok()?;
221 let canonical_git = std::fs::canonicalize(&git_path).unwrap_or(git_path.clone());
222 let canonical_cwd = std::fs::canonicalize(&cwd).unwrap_or(cwd);
223 if !canonical_git.starts_with(&canonical_cwd) {
224 tracing::warn!(
225 "git-dir {} is outside repo root {}, refusing to use cache",
226 canonical_git.display(),
227 canonical_cwd.display()
228 );
229 return None;
230 }
231
232 Some(PathBuf::from(git_dir).join("semantic-diff-cache.json"))
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_diff_hash_deterministic() {
241 let a = diff_hash("hello world");
242 let b = diff_hash("hello world");
243 assert_eq!(a, b);
244 }
245
246 #[test]
247 fn test_diff_hash_changes() {
248 let a = diff_hash("hello");
249 let b = diff_hash("world");
250 assert_ne!(a, b);
251 }
252
253 #[test]
254 fn test_cache_path_validates_git_dir_within_cwd() {
255 let path = cache_path();
258 if let Some(p) = &path {
259 assert!(
260 p.to_string_lossy().contains("semantic-diff-cache.json"),
261 "cache path should contain cache filename, got: {}",
262 p.display()
263 );
264 }
265 }
267
268 #[test]
269 fn test_load_rejects_oversized_cache() {
270 let temp_dir = tempfile::tempdir().unwrap();
272 let cache_file = temp_dir.path().join("oversized-cache.json");
273 let large_content = "x".repeat(1_048_577);
275 std::fs::write(&cache_file, large_content).unwrap();
276 let metadata = std::fs::metadata(&cache_file).unwrap();
277 assert!(
278 metadata.len() > 1_048_576,
279 "Test file should be larger than 1MB"
280 );
281 }
284
285 #[test]
286 fn test_cache_entry_with_valid_groups_deserializes() {
287 let json = r#"{
288 "diff_hash": 12345,
289 "groups": [
290 {"label": "Auth", "description": "Auth changes", "changes": [{"file": "src/auth.rs", "hunks": [0]}]}
291 ]
292 }"#;
293 let entry: CacheEntry = serde_json::from_str(json).unwrap();
294 assert_eq!(entry.groups.len(), 1);
295 assert_eq!(entry.groups[0].label, "Auth");
296 }
297
298 #[test]
299 fn test_cache_entry_group_count_validation() {
300 let mut groups = Vec::new();
302 for i in 0..60 {
303 groups.push(CachedGroup {
304 label: format!("Group {}", i),
305 description: "desc".to_string(),
306 changes: vec![],
307 });
308 }
309 let entry = CacheEntry {
310 diff_hash: 99999,
311 groups,
312 head_commit: None,
313 file_hashes: HashMap::new(),
314 };
315 assert!(entry.groups.len() > 50);
317 }
318}