Skip to main content

rsigma_runtime/enrichment/
http_cache.rs

1//! In-memory response cache for [`HttpEnricher`](super::http::HttpEnricher).
2//!
3//! Keyed on `(method, url, body_hash)` with a configurable TTL. Each
4//! `HttpEnricher` instance owns its own cache so two recipes that hit
5//! the same URL with different API keys (different `Authorization`
6//! headers) cannot accidentally share each other's cached responses.
7//!
8//! Mandatory in practice for any rate-limited API (VirusTotal: 4 req/min
9//! on the free tier) and a major win for any duplicate-detection burst.
10//! Off by default; `cache_ttl: <duration>` on the enricher config flips
11//! it on.
12
13use std::collections::HashMap;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17use parking_lot::RwLock;
18
19/// Composite cache key for one cached HTTP response.
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct CacheKey {
22    /// Request method, normalized to upper case (`GET`, `POST`, …).
23    pub method: String,
24    /// Full request URL.
25    pub url: String,
26    /// 64-bit hash of the request body (zero when no body).
27    pub body_hash: u64,
28}
29
30impl CacheKey {
31    /// Build a cache key from raw components. Hashes the body once at
32    /// insert time so subsequent lookups are O(key-size).
33    pub fn new(method: &str, url: &str, body: Option<&[u8]>) -> Self {
34        use std::hash::{Hash, Hasher};
35        let mut hasher = std::collections::hash_map::DefaultHasher::new();
36        body.unwrap_or(&[]).hash(&mut hasher);
37        Self {
38            method: method.to_ascii_uppercase(),
39            url: url.to_string(),
40            body_hash: hasher.finish(),
41        }
42    }
43}
44
45#[derive(Clone)]
46struct CacheEntry {
47    value: serde_json::Value,
48    stored_at: Instant,
49}
50
51/// Outcome of a cache lookup.
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum CacheOutcome {
54    /// Entry was present and within TTL; the cached value was returned.
55    Hit,
56    /// No entry for this key.
57    Miss,
58    /// Entry was present but past its TTL; lazily evicted on read.
59    Expired,
60}
61
62/// Stats counters that survive across [`HttpResponseCache::lookup`] /
63/// [`HttpResponseCache::insert`] calls. Wired into Prometheus metrics
64/// in Phase 4 (`rsigma_enrichment_http_cache_{hits,misses,expirations}_total`).
65#[derive(Default)]
66pub struct CacheStats {
67    pub hits: u64,
68    pub misses: u64,
69    pub expirations: u64,
70}
71
72impl std::fmt::Debug for CacheStats {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_struct("CacheStats")
75            .field("hits", &self.hits)
76            .field("misses", &self.misses)
77            .field("expirations", &self.expirations)
78            .finish()
79    }
80}
81
82/// In-memory `(method, url, body_hash) → JSON value` cache with TTL.
83///
84/// Cheap to clone (`Arc`-wrapped internals); each enricher instance keeps
85/// its own clone in its hot path, and the daemon may share instances
86/// across enricher reload cycles when the config has not changed.
87#[derive(Clone)]
88pub struct HttpResponseCache {
89    inner: Arc<HttpResponseCacheInner>,
90}
91
92struct HttpResponseCacheInner {
93    entries: RwLock<HashMap<CacheKey, CacheEntry>>,
94    stats: RwLock<CacheStats>,
95    ttl: Duration,
96}
97
98impl HttpResponseCache {
99    /// Build a new cache with the given TTL.
100    ///
101    /// A `ttl` of zero is treated as "disabled" — every lookup returns
102    /// [`CacheOutcome::Miss`] and inserts are no-ops, so call sites can
103    /// always go through the cache without checking `cache_ttl > 0`.
104    pub fn new(ttl: Duration) -> Self {
105        Self {
106            inner: Arc::new(HttpResponseCacheInner {
107                entries: RwLock::new(HashMap::new()),
108                stats: RwLock::new(CacheStats::default()),
109                ttl,
110            }),
111        }
112    }
113
114    /// Returns true when this cache is "off" (TTL is zero).
115    pub fn is_disabled(&self) -> bool {
116        self.inner.ttl.is_zero()
117    }
118
119    /// Configured TTL.
120    pub fn ttl(&self) -> Duration {
121        self.inner.ttl
122    }
123
124    /// Look up `key`. Returns the cached value if it is present and
125    /// within TTL; expires it lazily otherwise.
126    pub fn lookup(&self, key: &CacheKey) -> (CacheOutcome, Option<serde_json::Value>) {
127        if self.is_disabled() {
128            self.inner.stats.write().misses += 1;
129            return (CacheOutcome::Miss, None);
130        }
131
132        // Fast path: read lock for the common hit / miss case.
133        {
134            let map = self.inner.entries.read();
135            if let Some(entry) = map.get(key) {
136                if entry.stored_at.elapsed() <= self.inner.ttl {
137                    self.inner.stats.write().hits += 1;
138                    return (CacheOutcome::Hit, Some(entry.value.clone()));
139                }
140                // Expired — fall through to write-locked eviction.
141            } else {
142                self.inner.stats.write().misses += 1;
143                return (CacheOutcome::Miss, None);
144            }
145        }
146
147        // Slow path: take write lock to evict expired entry.
148        let mut map = self.inner.entries.write();
149        if let Some(entry) = map.get(key) {
150            if entry.stored_at.elapsed() > self.inner.ttl {
151                map.remove(key);
152                self.inner.stats.write().expirations += 1;
153                return (CacheOutcome::Expired, None);
154            }
155            // Race: re-validated by another thread.
156            self.inner.stats.write().hits += 1;
157            return (CacheOutcome::Hit, Some(entry.value.clone()));
158        }
159        // Race: removed by another thread.
160        self.inner.stats.write().misses += 1;
161        (CacheOutcome::Miss, None)
162    }
163
164    /// Insert `value` under `key`. No-op when the cache is disabled.
165    pub fn insert(&self, key: CacheKey, value: serde_json::Value) {
166        if self.is_disabled() {
167            return;
168        }
169        self.inner.entries.write().insert(
170            key,
171            CacheEntry {
172                value,
173                stored_at: Instant::now(),
174            },
175        );
176    }
177
178    /// Remove every entry whose stored_at + TTL is in the past. Called
179    /// periodically by a background sweep when the daemon's enrichment
180    /// pipeline has at least one cache configured.
181    pub fn evict_expired(&self) -> usize {
182        if self.is_disabled() {
183            return 0;
184        }
185        let mut map = self.inner.entries.write();
186        let before = map.len();
187        let ttl = self.inner.ttl;
188        map.retain(|_, e| e.stored_at.elapsed() <= ttl);
189        let removed = before - map.len();
190        if removed > 0 {
191            self.inner.stats.write().expirations += removed as u64;
192        }
193        removed
194    }
195
196    /// Snapshot the cumulative cache stats since construction.
197    pub fn stats(&self) -> (u64, u64, u64) {
198        let s = self.inner.stats.read();
199        (s.hits, s.misses, s.expirations)
200    }
201
202    /// Number of entries currently held.
203    pub fn len(&self) -> usize {
204        self.inner.entries.read().len()
205    }
206
207    /// True when no entries are held.
208    pub fn is_empty(&self) -> bool {
209        self.inner.entries.read().is_empty()
210    }
211}
212
213impl std::fmt::Debug for HttpResponseCache {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        let (h, m, x) = self.stats();
216        f.debug_struct("HttpResponseCache")
217            .field("ttl", &self.inner.ttl)
218            .field("len", &self.len())
219            .field("hits", &h)
220            .field("misses", &m)
221            .field("expirations", &x)
222            .finish()
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn disabled_cache_always_misses() {
232        let cache = HttpResponseCache::new(Duration::from_secs(0));
233        assert!(cache.is_disabled());
234        let key = CacheKey::new("GET", "https://x", None);
235        cache.insert(key.clone(), serde_json::json!("v"));
236        let (out, val) = cache.lookup(&key);
237        assert_eq!(out, CacheOutcome::Miss);
238        assert!(val.is_none());
239    }
240
241    #[test]
242    fn hit_then_miss_after_ttl() {
243        let cache = HttpResponseCache::new(Duration::from_millis(50));
244        let key = CacheKey::new("GET", "https://x", None);
245        cache.insert(key.clone(), serde_json::json!("v"));
246        let (out, val) = cache.lookup(&key);
247        assert_eq!(out, CacheOutcome::Hit);
248        assert_eq!(val, Some(serde_json::json!("v")));
249
250        std::thread::sleep(Duration::from_millis(80));
251        let (out, val) = cache.lookup(&key);
252        assert_eq!(out, CacheOutcome::Expired);
253        assert!(val.is_none());
254    }
255
256    #[test]
257    fn body_hash_separates_keys() {
258        let a = CacheKey::new("POST", "https://x", Some(b"a"));
259        let b = CacheKey::new("POST", "https://x", Some(b"b"));
260        assert_ne!(a, b);
261        let cache = HttpResponseCache::new(Duration::from_secs(60));
262        cache.insert(a.clone(), serde_json::json!(1));
263        let (out, _) = cache.lookup(&b);
264        assert_eq!(out, CacheOutcome::Miss);
265    }
266
267    #[test]
268    fn method_difference_separates_keys() {
269        let a = CacheKey::new("GET", "https://x", None);
270        let b = CacheKey::new("POST", "https://x", None);
271        assert_ne!(a, b);
272    }
273
274    #[test]
275    fn evict_expired_drops_old_entries() {
276        let cache = HttpResponseCache::new(Duration::from_millis(20));
277        for i in 0..5 {
278            cache.insert(
279                CacheKey::new("GET", &format!("https://x/{i}"), None),
280                serde_json::json!(i),
281            );
282        }
283        std::thread::sleep(Duration::from_millis(40));
284        let evicted = cache.evict_expired();
285        assert_eq!(evicted, 5);
286        assert_eq!(cache.len(), 0);
287    }
288
289    #[test]
290    fn stats_counters_increment() {
291        let cache = HttpResponseCache::new(Duration::from_secs(60));
292        let key = CacheKey::new("GET", "https://x", None);
293        let (_, _) = cache.lookup(&key);
294        cache.insert(key.clone(), serde_json::json!("v"));
295        let (_, _) = cache.lookup(&key);
296        let (_, _) = cache.lookup(&key);
297        let (hits, misses, _exp) = cache.stats();
298        assert_eq!(hits, 2);
299        assert_eq!(misses, 1);
300    }
301}