Skip to main content

lintel_schema_cache/
lib.rs

1#![doc = include_str!("../README.md")]
2
3extern crate alloc;
4
5use alloc::sync::Arc;
6use core::error::Error;
7use core::time::Duration;
8use std::collections::HashMap;
9use std::fs;
10use std::path::PathBuf;
11use std::sync::Mutex;
12
13use serde_json::Value;
14use sha2::{Digest, Sha256};
15
16/// Default TTL for cached schemas (12 hours).
17pub const DEFAULT_SCHEMA_CACHE_TTL: Duration = Duration::from_secs(12 * 60 * 60);
18
19/// Whether a schema was served from disk cache or fetched from the network.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum CacheStatus {
22    /// Schema was found in the disk cache.
23    Hit,
24    /// Schema was fetched from the network (and possibly written to cache).
25    Miss,
26    /// Caching is disabled (`cache_dir` is `None`).
27    Disabled,
28}
29
30impl core::fmt::Display for CacheStatus {
31    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
32        match self {
33            Self::Hit => f.write_str("cached"),
34            Self::Miss => f.write_str("fetched"),
35            Self::Disabled => f.write_str("fetched (no cache)"),
36        }
37    }
38}
39
40/// Response from a conditional HTTP request.
41struct ConditionalResponse {
42    /// Response body. `None` indicates a 304 Not Modified response.
43    body: Option<String>,
44    /// `ETag` header from the response, if present.
45    etag: Option<String>,
46}
47
48/// Internal HTTP backend.
49enum HttpMode {
50    /// Production mode — uses reqwest for HTTP requests.
51    Reqwest(reqwest::Client),
52    /// Test mode — no HTTP, no disk. Only serves from memory cache.
53    Memory,
54}
55
56/// A disk-backed schema cache with HTTP fetching and JSON parsing.
57///
58/// Schemas are fetched via HTTP and stored as `<cache_dir>/<hash>.json`
59/// where `<hash>` is a SHA-256 hex digest of the URI. When a schema is
60/// requested, the cache is checked first; on a miss the schema is fetched
61/// and written to disk for future use.
62#[derive(Clone)]
63pub struct SchemaCache {
64    cache_dir: Option<PathBuf>,
65    http: Arc<HttpMode>,
66    skip_read: bool,
67    ttl: Option<Duration>,
68    /// In-memory cache shared across all clones via `Arc`.
69    memory_cache: Arc<Mutex<HashMap<String, Value>>>,
70    /// SHA-256 hex digests of the raw content fetched for each URI.
71    content_hashes: Arc<Mutex<HashMap<String, String>>>,
72    /// Semaphore that limits concurrent HTTP requests across all callers.
73    http_semaphore: Arc<tokio::sync::Semaphore>,
74}
75
76/// Builder for constructing a [`SchemaCache`] with sensible defaults.
77///
78/// Defaults:
79/// - `cache_dir`: [`ensure_cache_dir()`]
80/// - `force_fetch`: `false`
81/// - `ttl`: [`DEFAULT_SCHEMA_CACHE_TTL`] (12 hours)
82///
83/// # Examples
84///
85/// ```rust,ignore
86/// let cache = SchemaCache::builder().build();
87/// let cache = SchemaCache::builder().force_fetch(true).ttl(Duration::from_secs(3600)).build();
88/// ```
89/// Default maximum number of concurrent HTTP requests.
90const DEFAULT_MAX_CONCURRENT_REQUESTS: usize = 20;
91
92#[must_use]
93pub struct SchemaCacheBuilder {
94    cache_dir: Option<PathBuf>,
95    skip_read: bool,
96    ttl: Option<Duration>,
97    max_concurrent_requests: usize,
98}
99
100impl SchemaCacheBuilder {
101    /// Override the default cache directory.
102    pub fn cache_dir(mut self, dir: PathBuf) -> Self {
103        self.cache_dir = Some(dir);
104        self
105    }
106
107    /// When `true`, bypass cache reads and always fetch from the network.
108    /// Fetched schemas are still written to the cache.
109    pub fn force_fetch(mut self, force: bool) -> Self {
110        self.skip_read = force;
111        self
112    }
113
114    /// Override the default TTL for cached schemas.
115    pub fn ttl(mut self, ttl: Duration) -> Self {
116        self.ttl = Some(ttl);
117        self
118    }
119
120    /// Set the maximum number of concurrent HTTP requests.
121    pub fn max_concurrent_requests(mut self, n: usize) -> Self {
122        self.max_concurrent_requests = n;
123        self
124    }
125
126    /// Returns the cache directory that will be used, or [`ensure_cache_dir()`]
127    /// if none was explicitly set.
128    ///
129    /// Useful when callers need the resolved path before calling [`build`](Self::build).
130    pub fn cache_dir_or_default(&self) -> PathBuf {
131        self.cache_dir.clone().unwrap_or_else(ensure_cache_dir)
132    }
133
134    /// Build the [`SchemaCache`].
135    pub fn build(self) -> SchemaCache {
136        SchemaCache {
137            cache_dir: self.cache_dir,
138            http: Arc::new(HttpMode::Reqwest(reqwest::Client::new())),
139            skip_read: self.skip_read,
140            ttl: self.ttl,
141            memory_cache: Arc::new(Mutex::new(HashMap::new())),
142            content_hashes: Arc::new(Mutex::new(HashMap::new())),
143            http_semaphore: Arc::new(tokio::sync::Semaphore::new(self.max_concurrent_requests)),
144        }
145    }
146}
147
148impl SchemaCache {
149    /// Returns a builder pre-configured with sensible defaults.
150    ///
151    /// - `cache_dir` = [`ensure_cache_dir()`]
152    /// - `ttl` = [`DEFAULT_SCHEMA_CACHE_TTL`]
153    /// - `force_fetch` = `false`
154    pub fn builder() -> SchemaCacheBuilder {
155        SchemaCacheBuilder {
156            cache_dir: Some(ensure_cache_dir()),
157            skip_read: false,
158            ttl: Some(DEFAULT_SCHEMA_CACHE_TTL),
159            max_concurrent_requests: DEFAULT_MAX_CONCURRENT_REQUESTS,
160        }
161    }
162
163    /// Test constructor — memory-only, no HTTP, no disk.
164    ///
165    /// Pre-populate with [`insert`](Self::insert). Calls to [`fetch`](Self::fetch)
166    /// for unknown URIs will error.
167    pub fn memory() -> Self {
168        Self {
169            cache_dir: None,
170            http: Arc::new(HttpMode::Memory),
171            skip_read: false,
172            ttl: None,
173            memory_cache: Arc::new(Mutex::new(HashMap::new())),
174            content_hashes: Arc::new(Mutex::new(HashMap::new())),
175            http_semaphore: Arc::new(tokio::sync::Semaphore::new(DEFAULT_MAX_CONCURRENT_REQUESTS)),
176        }
177    }
178
179    /// Insert a value into the in-memory cache (useful for tests).
180    #[allow(clippy::missing_panics_doc)] // Mutex poisoning is unreachable
181    pub fn insert(&self, uri: &str, value: Value) {
182        self.memory_cache
183            .lock()
184            .expect("memory cache poisoned")
185            .insert(uri.to_string(), value);
186    }
187
188    /// Look up a schema by URI from the in-memory cache only.
189    ///
190    /// Returns `None` if the URI is not in memory. Does not check disk cache
191    /// or fetch from the network.
192    #[allow(clippy::missing_panics_doc)] // Mutex poisoning is unreachable
193    pub fn get(&self, uri: &str) -> Option<Value> {
194        self.memory_cache
195            .lock()
196            .expect("memory cache poisoned")
197            .get(uri)
198            .cloned()
199    }
200
201    /// Return the SHA-256 hex digest of the raw content last fetched for `uri`.
202    ///
203    /// Returns `None` if the URI has not been fetched or was inserted via
204    /// [`insert`](Self::insert) (which has no raw content to hash).
205    #[allow(clippy::missing_panics_doc)] // Mutex poisoning is unreachable
206    pub fn content_hash(&self, uri: &str) -> Option<String> {
207        self.content_hashes
208            .lock()
209            .expect("content hashes poisoned")
210            .get(uri)
211            .cloned()
212    }
213
214    /// Compute SHA-256 of raw content and store it keyed by URI.
215    fn store_content_hash(&self, uri: &str, content: &str) {
216        let hash = Self::hash_content(content);
217        self.content_hashes
218            .lock()
219            .expect("content hashes poisoned")
220            .insert(uri.to_string(), hash);
221    }
222
223    /// Compute the SHA-256 hash of arbitrary content, returned as a 64-char hex string.
224    pub fn hash_content(content: &str) -> String {
225        let mut hasher = Sha256::new();
226        hasher.update(content.as_bytes());
227        format!("{:x}", hasher.finalize())
228    }
229
230    /// Fetch a schema by URI, using the disk cache when available.
231    ///
232    /// Returns the parsed schema and a [`CacheStatus`] indicating whether the
233    /// result came from the disk cache, the network, or caching was disabled.
234    ///
235    /// When `skip_read` is set, the cache read is skipped but fetched schemas
236    /// are still written to disk.
237    ///
238    /// # Errors
239    ///
240    /// Returns an error if the schema cannot be fetched from the network,
241    /// read from disk cache, or parsed as JSON.
242    #[allow(clippy::missing_panics_doc)] // Mutex poisoning is unreachable
243    #[tracing::instrument(level = "debug", skip(self), fields(status))]
244    pub async fn fetch(
245        &self,
246        uri: &str,
247    ) -> Result<(Value, CacheStatus), Box<dyn Error + Send + Sync>> {
248        // Check in-memory cache first (unless skip_read is set)
249        if !self.skip_read
250            && let Some(value) = self
251                .memory_cache
252                .lock()
253                .expect("memory cache poisoned")
254                .get(uri)
255                .cloned()
256        {
257            tracing::Span::current().record("status", "memory_hit");
258            return Ok((value, CacheStatus::Hit));
259        }
260
261        // Memory-only mode: if not in cache, error out.
262        if matches!(*self.http, HttpMode::Memory) {
263            return Err(format!("memory-only cache: no entry for {uri}").into());
264        }
265
266        // Check disk cache (unless skip_read is set)
267        let mut stored_etag: Option<String> = None;
268        let mut cached_content: Option<String> = None;
269
270        if let Some(ref cache_dir) = self.cache_dir {
271            let hash = Self::hash_uri(uri);
272            let cache_path = cache_dir.join(format!("{hash}.json"));
273            let etag_path = cache_dir.join(format!("{hash}.etag"));
274
275            if cache_path.exists() {
276                if !self.skip_read && !self.is_expired(&cache_path) {
277                    // Fresh cache — return immediately
278                    if let Ok(content) = tokio::fs::read_to_string(&cache_path).await
279                        && let Ok(value) = serde_json::from_str::<Value>(&content)
280                    {
281                        self.store_content_hash(uri, &content);
282                        self.memory_cache
283                            .lock()
284                            .expect("memory cache poisoned")
285                            .insert(uri.to_string(), value.clone());
286                        tracing::Span::current().record("status", "cache_hit");
287                        return Ok((value, CacheStatus::Hit));
288                    }
289                }
290
291                // Stale or skip_read — read ETag for conditional fetch
292                if let Ok(etag) = tokio::fs::read_to_string(&etag_path).await {
293                    stored_etag = Some(etag);
294                }
295                // Keep cached content for 304 fallback
296                if let Ok(content) = tokio::fs::read_to_string(&cache_path).await {
297                    cached_content = Some(content);
298                }
299            }
300        }
301
302        // Acquire a permit before making the HTTP request
303        let _permit = self
304            .http_semaphore
305            .acquire()
306            .await
307            .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)?;
308
309        // Conditional network fetch
310        tracing::Span::current().record("status", "network_fetch");
311        let conditional = self.get_conditional(uri, stored_etag.as_deref()).await?;
312
313        if conditional.body.is_none() {
314            // 304 Not Modified — use cached content
315            if let Some(content) = cached_content {
316                let value: Value = serde_json::from_str(&content)?;
317                self.store_content_hash(uri, &content);
318                self.memory_cache
319                    .lock()
320                    .expect("memory cache poisoned")
321                    .insert(uri.to_string(), value.clone());
322
323                // Touch the cache file to reset TTL
324                if let Some(ref cache_dir) = self.cache_dir {
325                    let hash = Self::hash_uri(uri);
326                    let cache_path = cache_dir.join(format!("{hash}.json"));
327                    let now = filetime::FileTime::now();
328                    let _ = filetime::set_file_mtime(&cache_path, now);
329                }
330
331                tracing::Span::current().record("status", "etag_hit");
332                return Ok((value, CacheStatus::Hit));
333            }
334        }
335
336        let body = conditional.body.expect("non-304 response must have a body");
337        let value: Value = serde_json::from_str(&body)?;
338        self.store_content_hash(uri, &body);
339
340        // Populate in-memory cache
341        self.memory_cache
342            .lock()
343            .expect("memory cache poisoned")
344            .insert(uri.to_string(), value.clone());
345
346        let status = if let Some(ref cache_dir) = self.cache_dir {
347            let hash = Self::hash_uri(uri);
348            let cache_path = cache_dir.join(format!("{hash}.json"));
349            let etag_path = cache_dir.join(format!("{hash}.etag"));
350            if let Err(e) = tokio::fs::write(&cache_path, &body).await {
351                tracing::warn!(
352                    path = %cache_path.display(),
353                    error = %e,
354                    "failed to write schema to disk cache"
355                );
356            }
357            // Write ETag if present
358            if let Some(etag) = conditional.etag {
359                let _ = tokio::fs::write(&etag_path, &etag).await;
360            }
361            CacheStatus::Miss
362        } else {
363            CacheStatus::Disabled
364        };
365
366        Ok((value, status))
367    }
368
369    /// Check whether a cached file has exceeded the configured TTL.
370    ///
371    /// Returns `false` (not expired) when:
372    /// - No TTL is configured (`self.ttl` is `None`)
373    /// - The file metadata or mtime cannot be read (graceful degradation)
374    fn is_expired(&self, path: &std::path::Path) -> bool {
375        let Some(ttl) = self.ttl else {
376            return false;
377        };
378        fs::metadata(path)
379            .ok()
380            .and_then(|m| m.modified().ok())
381            .and_then(|mtime| mtime.elapsed().ok())
382            .is_some_and(|age| age > ttl)
383    }
384
385    /// Compute the SHA-256 hash of a URI, returned as a 64-char hex string.
386    pub fn hash_uri(uri: &str) -> String {
387        let mut hasher = Sha256::new();
388        hasher.update(uri.as_bytes());
389        format!("{:x}", hasher.finalize())
390    }
391
392    /// Internal: perform a conditional GET using reqwest.
393    async fn get_conditional(
394        &self,
395        uri: &str,
396        etag: Option<&str>,
397    ) -> Result<ConditionalResponse, Box<dyn Error + Send + Sync>> {
398        let HttpMode::Reqwest(ref client) = *self.http else {
399            return Err("HTTP not available in memory-only mode".into());
400        };
401
402        let mut req = client.get(uri);
403        if let Some(etag) = etag {
404            req = req.header("If-None-Match", etag);
405        }
406        let resp = req.send().await?;
407        if resp.status() == reqwest::StatusCode::NOT_MODIFIED {
408            return Ok(ConditionalResponse {
409                body: None,
410                etag: None,
411            });
412        }
413        let resp = resp.error_for_status()?;
414        let etag = resp
415            .headers()
416            .get("etag")
417            .and_then(|v| v.to_str().ok())
418            .map(String::from);
419        let body = resp.text().await?;
420        Ok(ConditionalResponse {
421            body: Some(body),
422            etag,
423        })
424    }
425}
426
427/// Return a usable cache directory for schemas, creating it if necessary.
428///
429/// Tries `<system_cache>/lintel/schemas` first, falling back to
430/// `<temp_dir>/lintel/schemas` when the preferred path is unwritable.
431pub fn ensure_cache_dir() -> PathBuf {
432    let candidates = [
433        dirs::cache_dir().map(|d| d.join("lintel").join("schemas")),
434        Some(std::env::temp_dir().join("lintel").join("schemas")),
435    ];
436    for candidate in candidates.into_iter().flatten() {
437        if fs::create_dir_all(&candidate).is_ok() {
438            return candidate;
439        }
440    }
441    std::env::temp_dir().join("lintel").join("schemas")
442}
443
444// -- jsonschema trait impls --------------------------------------------------
445
446#[async_trait::async_trait]
447impl jsonschema::AsyncRetrieve for SchemaCache {
448    async fn retrieve(
449        &self,
450        uri: &jsonschema::Uri<String>,
451    ) -> Result<Value, Box<dyn Error + Send + Sync>> {
452        let (value, _status) = self.fetch(uri.as_str()).await?;
453        Ok(value)
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn hash_uri_deterministic() {
463        let a = SchemaCache::hash_uri("https://example.com/schema.json");
464        let b = SchemaCache::hash_uri("https://example.com/schema.json");
465        assert_eq!(a, b);
466    }
467
468    #[test]
469    fn hash_uri_different_inputs() {
470        let a = SchemaCache::hash_uri("https://example.com/a.json");
471        let b = SchemaCache::hash_uri("https://example.com/b.json");
472        assert_ne!(a, b);
473    }
474
475    #[test]
476    fn hash_uri_is_64_hex_chars() {
477        let h = SchemaCache::hash_uri("https://example.com/schema.json");
478        assert_eq!(h.len(), 64);
479        assert!(h.chars().all(|c| c.is_ascii_hexdigit()));
480    }
481
482    /// Convert a `Box<dyn Error + Send + Sync>` to `anyhow::Error`.
483    #[allow(clippy::needless_pass_by_value)]
484    fn boxerr(e: Box<dyn Error + Send + Sync>) -> anyhow::Error {
485        anyhow::anyhow!("{e}")
486    }
487
488    #[tokio::test]
489    async fn memory_cache_insert_and_fetch() -> anyhow::Result<()> {
490        let cache = SchemaCache::memory();
491        cache.insert(
492            "https://example.com/s.json",
493            serde_json::json!({"type": "object"}),
494        );
495        let (val, status) = cache
496            .fetch("https://example.com/s.json")
497            .await
498            .map_err(boxerr)?;
499        assert_eq!(val, serde_json::json!({"type": "object"}));
500        assert_eq!(status, CacheStatus::Hit);
501        Ok(())
502    }
503
504    #[tokio::test]
505    async fn memory_cache_missing_uri_errors() {
506        let cache = SchemaCache::memory();
507        assert!(
508            cache
509                .fetch("https://example.com/missing.json")
510                .await
511                .is_err()
512        );
513    }
514
515    #[tokio::test]
516    async fn async_retrieve_trait_delegates() -> anyhow::Result<()> {
517        let cache = SchemaCache::memory();
518        cache.insert(
519            "https://example.com/s.json",
520            serde_json::json!({"type": "object"}),
521        );
522        let uri: jsonschema::Uri<String> = "https://example.com/s.json".parse()?;
523        let val = jsonschema::AsyncRetrieve::retrieve(&cache, &uri)
524            .await
525            .map_err(boxerr)?;
526        assert_eq!(val, serde_json::json!({"type": "object"}));
527        Ok(())
528    }
529
530    #[test]
531    fn ensure_cache_dir_ends_with_schemas() {
532        let dir = ensure_cache_dir();
533        assert!(dir.ends_with("lintel/schemas"));
534    }
535}