1use anyhow::{Context, Result};
2use http::{HeaderMap, Uri};
3use octocrab::service::middleware::cache::{CacheKey, CacheStorage, CacheWriter, CachedResponse};
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::sync::{Arc, Mutex};
7
8#[derive(Clone, Debug)]
10pub struct CacheConfig {
11 pub enabled: bool, }
13
14pub fn get_cache_path() -> PathBuf {
16 dirs::cache_dir()
17 .map(|p| p.join("pr-bro/http-cache"))
18 .unwrap_or_else(|| {
19 PathBuf::from(format!(
20 "{}/.cache/pr-bro/http-cache",
21 std::env::var("HOME").unwrap_or_default()
22 ))
23 })
24}
25
26pub fn clear_cache() -> Result<()> {
28 let cache_path = get_cache_path();
29 match std::fs::remove_dir_all(&cache_path) {
30 Ok(()) => Ok(()),
31 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
32 Err(e) => Err(e).context("Failed to remove cache directory"),
33 }
34}
35
36pub fn evict_stale_entries() -> usize {
39 let cache_path = get_cache_path();
40 let threshold = std::time::SystemTime::now()
41 .duration_since(std::time::UNIX_EPOCH)
42 .map(|d| d.as_millis())
43 .unwrap_or(0);
44 let max_age_ms: u128 = 7 * 24 * 60 * 60 * 1000;
46 let cutoff = threshold.saturating_sub(max_age_ms);
47
48 let mut removed = 0usize;
49 for entry in cacache::list_sync(&cache_path).flatten() {
50 if entry.time < cutoff {
51 let _ = cacache::remove_sync(&cache_path, &entry.key);
52 removed += 1;
53 }
54 }
55 removed
56}
57
58#[derive(Clone)]
63pub struct DiskCache {
64 inner: Arc<Mutex<CacheData>>,
65 cache_path: PathBuf,
66}
67
68struct CacheData {
69 keys: HashMap<String, CacheKey>, responses: HashMap<String, CachedResponse>, }
72
73#[derive(serde::Serialize, serde::Deserialize)]
75struct DiskCacheEntry {
76 etag: Option<String>,
77 last_modified: Option<String>,
78 headers: Vec<(String, Vec<u8>)>, body: Vec<u8>,
80}
81
82impl DiskCacheEntry {
83 fn from_parts(key: &CacheKey, response: &CachedResponse) -> Self {
85 let (etag, last_modified) = match key {
86 CacheKey::ETag(etag) => (Some(etag.clone()), None),
87 CacheKey::LastModified(lm) => (None, Some(lm.clone())),
88 _ => (None, None), };
90
91 let headers: Vec<(String, Vec<u8>)> = response
92 .headers
93 .iter()
94 .map(|(name, value)| (name.to_string(), value.as_bytes().to_vec()))
95 .collect();
96
97 Self {
98 etag,
99 last_modified,
100 headers,
101 body: response.body.clone(),
102 }
103 }
104
105 fn to_parts(&self) -> Result<(CacheKey, CachedResponse)> {
107 let key = if let Some(etag) = &self.etag {
108 CacheKey::ETag(etag.clone())
109 } else if let Some(lm) = &self.last_modified {
110 CacheKey::LastModified(lm.clone())
111 } else {
112 anyhow::bail!("Invalid cache entry: no ETag or Last-Modified");
113 };
114
115 let mut headers = HeaderMap::new();
116 for (name, value) in &self.headers {
117 let header_name: http::HeaderName = name.parse().context("Invalid header name")?;
118 let header_value =
119 http::HeaderValue::from_bytes(value).context("Invalid header value")?;
120 headers.insert(header_name, header_value);
121 }
122
123 let response = CachedResponse {
124 headers,
125 body: self.body.clone(),
126 };
127
128 Ok((key, response))
129 }
130}
131
132impl DiskCache {
133 pub fn new(cache_path: PathBuf) -> Self {
134 Self {
136 inner: Arc::new(Mutex::new(CacheData {
137 keys: HashMap::new(),
138 responses: HashMap::new(),
139 })),
140 cache_path,
141 }
142 }
143
144 pub fn clear_memory(&self) {
146 let mut data = self.inner.lock().unwrap();
147 data.keys.clear();
148 data.responses.clear();
149 }
150
151 fn load_from_disk(&self, uri_key: &str) -> Option<CacheKey> {
153 let bytes = cacache::read_sync(&self.cache_path, uri_key).ok()?;
155
156 let entry: DiskCacheEntry = serde_json::from_slice(&bytes).ok()?;
158
159 let (key, response) = entry.to_parts().ok()?;
161
162 let mut data = self.inner.lock().unwrap();
164 data.keys.insert(uri_key.to_string(), key.clone());
165 data.responses.insert(uri_key.to_string(), response);
166
167 Some(key)
168 }
169}
170
171impl CacheStorage for DiskCache {
172 fn try_hit(&self, uri: &Uri) -> Option<CacheKey> {
173 let uri_key = uri.to_string();
174
175 {
177 let data = self.inner.lock().unwrap();
178 if let Some(cache_key) = data.keys.get(&uri_key) {
179 return Some(cache_key.clone());
180 }
181 }
182
183 self.load_from_disk(&uri_key)
185 }
186
187 fn load(&self, uri: &Uri) -> Option<CachedResponse> {
188 let data = self.inner.lock().unwrap();
189 data.responses.get(&uri.to_string()).cloned()
190 }
191
192 fn writer(&self, uri: &Uri, key: CacheKey, headers: HeaderMap) -> Box<dyn CacheWriter> {
193 Box::new(DiskCacheWriter {
194 cache: self.inner.clone(),
195 cache_path: self.cache_path.clone(),
196 uri_key: uri.to_string(),
197 key,
198 response: CachedResponse {
199 body: Vec::new(),
200 headers,
201 },
202 })
203 }
204}
205
206struct DiskCacheWriter {
208 cache: Arc<Mutex<CacheData>>,
209 cache_path: PathBuf,
210 uri_key: String,
211 key: CacheKey,
212 response: CachedResponse,
213}
214
215impl CacheWriter for DiskCacheWriter {
216 fn write_body(&mut self, data: &[u8]) {
217 self.response.body.extend_from_slice(data);
218 }
219}
220
221impl Drop for DiskCacheWriter {
222 fn drop(&mut self) {
223 let uri_key = self.uri_key.clone();
224 let key = self.key.clone();
225 let response = CachedResponse {
226 body: std::mem::take(&mut self.response.body),
227 headers: self.response.headers.clone(),
228 };
229
230 if serde_json::from_slice::<serde_json::Value>(&response.body).is_err() {
233 return;
235 }
236
237 {
239 let mut data = self.cache.lock().unwrap();
240 data.keys.insert(uri_key.clone(), key.clone());
241 data.responses.insert(uri_key.clone(), response.clone());
242 }
243
244 let entry = DiskCacheEntry::from_parts(&key, &response);
246 if let Ok(serialized) = serde_json::to_vec(&entry) {
247 let _ = cacache::write_sync(&self.cache_path, &uri_key, &serialized);
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use http::{HeaderMap, Uri};
256 use octocrab::service::middleware::cache::{CacheKey, CacheStorage};
257
258 fn unique_cache_path(test_name: &str) -> PathBuf {
259 let timestamp = std::time::SystemTime::now()
260 .duration_since(std::time::UNIX_EPOCH)
261 .unwrap()
262 .as_nanos();
263 std::env::temp_dir().join(format!("pr-bro-test-cache-{}-{}", test_name, timestamp))
264 }
265
266 #[test]
267 fn test_valid_json_is_cached() {
268 let cache_path = unique_cache_path("valid");
269 let cache = DiskCache::new(cache_path.clone());
270
271 let uri = Uri::from_static("https://api.github.com/repos/test/test/pulls/1");
272 let key = CacheKey::ETag("test-etag".to_string());
273 let headers = HeaderMap::new();
274
275 let mut writer = cache.writer(&uri, key, headers);
277 writer.write_body(br#"{"login":"test","id":1}"#);
278 drop(writer);
279
280 assert!(cache.try_hit(&uri).is_some());
282 assert!(cache.load(&uri).is_some());
283
284 let _ = std::fs::remove_dir_all(&cache_path);
286 }
287
288 #[test]
289 fn test_truncated_json_is_not_cached() {
290 let cache_path = unique_cache_path("truncated");
291 let cache = DiskCache::new(cache_path.clone());
292
293 let uri = Uri::from_static("https://api.github.com/repos/test/test/pulls/2");
294 let key = CacheKey::ETag("test-etag-2".to_string());
295 let headers = HeaderMap::new();
296
297 let mut writer = cache.writer(&uri, key, headers);
299 writer.write_body(br#"{"login":"test","id":"#);
300 drop(writer);
301
302 assert!(cache.try_hit(&uri).is_none());
304 assert!(cache.load(&uri).is_none());
305
306 let _ = std::fs::remove_dir_all(&cache_path);
308 }
309
310 #[test]
311 fn test_empty_body_is_not_cached() {
312 let cache_path = unique_cache_path("empty");
313 let cache = DiskCache::new(cache_path.clone());
314
315 let uri = Uri::from_static("https://api.github.com/repos/test/test/pulls/3");
316 let key = CacheKey::ETag("test-etag-3".to_string());
317 let headers = HeaderMap::new();
318
319 let mut writer = cache.writer(&uri, key, headers);
321 writer.write_body(b"");
322 drop(writer);
323
324 assert!(cache.try_hit(&uri).is_none());
326 assert!(cache.load(&uri).is_none());
327
328 let _ = std::fs::remove_dir_all(&cache_path);
330 }
331}