hypertwobits/
h2b.rs

1pub(super) mod sketch;
2#[cfg(test)]
3mod tests;
4
5use std::hash::BuildHasher;
6
7pub use sketch::{Sketch, M1024, M128, M2048, M256, M4096, M512, M64, M8192};
8
9use crate::AHasherDefaultBuilder;
10
11/// `HyperTwoBits` implementation that is fully stack allocated and generic to avoid branches for
12/// different numbers of sub streams.
13///
14/// Both the hasher and the sub stream size siaz can be customized, by default it uses `AHasherBuilder` and `M256`
15#[cfg_attr(feature = "mem_dbg", derive(mem_dbg::MemDbg, mem_dbg::MemSize))]
16#[derive(Debug, Eq, PartialEq, Hash, Clone)]
17pub struct HyperTwoBits<SKETCH: Sketch = M256, HASH: BuildHasher = AHasherDefaultBuilder> {
18    hash: HASH,
19    sketch: SKETCH,
20    count: u32,
21    t: u32,
22}
23
24impl<SKETCH: Sketch, H: Default + BuildHasher> Default for HyperTwoBits<SKETCH, H> {
25    fn default() -> Self {
26        Self {
27            hash: H::default(),
28            sketch: SKETCH::default(),
29            count: 0,
30            t: 1,
31        }
32    }
33}
34
35impl<HASH: BuildHasher + Default, BITS: Sketch> HyperTwoBits<BITS, HASH> {
36    const ALPHA: f64 = 0.988;
37    #[must_use]
38    /// Creates a new `HyperTwoBits` counter with specified hasher and bitset,
39    /// use `HyperTwoBits::default()` for default values.
40    pub fn new() -> Self {
41        Self {
42            hash: HASH::default(),
43            sketch: BITS::default(),
44            count: 0,
45            t: 1,
46        }
47    }
48
49    /// Merges another `HyperTwoBits` counter into this one
50    /// # Panics
51    /// If hasheres are seeded as that prevents merging
52    pub fn merge(&mut self, mut other: Self) {
53        assert_eq!(
54            self.hash.hash_one(42),
55            other.hash.hash_one(42),
56            "Hashers must be the same, can not merge"
57        );
58        // The paper asks for actions if the sketch is "nearly full", this is a very loose definition
59        // we will assume 99% if substreams set is "nearly full"
60        #[allow(
61            clippy::cast_lossless,
62            clippy::cast_sign_loss,
63            clippy::cast_possible_truncation
64        )]
65        let threshold = const { (0.99 * (BITS::STREAMS as f64)) as u32 };
66        // for simplicity we ensure that `self` is always the larger sketch
67        if other.t > self.t {
68            std::mem::swap(self, &mut other);
69        }
70
71        // If the values of T differ by 8 or more, use the larger value and its sketches.
72        if self.t - other.t > 8 {
73            return;
74        }
75        // we pre-compute if self.t == other.t so we can do the decrement below before handling
76        // the other cases
77        let same = self.t == other.t;
78        // We now only have the first and third case left, so we can handle the decrement
79        if self.count >= threshold {
80            self.count = self.sketch.decrement();
81            self.t += 4;
82        }
83
84        if same {
85            // Merg sketches
86            self.sketch.merge(&other.sketch);
87        } else {
88            // merge the high bits of other into the low bits of self
89            self.sketch.merge_high_into_lo(&other.sketch);
90        }
91        // update the count
92        self.count = self.sketch.count();
93    }
94
95    #[inline]
96    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
97    /// Inserts a value into the counter
98    pub fn insert<V: std::hash::Hash + ?Sized>(&mut self, value: &V) {
99        let hash = self.hash.hash_one(value);
100        self.insert_hash(hash);
101    }
102
103    #[inline]
104    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
105    /// Inserts a value into the counter
106    pub fn insert_hash(&mut self, hash: u64) {
107        let threshold: u32 = const { (Self::ALPHA * BITS::STREAMS as f64) as u32 };
108        // use most significant bits for k the rest for x
109        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
110        let hash: u64 = hash & BITS::HASH_MASK;
111
112        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
113            self.count += 1;
114            self.sketch.set(stream, 1);
115        }
116        // 2^4
117        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
118            self.sketch.set(stream, 2);
119        }
120
121        // 2^8
122        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
123            self.sketch.set(stream, 3);
124        }
125
126        if self.count >= threshold {
127            self.count = self.sketch.decrement();
128            self.t += 4;
129        }
130    }
131
132    #[inline]
133    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
134    /// Inserts 2 elements into the counter for micro batching purposes, note this will delay
135    /// the count update to the end
136    pub fn insert2<V: std::hash::Hash>(&mut self, v1: &V, v2: &V) {
137        let threshold: u32 = const { (Self::ALPHA * BITS::STREAMS as f64) as u32 };
138
139        let hash = self.hash.hash_one(v1);
140        // use most significant bits for k the rest for x
141        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
142        let hash: u64 = hash & BITS::HASH_MASK;
143
144        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
145            self.count += 1;
146            self.sketch.set(stream, 1);
147        }
148        // 2^4
149        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
150            self.sketch.set(stream, 2);
151        }
152
153        // 2^8
154        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
155            self.sketch.set(stream, 3);
156        }
157
158        let hash = self.hash.hash_one(v2);
159        // use most significant bits for k the rest for x
160        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
161        let hash: u64 = hash & BITS::HASH_MASK;
162
163        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
164            self.count += 1;
165            self.sketch.set(stream, 1);
166        }
167        // 2^4
168        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
169            self.sketch.set(stream, 2);
170        }
171
172        // 2^8
173        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
174            self.sketch.set(stream, 3);
175        }
176
177        if self.count >= threshold {
178            self.count = self.sketch.decrement();
179            self.t += 4;
180        }
181    }
182
183    #[inline]
184    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
185    /// Inserts 4 elements into the counter for micro batching purposes, note this will delay
186    /// the count update to the end
187    pub fn insert4<V: std::hash::Hash>(&mut self, v1: &V, v2: &V, v3: &V, v4: &V) {
188        let threshold: u32 = const { (Self::ALPHA * BITS::STREAMS as f64) as u32 };
189
190        let hash = self.hash.hash_one(v1);
191        // use most significant bits for k the rest for x
192        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
193        let hash: u64 = hash & BITS::HASH_MASK;
194
195        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
196            self.count += 1;
197            self.sketch.set(stream, 1);
198        }
199        // 2^4
200        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
201            self.sketch.set(stream, 2);
202        }
203
204        // 2^8
205        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
206            self.sketch.set(stream, 3);
207        }
208
209        let hash = self.hash.hash_one(v2);
210        // use most significant bits for k the rest for x
211        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
212        let hash: u64 = hash & BITS::HASH_MASK;
213
214        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
215            self.count += 1;
216            self.sketch.set(stream, 1);
217        }
218        // 2^4
219        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
220            self.sketch.set(stream, 2);
221        }
222
223        // 2^8
224        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
225            self.sketch.set(stream, 3);
226        }
227
228        let hash = self.hash.hash_one(v3);
229        // use most significant bits for k the rest for x
230        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
231        let hash: u64 = hash & BITS::HASH_MASK;
232
233        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
234            self.count += 1;
235            self.sketch.set(stream, 1);
236        }
237        // 2^4
238        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
239            self.sketch.set(stream, 2);
240        }
241
242        // 2^8
243        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
244            self.sketch.set(stream, 3);
245        }
246
247        let hash = self.hash.hash_one(v4);
248        // use most significant bits for k the rest for x
249        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
250        let hash: u64 = hash & BITS::HASH_MASK;
251
252        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
253            self.count += 1;
254            self.sketch.set(stream, 1);
255        }
256        // 2^4
257        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
258            self.sketch.set(stream, 2);
259        }
260
261        // 2^8
262        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
263            self.sketch.set(stream, 3);
264        }
265
266        if self.count >= threshold {
267            self.count = self.sketch.decrement();
268            self.t += 4;
269        }
270    }
271
272    /// returns the estimated count. This function is non destructive
273    /// and can be called multiple times without changing the state of the counter
274    #[inline]
275    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
276    pub fn count(&self) -> u64 {
277        let beta = 1.0 - f64::from(self.count) / f64::from(BITS::STREAMS);
278        let bias: f64 = (1.0 / beta).ln();
279        (f64::from(self.t).exp2() * f64::from(BITS::STREAMS) * bias) as u64
280    }
281}