Skip to main content

papers_zotero/
cache.rs

1use serde::{Deserialize, Serialize};
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4use std::io;
5use std::path::PathBuf;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8/// Disk-based response cache with configurable TTL.
9///
10/// Caches HTTP response text as JSON files keyed by a hash of the request URL,
11/// query parameters, and optional POST body. Expired entries are treated as
12/// cache misses and silently ignored.
13///
14/// # Atomic writes
15///
16/// Writes use a temporary file + rename pattern to prevent partial reads from
17/// concurrent access.
18#[derive(Clone, Debug)]
19pub struct DiskCache {
20    cache_dir: PathBuf,
21    ttl: Duration,
22}
23
24#[derive(Serialize, Deserialize)]
25struct CacheEntry {
26    ts: u64,
27    body: String,
28}
29
30impl DiskCache {
31    /// Create a cache storing entries in `cache_dir` with the given TTL.
32    ///
33    /// Creates the directory (and parents) if it doesn't exist.
34    pub fn new(cache_dir: PathBuf, ttl: Duration) -> io::Result<Self> {
35        std::fs::create_dir_all(&cache_dir)?;
36        let cache = Self { cache_dir, ttl };
37        cache.prune();
38        Ok(cache)
39    }
40
41    /// Create a cache in the platform-standard cache directory.
42    ///
43    /// - Linux: `~/.cache/papers/requests`
44    /// - macOS: `~/Library/Caches/papers/requests`
45    /// - Windows: `{FOLDERID_LocalAppData}/papers/requests`
46    ///
47    /// Returns `Err` if no cache directory can be determined or created.
48    pub fn default_location(ttl: Duration) -> io::Result<Self> {
49        let base = dirs::cache_dir().ok_or_else(|| {
50            io::Error::new(io::ErrorKind::NotFound, "no platform cache directory")
51        })?;
52        Self::new(base.join("papers").join("requests"), ttl)
53    }
54
55    /// Look up a cached response.
56    ///
57    /// Returns `None` on cache miss, expired entry, or any I/O / parse error.
58    pub fn get(&self, url: &str, query: &[(&str, String)], body: Option<&str>) -> Option<String> {
59        let key = cache_key(url, query, body);
60        let path = self.cache_dir.join(format!("{key:016x}.json"));
61        let data = std::fs::read_to_string(&path).ok()?;
62        let entry: CacheEntry = serde_json::from_str(&data).ok()?;
63        let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs();
64        if now.saturating_sub(entry.ts) > self.ttl.as_secs() {
65            return None;
66        }
67        Some(entry.body)
68    }
69
70    /// Store a response in the cache.
71    ///
72    /// Writes atomically via a `.tmp` file + rename. Errors are silently
73    /// ignored — a failed cache write should never break a request.
74    pub fn set(&self, url: &str, query: &[(&str, String)], body: Option<&str>, response: &str) {
75        let _ = self.set_inner(url, query, body, response);
76    }
77
78    fn set_inner(
79        &self,
80        url: &str,
81        query: &[(&str, String)],
82        body: Option<&str>,
83        response: &str,
84    ) -> io::Result<()> {
85        let key = cache_key(url, query, body);
86        let ts = SystemTime::now()
87            .duration_since(UNIX_EPOCH)
88            .map_err(io::Error::other)?
89            .as_secs();
90        let entry = CacheEntry {
91            ts,
92            body: response.to_string(),
93        };
94        let json = serde_json::to_string(&entry).map_err(io::Error::other)?;
95        let tmp_path = self.cache_dir.join(format!("{key:016x}.tmp"));
96        let final_path = self.cache_dir.join(format!("{key:016x}.json"));
97        std::fs::write(&tmp_path, json)?;
98        std::fs::rename(&tmp_path, &final_path)?;
99        Ok(())
100    }
101
102    /// Remove expired entries and leftover `.tmp` files from the cache directory.
103    ///
104    /// Called automatically on construction. Errors on individual files are
105    /// silently ignored.
106    pub fn prune(&self) {
107        let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
108            Ok(d) => d.as_secs(),
109            Err(_) => return,
110        };
111        let entries = match std::fs::read_dir(&self.cache_dir) {
112            Ok(e) => e,
113            Err(_) => return,
114        };
115        for entry in entries.flatten() {
116            let path = entry.path();
117            let name = match path.file_name().and_then(|n| n.to_str()) {
118                Some(n) => n,
119                None => continue,
120            };
121            // Clean up leftover .tmp files
122            if name.ends_with(".tmp") {
123                let _ = std::fs::remove_file(&path);
124                continue;
125            }
126            // Only process our .json cache files
127            if !name.ends_with(".json") {
128                continue;
129            }
130            let data = match std::fs::read_to_string(&path) {
131                Ok(d) => d,
132                Err(_) => {
133                    let _ = std::fs::remove_file(&path);
134                    continue;
135                }
136            };
137            let entry: CacheEntry = match serde_json::from_str(&data) {
138                Ok(e) => e,
139                Err(_) => {
140                    let _ = std::fs::remove_file(&path);
141                    continue;
142                }
143            };
144            if now.saturating_sub(entry.ts) > self.ttl.as_secs() {
145                let _ = std::fs::remove_file(&path);
146            }
147        }
148    }
149}
150
151/// Compute a deterministic cache key from (url, sorted query pairs, optional body).
152fn cache_key(url: &str, query: &[(&str, String)], body: Option<&str>) -> u64 {
153    let mut sorted: Vec<(&str, &str)> = query.iter().map(|(k, v)| (*k, v.as_str())).collect();
154    sorted.sort();
155    let mut hasher = DefaultHasher::new();
156    url.hash(&mut hasher);
157    for (k, v) in &sorted {
158        k.hash(&mut hasher);
159        v.hash(&mut hasher);
160    }
161    if let Some(b) = body {
162        b.hash(&mut hasher);
163    }
164    hasher.finish()
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use std::thread::sleep;
171
172    fn temp_cache(ttl_secs: u64) -> DiskCache {
173        let dir = std::env::temp_dir()
174            .join("papers-zotero-test-cache")
175            .join(format!("{:x}", rand_u64()));
176        DiskCache::new(dir, Duration::from_secs(ttl_secs)).unwrap()
177    }
178
179    fn rand_u64() -> u64 {
180        let mut hasher = DefaultHasher::new();
181        SystemTime::now()
182            .duration_since(UNIX_EPOCH)
183            .unwrap()
184            .as_nanos()
185            .hash(&mut hasher);
186        std::thread::current().id().hash(&mut hasher);
187        hasher.finish()
188    }
189
190    #[test]
191    fn key_is_deterministic() {
192        let q = vec![("a", "1".into()), ("b", "2".into())];
193        let k1 = cache_key("http://x", &q, None);
194        let k2 = cache_key("http://x", &q, None);
195        assert_eq!(k1, k2);
196    }
197
198    #[test]
199    fn key_query_order_independent() {
200        let q1 = vec![("b", "2".into()), ("a", "1".into())];
201        let q2 = vec![("a", "1".into()), ("b", "2".into())];
202        assert_eq!(
203            cache_key("http://x", &q1, None),
204            cache_key("http://x", &q2, None)
205        );
206    }
207
208    #[test]
209    fn set_get_roundtrip() {
210        let cache = temp_cache(60);
211        let q = vec![("k", "v".into())];
212        cache.set("http://x", &q, None, "response body");
213        let got = cache.get("http://x", &q, None);
214        assert_eq!(got.as_deref(), Some("response body"));
215    }
216
217    #[test]
218    fn missing_key_returns_none() {
219        let cache = temp_cache(60);
220        let q: Vec<(&str, String)> = vec![];
221        assert!(cache.get("http://nonexistent", &q, None).is_none());
222    }
223
224    #[test]
225    fn expired_entry_returns_none() {
226        let cache = temp_cache(1);
227        let q: Vec<(&str, String)> = vec![];
228        cache.set("http://x", &q, None, "data");
229        sleep(Duration::from_secs(2));
230        assert!(cache.get("http://x", &q, None).is_none());
231    }
232}