1#![doc = include_str!("../README.md")]
2
3extern crate alloc;
4
5use alloc::sync::Arc;
6use core::error::Error;
7use core::hash::{Hash, Hasher};
8use core::time::Duration;
9use std::collections::HashMap;
10use std::collections::hash_map::DefaultHasher;
11use std::fs;
12use std::path::PathBuf;
13use std::sync::Mutex;
14
15pub const DEFAULT_SCHEMA_CACHE_TTL: Duration = Duration::from_secs(12 * 60 * 60);
17
18use serde_json::Value;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum CacheStatus {
23 Hit,
25 Miss,
27 Disabled,
29}
30
31pub struct ConditionalResponse {
33 pub body: Option<String>,
35 pub etag: Option<String>,
37}
38
39#[async_trait::async_trait]
41pub trait HttpClient: Clone + Send + Sync + 'static {
42 async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>>;
46
47 async fn get_conditional(
58 &self,
59 uri: &str,
60 _etag: Option<&str>,
61 ) -> Result<ConditionalResponse, Box<dyn Error + Send + Sync>> {
62 let body = self.get(uri).await?;
63 Ok(ConditionalResponse {
64 body: Some(body),
65 etag: None,
66 })
67 }
68}
69
70#[derive(Clone)]
72pub struct ReqwestClient(pub reqwest::Client);
73
74impl Default for ReqwestClient {
75 fn default() -> Self {
76 Self(reqwest::Client::new())
77 }
78}
79
80#[async_trait::async_trait]
81impl HttpClient for ReqwestClient {
82 async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
83 let resp = self.0.get(uri).send().await?.error_for_status()?;
84 Ok(resp.text().await?)
85 }
86
87 async fn get_conditional(
88 &self,
89 uri: &str,
90 etag: Option<&str>,
91 ) -> Result<ConditionalResponse, Box<dyn Error + Send + Sync>> {
92 let mut req = self.0.get(uri);
93 if let Some(etag) = etag {
94 req = req.header("If-None-Match", etag);
95 }
96 let resp = req.send().await?;
97 if resp.status() == reqwest::StatusCode::NOT_MODIFIED {
98 return Ok(ConditionalResponse {
99 body: None,
100 etag: None,
101 });
102 }
103 let resp = resp.error_for_status()?;
104 let etag = resp
105 .headers()
106 .get("etag")
107 .and_then(|v| v.to_str().ok())
108 .map(String::from);
109 let body = resp.text().await?;
110 Ok(ConditionalResponse {
111 body: Some(body),
112 etag,
113 })
114 }
115}
116
117#[derive(Clone)]
124pub struct SchemaCache<C: HttpClient = ReqwestClient> {
125 cache_dir: Option<PathBuf>,
126 client: C,
127 skip_read: bool,
128 ttl: Option<Duration>,
129 memory_cache: Arc<Mutex<HashMap<String, Value>>>,
131}
132
133impl<C: HttpClient> SchemaCache<C> {
134 pub fn new(
135 cache_dir: Option<PathBuf>,
136 client: C,
137 skip_read: bool,
138 ttl: Option<Duration>,
139 ) -> Self {
140 Self {
141 cache_dir,
142 client,
143 skip_read,
144 ttl,
145 memory_cache: Arc::new(Mutex::new(HashMap::new())),
146 }
147 }
148
149 #[allow(clippy::missing_panics_doc)] #[tracing::instrument(skip(self), fields(status))]
163 pub async fn fetch(
164 &self,
165 uri: &str,
166 ) -> Result<(Value, CacheStatus), Box<dyn Error + Send + Sync>> {
167 if !self.skip_read
169 && let Some(value) = self
170 .memory_cache
171 .lock()
172 .expect("memory cache poisoned")
173 .get(uri)
174 .cloned()
175 {
176 tracing::Span::current().record("status", "memory_hit");
177 return Ok((value, CacheStatus::Hit));
178 }
179
180 let mut stored_etag: Option<String> = None;
182 let mut cached_content: Option<String> = None;
183
184 if let Some(ref cache_dir) = self.cache_dir {
185 let hash = Self::hash_uri(uri);
186 let cache_path = cache_dir.join(format!("{hash}.json"));
187 let etag_path = cache_dir.join(format!("{hash}.etag"));
188
189 if cache_path.exists() {
190 if !self.skip_read && !self.is_expired(&cache_path) {
191 if let Ok(content) = tokio::fs::read_to_string(&cache_path).await
193 && let Ok(value) = serde_json::from_str::<Value>(&content)
194 {
195 self.memory_cache
196 .lock()
197 .expect("memory cache poisoned")
198 .insert(uri.to_string(), value.clone());
199 tracing::Span::current().record("status", "cache_hit");
200 return Ok((value, CacheStatus::Hit));
201 }
202 }
203
204 if let Ok(etag) = tokio::fs::read_to_string(&etag_path).await {
206 stored_etag = Some(etag);
207 }
208 if let Ok(content) = tokio::fs::read_to_string(&cache_path).await {
210 cached_content = Some(content);
211 }
212 }
213 }
214
215 tracing::Span::current().record("status", "network_fetch");
217 let conditional = self
218 .client
219 .get_conditional(uri, stored_etag.as_deref())
220 .await?;
221
222 if conditional.body.is_none() {
223 if let Some(content) = cached_content {
225 let value: Value = serde_json::from_str(&content)?;
226 self.memory_cache
227 .lock()
228 .expect("memory cache poisoned")
229 .insert(uri.to_string(), value.clone());
230
231 if let Some(ref cache_dir) = self.cache_dir {
233 let hash = Self::hash_uri(uri);
234 let cache_path = cache_dir.join(format!("{hash}.json"));
235 let now = filetime::FileTime::now();
236 let _ = filetime::set_file_mtime(&cache_path, now);
237 }
238
239 tracing::Span::current().record("status", "etag_hit");
240 return Ok((value, CacheStatus::Hit));
241 }
242 }
243
244 let body = conditional.body.expect("non-304 response must have a body");
245 let value: Value = serde_json::from_str(&body)?;
246
247 self.memory_cache
249 .lock()
250 .expect("memory cache poisoned")
251 .insert(uri.to_string(), value.clone());
252
253 let status = if let Some(ref cache_dir) = self.cache_dir {
254 let hash = Self::hash_uri(uri);
255 let cache_path = cache_dir.join(format!("{hash}.json"));
256 let etag_path = cache_dir.join(format!("{hash}.etag"));
257 if let Err(e) = tokio::fs::write(&cache_path, &body).await {
258 tracing::warn!(
259 path = %cache_path.display(),
260 error = %e,
261 "failed to write schema to disk cache"
262 );
263 }
264 if let Some(etag) = conditional.etag {
266 let _ = tokio::fs::write(&etag_path, &etag).await;
267 }
268 CacheStatus::Miss
269 } else {
270 CacheStatus::Disabled
271 };
272
273 Ok((value, status))
274 }
275
276 fn is_expired(&self, path: &std::path::Path) -> bool {
282 let Some(ttl) = self.ttl else {
283 return false;
284 };
285 fs::metadata(path)
286 .ok()
287 .and_then(|m| m.modified().ok())
288 .and_then(|mtime| mtime.elapsed().ok())
289 .is_some_and(|age| age > ttl)
290 }
291
292 pub fn hash_uri(uri: &str) -> String {
293 let mut hasher = DefaultHasher::new();
294 uri.hash(&mut hasher);
295 format!("{:016x}", hasher.finish())
296 }
297}
298
299pub fn ensure_cache_dir() -> PathBuf {
304 let candidates = [
305 dirs::cache_dir().map(|d| d.join("lintel").join("schemas")),
306 Some(std::env::temp_dir().join("lintel").join("schemas")),
307 ];
308 for candidate in candidates.into_iter().flatten() {
309 if fs::create_dir_all(&candidate).is_ok() {
310 return candidate;
311 }
312 }
313 std::env::temp_dir().join("lintel").join("schemas")
314}
315
316#[async_trait::async_trait]
319impl<C: HttpClient> jsonschema::AsyncRetrieve for SchemaCache<C> {
320 async fn retrieve(
321 &self,
322 uri: &jsonschema::Uri<String>,
323 ) -> Result<Value, Box<dyn Error + Send + Sync>> {
324 let (value, _status) = self.fetch(uri.as_str()).await?;
325 Ok(value)
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[derive(Clone)]
334 struct MockClient(HashMap<String, String>);
335
336 #[async_trait::async_trait]
337 impl HttpClient for MockClient {
338 async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
339 self.0
340 .get(uri)
341 .cloned()
342 .ok_or_else(|| format!("mock: no response for {uri}").into())
343 }
344 }
345
346 fn mock(entries: &[(&str, &str)]) -> MockClient {
347 MockClient(
348 entries
349 .iter()
350 .map(|(k, v)| (k.to_string(), v.to_string()))
351 .collect(),
352 )
353 }
354
355 #[test]
356 fn hash_uri_deterministic() {
357 let a = SchemaCache::<MockClient>::hash_uri("https://example.com/schema.json");
358 let b = SchemaCache::<MockClient>::hash_uri("https://example.com/schema.json");
359 assert_eq!(a, b);
360 }
361
362 #[test]
363 fn hash_uri_different_inputs() {
364 let a = SchemaCache::<MockClient>::hash_uri("https://example.com/a.json");
365 let b = SchemaCache::<MockClient>::hash_uri("https://example.com/b.json");
366 assert_ne!(a, b);
367 }
368
369 #[allow(clippy::needless_pass_by_value)]
371 fn boxerr(e: Box<dyn Error + Send + Sync>) -> anyhow::Error {
372 anyhow::anyhow!("{e}")
373 }
374
375 #[tokio::test]
376 async fn fetch_no_cache_dir() -> anyhow::Result<()> {
377 let client = mock(&[("https://example.com/s.json", r#"{"type":"object"}"#)]);
378 let cache = SchemaCache::new(None, client, false, None);
379 let (val, status) = cache
380 .fetch("https://example.com/s.json")
381 .await
382 .map_err(boxerr)?;
383 assert_eq!(val, serde_json::json!({"type": "object"}));
384 assert_eq!(status, CacheStatus::Disabled);
385 Ok(())
386 }
387
388 #[tokio::test]
389 async fn fetch_cold_cache() -> anyhow::Result<()> {
390 let tmp = tempfile::tempdir()?;
391 let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
392 let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, false, None);
393 let (val, status) = cache
394 .fetch("https://example.com/s.json")
395 .await
396 .map_err(boxerr)?;
397 assert_eq!(val, serde_json::json!({"type": "string"}));
398 assert_eq!(status, CacheStatus::Miss);
399
400 let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
402 let cache_path = tmp.path().join(format!("{hash}.json"));
403 assert!(cache_path.exists());
404 Ok(())
405 }
406
407 #[tokio::test]
408 async fn fetch_warm_cache() -> anyhow::Result<()> {
409 let tmp = tempfile::tempdir()?;
410 let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
411 let cache_path = tmp.path().join(format!("{hash}.json"));
412 fs::write(&cache_path, r#"{"type":"number"}"#)?;
413
414 let client = mock(&[]);
416 let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, false, None);
417 let (val, status) = cache
418 .fetch("https://example.com/s.json")
419 .await
420 .map_err(boxerr)?;
421 assert_eq!(val, serde_json::json!({"type": "number"}));
422 assert_eq!(status, CacheStatus::Hit);
423 Ok(())
424 }
425
426 #[tokio::test]
427 async fn fetch_skip_read_bypasses_cache() -> anyhow::Result<()> {
428 let tmp = tempfile::tempdir()?;
429 let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
430 let cache_path = tmp.path().join(format!("{hash}.json"));
431 fs::write(&cache_path, r#"{"type":"number"}"#)?;
432
433 let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
435 let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, true, None);
436 let (val, status) = cache
437 .fetch("https://example.com/s.json")
438 .await
439 .map_err(boxerr)?;
440 assert_eq!(val, serde_json::json!({"type": "string"}));
441 assert_eq!(status, CacheStatus::Miss);
442 Ok(())
443 }
444
445 #[tokio::test]
446 async fn fetch_client_error() {
447 let client = mock(&[]);
448 let cache = SchemaCache::new(None, client, false, None);
449 assert!(
450 cache
451 .fetch("https://example.com/missing.json")
452 .await
453 .is_err()
454 );
455 }
456
457 #[tokio::test]
458 async fn fetch_invalid_json() {
459 let client = mock(&[("https://example.com/bad.json", "not json")]);
460 let cache = SchemaCache::new(None, client, false, None);
461 assert!(cache.fetch("https://example.com/bad.json").await.is_err());
462 }
463
464 #[tokio::test]
465 async fn async_retrieve_trait_delegates() -> anyhow::Result<()> {
466 let client = mock(&[("https://example.com/s.json", r#"{"type":"object"}"#)]);
467 let cache = SchemaCache::new(None, client, false, None);
468 let uri: jsonschema::Uri<String> = "https://example.com/s.json".parse()?;
469 let val = jsonschema::AsyncRetrieve::retrieve(&cache, &uri)
470 .await
471 .map_err(boxerr)?;
472 assert_eq!(val, serde_json::json!({"type": "object"}));
473 Ok(())
474 }
475
476 #[tokio::test]
477 async fn fetch_expired_ttl_refetches() -> anyhow::Result<()> {
478 let tmp = tempfile::tempdir()?;
479 let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
480 let cache_path = tmp.path().join(format!("{hash}.json"));
481 fs::write(&cache_path, r#"{"type":"number"}"#)?;
482
483 let two_secs_ago = filetime::FileTime::from_system_time(
485 std::time::SystemTime::now() - core::time::Duration::from_secs(2),
486 );
487 filetime::set_file_mtime(&cache_path, two_secs_ago)?;
488
489 let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
491 let cache = SchemaCache::new(
492 Some(tmp.path().to_path_buf()),
493 client,
494 false,
495 Some(Duration::from_secs(1)),
496 );
497 let (val, status) = cache
498 .fetch("https://example.com/s.json")
499 .await
500 .map_err(boxerr)?;
501 assert_eq!(val, serde_json::json!({"type": "string"}));
502 assert_eq!(status, CacheStatus::Miss);
503 Ok(())
504 }
505
506 #[tokio::test]
507 async fn fetch_unexpired_ttl_serves_cache() -> anyhow::Result<()> {
508 let tmp = tempfile::tempdir()?;
509 let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
510 let cache_path = tmp.path().join(format!("{hash}.json"));
511 fs::write(&cache_path, r#"{"type":"number"}"#)?;
512
513 let client = mock(&[]);
515 let cache = SchemaCache::new(
516 Some(tmp.path().to_path_buf()),
517 client,
518 false,
519 Some(Duration::from_secs(3600)),
520 );
521 let (val, status) = cache
522 .fetch("https://example.com/s.json")
523 .await
524 .map_err(boxerr)?;
525 assert_eq!(val, serde_json::json!({"type": "number"}));
526 assert_eq!(status, CacheStatus::Hit);
527 Ok(())
528 }
529
530 #[test]
531 fn ensure_cache_dir_ends_with_schemas() {
532 let dir = ensure_cache_dir();
533 assert!(dir.ends_with("lintel/schemas"));
534 }
535}