1use serde::{Deserialize, Serialize};
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4use std::io;
5use std::path::PathBuf;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8#[derive(Clone, Debug)]
19pub struct DiskCache {
20 cache_dir: PathBuf,
21 ttl: Duration,
22}
23
24#[derive(Serialize, Deserialize)]
25struct CacheEntry {
26 ts: u64,
27 body: String,
28}
29
30impl DiskCache {
31 pub fn new(cache_dir: PathBuf, ttl: Duration) -> io::Result<Self> {
35 std::fs::create_dir_all(&cache_dir)?;
36 let cache = Self { cache_dir, ttl };
37 cache.prune();
38 Ok(cache)
39 }
40
41 pub fn default_location(ttl: Duration) -> io::Result<Self> {
49 let base = dirs::cache_dir().ok_or_else(|| {
50 io::Error::new(io::ErrorKind::NotFound, "no platform cache directory")
51 })?;
52 Self::new(base.join("papers").join("requests"), ttl)
53 }
54
55 pub fn get(&self, url: &str, query: &[(&str, String)], body: Option<&str>) -> Option<String> {
59 let key = cache_key(url, query, body);
60 let path = self.cache_dir.join(format!("{key:016x}.json"));
61 let data = std::fs::read_to_string(&path).ok()?;
62 let entry: CacheEntry = serde_json::from_str(&data).ok()?;
63 let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs();
64 if now.saturating_sub(entry.ts) > self.ttl.as_secs() {
65 return None;
66 }
67 Some(entry.body)
68 }
69
70 pub fn set(&self, url: &str, query: &[(&str, String)], body: Option<&str>, response: &str) {
75 let _ = self.set_inner(url, query, body, response);
76 }
77
78 fn set_inner(
79 &self,
80 url: &str,
81 query: &[(&str, String)],
82 body: Option<&str>,
83 response: &str,
84 ) -> io::Result<()> {
85 let key = cache_key(url, query, body);
86 let ts = SystemTime::now()
87 .duration_since(UNIX_EPOCH)
88 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
89 .as_secs();
90 let entry = CacheEntry {
91 ts,
92 body: response.to_string(),
93 };
94 let json = serde_json::to_string(&entry)
95 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
96 let tmp_path = self.cache_dir.join(format!("{key:016x}.tmp"));
97 let final_path = self.cache_dir.join(format!("{key:016x}.json"));
98 std::fs::write(&tmp_path, json)?;
99 std::fs::rename(&tmp_path, &final_path)?;
100 Ok(())
101 }
102
103 pub fn prune(&self) {
108 let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
109 Ok(d) => d.as_secs(),
110 Err(_) => return,
111 };
112 let entries = match std::fs::read_dir(&self.cache_dir) {
113 Ok(e) => e,
114 Err(_) => return,
115 };
116 for entry in entries.flatten() {
117 let path = entry.path();
118 let name = match path.file_name().and_then(|n| n.to_str()) {
119 Some(n) => n,
120 None => continue,
121 };
122 if name.ends_with(".tmp") {
124 let _ = std::fs::remove_file(&path);
125 continue;
126 }
127 if !name.ends_with(".json") {
129 continue;
130 }
131 let data = match std::fs::read_to_string(&path) {
132 Ok(d) => d,
133 Err(_) => {
134 let _ = std::fs::remove_file(&path);
135 continue;
136 }
137 };
138 let entry: CacheEntry = match serde_json::from_str(&data) {
139 Ok(e) => e,
140 Err(_) => {
141 let _ = std::fs::remove_file(&path);
142 continue;
143 }
144 };
145 if now.saturating_sub(entry.ts) > self.ttl.as_secs() {
146 let _ = std::fs::remove_file(&path);
147 }
148 }
149 }
150}
151
152fn cache_key(url: &str, query: &[(&str, String)], body: Option<&str>) -> u64 {
154 let mut sorted: Vec<(&str, &str)> = query.iter().map(|(k, v)| (*k, v.as_str())).collect();
155 sorted.sort();
156 let mut hasher = DefaultHasher::new();
157 url.hash(&mut hasher);
158 for (k, v) in &sorted {
159 k.hash(&mut hasher);
160 v.hash(&mut hasher);
161 }
162 if let Some(b) = body {
163 b.hash(&mut hasher);
164 }
165 hasher.finish()
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use std::thread::sleep;
172
173 fn temp_cache(ttl_secs: u64) -> DiskCache {
174 let dir = std::env::temp_dir()
175 .join("papers-test-cache")
176 .join(format!("{:x}", rand_u64()));
177 DiskCache::new(dir, Duration::from_secs(ttl_secs)).unwrap()
178 }
179
180 fn rand_u64() -> u64 {
181 let mut hasher = DefaultHasher::new();
182 SystemTime::now()
183 .duration_since(UNIX_EPOCH)
184 .unwrap()
185 .as_nanos()
186 .hash(&mut hasher);
187 std::thread::current().id().hash(&mut hasher);
188 hasher.finish()
189 }
190
191 #[test]
192 fn key_is_deterministic() {
193 let q = vec![("a", "1".into()), ("b", "2".into())];
194 let k1 = cache_key("http://x", &q, None);
195 let k2 = cache_key("http://x", &q, None);
196 assert_eq!(k1, k2);
197 }
198
199 #[test]
200 fn key_differs_by_url() {
201 let q: Vec<(&str, String)> = vec![];
202 assert_ne!(cache_key("http://a", &q, None), cache_key("http://b", &q, None));
203 }
204
205 #[test]
206 fn key_differs_by_query() {
207 let q1 = vec![("a", "1".into())];
208 let q2 = vec![("a", "2".into())];
209 assert_ne!(cache_key("http://x", &q1, None), cache_key("http://x", &q2, None));
210 }
211
212 #[test]
213 fn key_differs_by_body() {
214 let q: Vec<(&str, String)> = vec![];
215 assert_ne!(
216 cache_key("http://x", &q, Some("body1")),
217 cache_key("http://x", &q, Some("body2"))
218 );
219 }
220
221 #[test]
222 fn key_query_order_independent() {
223 let q1 = vec![("b", "2".into()), ("a", "1".into())];
224 let q2 = vec![("a", "1".into()), ("b", "2".into())];
225 assert_eq!(cache_key("http://x", &q1, None), cache_key("http://x", &q2, None));
226 }
227
228 #[test]
229 fn set_get_roundtrip() {
230 let cache = temp_cache(60);
231 let q = vec![("k", "v".into())];
232 cache.set("http://x", &q, None, "response body");
233 let got = cache.get("http://x", &q, None);
234 assert_eq!(got.as_deref(), Some("response body"));
235 }
236
237 #[test]
238 fn missing_key_returns_none() {
239 let cache = temp_cache(60);
240 let q: Vec<(&str, String)> = vec![];
241 assert!(cache.get("http://nonexistent", &q, None).is_none());
242 }
243
244 #[test]
245 fn expired_entry_returns_none() {
246 let cache = temp_cache(1);
247 let q: Vec<(&str, String)> = vec![];
248 cache.set("http://x", &q, None, "data");
249 sleep(Duration::from_secs(2));
250 assert!(cache.get("http://x", &q, None).is_none());
251 }
252
253 #[test]
254 fn corrupted_file_returns_none() {
255 let cache = temp_cache(60);
256 let q: Vec<(&str, String)> = vec![];
257 let key = cache_key("http://x", &q, None);
258 let path = cache.cache_dir.join(format!("{key:016x}.json"));
259 std::fs::write(&path, "not json").unwrap();
260 assert!(cache.get("http://x", &q, None).is_none());
261 }
262
263 #[test]
264 fn prune_removes_expired_entries() {
265 let dir = std::env::temp_dir()
266 .join("papers-test-cache")
267 .join(format!("{:x}", rand_u64()));
268 let cache = DiskCache::new(dir.clone(), Duration::from_secs(3600)).unwrap();
270 let q: Vec<(&str, String)> = vec![];
271 cache.set("http://a", &q, None, "fresh");
272
273 let key = cache_key("http://old", &q, None);
275 let expired = CacheEntry { ts: 0, body: "old".into() };
276 let json = serde_json::to_string(&expired).unwrap();
277 std::fs::write(dir.join(format!("{key:016x}.json")), json).unwrap();
278
279 std::fs::write(dir.join("leftover.tmp"), "junk").unwrap();
281
282 let file_count = || std::fs::read_dir(&dir).unwrap().count();
283 assert_eq!(file_count(), 3); let cache2 = DiskCache::new(dir.clone(), Duration::from_secs(3600)).unwrap();
287 assert_eq!(file_count(), 1); assert_eq!(cache2.get("http://a", &q, None).as_deref(), Some("fresh"));
289 assert!(cache2.get("http://old", &q, None).is_none());
290 }
291
292 #[test]
293 fn prune_removes_corrupted_files() {
294 let dir = std::env::temp_dir()
295 .join("papers-test-cache")
296 .join(format!("{:x}", rand_u64()));
297 std::fs::create_dir_all(&dir).unwrap();
298 std::fs::write(dir.join("badhash0000000000.json"), "not json").unwrap();
299 assert_eq!(std::fs::read_dir(&dir).unwrap().count(), 1);
300 let _cache = DiskCache::new(dir.clone(), Duration::from_secs(60)).unwrap();
301 assert_eq!(std::fs::read_dir(&dir).unwrap().count(), 0);
302 }
303
304 #[test]
305 fn directory_creation() {
306 let dir = std::env::temp_dir()
307 .join("papers-test-cache")
308 .join(format!("{:x}", rand_u64()))
309 .join("nested")
310 .join("deep");
311 let cache = DiskCache::new(dir.clone(), Duration::from_secs(60)).unwrap();
312 assert!(dir.exists());
313 let q: Vec<(&str, String)> = vec![];
314 cache.set("http://x", &q, None, "ok");
315 assert_eq!(cache.get("http://x", &q, None).as_deref(), Some("ok"));
316 }
317}