1use serde::{Deserialize, Serialize};
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4use std::io;
5use std::path::PathBuf;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8#[derive(Clone, Debug)]
19pub struct DiskCache {
20 cache_dir: PathBuf,
21 ttl: Duration,
22}
23
24#[derive(Serialize, Deserialize)]
25struct CacheEntry {
26 ts: u64,
27 body: String,
28}
29
30impl DiskCache {
31 pub fn new(cache_dir: PathBuf, ttl: Duration) -> io::Result<Self> {
35 std::fs::create_dir_all(&cache_dir)?;
36 let cache = Self { cache_dir, ttl };
37 cache.prune();
38 Ok(cache)
39 }
40
41 pub fn default_location(ttl: Duration) -> io::Result<Self> {
49 let base = dirs::cache_dir().ok_or_else(|| {
50 io::Error::new(io::ErrorKind::NotFound, "no platform cache directory")
51 })?;
52 Self::new(base.join("papers").join("requests"), ttl)
53 }
54
55 pub fn get(&self, url: &str, query: &[(&str, String)], body: Option<&str>) -> Option<String> {
59 let key = cache_key(url, query, body);
60 let path = self.cache_dir.join(format!("{key:016x}.json"));
61 let data = std::fs::read_to_string(&path).ok()?;
62 let entry: CacheEntry = serde_json::from_str(&data).ok()?;
63 let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs();
64 if now.saturating_sub(entry.ts) > self.ttl.as_secs() {
65 return None;
66 }
67 Some(entry.body)
68 }
69
70 pub fn set(&self, url: &str, query: &[(&str, String)], body: Option<&str>, response: &str) {
75 let _ = self.set_inner(url, query, body, response);
76 }
77
78 fn set_inner(
79 &self,
80 url: &str,
81 query: &[(&str, String)],
82 body: Option<&str>,
83 response: &str,
84 ) -> io::Result<()> {
85 let key = cache_key(url, query, body);
86 let ts = SystemTime::now()
87 .duration_since(UNIX_EPOCH)
88 .map_err(io::Error::other)?
89 .as_secs();
90 let entry = CacheEntry {
91 ts,
92 body: response.to_string(),
93 };
94 let json = serde_json::to_string(&entry).map_err(io::Error::other)?;
95 let tmp_path = self.cache_dir.join(format!("{key:016x}.tmp"));
96 let final_path = self.cache_dir.join(format!("{key:016x}.json"));
97 std::fs::write(&tmp_path, json)?;
98 std::fs::rename(&tmp_path, &final_path)?;
99 Ok(())
100 }
101
102 pub fn prune(&self) {
107 let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
108 Ok(d) => d.as_secs(),
109 Err(_) => return,
110 };
111 let entries = match std::fs::read_dir(&self.cache_dir) {
112 Ok(e) => e,
113 Err(_) => return,
114 };
115 for entry in entries.flatten() {
116 let path = entry.path();
117 let name = match path.file_name().and_then(|n| n.to_str()) {
118 Some(n) => n,
119 None => continue,
120 };
121 if name.ends_with(".tmp") {
123 let _ = std::fs::remove_file(&path);
124 continue;
125 }
126 if !name.ends_with(".json") {
128 continue;
129 }
130 let data = match std::fs::read_to_string(&path) {
131 Ok(d) => d,
132 Err(_) => {
133 let _ = std::fs::remove_file(&path);
134 continue;
135 }
136 };
137 let entry: CacheEntry = match serde_json::from_str(&data) {
138 Ok(e) => e,
139 Err(_) => {
140 let _ = std::fs::remove_file(&path);
141 continue;
142 }
143 };
144 if now.saturating_sub(entry.ts) > self.ttl.as_secs() {
145 let _ = std::fs::remove_file(&path);
146 }
147 }
148 }
149}
150
151fn cache_key(url: &str, query: &[(&str, String)], body: Option<&str>) -> u64 {
153 let mut sorted: Vec<(&str, &str)> = query.iter().map(|(k, v)| (*k, v.as_str())).collect();
154 sorted.sort();
155 let mut hasher = DefaultHasher::new();
156 url.hash(&mut hasher);
157 for (k, v) in &sorted {
158 k.hash(&mut hasher);
159 v.hash(&mut hasher);
160 }
161 if let Some(b) = body {
162 b.hash(&mut hasher);
163 }
164 hasher.finish()
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use std::thread::sleep;
171
172 fn temp_cache(ttl_secs: u64) -> DiskCache {
173 let dir = std::env::temp_dir()
174 .join("papers-zotero-test-cache")
175 .join(format!("{:x}", rand_u64()));
176 DiskCache::new(dir, Duration::from_secs(ttl_secs)).unwrap()
177 }
178
179 fn rand_u64() -> u64 {
180 let mut hasher = DefaultHasher::new();
181 SystemTime::now()
182 .duration_since(UNIX_EPOCH)
183 .unwrap()
184 .as_nanos()
185 .hash(&mut hasher);
186 std::thread::current().id().hash(&mut hasher);
187 hasher.finish()
188 }
189
190 #[test]
191 fn key_is_deterministic() {
192 let q = vec![("a", "1".into()), ("b", "2".into())];
193 let k1 = cache_key("http://x", &q, None);
194 let k2 = cache_key("http://x", &q, None);
195 assert_eq!(k1, k2);
196 }
197
198 #[test]
199 fn key_query_order_independent() {
200 let q1 = vec![("b", "2".into()), ("a", "1".into())];
201 let q2 = vec![("a", "1".into()), ("b", "2".into())];
202 assert_eq!(
203 cache_key("http://x", &q1, None),
204 cache_key("http://x", &q2, None)
205 );
206 }
207
208 #[test]
209 fn set_get_roundtrip() {
210 let cache = temp_cache(60);
211 let q = vec![("k", "v".into())];
212 cache.set("http://x", &q, None, "response body");
213 let got = cache.get("http://x", &q, None);
214 assert_eq!(got.as_deref(), Some("response body"));
215 }
216
217 #[test]
218 fn missing_key_returns_none() {
219 let cache = temp_cache(60);
220 let q: Vec<(&str, String)> = vec![];
221 assert!(cache.get("http://nonexistent", &q, None).is_none());
222 }
223
224 #[test]
225 fn expired_entry_returns_none() {
226 let cache = temp_cache(1);
227 let q: Vec<(&str, String)> = vec![];
228 cache.set("http://x", &q, None, "data");
229 sleep(Duration::from_secs(2));
230 assert!(cache.get("http://x", &q, None).is_none());
231 }
232}