Skip to main content

gitstack/
cache.rs

1use std::collections::hash_map::DefaultHasher;
2use std::fs;
3use std::hash::{Hash, Hasher};
4use std::path::{Path, PathBuf};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use anyhow::Result;
8use git2::Repository;
9use serde::{Deserialize, Serialize};
10
11const SECONDS_PER_HOUR: u64 = 3600;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14struct CacheEnvelope<T> {
15    generated_at_unix: u64,
16    head_hash: String,
17    payload: T,
18}
19
20pub fn clear_analysis_cache() -> Result<()> {
21    let dir = analysis_cache_dir()?;
22    if dir.exists() {
23        fs::remove_dir_all(&dir)?;
24    }
25    Ok(())
26}
27
28pub fn load_or_compute(
29    key: &str,
30    ttl_hours: u64,
31    compute: impl FnOnce() -> Result<String>,
32) -> Result<String> {
33    load_or_compute_with_repo(None, key, ttl_hours, compute)
34}
35
36pub fn load_or_compute_with_repo(
37    repo: Option<&Repository>,
38    key: &str,
39    ttl_hours: u64,
40    compute: impl FnOnce() -> Result<String>,
41) -> Result<String> {
42    let owned_repo;
43    let repo = match repo {
44        Some(r) => r,
45        None => {
46            owned_repo = Repository::discover(".")?;
47            &owned_repo
48        }
49    };
50    let head_hash = repo
51        .head()?
52        .target()
53        .map(|oid| oid.to_string())
54        .unwrap_or_else(|| "HEAD".to_string());
55    let path = cache_file_path(repo, key)?;
56
57    if let Ok(content) = fs::read_to_string(&path) {
58        if let Ok(envelope) = serde_json::from_str::<CacheEnvelope<String>>(&content) {
59            if envelope.head_hash == head_hash && !is_expired(envelope.generated_at_unix, ttl_hours)
60            {
61                return Ok(envelope.payload);
62            }
63        }
64    }
65
66    let payload = compute()?;
67    let envelope = CacheEnvelope {
68        generated_at_unix: now_unix_secs(),
69        head_hash,
70        payload,
71    };
72
73    if let Some(parent) = path.parent() {
74        let _ = fs::create_dir_all(parent);
75    }
76    if let Ok(serialized) = serde_json::to_string(&envelope) {
77        let _ = fs::write(&path, serialized);
78    }
79
80    Ok(envelope.payload)
81}
82
83fn is_expired(generated_at_unix: u64, ttl_hours: u64) -> bool {
84    if ttl_hours == 0 {
85        return true;
86    }
87    let now = now_unix_secs();
88    now.saturating_sub(generated_at_unix) > ttl_hours.saturating_mul(SECONDS_PER_HOUR)
89}
90
91fn cache_file_path(repo: &Repository, key: &str) -> Result<PathBuf> {
92    Ok(analysis_cache_dir()?
93        .join(repo_cache_key(repo)?)
94        .join(format!("{}.json", key)))
95}
96
97fn analysis_cache_dir() -> Result<PathBuf> {
98    if let Some(base) = dirs::cache_dir() {
99        return Ok(base.join("gitstack").join("analysis"));
100    }
101    Ok(std::env::temp_dir().join("gitstack").join("analysis"))
102}
103
104fn repo_cache_key(repo: &Repository) -> Result<String> {
105    let path = repo
106        .workdir()
107        .or_else(|| repo.path().parent())
108        .unwrap_or_else(|| Path::new("."));
109
110    let mut hasher = DefaultHasher::new();
111    path.to_string_lossy().hash(&mut hasher);
112    Ok(format!("{:x}", hasher.finish()))
113}
114
115fn now_unix_secs() -> u64 {
116    SystemTime::now()
117        .duration_since(UNIX_EPOCH)
118        .unwrap_or_default()
119        .as_secs()
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn is_expired_returns_true_when_ttl_is_zero() {
128        assert!(is_expired(now_unix_secs(), 0));
129    }
130
131    #[test]
132    fn is_expired_returns_false_for_fresh_entry() {
133        let now = now_unix_secs();
134        assert!(!is_expired(now, 1));
135    }
136
137    #[test]
138    fn is_expired_returns_true_for_old_entry() {
139        let two_hours_ago = now_unix_secs() - 7200;
140        assert!(is_expired(two_hours_ago, 1));
141    }
142
143    #[test]
144    fn is_expired_boundary_exactly_at_ttl() {
145        let now = now_unix_secs();
146        let generated_at = now - 3600;
147        // At exactly 1 hour boundary, should not be expired (> not >=)
148        assert!(!is_expired(generated_at, 1));
149    }
150
151    #[test]
152    fn is_expired_boundary_one_second_past_ttl() {
153        let now = now_unix_secs();
154        let generated_at = now - 3601;
155        assert!(is_expired(generated_at, 1));
156    }
157
158    #[test]
159    fn is_expired_large_ttl() {
160        let one_day_ago = now_unix_secs() - 86400;
161        // 48 hours TTL, generated 24 hours ago => not expired
162        assert!(!is_expired(one_day_ago, 48));
163    }
164
165    #[test]
166    fn is_expired_saturating_sub_on_future_timestamp() {
167        // generated_at in the "future" (larger than now) => saturating_sub => 0 => not expired
168        let future = now_unix_secs() + 10000;
169        assert!(!is_expired(future, 1));
170    }
171
172    #[test]
173    fn now_unix_secs_returns_reasonable_value() {
174        let secs = now_unix_secs();
175        // Should be after 2024-01-01 (1704067200)
176        assert!(secs > 1_704_067_200);
177    }
178
179    #[test]
180    fn cache_envelope_serialization_roundtrip() {
181        let envelope = CacheEnvelope {
182            generated_at_unix: 1_700_000_000,
183            head_hash: "abc123".to_string(),
184            payload: "test payload".to_string(),
185        };
186        let json = serde_json::to_string(&envelope).unwrap();
187        let deserialized: CacheEnvelope<String> = serde_json::from_str(&json).unwrap();
188        assert_eq!(deserialized.generated_at_unix, 1_700_000_000);
189        assert_eq!(deserialized.head_hash, "abc123");
190        assert_eq!(deserialized.payload, "test payload");
191    }
192
193    #[test]
194    fn cache_envelope_deserialization_from_known_json() {
195        let json = r#"{"generated_at_unix":1700000000,"head_hash":"def456","payload":"hello"}"#;
196        let envelope: CacheEnvelope<String> = serde_json::from_str(json).unwrap();
197        assert_eq!(envelope.generated_at_unix, 1_700_000_000);
198        assert_eq!(envelope.head_hash, "def456");
199        assert_eq!(envelope.payload, "hello");
200    }
201
202    #[test]
203    fn cache_envelope_invalid_json_returns_error() {
204        let result = serde_json::from_str::<CacheEnvelope<String>>("not json");
205        assert!(result.is_err());
206    }
207
208    #[test]
209    fn analysis_cache_dir_returns_path_with_gitstack() {
210        let dir = analysis_cache_dir().unwrap();
211        let dir_str = dir.to_string_lossy();
212        assert!(dir_str.contains("gitstack"));
213        assert!(dir_str.contains("analysis"));
214    }
215}