1use anyhow::{Context, Result};
2use http::{HeaderMap, Uri};
3use octocrab::service::middleware::cache::{CacheKey, CacheStorage, CacheWriter, CachedResponse};
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::sync::{Arc, Mutex};
7
8#[derive(Clone, Debug)]
10pub struct CacheConfig {
11 pub enabled: bool, }
13
14pub fn get_cache_path() -> PathBuf {
16 dirs::cache_dir()
17 .map(|p| p.join("pr-bro/http-cache"))
18 .unwrap_or_else(|| {
19 PathBuf::from(format!(
20 "{}/.cache/pr-bro/http-cache",
21 std::env::var("HOME").unwrap_or_default()
22 ))
23 })
24}
25
26pub fn clear_cache() -> Result<()> {
28 let cache_path = get_cache_path();
29 match std::fs::remove_dir_all(&cache_path) {
30 Ok(()) => Ok(()),
31 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
32 Err(e) => Err(e).context("Failed to remove cache directory"),
33 }
34}
35
36pub fn evict_stale_entries() -> usize {
39 let cache_path = get_cache_path();
40 let threshold = std::time::SystemTime::now()
41 .duration_since(std::time::UNIX_EPOCH)
42 .map(|d| d.as_millis())
43 .unwrap_or(0);
44 let max_age_ms: u128 = 7 * 24 * 60 * 60 * 1000;
46 let cutoff = threshold.saturating_sub(max_age_ms);
47
48 let mut removed = 0usize;
49 for entry in cacache::list_sync(&cache_path).flatten() {
50 if entry.time < cutoff {
51 let _ = cacache::remove_sync(&cache_path, &entry.key);
52 removed += 1;
53 }
54 }
55 removed
56}
57
58#[derive(Clone)]
63pub struct DiskCache {
64 inner: Arc<Mutex<CacheData>>,
65 cache_path: PathBuf,
66}
67
68struct CacheData {
69 keys: HashMap<String, CacheKey>, responses: HashMap<String, CachedResponse>, }
72
73#[derive(serde::Serialize, serde::Deserialize)]
75struct DiskCacheEntry {
76 etag: Option<String>,
77 last_modified: Option<String>,
78 headers: Vec<(String, Vec<u8>)>, body: Vec<u8>,
80}
81
82impl DiskCacheEntry {
83 fn from_parts(key: &CacheKey, response: &CachedResponse) -> Self {
85 let (etag, last_modified) = match key {
86 CacheKey::ETag(etag) => (Some(etag.clone()), None),
87 CacheKey::LastModified(lm) => (None, Some(lm.clone())),
88 _ => (None, None), };
90
91 let headers: Vec<(String, Vec<u8>)> = response
92 .headers
93 .iter()
94 .map(|(name, value)| (name.to_string(), value.as_bytes().to_vec()))
95 .collect();
96
97 Self {
98 etag,
99 last_modified,
100 headers,
101 body: response.body.clone(),
102 }
103 }
104
105 fn to_parts(&self) -> Result<(CacheKey, CachedResponse)> {
107 let key = if let Some(etag) = &self.etag {
108 CacheKey::ETag(etag.clone())
109 } else if let Some(lm) = &self.last_modified {
110 CacheKey::LastModified(lm.clone())
111 } else {
112 anyhow::bail!("Invalid cache entry: no ETag or Last-Modified");
113 };
114
115 let mut headers = HeaderMap::new();
116 for (name, value) in &self.headers {
117 let header_name: http::HeaderName = name.parse().context("Invalid header name")?;
118 let header_value =
119 http::HeaderValue::from_bytes(value).context("Invalid header value")?;
120 headers.insert(header_name, header_value);
121 }
122
123 let response = CachedResponse {
124 headers,
125 body: self.body.clone(),
126 };
127
128 Ok((key, response))
129 }
130}
131
132impl DiskCache {
133 pub fn new(cache_path: PathBuf) -> Self {
134 Self {
136 inner: Arc::new(Mutex::new(CacheData {
137 keys: HashMap::new(),
138 responses: HashMap::new(),
139 })),
140 cache_path,
141 }
142 }
143
144 pub fn clear_memory(&self) {
146 let mut data = self.inner.lock().unwrap();
147 data.keys.clear();
148 data.responses.clear();
149 }
150
151 fn load_from_disk(&self, uri_key: &str) -> Option<CacheKey> {
153 let bytes = cacache::read_sync(&self.cache_path, uri_key).ok()?;
155
156 let entry: DiskCacheEntry = serde_json::from_slice(&bytes).ok()?;
158
159 let (key, response) = entry.to_parts().ok()?;
161
162 let mut data = self.inner.lock().unwrap();
164 data.keys.insert(uri_key.to_string(), key.clone());
165 data.responses.insert(uri_key.to_string(), response);
166
167 Some(key)
168 }
169}
170
171impl CacheStorage for DiskCache {
172 fn try_hit(&self, uri: &Uri) -> Option<CacheKey> {
173 let uri_key = uri.to_string();
174
175 {
177 let data = self.inner.lock().unwrap();
178 if let Some(cache_key) = data.keys.get(&uri_key) {
179 return Some(cache_key.clone());
180 }
181 }
182
183 self.load_from_disk(&uri_key)
185 }
186
187 fn load(&self, uri: &Uri) -> Option<CachedResponse> {
188 let data = self.inner.lock().unwrap();
189 data.responses.get(&uri.to_string()).cloned()
190 }
191
192 fn writer(&self, uri: &Uri, key: CacheKey, headers: HeaderMap) -> Box<dyn CacheWriter> {
193 Box::new(DiskCacheWriter {
194 cache: self.inner.clone(),
195 cache_path: self.cache_path.clone(),
196 uri_key: uri.to_string(),
197 key,
198 response: CachedResponse {
199 body: Vec::new(),
200 headers,
201 },
202 })
203 }
204}
205
206struct DiskCacheWriter {
208 cache: Arc<Mutex<CacheData>>,
209 cache_path: PathBuf,
210 uri_key: String,
211 key: CacheKey,
212 response: CachedResponse,
213}
214
215impl CacheWriter for DiskCacheWriter {
216 fn write_body(&mut self, data: &[u8]) {
217 self.response.body.extend_from_slice(data);
218 }
219}
220
221impl Drop for DiskCacheWriter {
222 fn drop(&mut self) {
223 let uri_key = self.uri_key.clone();
224 let key = self.key.clone();
225 let response = CachedResponse {
226 body: std::mem::take(&mut self.response.body),
227 headers: self.response.headers.clone(),
228 };
229
230 {
232 let mut data = self.cache.lock().unwrap();
233 data.keys.insert(uri_key.clone(), key.clone());
234 data.responses.insert(uri_key.clone(), response.clone());
235 }
236
237 let entry = DiskCacheEntry::from_parts(&key, &response);
239 if let Ok(serialized) = serde_json::to_vec(&entry) {
240 let _ = cacache::write_sync(&self.cache_path, &uri_key, &serialized);
241 }
242 }
243}