1use serde::{Deserialize, Serialize};
2use std::hash::{Hash, Hasher};
3use twox_hash::XxHash64;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct HyperLogLog {
8 precision: u8,
10 m: usize,
12 registers: Vec<u8>,
14}
15
16impl HyperLogLog {
17 pub fn new(precision: u8) -> Result<Self, crate::error::HllError> {
20 if !(4..=16).contains(&precision) {
21 return Err(crate::error::HllError::InvalidPrecision(precision));
22 }
23
24 let m = 1 << precision;
25 Ok(HyperLogLog {
26 precision,
27 m,
28 registers: vec![0; m],
29 })
30 }
31
32 pub fn add<T: Hash>(&mut self, element: &T) {
34 let hash = self.hash_element(element);
35
36 let idx = (hash >> (64 - self.precision)) as usize;
38
39 let remaining = hash << self.precision;
41 let leading_zeros = if remaining == 0 {
42 64 - self.precision + 1
43 } else {
44 remaining.leading_zeros() as u8 + 1
45 };
46
47 if leading_zeros > self.registers[idx] {
49 self.registers[idx] = leading_zeros;
50 }
51 }
52
53 pub fn add_str(&mut self, element: &str) {
55 self.add(&element);
56 }
57
58 pub fn count(&self) -> u64 {
60 let m = self.m as f64;
61
62 let sum: f64 = self.registers.iter()
64 .map(|&val| 2.0_f64.powi(-(val as i32)))
65 .sum();
66
67 let alpha = self.alpha_m();
68 let raw_estimate = alpha * m * m / sum;
69
70 if raw_estimate <= 2.5 * m {
72 let zeros = self.registers.iter().filter(|&&x| x == 0).count();
74 if zeros != 0 {
75 return (m * (m / zeros as f64).ln()) as u64;
76 }
77 }
78
79 if raw_estimate <= (1.0 / 30.0) * (1u64 << 32) as f64 {
80 return raw_estimate as u64;
81 }
82
83 (-((1u64 << 32) as f64) * (1.0 - raw_estimate / ((1u64 << 32) as f64)).ln()) as u64
85 }
86
87 pub fn merge(&mut self, other: &HyperLogLog) -> Result<(), crate::error::HllError> {
89 if self.precision != other.precision {
90 return Err(crate::error::HllError::Storage(
91 "Cannot merge HyperLogLogs with different precision".to_string()
92 ));
93 }
94
95 for (i, &val) in other.registers.iter().enumerate() {
96 if val > self.registers[i] {
97 self.registers[i] = val;
98 }
99 }
100
101 Ok(())
102 }
103
104 pub fn precision(&self) -> u8 {
106 self.precision
107 }
108
109 fn hash_element<T: Hash>(&self, element: &T) -> u64 {
111 let mut hasher = XxHash64::with_seed(0);
112 element.hash(&mut hasher);
113 hasher.finish()
114 }
115
116 fn alpha_m(&self) -> f64 {
118 match self.m {
119 16 => 0.673,
120 32 => 0.697,
121 64 => 0.709,
122 _ => 0.7213 / (1.0 + 1.079 / self.m as f64),
123 }
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn test_new_valid_precision() {
133 for precision in 4..=16 {
134 let hll = HyperLogLog::new(precision);
135 assert!(hll.is_ok(), "Precision {} should be valid", precision);
136
137 let hll = hll.unwrap();
138 assert_eq!(hll.precision(), precision);
139 assert_eq!(hll.m, 1 << precision);
140 assert_eq!(hll.registers.len(), 1 << precision);
141 }
142 }
143
144 #[test]
145 fn test_new_invalid_precision() {
146 assert!(HyperLogLog::new(3).is_err());
147 assert!(HyperLogLog::new(17).is_err());
148 assert!(HyperLogLog::new(0).is_err());
149 assert!(HyperLogLog::new(255).is_err());
150 }
151
152 #[test]
153 fn test_add_deduplication() {
154 let mut hll = HyperLogLog::new(10).unwrap();
155
156 for _ in 0..100 {
158 hll.add_str("same_element");
159 }
160
161 let count = hll.count();
162 assert!(count <= 5, "Count should be close to 1, got {}", count);
164 }
165
166 #[test]
167 fn test_basic_counting_small() {
168 let mut hll = HyperLogLog::new(14).unwrap();
169
170 for i in 0..100 {
171 hll.add(&i);
172 }
173
174 let count = hll.count();
175 let error_rate = ((count as f64 - 100.0) / 100.0).abs();
176
177 assert!(error_rate < 0.15, "Error rate: {:.2}%", error_rate * 100.0);
179 }
180
181 #[test]
182 fn test_basic_counting_medium() {
183 let mut hll = HyperLogLog::new(14).unwrap();
184
185 for i in 0..10000 {
186 hll.add(&i);
187 }
188
189 let count = hll.count();
190 let error_rate = ((count as f64 - 10000.0) / 10000.0).abs();
191
192 assert!(error_rate < 0.05, "Error rate: {:.2}%", error_rate * 100.0);
194 }
195
196 #[test]
197 fn test_basic_counting_large() {
198 let mut hll = HyperLogLog::new(14).unwrap();
199
200 for i in 0..100000 {
201 hll.add(&i);
202 }
203
204 let count = hll.count();
205 let error_rate = ((count as f64 - 100000.0) / 100000.0).abs();
206
207 assert!(error_rate < 0.03, "Error rate: {:.2}%", error_rate * 100.0);
209 }
210
211 #[test]
212 fn test_string_elements() {
213 let mut hll = HyperLogLog::new(10).unwrap();
214
215 hll.add_str("user:1");
216 hll.add_str("user:2");
217 hll.add_str("user:3");
218
219 let count = hll.count();
220 assert!(count >= 2 && count <= 5, "Count should be ~3, got {}", count);
221 }
222
223 #[test]
224 fn test_merge_disjoint() {
225 let mut hll1 = HyperLogLog::new(10).unwrap();
226 let mut hll2 = HyperLogLog::new(10).unwrap();
227
228 for i in 0..100 {
229 hll1.add(&i);
230 }
231
232 for i in 100..200 {
233 hll2.add(&i);
234 }
235
236 hll1.merge(&hll2).unwrap();
237 let count = hll1.count();
238
239 assert!(count > 150 && count < 250, "Count should be ~200, got {}", count);
241 }
242
243 #[test]
244 fn test_merge_overlapping() {
245 let mut hll1 = HyperLogLog::new(12).unwrap();
246 let mut hll2 = HyperLogLog::new(12).unwrap();
247
248 for i in 0..150 {
250 hll1.add(&i);
251 }
252
253 for i in 100..250 {
254 hll2.add(&i);
255 }
256
257 let count1 = hll1.count();
258 let count2 = hll2.count();
259
260 hll1.merge(&hll2).unwrap();
261 let merged_count = hll1.count();
262
263 assert!(
265 merged_count > 200 && merged_count < 300,
266 "Merged count should be ~250, got {}. Individual counts: {}, {}",
267 merged_count,
268 count1,
269 count2
270 );
271 }
272
273 #[test]
274 fn test_merge_precision_mismatch() {
275 let mut hll1 = HyperLogLog::new(10).unwrap();
276 let hll2 = HyperLogLog::new(12).unwrap();
277
278 let result = hll1.merge(&hll2);
279 assert!(result.is_err(), "Should fail to merge different precisions");
280 }
281
282 #[test]
283 fn test_merge_same_data() {
284 let mut hll1 = HyperLogLog::new(10).unwrap();
285 let mut hll2 = HyperLogLog::new(10).unwrap();
286
287 for i in 0..100 {
289 hll1.add(&i);
290 hll2.add(&i);
291 }
292
293 let count_before = hll1.count();
294 hll1.merge(&hll2).unwrap();
295 let count_after = hll1.count();
296
297 let diff = ((count_after as f64 - count_before as f64) / count_before as f64).abs();
299 assert!(diff < 0.1, "Counts should be similar: {} vs {}", count_before, count_after);
300 }
301
302 #[test]
303 fn test_clone() {
304 let mut hll = HyperLogLog::new(10).unwrap();
305
306 for i in 0..1000 {
307 hll.add(&i);
308 }
309
310 let hll_clone = hll.clone();
311
312 assert_eq!(hll.precision(), hll_clone.precision());
313 assert_eq!(hll.count(), hll_clone.count());
314 assert_eq!(hll.registers, hll_clone.registers);
315 }
316
317 #[test]
318 fn test_serialization() {
319 let mut hll = HyperLogLog::new(12).unwrap();
320
321 for i in 0..5000 {
322 hll.add(&i);
323 }
324
325 let serialized = serde_json::to_string(&hll).unwrap();
327
328 let deserialized: HyperLogLog = serde_json::from_str(&serialized).unwrap();
330
331 assert_eq!(hll.precision(), deserialized.precision());
332 assert_eq!(hll.count(), deserialized.count());
333 assert_eq!(hll.registers, deserialized.registers);
334 }
335
336 #[test]
337 fn test_empty_count() {
338 let hll = HyperLogLog::new(10).unwrap();
339 let count = hll.count();
340
341 assert!(count < 10, "Empty HLL count should be ~0, got {}", count);
343 }
344
345 #[test]
346 fn test_different_types() {
347 let mut hll = HyperLogLog::new(10).unwrap();
348
349 hll.add(&42u32);
351 hll.add(&"string");
352 hll.add(&true);
353 hll.add(&3.14f64.to_bits()); let count = hll.count();
356 assert!(count >= 3 && count <= 6, "Should count ~4 items, got {}", count);
357 }
358
359 #[test]
360 fn test_precision_memory_size() {
361 for precision in 4..=16 {
362 let hll = HyperLogLog::new(precision).unwrap();
363 let expected_size = 1 << precision;
364 assert_eq!(
365 hll.registers.len(),
366 expected_size,
367 "Precision {} should have {} registers",
368 precision,
369 expected_size
370 );
371 }
372 }
373}