Skip to main content

semantic_diff/
cache.rs

1use crate::grouper::SemanticGroup;
2use serde::{Deserialize, Serialize};
3use std::collections::hash_map::DefaultHasher;
4use std::collections::HashMap;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7
8/// Cached grouping result stored in .git/semantic-diff-cache.json.
9#[derive(Debug, Serialize, Deserialize)]
10struct CacheEntry {
11    /// Hash of the raw diff output — if this matches, the cache is valid.
12    diff_hash: u64,
13    groups: Vec<CachedGroup>,
14    /// HEAD commit hash when this cache was saved. Used for incremental grouping.
15    #[serde(default)]
16    head_commit: Option<String>,
17    /// Per-file content hashes. Key = file path, Value = hash of hunk content.
18    #[serde(default)]
19    file_hashes: HashMap<String, u64>,
20}
21
22/// Serializable version of SemanticGroup.
23#[derive(Debug, Serialize, Deserialize)]
24struct CachedGroup {
25    label: String,
26    description: String,
27    changes: Vec<CachedChange>,
28}
29
30/// Serializable version of GroupedChange.
31#[derive(Debug, Serialize, Deserialize)]
32struct CachedChange {
33    file: String,
34    hunks: Vec<usize>,
35}
36
37/// Return the current HEAD commit hash, or None if not in a git repo.
38pub fn get_head_commit() -> Option<String> {
39    let output = std::process::Command::new("git")
40        .args(["rev-parse", "HEAD"])
41        .output()
42        .ok()?;
43    if !output.status.success() {
44        return None;
45    }
46    Some(String::from_utf8(output.stdout).ok()?.trim().to_string())
47}
48
49/// Compute a fast hash of the raw diff string.
50pub fn diff_hash(raw_diff: &str) -> u64 {
51    let mut hasher = DefaultHasher::new();
52    raw_diff.hash(&mut hasher);
53    hasher.finish()
54}
55
56/// Try to load cached grouping for the given diff hash.
57/// Returns None if no cache, hash mismatch, parse error, or oversized file.
58pub fn load(hash: u64) -> Option<Vec<SemanticGroup>> {
59    let path = cache_path()?;
60
61    // Reject oversized cache files (FINDING-16: prevent OOM from crafted cache)
62    let metadata = std::fs::metadata(&path).ok()?;
63    if metadata.len() > 1_048_576 {
64        // 1MB limit
65        tracing::warn!("Cache file too large ({} bytes), ignoring", metadata.len());
66        return None;
67    }
68
69    let content = std::fs::read_to_string(&path).ok()?;
70    let entry: CacheEntry = serde_json::from_str(&content).ok()?;
71
72    // Validate cache structure (FINDING-16: reject unreasonable group counts)
73    if entry.groups.len() > 50 {
74        tracing::warn!(
75            "Cache has too many groups ({}), ignoring",
76            entry.groups.len()
77        );
78        return None;
79    }
80
81    if entry.diff_hash != hash {
82        tracing::debug!("Cache miss: hash mismatch");
83        return None;
84    }
85
86    tracing::info!("Cache hit: reusing {} groups", entry.groups.len());
87    Some(
88        entry
89            .groups
90            .into_iter()
91            .map(|g| SemanticGroup::new(
92                g.label,
93                g.description,
94                g.changes
95                    .into_iter()
96                    .map(|c| crate::grouper::GroupedChange {
97                        file: c.file,
98                        hunks: c.hunks,
99                    })
100                    .collect(),
101            ))
102            .collect(),
103    )
104}
105
106/// Save grouping result to the cache file with optional incremental state.
107pub fn save_with_state(
108    hash: u64,
109    groups: &[SemanticGroup],
110    head_commit: Option<&str>,
111    file_hashes: &HashMap<String, u64>,
112) {
113    let Some(path) = cache_path() else { return };
114
115    let entry = CacheEntry {
116        diff_hash: hash,
117        groups: groups
118            .iter()
119            .map(|g| CachedGroup {
120                label: g.label.clone(),
121                description: g.description.clone(),
122                changes: g
123                    .changes()
124                    .iter()
125                    .map(|c| CachedChange {
126                        file: c.file.clone(),
127                        hunks: c.hunks.clone(),
128                    })
129                    .collect(),
130            })
131            .collect(),
132        head_commit: head_commit.map(|s| s.to_string()),
133        file_hashes: file_hashes.clone(),
134    };
135
136    match serde_json::to_string(&entry) {
137        Ok(json) => {
138            if let Err(e) = std::fs::write(&path, json) {
139                tracing::warn!("Failed to write cache: {}", e);
140            } else {
141                tracing::debug!("Saved cache to {}", path.display());
142            }
143        }
144        Err(e) => tracing::warn!("Failed to serialize cache: {}", e),
145    }
146}
147
148/// Try to load cached grouping for the given HEAD commit (incremental mode).
149/// Returns the cached groups and file hashes so the caller can compute the delta.
150/// Returns None if no cache, HEAD mismatch, empty file hashes, parse error, or oversized file.
151pub fn load_incremental(
152    current_head: &str,
153) -> Option<(Vec<SemanticGroup>, HashMap<String, u64>)> {
154    let path = cache_path()?;
155
156    let metadata = std::fs::metadata(&path).ok()?;
157    if metadata.len() > 1_048_576 {
158        return None;
159    }
160
161    let content = std::fs::read_to_string(&path).ok()?;
162    let entry: CacheEntry = serde_json::from_str(&content).ok()?;
163
164    if entry.groups.len() > 50 {
165        return None;
166    }
167
168    // Match by HEAD commit, not diff hash.
169    let cached_head = entry.head_commit.as_deref()?;
170    if cached_head != current_head {
171        return None;
172    }
173
174    if entry.file_hashes.is_empty() {
175        return None;
176    }
177
178    tracing::info!(
179        "Incremental cache hit: {} groups, {} file hashes",
180        entry.groups.len(),
181        entry.file_hashes.len()
182    );
183
184    let groups = entry
185        .groups
186        .into_iter()
187        .map(|g| {
188            SemanticGroup::new(
189                g.label,
190                g.description,
191                g.changes
192                    .into_iter()
193                    .map(|c| crate::grouper::GroupedChange {
194                        file: c.file,
195                        hunks: c.hunks,
196                    })
197                    .collect(),
198            )
199        })
200        .collect();
201
202    Some((groups, entry.file_hashes))
203}
204
205/// Path to the cache file: .git/semantic-diff-cache.json
206/// Returns None if not in a git repo or if git-dir is outside the repo root.
207fn cache_path() -> Option<PathBuf> {
208    let output = std::process::Command::new("git")
209        .args(["rev-parse", "--git-dir"])
210        .output()
211        .ok()?;
212    if !output.status.success() {
213        return None;
214    }
215    let git_dir = String::from_utf8(output.stdout).ok()?.trim().to_string();
216    let git_path = PathBuf::from(&git_dir);
217
218    // Validate: git-dir should be within or adjacent to the current working directory.
219    // This prevents crafted .git files from redirecting cache writes to arbitrary locations.
220    let cwd = std::env::current_dir().ok()?;
221    let canonical_git = std::fs::canonicalize(&git_path).unwrap_or(git_path.clone());
222    let canonical_cwd = std::fs::canonicalize(&cwd).unwrap_or(cwd);
223    if !canonical_git.starts_with(&canonical_cwd) {
224        tracing::warn!(
225            "git-dir {} is outside repo root {}, refusing to use cache",
226            canonical_git.display(),
227            canonical_cwd.display()
228        );
229        return None;
230    }
231
232    Some(PathBuf::from(git_dir).join("semantic-diff-cache.json"))
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_diff_hash_deterministic() {
241        let a = diff_hash("hello world");
242        let b = diff_hash("hello world");
243        assert_eq!(a, b);
244    }
245
246    #[test]
247    fn test_diff_hash_changes() {
248        let a = diff_hash("hello");
249        let b = diff_hash("world");
250        assert_ne!(a, b);
251    }
252
253    #[test]
254    fn test_cache_path_validates_git_dir_within_cwd() {
255        // cache_path() should return a path that's within the repo (when in a git repo)
256        // This test just verifies the function doesn't panic and returns a reasonable result
257        let path = cache_path();
258        if let Some(p) = &path {
259            assert!(
260                p.to_string_lossy().contains("semantic-diff-cache.json"),
261                "cache path should contain cache filename, got: {}",
262                p.display()
263            );
264        }
265        // None is acceptable (not in a git repo, or validation failed)
266    }
267
268    #[test]
269    fn test_load_rejects_oversized_cache() {
270        // Create a temp directory with an oversized cache file
271        let temp_dir = tempfile::tempdir().unwrap();
272        let cache_file = temp_dir.path().join("oversized-cache.json");
273        // Create a file larger than 1MB
274        let large_content = "x".repeat(1_048_577);
275        std::fs::write(&cache_file, large_content).unwrap();
276        let metadata = std::fs::metadata(&cache_file).unwrap();
277        assert!(
278            metadata.len() > 1_048_576,
279            "Test file should be larger than 1MB"
280        );
281        // We can't easily test the full load() path without mocking cache_path(),
282        // but we verify the size check constant is correct
283    }
284
285    #[test]
286    fn test_cache_entry_with_valid_groups_deserializes() {
287        let json = r#"{
288            "diff_hash": 12345,
289            "groups": [
290                {"label": "Auth", "description": "Auth changes", "changes": [{"file": "src/auth.rs", "hunks": [0]}]}
291            ]
292        }"#;
293        let entry: CacheEntry = serde_json::from_str(json).unwrap();
294        assert_eq!(entry.groups.len(), 1);
295        assert_eq!(entry.groups[0].label, "Auth");
296    }
297
298    #[test]
299    fn test_cache_entry_group_count_validation() {
300        // Build a cache entry with 60 groups (over the 50 limit)
301        let mut groups = Vec::new();
302        for i in 0..60 {
303            groups.push(CachedGroup {
304                label: format!("Group {}", i),
305                description: "desc".to_string(),
306                changes: vec![],
307            });
308        }
309        let entry = CacheEntry {
310            diff_hash: 99999,
311            groups,
312            head_commit: None,
313            file_hashes: HashMap::new(),
314        };
315        // Validation check: > 50 groups should be rejected
316        assert!(entry.groups.len() > 50);
317    }
318}