1use crossbeam_utils::CachePadded;
30use crossbeam_utils::sync::{ShardedLock, ShardedLockReadGuard, ShardedLockWriteGuard};
31use std::collections::HashMap;
32use std::hash::{BuildHasher, Hash, RandomState};
33use std::sync::atomic::{AtomicBool, Ordering};
34use std::sync::{Arc, PoisonError};
35
36pub struct SieveCache<K, V, S = RandomState> {
38 shards: Vec<ShardedLock<Shard<K, V, S>>>,
39 hash_builder: S,
40 total_capacity: usize,
41}
42
43impl<K, V, S> SieveCache<K, V, S> {
44 pub const DEFAULT_SHARDS: usize = 256;
46}
47
48struct Node<K, V> {
49 key: K,
50 value: Arc<V>,
51 visited: CachePadded<AtomicBool>,
52 charge: usize,
53}
54
55struct Shard<K, V, S> {
56 map: HashMap<K, usize, S>,
58 nodes: Vec<Node<K, V>>,
63 hand: Option<usize>,
67 used_charge: usize,
69 capacity: usize,
71}
72
73impl<K, V, S> Shard<K, V, S>
74where
75 K: Eq + Hash + Clone,
76 V: Send + Sync + 'static,
77 S: BuildHasher + Clone,
78{
79 fn new(hash_builder: S, capacity: usize) -> Self {
80 Self {
81 map: HashMap::with_hasher(hash_builder),
82 nodes: Vec::new(),
83 hand: None,
84 used_charge: 0,
85 capacity,
86 }
87 }
88
89 fn insert(&mut self, key: K, value: V, charge: usize) -> Option<Arc<V>> {
90 if let Some(&idx) = self.map.get(&key) {
91 let (prev, old_charge) = {
94 let node = &mut self.nodes[idx];
95 let prev = node.value.clone();
96 let old_charge = node.charge;
97 node.value = Arc::new(value);
98 node.charge = charge;
99 node.visited.store(true, Ordering::Relaxed);
100 (prev, old_charge)
101 };
102 self.used_charge = self.used_charge.saturating_sub(old_charge);
103 self.used_charge = self.used_charge.saturating_add(charge);
104 self.evict_until_within_capacity();
105 return Some(prev);
106 }
107
108 let idx = self.nodes.len();
109 self.nodes.push(Node {
110 key: key.clone(),
111 value: Arc::new(value),
112 visited: CachePadded::new(AtomicBool::new(false)),
114 charge,
115 });
116 self.map.insert(key, idx);
117 self.used_charge = self.used_charge.saturating_add(charge);
118 self.evict_until_within_capacity();
119 None
120 }
121
122 fn get(&self, key: &K) -> Option<Arc<V>> {
123 let idx = self.map.get(key).copied()?;
124 let node = &self.nodes[idx];
125 node.visited.store(true, Ordering::Relaxed);
128 Some(node.value.clone())
129 }
130
131 fn remove(&mut self, key: &K) -> Option<Arc<V>> {
132 let idx = self.map.get(key).copied()?;
133 let removed = self.remove_at_index(idx);
134 Some(removed.value)
135 }
136
137 fn remove_if<F>(&mut self, predicate: F) -> usize
138 where
139 F: Fn(&K) -> bool,
140 {
141 let keys: Vec<K> = self
142 .map
143 .keys()
144 .filter(|key| predicate(key))
145 .cloned()
146 .collect();
147 let mut removed = 0;
148 for key in keys {
149 if let Some(idx) = self.map.get(&key).copied() {
150 let _ = self.remove_at_index(idx);
151 removed += 1;
152 }
153 }
154 removed
155 }
156
157 fn contains_key(&self, key: &K) -> bool {
158 self.map.contains_key(key)
159 }
160
161 fn len(&self) -> usize {
162 self.map.len()
163 }
164
165 fn usage(&self) -> (usize, usize) {
166 (self.used_charge, self.capacity)
167 }
168
169 fn evict_until_within_capacity(&mut self) {
170 while self.used_charge > self.capacity {
171 if !self.evict_one() && !self.evict_one() {
174 break;
175 }
176 }
177 }
178
179 fn evict_one(&mut self) -> bool {
180 if self.nodes.is_empty() {
181 return false;
182 }
183
184 let mut current_idx = self.hand.unwrap_or(self.nodes.len() - 1);
188 let start_idx = current_idx;
189 let mut wrapped = false;
190 let mut found_idx = None;
191
192 loop {
193 if !self.nodes[current_idx]
194 .visited
195 .swap(false, Ordering::Relaxed)
196 {
197 found_idx = Some(current_idx);
199 break;
200 }
201
202 current_idx = if current_idx > 0 {
204 current_idx - 1
205 } else {
206 if wrapped {
207 break;
208 }
209 wrapped = true;
210 self.nodes.len() - 1
211 };
212
213 if current_idx == start_idx {
214 break;
215 }
216 }
217
218 if let Some(idx) = found_idx {
219 let _ = self.remove_at_index(idx);
220 true
221 } else {
222 false
223 }
224 }
225
226 fn remove_at_index(&mut self, idx: usize) -> Node<K, V> {
227 let last_idx = self.nodes.len() - 1;
228
229 if let Some(hand_idx) = self.hand {
231 if hand_idx == idx {
232 self.hand = if idx > 0 {
233 Some(idx - 1)
234 } else if self.nodes.len() > 1 {
235 Some(self.nodes.len() - 2)
236 } else {
237 None
238 };
239 } else if hand_idx == last_idx && idx != last_idx {
240 self.hand = Some(idx);
241 }
242 }
243
244 let removed = if idx == last_idx {
246 self.nodes.pop().expect("index must exist")
247 } else {
248 let moved_key = self.nodes[last_idx].key.clone();
249 let removed = self.nodes.swap_remove(idx);
250 self.map.insert(moved_key, idx);
251 removed
252 };
253
254 self.map.remove(&removed.key);
255 self.used_charge = self.used_charge.saturating_sub(removed.charge);
256
257 if self.nodes.is_empty() {
258 self.hand = None;
259 }
260
261 removed
262 }
263}
264
265impl<K, V> SieveCache<K, V, RandomState>
266where
267 K: Eq + Hash + Clone,
268 V: Send + Sync + 'static,
269{
270 pub fn new(total_capacity_bytes: usize) -> Self {
275 Self::with_hasher(
276 total_capacity_bytes,
277 SieveCache::<K, V>::DEFAULT_SHARDS,
278 RandomState::new(),
279 )
280 }
281
282 pub fn with_shards(total_capacity_bytes: usize, num_shards: usize) -> Self {
288 Self::with_hasher(total_capacity_bytes, num_shards, RandomState::new())
289 }
290}
291
292impl<K, V, S> SieveCache<K, V, S>
293where
294 K: Eq + Hash + Clone,
295 V: Send + Sync + 'static,
296 S: BuildHasher + Clone,
297{
298 pub fn with_hasher(total_capacity_bytes: usize, num_shards: usize, hash_builder: S) -> Self {
306 assert!(num_shards > 0, "num_shards must be > 0");
307 assert!(
308 num_shards.is_power_of_two(),
309 "num_shards must be a power of two"
310 );
311
312 let base_capacity = total_capacity_bytes / num_shards;
313 let remainder = total_capacity_bytes % num_shards;
314 let shards = (0..num_shards)
315 .map(|idx| {
316 let shard_capacity = base_capacity + usize::from(idx < remainder);
318 ShardedLock::new(Shard::new(hash_builder.clone(), shard_capacity))
319 })
320 .collect();
321
322 Self {
323 shards,
324 hash_builder,
325 total_capacity: total_capacity_bytes,
326 }
327 }
328
329 pub fn insert(&self, key: K, value: V, charge: usize) -> Option<Arc<V>> {
337 let shard_idx = self.shard_index(&key);
338 let mut shard = self.lock_shard_write(shard_idx);
339 shard.insert(key, value, charge)
340 }
341
342 pub fn get(&self, key: &K) -> Option<Arc<V>> {
344 let shard_idx = self.shard_index(key);
345 let shard = self.lock_shard_read(shard_idx);
346 shard.get(key)
347 }
348
349 pub fn remove(&self, key: &K) -> Option<Arc<V>> {
351 let shard_idx = self.shard_index(key);
352 let mut shard = self.lock_shard_write(shard_idx);
353 shard.remove(key)
354 }
355
356 pub fn remove_if<F>(&self, predicate: F) -> usize
358 where
359 F: Fn(&K) -> bool,
360 {
361 let mut removed = 0;
362 for shard_mutex in &self.shards {
363 let mut shard = self.lock_shard_write_ref(shard_mutex);
364 removed += shard.remove_if(&predicate);
365 }
366 removed
367 }
368
369 pub fn contains_key(&self, key: &K) -> bool {
371 let shard_idx = self.shard_index(key);
372 self.lock_shard_read(shard_idx).contains_key(key)
373 }
374
375 pub fn len(&self) -> usize {
377 self.shards
378 .iter()
379 .map(|shard| self.lock_shard_read_ref(shard).len())
380 .sum()
381 }
382
383 pub fn is_empty(&self) -> bool {
385 self.len() == 0
386 }
387
388 pub fn total_charge(&self) -> usize {
390 self.shards
391 .iter()
392 .map(|shard| self.lock_shard_read_ref(shard).used_charge)
393 .sum()
394 }
395
396 pub fn total_capacity(&self) -> usize {
398 self.total_capacity
399 }
400
401 pub fn shard_count(&self) -> usize {
403 self.shards.len()
404 }
405
406 pub fn shard_usage(&self, idx: usize) -> (usize, usize) {
412 let shard = self.lock_shard_read(idx);
413 shard.usage()
414 }
415
416 fn shard_index(&self, key: &K) -> usize {
417 (self.hash_builder.hash_one(key) as usize) & (self.shards.len() - 1)
418 }
419
420 fn lock_shard_write(&self, idx: usize) -> ShardedLockWriteGuard<'_, Shard<K, V, S>> {
421 self.lock_shard_write_ref(&self.shards[idx])
422 }
423
424 fn lock_shard_read(&self, idx: usize) -> ShardedLockReadGuard<'_, Shard<K, V, S>> {
425 self.lock_shard_read_ref(&self.shards[idx])
426 }
427
428 fn lock_shard_write_ref<'a>(
429 &self,
430 shard: &'a ShardedLock<Shard<K, V, S>>,
431 ) -> ShardedLockWriteGuard<'a, Shard<K, V, S>> {
432 shard.write().unwrap_or_else(PoisonError::into_inner)
433 }
434
435 fn lock_shard_read_ref<'a>(
436 &self,
437 shard: &'a ShardedLock<Shard<K, V, S>>,
438 ) -> ShardedLockReadGuard<'a, Shard<K, V, S>> {
439 shard.read().unwrap_or_else(PoisonError::into_inner)
440 }
441
442 #[cfg(test)]
443 fn validate_invariants(&self) {
444 for shard_mutex in &self.shards {
445 let shard = self.lock_shard_read_ref(shard_mutex);
446 let live_sum: usize = shard.nodes.iter().map(|node| node.charge).sum();
447 assert_eq!(
448 shard.map.len(),
449 shard.nodes.len(),
450 "map/nodes length mismatch"
451 );
452 assert_eq!(shard.used_charge, live_sum, "used charge mismatch");
453 for (key, idx) in &shard.map {
454 assert!(
455 *key == shard.nodes[*idx].key,
456 "index map points to wrong key"
457 );
458 }
459 if !shard.map.is_empty() {
460 assert!(
461 shard.used_charge <= shard.capacity,
462 "used {} exceeds capacity {}",
463 shard.used_charge,
464 shard.capacity
465 );
466 }
467 }
468 }
469
470 #[cfg(test)]
471 fn entry_visited(&self, key: &K) -> Option<bool> {
472 let shard = self.lock_shard_read(self.shard_index(key));
473 let idx = shard.map.get(key).copied()?;
474 Some(shard.nodes[idx].visited.load(Ordering::Relaxed))
475 }
476
477 #[cfg(test)]
478 fn set_entry_visited(&self, key: &K, visited: bool) {
479 let shard = self.lock_shard_read(self.shard_index(key));
480 if let Some(idx) = shard.map.get(key).copied() {
481 shard.nodes[idx].visited.store(visited, Ordering::Relaxed);
482 }
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::SieveCache;
489 use rand::{Rng, SeedableRng, rngs::StdRng};
490 use std::sync::Arc;
491 use std::thread;
492
493 #[test]
494 fn basic_insert_get_remove() {
495 let cache = SieveCache::<u64, String>::with_shards(1024, 4);
496 assert!(cache.insert(1, "a".to_string(), 8).is_none());
497 assert_eq!(&*cache.get(&1).unwrap(), "a");
498 assert_eq!(&*cache.remove(&1).unwrap(), "a");
499 assert!(cache.get(&1).is_none());
500 }
501
502 #[test]
503 fn replacement_returns_previous_and_updates_charge() {
504 let cache = SieveCache::<u64, String>::with_shards(64, 2);
505 assert!(cache.insert(1, "a".to_string(), 10).is_none());
506 let prev = cache.insert(1, "b".to_string(), 7).unwrap();
507 assert_eq!(&*prev, "a");
508 assert_eq!(&*cache.get(&1).unwrap(), "b");
509 assert_eq!(cache.total_charge(), 7);
510 cache.validate_invariants();
511 }
512
513 #[test]
514 fn weighted_eviction_happens_by_charge() {
515 let cache = SieveCache::<u64, &'static str>::with_shards(10, 1);
516 cache.insert(1, "a", 6);
517 cache.insert(2, "b", 6);
518 assert_eq!(cache.total_charge(), 6);
519 assert!(!cache.contains_key(&1) || !cache.contains_key(&2));
520 cache.validate_invariants();
521 }
522
523 #[test]
524 fn fresh_insert_can_be_evicted_before_older_unvisited_entry() {
525 let cache = SieveCache::<u64, &'static str>::with_shards(8, 1);
526 cache.insert(1, "a", 4);
527 cache.insert(2, "b", 4);
528 cache.set_entry_visited(&1, true);
529 cache.set_entry_visited(&2, false);
530 cache.insert(3, "c", 4);
531 assert!(cache.contains_key(&1));
532 assert!(cache.contains_key(&2));
533 assert!(!cache.contains_key(&3));
534 cache.validate_invariants();
535 }
536
537 #[test]
538 fn replacement_does_not_break_index_map() {
539 let cache = SieveCache::<u64, &'static str>::with_shards(8, 1);
540 cache.insert(1, "old", 4);
541 cache.insert(1, "new", 4);
542 cache.insert(2, "x", 4);
543 cache.insert(3, "y", 4);
544 if let Some(v) = cache.get(&1) {
545 assert_eq!(*v, "new");
546 }
547 cache.validate_invariants();
548 }
549
550 #[test]
551 fn remove_keeps_structure_consistent() {
552 let cache = SieveCache::<u64, &'static str>::with_shards(8, 1);
553 cache.insert(1, "a", 4);
554 cache.insert(2, "b", 4);
555 cache.remove(&1);
556 cache.insert(3, "c", 4);
557 assert!(cache.contains_key(&2) || cache.contains_key(&3));
558 cache.validate_invariants();
559 }
560
561 #[test]
562 fn shard_selection_stability() {
563 let cache = SieveCache::<u64, u64>::with_shards(1024, 8);
564 for key in 0..100 {
565 cache.insert(key, key, 1);
566 assert_eq!(cache.shard_index(&key), cache.shard_index(&key));
567 }
568 cache.validate_invariants();
569 }
570
571 #[test]
572 fn visited_bit_set_on_get() {
573 let cache = SieveCache::<u64, &'static str>::with_shards(8, 1);
574 cache.insert(1, "a", 4);
575 cache.set_entry_visited(&1, false);
576 assert_eq!(cache.entry_visited(&1), Some(false));
577 let _ = cache.get(&1);
578 assert_eq!(cache.entry_visited(&1), Some(true));
579 cache.validate_invariants();
580 }
581
582 #[test]
583 fn fresh_insert_starts_unvisited() {
584 let cache = SieveCache::<u64, &'static str>::with_shards(8, 1);
585 cache.insert(1, "a", 4);
586 assert_eq!(cache.entry_visited(&1), Some(false));
587
588 cache.insert(1, "b", 4);
589 assert_eq!(cache.entry_visited(&1), Some(true));
590 cache.validate_invariants();
591 }
592
593 #[test]
594 fn oversize_entry_policy_is_insert_then_evict() {
595 let cache = SieveCache::<u64, &'static str>::with_shards(8, 1);
596 cache.insert(1, "big", 32);
597 assert_eq!(cache.total_charge(), 0);
598 assert!(!cache.contains_key(&1));
599 cache.validate_invariants();
600 }
601
602 #[test]
603 fn concurrency_smoke_test() {
604 let cache = Arc::new(SieveCache::<u64, u64>::with_shards(4096, 16));
605 let mut threads = Vec::new();
606 for tid in 0..8 {
607 let cache = cache.clone();
608 threads.push(thread::spawn(move || {
609 let mut rng = StdRng::seed_from_u64(1234 + tid);
610 for _ in 0..10_000 {
611 let key = rng.gen_range(0..256);
612 match rng.gen_range(0..3) {
613 0 => {
614 cache.insert(key, key, rng.gen_range(1..32));
615 }
616 1 => {
617 let _ = cache.get(&key);
618 }
619 _ => {
620 let _ = cache.remove(&key);
621 }
622 }
623 }
624 }));
625 }
626 for t in threads {
627 t.join().unwrap();
628 }
629 cache.validate_invariants();
630 }
631
632 #[test]
633 fn charge_accounting_after_replacement_and_remove() {
634 let cache = SieveCache::<u64, &'static str>::with_shards(100, 2);
635 cache.insert(1, "a", 20);
636 cache.insert(1, "b", 30);
637 assert_eq!(cache.total_charge(), 30);
638 cache.remove(&1);
639 assert_eq!(cache.total_charge(), 0);
640 cache.validate_invariants();
641 }
642
643 #[test]
644 fn basic_sequence() {
645 let cache = SieveCache::<String, String>::with_shards(3, 1);
646 assert!(
647 cache
648 .insert("foo".to_string(), "foocontent".to_string(), 1)
649 .is_none()
650 );
651 assert!(
652 cache
653 .insert("bar".to_string(), "barcontent".to_string(), 1)
654 .is_none()
655 );
656 assert_eq!(
657 cache
658 .remove(&"bar".to_string())
659 .as_deref()
660 .map(String::as_str),
661 Some("barcontent")
662 );
663 assert!(
664 cache
665 .insert("bar2".to_string(), "bar2content".to_string(), 1)
666 .is_none()
667 );
668 assert!(
669 cache
670 .insert("bar3".to_string(), "bar3content".to_string(), 1)
671 .is_none()
672 );
673 assert_eq!(
674 cache.get(&"foo".to_string()).as_deref().map(String::as_str),
675 Some("foocontent")
676 );
677 assert_eq!(cache.get(&"bar".to_string()), None);
678 assert_eq!(
679 cache
680 .get(&"bar2".to_string())
681 .as_deref()
682 .map(String::as_str),
683 Some("bar2content")
684 );
685 assert_eq!(
686 cache
687 .get(&"bar3".to_string())
688 .as_deref()
689 .map(String::as_str),
690 Some("bar3content")
691 );
692 }
693
694 #[test]
695 fn visited_flag_update() {
696 let cache = SieveCache::<String, String>::with_shards(2, 1);
697 cache.insert("key1".to_string(), "value1".to_string(), 1);
698 cache.insert("key2".to_string(), "value2".to_string(), 1);
699 cache.insert("key1".to_string(), "updated".to_string(), 1);
700 cache.insert("key3".to_string(), "value3".to_string(), 1);
701 assert_eq!(
702 cache
703 .get(&"key1".to_string())
704 .as_deref()
705 .map(String::as_str),
706 Some("updated")
707 );
708 cache.validate_invariants();
709 }
710
711 #[test]
712 fn insert_never_exceeds_capacity_when_all_visited() {
713 let cache = SieveCache::<String, u64>::with_shards(2, 1);
714 cache.insert("a".to_string(), 1, 1);
715 cache.insert("b".to_string(), 2, 1);
716 assert!(cache.get(&"a".to_string()).is_some());
717 assert!(cache.get(&"b".to_string()).is_some());
718 cache.insert("c".to_string(), 3, 1);
719 assert!(cache.len() <= 2);
720 assert!(cache.total_charge() <= cache.total_capacity());
721 }
722
723 #[test]
724 fn shard_capacity_remainder_distribution_is_deterministic() {
725 let cache = SieveCache::<u64, u64>::with_shards(10, 4);
726 let caps: Vec<usize> = (0..4).map(|i| cache.shard_usage(i).1).collect();
727 assert_eq!(caps, vec![3, 3, 2, 2]);
728 }
729
730 #[test]
731 fn zero_charge_entries_do_not_increase_budget_usage() {
732 let cache = SieveCache::<u64, u64>::with_shards(1, 1);
733 cache.insert(1, 10, 0);
734 cache.insert(2, 20, 0);
735 cache.insert(3, 30, 0);
736 assert_eq!(cache.total_charge(), 0);
737 assert_eq!(cache.len(), 3);
738 assert_eq!(*cache.get(&1).unwrap(), 10);
739 }
740}