Skip to main content

papers_openalex/
cache.rs

1use serde::{Deserialize, Serialize};
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4use std::io;
5use std::path::PathBuf;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8/// Disk-based response cache with configurable TTL.
9///
10/// Caches HTTP response text as JSON files keyed by a hash of the request URL,
11/// query parameters, and optional POST body. Expired entries are treated as
12/// cache misses and silently ignored.
13///
14/// # Atomic writes
15///
16/// Writes use a temporary file + rename pattern to prevent partial reads from
17/// concurrent access.
18#[derive(Clone, Debug)]
19pub struct DiskCache {
20    cache_dir: PathBuf,
21    ttl: Duration,
22}
23
24#[derive(Serialize, Deserialize)]
25struct CacheEntry {
26    ts: u64,
27    body: String,
28}
29
30impl DiskCache {
31    /// Create a cache storing entries in `cache_dir` with the given TTL.
32    ///
33    /// Creates the directory (and parents) if it doesn't exist.
34    pub fn new(cache_dir: PathBuf, ttl: Duration) -> io::Result<Self> {
35        std::fs::create_dir_all(&cache_dir)?;
36        let cache = Self { cache_dir, ttl };
37        cache.prune();
38        Ok(cache)
39    }
40
41    /// Create a cache in the platform-standard cache directory.
42    ///
43    /// - Linux: `~/.cache/papers/requests`
44    /// - macOS: `~/Library/Caches/papers/requests`
45    /// - Windows: `{FOLDERID_LocalAppData}/papers/requests`
46    ///
47    /// Returns `Err` if no cache directory can be determined or created.
48    pub fn default_location(ttl: Duration) -> io::Result<Self> {
49        let base = dirs::cache_dir().ok_or_else(|| {
50            io::Error::new(io::ErrorKind::NotFound, "no platform cache directory")
51        })?;
52        Self::new(base.join("papers").join("requests"), ttl)
53    }
54
55    /// Look up a cached response.
56    ///
57    /// Returns `None` on cache miss, expired entry, or any I/O / parse error.
58    pub fn get(&self, url: &str, query: &[(&str, String)], body: Option<&str>) -> Option<String> {
59        let key = cache_key(url, query, body);
60        let path = self.cache_dir.join(format!("{key:016x}.json"));
61        let data = std::fs::read_to_string(&path).ok()?;
62        let entry: CacheEntry = serde_json::from_str(&data).ok()?;
63        let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs();
64        if now.saturating_sub(entry.ts) > self.ttl.as_secs() {
65            return None;
66        }
67        Some(entry.body)
68    }
69
70    /// Store a response in the cache.
71    ///
72    /// Writes atomically via a `.tmp` file + rename. Errors are silently
73    /// ignored — a failed cache write should never break a request.
74    pub fn set(&self, url: &str, query: &[(&str, String)], body: Option<&str>, response: &str) {
75        let _ = self.set_inner(url, query, body, response);
76    }
77
78    fn set_inner(
79        &self,
80        url: &str,
81        query: &[(&str, String)],
82        body: Option<&str>,
83        response: &str,
84    ) -> io::Result<()> {
85        let key = cache_key(url, query, body);
86        let ts = SystemTime::now()
87            .duration_since(UNIX_EPOCH)
88            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
89            .as_secs();
90        let entry = CacheEntry {
91            ts,
92            body: response.to_string(),
93        };
94        let json = serde_json::to_string(&entry)
95            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
96        let tmp_path = self.cache_dir.join(format!("{key:016x}.tmp"));
97        let final_path = self.cache_dir.join(format!("{key:016x}.json"));
98        std::fs::write(&tmp_path, json)?;
99        std::fs::rename(&tmp_path, &final_path)?;
100        Ok(())
101    }
102
103    /// Remove expired entries and leftover `.tmp` files from the cache directory.
104    ///
105    /// Called automatically on construction. Errors on individual files are
106    /// silently ignored.
107    pub fn prune(&self) {
108        let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
109            Ok(d) => d.as_secs(),
110            Err(_) => return,
111        };
112        let entries = match std::fs::read_dir(&self.cache_dir) {
113            Ok(e) => e,
114            Err(_) => return,
115        };
116        for entry in entries.flatten() {
117            let path = entry.path();
118            let name = match path.file_name().and_then(|n| n.to_str()) {
119                Some(n) => n,
120                None => continue,
121            };
122            // Clean up leftover .tmp files
123            if name.ends_with(".tmp") {
124                let _ = std::fs::remove_file(&path);
125                continue;
126            }
127            // Only process our .json cache files
128            if !name.ends_with(".json") {
129                continue;
130            }
131            let data = match std::fs::read_to_string(&path) {
132                Ok(d) => d,
133                Err(_) => {
134                    let _ = std::fs::remove_file(&path);
135                    continue;
136                }
137            };
138            let entry: CacheEntry = match serde_json::from_str(&data) {
139                Ok(e) => e,
140                Err(_) => {
141                    let _ = std::fs::remove_file(&path);
142                    continue;
143                }
144            };
145            if now.saturating_sub(entry.ts) > self.ttl.as_secs() {
146                let _ = std::fs::remove_file(&path);
147            }
148        }
149    }
150}
151
152/// Compute a deterministic cache key from (url, sorted query pairs, optional body).
153fn cache_key(url: &str, query: &[(&str, String)], body: Option<&str>) -> u64 {
154    let mut sorted: Vec<(&str, &str)> = query.iter().map(|(k, v)| (*k, v.as_str())).collect();
155    sorted.sort();
156    let mut hasher = DefaultHasher::new();
157    url.hash(&mut hasher);
158    for (k, v) in &sorted {
159        k.hash(&mut hasher);
160        v.hash(&mut hasher);
161    }
162    if let Some(b) = body {
163        b.hash(&mut hasher);
164    }
165    hasher.finish()
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use std::thread::sleep;
172
173    fn temp_cache(ttl_secs: u64) -> DiskCache {
174        let dir = std::env::temp_dir()
175            .join("papers-test-cache")
176            .join(format!("{:x}", rand_u64()));
177        DiskCache::new(dir, Duration::from_secs(ttl_secs)).unwrap()
178    }
179
180    fn rand_u64() -> u64 {
181        let mut hasher = DefaultHasher::new();
182        SystemTime::now()
183            .duration_since(UNIX_EPOCH)
184            .unwrap()
185            .as_nanos()
186            .hash(&mut hasher);
187        std::thread::current().id().hash(&mut hasher);
188        hasher.finish()
189    }
190
191    #[test]
192    fn key_is_deterministic() {
193        let q = vec![("a", "1".into()), ("b", "2".into())];
194        let k1 = cache_key("http://x", &q, None);
195        let k2 = cache_key("http://x", &q, None);
196        assert_eq!(k1, k2);
197    }
198
199    #[test]
200    fn key_differs_by_url() {
201        let q: Vec<(&str, String)> = vec![];
202        assert_ne!(cache_key("http://a", &q, None), cache_key("http://b", &q, None));
203    }
204
205    #[test]
206    fn key_differs_by_query() {
207        let q1 = vec![("a", "1".into())];
208        let q2 = vec![("a", "2".into())];
209        assert_ne!(cache_key("http://x", &q1, None), cache_key("http://x", &q2, None));
210    }
211
212    #[test]
213    fn key_differs_by_body() {
214        let q: Vec<(&str, String)> = vec![];
215        assert_ne!(
216            cache_key("http://x", &q, Some("body1")),
217            cache_key("http://x", &q, Some("body2"))
218        );
219    }
220
221    #[test]
222    fn key_query_order_independent() {
223        let q1 = vec![("b", "2".into()), ("a", "1".into())];
224        let q2 = vec![("a", "1".into()), ("b", "2".into())];
225        assert_eq!(cache_key("http://x", &q1, None), cache_key("http://x", &q2, None));
226    }
227
228    #[test]
229    fn set_get_roundtrip() {
230        let cache = temp_cache(60);
231        let q = vec![("k", "v".into())];
232        cache.set("http://x", &q, None, "response body");
233        let got = cache.get("http://x", &q, None);
234        assert_eq!(got.as_deref(), Some("response body"));
235    }
236
237    #[test]
238    fn missing_key_returns_none() {
239        let cache = temp_cache(60);
240        let q: Vec<(&str, String)> = vec![];
241        assert!(cache.get("http://nonexistent", &q, None).is_none());
242    }
243
244    #[test]
245    fn expired_entry_returns_none() {
246        let cache = temp_cache(1);
247        let q: Vec<(&str, String)> = vec![];
248        cache.set("http://x", &q, None, "data");
249        sleep(Duration::from_secs(2));
250        assert!(cache.get("http://x", &q, None).is_none());
251    }
252
253    #[test]
254    fn corrupted_file_returns_none() {
255        let cache = temp_cache(60);
256        let q: Vec<(&str, String)> = vec![];
257        let key = cache_key("http://x", &q, None);
258        let path = cache.cache_dir.join(format!("{key:016x}.json"));
259        std::fs::write(&path, "not json").unwrap();
260        assert!(cache.get("http://x", &q, None).is_none());
261    }
262
263    #[test]
264    fn prune_removes_expired_entries() {
265        let dir = std::env::temp_dir()
266            .join("papers-test-cache")
267            .join(format!("{:x}", rand_u64()));
268        // Create cache with long TTL, write an entry, then re-create with short TTL
269        let cache = DiskCache::new(dir.clone(), Duration::from_secs(3600)).unwrap();
270        let q: Vec<(&str, String)> = vec![];
271        cache.set("http://a", &q, None, "fresh");
272
273        // Manually write an expired entry
274        let key = cache_key("http://old", &q, None);
275        let expired = CacheEntry { ts: 0, body: "old".into() };
276        let json = serde_json::to_string(&expired).unwrap();
277        std::fs::write(dir.join(format!("{key:016x}.json")), json).unwrap();
278
279        // Write a .tmp leftover
280        std::fs::write(dir.join("leftover.tmp"), "junk").unwrap();
281
282        let file_count = || std::fs::read_dir(&dir).unwrap().count();
283        assert_eq!(file_count(), 3); // fresh + expired + tmp
284
285        // Re-create cache — prune runs on construction
286        let cache2 = DiskCache::new(dir.clone(), Duration::from_secs(3600)).unwrap();
287        assert_eq!(file_count(), 1); // only fresh remains
288        assert_eq!(cache2.get("http://a", &q, None).as_deref(), Some("fresh"));
289        assert!(cache2.get("http://old", &q, None).is_none());
290    }
291
292    #[test]
293    fn prune_removes_corrupted_files() {
294        let dir = std::env::temp_dir()
295            .join("papers-test-cache")
296            .join(format!("{:x}", rand_u64()));
297        std::fs::create_dir_all(&dir).unwrap();
298        std::fs::write(dir.join("badhash0000000000.json"), "not json").unwrap();
299        assert_eq!(std::fs::read_dir(&dir).unwrap().count(), 1);
300        let _cache = DiskCache::new(dir.clone(), Duration::from_secs(60)).unwrap();
301        assert_eq!(std::fs::read_dir(&dir).unwrap().count(), 0);
302    }
303
304    #[test]
305    fn directory_creation() {
306        let dir = std::env::temp_dir()
307            .join("papers-test-cache")
308            .join(format!("{:x}", rand_u64()))
309            .join("nested")
310            .join("deep");
311        let cache = DiskCache::new(dir.clone(), Duration::from_secs(60)).unwrap();
312        assert!(dir.exists());
313        let q: Vec<(&str, String)> = vec![];
314        cache.set("http://x", &q, None, "ok");
315        assert_eq!(cache.get("http://x", &q, None).as_deref(), Some("ok"));
316    }
317}