1use crossbeam::utils::CachePadded;
2use std::hash::{Hash, Hasher};
3use std::sync::atomic::AtomicU64;
4use std::sync::atomic::Ordering::Relaxed;
5use twox_hash::xxhash64::Hasher as XxHash64;
6
7static SEEDS: [u64; 8] = [
19 0x9e3779b97f4a7c15,
20 0xbf58476d1ce4e5b9,
21 0x94d049bb133111eb,
22 0xff51afd7ed558ccd,
23 0x6a09e667f3bcc908,
24 0xbb67ae8584caa73b,
25 0x3c6ef372fe94f82b,
26 0xa54ff53a5f1d36f1,
27];
28
29pub struct CountMinSketch {
34 data: CachePadded<Box<[AtomicU64]>>,
36 blocks: usize,
38 depth: usize,
40 blocks_mask: usize,
42}
43
44impl CountMinSketch {
45 #[inline]
60 pub fn new(width: usize, depth: usize) -> Self {
61 assert!(depth <= 8, "depth must not exceed 8");
62 let blocks = (width / 8).next_power_of_two();
63 let blocks_mask = blocks - 1;
64
65 let data = (0..(blocks * depth))
66 .map(|_| AtomicU64::default())
67 .collect::<Vec<_>>()
68 .into_boxed_slice();
69
70 Self {
71 data: CachePadded::new(data),
72 blocks,
73 depth,
74 blocks_mask,
75 }
76 }
77
78 pub fn increment<K: Eq + Hash>(&self, key: &K) {
86 let mut skip = 0;
87
88 for seed in (0..self.depth).map(|index| SEEDS[index]) {
89 let hash = self.hash(key, seed) as usize;
90
91 let block_index = skip + (hash & self.blocks_mask);
92
93 let shift = ((hash >> 32) & 0x7) * 8;
94
95 let _ = self.data[block_index].fetch_update(Relaxed, Relaxed, |block| {
96 let frequency = (block >> shift) & 0xFF;
97 match frequency {
98 255 => None,
99 _ => Some(block + (1 << shift)),
100 }
101 });
102
103 skip += self.blocks;
104 }
105 }
106
107 pub fn decrement<K: Eq + Hash>(&self, key: &K) {
115 let mut skip = 0;
116
117 for seed in (0..self.depth).map(|index| SEEDS[index]) {
118 let hash = self.hash(key, seed) as usize;
119
120 let block_index = skip + (hash & self.blocks_mask);
121
122 let shift = ((hash >> 32) & 0x7) * 8;
123
124 let _ = self.data[block_index].fetch_update(Relaxed, Relaxed, |block| {
125 let frequency = (block >> shift) & 0xFF;
126 match frequency {
127 0 => None,
128 _ => Some(block - (1 << shift)),
129 }
130 });
131
132 skip += self.blocks;
133 }
134 }
135
136 pub fn decay(&self) {
141 let mask: u64 = 0xFEFEFEFEFEFEFEFE;
142
143 for i in 0..self.data.len() {
144 let _ = self.data[i].fetch_update(Relaxed, Relaxed, |block| {
145 if block == 0 {
146 None
147 } else {
148 Some((block & mask) >> 1)
149 }
150 });
151 }
152 }
153
154 pub fn contains<K: Eq + Hash>(&self, key: &K) -> bool {
161 self.get(key) > 0
162 }
163
164 pub fn get<K: Eq + Hash>(&self, key: &K) -> u8 {
173 let mut skip = 0;
174 let mut frequency: u16 = u16::MAX;
175
176 for seed in (0..self.depth).map(|index| SEEDS[index]) {
177 let hash = self.hash(key, seed) as usize;
178
179 let block_index = skip + (hash & self.blocks_mask);
180
181 let shift = ((hash >> 32) & 0x7) * 8;
182
183 let block = self.data[block_index].load(Relaxed);
184
185 let current_frequency = (block >> shift) as u16;
186
187 frequency = frequency.min(current_frequency);
188
189 skip += self.blocks
190 }
191
192 if frequency == u16::MAX {
193 0
194 } else {
195 frequency as u8
196 }
197 }
198
199 fn hash<K: Eq + Hash>(&self, key: K, seed: u64) -> u64 {
208 let mut hasher = XxHash64::with_seed(seed);
209 key.hash(&mut hasher);
210 hasher.finish()
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use rand::distr::{Alphanumeric, SampleString};
218 use std::thread::scope;
219
220 fn random_key(len: usize) -> String {
221 Alphanumeric.sample_string(&mut rand::rng(), len)
222 }
223
224 #[test]
225 fn test_count_min_sketch_should_increment_and_retrieve_frequency() {
226 let cms = CountMinSketch::new(128, 4);
227 let key = random_key(10);
228
229 cms.increment(&key);
230 cms.increment(&key);
231 cms.increment(&key);
232
233 assert_eq!(
234 cms.get(&key),
235 3,
236 "Frequency should reflect the exact number of increments."
237 );
238 }
239
240 #[test]
241 fn test_count_min_sketch_should_saturate_at_max_u8_without_overflow() {
242 let cms = CountMinSketch::new(64, 2);
243 let key = random_key(10);
244
245 for _ in 0..300 {
247 cms.increment(&key);
248 }
249
250 assert_eq!(
251 cms.get(&key),
252 255,
253 "Counters must cap at 255 to protect adjacent bit-packed slots."
254 );
255 }
256
257 #[test]
258 fn test_count_min_sketch_should_halve_all_counters_on_decay() {
259 let cms = CountMinSketch::new(1024, 4);
260 let key = random_key(10);
261
262 for _ in 0..20 {
263 cms.increment(&key);
264 }
265 assert_eq!(cms.get(&key), 20);
266
267 cms.decay();
268 assert_eq!(cms.get(&key), 10);
269
270 cms.decay();
271 assert_eq!(cms.get(&key), 5);
272
273 cms.decay();
274 assert_eq!(cms.get(&key), 2);
275 }
276
277 #[test]
278 fn test_count_min_sketch_should_saturate_at_zero_on_decrement() {
279 let cms = CountMinSketch::new(128, 4);
280 let key = random_key(10);
281
282 cms.increment(&key);
283 cms.decrement(&key);
284 assert_eq!(cms.get(&key), 0);
285
286 cms.decrement(&key);
288 assert_eq!(cms.get(&key), 0, "Counter must not underflow below zero.");
289 }
290
291 #[test]
292 fn test_count_min_sketch_should_maintain_consistent_state_under_contention() {
293 let cms = CountMinSketch::new(16, 4); let num_threads = 10;
295 let ops_per_thread = 20;
296 let key = random_key(10);
297
298 scope(|s| {
299 for _ in 0..num_threads {
300 s.spawn(|| {
301 for _ in 0..ops_per_thread {
302 cms.increment(&key);
303 }
304 });
305 }
306 });
307
308 assert_eq!(
309 cms.get(&key),
310 200,
311 "Lock-free increments must be atomic across bit-packed blocks."
312 );
313 }
314
315 #[test]
316 fn test_count_min_sketch_should_return_zero_for_unknown_keys() {
317 let cms = CountMinSketch::new(2048, 4);
318 let key = random_key(10);
319
320 assert_eq!(cms.get(&key), 0);
321 assert!(!cms.contains(&key));
322 }
323
324 #[test]
325 fn test_count_min_sketch_should_tolerate_collisions_within_probabilistic_bounds() {
326 let cms = CountMinSketch::new(2048, 4);
327 let key_a = random_key(10);
328 let key_b = random_key(20);
329
330 for _ in 0..50 {
331 cms.increment(&key_a);
332 }
333 for _ in 0..5 {
334 cms.increment(&key_b);
335 }
336
337 assert!(cms.get(&key_a) >= 50);
339 assert!(cms.get(&key_b) >= 5);
340 }
341}