1use dashmap::DashMap;
9use parking_lot::RwLock;
10use serde::{de::DeserializeOwned, Serialize};
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tracing::{debug, warn};
15
16use crate::cache::RedisCache;
17use crate::error::Result;
18
19#[derive(Debug, Clone)]
21pub struct MultiTierCacheConfig {
22 pub l1_max_size: usize,
24 pub l1_ttl_secs: u64,
26 pub l2_ttl_secs: u64,
28 pub promotion_threshold: u64,
30 pub demotion_threshold_secs: u64,
32}
33
34impl Default for MultiTierCacheConfig {
35 fn default() -> Self {
36 Self {
37 l1_max_size: 1000,
38 l1_ttl_secs: 300, l2_ttl_secs: 3600, promotion_threshold: 3,
41 demotion_threshold_secs: 600, }
43 }
44}
45
46#[derive(Debug, Clone)]
48struct L1Entry<T> {
49 value: T,
51 created_at: Instant,
53 last_accessed: Instant,
55 access_count: u64,
57}
58
59impl<T> L1Entry<T> {
60 fn new(value: T) -> Self {
61 let now = Instant::now();
62 Self {
63 value,
64 created_at: now,
65 last_accessed: now,
66 access_count: 1,
67 }
68 }
69
70 fn is_expired(&self, ttl: Duration) -> bool {
71 self.created_at.elapsed() > ttl
72 }
73
74 fn touch(&mut self) {
75 self.last_accessed = Instant::now();
76 self.access_count += 1;
77 }
78
79 fn should_demote(&self, threshold: Duration) -> bool {
80 self.last_accessed.elapsed() > threshold
81 }
82}
83
84#[derive(Debug, Clone, Default)]
86pub struct MultiTierStats {
87 pub l1_hits: u64,
89 pub l2_hits: u64,
91 pub misses: u64,
93 pub promotions: u64,
95 pub demotions: u64,
97 pub l1_evictions: u64,
99}
100
101impl MultiTierStats {
102 pub fn total_hits(&self) -> u64 {
104 self.l1_hits + self.l2_hits
105 }
106
107 pub fn hit_rate(&self) -> f64 {
109 let total = self.total_hits() + self.misses;
110 if total == 0 {
111 0.0
112 } else {
113 self.total_hits() as f64 / total as f64
114 }
115 }
116
117 pub fn l1_hit_rate(&self) -> f64 {
119 let total = self.l1_hits + self.l2_hits + self.misses;
120 if total == 0 {
121 0.0
122 } else {
123 self.l1_hits as f64 / total as f64
124 }
125 }
126
127 pub fn reset(&mut self) {
129 self.l1_hits = 0;
130 self.l2_hits = 0;
131 self.misses = 0;
132 self.promotions = 0;
133 self.demotions = 0;
134 self.l1_evictions = 0;
135 }
136}
137
138pub struct MultiTierCache {
140 l1_cache: Arc<DashMap<String, L1Entry<Vec<u8>>>>,
142 l2_cache: Arc<RedisCache>,
144 lru_queue: Arc<RwLock<VecDeque<String>>>,
146 l2_access_counts: Arc<DashMap<String, u64>>,
148 config: MultiTierCacheConfig,
150 stats: Arc<RwLock<MultiTierStats>>,
152}
153
154impl MultiTierCache {
155 pub fn new(l2_cache: Arc<RedisCache>, config: MultiTierCacheConfig) -> Self {
157 Self {
158 l1_cache: Arc::new(DashMap::new()),
159 l2_cache,
160 lru_queue: Arc::new(RwLock::new(VecDeque::new())),
161 l2_access_counts: Arc::new(DashMap::new()),
162 config,
163 stats: Arc::new(RwLock::new(MultiTierStats::default())),
164 }
165 }
166
167 pub async fn get<T: DeserializeOwned + Serialize>(&self, key: &str) -> Result<Option<T>> {
169 if let Some(mut entry) = self.l1_cache.get_mut(key) {
171 if entry.is_expired(Duration::from_secs(self.config.l1_ttl_secs)) {
173 drop(entry);
174 self.l1_cache.remove(key);
175 self.remove_from_lru(key);
176 debug!(key = %key, "L1 entry expired");
177 } else {
178 entry.touch();
179 self.stats.write().l1_hits += 1;
180 debug!(key = %key, "L1 cache hit");
181
182 let value: T = serde_json::from_slice(&entry.value).map_err(|e| {
183 crate::error::DbError::Cache(format!("Deserialization error: {}", e))
184 })?;
185
186 return Ok(Some(value));
187 }
188 }
189
190 let l2_result: Option<T> = self.l2_cache.get(key).await?;
192
193 if let Some(value) = l2_result.as_ref() {
194 self.stats.write().l2_hits += 1;
195 debug!(key = %key, "L2 cache hit");
196
197 let mut count = self.l2_access_counts.entry(key.to_string()).or_insert(0);
199 *count += 1;
200
201 if *count >= self.config.promotion_threshold {
203 debug!(key = %key, count = *count, "Promoting to L1");
204 self.promote_to_l1(key, value).await?;
205 self.l2_access_counts.remove(key);
206 self.stats.write().promotions += 1;
207 }
208
209 Ok(Some(l2_result.unwrap()))
210 } else {
211 self.stats.write().misses += 1;
212 debug!(key = %key, "Cache miss");
213 Ok(None)
214 }
215 }
216
217 pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
219 let bytes = serde_json::to_vec(value)
221 .map_err(|e| crate::error::DbError::Cache(format!("Serialization error: {}", e)))?;
222
223 self.evict_if_needed();
225
226 let entry = L1Entry::new(bytes.clone());
227 self.l1_cache.insert(key.to_string(), entry);
228 self.lru_queue.write().push_front(key.to_string());
229
230 debug!(key = %key, "Stored in L1");
231
232 self.l2_cache
234 .set(key, value, self.config.l2_ttl_secs)
235 .await?;
236
237 debug!(key = %key, "Stored in L2");
238
239 Ok(())
240 }
241
242 pub async fn delete(&self, key: &str) -> Result<bool> {
244 self.l1_cache.remove(key);
245 self.remove_from_lru(key);
246 self.l2_access_counts.remove(key);
247
248 let l2_deleted = self.l2_cache.delete(key).await?;
249
250 debug!(key = %key, "Deleted from both tiers");
251
252 Ok(l2_deleted)
253 }
254
255 async fn promote_to_l1<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
257 let bytes = serde_json::to_vec(value)
258 .map_err(|e| crate::error::DbError::Cache(format!("Serialization error: {}", e)))?;
259
260 self.evict_if_needed();
261
262 let entry = L1Entry::new(bytes);
263 self.l1_cache.insert(key.to_string(), entry);
264 self.lru_queue.write().push_front(key.to_string());
265
266 Ok(())
267 }
268
269 fn evict_if_needed(&self) {
271 while self.l1_cache.len() >= self.config.l1_max_size {
272 if let Some(oldest_key) = self.lru_queue.write().pop_back() {
273 self.l1_cache.remove(&oldest_key);
274 self.stats.write().l1_evictions += 1;
275 debug!(key = %oldest_key, "Evicted from L1");
276 } else {
277 break;
278 }
279 }
280 }
281
282 fn remove_from_lru(&self, key: &str) {
284 let mut queue = self.lru_queue.write();
285 if let Some(pos) = queue.iter().position(|k| k == key) {
286 queue.remove(pos);
287 }
288 }
289
290 pub fn cleanup_l1(&self) {
292 let ttl = Duration::from_secs(self.config.l1_ttl_secs);
293 let demotion_threshold = Duration::from_secs(self.config.demotion_threshold_secs);
294
295 let keys_to_remove: Vec<String> = self
296 .l1_cache
297 .iter()
298 .filter(|entry| {
299 entry.value().is_expired(ttl) || entry.value().should_demote(demotion_threshold)
300 })
301 .map(|entry| entry.key().clone())
302 .collect();
303
304 for key in keys_to_remove {
305 self.l1_cache.remove(&key);
306 self.remove_from_lru(&key);
307 self.stats.write().demotions += 1;
308 debug!(key = %key, "Demoted/expired from L1");
309 }
310 }
311
312 pub fn stats(&self) -> MultiTierStats {
314 self.stats.read().clone()
315 }
316
317 pub fn l1_size(&self) -> usize {
319 self.l1_cache.len()
320 }
321
322 pub async fn clear(&self) -> Result<()> {
324 self.l1_cache.clear();
325 self.lru_queue.write().clear();
326 self.l2_access_counts.clear();
327
328 warn!("Cleared L1 cache, L2 entries will expire naturally");
331
332 Ok(())
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_multi_tier_config_default() {
342 let config = MultiTierCacheConfig::default();
343 assert_eq!(config.l1_max_size, 1000);
344 assert_eq!(config.l1_ttl_secs, 300);
345 assert_eq!(config.l2_ttl_secs, 3600);
346 assert_eq!(config.promotion_threshold, 3);
347 assert_eq!(config.demotion_threshold_secs, 600);
348 }
349
350 #[test]
351 fn test_l1_entry_expiration() {
352 let entry = L1Entry::new("test".to_string());
353 assert!(!entry.is_expired(Duration::from_secs(3600)));
354 }
355
356 #[test]
357 fn test_l1_entry_touch() {
358 let mut entry = L1Entry::new("test".to_string());
359 let initial_count = entry.access_count;
360
361 entry.touch();
362
363 assert_eq!(entry.access_count, initial_count + 1);
364 }
365
366 #[test]
367 fn test_multi_tier_stats() {
368 let stats = MultiTierStats {
369 l1_hits: 80,
370 l2_hits: 15,
371 misses: 5,
372 promotions: 3,
373 demotions: 2,
374 l1_evictions: 1,
375 };
376
377 assert_eq!(stats.total_hits(), 95);
378 assert!((stats.hit_rate() - 0.95).abs() < 0.01);
379 assert!((stats.l1_hit_rate() - 0.80).abs() < 0.01);
380 }
381
382 #[test]
383 fn test_stats_reset() {
384 let mut stats = MultiTierStats {
385 l1_hits: 100,
386 l2_hits: 50,
387 misses: 10,
388 promotions: 5,
389 demotions: 3,
390 l1_evictions: 2,
391 };
392
393 stats.reset();
394
395 assert_eq!(stats.l1_hits, 0);
396 assert_eq!(stats.l2_hits, 0);
397 assert_eq!(stats.misses, 0);
398 assert_eq!(stats.promotions, 0);
399 assert_eq!(stats.demotions, 0);
400 assert_eq!(stats.l1_evictions, 0);
401 }
402}