Skip to main content

pr_bro/github/
cache.rs

1use anyhow::{Context, Result};
2use http::{HeaderMap, Uri};
3use octocrab::service::middleware::cache::{CacheKey, CacheStorage, CacheWriter, CachedResponse};
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::sync::{Arc, Mutex};
7
8/// Configuration for HTTP response caching
9#[derive(Clone, Debug)]
10pub struct CacheConfig {
11    pub enabled: bool, // false when --no-cache
12}
13
14/// Get the platform-appropriate cache directory for pr-bro
15pub fn get_cache_path() -> PathBuf {
16    dirs::cache_dir()
17        .map(|p| p.join("pr-bro/http-cache"))
18        .unwrap_or_else(|| {
19            PathBuf::from(format!(
20                "{}/.cache/pr-bro/http-cache",
21                std::env::var("HOME").unwrap_or_default()
22            ))
23        })
24}
25
26/// Clear the HTTP cache directory
27pub fn clear_cache() -> Result<()> {
28    let cache_path = get_cache_path();
29    match std::fs::remove_dir_all(&cache_path) {
30        Ok(()) => Ok(()),
31        Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
32        Err(e) => Err(e).context("Failed to remove cache directory"),
33    }
34}
35
36/// Evict cache entries older than 7 days. Returns number of entries removed.
37/// Best-effort: errors during listing or removal are silently ignored.
38pub fn evict_stale_entries() -> usize {
39    let cache_path = get_cache_path();
40    let threshold = std::time::SystemTime::now()
41        .duration_since(std::time::UNIX_EPOCH)
42        .map(|d| d.as_millis())
43        .unwrap_or(0);
44    // 7 days in milliseconds
45    let max_age_ms: u128 = 7 * 24 * 60 * 60 * 1000;
46    let cutoff = threshold.saturating_sub(max_age_ms);
47
48    let mut removed = 0usize;
49    for entry in cacache::list_sync(&cache_path).flatten() {
50        if entry.time < cutoff {
51            let _ = cacache::remove_sync(&cache_path, &entry.key);
52            removed += 1;
53        }
54    }
55    removed
56}
57
58/// Disk-persistent cache implementing octocrab's CacheStorage trait
59///
60/// Uses cacache for disk persistence and in-memory HashMap for fast access.
61/// Responses are cached by URI with ETag/Last-Modified headers for conditional requests.
62#[derive(Clone)]
63pub struct DiskCache {
64    inner: Arc<Mutex<CacheData>>,
65    cache_path: PathBuf,
66}
67
68struct CacheData {
69    keys: HashMap<String, CacheKey>,            // URI string -> CacheKey
70    responses: HashMap<String, CachedResponse>, // URI string -> cached response
71}
72
73/// Serializable representation of a cache entry for disk storage
74#[derive(serde::Serialize, serde::Deserialize)]
75struct DiskCacheEntry {
76    etag: Option<String>,
77    last_modified: Option<String>,
78    headers: Vec<(String, Vec<u8>)>, // header name -> value bytes
79    body: Vec<u8>,
80}
81
82impl DiskCacheEntry {
83    /// Create a DiskCacheEntry from CacheKey and CachedResponse
84    fn from_parts(key: &CacheKey, response: &CachedResponse) -> Self {
85        let (etag, last_modified) = match key {
86            CacheKey::ETag(etag) => (Some(etag.clone()), None),
87            CacheKey::LastModified(lm) => (None, Some(lm.clone())),
88            _ => (None, None), // Handle non-exhaustive enum
89        };
90
91        let headers: Vec<(String, Vec<u8>)> = response
92            .headers
93            .iter()
94            .map(|(name, value)| (name.to_string(), value.as_bytes().to_vec()))
95            .collect();
96
97        Self {
98            etag,
99            last_modified,
100            headers,
101            body: response.body.clone(),
102        }
103    }
104
105    /// Convert back to CacheKey and CachedResponse
106    fn to_parts(&self) -> Result<(CacheKey, CachedResponse)> {
107        let key = if let Some(etag) = &self.etag {
108            CacheKey::ETag(etag.clone())
109        } else if let Some(lm) = &self.last_modified {
110            CacheKey::LastModified(lm.clone())
111        } else {
112            anyhow::bail!("Invalid cache entry: no ETag or Last-Modified");
113        };
114
115        let mut headers = HeaderMap::new();
116        for (name, value) in &self.headers {
117            let header_name: http::HeaderName = name.parse().context("Invalid header name")?;
118            let header_value =
119                http::HeaderValue::from_bytes(value).context("Invalid header value")?;
120            headers.insert(header_name, header_value);
121        }
122
123        let response = CachedResponse {
124            headers,
125            body: self.body.clone(),
126        };
127
128        Ok((key, response))
129    }
130}
131
132impl DiskCache {
133    pub fn new(cache_path: PathBuf) -> Self {
134        // Don't pre-load disk cache - entries are loaded on demand
135        Self {
136            inner: Arc::new(Mutex::new(CacheData {
137                keys: HashMap::new(),
138                responses: HashMap::new(),
139            })),
140            cache_path,
141        }
142    }
143
144    /// Clear the in-memory cache to force fresh requests on next fetch
145    pub fn clear_memory(&self) {
146        let mut data = self.inner.lock().unwrap();
147        data.keys.clear();
148        data.responses.clear();
149    }
150
151    /// Try to load a cache entry from disk
152    fn load_from_disk(&self, uri_key: &str) -> Option<CacheKey> {
153        // Try to read from disk
154        let bytes = cacache::read_sync(&self.cache_path, uri_key).ok()?;
155
156        // Deserialize
157        let entry: DiskCacheEntry = serde_json::from_slice(&bytes).ok()?;
158
159        // Convert to CacheKey and CachedResponse
160        let (key, response) = entry.to_parts().ok()?;
161
162        // Populate in-memory cache for subsequent hits
163        let mut data = self.inner.lock().unwrap();
164        data.keys.insert(uri_key.to_string(), key.clone());
165        data.responses.insert(uri_key.to_string(), response);
166
167        Some(key)
168    }
169}
170
171impl CacheStorage for DiskCache {
172    fn try_hit(&self, uri: &Uri) -> Option<CacheKey> {
173        let uri_key = uri.to_string();
174
175        // Check in-memory first
176        {
177            let data = self.inner.lock().unwrap();
178            if let Some(cache_key) = data.keys.get(&uri_key) {
179                return Some(cache_key.clone());
180            }
181        }
182
183        // Try loading from disk
184        self.load_from_disk(&uri_key)
185    }
186
187    fn load(&self, uri: &Uri) -> Option<CachedResponse> {
188        let data = self.inner.lock().unwrap();
189        data.responses.get(&uri.to_string()).cloned()
190    }
191
192    fn writer(&self, uri: &Uri, key: CacheKey, headers: HeaderMap) -> Box<dyn CacheWriter> {
193        Box::new(DiskCacheWriter {
194            cache: self.inner.clone(),
195            cache_path: self.cache_path.clone(),
196            uri_key: uri.to_string(),
197            key,
198            response: CachedResponse {
199                body: Vec::new(),
200                headers,
201            },
202        })
203    }
204}
205
206/// Writer that persists cache entries to both memory and disk
207struct DiskCacheWriter {
208    cache: Arc<Mutex<CacheData>>,
209    cache_path: PathBuf,
210    uri_key: String,
211    key: CacheKey,
212    response: CachedResponse,
213}
214
215impl CacheWriter for DiskCacheWriter {
216    fn write_body(&mut self, data: &[u8]) {
217        self.response.body.extend_from_slice(data);
218    }
219}
220
221impl Drop for DiskCacheWriter {
222    fn drop(&mut self) {
223        let uri_key = self.uri_key.clone();
224        let key = self.key.clone();
225        let response = CachedResponse {
226            body: std::mem::take(&mut self.response.body),
227            headers: self.response.headers.clone(),
228        };
229
230        // Validate that the response body is valid JSON before caching
231        // Truncated/incomplete responses from network failures should not be persisted
232        if serde_json::from_slice::<serde_json::Value>(&response.body).is_err() {
233            // Skip caching - body is empty or invalid JSON
234            return;
235        }
236
237        // Write to in-memory cache
238        {
239            let mut data = self.cache.lock().unwrap();
240            data.keys.insert(uri_key.clone(), key.clone());
241            data.responses.insert(uri_key.clone(), response.clone());
242        }
243
244        // Write to disk (fire-and-forget, don't block on disk errors)
245        let entry = DiskCacheEntry::from_parts(&key, &response);
246        if let Ok(serialized) = serde_json::to_vec(&entry) {
247            let _ = cacache::write_sync(&self.cache_path, &uri_key, &serialized);
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use http::{HeaderMap, Uri};
256    use octocrab::service::middleware::cache::{CacheKey, CacheStorage};
257
258    fn unique_cache_path(test_name: &str) -> PathBuf {
259        let timestamp = std::time::SystemTime::now()
260            .duration_since(std::time::UNIX_EPOCH)
261            .unwrap()
262            .as_nanos();
263        std::env::temp_dir().join(format!("pr-bro-test-cache-{}-{}", test_name, timestamp))
264    }
265
266    #[test]
267    fn test_valid_json_is_cached() {
268        let cache_path = unique_cache_path("valid");
269        let cache = DiskCache::new(cache_path.clone());
270
271        let uri = Uri::from_static("https://api.github.com/repos/test/test/pulls/1");
272        let key = CacheKey::ETag("test-etag".to_string());
273        let headers = HeaderMap::new();
274
275        // Write valid JSON body
276        let mut writer = cache.writer(&uri, key, headers);
277        writer.write_body(br#"{"login":"test","id":1}"#);
278        drop(writer);
279
280        // Verify cache hit
281        assert!(cache.try_hit(&uri).is_some());
282        assert!(cache.load(&uri).is_some());
283
284        // Cleanup
285        let _ = std::fs::remove_dir_all(&cache_path);
286    }
287
288    #[test]
289    fn test_truncated_json_is_not_cached() {
290        let cache_path = unique_cache_path("truncated");
291        let cache = DiskCache::new(cache_path.clone());
292
293        let uri = Uri::from_static("https://api.github.com/repos/test/test/pulls/2");
294        let key = CacheKey::ETag("test-etag-2".to_string());
295        let headers = HeaderMap::new();
296
297        // Write truncated JSON body (missing closing brace and value)
298        let mut writer = cache.writer(&uri, key, headers);
299        writer.write_body(br#"{"login":"test","id":"#);
300        drop(writer);
301
302        // Verify cache miss - truncated JSON should not be cached
303        assert!(cache.try_hit(&uri).is_none());
304        assert!(cache.load(&uri).is_none());
305
306        // Cleanup
307        let _ = std::fs::remove_dir_all(&cache_path);
308    }
309
310    #[test]
311    fn test_empty_body_is_not_cached() {
312        let cache_path = unique_cache_path("empty");
313        let cache = DiskCache::new(cache_path.clone());
314
315        let uri = Uri::from_static("https://api.github.com/repos/test/test/pulls/3");
316        let key = CacheKey::ETag("test-etag-3".to_string());
317        let headers = HeaderMap::new();
318
319        // Write empty body
320        let mut writer = cache.writer(&uri, key, headers);
321        writer.write_body(b"");
322        drop(writer);
323
324        // Verify cache miss - empty body should not be cached
325        assert!(cache.try_hit(&uri).is_none());
326        assert!(cache.load(&uri).is_none());
327
328        // Cleanup
329        let _ = std::fs::remove_dir_all(&cache_path);
330    }
331}