amadeus_streaming/
count_min.rs

1// This file includes source code from https://github.com/jedisct1/rust-count-min-sketch/blob/088274e22a3decc986dec928c92cc90a709a0274/src/lib.rs under the following MIT License:
2
3// Copyright (c) 2016 Frank Denis
4
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23use serde::{Deserialize, Serialize};
24use std::{
25	borrow::Borrow, cmp::max, convert::TryFrom, fmt, hash::{Hash, Hasher}, marker::PhantomData, ops
26};
27use twox_hash::XxHash;
28
29use super::f64_to_usize;
30use crate::traits::{Intersect, IntersectPlusUnionIsPlus, New, UnionAssign};
31
32/// An implementation of a [count-min sketch](https://en.wikipedia.org/wiki/Count–min_sketch) data structure with *conservative updating* for increased accuracy.
33///
34/// This data structure is also known as a [counting Bloom filter](https://en.wikipedia.org/wiki/Bloom_filter#Counting_filters).
35///
36/// See [*An Improved Data Stream Summary: The Count-Min Sketch and its Applications*](http://dimacs.rutgers.edu/~graham/pubs/papers/cm-full.pdf) and [*New Directions in Traffic Measurement and Accounting*](http://pages.cs.wisc.edu/~suman/courses/740/papers/estan03tocs.pdf) for background on the count-min sketch with conservative updating.
37#[derive(Serialize, Deserialize)]
38#[serde(bound(
39	serialize = "C: Serialize, <C as New>::Config: Serialize",
40	deserialize = "C: Deserialize<'de>, <C as New>::Config: Deserialize<'de>"
41))]
42pub struct CountMinSketch<K: ?Sized, C: New> {
43	counters: Vec<Vec<C>>,
44	offsets: Vec<usize>, // to avoid malloc/free each push
45	mask: usize,
46	k_num: usize,
47	config: <C as New>::Config,
48	marker: PhantomData<fn(K)>,
49}
50
51impl<K: ?Sized, C> CountMinSketch<K, C>
52where
53	K: Hash,
54	C: New + for<'a> UnionAssign<&'a C> + Intersect,
55{
56	/// Create an empty `CountMinSketch` data structure with the specified error tolerance.
57	pub fn new(probability: f64, tolerance: f64, config: C::Config) -> Self {
58		let width = Self::optimal_width(tolerance);
59		let k_num = Self::optimal_k_num(probability);
60		let counters: Vec<Vec<C>> = (0..k_num)
61			.map(|_| (0..width).map(|_| C::new(&config)).collect())
62			.collect();
63		let offsets = vec![0; k_num];
64		Self {
65			counters,
66			offsets,
67			mask: Self::mask(width),
68			k_num,
69			config,
70			marker: PhantomData,
71		}
72	}
73
74	/// "Visit" an element.
75	pub fn push<Q: ?Sized, V: ?Sized>(&mut self, key: &Q, value: &V) -> C
76	where
77		Q: Hash,
78		K: Borrow<Q>,
79		C: for<'a> ops::AddAssign<&'a V> + IntersectPlusUnionIsPlus,
80	{
81		let offsets = self.offsets(key);
82		if !<C as IntersectPlusUnionIsPlus>::VAL {
83			self.offsets
84				.iter_mut()
85				.zip(offsets)
86				.for_each(|(offset, offset_new)| {
87					*offset = offset_new;
88				});
89			let mut lowest = C::intersect(
90				self.offsets
91					.iter()
92					.enumerate()
93					.map(|(k_i, &offset)| &self.counters[k_i][offset]),
94			)
95			.unwrap();
96			lowest += value;
97			self.counters
98				.iter_mut()
99				.zip(self.offsets.iter())
100				.for_each(|(counters, &offset)| {
101					counters[offset].union_assign(&lowest);
102				});
103			lowest
104		} else {
105			C::intersect(
106				self.counters
107					.iter_mut()
108					.zip(offsets)
109					.map(|(counters, offset)| {
110						counters[offset] += value;
111						&counters[offset]
112					}),
113			)
114			.unwrap()
115		}
116	}
117
118	/// Union the aggregated value for `key` with `value`.
119	pub fn union_assign<Q: ?Sized>(&mut self, key: &Q, value: &C)
120	where
121		Q: Hash,
122		K: Borrow<Q>,
123	{
124		let offsets = self.offsets(key);
125		self.counters
126			.iter_mut()
127			.zip(offsets)
128			.for_each(|(counters, offset)| {
129				counters[offset].union_assign(value);
130			})
131	}
132
133	/// Retrieve an estimate of the aggregated value for `key`.
134	pub fn get<Q: ?Sized>(&self, key: &Q) -> C
135	where
136		Q: Hash,
137		K: Borrow<Q>,
138	{
139		C::intersect(
140			self.counters
141				.iter()
142				.zip(self.offsets(key))
143				.map(|(counters, offset)| &counters[offset]),
144		)
145		.unwrap()
146	}
147
148	// pub fn estimate_memory(
149	// 	probability: f64, tolerance: f64,
150	// ) -> Result<usize, &'static str> {
151	// 	let width = Self::optimal_width(tolerance);
152	// 	let k_num = Self::optimal_k_num(probability);
153	// 	Ok(width * mem::size_of::<C>() * k_num)
154	// }
155
156	/// Clears the `CountMinSketch` data structure, as if it was new.
157	pub fn clear(&mut self) {
158		let config = &self.config;
159		self.counters
160			.iter_mut()
161			.flat_map(|x| x.iter_mut())
162			.for_each(|counter| {
163				*counter = C::new(config);
164			})
165	}
166
167	fn optimal_width(tolerance: f64) -> usize {
168		let e = tolerance;
169		let width = f64_to_usize((2.0 / e).round());
170		max(2, width)
171			.checked_next_power_of_two()
172			.expect("Width would be way too large")
173	}
174
175	fn mask(width: usize) -> usize {
176		assert!(width > 1);
177		assert_eq!(width & (width - 1), 0);
178		width - 1
179	}
180
181	fn optimal_k_num(probability: f64) -> usize {
182		max(
183			1,
184			f64_to_usize(((1.0 - probability).ln() / 0.5_f64.ln()).floor()),
185		)
186	}
187
188	fn offsets<Q: ?Sized>(&self, key: &Q) -> impl Iterator<Item = usize>
189	where
190		Q: Hash,
191		K: Borrow<Q>,
192	{
193		let mask = self.mask;
194		hashes(key).map(move |hash| usize::try_from(hash & u64::try_from(mask).unwrap()).unwrap())
195	}
196}
197
198fn hashes<Q: ?Sized>(key: &Q) -> impl Iterator<Item = u64>
199where
200	Q: Hash,
201{
202	#[allow(missing_copy_implementations, missing_debug_implementations)]
203	struct X(XxHash);
204	impl Iterator for X {
205		type Item = u64;
206		fn next(&mut self) -> Option<Self::Item> {
207			let ret = self.0.finish();
208			self.0.write(&[123]);
209			Some(ret)
210		}
211	}
212	let mut hasher = XxHash::default();
213	key.hash(&mut hasher);
214	X(hasher)
215}
216
217impl<K: ?Sized, C: New + Clone> Clone for CountMinSketch<K, C> {
218	fn clone(&self) -> Self {
219		Self {
220			counters: self.counters.clone(),
221			offsets: vec![0; self.offsets.len()],
222			mask: self.mask,
223			k_num: self.k_num,
224			config: self.config.clone(),
225			marker: PhantomData,
226		}
227	}
228}
229impl<K: ?Sized, C: New> fmt::Debug for CountMinSketch<K, C> {
230	fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
231		fmt.debug_struct("CountMinSketch")
232			// .field("counters", &self.counters)
233			.finish()
234	}
235}
236
237#[cfg(test)]
238mod tests {
239	type CountMinSketch8<K> = super::CountMinSketch<K, u8>;
240	type CountMinSketch16<K> = super::CountMinSketch<K, u16>;
241	type CountMinSketch64<K> = super::CountMinSketch<K, u64>;
242
243	#[ignore] // release mode stops panic
244	#[test]
245	#[should_panic]
246	fn test_overflow() {
247		let mut cms = CountMinSketch8::<&str>::new(0.95, 10.0 / 100.0, ());
248		for _ in 0..300 {
249			let _ = cms.push("key", &1);
250		}
251		// assert_eq!(cms.get("key"), &u8::max_value());
252	}
253
254	#[test]
255	fn test_increment() {
256		let mut cms = CountMinSketch16::<&str>::new(0.95, 10.0 / 100.0, ());
257		for _ in 0..300 {
258			let _ = cms.push("key", &1);
259		}
260		assert_eq!(cms.get("key"), 300);
261	}
262
263	#[test]
264	#[cfg_attr(miri, ignore)]
265	fn test_increment_multi() {
266		let mut cms = CountMinSketch64::<u64>::new(0.99, 2.0 / 100.0, ());
267		for i in 0..1_000_000 {
268			let _ = cms.push(&(i % 100), &1);
269		}
270		for key in 0..100 {
271			assert!(cms.get(&key) >= 9_000);
272		}
273		// cms.reset();
274		// for key in 0..100 {
275		//     assert!(cms.get(&key) < 11_000);
276		// }
277	}
278}