1use std::collections::HashMap;
2use std::collections::hash_map::DefaultHasher;
3use std::error::Error;
4use std::fs;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9
10pub const DEFAULT_SCHEMA_CACHE_TTL: Duration = Duration::from_secs(12 * 60 * 60);
12
13use serde_json::Value;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum CacheStatus {
18 Hit,
20 Miss,
22 Disabled,
24}
25
26#[async_trait::async_trait]
28pub trait HttpClient: Clone + Send + Sync + 'static {
29 async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>>;
33}
34
35#[derive(Clone)]
37pub struct ReqwestClient(pub reqwest::Client);
38
39impl Default for ReqwestClient {
40 fn default() -> Self {
41 Self(reqwest::Client::new())
42 }
43}
44
45#[async_trait::async_trait]
46impl HttpClient for ReqwestClient {
47 async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
48 let resp = self.0.get(uri).send().await?.error_for_status()?;
49 Ok(resp.text().await?)
50 }
51}
52
53#[derive(Clone)]
60pub struct SchemaCache<C: HttpClient = ReqwestClient> {
61 cache_dir: Option<PathBuf>,
62 client: C,
63 skip_read: bool,
64 ttl: Option<Duration>,
65 memory_cache: Arc<Mutex<HashMap<String, Value>>>,
67}
68
69impl<C: HttpClient> SchemaCache<C> {
70 pub fn new(
71 cache_dir: Option<PathBuf>,
72 client: C,
73 skip_read: bool,
74 ttl: Option<Duration>,
75 ) -> Self {
76 Self {
77 cache_dir,
78 client,
79 skip_read,
80 ttl,
81 memory_cache: Arc::new(Mutex::new(HashMap::new())),
82 }
83 }
84
85 #[allow(clippy::missing_panics_doc)] #[tracing::instrument(skip(self), fields(status))]
99 pub async fn fetch(
100 &self,
101 uri: &str,
102 ) -> Result<(Value, CacheStatus), Box<dyn Error + Send + Sync>> {
103 if !self.skip_read
105 && let Some(value) = self
106 .memory_cache
107 .lock()
108 .expect("memory cache poisoned")
109 .get(uri)
110 .cloned()
111 {
112 tracing::Span::current().record("status", "memory_hit");
113 return Ok((value, CacheStatus::Hit));
114 }
115
116 if !self.skip_read
118 && let Some(ref cache_dir) = self.cache_dir
119 {
120 let hash = Self::hash_uri(uri);
121 let cache_path = cache_dir.join(format!("{hash}.json"));
122 if cache_path.exists()
123 && !self.is_expired(&cache_path)
124 && let Ok(content) = tokio::fs::read_to_string(&cache_path).await
125 && let Ok(value) = serde_json::from_str::<Value>(&content)
126 {
127 self.memory_cache
128 .lock()
129 .expect("memory cache poisoned")
130 .insert(uri.to_string(), value.clone());
131 tracing::Span::current().record("status", "cache_hit");
132 return Ok((value, CacheStatus::Hit));
133 }
134 }
135
136 tracing::Span::current().record("status", "network_fetch");
138 let body = self.client.get(uri).await?;
139 let value: Value = serde_json::from_str(&body)?;
140
141 self.memory_cache
143 .lock()
144 .expect("memory cache poisoned")
145 .insert(uri.to_string(), value.clone());
146
147 let status = if let Some(ref cache_dir) = self.cache_dir {
148 let hash = Self::hash_uri(uri);
149 let cache_path = cache_dir.join(format!("{hash}.json"));
150 if let Err(e) = tokio::fs::write(&cache_path, &body).await {
151 tracing::warn!(
152 path = %cache_path.display(),
153 error = %e,
154 "failed to write schema to disk cache"
155 );
156 }
157 CacheStatus::Miss
158 } else {
159 CacheStatus::Disabled
160 };
161
162 Ok((value, status))
163 }
164
165 fn is_expired(&self, path: &std::path::Path) -> bool {
171 let Some(ttl) = self.ttl else {
172 return false;
173 };
174 fs::metadata(path)
175 .ok()
176 .and_then(|m| m.modified().ok())
177 .and_then(|mtime| mtime.elapsed().ok())
178 .is_some_and(|age| age > ttl)
179 }
180
181 fn hash_uri(uri: &str) -> String {
182 let mut hasher = DefaultHasher::new();
183 uri.hash(&mut hasher);
184 format!("{:016x}", hasher.finish())
185 }
186}
187
188pub fn ensure_cache_dir() -> PathBuf {
193 let candidates = [
194 dirs::cache_dir().map(|d| d.join("lintel").join("schemas")),
195 Some(std::env::temp_dir().join("lintel").join("schemas")),
196 ];
197 for candidate in candidates.into_iter().flatten() {
198 if fs::create_dir_all(&candidate).is_ok() {
199 return candidate;
200 }
201 }
202 std::env::temp_dir().join("lintel").join("schemas")
203}
204
205#[async_trait::async_trait]
208impl<C: HttpClient> jsonschema::AsyncRetrieve for SchemaCache<C> {
209 async fn retrieve(
210 &self,
211 uri: &jsonschema::Uri<String>,
212 ) -> Result<Value, Box<dyn Error + Send + Sync>> {
213 let (value, _status) = self.fetch(uri.as_str()).await?;
214 Ok(value)
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[derive(Clone)]
223 struct MockClient(HashMap<String, String>);
224
225 #[async_trait::async_trait]
226 impl HttpClient for MockClient {
227 async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
228 self.0
229 .get(uri)
230 .cloned()
231 .ok_or_else(|| format!("mock: no response for {uri}").into())
232 }
233 }
234
235 fn mock(entries: &[(&str, &str)]) -> MockClient {
236 MockClient(
237 entries
238 .iter()
239 .map(|(k, v)| (k.to_string(), v.to_string()))
240 .collect(),
241 )
242 }
243
244 #[test]
245 fn hash_uri_deterministic() {
246 let a = SchemaCache::<MockClient>::hash_uri("https://example.com/schema.json");
247 let b = SchemaCache::<MockClient>::hash_uri("https://example.com/schema.json");
248 assert_eq!(a, b);
249 }
250
251 #[test]
252 fn hash_uri_different_inputs() {
253 let a = SchemaCache::<MockClient>::hash_uri("https://example.com/a.json");
254 let b = SchemaCache::<MockClient>::hash_uri("https://example.com/b.json");
255 assert_ne!(a, b);
256 }
257
258 #[allow(clippy::needless_pass_by_value)]
260 fn boxerr(e: Box<dyn Error + Send + Sync>) -> anyhow::Error {
261 anyhow::anyhow!("{e}")
262 }
263
264 #[tokio::test]
265 async fn fetch_no_cache_dir() -> anyhow::Result<()> {
266 let client = mock(&[("https://example.com/s.json", r#"{"type":"object"}"#)]);
267 let cache = SchemaCache::new(None, client, false, None);
268 let (val, status) = cache
269 .fetch("https://example.com/s.json")
270 .await
271 .map_err(boxerr)?;
272 assert_eq!(val, serde_json::json!({"type": "object"}));
273 assert_eq!(status, CacheStatus::Disabled);
274 Ok(())
275 }
276
277 #[tokio::test]
278 async fn fetch_cold_cache() -> anyhow::Result<()> {
279 let tmp = tempfile::tempdir()?;
280 let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
281 let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, false, None);
282 let (val, status) = cache
283 .fetch("https://example.com/s.json")
284 .await
285 .map_err(boxerr)?;
286 assert_eq!(val, serde_json::json!({"type": "string"}));
287 assert_eq!(status, CacheStatus::Miss);
288
289 let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
291 let cache_path = tmp.path().join(format!("{hash}.json"));
292 assert!(cache_path.exists());
293 Ok(())
294 }
295
296 #[tokio::test]
297 async fn fetch_warm_cache() -> anyhow::Result<()> {
298 let tmp = tempfile::tempdir()?;
299 let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
300 let cache_path = tmp.path().join(format!("{hash}.json"));
301 fs::write(&cache_path, r#"{"type":"number"}"#)?;
302
303 let client = mock(&[]);
305 let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, false, None);
306 let (val, status) = cache
307 .fetch("https://example.com/s.json")
308 .await
309 .map_err(boxerr)?;
310 assert_eq!(val, serde_json::json!({"type": "number"}));
311 assert_eq!(status, CacheStatus::Hit);
312 Ok(())
313 }
314
315 #[tokio::test]
316 async fn fetch_skip_read_bypasses_cache() -> anyhow::Result<()> {
317 let tmp = tempfile::tempdir()?;
318 let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
319 let cache_path = tmp.path().join(format!("{hash}.json"));
320 fs::write(&cache_path, r#"{"type":"number"}"#)?;
321
322 let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
324 let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, true, None);
325 let (val, status) = cache
326 .fetch("https://example.com/s.json")
327 .await
328 .map_err(boxerr)?;
329 assert_eq!(val, serde_json::json!({"type": "string"}));
330 assert_eq!(status, CacheStatus::Miss);
331 Ok(())
332 }
333
334 #[tokio::test]
335 async fn fetch_client_error() {
336 let client = mock(&[]);
337 let cache = SchemaCache::new(None, client, false, None);
338 assert!(
339 cache
340 .fetch("https://example.com/missing.json")
341 .await
342 .is_err()
343 );
344 }
345
346 #[tokio::test]
347 async fn fetch_invalid_json() {
348 let client = mock(&[("https://example.com/bad.json", "not json")]);
349 let cache = SchemaCache::new(None, client, false, None);
350 assert!(cache.fetch("https://example.com/bad.json").await.is_err());
351 }
352
353 #[tokio::test]
354 async fn async_retrieve_trait_delegates() -> anyhow::Result<()> {
355 let client = mock(&[("https://example.com/s.json", r#"{"type":"object"}"#)]);
356 let cache = SchemaCache::new(None, client, false, None);
357 let uri: jsonschema::Uri<String> = "https://example.com/s.json".parse()?;
358 let val = jsonschema::AsyncRetrieve::retrieve(&cache, &uri)
359 .await
360 .map_err(boxerr)?;
361 assert_eq!(val, serde_json::json!({"type": "object"}));
362 Ok(())
363 }
364
365 #[tokio::test]
366 async fn fetch_expired_ttl_refetches() -> anyhow::Result<()> {
367 let tmp = tempfile::tempdir()?;
368 let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
369 let cache_path = tmp.path().join(format!("{hash}.json"));
370 fs::write(&cache_path, r#"{"type":"number"}"#)?;
371
372 let two_secs_ago = filetime::FileTime::from_system_time(
374 std::time::SystemTime::now() - std::time::Duration::from_secs(2),
375 );
376 filetime::set_file_mtime(&cache_path, two_secs_ago)?;
377
378 let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
380 let cache = SchemaCache::new(
381 Some(tmp.path().to_path_buf()),
382 client,
383 false,
384 Some(Duration::from_secs(1)),
385 );
386 let (val, status) = cache
387 .fetch("https://example.com/s.json")
388 .await
389 .map_err(boxerr)?;
390 assert_eq!(val, serde_json::json!({"type": "string"}));
391 assert_eq!(status, CacheStatus::Miss);
392 Ok(())
393 }
394
395 #[tokio::test]
396 async fn fetch_unexpired_ttl_serves_cache() -> anyhow::Result<()> {
397 let tmp = tempfile::tempdir()?;
398 let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
399 let cache_path = tmp.path().join(format!("{hash}.json"));
400 fs::write(&cache_path, r#"{"type":"number"}"#)?;
401
402 let client = mock(&[]);
404 let cache = SchemaCache::new(
405 Some(tmp.path().to_path_buf()),
406 client,
407 false,
408 Some(Duration::from_secs(3600)),
409 );
410 let (val, status) = cache
411 .fetch("https://example.com/s.json")
412 .await
413 .map_err(boxerr)?;
414 assert_eq!(val, serde_json::json!({"type": "number"}));
415 assert_eq!(status, CacheStatus::Hit);
416 Ok(())
417 }
418
419 #[test]
420 fn ensure_cache_dir_ends_with_schemas() {
421 let dir = ensure_cache_dir();
422 assert!(dir.ends_with("lintel/schemas"));
423 }
424}