1use serde::{Deserialize, Serialize};
24use std::{
25 borrow::Borrow, cmp::max, convert::TryFrom, fmt, hash::{Hash, Hasher}, marker::PhantomData, ops
26};
27use twox_hash::XxHash;
28
29use super::f64_to_usize;
30use crate::traits::{Intersect, IntersectPlusUnionIsPlus, New, UnionAssign};
31
32#[derive(Serialize, Deserialize)]
38#[serde(bound(
39 serialize = "C: Serialize, <C as New>::Config: Serialize",
40 deserialize = "C: Deserialize<'de>, <C as New>::Config: Deserialize<'de>"
41))]
42pub struct CountMinSketch<K: ?Sized, C: New> {
43 counters: Vec<Vec<C>>,
44 offsets: Vec<usize>, mask: usize,
46 k_num: usize,
47 config: <C as New>::Config,
48 marker: PhantomData<fn(K)>,
49}
50
51impl<K: ?Sized, C> CountMinSketch<K, C>
52where
53 K: Hash,
54 C: New + for<'a> UnionAssign<&'a C> + Intersect,
55{
56 pub fn new(probability: f64, tolerance: f64, config: C::Config) -> Self {
58 let width = Self::optimal_width(tolerance);
59 let k_num = Self::optimal_k_num(probability);
60 let counters: Vec<Vec<C>> = (0..k_num)
61 .map(|_| (0..width).map(|_| C::new(&config)).collect())
62 .collect();
63 let offsets = vec![0; k_num];
64 Self {
65 counters,
66 offsets,
67 mask: Self::mask(width),
68 k_num,
69 config,
70 marker: PhantomData,
71 }
72 }
73
74 pub fn push<Q: ?Sized, V: ?Sized>(&mut self, key: &Q, value: &V) -> C
76 where
77 Q: Hash,
78 K: Borrow<Q>,
79 C: for<'a> ops::AddAssign<&'a V> + IntersectPlusUnionIsPlus,
80 {
81 let offsets = self.offsets(key);
82 if !<C as IntersectPlusUnionIsPlus>::VAL {
83 self.offsets
84 .iter_mut()
85 .zip(offsets)
86 .for_each(|(offset, offset_new)| {
87 *offset = offset_new;
88 });
89 let mut lowest = C::intersect(
90 self.offsets
91 .iter()
92 .enumerate()
93 .map(|(k_i, &offset)| &self.counters[k_i][offset]),
94 )
95 .unwrap();
96 lowest += value;
97 self.counters
98 .iter_mut()
99 .zip(self.offsets.iter())
100 .for_each(|(counters, &offset)| {
101 counters[offset].union_assign(&lowest);
102 });
103 lowest
104 } else {
105 C::intersect(
106 self.counters
107 .iter_mut()
108 .zip(offsets)
109 .map(|(counters, offset)| {
110 counters[offset] += value;
111 &counters[offset]
112 }),
113 )
114 .unwrap()
115 }
116 }
117
118 pub fn union_assign<Q: ?Sized>(&mut self, key: &Q, value: &C)
120 where
121 Q: Hash,
122 K: Borrow<Q>,
123 {
124 let offsets = self.offsets(key);
125 self.counters
126 .iter_mut()
127 .zip(offsets)
128 .for_each(|(counters, offset)| {
129 counters[offset].union_assign(value);
130 })
131 }
132
133 pub fn get<Q: ?Sized>(&self, key: &Q) -> C
135 where
136 Q: Hash,
137 K: Borrow<Q>,
138 {
139 C::intersect(
140 self.counters
141 .iter()
142 .zip(self.offsets(key))
143 .map(|(counters, offset)| &counters[offset]),
144 )
145 .unwrap()
146 }
147
148 pub fn clear(&mut self) {
158 let config = &self.config;
159 self.counters
160 .iter_mut()
161 .flat_map(|x| x.iter_mut())
162 .for_each(|counter| {
163 *counter = C::new(config);
164 })
165 }
166
167 fn optimal_width(tolerance: f64) -> usize {
168 let e = tolerance;
169 let width = f64_to_usize((2.0 / e).round());
170 max(2, width)
171 .checked_next_power_of_two()
172 .expect("Width would be way too large")
173 }
174
175 fn mask(width: usize) -> usize {
176 assert!(width > 1);
177 assert_eq!(width & (width - 1), 0);
178 width - 1
179 }
180
181 fn optimal_k_num(probability: f64) -> usize {
182 max(
183 1,
184 f64_to_usize(((1.0 - probability).ln() / 0.5_f64.ln()).floor()),
185 )
186 }
187
188 fn offsets<Q: ?Sized>(&self, key: &Q) -> impl Iterator<Item = usize>
189 where
190 Q: Hash,
191 K: Borrow<Q>,
192 {
193 let mask = self.mask;
194 hashes(key).map(move |hash| usize::try_from(hash & u64::try_from(mask).unwrap()).unwrap())
195 }
196}
197
198fn hashes<Q: ?Sized>(key: &Q) -> impl Iterator<Item = u64>
199where
200 Q: Hash,
201{
202 #[allow(missing_copy_implementations, missing_debug_implementations)]
203 struct X(XxHash);
204 impl Iterator for X {
205 type Item = u64;
206 fn next(&mut self) -> Option<Self::Item> {
207 let ret = self.0.finish();
208 self.0.write(&[123]);
209 Some(ret)
210 }
211 }
212 let mut hasher = XxHash::default();
213 key.hash(&mut hasher);
214 X(hasher)
215}
216
217impl<K: ?Sized, C: New + Clone> Clone for CountMinSketch<K, C> {
218 fn clone(&self) -> Self {
219 Self {
220 counters: self.counters.clone(),
221 offsets: vec![0; self.offsets.len()],
222 mask: self.mask,
223 k_num: self.k_num,
224 config: self.config.clone(),
225 marker: PhantomData,
226 }
227 }
228}
229impl<K: ?Sized, C: New> fmt::Debug for CountMinSketch<K, C> {
230 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
231 fmt.debug_struct("CountMinSketch")
232 .finish()
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 type CountMinSketch8<K> = super::CountMinSketch<K, u8>;
240 type CountMinSketch16<K> = super::CountMinSketch<K, u16>;
241 type CountMinSketch64<K> = super::CountMinSketch<K, u64>;
242
243 #[ignore] #[test]
245 #[should_panic]
246 fn test_overflow() {
247 let mut cms = CountMinSketch8::<&str>::new(0.95, 10.0 / 100.0, ());
248 for _ in 0..300 {
249 let _ = cms.push("key", &1);
250 }
251 }
253
254 #[test]
255 fn test_increment() {
256 let mut cms = CountMinSketch16::<&str>::new(0.95, 10.0 / 100.0, ());
257 for _ in 0..300 {
258 let _ = cms.push("key", &1);
259 }
260 assert_eq!(cms.get("key"), 300);
261 }
262
263 #[test]
264 #[cfg_attr(miri, ignore)]
265 fn test_increment_multi() {
266 let mut cms = CountMinSketch64::<u64>::new(0.99, 2.0 / 100.0, ());
267 for i in 0..1_000_000 {
268 let _ = cms.push(&(i % 100), &1);
269 }
270 for key in 0..100 {
271 assert!(cms.get(&key) >= 9_000);
272 }
273 }
278}