1extern crate alloc;
123
124use crate::lfuda::LfudaSegment;
125use crate::metrics::CacheMetrics;
126use alloc::boxed::Box;
127use alloc::collections::BTreeMap;
128use alloc::string::String;
129use alloc::vec::Vec;
130use core::borrow::Borrow;
131use core::hash::{BuildHasher, Hash};
132use core::num::NonZeroUsize;
133use parking_lot::Mutex;
134
135#[cfg(feature = "hashbrown")]
136use hashbrown::DefaultHashBuilder;
137
138#[cfg(not(feature = "hashbrown"))]
139use std::collections::hash_map::RandomState as DefaultHashBuilder;
140
141pub struct ConcurrentLfudaCache<K, V, S = DefaultHashBuilder> {
143 segments: Box<[Mutex<LfudaSegment<K, V, S>>]>,
144 hash_builder: S,
145}
146
147impl<K, V> ConcurrentLfudaCache<K, V, DefaultHashBuilder>
148where
149 K: Hash + Eq + Clone + Send,
150 V: Clone + Send,
151{
152 pub fn init(
161 config: crate::config::ConcurrentLfudaCacheConfig,
162 hasher: Option<DefaultHashBuilder>,
163 ) -> Self {
164 let segment_count = config.segments;
165 let capacity = config.base.capacity;
166 let max_size = config.base.max_size;
167 let initial_age = config.base.initial_age;
168
169 let hash_builder = hasher.unwrap_or_default();
170
171 let segment_capacity = capacity.get() / segment_count;
172 let segment_cap = NonZeroUsize::new(segment_capacity.max(1)).unwrap();
173 let segment_max_size = max_size / segment_count as u64;
174
175 let segments: Vec<_> = (0..segment_count)
176 .map(|_| {
177 let segment_config = crate::config::LfudaCacheConfig {
178 capacity: segment_cap,
179 initial_age,
180 max_size: segment_max_size,
181 };
182 Mutex::new(LfudaSegment::init(segment_config, hash_builder.clone()))
183 })
184 .collect();
185
186 Self {
187 segments: segments.into_boxed_slice(),
188 hash_builder,
189 }
190 }
191}
192
193impl<K, V, S> ConcurrentLfudaCache<K, V, S>
194where
195 K: Hash + Eq + Clone + Send,
196 V: Clone + Send,
197 S: BuildHasher + Clone + Send,
198{
199 #[inline]
200 fn segment_index<Q>(&self, key: &Q) -> usize
201 where
202 K: Borrow<Q>,
203 Q: ?Sized + Hash,
204 {
205 (self.hash_builder.hash_one(key) as usize) % self.segments.len()
206 }
207
208 pub fn capacity(&self) -> usize {
210 let mut total = 0usize;
211 for segment in self.segments.iter() {
212 total += segment.lock().cap().get();
213 }
214 total
215 }
216
217 pub fn segment_count(&self) -> usize {
219 self.segments.len()
220 }
221
222 pub fn len(&self) -> usize {
224 let mut total = 0usize;
225 for segment in self.segments.iter() {
226 total += segment.lock().len();
227 }
228 total
229 }
230
231 pub fn is_empty(&self) -> bool {
233 for segment in self.segments.iter() {
234 if !segment.lock().is_empty() {
235 return false;
236 }
237 }
238 true
239 }
240
241 pub fn get<Q>(&self, key: &Q) -> Option<V>
246 where
247 K: Borrow<Q>,
248 Q: ?Sized + Hash + Eq,
249 {
250 let idx = self.segment_index(key);
251 let mut segment = self.segments[idx].lock();
252 segment.get(key).cloned()
253 }
254
255 pub fn get_with<Q, F, R>(&self, key: &Q, f: F) -> Option<R>
260 where
261 K: Borrow<Q>,
262 Q: ?Sized + Hash + Eq,
263 F: FnOnce(&V) -> R,
264 {
265 let idx = self.segment_index(key);
266 let mut segment = self.segments[idx].lock();
267 segment.get(key).map(f)
268 }
269
270 pub fn put(&self, key: K, value: V) -> Option<(K, V)> {
274 let idx = self.segment_index(&key);
275 let mut segment = self.segments[idx].lock();
276 segment.put(key, value)
277 }
278
279 pub fn put_with_size(&self, key: K, value: V, size: u64) -> Option<(K, V)> {
281 let idx = self.segment_index(&key);
282 let mut segment = self.segments[idx].lock();
283 segment.put_with_size(key, value, size)
284 }
285
286 pub fn remove<Q>(&self, key: &Q) -> Option<V>
288 where
289 K: Borrow<Q>,
290 Q: ?Sized + Hash + Eq,
291 {
292 let idx = self.segment_index(key);
293 let mut segment = self.segments[idx].lock();
294 segment.remove(key)
295 }
296
297 pub fn contains_key<Q>(&self, key: &Q) -> bool
299 where
300 K: Borrow<Q>,
301 Q: ?Sized + Hash + Eq,
302 {
303 let idx = self.segment_index(key);
304 let mut segment = self.segments[idx].lock();
305 segment.get(key).is_some()
306 }
307
308 pub fn clear(&self) {
310 for segment in self.segments.iter() {
311 segment.lock().clear();
312 }
313 }
314
315 pub fn current_size(&self) -> u64 {
317 self.segments.iter().map(|s| s.lock().current_size()).sum()
318 }
319
320 pub fn max_size(&self) -> u64 {
322 self.segments.iter().map(|s| s.lock().max_size()).sum()
323 }
324}
325
326impl<K, V, S> CacheMetrics for ConcurrentLfudaCache<K, V, S>
327where
328 K: Hash + Eq + Clone + Send,
329 V: Clone + Send,
330 S: BuildHasher + Clone + Send,
331{
332 fn metrics(&self) -> BTreeMap<String, f64> {
333 let mut aggregated = BTreeMap::new();
334 for segment in self.segments.iter() {
335 let segment_metrics = segment.lock().metrics().metrics();
336 for (key, value) in segment_metrics {
337 *aggregated.entry(key).or_insert(0.0) += value;
338 }
339 }
340 aggregated
341 }
342
343 fn algorithm_name(&self) -> &'static str {
344 "ConcurrentLFUDA"
345 }
346}
347
348unsafe impl<K: Send, V: Send, S: Send> Send for ConcurrentLfudaCache<K, V, S> {}
349unsafe impl<K: Send, V: Send, S: Send + Sync> Sync for ConcurrentLfudaCache<K, V, S> {}
350
351impl<K, V, S> core::fmt::Debug for ConcurrentLfudaCache<K, V, S>
352where
353 K: Hash + Eq + Clone + Send,
354 V: Clone + Send,
355 S: BuildHasher + Clone + Send,
356{
357 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
358 f.debug_struct("ConcurrentLfudaCache")
359 .field("segment_count", &self.segments.len())
360 .field("total_len", &self.len())
361 .finish()
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use crate::config::{ConcurrentCacheConfig, ConcurrentLfudaCacheConfig, LfudaCacheConfig};
369
370 extern crate std;
371 use std::string::ToString;
372 use std::sync::Arc;
373 use std::thread;
374 use std::vec::Vec;
375
376 fn make_config(capacity: usize, segments: usize) -> ConcurrentLfudaCacheConfig {
377 ConcurrentCacheConfig {
378 base: LfudaCacheConfig {
379 capacity: NonZeroUsize::new(capacity).unwrap(),
380 initial_age: 0,
381 max_size: u64::MAX,
382 },
383 segments,
384 }
385 }
386
387 #[test]
388 fn test_basic_operations() {
389 let cache: ConcurrentLfudaCache<String, i32> =
390 ConcurrentLfudaCache::init(make_config(100, 16), None);
391
392 cache.put("a".to_string(), 1);
393 cache.put("b".to_string(), 2);
394
395 assert_eq!(cache.get(&"a".to_string()), Some(1));
396 assert_eq!(cache.get(&"b".to_string()), Some(2));
397 }
398
399 #[test]
400 fn test_concurrent_access() {
401 let cache: Arc<ConcurrentLfudaCache<String, i32>> =
402 Arc::new(ConcurrentLfudaCache::init(make_config(1000, 16), None));
403 let num_threads = 8;
404 let ops_per_thread = 500;
405
406 let mut handles: Vec<std::thread::JoinHandle<()>> = Vec::new();
407
408 for t in 0..num_threads {
409 let cache = Arc::clone(&cache);
410 handles.push(thread::spawn(move || {
411 for i in 0..ops_per_thread {
412 let key = std::format!("key_{}_{}", t, i);
413 cache.put(key.clone(), i);
414 let _ = cache.get(&key);
415 }
416 }));
417 }
418
419 for handle in handles {
420 handle.join().unwrap();
421 }
422
423 assert!(!cache.is_empty());
424 }
425
426 #[test]
427 fn test_capacity() {
428 let cache: ConcurrentLfudaCache<String, i32> =
429 ConcurrentLfudaCache::init(make_config(100, 16), None);
430
431 let capacity = cache.capacity();
433 assert!(capacity >= 16);
434 assert!(capacity <= 100);
435 }
436
437 #[test]
438 fn test_segment_count() {
439 let cache: ConcurrentLfudaCache<String, i32> =
440 ConcurrentLfudaCache::init(make_config(100, 8), None);
441
442 assert_eq!(cache.segment_count(), 8);
443 }
444
445 #[test]
446 fn test_len_and_is_empty() {
447 let cache: ConcurrentLfudaCache<String, i32> =
448 ConcurrentLfudaCache::init(make_config(100, 16), None);
449
450 assert!(cache.is_empty());
451 assert_eq!(cache.len(), 0);
452
453 cache.put("key1".to_string(), 1);
454 assert_eq!(cache.len(), 1);
455 assert!(!cache.is_empty());
456
457 cache.put("key2".to_string(), 2);
458 assert_eq!(cache.len(), 2);
459 }
460
461 #[test]
462 fn test_remove() {
463 let cache: ConcurrentLfudaCache<String, i32> =
464 ConcurrentLfudaCache::init(make_config(100, 16), None);
465
466 cache.put("key1".to_string(), 1);
467 cache.put("key2".to_string(), 2);
468
469 assert_eq!(cache.remove(&"key1".to_string()), Some(1));
470 assert_eq!(cache.len(), 1);
471 assert_eq!(cache.get(&"key1".to_string()), None);
472
473 assert_eq!(cache.remove(&"nonexistent".to_string()), None);
474 }
475
476 #[test]
477 fn test_clear() {
478 let cache: ConcurrentLfudaCache<String, i32> =
479 ConcurrentLfudaCache::init(make_config(100, 16), None);
480
481 cache.put("key1".to_string(), 1);
482 cache.put("key2".to_string(), 2);
483 cache.put("key3".to_string(), 3);
484
485 assert_eq!(cache.len(), 3);
486
487 cache.clear();
488
489 assert_eq!(cache.len(), 0);
490 assert!(cache.is_empty());
491 assert_eq!(cache.get(&"key1".to_string()), None);
492 }
493
494 #[test]
495 fn test_contains_key() {
496 let cache: ConcurrentLfudaCache<String, i32> =
497 ConcurrentLfudaCache::init(make_config(100, 16), None);
498
499 cache.put("exists".to_string(), 1);
500
501 assert!(cache.contains_key(&"exists".to_string()));
502 assert!(!cache.contains_key(&"missing".to_string()));
503 }
504
505 #[test]
506 fn test_get_with() {
507 let cache: ConcurrentLfudaCache<String, String> =
508 ConcurrentLfudaCache::init(make_config(100, 16), None);
509
510 cache.put("key".to_string(), "hello world".to_string());
511
512 let len = cache.get_with(&"key".to_string(), |v: &String| v.len());
513 assert_eq!(len, Some(11));
514
515 let missing = cache.get_with(&"missing".to_string(), |v: &String| v.len());
516 assert_eq!(missing, None);
517 }
518
519 #[test]
520 fn test_aging_behavior() {
521 let cache: ConcurrentLfudaCache<String, i32> =
522 ConcurrentLfudaCache::init(make_config(48, 16), None);
523
524 cache.put("a".to_string(), 1);
525 cache.put("b".to_string(), 2);
526 cache.put("c".to_string(), 3);
527
528 for _ in 0..5 {
530 let _ = cache.get(&"a".to_string());
531 let _ = cache.get(&"c".to_string());
532 }
533
534 cache.put("d".to_string(), 4);
536
537 assert!(cache.len() <= 48);
538 }
539
540 #[test]
541 fn test_eviction_on_capacity() {
542 let cache: ConcurrentLfudaCache<String, i32> =
543 ConcurrentLfudaCache::init(make_config(80, 16), None);
544
545 for i in 0..10 {
547 cache.put(std::format!("key{}", i), i);
548 }
549
550 assert!(cache.len() <= 80);
552 }
553
554 #[test]
555 fn test_metrics() {
556 let cache: ConcurrentLfudaCache<String, i32> =
557 ConcurrentLfudaCache::init(make_config(100, 16), None);
558
559 cache.put("a".to_string(), 1);
560 cache.put("b".to_string(), 2);
561
562 let metrics = cache.metrics();
563 assert!(!metrics.is_empty());
565 }
566
567 #[test]
568 fn test_algorithm_name() {
569 let cache: ConcurrentLfudaCache<String, i32> =
570 ConcurrentLfudaCache::init(make_config(100, 16), None);
571
572 assert_eq!(cache.algorithm_name(), "ConcurrentLFUDA");
573 }
574
575 #[test]
576 fn test_empty_cache_operations() {
577 let cache: ConcurrentLfudaCache<String, i32> =
578 ConcurrentLfudaCache::init(make_config(100, 16), None);
579
580 assert!(cache.is_empty());
581 assert_eq!(cache.len(), 0);
582 assert_eq!(cache.get(&"missing".to_string()), None);
583 assert_eq!(cache.remove(&"missing".to_string()), None);
584 assert!(!cache.contains_key(&"missing".to_string()));
585 }
586
587 #[test]
588 fn test_borrowed_key_lookup() {
589 let cache: ConcurrentLfudaCache<String, i32> =
590 ConcurrentLfudaCache::init(make_config(100, 16), None);
591
592 cache.put("test_key".to_string(), 42);
593
594 let key_str = "test_key";
596 assert_eq!(cache.get(key_str), Some(42));
597 assert!(cache.contains_key(key_str));
598 assert_eq!(cache.remove(key_str), Some(42));
599 }
600
601 #[test]
602 fn test_frequency_with_aging() {
603 let cache: ConcurrentLfudaCache<String, i32> =
604 ConcurrentLfudaCache::init(make_config(100, 16), None);
605
606 cache.put("key".to_string(), 1);
607
608 for _ in 0..10 {
610 let _ = cache.get(&"key".to_string());
611 }
612
613 assert_eq!(cache.get(&"key".to_string()), Some(1));
615 }
616
617 #[test]
618 fn test_dynamic_aging() {
619 let cache: ConcurrentLfudaCache<String, i32> =
620 ConcurrentLfudaCache::init(make_config(80, 16), None);
621
622 for i in 0..5 {
624 cache.put(std::format!("key{}", i), i);
625 for _ in 0..i {
626 let _ = cache.get(&std::format!("key{}", i));
627 }
628 }
629
630 for i in 5..10 {
632 cache.put(std::format!("key{}", i), i);
633 }
634
635 assert!(cache.len() <= 80);
636 }
637}