Skip to main content

hyperi_rustlib/cache/
mod.rs

1// Project:   hyperi-rustlib
2// File:      src/cache/mod.rs
3// Purpose:   In-memory cache with per-source TTL, metrics, and invalidation
4// Language:  Rust
5//
6// License:   FSL-1.1-ALv2
7// Copyright: (c) 2026 HYPERI PTY LIMITED
8
9//! In-memory cache with per-source TTL, metrics, and invalidation.
10//!
11//! Wraps [`moka`] to provide a concurrent, async-friendly cache with
12//! TinyLFU eviction. Matches hyperi-pylib's cache module API:
13//! per-source TTL configuration, `get`/`set`/`invalidate_source`.
14//!
15//! # Config Cascade
16//!
17//! ```yaml
18//! cache:
19//!   max_capacity: 10000
20//!   default_ttl_secs: 3600
21//!   source_ttls:
22//!     http: 86400
23//!     db: 1800
24//!     search: 3600
25//! ```
26//!
27//! # Usage
28//!
29//! ```rust,no_run
30//! use hyperi_rustlib::cache::{Cache, CacheConfig};
31//!
32//! #[tokio::main]
33//! async fn main() {
34//!     let cache = Cache::new(CacheConfig::default());
35//!
36//!     // Set with source-specific TTL
37//!     cache.set("http", "https://api.example.com", "response_data").await.expect("cache set");
38//!
39//!     // Get
40//!     if let Some(value) = cache.get::<String>("http", "https://api.example.com").await {
41//!         println!("cached: {value}");
42//!     }
43//!
44//!     // Invalidate all entries for a source
45//!     cache.invalidate_source("http").await;
46//! }
47//! ```
48
49pub mod config;
50
51pub use config::CacheConfig;
52
53use std::collections::HashMap;
54use std::sync::{Arc, Mutex};
55use std::time::Duration;
56
57use moka::future::Cache as MokaCache;
58
59/// In-memory cache with per-source TTL and source-aware keys.
60pub struct Cache {
61    inner: MokaCache<String, Arc<Vec<u8>>>,
62    config: CacheConfig,
63    /// Track keys per source for invalidation.
64    source_keys: Mutex<HashMap<String, Vec<String>>>,
65}
66
67impl Cache {
68    /// Create a new cache with the given config.
69    #[must_use]
70    pub fn new(config: CacheConfig) -> Self {
71        let inner = MokaCache::builder()
72            .max_capacity(config.max_capacity)
73            .time_to_live(Duration::from_secs(config.default_ttl_secs))
74            .build();
75
76        Self {
77            inner,
78            config,
79            source_keys: Mutex::new(HashMap::new()),
80        }
81    }
82
83    /// Create a cache from the config cascade (or defaults).
84    #[must_use]
85    pub fn from_cascade() -> Self {
86        Self::new(CacheConfig::from_cascade())
87    }
88
89    /// Get a cached value by source and key.
90    ///
91    /// Returns `None` if not found or expired.
92    pub async fn get<T: serde::de::DeserializeOwned>(&self, source: &str, key: &str) -> Option<T> {
93        let full_key = format!("{source}:{key}");
94        let bytes = self.inner.get(&full_key).await;
95
96        #[cfg(feature = "metrics")]
97        if bytes.is_some() {
98            metrics::counter!("dfe_cache_hits_total", "source" => source.to_string()).increment(1);
99        } else {
100            metrics::counter!("dfe_cache_misses_total", "source" => source.to_string())
101                .increment(1);
102        }
103
104        let bytes = bytes?;
105        serde_json::from_slice(&bytes).ok()
106    }
107
108    /// Set a cached value.
109    ///
110    /// **TTL note:** all entries currently use the cache-wide TTL set at
111    /// construction time. Per-source TTL configuration is read from
112    /// [`CacheConfig::source_ttls`] but is not yet honoured for inserts
113    /// (moka's API requires per-entry expiration to go through its
114    /// `Expiry` trait, which would force a cache rebuild on every config
115    /// reload). The `ttl_for_source` helper remains for callers that want
116    /// to inspect the configured policy.
117    ///
118    /// # Errors
119    ///
120    /// Returns the underlying [`serde_json::Error`] when serialisation of
121    /// `value` fails. Previously this dropped the error silently; callers
122    /// had no way to know their cache write was a no-op.
123    pub async fn set<T: serde::Serialize>(
124        &self,
125        source: &str,
126        key: &str,
127        value: T,
128    ) -> Result<(), serde_json::Error> {
129        let full_key = format!("{source}:{key}");
130        let bytes = Arc::new(serde_json::to_vec(&value)?);
131
132        self.inner.insert(full_key.clone(), bytes).await;
133
134        #[cfg(feature = "metrics")]
135        metrics::gauge!("dfe_cache_entries").set(self.inner.entry_count() as f64);
136
137        // Track key for source-level invalidation
138        if let Ok(mut keys) = self.source_keys.lock() {
139            keys.entry(source.to_string()).or_default().push(full_key);
140        }
141
142        Ok(())
143    }
144
145    /// Invalidate all cached entries for a source.
146    pub async fn invalidate_source(&self, source: &str) {
147        let keys = {
148            let Ok(mut source_keys) = self.source_keys.lock() else {
149                return;
150            };
151            source_keys.remove(source).unwrap_or_default()
152        };
153
154        for key in keys {
155            self.inner.invalidate(&key).await;
156        }
157
158        #[cfg(feature = "metrics")]
159        metrics::gauge!("dfe_cache_entries").set(self.inner.entry_count() as f64);
160    }
161
162    /// Invalidate a single entry.
163    pub async fn invalidate(&self, source: &str, key: &str) {
164        let full_key = format!("{source}:{key}");
165        self.inner.invalidate(&full_key).await;
166    }
167
168    /// Get the TTL for a source (from config or default).
169    ///
170    /// Returns the per-source TTL when configured in
171    /// [`CacheConfig::source_ttls`], otherwise the cache-wide
172    /// `default_ttl_secs`. Note that moka uses only the cache-wide TTL
173    /// for actual expiration -- see [`Self::set`] for the gap; this
174    /// accessor lets callers inspect what the policy *would* be.
175    #[must_use]
176    pub fn ttl_for_source(&self, source: &str) -> Duration {
177        self.config.source_ttls.get(source).copied().map_or(
178            Duration::from_secs(self.config.default_ttl_secs),
179            Duration::from_secs,
180        )
181    }
182
183    /// Current number of entries in the cache.
184    pub fn entry_count(&self) -> u64 {
185        self.inner.entry_count()
186    }
187
188    /// Access the current config.
189    #[must_use]
190    pub fn config(&self) -> &CacheConfig {
191        &self.config
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    fn test_config() -> CacheConfig {
200        CacheConfig {
201            max_capacity: 100,
202            default_ttl_secs: 60,
203            source_ttls: HashMap::from([("http".into(), 3600), ("db".into(), 1800)]),
204        }
205    }
206
207    #[tokio::test]
208    async fn set_and_get() {
209        let cache = Cache::new(test_config());
210        cache
211            .set("http", "url1", "value1".to_string())
212            .await
213            .expect("cache set");
214
215        let result: Option<String> = cache.get("http", "url1").await;
216        assert_eq!(result.as_deref(), Some("value1"));
217    }
218
219    #[tokio::test]
220    async fn get_missing_returns_none() {
221        let cache = Cache::new(test_config());
222        let result: Option<String> = cache.get("http", "nonexistent").await;
223        assert!(result.is_none());
224    }
225
226    #[tokio::test]
227    async fn sources_are_isolated() {
228        let cache = Cache::new(test_config());
229        cache
230            .set("http", "key1", "http_value".to_string())
231            .await
232            .expect("cache set");
233        cache
234            .set("db", "key1", "db_value".to_string())
235            .await
236            .expect("cache set");
237
238        let http: Option<String> = cache.get("http", "key1").await;
239        let db: Option<String> = cache.get("db", "key1").await;
240
241        assert_eq!(http.as_deref(), Some("http_value"));
242        assert_eq!(db.as_deref(), Some("db_value"));
243    }
244
245    #[tokio::test]
246    async fn invalidate_source_removes_only_that_source() {
247        let cache = Cache::new(test_config());
248        cache
249            .set("http", "url1", "v1".to_string())
250            .await
251            .expect("cache set");
252        cache
253            .set("http", "url2", "v2".to_string())
254            .await
255            .expect("cache set");
256        cache
257            .set("db", "query1", "v3".to_string())
258            .await
259            .expect("cache set");
260
261        cache.invalidate_source("http").await;
262
263        // Run pending tasks to ensure invalidation is processed
264        cache.inner.run_pending_tasks().await;
265
266        let http1: Option<String> = cache.get("http", "url1").await;
267        let http2: Option<String> = cache.get("http", "url2").await;
268        let db1: Option<String> = cache.get("db", "query1").await;
269
270        assert!(http1.is_none(), "http url1 should be invalidated");
271        assert!(http2.is_none(), "http url2 should be invalidated");
272        assert_eq!(db1.as_deref(), Some("v3"), "db should be preserved");
273    }
274
275    #[tokio::test]
276    async fn invalidate_single_entry() {
277        let cache = Cache::new(test_config());
278        cache
279            .set("http", "url1", "v1".to_string())
280            .await
281            .expect("cache set");
282        cache
283            .set("http", "url2", "v2".to_string())
284            .await
285            .expect("cache set");
286
287        cache.invalidate("http", "url1").await;
288        cache.inner.run_pending_tasks().await;
289
290        let v1: Option<String> = cache.get("http", "url1").await;
291        let v2: Option<String> = cache.get("http", "url2").await;
292
293        assert!(v1.is_none());
294        assert_eq!(v2.as_deref(), Some("v2"));
295    }
296
297    #[tokio::test]
298    async fn entry_count() {
299        let cache = Cache::new(test_config());
300        assert_eq!(cache.entry_count(), 0);
301
302        cache
303            .set("http", "url1", "v1".to_string())
304            .await
305            .expect("cache set");
306        cache
307            .set("http", "url2", "v2".to_string())
308            .await
309            .expect("cache set");
310        cache.inner.run_pending_tasks().await;
311
312        assert_eq!(cache.entry_count(), 2);
313    }
314
315    #[tokio::test]
316    async fn complex_types() {
317        let cache = Cache::new(test_config());
318
319        let data = serde_json::json!({"name": "test", "values": [1, 2, 3]});
320        cache
321            .set("db", "query1", data.clone())
322            .await
323            .expect("cache set");
324
325        let result: Option<serde_json::Value> = cache.get("db", "query1").await;
326        assert_eq!(result, Some(data));
327    }
328}