1use parking_lot::RwLock;
31use std::collections::{HashMap, HashSet};
32use std::sync::atomic::{AtomicUsize, Ordering};
33use std::time::{Duration, Instant};
34
35use super::backend::{BackendStats, CacheBackend, CacheError, CacheResult};
36use super::invalidation::EntityTag;
37use super::key::{CacheKey, KeyPattern};
38
39#[derive(Debug, Clone)]
41pub struct MemoryCacheConfig {
42 pub max_capacity: u64,
44 pub time_to_live: Option<Duration>,
46 pub time_to_idle: Option<Duration>,
48 pub per_entry_ttl: bool,
50 pub enable_tags: bool,
52}
53
54impl Default for MemoryCacheConfig {
55 fn default() -> Self {
56 Self {
57 max_capacity: 10_000,
58 time_to_live: Some(Duration::from_secs(300)),
59 time_to_idle: None,
60 per_entry_ttl: true,
61 enable_tags: true,
62 }
63 }
64}
65
66impl MemoryCacheConfig {
67 pub fn new(max_capacity: u64) -> Self {
69 Self {
70 max_capacity,
71 ..Default::default()
72 }
73 }
74
75 pub fn with_ttl(mut self, ttl: Duration) -> Self {
77 self.time_to_live = Some(ttl);
78 self
79 }
80
81 pub fn with_tti(mut self, tti: Duration) -> Self {
83 self.time_to_idle = Some(tti);
84 self
85 }
86
87 pub fn without_tags(mut self) -> Self {
89 self.enable_tags = false;
90 self
91 }
92}
93
94#[derive(Default)]
96pub struct MemoryCacheBuilder {
97 config: MemoryCacheConfig,
98}
99
100impl MemoryCacheBuilder {
101 pub fn new() -> Self {
103 Self::default()
104 }
105
106 pub fn max_capacity(mut self, capacity: u64) -> Self {
108 self.config.max_capacity = capacity;
109 self
110 }
111
112 pub fn time_to_live(mut self, ttl: Duration) -> Self {
114 self.config.time_to_live = Some(ttl);
115 self
116 }
117
118 pub fn time_to_idle(mut self, tti: Duration) -> Self {
120 self.config.time_to_idle = Some(tti);
121 self
122 }
123
124 pub fn per_entry_ttl(mut self, enabled: bool) -> Self {
126 self.config.per_entry_ttl = enabled;
127 self
128 }
129
130 pub fn enable_tags(mut self, enabled: bool) -> Self {
132 self.config.enable_tags = enabled;
133 self
134 }
135
136 pub fn build(self) -> MemoryCache {
138 MemoryCache::new(self.config)
139 }
140}
141
142#[derive(Clone)]
144struct CacheEntry {
145 data: Vec<u8>,
147 created_at: Instant,
149 expires_at: Option<Instant>,
151 last_accessed: Instant,
153 tags: Vec<EntityTag>,
155}
156
157impl CacheEntry {
158 fn new(data: Vec<u8>, ttl: Option<Duration>, tags: Vec<EntityTag>) -> Self {
159 let now = Instant::now();
160 Self {
161 data,
162 created_at: now,
163 expires_at: ttl.map(|d| now + d),
164 last_accessed: now,
165 tags,
166 }
167 }
168
169 fn is_expired(&self) -> bool {
170 self.expires_at.map_or(false, |exp| Instant::now() >= exp)
171 }
172
173 fn touch(&mut self) {
174 self.last_accessed = Instant::now();
175 }
176}
177
178pub struct MemoryCache {
182 config: MemoryCacheConfig,
183 entries: RwLock<HashMap<String, CacheEntry>>,
184 tag_index: RwLock<HashMap<String, HashSet<String>>>,
185 entry_count: AtomicUsize,
186}
187
188impl MemoryCache {
189 pub fn new(config: MemoryCacheConfig) -> Self {
191 Self {
192 entries: RwLock::new(HashMap::with_capacity(config.max_capacity as usize)),
193 tag_index: RwLock::new(HashMap::new()),
194 entry_count: AtomicUsize::new(0),
195 config,
196 }
197 }
198
199 pub fn builder() -> MemoryCacheBuilder {
201 MemoryCacheBuilder::new()
202 }
203
204 pub fn config(&self) -> &MemoryCacheConfig {
206 &self.config
207 }
208
209 pub fn evict_expired(&self) -> usize {
211 let mut entries = self.entries.write();
212 let before = entries.len();
213
214 let expired_keys: Vec<String> = entries
215 .iter()
216 .filter(|(_, e)| e.is_expired())
217 .map(|(k, _)| k.clone())
218 .collect();
219
220 for key in &expired_keys {
221 if let Some(entry) = entries.remove(key) {
222 self.remove_from_tag_index(key, &entry.tags);
223 }
224 }
225
226 let evicted = before - entries.len();
227 self.entry_count.fetch_sub(evicted, Ordering::Relaxed);
228 evicted
229 }
230
231 fn evict_lru(&self, count: usize) {
233 let mut entries = self.entries.write();
234
235 let mut by_access: Vec<_> = entries
237 .iter()
238 .map(|(k, e)| (k.clone(), e.last_accessed))
239 .collect();
240 by_access.sort_by_key(|(_, t)| *t);
241
242 for (key, _) in by_access.into_iter().take(count) {
243 if let Some(entry) = entries.remove(&key) {
244 self.remove_from_tag_index(&key, &entry.tags);
245 }
246 }
247
248 self.entry_count.store(entries.len(), Ordering::Relaxed);
249 }
250
251 fn add_to_tag_index(&self, key: &str, tags: &[EntityTag]) {
253 if !self.config.enable_tags || tags.is_empty() {
254 return;
255 }
256
257 let mut index = self.tag_index.write();
258 for tag in tags {
259 index
260 .entry(tag.value().to_string())
261 .or_default()
262 .insert(key.to_string());
263 }
264 }
265
266 fn remove_from_tag_index(&self, key: &str, tags: &[EntityTag]) {
268 if !self.config.enable_tags || tags.is_empty() {
269 return;
270 }
271
272 let mut index = self.tag_index.write();
273 for tag in tags {
274 if let Some(keys) = index.get_mut(tag.value()) {
275 keys.remove(key);
276 if keys.is_empty() {
277 index.remove(tag.value());
278 }
279 }
280 }
281 }
282}
283
284impl CacheBackend for MemoryCache {
285 async fn get<T>(&self, key: &CacheKey) -> CacheResult<Option<T>>
286 where
287 T: serde::de::DeserializeOwned,
288 {
289 let key_str = key.as_str();
290
291 {
293 let entries = self.entries.read();
294 if let Some(entry) = entries.get(&key_str) {
295 if entry.is_expired() {
296 return Ok(None);
298 }
299
300 let value: T = serde_json::from_slice(&entry.data)
302 .map_err(|e| CacheError::Deserialization(e.to_string()))?;
303
304 return Ok(Some(value));
305 }
306 }
307
308 {
310 let mut entries = self.entries.write();
311 if let Some(entry) = entries.get_mut(&key_str) {
312 entry.touch();
313 }
314 }
315
316 Ok(None)
317 }
318
319 async fn set<T>(&self, key: &CacheKey, value: &T, ttl: Option<Duration>) -> CacheResult<()>
320 where
321 T: serde::Serialize + Sync,
322 {
323 let key_str = key.as_str();
324
325 let data =
327 serde_json::to_vec(value).map_err(|e| CacheError::Serialization(e.to_string()))?;
328
329 let effective_ttl = ttl.or(self.config.time_to_live);
330 let entry = CacheEntry::new(data, effective_ttl, Vec::new());
331
332 let current = self.entry_count.load(Ordering::Relaxed);
334 if current >= self.config.max_capacity as usize {
335 self.evict_expired();
337 let still_over = self.entry_count.load(Ordering::Relaxed);
338 if still_over >= self.config.max_capacity as usize {
339 self.evict_lru((self.config.max_capacity as usize / 10).max(1));
340 }
341 }
342
343 {
345 let mut entries = self.entries.write();
346 let is_new = !entries.contains_key(&key_str);
347 entries.insert(key_str.clone(), entry);
348 if is_new {
349 self.entry_count.fetch_add(1, Ordering::Relaxed);
350 }
351 }
352
353 Ok(())
354 }
355
356 async fn delete(&self, key: &CacheKey) -> CacheResult<bool> {
357 let key_str = key.as_str();
358
359 let mut entries = self.entries.write();
360 if let Some(entry) = entries.remove(&key_str) {
361 self.remove_from_tag_index(&key_str, &entry.tags);
362 self.entry_count.fetch_sub(1, Ordering::Relaxed);
363 Ok(true)
364 } else {
365 Ok(false)
366 }
367 }
368
369 async fn exists(&self, key: &CacheKey) -> CacheResult<bool> {
370 let key_str = key.as_str();
371
372 let entries = self.entries.read();
373 if let Some(entry) = entries.get(&key_str) {
374 Ok(!entry.is_expired())
375 } else {
376 Ok(false)
377 }
378 }
379
380 async fn invalidate_pattern(&self, pattern: &KeyPattern) -> CacheResult<u64> {
381 let mut entries = self.entries.write();
382 let before = entries.len();
383
384 let matching_keys: Vec<String> = entries
385 .keys()
386 .filter(|k| pattern.matches_str(k))
387 .cloned()
388 .collect();
389
390 for key in &matching_keys {
391 if let Some(entry) = entries.remove(key) {
392 self.remove_from_tag_index(key, &entry.tags);
393 }
394 }
395
396 let removed = before - entries.len();
397 self.entry_count.fetch_sub(removed, Ordering::Relaxed);
398 Ok(removed as u64)
399 }
400
401 async fn invalidate_tags(&self, tags: &[EntityTag]) -> CacheResult<u64> {
402 if !self.config.enable_tags {
403 return Ok(0);
404 }
405
406 let keys_to_remove: HashSet<String> = {
407 let index = self.tag_index.read();
408 tags.iter()
409 .filter_map(|tag| index.get(tag.value()))
410 .flatten()
411 .cloned()
412 .collect()
413 };
414
415 let mut entries = self.entries.write();
416 let mut removed = 0u64;
417
418 for key in keys_to_remove {
419 if let Some(entry) = entries.remove(&key) {
420 self.remove_from_tag_index(&key, &entry.tags);
421 removed += 1;
422 }
423 }
424
425 self.entry_count
426 .fetch_sub(removed as usize, Ordering::Relaxed);
427 Ok(removed)
428 }
429
430 async fn clear(&self) -> CacheResult<()> {
431 let mut entries = self.entries.write();
432 entries.clear();
433 self.entry_count.store(0, Ordering::Relaxed);
434
435 if self.config.enable_tags {
436 let mut index = self.tag_index.write();
437 index.clear();
438 }
439
440 Ok(())
441 }
442
443 async fn len(&self) -> CacheResult<usize> {
444 Ok(self.entry_count.load(Ordering::Relaxed))
445 }
446
447 async fn stats(&self) -> CacheResult<BackendStats> {
448 let entries = self.entries.read();
449 let memory_estimate: usize = entries
450 .values()
451 .map(|e| e.data.len() + 64) .sum();
453
454 Ok(BackendStats {
455 entries: entries.len(),
456 memory_bytes: Some(memory_estimate),
457 connections: None,
458 info: Some(format!("MemoryCache (max: {})", self.config.max_capacity)),
459 })
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[tokio::test]
468 async fn test_memory_cache_basic() {
469 let cache = MemoryCache::new(MemoryCacheConfig::new(100));
470
471 let key = CacheKey::new("test", "key1");
472
473 cache.set(&key, &"hello", None).await.unwrap();
475
476 let value: Option<String> = cache.get(&key).await.unwrap();
478 assert_eq!(value, Some("hello".to_string()));
479
480 assert!(cache.delete(&key).await.unwrap());
482
483 let value: Option<String> = cache.get(&key).await.unwrap();
485 assert!(value.is_none());
486 }
487
488 #[tokio::test]
489 async fn test_memory_cache_ttl() {
490 let config = MemoryCacheConfig::new(100).with_ttl(Duration::from_millis(50));
491 let cache = MemoryCache::new(config);
492
493 let key = CacheKey::new("test", "ttl");
494 cache.set(&key, &"expires soon", None).await.unwrap();
495
496 let value: Option<String> = cache.get(&key).await.unwrap();
498 assert!(value.is_some());
499
500 tokio::time::sleep(Duration::from_millis(60)).await;
502
503 let value: Option<String> = cache.get(&key).await.unwrap();
505 assert!(value.is_none());
506 }
507
508 #[tokio::test]
509 async fn test_memory_cache_eviction() {
510 let cache = MemoryCache::new(MemoryCacheConfig::new(5));
511
512 for i in 0..10 {
514 let key = CacheKey::new("test", format!("key{}", i));
515 cache.set(&key, &i, None).await.unwrap();
516 }
517
518 let len = cache.len().await.unwrap();
520 assert!(len <= 5);
521 }
522
523 #[tokio::test]
524 async fn test_memory_cache_pattern_invalidation() {
525 let cache = MemoryCache::new(MemoryCacheConfig::new(100));
526
527 for i in 0..5 {
529 let key = CacheKey::new("User", format!("id:{}", i));
530 cache.set(&key, &i, None).await.unwrap();
531 }
532 for i in 0..3 {
533 let key = CacheKey::new("Post", format!("id:{}", i));
534 cache.set(&key, &i, None).await.unwrap();
535 }
536
537 assert_eq!(cache.len().await.unwrap(), 8);
538
539 let removed = cache
541 .invalidate_pattern(&KeyPattern::entity("User"))
542 .await
543 .unwrap();
544 assert_eq!(removed, 5);
545 assert_eq!(cache.len().await.unwrap(), 3);
546 }
547
548 #[tokio::test]
549 async fn test_memory_cache_builder() {
550 let cache = MemoryCache::builder()
551 .max_capacity(1000)
552 .time_to_live(Duration::from_secs(60))
553 .build();
554
555 assert_eq!(cache.config().max_capacity, 1000);
556 assert_eq!(cache.config().time_to_live, Some(Duration::from_secs(60)));
557 }
558}