1use parking_lot::RwLock;
31use std::collections::{HashMap, HashSet};
32use std::sync::atomic::{AtomicUsize, Ordering};
33use std::time::{Duration, Instant};
34
35use super::backend::{BackendStats, CacheBackend, CacheError, CacheResult};
36use super::invalidation::EntityTag;
37use super::key::{CacheKey, KeyPattern};
38
39#[derive(Debug, Clone)]
41pub struct MemoryCacheConfig {
42 pub max_capacity: u64,
44 pub time_to_live: Option<Duration>,
46 pub time_to_idle: Option<Duration>,
48 pub per_entry_ttl: bool,
50 pub enable_tags: bool,
52}
53
54impl Default for MemoryCacheConfig {
55 fn default() -> Self {
56 Self {
57 max_capacity: 10_000,
58 time_to_live: Some(Duration::from_secs(300)),
59 time_to_idle: None,
60 per_entry_ttl: true,
61 enable_tags: true,
62 }
63 }
64}
65
66impl MemoryCacheConfig {
67 pub fn new(max_capacity: u64) -> Self {
69 Self {
70 max_capacity,
71 ..Default::default()
72 }
73 }
74
75 pub fn with_ttl(mut self, ttl: Duration) -> Self {
77 self.time_to_live = Some(ttl);
78 self
79 }
80
81 pub fn with_tti(mut self, tti: Duration) -> Self {
83 self.time_to_idle = Some(tti);
84 self
85 }
86
87 pub fn without_tags(mut self) -> Self {
89 self.enable_tags = false;
90 self
91 }
92}
93
94#[derive(Default)]
96pub struct MemoryCacheBuilder {
97 config: MemoryCacheConfig,
98}
99
100impl MemoryCacheBuilder {
101 pub fn new() -> Self {
103 Self::default()
104 }
105
106 pub fn max_capacity(mut self, capacity: u64) -> Self {
108 self.config.max_capacity = capacity;
109 self
110 }
111
112 pub fn time_to_live(mut self, ttl: Duration) -> Self {
114 self.config.time_to_live = Some(ttl);
115 self
116 }
117
118 pub fn time_to_idle(mut self, tti: Duration) -> Self {
120 self.config.time_to_idle = Some(tti);
121 self
122 }
123
124 pub fn per_entry_ttl(mut self, enabled: bool) -> Self {
126 self.config.per_entry_ttl = enabled;
127 self
128 }
129
130 pub fn enable_tags(mut self, enabled: bool) -> Self {
132 self.config.enable_tags = enabled;
133 self
134 }
135
136 pub fn build(self) -> MemoryCache {
138 MemoryCache::new(self.config)
139 }
140}
141
142#[derive(Clone)]
144struct CacheEntry {
145 data: Vec<u8>,
147 created_at: Instant,
149 expires_at: Option<Instant>,
151 last_accessed: Instant,
153 tags: Vec<EntityTag>,
155}
156
157impl CacheEntry {
158 fn new(data: Vec<u8>, ttl: Option<Duration>, tags: Vec<EntityTag>) -> Self {
159 let now = Instant::now();
160 Self {
161 data,
162 created_at: now,
163 expires_at: ttl.map(|d| now + d),
164 last_accessed: now,
165 tags,
166 }
167 }
168
169 fn is_expired(&self) -> bool {
170 self.expires_at.map_or(false, |exp| Instant::now() >= exp)
171 }
172
173 fn touch(&mut self) {
174 self.last_accessed = Instant::now();
175 }
176}
177
178pub struct MemoryCache {
182 config: MemoryCacheConfig,
183 entries: RwLock<HashMap<String, CacheEntry>>,
184 tag_index: RwLock<HashMap<String, HashSet<String>>>,
185 entry_count: AtomicUsize,
186}
187
188impl MemoryCache {
189 pub fn new(config: MemoryCacheConfig) -> Self {
191 Self {
192 entries: RwLock::new(HashMap::with_capacity(config.max_capacity as usize)),
193 tag_index: RwLock::new(HashMap::new()),
194 entry_count: AtomicUsize::new(0),
195 config,
196 }
197 }
198
199 pub fn builder() -> MemoryCacheBuilder {
201 MemoryCacheBuilder::new()
202 }
203
204 pub fn config(&self) -> &MemoryCacheConfig {
206 &self.config
207 }
208
209 pub fn evict_expired(&self) -> usize {
211 let mut entries = self.entries.write();
212 let before = entries.len();
213
214 let expired_keys: Vec<String> = entries
215 .iter()
216 .filter(|(_, e)| e.is_expired())
217 .map(|(k, _)| k.clone())
218 .collect();
219
220 for key in &expired_keys {
221 if let Some(entry) = entries.remove(key) {
222 self.remove_from_tag_index(key, &entry.tags);
223 }
224 }
225
226 let evicted = before - entries.len();
227 self.entry_count.fetch_sub(evicted, Ordering::Relaxed);
228 evicted
229 }
230
231 fn evict_lru(&self, count: usize) {
233 let mut entries = self.entries.write();
234
235 let mut by_access: Vec<_> = entries
237 .iter()
238 .map(|(k, e)| (k.clone(), e.last_accessed))
239 .collect();
240 by_access.sort_by_key(|(_, t)| *t);
241
242 for (key, _) in by_access.into_iter().take(count) {
243 if let Some(entry) = entries.remove(&key) {
244 self.remove_from_tag_index(&key, &entry.tags);
245 }
246 }
247
248 self.entry_count
249 .store(entries.len(), Ordering::Relaxed);
250 }
251
252 fn add_to_tag_index(&self, key: &str, tags: &[EntityTag]) {
254 if !self.config.enable_tags || tags.is_empty() {
255 return;
256 }
257
258 let mut index = self.tag_index.write();
259 for tag in tags {
260 index
261 .entry(tag.value().to_string())
262 .or_default()
263 .insert(key.to_string());
264 }
265 }
266
267 fn remove_from_tag_index(&self, key: &str, tags: &[EntityTag]) {
269 if !self.config.enable_tags || tags.is_empty() {
270 return;
271 }
272
273 let mut index = self.tag_index.write();
274 for tag in tags {
275 if let Some(keys) = index.get_mut(tag.value()) {
276 keys.remove(key);
277 if keys.is_empty() {
278 index.remove(tag.value());
279 }
280 }
281 }
282 }
283}
284
285impl CacheBackend for MemoryCache {
286 async fn get<T>(&self, key: &CacheKey) -> CacheResult<Option<T>>
287 where
288 T: serde::de::DeserializeOwned,
289 {
290 let key_str = key.as_str();
291
292 {
294 let entries = self.entries.read();
295 if let Some(entry) = entries.get(&key_str) {
296 if entry.is_expired() {
297 return Ok(None);
299 }
300
301 let value: T = serde_json::from_slice(&entry.data)
303 .map_err(|e| CacheError::Deserialization(e.to_string()))?;
304
305 return Ok(Some(value));
306 }
307 }
308
309 {
311 let mut entries = self.entries.write();
312 if let Some(entry) = entries.get_mut(&key_str) {
313 entry.touch();
314 }
315 }
316
317 Ok(None)
318 }
319
320 async fn set<T>(
321 &self,
322 key: &CacheKey,
323 value: &T,
324 ttl: Option<Duration>,
325 ) -> CacheResult<()>
326 where
327 T: serde::Serialize + Sync,
328 {
329 let key_str = key.as_str();
330
331 let data = serde_json::to_vec(value)
333 .map_err(|e| CacheError::Serialization(e.to_string()))?;
334
335 let effective_ttl = ttl.or(self.config.time_to_live);
336 let entry = CacheEntry::new(data, effective_ttl, Vec::new());
337
338 let current = self.entry_count.load(Ordering::Relaxed);
340 if current >= self.config.max_capacity as usize {
341 self.evict_expired();
343 let still_over = self.entry_count.load(Ordering::Relaxed);
344 if still_over >= self.config.max_capacity as usize {
345 self.evict_lru((self.config.max_capacity as usize / 10).max(1));
346 }
347 }
348
349 {
351 let mut entries = self.entries.write();
352 let is_new = !entries.contains_key(&key_str);
353 entries.insert(key_str.clone(), entry);
354 if is_new {
355 self.entry_count.fetch_add(1, Ordering::Relaxed);
356 }
357 }
358
359 Ok(())
360 }
361
362 async fn delete(&self, key: &CacheKey) -> CacheResult<bool> {
363 let key_str = key.as_str();
364
365 let mut entries = self.entries.write();
366 if let Some(entry) = entries.remove(&key_str) {
367 self.remove_from_tag_index(&key_str, &entry.tags);
368 self.entry_count.fetch_sub(1, Ordering::Relaxed);
369 Ok(true)
370 } else {
371 Ok(false)
372 }
373 }
374
375 async fn exists(&self, key: &CacheKey) -> CacheResult<bool> {
376 let key_str = key.as_str();
377
378 let entries = self.entries.read();
379 if let Some(entry) = entries.get(&key_str) {
380 Ok(!entry.is_expired())
381 } else {
382 Ok(false)
383 }
384 }
385
386 async fn invalidate_pattern(&self, pattern: &KeyPattern) -> CacheResult<u64> {
387 let mut entries = self.entries.write();
388 let before = entries.len();
389
390 let matching_keys: Vec<String> = entries
391 .keys()
392 .filter(|k| pattern.matches_str(k))
393 .cloned()
394 .collect();
395
396 for key in &matching_keys {
397 if let Some(entry) = entries.remove(key) {
398 self.remove_from_tag_index(key, &entry.tags);
399 }
400 }
401
402 let removed = before - entries.len();
403 self.entry_count.fetch_sub(removed, Ordering::Relaxed);
404 Ok(removed as u64)
405 }
406
407 async fn invalidate_tags(&self, tags: &[EntityTag]) -> CacheResult<u64> {
408 if !self.config.enable_tags {
409 return Ok(0);
410 }
411
412 let keys_to_remove: HashSet<String> = {
413 let index = self.tag_index.read();
414 tags.iter()
415 .filter_map(|tag| index.get(tag.value()))
416 .flatten()
417 .cloned()
418 .collect()
419 };
420
421 let mut entries = self.entries.write();
422 let mut removed = 0u64;
423
424 for key in keys_to_remove {
425 if let Some(entry) = entries.remove(&key) {
426 self.remove_from_tag_index(&key, &entry.tags);
427 removed += 1;
428 }
429 }
430
431 self.entry_count.fetch_sub(removed as usize, Ordering::Relaxed);
432 Ok(removed)
433 }
434
435 async fn clear(&self) -> CacheResult<()> {
436 let mut entries = self.entries.write();
437 entries.clear();
438 self.entry_count.store(0, Ordering::Relaxed);
439
440 if self.config.enable_tags {
441 let mut index = self.tag_index.write();
442 index.clear();
443 }
444
445 Ok(())
446 }
447
448 async fn len(&self) -> CacheResult<usize> {
449 Ok(self.entry_count.load(Ordering::Relaxed))
450 }
451
452 async fn stats(&self) -> CacheResult<BackendStats> {
453 let entries = self.entries.read();
454 let memory_estimate: usize = entries
455 .values()
456 .map(|e| e.data.len() + 64) .sum();
458
459 Ok(BackendStats {
460 entries: entries.len(),
461 memory_bytes: Some(memory_estimate),
462 connections: None,
463 info: Some(format!("MemoryCache (max: {})", self.config.max_capacity)),
464 })
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471
472 #[tokio::test]
473 async fn test_memory_cache_basic() {
474 let cache = MemoryCache::new(MemoryCacheConfig::new(100));
475
476 let key = CacheKey::new("test", "key1");
477
478 cache.set(&key, &"hello", None).await.unwrap();
480
481 let value: Option<String> = cache.get(&key).await.unwrap();
483 assert_eq!(value, Some("hello".to_string()));
484
485 assert!(cache.delete(&key).await.unwrap());
487
488 let value: Option<String> = cache.get(&key).await.unwrap();
490 assert!(value.is_none());
491 }
492
493 #[tokio::test]
494 async fn test_memory_cache_ttl() {
495 let config = MemoryCacheConfig::new(100).with_ttl(Duration::from_millis(50));
496 let cache = MemoryCache::new(config);
497
498 let key = CacheKey::new("test", "ttl");
499 cache.set(&key, &"expires soon", None).await.unwrap();
500
501 let value: Option<String> = cache.get(&key).await.unwrap();
503 assert!(value.is_some());
504
505 tokio::time::sleep(Duration::from_millis(60)).await;
507
508 let value: Option<String> = cache.get(&key).await.unwrap();
510 assert!(value.is_none());
511 }
512
513 #[tokio::test]
514 async fn test_memory_cache_eviction() {
515 let cache = MemoryCache::new(MemoryCacheConfig::new(5));
516
517 for i in 0..10 {
519 let key = CacheKey::new("test", format!("key{}", i));
520 cache.set(&key, &i, None).await.unwrap();
521 }
522
523 let len = cache.len().await.unwrap();
525 assert!(len <= 5);
526 }
527
528 #[tokio::test]
529 async fn test_memory_cache_pattern_invalidation() {
530 let cache = MemoryCache::new(MemoryCacheConfig::new(100));
531
532 for i in 0..5 {
534 let key = CacheKey::new("User", format!("id:{}", i));
535 cache.set(&key, &i, None).await.unwrap();
536 }
537 for i in 0..3 {
538 let key = CacheKey::new("Post", format!("id:{}", i));
539 cache.set(&key, &i, None).await.unwrap();
540 }
541
542 assert_eq!(cache.len().await.unwrap(), 8);
543
544 let removed = cache
546 .invalidate_pattern(&KeyPattern::entity("User"))
547 .await
548 .unwrap();
549 assert_eq!(removed, 5);
550 assert_eq!(cache.len().await.unwrap(), 3);
551 }
552
553 #[tokio::test]
554 async fn test_memory_cache_builder() {
555 let cache = MemoryCache::builder()
556 .max_capacity(1000)
557 .time_to_live(Duration::from_secs(60))
558 .build();
559
560 assert_eq!(cache.config().max_capacity, 1000);
561 assert_eq!(cache.config().time_to_live, Some(Duration::from_secs(60)));
562 }
563}
564