Skip to main content

lintel_schema_cache/
lib.rs

1#![doc = include_str!("../README.md")]
2
3extern crate alloc;
4
5use alloc::sync::Arc;
6use core::error::Error;
7use core::hash::{Hash, Hasher};
8use core::time::Duration;
9use std::collections::HashMap;
10use std::collections::hash_map::DefaultHasher;
11use std::fs;
12use std::path::PathBuf;
13use std::sync::Mutex;
14
15/// Default TTL for cached schemas (12 hours).
16pub const DEFAULT_SCHEMA_CACHE_TTL: Duration = Duration::from_secs(12 * 60 * 60);
17
18use serde_json::Value;
19
20/// Whether a schema was served from disk cache or fetched from the network.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum CacheStatus {
23    /// Schema was found in the disk cache.
24    Hit,
25    /// Schema was fetched from the network (and possibly written to cache).
26    Miss,
27    /// Caching is disabled (`cache_dir` is `None`).
28    Disabled,
29}
30
31/// Response from a conditional HTTP request.
32pub struct ConditionalResponse {
33    /// Response body. `None` indicates a 304 Not Modified response.
34    pub body: Option<String>,
35    /// `ETag` header from the response, if present.
36    pub etag: Option<String>,
37}
38
39/// Trait for fetching content over HTTP.
40#[async_trait::async_trait]
41pub trait HttpClient: Clone + Send + Sync + 'static {
42    /// # Errors
43    ///
44    /// Returns an error if the HTTP request fails or the response cannot be read.
45    async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>>;
46
47    /// Fetch with conditional request support.
48    ///
49    /// If `etag` is `Some`, sends an `If-None-Match` header. Returns
50    /// `body: None` on 304 Not Modified.
51    ///
52    /// The default implementation ignores the `ETag` and delegates to [`get`](Self::get).
53    ///
54    /// # Errors
55    ///
56    /// Returns an error if the HTTP request fails or the response cannot be read.
57    async fn get_conditional(
58        &self,
59        uri: &str,
60        _etag: Option<&str>,
61    ) -> Result<ConditionalResponse, Box<dyn Error + Send + Sync>> {
62        let body = self.get(uri).await?;
63        Ok(ConditionalResponse {
64            body: Some(body),
65            etag: None,
66        })
67    }
68}
69
70/// Default HTTP client using reqwest.
71#[derive(Clone)]
72pub struct ReqwestClient(pub reqwest::Client);
73
74impl Default for ReqwestClient {
75    fn default() -> Self {
76        Self(reqwest::Client::new())
77    }
78}
79
80#[async_trait::async_trait]
81impl HttpClient for ReqwestClient {
82    async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
83        let resp = self.0.get(uri).send().await?.error_for_status()?;
84        Ok(resp.text().await?)
85    }
86
87    async fn get_conditional(
88        &self,
89        uri: &str,
90        etag: Option<&str>,
91    ) -> Result<ConditionalResponse, Box<dyn Error + Send + Sync>> {
92        let mut req = self.0.get(uri);
93        if let Some(etag) = etag {
94            req = req.header("If-None-Match", etag);
95        }
96        let resp = req.send().await?;
97        if resp.status() == reqwest::StatusCode::NOT_MODIFIED {
98            return Ok(ConditionalResponse {
99                body: None,
100                etag: None,
101            });
102        }
103        let resp = resp.error_for_status()?;
104        let etag = resp
105            .headers()
106            .get("etag")
107            .and_then(|v| v.to_str().ok())
108            .map(String::from);
109        let body = resp.text().await?;
110        Ok(ConditionalResponse {
111            body: Some(body),
112            etag,
113        })
114    }
115}
116
117/// A disk-backed cache for JSON Schema files.
118///
119/// Schemas are fetched via HTTP and stored as `<cache_dir>/<hash>.json`
120/// where `<hash>` is a hex-encoded hash of the URI. When a schema is
121/// requested, the cache is checked first; on a miss the schema is fetched
122/// and written to disk for future use.
123#[derive(Clone)]
124pub struct SchemaCache<C: HttpClient = ReqwestClient> {
125    cache_dir: Option<PathBuf>,
126    client: C,
127    skip_read: bool,
128    ttl: Option<Duration>,
129    /// In-memory cache shared across all clones via `Arc`.
130    memory_cache: Arc<Mutex<HashMap<String, Value>>>,
131}
132
133impl<C: HttpClient> SchemaCache<C> {
134    pub fn new(
135        cache_dir: Option<PathBuf>,
136        client: C,
137        skip_read: bool,
138        ttl: Option<Duration>,
139    ) -> Self {
140        Self {
141            cache_dir,
142            client,
143            skip_read,
144            ttl,
145            memory_cache: Arc::new(Mutex::new(HashMap::new())),
146        }
147    }
148
149    /// Fetch a schema by URI, using the disk cache when available.
150    ///
151    /// Returns the parsed schema and a [`CacheStatus`] indicating whether the
152    /// result came from the disk cache, the network, or caching was disabled.
153    ///
154    /// When `skip_read` is set, the cache read is skipped but fetched schemas
155    /// are still written to disk.
156    ///
157    /// # Errors
158    ///
159    /// Returns an error if the schema cannot be fetched from the network,
160    /// read from disk cache, or parsed as JSON.
161    #[allow(clippy::missing_panics_doc)] // Mutex poisoning is unreachable
162    #[tracing::instrument(skip(self), fields(status))]
163    pub async fn fetch(
164        &self,
165        uri: &str,
166    ) -> Result<(Value, CacheStatus), Box<dyn Error + Send + Sync>> {
167        // Check in-memory cache first (unless skip_read is set)
168        if !self.skip_read
169            && let Some(value) = self
170                .memory_cache
171                .lock()
172                .expect("memory cache poisoned")
173                .get(uri)
174                .cloned()
175        {
176            tracing::Span::current().record("status", "memory_hit");
177            return Ok((value, CacheStatus::Hit));
178        }
179
180        // Check disk cache (unless skip_read is set)
181        let mut stored_etag: Option<String> = None;
182        let mut cached_content: Option<String> = None;
183
184        if let Some(ref cache_dir) = self.cache_dir {
185            let hash = Self::hash_uri(uri);
186            let cache_path = cache_dir.join(format!("{hash}.json"));
187            let etag_path = cache_dir.join(format!("{hash}.etag"));
188
189            if cache_path.exists() {
190                if !self.skip_read && !self.is_expired(&cache_path) {
191                    // Fresh cache — return immediately
192                    if let Ok(content) = tokio::fs::read_to_string(&cache_path).await
193                        && let Ok(value) = serde_json::from_str::<Value>(&content)
194                    {
195                        self.memory_cache
196                            .lock()
197                            .expect("memory cache poisoned")
198                            .insert(uri.to_string(), value.clone());
199                        tracing::Span::current().record("status", "cache_hit");
200                        return Ok((value, CacheStatus::Hit));
201                    }
202                }
203
204                // Stale or skip_read — read ETag for conditional fetch
205                if let Ok(etag) = tokio::fs::read_to_string(&etag_path).await {
206                    stored_etag = Some(etag);
207                }
208                // Keep cached content for 304 fallback
209                if let Ok(content) = tokio::fs::read_to_string(&cache_path).await {
210                    cached_content = Some(content);
211                }
212            }
213        }
214
215        // Conditional network fetch
216        tracing::Span::current().record("status", "network_fetch");
217        let conditional = self
218            .client
219            .get_conditional(uri, stored_etag.as_deref())
220            .await?;
221
222        if conditional.body.is_none() {
223            // 304 Not Modified — use cached content
224            if let Some(content) = cached_content {
225                let value: Value = serde_json::from_str(&content)?;
226                self.memory_cache
227                    .lock()
228                    .expect("memory cache poisoned")
229                    .insert(uri.to_string(), value.clone());
230
231                // Touch the cache file to reset TTL
232                if let Some(ref cache_dir) = self.cache_dir {
233                    let hash = Self::hash_uri(uri);
234                    let cache_path = cache_dir.join(format!("{hash}.json"));
235                    let now = filetime::FileTime::now();
236                    let _ = filetime::set_file_mtime(&cache_path, now);
237                }
238
239                tracing::Span::current().record("status", "etag_hit");
240                return Ok((value, CacheStatus::Hit));
241            }
242        }
243
244        let body = conditional.body.expect("non-304 response must have a body");
245        let value: Value = serde_json::from_str(&body)?;
246
247        // Populate in-memory cache
248        self.memory_cache
249            .lock()
250            .expect("memory cache poisoned")
251            .insert(uri.to_string(), value.clone());
252
253        let status = if let Some(ref cache_dir) = self.cache_dir {
254            let hash = Self::hash_uri(uri);
255            let cache_path = cache_dir.join(format!("{hash}.json"));
256            let etag_path = cache_dir.join(format!("{hash}.etag"));
257            if let Err(e) = tokio::fs::write(&cache_path, &body).await {
258                tracing::warn!(
259                    path = %cache_path.display(),
260                    error = %e,
261                    "failed to write schema to disk cache"
262                );
263            }
264            // Write ETag if present
265            if let Some(etag) = conditional.etag {
266                let _ = tokio::fs::write(&etag_path, &etag).await;
267            }
268            CacheStatus::Miss
269        } else {
270            CacheStatus::Disabled
271        };
272
273        Ok((value, status))
274    }
275
276    /// Check whether a cached file has exceeded the configured TTL.
277    ///
278    /// Returns `false` (not expired) when:
279    /// - No TTL is configured (`self.ttl` is `None`)
280    /// - The file metadata or mtime cannot be read (graceful degradation)
281    fn is_expired(&self, path: &std::path::Path) -> bool {
282        let Some(ttl) = self.ttl else {
283            return false;
284        };
285        fs::metadata(path)
286            .ok()
287            .and_then(|m| m.modified().ok())
288            .and_then(|mtime| mtime.elapsed().ok())
289            .is_some_and(|age| age > ttl)
290    }
291
292    pub fn hash_uri(uri: &str) -> String {
293        let mut hasher = DefaultHasher::new();
294        uri.hash(&mut hasher);
295        format!("{:016x}", hasher.finish())
296    }
297}
298
299/// Return a usable cache directory for schemas, creating it if necessary.
300///
301/// Tries `<system_cache>/lintel/schemas` first, falling back to
302/// `<temp_dir>/lintel/schemas` when the preferred path is unwritable.
303pub fn ensure_cache_dir() -> PathBuf {
304    let candidates = [
305        dirs::cache_dir().map(|d| d.join("lintel").join("schemas")),
306        Some(std::env::temp_dir().join("lintel").join("schemas")),
307    ];
308    for candidate in candidates.into_iter().flatten() {
309        if fs::create_dir_all(&candidate).is_ok() {
310            return candidate;
311        }
312    }
313    std::env::temp_dir().join("lintel").join("schemas")
314}
315
316// -- jsonschema trait impls --------------------------------------------------
317
318#[async_trait::async_trait]
319impl<C: HttpClient> jsonschema::AsyncRetrieve for SchemaCache<C> {
320    async fn retrieve(
321        &self,
322        uri: &jsonschema::Uri<String>,
323    ) -> Result<Value, Box<dyn Error + Send + Sync>> {
324        let (value, _status) = self.fetch(uri.as_str()).await?;
325        Ok(value)
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[derive(Clone)]
334    struct MockClient(HashMap<String, String>);
335
336    #[async_trait::async_trait]
337    impl HttpClient for MockClient {
338        async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
339            self.0
340                .get(uri)
341                .cloned()
342                .ok_or_else(|| format!("mock: no response for {uri}").into())
343        }
344    }
345
346    fn mock(entries: &[(&str, &str)]) -> MockClient {
347        MockClient(
348            entries
349                .iter()
350                .map(|(k, v)| (k.to_string(), v.to_string()))
351                .collect(),
352        )
353    }
354
355    #[test]
356    fn hash_uri_deterministic() {
357        let a = SchemaCache::<MockClient>::hash_uri("https://example.com/schema.json");
358        let b = SchemaCache::<MockClient>::hash_uri("https://example.com/schema.json");
359        assert_eq!(a, b);
360    }
361
362    #[test]
363    fn hash_uri_different_inputs() {
364        let a = SchemaCache::<MockClient>::hash_uri("https://example.com/a.json");
365        let b = SchemaCache::<MockClient>::hash_uri("https://example.com/b.json");
366        assert_ne!(a, b);
367    }
368
369    /// Convert a `Box<dyn Error + Send + Sync>` to `anyhow::Error`.
370    #[allow(clippy::needless_pass_by_value)]
371    fn boxerr(e: Box<dyn Error + Send + Sync>) -> anyhow::Error {
372        anyhow::anyhow!("{e}")
373    }
374
375    #[tokio::test]
376    async fn fetch_no_cache_dir() -> anyhow::Result<()> {
377        let client = mock(&[("https://example.com/s.json", r#"{"type":"object"}"#)]);
378        let cache = SchemaCache::new(None, client, false, None);
379        let (val, status) = cache
380            .fetch("https://example.com/s.json")
381            .await
382            .map_err(boxerr)?;
383        assert_eq!(val, serde_json::json!({"type": "object"}));
384        assert_eq!(status, CacheStatus::Disabled);
385        Ok(())
386    }
387
388    #[tokio::test]
389    async fn fetch_cold_cache() -> anyhow::Result<()> {
390        let tmp = tempfile::tempdir()?;
391        let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
392        let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, false, None);
393        let (val, status) = cache
394            .fetch("https://example.com/s.json")
395            .await
396            .map_err(boxerr)?;
397        assert_eq!(val, serde_json::json!({"type": "string"}));
398        assert_eq!(status, CacheStatus::Miss);
399
400        // Verify file was written to disk
401        let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
402        let cache_path = tmp.path().join(format!("{hash}.json"));
403        assert!(cache_path.exists());
404        Ok(())
405    }
406
407    #[tokio::test]
408    async fn fetch_warm_cache() -> anyhow::Result<()> {
409        let tmp = tempfile::tempdir()?;
410        let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
411        let cache_path = tmp.path().join(format!("{hash}.json"));
412        fs::write(&cache_path, r#"{"type":"number"}"#)?;
413
414        // Client has no entries — if it were called, it would error
415        let client = mock(&[]);
416        let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, false, None);
417        let (val, status) = cache
418            .fetch("https://example.com/s.json")
419            .await
420            .map_err(boxerr)?;
421        assert_eq!(val, serde_json::json!({"type": "number"}));
422        assert_eq!(status, CacheStatus::Hit);
423        Ok(())
424    }
425
426    #[tokio::test]
427    async fn fetch_skip_read_bypasses_cache() -> anyhow::Result<()> {
428        let tmp = tempfile::tempdir()?;
429        let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
430        let cache_path = tmp.path().join(format!("{hash}.json"));
431        fs::write(&cache_path, r#"{"type":"number"}"#)?;
432
433        // With skip_read, the cached value is ignored and the client is called
434        let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
435        let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, true, None);
436        let (val, status) = cache
437            .fetch("https://example.com/s.json")
438            .await
439            .map_err(boxerr)?;
440        assert_eq!(val, serde_json::json!({"type": "string"}));
441        assert_eq!(status, CacheStatus::Miss);
442        Ok(())
443    }
444
445    #[tokio::test]
446    async fn fetch_client_error() {
447        let client = mock(&[]);
448        let cache = SchemaCache::new(None, client, false, None);
449        assert!(
450            cache
451                .fetch("https://example.com/missing.json")
452                .await
453                .is_err()
454        );
455    }
456
457    #[tokio::test]
458    async fn fetch_invalid_json() {
459        let client = mock(&[("https://example.com/bad.json", "not json")]);
460        let cache = SchemaCache::new(None, client, false, None);
461        assert!(cache.fetch("https://example.com/bad.json").await.is_err());
462    }
463
464    #[tokio::test]
465    async fn async_retrieve_trait_delegates() -> anyhow::Result<()> {
466        let client = mock(&[("https://example.com/s.json", r#"{"type":"object"}"#)]);
467        let cache = SchemaCache::new(None, client, false, None);
468        let uri: jsonschema::Uri<String> = "https://example.com/s.json".parse()?;
469        let val = jsonschema::AsyncRetrieve::retrieve(&cache, &uri)
470            .await
471            .map_err(boxerr)?;
472        assert_eq!(val, serde_json::json!({"type": "object"}));
473        Ok(())
474    }
475
476    #[tokio::test]
477    async fn fetch_expired_ttl_refetches() -> anyhow::Result<()> {
478        let tmp = tempfile::tempdir()?;
479        let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
480        let cache_path = tmp.path().join(format!("{hash}.json"));
481        fs::write(&cache_path, r#"{"type":"number"}"#)?;
482
483        // Set mtime to 2 seconds ago
484        let two_secs_ago = filetime::FileTime::from_system_time(
485            std::time::SystemTime::now() - core::time::Duration::from_secs(2),
486        );
487        filetime::set_file_mtime(&cache_path, two_secs_ago)?;
488
489        // TTL of 1 second — the cached file is stale
490        let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
491        let cache = SchemaCache::new(
492            Some(tmp.path().to_path_buf()),
493            client,
494            false,
495            Some(Duration::from_secs(1)),
496        );
497        let (val, status) = cache
498            .fetch("https://example.com/s.json")
499            .await
500            .map_err(boxerr)?;
501        assert_eq!(val, serde_json::json!({"type": "string"}));
502        assert_eq!(status, CacheStatus::Miss);
503        Ok(())
504    }
505
506    #[tokio::test]
507    async fn fetch_unexpired_ttl_serves_cache() -> anyhow::Result<()> {
508        let tmp = tempfile::tempdir()?;
509        let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
510        let cache_path = tmp.path().join(format!("{hash}.json"));
511        fs::write(&cache_path, r#"{"type":"number"}"#)?;
512
513        // TTL of 1 hour — the file was just written, so it's fresh
514        let client = mock(&[]);
515        let cache = SchemaCache::new(
516            Some(tmp.path().to_path_buf()),
517            client,
518            false,
519            Some(Duration::from_secs(3600)),
520        );
521        let (val, status) = cache
522            .fetch("https://example.com/s.json")
523            .await
524            .map_err(boxerr)?;
525        assert_eq!(val, serde_json::json!({"type": "number"}));
526        assert_eq!(status, CacheStatus::Hit);
527        Ok(())
528    }
529
530    #[test]
531    fn ensure_cache_dir_ends_with_schemas() {
532        let dir = ensure_cache_dir();
533        assert!(dir.ends_with("lintel/schemas"));
534    }
535}