1use crate::core::thread_context::ThreadContext;
2use std::hash::{Hash, Hasher};
3use std::sync::atomic::AtomicU16;
4use std::sync::atomic::Ordering::{Acquire, Relaxed};
5use twox_hash::xxhash64::Hasher as XxHash64;
6
7static SEEDS: [u64; 8] = [
13 0x9e3779b97f4a7c15,
14 0xbf58476d1ce4e5b9,
15 0x94d049bb133111eb,
16 0xff51afd7ed558ccd,
17 0x6a09e667f3bcc908,
18 0xbb67ae8584caa73b,
19 0x3c6ef372fe94f82b,
20 0xa54ff53a5f1d36f1,
21];
22
23pub struct CountMinSketch {
29 counters: Box<[AtomicU16]>,
30 columns: usize,
31 rows: usize,
32}
33
34impl CountMinSketch {
35 #[inline]
45 pub fn new(columns: usize, rows: usize) -> Self {
46 assert!(rows <= 8, "Depth exceeds available static seeds (8)");
47 let columns = (columns * 2).next_power_of_two();
48
49 let counts = (0..columns * rows)
50 .map(|_| AtomicU16::default())
51 .collect::<Vec<_>>()
52 .into_boxed_slice();
53
54 Self {
55 counters: counts,
56 columns,
57 rows,
58 }
59 }
60
61 pub fn increment<K>(&self, key: &K, context: &ThreadContext)
67 where
68 K: Eq + Hash + ?Sized,
69 {
70 let mut skip = 0;
71
72 for seed in self.seeds() {
73 let hash = self.hash(key, seed) as usize;
74
75 let column = hash & (self.columns - 1);
76 let index = skip + column;
77
78 let mut counter = self.counters[index].load(Acquire);
79
80 while counter < u16::MAX {
81 match self.counters[index].compare_exchange_weak(
82 counter,
83 counter + 1,
84 Relaxed,
85 Relaxed,
86 ) {
87 Ok(_) => {
88 context.decay();
89 break;
90 }
91 Err(latest) => {
92 counter = latest;
93 context.wait();
94 }
95 }
96 }
97
98 skip += self.columns;
99 }
100 }
101
102 #[inline(always)]
104 fn seeds(&self) -> Vec<u64> {
105 (0..self.rows).map(|index| SEEDS[index]).collect()
106 }
107
108 pub fn decrement<K>(&self, key: &K, context: &ThreadContext)
114 where
115 K: Eq + Hash + ?Sized,
116 {
117 let mut skip = 0;
118
119 for seed in self.seeds() {
120 let hash = self.hash(key, seed) as usize;
121
122 let column = hash & (self.columns - 1);
123 let index = skip + column;
124
125 let mut counter = self.counters[index].load(Acquire);
126
127 while counter > 0 {
128 match self.counters[index].compare_exchange_weak(
129 counter,
130 counter - 1,
131 Relaxed,
132 Relaxed,
133 ) {
134 Ok(_) => {
135 context.decay();
136 break;
137 }
138 Err(latest) => {
139 counter = latest;
140 context.wait();
141 }
142 }
143 }
144
145 skip += self.columns;
146 }
147 }
148
149 pub fn decay(&self, context: &ThreadContext) {
155 for counter in &self.counters {
156 let mut counter_value = counter.load(Relaxed);
157
158 if counter_value > 0 {
159 match counter.compare_exchange_weak(
160 counter_value,
161 counter_value >> 1,
162 Relaxed,
163 Relaxed,
164 ) {
165 Ok(_) => {
166 context.decay();
167 }
168 Err(latest) => {
169 counter_value = latest;
170 context.wait();
171 }
172 }
173 }
174 }
175 }
176
177 pub fn contains<K: Eq + Hash>(&self, key: &K) -> bool {
179 self.get(key) > 0
180 }
181
182 pub fn get<K>(&self, key: &K) -> u16
187 where
188 K: Eq + Hash + ?Sized,
189 {
190 let mut skip = 0;
191 let mut frequency = u32::MAX;
192
193 for seed in (0..self.rows).map(|index| SEEDS[index]) {
194 let hash = self.hash(key, seed) as usize;
195 let index = skip + (hash & (self.columns - 1));
196
197 let counter = self.counters[index].load(Relaxed) as u32;
198 frequency = frequency.min(counter);
199
200 skip += self.columns
201 }
202
203 if frequency == u32::MAX {
204 0
205 } else {
206 frequency as u16
207 }
208 }
209
210 #[inline(always)]
212 fn hash<K: Eq + Hash + ?Sized>(&self, key: &K, seed: u64) -> u64 {
213 let mut hasher = XxHash64::with_seed(seed);
214 key.hash(&mut hasher);
215 hasher.finish()
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use rand::distr::{Alphanumeric, SampleString};
223 use std::thread::scope;
224
225 fn random_key(len: usize) -> String {
226 Alphanumeric.sample_string(&mut rand::rng(), len)
227 }
228
229 #[test]
230 fn test_count_min_sketch_should_increment_and_retrieve_frequency() {
231 let cms = CountMinSketch::new(128, 4);
232 let key = random_key(10);
233 let context = ThreadContext::default();
234
235 cms.increment(&key, &context);
236 cms.increment(&key, &context);
237 cms.increment(&key, &context);
238
239 assert_eq!(
240 cms.get(&key),
241 3,
242 "Frequency should reflect the exact number of increments."
243 );
244 }
245
246 #[test]
247 fn test_count_min_sketch_should_saturate_at_max_logical_value() {
248 let cms = CountMinSketch::new(64, 2);
249 let key = random_key(10);
250 let context = ThreadContext::default();
251
252 for _ in 0..100000 {
253 cms.increment(&key, &context);
254 }
255
256 assert_eq!(
257 cms.get(&key),
258 u16::MAX,
259 "Counters must cap at MAX_FREQUENCY to prevent wrap-around."
260 );
261 }
262
263 #[test]
264 fn test_count_min_sketch_should_halve_all_counters_on_decay() {
265 let cms = CountMinSketch::new(1024, 4);
266 let key = random_key(10);
267 let context = ThreadContext::default();
268
269 for _ in 0..20 {
270 cms.increment(&key, &context);
271 }
272
273 assert_eq!(cms.get(&key), 20);
274
275 cms.decay(&context);
276 assert_eq!(cms.get(&key), 10);
277
278 cms.decay(&context);
279 assert_eq!(cms.get(&key), 5);
280
281 cms.decay(&context);
282 assert_eq!(cms.get(&key), 2); }
284
285 #[test]
286 fn test_count_min_sketch_should_saturate_at_zero_on_decrement() {
287 let cms = CountMinSketch::new(128, 4);
288 let key = random_key(10);
289 let context = ThreadContext::default();
290
291 cms.increment(&key, &context);
292 cms.decrement(&key, &context);
293 assert_eq!(cms.get(&key), 0);
294
295 cms.decrement(&key, &context);
296 assert_eq!(cms.get(&key), 0, "Counter must not underflow below zero.");
297 }
298
299 #[test]
300 fn test_count_min_sketch_should_maintain_consistent_state_under_contention() {
301 let cms = CountMinSketch::new(16, 4); let num_threads = 8;
303 let ops_per_thread = 100;
304 let key = random_key(10);
305
306 scope(|s| {
307 for _ in 0..num_threads {
308 s.spawn(|| {
309 let context = ThreadContext::default();
310 for _ in 0..ops_per_thread {
311 cms.increment(&key, &context);
312 }
313 });
314 }
315 });
316
317 assert_eq!(
318 cms.get(&key),
319 (num_threads * ops_per_thread) as u16,
320 "Atomic increments must be consistent across multiple threads."
321 );
322 }
323
324 #[test]
325 fn test_count_min_sketch_should_return_zero_for_unknown_keys() {
326 let cms = CountMinSketch::new(2048, 4);
327 let key = random_key(10);
328
329 assert_eq!(cms.get(&key), 0);
330 assert!(!cms.contains(&key));
331 }
332
333 #[test]
334 fn test_count_min_sketch_should_tolerate_collisions_within_probabilistic_bounds() {
335 let cms = CountMinSketch::new(2048, 4);
336 let key_a = random_key(10);
337 let key_b = random_key(20);
338 let context = ThreadContext::default();
339
340 for _ in 0..50 {
341 cms.increment(&key_a, &context);
342 }
343 for _ in 0..5 {
344 cms.increment(&key_b, &context);
345 }
346
347 assert!(cms.get(&key_a) >= 50);
348 assert!(cms.get(&key_b) >= 5);
349 }
350}