1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
pub(super) mod sketch;
#[cfg(test)]
mod tests;

use std::hash::BuildHasher;

pub use sketch::{Sketch, M1024, M128, M2048, M256, M4096, M512, M64};

use crate::AHasherDefaultBuilder;

/// `HyperTwoBits` implementation that is fully stack allocated and generic to avoid branches for
/// different numbers of sub streams.
///
/// Both the hasher and the sub stream size siaz can be customized, by default it uses `AHasherBuilder` and `M256`
pub struct HyperTwoBits<SKETCH: Sketch = M256, HASH: BuildHasher = AHasherDefaultBuilder> {
    hash: HASH,
    sketch: SKETCH,
    count: u32,
    t: u32,
}

impl<SKETCH: Sketch> Default for HyperTwoBits<SKETCH> {
    fn default() -> Self {
        Self {
            hash: AHasherDefaultBuilder::default(),
            sketch: SKETCH::default(),
            count: 0,
            t: 1,
        }
    }
}

impl<HASH: BuildHasher + Default, BITS: Sketch> HyperTwoBits<BITS, HASH> {
    const ALPHA: f64 = 0.988;
    #[must_use]
    /// Creates a new `HyperTwoBits` counter with specified hasher and bitset,
    /// use `HyperTwoBits::default()` for default values.
    pub fn new() -> Self {
        Self {
            hash: HASH::default(),
            sketch: BITS::default(),
            count: 0,
            t: 1,
        }
    }

    /// Merges another `HyperTwoBits` counter into this one
    /// # Panics
    /// If hasheres are seeded as that prevents merging
    pub fn merge(&mut self, mut other: Self) {
        assert_eq!(
            self.hash.hash_one(42),
            other.hash.hash_one(42),
            "Hashers must be the same, can not merge"
        );
        // The paper asks for actions if the sketch is "nearly full", this is a very loose definition
        // we will assume 99% if substreams set is "nearly full"
        #[allow(
            clippy::cast_lossless,
            clippy::cast_sign_loss,
            clippy::cast_possible_truncation
        )]
        let threshold = const { (0.99 * (BITS::STREAMS as f64)) as u32 };
        // for simplicity we ensure that `self` is always the larger sketch
        if other.t > self.t {
            std::mem::swap(self, &mut other);
        }

        // If the values of T differ by 8 or more, use the larger value and its sketches.
        if self.t - other.t > 8 {
            return;
        }
        // we pre-compute if self.t == other.t so we can do the decrement below before handling
        // the other cases
        let same = self.t == other.t;
        // We now only have the first and third case left, so we can handle the decrement
        if self.count >= threshold {
            self.count = self.sketch.decrement();
            self.t += 4;
        }

        if same {
            // Merg sketches
            self.sketch.merge(&other.sketch);
        } else {
            // merge the high bits of other into the low bits of self
            self.sketch.merge_high_into_lo(&other.sketch);
        }
        // update the count
        self.count = self.sketch.count();
    }

    #[inline]
    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
    /// Inserts a value into the counter
    pub fn insert<V: std::hash::Hash + ?Sized>(&mut self, value: &V) {
        let hash = self.hash.hash_one(value);
        self.insert_hash(hash);
    }

    #[inline]
    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
    /// Inserts a value into the counter
    pub fn insert_hash(&mut self, hash: u64) {
        let threshold: u32 = const { (Self::ALPHA * BITS::STREAMS as f64) as u32 };
        // use most significant bits for k the rest for x
        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
        let hash: u64 = hash & BITS::HASH_MASK;

        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
            self.count += 1;
            self.sketch.set(stream, 1);
        }
        // 2^4
        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
            self.sketch.set(stream, 2);
        }

        // 2^8
        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
            self.sketch.set(stream, 3);
        }

        if self.count >= threshold {
            self.count = self.sketch.decrement();
            self.t += 4;
        }
    }

    #[inline]
    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
    /// Inserts 2 elements into the counter for micro batching purposes, note this will delay
    /// the count update to the end
    pub fn insert2<V: std::hash::Hash>(&mut self, v1: &V, v2: &V) {
        let threshold: u32 = const { (Self::ALPHA * BITS::STREAMS as f64) as u32 };

        let hash = self.hash.hash_one(v1);
        // use most significant bits for k the rest for x
        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
        let hash: u64 = hash & BITS::HASH_MASK;

        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
            self.count += 1;
            self.sketch.set(stream, 1);
        }
        // 2^4
        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
            self.sketch.set(stream, 2);
        }

        // 2^8
        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
            self.sketch.set(stream, 3);
        }

        let hash = self.hash.hash_one(v2);
        // use most significant bits for k the rest for x
        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
        let hash: u64 = hash & BITS::HASH_MASK;

        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
            self.count += 1;
            self.sketch.set(stream, 1);
        }
        // 2^4
        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
            self.sketch.set(stream, 2);
        }

        // 2^8
        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
            self.sketch.set(stream, 3);
        }

        if self.count >= threshold {
            self.count = self.sketch.decrement();
            self.t += 4;
        }
    }

    #[inline]
    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
    /// Inserts 4 elements into the counter for micro batching purposes, note this will delay
    /// the count update to the end
    pub fn insert4<V: std::hash::Hash>(&mut self, v1: &V, v2: &V, v3: &V, v4: &V) {
        let threshold: u32 = const { (Self::ALPHA * BITS::STREAMS as f64) as u32 };

        let hash = self.hash.hash_one(v1);
        // use most significant bits for k the rest for x
        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
        let hash: u64 = hash & BITS::HASH_MASK;

        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
            self.count += 1;
            self.sketch.set(stream, 1);
        }
        // 2^4
        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
            self.sketch.set(stream, 2);
        }

        // 2^8
        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
            self.sketch.set(stream, 3);
        }

        let hash = self.hash.hash_one(v2);
        // use most significant bits for k the rest for x
        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
        let hash: u64 = hash & BITS::HASH_MASK;

        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
            self.count += 1;
            self.sketch.set(stream, 1);
        }
        // 2^4
        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
            self.sketch.set(stream, 2);
        }

        // 2^8
        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
            self.sketch.set(stream, 3);
        }

        let hash = self.hash.hash_one(v3);
        // use most significant bits for k the rest for x
        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
        let hash: u64 = hash & BITS::HASH_MASK;

        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
            self.count += 1;
            self.sketch.set(stream, 1);
        }
        // 2^4
        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
            self.sketch.set(stream, 2);
        }

        // 2^8
        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
            self.sketch.set(stream, 3);
        }

        let hash = self.hash.hash_one(v4);
        // use most significant bits for k the rest for x
        let stream: u32 = (hash >> BITS::IDX_SHIFT) as u32;
        let hash: u64 = hash & BITS::HASH_MASK;

        if hash.trailing_ones() >= self.t && self.sketch.val(stream) < 1 {
            self.count += 1;
            self.sketch.set(stream, 1);
        }
        // 2^4
        if hash.trailing_ones() >= self.t + 4 && self.sketch.val(stream) < 2 {
            self.sketch.set(stream, 2);
        }

        // 2^8
        if hash.trailing_ones() >= self.t + 8 && self.sketch.val(stream) < 3 {
            self.sketch.set(stream, 3);
        }

        if self.count >= threshold {
            self.count = self.sketch.decrement();
            self.t += 4;
        }
    }

    /// returns the estimated count. This function is non destructive
    /// and can be called multiple times without changing the state of the counter
    #[inline]
    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
    pub fn count(&self) -> u64 {
        let beta = 1.0 - f64::from(self.count) / f64::from(BITS::STREAMS);
        let bias: f64 = (1.0 / beta).ln();
        ((2.0_f64.powf(f64::from(self.t))) * f64::from(BITS::STREAMS) * bias) as u64
    }
}