rsigma_runtime/enrichment/
http_cache.rs1use std::collections::HashMap;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17use parking_lot::RwLock;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct CacheKey {
22 pub method: String,
24 pub url: String,
26 pub body_hash: u64,
28}
29
30impl CacheKey {
31 pub fn new(method: &str, url: &str, body: Option<&[u8]>) -> Self {
34 use std::hash::{Hash, Hasher};
35 let mut hasher = std::collections::hash_map::DefaultHasher::new();
36 body.unwrap_or(&[]).hash(&mut hasher);
37 Self {
38 method: method.to_ascii_uppercase(),
39 url: url.to_string(),
40 body_hash: hasher.finish(),
41 }
42 }
43}
44
45#[derive(Clone)]
46struct CacheEntry {
47 value: serde_json::Value,
48 stored_at: Instant,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum CacheOutcome {
54 Hit,
56 Miss,
58 Expired,
60}
61
62#[derive(Default)]
66pub struct CacheStats {
67 pub hits: u64,
68 pub misses: u64,
69 pub expirations: u64,
70}
71
72impl std::fmt::Debug for CacheStats {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("CacheStats")
75 .field("hits", &self.hits)
76 .field("misses", &self.misses)
77 .field("expirations", &self.expirations)
78 .finish()
79 }
80}
81
82#[derive(Clone)]
88pub struct HttpResponseCache {
89 inner: Arc<HttpResponseCacheInner>,
90}
91
92struct HttpResponseCacheInner {
93 entries: RwLock<HashMap<CacheKey, CacheEntry>>,
94 stats: RwLock<CacheStats>,
95 ttl: Duration,
96}
97
98impl HttpResponseCache {
99 pub fn new(ttl: Duration) -> Self {
105 Self {
106 inner: Arc::new(HttpResponseCacheInner {
107 entries: RwLock::new(HashMap::new()),
108 stats: RwLock::new(CacheStats::default()),
109 ttl,
110 }),
111 }
112 }
113
114 pub fn is_disabled(&self) -> bool {
116 self.inner.ttl.is_zero()
117 }
118
119 pub fn ttl(&self) -> Duration {
121 self.inner.ttl
122 }
123
124 pub fn lookup(&self, key: &CacheKey) -> (CacheOutcome, Option<serde_json::Value>) {
127 if self.is_disabled() {
128 self.inner.stats.write().misses += 1;
129 return (CacheOutcome::Miss, None);
130 }
131
132 {
134 let map = self.inner.entries.read();
135 if let Some(entry) = map.get(key) {
136 if entry.stored_at.elapsed() <= self.inner.ttl {
137 self.inner.stats.write().hits += 1;
138 return (CacheOutcome::Hit, Some(entry.value.clone()));
139 }
140 } else {
142 self.inner.stats.write().misses += 1;
143 return (CacheOutcome::Miss, None);
144 }
145 }
146
147 let mut map = self.inner.entries.write();
149 if let Some(entry) = map.get(key) {
150 if entry.stored_at.elapsed() > self.inner.ttl {
151 map.remove(key);
152 self.inner.stats.write().expirations += 1;
153 return (CacheOutcome::Expired, None);
154 }
155 self.inner.stats.write().hits += 1;
157 return (CacheOutcome::Hit, Some(entry.value.clone()));
158 }
159 self.inner.stats.write().misses += 1;
161 (CacheOutcome::Miss, None)
162 }
163
164 pub fn insert(&self, key: CacheKey, value: serde_json::Value) {
166 if self.is_disabled() {
167 return;
168 }
169 self.inner.entries.write().insert(
170 key,
171 CacheEntry {
172 value,
173 stored_at: Instant::now(),
174 },
175 );
176 }
177
178 pub fn evict_expired(&self) -> usize {
182 if self.is_disabled() {
183 return 0;
184 }
185 let mut map = self.inner.entries.write();
186 let before = map.len();
187 let ttl = self.inner.ttl;
188 map.retain(|_, e| e.stored_at.elapsed() <= ttl);
189 let removed = before - map.len();
190 if removed > 0 {
191 self.inner.stats.write().expirations += removed as u64;
192 }
193 removed
194 }
195
196 pub fn stats(&self) -> (u64, u64, u64) {
198 let s = self.inner.stats.read();
199 (s.hits, s.misses, s.expirations)
200 }
201
202 pub fn len(&self) -> usize {
204 self.inner.entries.read().len()
205 }
206
207 pub fn is_empty(&self) -> bool {
209 self.inner.entries.read().is_empty()
210 }
211}
212
213impl std::fmt::Debug for HttpResponseCache {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 let (h, m, x) = self.stats();
216 f.debug_struct("HttpResponseCache")
217 .field("ttl", &self.inner.ttl)
218 .field("len", &self.len())
219 .field("hits", &h)
220 .field("misses", &m)
221 .field("expirations", &x)
222 .finish()
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn disabled_cache_always_misses() {
232 let cache = HttpResponseCache::new(Duration::from_secs(0));
233 assert!(cache.is_disabled());
234 let key = CacheKey::new("GET", "https://x", None);
235 cache.insert(key.clone(), serde_json::json!("v"));
236 let (out, val) = cache.lookup(&key);
237 assert_eq!(out, CacheOutcome::Miss);
238 assert!(val.is_none());
239 }
240
241 #[test]
242 fn hit_then_miss_after_ttl() {
243 let cache = HttpResponseCache::new(Duration::from_millis(50));
244 let key = CacheKey::new("GET", "https://x", None);
245 cache.insert(key.clone(), serde_json::json!("v"));
246 let (out, val) = cache.lookup(&key);
247 assert_eq!(out, CacheOutcome::Hit);
248 assert_eq!(val, Some(serde_json::json!("v")));
249
250 std::thread::sleep(Duration::from_millis(80));
251 let (out, val) = cache.lookup(&key);
252 assert_eq!(out, CacheOutcome::Expired);
253 assert!(val.is_none());
254 }
255
256 #[test]
257 fn body_hash_separates_keys() {
258 let a = CacheKey::new("POST", "https://x", Some(b"a"));
259 let b = CacheKey::new("POST", "https://x", Some(b"b"));
260 assert_ne!(a, b);
261 let cache = HttpResponseCache::new(Duration::from_secs(60));
262 cache.insert(a.clone(), serde_json::json!(1));
263 let (out, _) = cache.lookup(&b);
264 assert_eq!(out, CacheOutcome::Miss);
265 }
266
267 #[test]
268 fn method_difference_separates_keys() {
269 let a = CacheKey::new("GET", "https://x", None);
270 let b = CacheKey::new("POST", "https://x", None);
271 assert_ne!(a, b);
272 }
273
274 #[test]
275 fn evict_expired_drops_old_entries() {
276 let cache = HttpResponseCache::new(Duration::from_millis(20));
277 for i in 0..5 {
278 cache.insert(
279 CacheKey::new("GET", &format!("https://x/{i}"), None),
280 serde_json::json!(i),
281 );
282 }
283 std::thread::sleep(Duration::from_millis(40));
284 let evicted = cache.evict_expired();
285 assert_eq!(evicted, 5);
286 assert_eq!(cache.len(), 0);
287 }
288
289 #[test]
290 fn stats_counters_increment() {
291 let cache = HttpResponseCache::new(Duration::from_secs(60));
292 let key = CacheKey::new("GET", "https://x", None);
293 let (_, _) = cache.lookup(&key);
294 cache.insert(key.clone(), serde_json::json!("v"));
295 let (_, _) = cache.lookup(&key);
296 let (_, _) = cache.lookup(&key);
297 let (hits, misses, _exp) = cache.stats();
298 assert_eq!(hits, 2);
299 assert_eq!(misses, 1);
300 }
301}