Skip to main content

omega_cache/core/
cms.rs

1use crossbeam::utils::CachePadded;
2use std::hash::{Hash, Hasher};
3use std::sync::atomic::AtomicU64;
4use std::sync::atomic::Ordering::Relaxed;
5use twox_hash::xxhash64::Hasher as XxHash64;
6
7/// A collection of high-entropy 64-bit prime seeds.
8///
9/// These seeds are used to initialize the `XxHash64` state for each row
10/// of the `CountMinSketch`. Using static primes ensures:
11///
12/// 1. **Deterministic Hashing**: Identical keys map to identical physical
13///    blocks across different instances of the sketch.
14/// 2. **Independence**: Minimizes the probability of "secondary collisions"
15///    where keys collide across multiple rows simultaneously.
16/// 3. **Bit Distribution**: Ensures the hash output is spread evenly across
17///    the `blocks_mask` and the 3-bit internal counter `shift`.
18static SEEDS: [u64; 8] = [
19    0x9e3779b97f4a7c15,
20    0xbf58476d1ce4e5b9,
21    0x94d049bb133111eb,
22    0xff51afd7ed558ccd,
23    0x6a09e667f3bcc908,
24    0xbb67ae8584caa73b,
25    0x3c6ef372fe94f82b,
26    0xa54ff53a5f1d36f1,
27];
28
29/// A lock-free probabilistic data structure for frequency estimation.
30///
31/// This implementation optimizes for CPU cache locality and multithreaded throughput
32/// by bit-packing 8-bit saturating counters into 64-bit atomic blocks.
33pub struct CountMinSketch {
34    /// Bit-packed counter storage. Each `AtomicU64` contains 8 x 8-bit counters.
35    data: CachePadded<Box<[AtomicU64]>>,
36    /// Number of `AtomicU64` blocks per row.
37    blocks: usize,
38    /// Number of independent rows (hash functions) in the sketch.
39    depth: usize,
40    /// Bitmask for power-of-two indexing within a row.
41    blocks_mask: usize,
42}
43
44impl CountMinSketch {
45    /// Creates a new probabilistic estimator with a specified geometry.
46    ///
47    /// The physical storage is calculated by mapping the requested logical width
48    /// to 64-bit atomic blocks (8 counters per block).
49    ///
50    /// # Arguments
51    /// * `width` - The logical number of counters per row. This is rounded to the next
52    ///   power of two to enable bitwise indexing.
53    /// * `depth` - The number of independent hash functions (rows) used to minimize
54    ///   the probability of overestimation.
55    ///
56    /// # Panics
57    /// Panics if `depth` exceeds 8, as the implementation relies on a fixed set
58    /// of static prime seeds for hashing independence.
59    #[inline]
60    pub fn new(width: usize, depth: usize) -> Self {
61        assert!(depth <= 8, "depth must not exceed 8");
62        let blocks = (width / 8).next_power_of_two();
63        let blocks_mask = blocks - 1;
64
65        let data = (0..(blocks * depth))
66            .map(|_| AtomicU64::default())
67            .collect::<Vec<_>>()
68            .into_boxed_slice();
69
70        Self {
71            data: CachePadded::new(data),
72            blocks,
73            depth,
74            blocks_mask,
75        }
76    }
77
78    /// Increments the frequency counters for the given key.
79    ///
80    /// This operation is performed across all rows (defined by `depth`) using a
81    /// saturating add. Counters will not exceed 255.
82    ///
83    /// # Arguments
84    /// * `key` - The item whose frequency should be increased.
85    pub fn increment<K: Eq + Hash>(&self, key: &K) {
86        let mut skip = 0;
87
88        for seed in (0..self.depth).map(|index| SEEDS[index]) {
89            let hash = self.hash(key, seed) as usize;
90
91            let block_index = skip + (hash & self.blocks_mask);
92
93            let shift = ((hash >> 32) & 0x7) * 8;
94
95            let _ = self.data[block_index].fetch_update(Relaxed, Relaxed, |block| {
96                let frequency = (block >> shift) & 0xFF;
97                match frequency {
98                    255 => None,
99                    _ => Some(block + (1 << shift)),
100                }
101            });
102
103            skip += self.blocks;
104        }
105    }
106
107    /// Decrements the frequency counters for the given key.
108    ///
109    /// This is used for manual aging or item removal within the sketch.
110    /// Counters will not underflow below 0.
111    ///
112    /// # Arguments
113    /// * `key` - The item whose frequency should be decreased.
114    pub fn decrement<K: Eq + Hash>(&self, key: &K) {
115        let mut skip = 0;
116
117        for seed in (0..self.depth).map(|index| SEEDS[index]) {
118            let hash = self.hash(key, seed) as usize;
119
120            let block_index = skip + (hash & self.blocks_mask);
121
122            let shift = ((hash >> 32) & 0x7) * 8;
123
124            let _ = self.data[block_index].fetch_update(Relaxed, Relaxed, |block| {
125                let frequency = (block >> shift) & 0xFF;
126                match frequency {
127                    0 => None,
128                    _ => Some(block - (1 << shift)),
129                }
130            });
131
132            skip += self.blocks;
133        }
134    }
135
136    /// Performs an aging operation by halving all counters in the sketch.
137    ///
138    /// This uses a bitwise trick to divide eight 8-bit counters by 2
139    /// simultaneously within each 64-bit word.
140    pub fn decay(&self) {
141        let mask: u64 = 0xFEFEFEFEFEFEFEFE;
142
143        for i in 0..self.data.len() {
144            let _ = self.data[i].fetch_update(Relaxed, Relaxed, |block| {
145                if block == 0 {
146                    None
147                } else {
148                    Some((block & mask) >> 1)
149                }
150            });
151        }
152    }
153
154    /// Returns `true` if the key has an estimated frequency greater than zero.
155    ///
156    /// Derived by taking the minimum frequency observed across all rows.
157    ///
158    /// # Arguments
159    /// * `key` - The item to check for presence in the sketch.
160    pub fn contains<K: Eq + Hash>(&self, key: &K) -> bool {
161        self.get(key) > 0
162    }
163
164    /// Returns the estimated frequency of the provided key.
165    ///
166    /// The estimate is derived by taking the minimum value found across all rows
167    /// (defined by `depth`). Due to the probabilistic nature of the sketch, this
168    /// value is an upper bound of the actual frequency.
169    ///
170    /// # Arguments
171    /// * `key` - The item whose frequency estimate is being requested.
172    pub fn get<K: Eq + Hash>(&self, key: &K) -> u8 {
173        let mut skip = 0;
174        let mut frequency: u16 = u16::MAX;
175
176        for seed in (0..self.depth).map(|index| SEEDS[index]) {
177            let hash = self.hash(key, seed) as usize;
178
179            let block_index = skip + (hash & self.blocks_mask);
180
181            let shift = ((hash >> 32) & 0x7) * 8;
182
183            let block = self.data[block_index].load(Relaxed);
184
185            let current_frequency = (block >> shift) as u16;
186
187            frequency = frequency.min(current_frequency);
188
189            skip += self.blocks
190        }
191
192        if frequency == u16::MAX {
193            0
194        } else {
195            frequency as u8
196        }
197    }
198
199    /// Generates a 64-bit fingerprint for a specific row index.
200    ///
201    /// This uses a seeded hashing strategy to ensure that each row in the
202    /// sketch provides an independent observation of the key's frequency.
203    ///
204    /// # Arguments
205    /// * `key` - The item to be hashed.
206    /// * `seed` - The row-specific random seed to initialize the hasher.
207    fn hash<K: Eq + Hash>(&self, key: K, seed: u64) -> u64 {
208        let mut hasher = XxHash64::with_seed(seed);
209        key.hash(&mut hasher);
210        hasher.finish()
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use rand::distr::{Alphanumeric, SampleString};
218    use std::thread::scope;
219
220    fn random_key(len: usize) -> String {
221        Alphanumeric.sample_string(&mut rand::rng(), len)
222    }
223
224    #[test]
225    fn test_count_min_sketch_should_increment_and_retrieve_frequency() {
226        let cms = CountMinSketch::new(128, 4);
227        let key = random_key(10);
228
229        cms.increment(&key);
230        cms.increment(&key);
231        cms.increment(&key);
232
233        assert_eq!(
234            cms.get(&key),
235            3,
236            "Frequency should reflect the exact number of increments."
237        );
238    }
239
240    #[test]
241    fn test_count_min_sketch_should_saturate_at_max_u8_without_overflow() {
242        let cms = CountMinSketch::new(64, 2);
243        let key = random_key(10);
244
245        // Increment past 255 to test bit-slot protection
246        for _ in 0..300 {
247            cms.increment(&key);
248        }
249
250        assert_eq!(
251            cms.get(&key),
252            255,
253            "Counters must cap at 255 to protect adjacent bit-packed slots."
254        );
255    }
256
257    #[test]
258    fn test_count_min_sketch_should_halve_all_counters_on_decay() {
259        let cms = CountMinSketch::new(1024, 4);
260        let key = random_key(10);
261
262        for _ in 0..20 {
263            cms.increment(&key);
264        }
265        assert_eq!(cms.get(&key), 20);
266
267        cms.decay();
268        assert_eq!(cms.get(&key), 10);
269
270        cms.decay();
271        assert_eq!(cms.get(&key), 5);
272
273        cms.decay();
274        assert_eq!(cms.get(&key), 2);
275    }
276
277    #[test]
278    fn test_count_min_sketch_should_saturate_at_zero_on_decrement() {
279        let cms = CountMinSketch::new(128, 4);
280        let key = random_key(10);
281
282        cms.increment(&key);
283        cms.decrement(&key);
284        assert_eq!(cms.get(&key), 0);
285
286        // Ensure no underflow (wrapping to 255)
287        cms.decrement(&key);
288        assert_eq!(cms.get(&key), 0, "Counter must not underflow below zero.");
289    }
290
291    #[test]
292    fn test_count_min_sketch_should_maintain_consistent_state_under_contention() {
293        let cms = CountMinSketch::new(16, 4); // Small width to force block collisions
294        let num_threads = 10;
295        let ops_per_thread = 20;
296        let key = random_key(10);
297
298        scope(|s| {
299            for _ in 0..num_threads {
300                s.spawn(|| {
301                    for _ in 0..ops_per_thread {
302                        cms.increment(&key);
303                    }
304                });
305            }
306        });
307
308        assert_eq!(
309            cms.get(&key),
310            200,
311            "Lock-free increments must be atomic across bit-packed blocks."
312        );
313    }
314
315    #[test]
316    fn test_count_min_sketch_should_return_zero_for_unknown_keys() {
317        let cms = CountMinSketch::new(2048, 4);
318        let key = random_key(10);
319
320        assert_eq!(cms.get(&key), 0);
321        assert!(!cms.contains(&key));
322    }
323
324    #[test]
325    fn test_count_min_sketch_should_tolerate_collisions_within_probabilistic_bounds() {
326        let cms = CountMinSketch::new(2048, 4);
327        let key_a = random_key(10);
328        let key_b = random_key(20);
329
330        for _ in 0..50 {
331            cms.increment(&key_a);
332        }
333        for _ in 0..5 {
334            cms.increment(&key_b);
335        }
336
337        // CMS estimates are always upper bounds
338        assert!(cms.get(&key_a) >= 50);
339        assert!(cms.get(&key_b) >= 5);
340    }
341}