Skip to main content

lintel_schema_cache/
lib.rs

1use std::collections::HashMap;
2use std::collections::hash_map::DefaultHasher;
3use std::error::Error;
4use std::fs;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9
10/// Default TTL for cached schemas (12 hours).
11pub const DEFAULT_SCHEMA_CACHE_TTL: Duration = Duration::from_secs(12 * 60 * 60);
12
13use serde_json::Value;
14
15/// Whether a schema was served from disk cache or fetched from the network.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum CacheStatus {
18    /// Schema was found in the disk cache.
19    Hit,
20    /// Schema was fetched from the network (and possibly written to cache).
21    Miss,
22    /// Caching is disabled (`cache_dir` is `None`).
23    Disabled,
24}
25
26/// Trait for fetching content over HTTP.
27#[async_trait::async_trait]
28pub trait HttpClient: Clone + Send + Sync + 'static {
29    /// # Errors
30    ///
31    /// Returns an error if the HTTP request fails or the response cannot be read.
32    async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>>;
33}
34
35/// Default HTTP client using reqwest.
36#[derive(Clone)]
37pub struct ReqwestClient(pub reqwest::Client);
38
39impl Default for ReqwestClient {
40    fn default() -> Self {
41        Self(reqwest::Client::new())
42    }
43}
44
45#[async_trait::async_trait]
46impl HttpClient for ReqwestClient {
47    async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
48        let resp = self.0.get(uri).send().await?.error_for_status()?;
49        Ok(resp.text().await?)
50    }
51}
52
53/// A disk-backed cache for JSON Schema files.
54///
55/// Schemas are fetched via HTTP and stored as `<cache_dir>/<hash>.json`
56/// where `<hash>` is a hex-encoded hash of the URI. When a schema is
57/// requested, the cache is checked first; on a miss the schema is fetched
58/// and written to disk for future use.
59#[derive(Clone)]
60pub struct SchemaCache<C: HttpClient = ReqwestClient> {
61    cache_dir: Option<PathBuf>,
62    client: C,
63    skip_read: bool,
64    ttl: Option<Duration>,
65    /// In-memory cache shared across all clones via `Arc`.
66    memory_cache: Arc<Mutex<HashMap<String, Value>>>,
67}
68
69impl<C: HttpClient> SchemaCache<C> {
70    pub fn new(
71        cache_dir: Option<PathBuf>,
72        client: C,
73        skip_read: bool,
74        ttl: Option<Duration>,
75    ) -> Self {
76        Self {
77            cache_dir,
78            client,
79            skip_read,
80            ttl,
81            memory_cache: Arc::new(Mutex::new(HashMap::new())),
82        }
83    }
84
85    /// Fetch a schema by URI, using the disk cache when available.
86    ///
87    /// Returns the parsed schema and a [`CacheStatus`] indicating whether the
88    /// result came from the disk cache, the network, or caching was disabled.
89    ///
90    /// When `skip_read` is set, the cache read is skipped but fetched schemas
91    /// are still written to disk.
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if the schema cannot be fetched from the network,
96    /// read from disk cache, or parsed as JSON.
97    #[allow(clippy::missing_panics_doc)] // Mutex poisoning is unreachable
98    #[tracing::instrument(skip(self), fields(status))]
99    pub async fn fetch(
100        &self,
101        uri: &str,
102    ) -> Result<(Value, CacheStatus), Box<dyn Error + Send + Sync>> {
103        // Check in-memory cache first (unless skip_read is set)
104        if !self.skip_read
105            && let Some(value) = self
106                .memory_cache
107                .lock()
108                .expect("memory cache poisoned")
109                .get(uri)
110                .cloned()
111        {
112            tracing::Span::current().record("status", "memory_hit");
113            return Ok((value, CacheStatus::Hit));
114        }
115
116        // Check disk cache (unless skip_read is set)
117        if !self.skip_read
118            && let Some(ref cache_dir) = self.cache_dir
119        {
120            let hash = Self::hash_uri(uri);
121            let cache_path = cache_dir.join(format!("{hash}.json"));
122            if cache_path.exists()
123                && !self.is_expired(&cache_path)
124                && let Ok(content) = tokio::fs::read_to_string(&cache_path).await
125                && let Ok(value) = serde_json::from_str::<Value>(&content)
126            {
127                self.memory_cache
128                    .lock()
129                    .expect("memory cache poisoned")
130                    .insert(uri.to_string(), value.clone());
131                tracing::Span::current().record("status", "cache_hit");
132                return Ok((value, CacheStatus::Hit));
133            }
134        }
135
136        // Fetch from network
137        tracing::Span::current().record("status", "network_fetch");
138        let body = self.client.get(uri).await?;
139        let value: Value = serde_json::from_str(&body)?;
140
141        // Populate in-memory cache
142        self.memory_cache
143            .lock()
144            .expect("memory cache poisoned")
145            .insert(uri.to_string(), value.clone());
146
147        let status = if let Some(ref cache_dir) = self.cache_dir {
148            let hash = Self::hash_uri(uri);
149            let cache_path = cache_dir.join(format!("{hash}.json"));
150            if let Err(e) = tokio::fs::write(&cache_path, &body).await {
151                tracing::warn!(
152                    path = %cache_path.display(),
153                    error = %e,
154                    "failed to write schema to disk cache"
155                );
156            }
157            CacheStatus::Miss
158        } else {
159            CacheStatus::Disabled
160        };
161
162        Ok((value, status))
163    }
164
165    /// Check whether a cached file has exceeded the configured TTL.
166    ///
167    /// Returns `false` (not expired) when:
168    /// - No TTL is configured (`self.ttl` is `None`)
169    /// - The file metadata or mtime cannot be read (graceful degradation)
170    fn is_expired(&self, path: &std::path::Path) -> bool {
171        let Some(ttl) = self.ttl else {
172            return false;
173        };
174        fs::metadata(path)
175            .ok()
176            .and_then(|m| m.modified().ok())
177            .and_then(|mtime| mtime.elapsed().ok())
178            .is_some_and(|age| age > ttl)
179    }
180
181    fn hash_uri(uri: &str) -> String {
182        let mut hasher = DefaultHasher::new();
183        uri.hash(&mut hasher);
184        format!("{:016x}", hasher.finish())
185    }
186}
187
188/// Return a usable cache directory for schemas, creating it if necessary.
189///
190/// Tries `<system_cache>/lintel/schemas` first, falling back to
191/// `<temp_dir>/lintel/schemas` when the preferred path is unwritable.
192pub fn ensure_cache_dir() -> PathBuf {
193    let candidates = [
194        dirs::cache_dir().map(|d| d.join("lintel").join("schemas")),
195        Some(std::env::temp_dir().join("lintel").join("schemas")),
196    ];
197    for candidate in candidates.into_iter().flatten() {
198        if fs::create_dir_all(&candidate).is_ok() {
199            return candidate;
200        }
201    }
202    std::env::temp_dir().join("lintel").join("schemas")
203}
204
205// -- jsonschema trait impls --------------------------------------------------
206
207#[async_trait::async_trait]
208impl<C: HttpClient> jsonschema::AsyncRetrieve for SchemaCache<C> {
209    async fn retrieve(
210        &self,
211        uri: &jsonschema::Uri<String>,
212    ) -> Result<Value, Box<dyn Error + Send + Sync>> {
213        let (value, _status) = self.fetch(uri.as_str()).await?;
214        Ok(value)
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[derive(Clone)]
223    struct MockClient(HashMap<String, String>);
224
225    #[async_trait::async_trait]
226    impl HttpClient for MockClient {
227        async fn get(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
228            self.0
229                .get(uri)
230                .cloned()
231                .ok_or_else(|| format!("mock: no response for {uri}").into())
232        }
233    }
234
235    fn mock(entries: &[(&str, &str)]) -> MockClient {
236        MockClient(
237            entries
238                .iter()
239                .map(|(k, v)| (k.to_string(), v.to_string()))
240                .collect(),
241        )
242    }
243
244    #[test]
245    fn hash_uri_deterministic() {
246        let a = SchemaCache::<MockClient>::hash_uri("https://example.com/schema.json");
247        let b = SchemaCache::<MockClient>::hash_uri("https://example.com/schema.json");
248        assert_eq!(a, b);
249    }
250
251    #[test]
252    fn hash_uri_different_inputs() {
253        let a = SchemaCache::<MockClient>::hash_uri("https://example.com/a.json");
254        let b = SchemaCache::<MockClient>::hash_uri("https://example.com/b.json");
255        assert_ne!(a, b);
256    }
257
258    /// Convert a `Box<dyn Error + Send + Sync>` to `anyhow::Error`.
259    #[allow(clippy::needless_pass_by_value)]
260    fn boxerr(e: Box<dyn Error + Send + Sync>) -> anyhow::Error {
261        anyhow::anyhow!("{e}")
262    }
263
264    #[tokio::test]
265    async fn fetch_no_cache_dir() -> anyhow::Result<()> {
266        let client = mock(&[("https://example.com/s.json", r#"{"type":"object"}"#)]);
267        let cache = SchemaCache::new(None, client, false, None);
268        let (val, status) = cache
269            .fetch("https://example.com/s.json")
270            .await
271            .map_err(boxerr)?;
272        assert_eq!(val, serde_json::json!({"type": "object"}));
273        assert_eq!(status, CacheStatus::Disabled);
274        Ok(())
275    }
276
277    #[tokio::test]
278    async fn fetch_cold_cache() -> anyhow::Result<()> {
279        let tmp = tempfile::tempdir()?;
280        let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
281        let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, false, None);
282        let (val, status) = cache
283            .fetch("https://example.com/s.json")
284            .await
285            .map_err(boxerr)?;
286        assert_eq!(val, serde_json::json!({"type": "string"}));
287        assert_eq!(status, CacheStatus::Miss);
288
289        // Verify file was written to disk
290        let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
291        let cache_path = tmp.path().join(format!("{hash}.json"));
292        assert!(cache_path.exists());
293        Ok(())
294    }
295
296    #[tokio::test]
297    async fn fetch_warm_cache() -> anyhow::Result<()> {
298        let tmp = tempfile::tempdir()?;
299        let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
300        let cache_path = tmp.path().join(format!("{hash}.json"));
301        fs::write(&cache_path, r#"{"type":"number"}"#)?;
302
303        // Client has no entries — if it were called, it would error
304        let client = mock(&[]);
305        let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, false, None);
306        let (val, status) = cache
307            .fetch("https://example.com/s.json")
308            .await
309            .map_err(boxerr)?;
310        assert_eq!(val, serde_json::json!({"type": "number"}));
311        assert_eq!(status, CacheStatus::Hit);
312        Ok(())
313    }
314
315    #[tokio::test]
316    async fn fetch_skip_read_bypasses_cache() -> anyhow::Result<()> {
317        let tmp = tempfile::tempdir()?;
318        let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
319        let cache_path = tmp.path().join(format!("{hash}.json"));
320        fs::write(&cache_path, r#"{"type":"number"}"#)?;
321
322        // With skip_read, the cached value is ignored and the client is called
323        let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
324        let cache = SchemaCache::new(Some(tmp.path().to_path_buf()), client, true, None);
325        let (val, status) = cache
326            .fetch("https://example.com/s.json")
327            .await
328            .map_err(boxerr)?;
329        assert_eq!(val, serde_json::json!({"type": "string"}));
330        assert_eq!(status, CacheStatus::Miss);
331        Ok(())
332    }
333
334    #[tokio::test]
335    async fn fetch_client_error() {
336        let client = mock(&[]);
337        let cache = SchemaCache::new(None, client, false, None);
338        assert!(
339            cache
340                .fetch("https://example.com/missing.json")
341                .await
342                .is_err()
343        );
344    }
345
346    #[tokio::test]
347    async fn fetch_invalid_json() {
348        let client = mock(&[("https://example.com/bad.json", "not json")]);
349        let cache = SchemaCache::new(None, client, false, None);
350        assert!(cache.fetch("https://example.com/bad.json").await.is_err());
351    }
352
353    #[tokio::test]
354    async fn async_retrieve_trait_delegates() -> anyhow::Result<()> {
355        let client = mock(&[("https://example.com/s.json", r#"{"type":"object"}"#)]);
356        let cache = SchemaCache::new(None, client, false, None);
357        let uri: jsonschema::Uri<String> = "https://example.com/s.json".parse()?;
358        let val = jsonschema::AsyncRetrieve::retrieve(&cache, &uri)
359            .await
360            .map_err(boxerr)?;
361        assert_eq!(val, serde_json::json!({"type": "object"}));
362        Ok(())
363    }
364
365    #[tokio::test]
366    async fn fetch_expired_ttl_refetches() -> anyhow::Result<()> {
367        let tmp = tempfile::tempdir()?;
368        let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
369        let cache_path = tmp.path().join(format!("{hash}.json"));
370        fs::write(&cache_path, r#"{"type":"number"}"#)?;
371
372        // Set mtime to 2 seconds ago
373        let two_secs_ago = filetime::FileTime::from_system_time(
374            std::time::SystemTime::now() - std::time::Duration::from_secs(2),
375        );
376        filetime::set_file_mtime(&cache_path, two_secs_ago)?;
377
378        // TTL of 1 second — the cached file is stale
379        let client = mock(&[("https://example.com/s.json", r#"{"type":"string"}"#)]);
380        let cache = SchemaCache::new(
381            Some(tmp.path().to_path_buf()),
382            client,
383            false,
384            Some(Duration::from_secs(1)),
385        );
386        let (val, status) = cache
387            .fetch("https://example.com/s.json")
388            .await
389            .map_err(boxerr)?;
390        assert_eq!(val, serde_json::json!({"type": "string"}));
391        assert_eq!(status, CacheStatus::Miss);
392        Ok(())
393    }
394
395    #[tokio::test]
396    async fn fetch_unexpired_ttl_serves_cache() -> anyhow::Result<()> {
397        let tmp = tempfile::tempdir()?;
398        let hash = SchemaCache::<MockClient>::hash_uri("https://example.com/s.json");
399        let cache_path = tmp.path().join(format!("{hash}.json"));
400        fs::write(&cache_path, r#"{"type":"number"}"#)?;
401
402        // TTL of 1 hour — the file was just written, so it's fresh
403        let client = mock(&[]);
404        let cache = SchemaCache::new(
405            Some(tmp.path().to_path_buf()),
406            client,
407            false,
408            Some(Duration::from_secs(3600)),
409        );
410        let (val, status) = cache
411            .fetch("https://example.com/s.json")
412            .await
413            .map_err(boxerr)?;
414        assert_eq!(val, serde_json::json!({"type": "number"}));
415        assert_eq!(status, CacheStatus::Hit);
416        Ok(())
417    }
418
419    #[test]
420    fn ensure_cache_dir_ends_with_schemas() {
421        let dir = ensure_cache_dir();
422        assert!(dir.ends_with("lintel/schemas"));
423    }
424}