Skip to main content

omega_cache/core/
cms.rs

1use crate::core::thread_context::ThreadContext;
2use std::hash::{Hash, Hasher};
3use std::sync::atomic::AtomicU16;
4use std::sync::atomic::Ordering::{Acquire, Relaxed};
5use twox_hash::xxhash64::Hasher as XxHash64;
6
7/// A collection of high-entropy 64-bit prime seeds for row-level hashing independence.
8///
9/// These seeds initialize the `XxHash64` state for each row. Using static primes ensures:
10/// 1. **Independence**: Minimizes the probability of "secondary collisions" across rows.
11/// 2. **Bit Distribution**: Spreads hash output evenly across the column bitmask.
12static SEEDS: [u64; 8] = [
13    0x9e3779b97f4a7c15,
14    0xbf58476d1ce4e5b9,
15    0x94d049bb133111eb,
16    0xff51afd7ed558ccd,
17    0x6a09e667f3bcc908,
18    0xbb67ae8584caa73b,
19    0x3c6ef372fe94f82b,
20    0xa54ff53a5f1d36f1,
21];
22
23/// A high-concurrency, memory-efficient frequency estimator.
24///
25/// Uses `AtomicU16` counters in a 2D matrix to provide a probabilistic upper bound
26/// on item frequency. Designed for high-throughput environments where a small
27/// overestimation error is acceptable in exchange for wait-free/lock-free performance.
28pub struct CountMinSketch {
29    counters: Box<[AtomicU16]>,
30    columns: usize,
31    rows: usize,
32}
33
34impl CountMinSketch {
35    /// Creates a new `CountMinSketch` with the specified logical dimensions.
36    ///
37    /// # Arguments
38    /// * `columns` - The logical width per row. The actual storage is doubled and
39    ///   rounded to the next power of two to optimize indexing via bitwise masking.
40    /// * `rows` - The number of independent hash functions.
41    ///
42    /// # Panics
43    /// Panics if `rows` > 8, as it exceeds the available static prime seeds.
44    #[inline]
45    pub fn new(columns: usize, rows: usize) -> Self {
46        assert!(rows <= 8, "Depth exceeds available static seeds (8)");
47        let columns = (columns * 2).next_power_of_two();
48
49        let counts = (0..columns * rows)
50            .map(|_| AtomicU16::default())
51            .collect::<Vec<_>>()
52            .into_boxed_slice();
53
54        Self {
55            counters: counts,
56            columns,
57            rows,
58        }
59    }
60
61    /// Increments the frequency estimate for a key across all rows.
62    ///
63    /// Uses a saturating atomic CAS loop to ensure counters never overflow.
64    /// If contention is detected, the provided `backoff` is utilized to
65    /// reduce CPU cache-coherency traffic.
66    pub fn increment<K>(&self, key: &K, context: &ThreadContext)
67    where
68        K: Eq + Hash + ?Sized,
69    {
70        let mut skip = 0;
71
72        for seed in self.seeds() {
73            let hash = self.hash(key, seed) as usize;
74
75            let column = hash & (self.columns - 1);
76            let index = skip + column;
77
78            let mut counter = self.counters[index].load(Acquire);
79
80            while counter < u16::MAX {
81                match self.counters[index].compare_exchange_weak(
82                    counter,
83                    counter + 1,
84                    Relaxed,
85                    Relaxed,
86                ) {
87                    Ok(_) => {
88                        context.decay();
89                        break;
90                    }
91                    Err(latest) => {
92                        counter = latest;
93                        context.wait();
94                    }
95                }
96            }
97
98            skip += self.columns;
99        }
100    }
101
102    /// Internal helper to map row indices to their respective static seeds.
103    #[inline(always)]
104    fn seeds(&self) -> Vec<u64> {
105        (0..self.rows).map(|index| SEEDS[index]).collect()
106    }
107
108    /// Decrements the frequency estimate for a key, saturating at zero.
109    ///
110    /// Useful for manual aging or correction. The CAS loop prevents
111    /// integer underflow, which would otherwise erroneously transform
112    /// a "cold" item into an "ultra-hot" item (65,535).
113    pub fn decrement<K>(&self, key: &K, context: &ThreadContext)
114    where
115        K: Eq + Hash + ?Sized,
116    {
117        let mut skip = 0;
118
119        for seed in self.seeds() {
120            let hash = self.hash(key, seed) as usize;
121
122            let column = hash & (self.columns - 1);
123            let index = skip + column;
124
125            let mut counter = self.counters[index].load(Acquire);
126
127            while counter > 0 {
128                match self.counters[index].compare_exchange_weak(
129                    counter,
130                    counter - 1,
131                    Relaxed,
132                    Relaxed,
133                ) {
134                    Ok(_) => {
135                        context.decay();
136                        break;
137                    }
138                    Err(latest) => {
139                        counter = latest;
140                        context.wait();
141                    }
142                }
143            }
144
145            skip += self.columns;
146        }
147    }
148
149    /// Performs a global aging operation by halving every counter in the sketch.
150    ///
151    /// This reduces the "weight" of historical data, allowing the sketch
152    /// to adapt to changes in key distribution over time. Uses a CAS loop
153    /// per counter to maintain atomicity during the bit-shift.
154    pub fn decay(&self, context: &ThreadContext) {
155        for counter in &self.counters {
156            let mut counter_value = counter.load(Relaxed);
157
158            if counter_value > 0 {
159                match counter.compare_exchange_weak(
160                    counter_value,
161                    counter_value >> 1,
162                    Relaxed,
163                    Relaxed,
164                ) {
165                    Ok(_) => {
166                        context.decay();
167                    }
168                    Err(latest) => {
169                        counter_value = latest;
170                        context.wait();
171                    }
172                }
173            }
174        }
175    }
176
177    /// Checks if an item is likely present in the sketch (estimate > 0).
178    pub fn contains<K: Eq + Hash>(&self, key: &K) -> bool {
179        self.get(key) > 0
180    }
181
182    /// Retrieves the estimated frequency of a key.
183    ///
184    /// The estimate is the minimum value found across all rows for the given key.
185    /// This is a mathematically guaranteed upper bound of the true frequency.
186    pub fn get<K>(&self, key: &K) -> u16
187    where
188        K: Eq + Hash + ?Sized,
189    {
190        let mut skip = 0;
191        let mut frequency = u32::MAX;
192
193        for seed in (0..self.rows).map(|index| SEEDS[index]) {
194            let hash = self.hash(key, seed) as usize;
195            let index = skip + (hash & (self.columns - 1));
196
197            let counter = self.counters[index].load(Relaxed) as u32;
198            frequency = frequency.min(counter);
199
200            skip += self.columns
201        }
202
203        if frequency == u32::MAX {
204            0
205        } else {
206            frequency as u16
207        }
208    }
209
210    /// Hashes a key with a specific seed using the XxHash64 algorithm.
211    #[inline(always)]
212    fn hash<K: Eq + Hash + ?Sized>(&self, key: &K, seed: u64) -> u64 {
213        let mut hasher = XxHash64::with_seed(seed);
214        key.hash(&mut hasher);
215        hasher.finish()
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use rand::distr::{Alphanumeric, SampleString};
223    use std::thread::scope;
224
225    fn random_key(len: usize) -> String {
226        Alphanumeric.sample_string(&mut rand::rng(), len)
227    }
228
229    #[test]
230    fn test_count_min_sketch_should_increment_and_retrieve_frequency() {
231        let cms = CountMinSketch::new(128, 4);
232        let key = random_key(10);
233        let context = ThreadContext::default();
234
235        cms.increment(&key, &context);
236        cms.increment(&key, &context);
237        cms.increment(&key, &context);
238
239        assert_eq!(
240            cms.get(&key),
241            3,
242            "Frequency should reflect the exact number of increments."
243        );
244    }
245
246    #[test]
247    fn test_count_min_sketch_should_saturate_at_max_logical_value() {
248        let cms = CountMinSketch::new(64, 2);
249        let key = random_key(10);
250        let context = ThreadContext::default();
251
252        for _ in 0..100000 {
253            cms.increment(&key, &context);
254        }
255
256        assert_eq!(
257            cms.get(&key),
258            u16::MAX,
259            "Counters must cap at MAX_FREQUENCY to prevent wrap-around."
260        );
261    }
262
263    #[test]
264    fn test_count_min_sketch_should_halve_all_counters_on_decay() {
265        let cms = CountMinSketch::new(1024, 4);
266        let key = random_key(10);
267        let context = ThreadContext::default();
268
269        for _ in 0..20 {
270            cms.increment(&key, &context);
271        }
272
273        assert_eq!(cms.get(&key), 20);
274
275        cms.decay(&context);
276        assert_eq!(cms.get(&key), 10);
277
278        cms.decay(&context);
279        assert_eq!(cms.get(&key), 5);
280
281        cms.decay(&context);
282        assert_eq!(cms.get(&key), 2); // 5 >> 1 = 2
283    }
284
285    #[test]
286    fn test_count_min_sketch_should_saturate_at_zero_on_decrement() {
287        let cms = CountMinSketch::new(128, 4);
288        let key = random_key(10);
289        let context = ThreadContext::default();
290
291        cms.increment(&key, &context);
292        cms.decrement(&key, &context);
293        assert_eq!(cms.get(&key), 0);
294
295        cms.decrement(&key, &context);
296        assert_eq!(cms.get(&key), 0, "Counter must not underflow below zero.");
297    }
298
299    #[test]
300    fn test_count_min_sketch_should_maintain_consistent_state_under_contention() {
301        let cms = CountMinSketch::new(16, 4); // Small width to force collisions
302        let num_threads = 8;
303        let ops_per_thread = 100;
304        let key = random_key(10);
305
306        scope(|s| {
307            for _ in 0..num_threads {
308                s.spawn(|| {
309                    let context = ThreadContext::default();
310                    for _ in 0..ops_per_thread {
311                        cms.increment(&key, &context);
312                    }
313                });
314            }
315        });
316
317        assert_eq!(
318            cms.get(&key),
319            (num_threads * ops_per_thread) as u16,
320            "Atomic increments must be consistent across multiple threads."
321        );
322    }
323
324    #[test]
325    fn test_count_min_sketch_should_return_zero_for_unknown_keys() {
326        let cms = CountMinSketch::new(2048, 4);
327        let key = random_key(10);
328
329        assert_eq!(cms.get(&key), 0);
330        assert!(!cms.contains(&key));
331    }
332
333    #[test]
334    fn test_count_min_sketch_should_tolerate_collisions_within_probabilistic_bounds() {
335        let cms = CountMinSketch::new(2048, 4);
336        let key_a = random_key(10);
337        let key_b = random_key(20);
338        let context = ThreadContext::default();
339
340        for _ in 0..50 {
341            cms.increment(&key_a, &context);
342        }
343        for _ in 0..5 {
344            cms.increment(&key_b, &context);
345        }
346
347        assert!(cms.get(&key_a) >= 50);
348        assert!(cms.get(&key_b) >= 5);
349    }
350}