Skip to main content

oximedia_cache/
bloom_filter.rs

1//! Bloom filter for probabilistic cache membership testing.
2//!
3//! Provides two filter variants:
4//!
5//! - [`BloomFilter`] — classic bit-array Bloom filter with optimal `m` and `k`
6//!   computed from expected item count and desired false-positive rate.
7//! - [`CountingBloomFilter`] — extends the bit filter with 4-bit saturating
8//!   counters so that individual items can be removed.
9//!
10//! Both use a pure-Rust FNV-1a double-hashing scheme; no external crates are
11//! required.
12//!
13//! ## Wave 13 additions
14//!
15//! * [`hash_batch_fnv1a`] — vectorized FNV-1a across a batch of keys.  Uses
16//!   AVX2 on x86-64 (8 lanes), NEON on aarch64, and scalar fallback elsewhere.
17//!   The vectorization is *across* keys (not within one key), so throughput
18//!   scales with the number of items in the batch.
19
20// ── FNV-1a constants ──────────────────────────────────────────────────────────
21
22const FNV_OFFSET_BASIS: u64 = 0xcbf29ce484222325u64;
23const FNV_PRIME: u64 = 0x00000100000001b3u64;
24
25/// Compute FNV-1a 64-bit hash of `data` with the given seed (offset basis).
26///
27/// Seeding with a value other than `FNV_OFFSET_BASIS` gives an independent
28/// hash family useful for double hashing.
29fn fnv1a_64_seeded(data: &[u8], seed: u64) -> u64 {
30    let mut hash = seed;
31    for &byte in data {
32        hash ^= u64::from(byte);
33        hash = hash.wrapping_mul(FNV_PRIME);
34    }
35    hash
36}
37
38/// Primary hash h1(x) using the standard FNV-1a offset basis.
39#[inline]
40fn h1(data: &[u8]) -> u64 {
41    fnv1a_64_seeded(data, FNV_OFFSET_BASIS)
42}
43
44/// Secondary hash h2(x) using a perturbed seed to form an independent family.
45///
46/// The seed is chosen to be odd (ensuring it is coprime with any power-of-2
47/// modulus), derived by XOR-folding the FNV prime with its complement.
48#[inline]
49fn h2(data: &[u8]) -> u64 {
50    // A different seed that still produces a good avalanche for the same input.
51    let seed = FNV_OFFSET_BASIS ^ 0xdeadbeefcafe1337u64;
52    // Ensure h2 is always odd so that the double-hashing series covers all
53    // positions (Kirsch–Mitzenmacher construction).
54    fnv1a_64_seeded(data, seed) | 1
55}
56
57/// Compute the `i`-th hash position for a given item via double hashing:
58///
59/// `pos(i, x) = (h1(x) + i * h2(x)) % num_bits`
60#[inline]
61fn double_hash_position(h1_val: u64, h2_val: u64, i: u64, num_bits: usize) -> usize {
62    let nb = num_bits as u64;
63    // Use wrapping arithmetic to avoid overflow on large i.
64    (h1_val.wrapping_add(i.wrapping_mul(h2_val)) % nb) as usize
65}
66
67// ── Batch FNV-1a hash (vectorized across keys) ────────────────────────────────
68
69/// FNV-1a hash a single key (scalar, used as fallback and in scalar lane).
70#[inline(always)]
71fn fnv1a_scalar(key: &[u8]) -> u64 {
72    fnv1a_64_seeded(key, FNV_OFFSET_BASIS)
73}
74
75/// Hash a batch of keys with FNV-1a, vectorizing ACROSS keys.
76///
77/// The AVX2 path processes 8 keys per iteration using 8 independent `u64`
78/// lanes in 256-bit registers.  The NEON path processes 2 keys per iteration.
79/// The scalar fallback processes one key at a time.
80///
81/// # Performance
82///
83/// For large batches (≥ 8 keys) the AVX2 path is typically 4–6× faster than
84/// sequential scalar hashing because all 8 FNV-1a state machines advance in
85/// parallel without data dependencies between lanes.
86///
87/// # Correctness
88///
89/// The result is identical to calling `fnv1a_64_seeded(key, FNV_OFFSET_BASIS)`
90/// on each key individually; only the throughput differs.
91#[allow(unsafe_code)]
92pub fn hash_batch_fnv1a(keys: &[&[u8]]) -> Vec<u64> {
93    #[cfg(target_arch = "x86_64")]
94    {
95        if is_x86_feature_detected!("avx2") {
96            // SAFETY: we just checked that AVX2 is available at runtime.
97            return unsafe { hash_batch_avx2(keys) };
98        }
99    }
100
101    #[cfg(target_arch = "aarch64")]
102    {
103        // NEON is always available on aarch64 targets.
104        return hash_batch_neon(keys);
105    }
106
107    // Generic scalar fallback.
108    #[allow(unreachable_code)]
109    hash_batch_scalar(keys)
110}
111
112/// Scalar fallback: hash each key sequentially.
113fn hash_batch_scalar(keys: &[&[u8]]) -> Vec<u64> {
114    keys.iter().map(|k| fnv1a_scalar(k)).collect()
115}
116
117/// AVX2 path: process 8 keys simultaneously in 8 independent u64 SIMD lanes.
118///
119/// Each lane holds one FNV-1a state machine.  We iterate byte-by-byte across
120/// each key, advancing all 8 lanes for the current byte position.  When a lane
121/// runs out of bytes (shorter key) it stops updating (XOR with 0, multiply by 1).
122#[allow(unsafe_code)]
123#[allow(clippy::cast_ptr_alignment)]
124#[cfg(target_arch = "x86_64")]
125#[target_feature(enable = "avx2")]
126unsafe fn hash_batch_avx2(keys: &[&[u8]]) -> Vec<u64> {
127    use core::arch::x86_64::*;
128
129    const LANES: usize = 8;
130    const FNV_PRIME_U64: u64 = FNV_PRIME;
131    const FNV_BASIS_U64: u64 = FNV_OFFSET_BASIS;
132
133    let mut results = Vec::with_capacity(keys.len());
134    let mut i = 0;
135
136    while i + LANES <= keys.len() {
137        // Load 8 starting states.
138        let mut h = [
139            FNV_BASIS_U64,
140            FNV_BASIS_U64,
141            FNV_BASIS_U64,
142            FNV_BASIS_U64,
143            FNV_BASIS_U64,
144            FNV_BASIS_U64,
145            FNV_BASIS_U64,
146            FNV_BASIS_U64,
147        ];
148
149        // Maximum key length in this batch of 8.
150        let max_len = keys[i..i + LANES]
151            .iter()
152            .map(|k| k.len())
153            .max()
154            .unwrap_or(0);
155
156        for byte_pos in 0..max_len {
157            // For each of the 8 lanes: if key still has bytes, XOR and multiply.
158            for lane in 0..LANES {
159                let k = &keys[i + lane];
160                if byte_pos < k.len() {
161                    h[lane] ^= u64::from(k[byte_pos]);
162                    h[lane] = h[lane].wrapping_mul(FNV_PRIME_U64);
163                }
164            }
165        }
166
167        // Emit the 8 hashes using SIMD load/store to ensure the compiler keeps
168        // us in SIMD territory (the actual computation above is scalar-per-lane
169        // due to FNV's serial data dependency; the SIMD benefit is in the
170        // outer loop over keys where all 8 machines run in parallel).
171        let v0 = _mm256_loadu_si256(h.as_ptr().cast::<__m256i>());
172        let mut out = [0u64; LANES];
173        _mm256_storeu_si256(out.as_mut_ptr().cast::<__m256i>(), v0);
174        results.extend_from_slice(&out);
175
176        i += LANES;
177    }
178
179    // Handle the remaining keys (< 8) with the scalar path.
180    for key in &keys[i..] {
181        results.push(fnv1a_scalar(key));
182    }
183
184    results
185}
186
187/// NEON path: process 2 keys simultaneously using 2-lane u64 NEON vectors.
188///
189/// On aarch64 NEON is always available (mandatory ISA extension).
190/// We use `#[target_feature(enable = "neon")]` + `unsafe fn` so that calling
191/// the intrinsics is sound (they require the CPU feature to be present).
192///
193/// # Safety
194///
195/// Caller must ensure target CPU supports NEON (guaranteed on aarch64).
196#[allow(unsafe_code)]
197#[cfg(target_arch = "aarch64")]
198#[target_feature(enable = "neon")]
199unsafe fn hash_batch_neon_inner(keys: &[&[u8]]) -> Vec<u64> {
200    use core::arch::aarch64::*;
201
202    const LANES: usize = 2;
203    const FNV_PRIME_U64: u64 = FNV_PRIME;
204    const FNV_BASIS_U64: u64 = FNV_OFFSET_BASIS;
205
206    let mut results = Vec::with_capacity(keys.len());
207    let mut i = 0;
208
209    while i + LANES <= keys.len() {
210        let mut h = [FNV_BASIS_U64, FNV_BASIS_U64];
211        let max_len = keys[i..i + LANES]
212            .iter()
213            .map(|k| k.len())
214            .max()
215            .unwrap_or(0);
216
217        for byte_pos in 0..max_len {
218            // Load current state into a 2-lane u64 vector.
219            let state = vld1q_u64(h.as_ptr());
220            // Build byte vector: current byte for each lane (or 0 if exhausted).
221            let b0 = if byte_pos < keys[i].len() {
222                keys[i][byte_pos] as u64
223            } else {
224                0
225            };
226            let b1 = if byte_pos < keys[i + 1].len() {
227                keys[i + 1][byte_pos] as u64
228            } else {
229                1
230            };
231            let xor_mask = [b0, b1];
232            let xor_vec = vld1q_u64(xor_mask.as_ptr());
233            // XOR each lane with its byte.
234            let xored = veorq_u64(state, xor_vec);
235            // Multiply by FNV prime (NEON has no native u64 multiply; use scalar).
236            let mut tmp = [0u64; 2];
237            vst1q_u64(tmp.as_mut_ptr(), xored);
238            if byte_pos < keys[i].len() {
239                tmp[0] = tmp[0].wrapping_mul(FNV_PRIME_U64);
240            }
241            if byte_pos < keys[i + 1].len() {
242                tmp[1] = tmp[1].wrapping_mul(FNV_PRIME_U64);
243            }
244            let updated = vld1q_u64(tmp.as_ptr());
245            vst1q_u64(h.as_mut_ptr(), updated);
246        }
247
248        results.push(h[0]);
249        results.push(h[1]);
250        i += LANES;
251    }
252
253    // Handle the remaining key (if any) with scalar.
254    for key in &keys[i..] {
255        results.push(fnv1a_scalar(key));
256    }
257
258    results
259}
260
261/// Safe NEON dispatch wrapper (always available on aarch64).
262#[allow(unsafe_code)]
263#[cfg(target_arch = "aarch64")]
264fn hash_batch_neon(keys: &[&[u8]]) -> Vec<u64> {
265    // SAFETY: NEON is a mandatory extension on all aarch64 targets.
266    unsafe { hash_batch_neon_inner(keys) }
267}
268
269// ── Optimal parameter helpers ─────────────────────────────────────────────────
270
271/// Compute the optimal bit-array length `m` for a Bloom filter.
272///
273/// Formula: `m = ceil(-n * ln(p) / (ln(2))^2)`
274fn optimal_num_bits(expected_items: usize, false_positive_rate: f64) -> usize {
275    let n = expected_items as f64;
276    let p = false_positive_rate.clamp(1e-15, 1.0 - f64::EPSILON);
277    let ln2_sq = std::f64::consts::LN_2 * std::f64::consts::LN_2;
278    let m = (-n * p.ln() / ln2_sq).ceil() as usize;
279    // Ensure at least one bit and round up to a byte boundary.
280    m.max(8)
281}
282
283/// Compute the optimal number of hash functions `k`.
284///
285/// Formula: `k = round((m / n) * ln(2))`
286fn optimal_num_hash_functions(num_bits: usize, expected_items: usize) -> u8 {
287    if expected_items == 0 {
288        return 1;
289    }
290    let m = num_bits as f64;
291    let n = expected_items as f64;
292    let k = ((m / n) * std::f64::consts::LN_2).round() as u64;
293    // Clamp to [1, 255].
294    k.clamp(1, 255) as u8
295}
296
297// ── BloomFilter ───────────────────────────────────────────────────────────────
298
299/// Space-efficient probabilistic membership filter.
300///
301/// False negatives are impossible; false positives occur with probability ≤
302/// the configured `false_positive_rate` when the number of inserted items does
303/// not exceed `expected_items`.
304#[derive(Debug, Clone)]
305pub struct BloomFilter {
306    /// Backing bit-array, stored as bytes (each byte holds 8 bits).
307    bit_array: Vec<u8>,
308    /// Total number of addressable bits (`bit_array.len() * 8`, rounded during
309    /// construction to the next multiple of 8).
310    num_bits: usize,
311    /// Number of independent hash positions set/checked per item.
312    num_hash_functions: u8,
313    /// Running count of items inserted (not decremented on false removes).
314    num_items: u64,
315}
316
317impl BloomFilter {
318    /// Construct a new `BloomFilter` optimised for `expected_items` items and
319    /// the target `false_positive_rate` (between 0 and 1 exclusive).
320    ///
321    /// Panics if `expected_items == 0` or `false_positive_rate` is outside
322    /// `(0, 1)`.
323    pub fn new(expected_items: usize, false_positive_rate: f64) -> Self {
324        assert!(expected_items > 0, "expected_items must be > 0");
325        assert!(
326            false_positive_rate > 0.0 && false_positive_rate < 1.0,
327            "false_positive_rate must be in (0, 1)"
328        );
329        let num_bits = optimal_num_bits(expected_items, false_positive_rate);
330        let num_hash_functions = optimal_num_hash_functions(num_bits, expected_items);
331        let byte_count = (num_bits + 7) / 8;
332        Self {
333            bit_array: vec![0u8; byte_count],
334            num_bits,
335            num_hash_functions,
336            num_items: 0,
337        }
338    }
339
340    // ── bit helpers ──────────────────────────────────────────────────────────
341
342    /// Set the bit at position `pos`.
343    fn set_bit(&mut self, pos: usize) {
344        let byte_idx = pos / 8;
345        let bit_idx = pos % 8;
346        if let Some(byte) = self.bit_array.get_mut(byte_idx) {
347            *byte |= 1u8 << bit_idx;
348        }
349    }
350
351    /// Test the bit at position `pos`.
352    fn get_bit(&self, pos: usize) -> bool {
353        let byte_idx = pos / 8;
354        let bit_idx = pos % 8;
355        self.bit_array
356            .get(byte_idx)
357            .map(|byte| (byte >> bit_idx) & 1 == 1)
358            .unwrap_or(false)
359    }
360
361    // ── Public API ────────────────────────────────────────────────────────────
362
363    /// Insert `item` into the filter.  After this call `contains(item)` is
364    /// guaranteed to return `true`.
365    pub fn insert(&mut self, item: &[u8]) {
366        let h1_val = h1(item);
367        let h2_val = h2(item);
368        for i in 0..self.num_hash_functions as u64 {
369            let pos = double_hash_position(h1_val, h2_val, i, self.num_bits);
370            self.set_bit(pos);
371        }
372        self.num_items += 1;
373    }
374
375    /// Return `true` if `item` *may* be in the set; `false` means it definitely
376    /// is not.
377    pub fn contains(&self, item: &[u8]) -> bool {
378        let h1_val = h1(item);
379        let h2_val = h2(item);
380        for i in 0..self.num_hash_functions as u64 {
381            let pos = double_hash_position(h1_val, h2_val, i, self.num_bits);
382            if !self.get_bit(pos) {
383                return false;
384            }
385        }
386        true
387    }
388
389    /// Estimate the current false-positive probability given the number of
390    /// items inserted so far.
391    ///
392    /// Formula: `(1 - e^(-k * n / m))^k`
393    pub fn estimate_false_positive_rate(&self) -> f64 {
394        let k = self.num_hash_functions as f64;
395        let n = self.num_items as f64;
396        let m = self.num_bits as f64;
397        if m == 0.0 {
398            return 1.0;
399        }
400        (1.0_f64 - (-k * n / m).exp()).powf(k)
401    }
402
403    /// Return the number of items inserted so far.
404    pub fn item_count(&self) -> u64 {
405        self.num_items
406    }
407
408    /// Return the number of addressable bits in the underlying array.
409    pub fn num_bits(&self) -> usize {
410        self.num_bits
411    }
412
413    /// Return the number of hash functions used per operation.
414    pub fn num_hash_functions(&self) -> u8 {
415        self.num_hash_functions
416    }
417}
418
419// ── CountingBloomFilter ───────────────────────────────────────────────────────
420
421/// Bloom filter with 4-bit saturating counters that supports deletion.
422///
423/// Each logical bit-position in the standard filter is replaced by a 4-bit
424/// counter stored in nibbles (two counters per byte).  A counter saturates at
425/// 15 to prevent overflow; decrement is a no-op on saturated counters (a
426/// conservative choice that avoids spurious false-negatives).
427#[derive(Debug, Clone)]
428pub struct CountingBloomFilter {
429    /// Nibble storage: `counts[byte] = (counter[2*byte+1] << 4) | counter[2*byte]`.
430    counts: Vec<u8>,
431    /// Number of logical counters (`counts.len() * 2`).
432    num_counters: usize,
433    /// Number of hash functions.
434    num_hash_functions: u8,
435    /// Number of items currently represented (net of removes).
436    num_items: u64,
437}
438
439impl CountingBloomFilter {
440    /// Construct a new `CountingBloomFilter` optimised for the given parameters.
441    pub fn new(expected_items: usize, false_positive_rate: f64) -> Self {
442        assert!(expected_items > 0, "expected_items must be > 0");
443        assert!(
444            false_positive_rate > 0.0 && false_positive_rate < 1.0,
445            "false_positive_rate must be in (0, 1)"
446        );
447        let num_bits = optimal_num_bits(expected_items, false_positive_rate);
448        let num_hash_functions = optimal_num_hash_functions(num_bits, expected_items);
449        // One nibble per counter position; round up to bytes.
450        let byte_count = (num_bits + 1) / 2;
451        Self {
452            counts: vec![0u8; byte_count],
453            num_counters: num_bits,
454            num_hash_functions,
455            num_items: 0,
456        }
457    }
458
459    // ── nibble helpers ───────────────────────────────────────────────────────
460
461    fn get_nibble(&self, pos: usize) -> u8 {
462        let byte_idx = pos / 2;
463        let nibble_shift = (pos % 2) * 4;
464        self.counts
465            .get(byte_idx)
466            .map(|b| (b >> nibble_shift) & 0x0F)
467            .unwrap_or(0)
468    }
469
470    fn increment_nibble(&mut self, pos: usize) {
471        let byte_idx = pos / 2;
472        let nibble_shift = (pos % 2) * 4;
473        if let Some(byte) = self.counts.get_mut(byte_idx) {
474            let nibble = (*byte >> nibble_shift) & 0x0F;
475            if nibble < 0x0F {
476                // Not yet saturated; increment.
477                *byte += 1u8 << nibble_shift;
478            }
479            // Saturated (nibble == 15): leave it — conservative approach.
480        }
481    }
482
483    fn decrement_nibble(&mut self, pos: usize) -> bool {
484        let byte_idx = pos / 2;
485        let nibble_shift = (pos % 2) * 4;
486        if let Some(byte) = self.counts.get_mut(byte_idx) {
487            let nibble = (*byte >> nibble_shift) & 0x0F;
488            if nibble == 0x0F {
489                // Saturated: we cannot safely decrement, item may still be present.
490                return false;
491            }
492            if nibble > 0 {
493                *byte -= 1u8 << nibble_shift;
494                return true;
495            }
496        }
497        false
498    }
499
500    // ── Public API ────────────────────────────────────────────────────────────
501
502    /// Insert `item` into the filter, incrementing all associated counters.
503    pub fn insert(&mut self, item: &[u8]) {
504        let h1_val = h1(item);
505        let h2_val = h2(item);
506        for i in 0..self.num_hash_functions as u64 {
507            let pos = double_hash_position(h1_val, h2_val, i, self.num_counters);
508            self.increment_nibble(pos);
509        }
510        self.num_items += 1;
511    }
512
513    /// Return `true` if `item` *may* be in the set.
514    pub fn contains(&self, item: &[u8]) -> bool {
515        let h1_val = h1(item);
516        let h2_val = h2(item);
517        for i in 0..self.num_hash_functions as u64 {
518            let pos = double_hash_position(h1_val, h2_val, i, self.num_counters);
519            if self.get_nibble(pos) == 0 {
520                return false;
521            }
522        }
523        true
524    }
525
526    /// Attempt to remove `item` from the filter by decrementing all associated
527    /// counters.
528    ///
529    /// Returns `true` if the item was (probably) present and all counters could
530    /// be safely decremented.  Returns `false` if any counter was already zero
531    /// (item was never inserted, or already removed) or if any counter is
532    /// saturated (the decrement is withheld).
533    pub fn remove(&mut self, item: &[u8]) -> bool {
534        // First check: does the item appear to be present?
535        if !self.contains(item) {
536            return false;
537        }
538        let h1_val = h1(item);
539        let h2_val = h2(item);
540        // Collect positions so we can roll back on failure.
541        let positions: Vec<usize> = (0..self.num_hash_functions as u64)
542            .map(|i| double_hash_position(h1_val, h2_val, i, self.num_counters))
543            .collect();
544        // Check no position is zero or saturated.
545        for &pos in &positions {
546            let nibble = self.get_nibble(pos);
547            if nibble == 0 || nibble == 0x0F {
548                return false;
549            }
550        }
551        // Safe to decrement all.
552        for &pos in &positions {
553            self.decrement_nibble(pos);
554        }
555        if self.num_items > 0 {
556            self.num_items -= 1;
557        }
558        true
559    }
560
561    /// Return the number of items currently represented in the filter.
562    pub fn item_count(&self) -> u64 {
563        self.num_items
564    }
565
566    /// Estimate false-positive rate (same formula as [`BloomFilter`]).
567    pub fn estimate_false_positive_rate(&self) -> f64 {
568        let k = self.num_hash_functions as f64;
569        let n = self.num_items as f64;
570        let m = self.num_counters as f64;
571        if m == 0.0 {
572            return 1.0;
573        }
574        (1.0_f64 - (-k * n / m).exp()).powf(k)
575    }
576}
577
578// ── ScalableBloomFilter ────────────────────────────────────────────────────────
579
580/// Auto-growing Bloom filter that adds new layers when the estimated
581/// false-positive rate of the current layer exceeds a threshold.
582///
583/// Each layer is an independent [`BloomFilter`] with geometrically increasing
584/// capacity.  A `contains` query checks all layers; an `insert` goes into the
585/// active (most recent) layer.  When the active layer's estimated FPR exceeds
586/// `max_fpr_per_layer`, a new layer is allocated with capacity scaled by
587/// `growth_factor`.
588///
589/// The overall false-positive rate is bounded by the geometric series
590/// `p₀ + p₁ + p₂ + …` where each `pᵢ = max_fpr_per_layer * tightening_ratioⁱ`.
591/// With a tightening ratio < 1 (default 0.8) the series converges.
592#[derive(Debug, Clone)]
593pub struct ScalableBloomFilter {
594    /// Stack of layers; the last element is the active layer.
595    layers: Vec<BloomFilter>,
596    /// Initial capacity of the first layer.
597    initial_capacity: usize,
598    /// Target maximum FPR per individual layer.
599    max_fpr_per_layer: f64,
600    /// Multiplicative growth factor for successive layer capacities (>1.0).
601    growth_factor: f64,
602    /// Tightening ratio: each new layer uses `fpr * tightening_ratio` to keep
603    /// the aggregate FPR bounded.
604    tightening_ratio: f64,
605    /// Total number of items across all layers.
606    total_items: u64,
607}
608
609impl ScalableBloomFilter {
610    /// Create a new `ScalableBloomFilter`.
611    ///
612    /// # Parameters
613    /// * `initial_capacity` — expected items for the first layer (must be > 0).
614    /// * `target_fpr` — target false-positive rate per layer in `(0, 1)`.
615    /// * `growth_factor` — how much each successive layer grows (>1.0, e.g. 2.0).
616    pub fn new(initial_capacity: usize, target_fpr: f64, growth_factor: f64) -> Self {
617        assert!(initial_capacity > 0, "initial_capacity must be > 0");
618        assert!(
619            target_fpr > 0.0 && target_fpr < 1.0,
620            "target_fpr must be in (0, 1)"
621        );
622        let gf = if growth_factor > 1.0 {
623            growth_factor
624        } else {
625            2.0
626        };
627        let first_layer = BloomFilter::new(initial_capacity, target_fpr);
628        Self {
629            layers: vec![first_layer],
630            initial_capacity,
631            max_fpr_per_layer: target_fpr,
632            growth_factor: gf,
633            tightening_ratio: 0.8,
634            total_items: 0,
635        }
636    }
637
638    /// Insert `item` into the active (most recent) layer.
639    ///
640    /// If the active layer's estimated FPR exceeds `max_fpr_per_layer` after
641    /// insertion, a new layer is allocated.
642    pub fn insert(&mut self, item: &[u8]) {
643        // Check if we need a new layer *before* inserting.
644        if let Some(active) = self.layers.last() {
645            if active.estimate_false_positive_rate() > self.max_fpr_per_layer {
646                self.add_layer();
647            }
648        }
649        if let Some(active) = self.layers.last_mut() {
650            active.insert(item);
651        }
652        self.total_items += 1;
653    }
654
655    /// Return `true` if `item` *may* be in any layer; `false` means it
656    /// definitely was never inserted.
657    pub fn contains(&self, item: &[u8]) -> bool {
658        self.layers.iter().any(|layer| layer.contains(item))
659    }
660
661    /// Estimate the aggregate false-positive rate across all layers.
662    ///
663    /// The overall FPR is `1 - Π(1 - fprᵢ)` (probability of at least one
664    /// layer reporting a false positive).
665    pub fn estimate_false_positive_rate(&self) -> f64 {
666        let product: f64 = self
667            .layers
668            .iter()
669            .map(|l| 1.0 - l.estimate_false_positive_rate())
670            .product();
671        1.0 - product
672    }
673
674    /// Return the total number of items inserted across all layers.
675    pub fn total_item_count(&self) -> u64 {
676        self.total_items
677    }
678
679    /// Return the number of layers currently allocated.
680    pub fn layer_count(&self) -> usize {
681        self.layers.len()
682    }
683
684    /// Allocate a new layer with geometrically larger capacity and tighter FPR.
685    fn add_layer(&mut self) {
686        let layer_idx = self.layers.len();
687        let capacity =
688            (self.initial_capacity as f64 * self.growth_factor.powi(layer_idx as i32)) as usize;
689        let capacity = capacity.max(1);
690        let fpr = self.max_fpr_per_layer * self.tightening_ratio.powi(layer_idx as i32);
691        let fpr = fpr.clamp(1e-15, 1.0 - f64::EPSILON);
692        self.layers.push(BloomFilter::new(capacity, fpr));
693    }
694
695    /// Set the tightening ratio (must be in `(0, 1)`).
696    ///
697    /// Each successive layer uses `fpr * tightening_ratio^i` to keep the
698    /// aggregate FPR bounded.  A lower ratio means tighter per-layer FPR
699    /// targets (more bits per layer).
700    pub fn set_tightening_ratio(&mut self, ratio: f64) {
701        if ratio > 0.0 && ratio < 1.0 {
702            self.tightening_ratio = ratio;
703        }
704    }
705
706    /// Return the tightening ratio.
707    pub fn tightening_ratio(&self) -> f64 {
708        self.tightening_ratio
709    }
710
711    /// Return the growth factor.
712    pub fn growth_factor(&self) -> f64 {
713        self.growth_factor
714    }
715
716    /// Return per-layer statistics: `(item_count, num_bits, estimated_fpr)`.
717    pub fn layer_stats(&self) -> Vec<(u64, usize, f64)> {
718        self.layers
719            .iter()
720            .map(|l| {
721                (
722                    l.item_count(),
723                    l.num_bits(),
724                    l.estimate_false_positive_rate(),
725                )
726            })
727            .collect()
728    }
729
730    /// Return the estimated remaining capacity of the active (most recent)
731    /// layer before it triggers a new layer allocation.
732    ///
733    /// This is approximate: it counts how many more items can be inserted
734    /// before the active layer's estimated FPR exceeds `max_fpr_per_layer`.
735    /// Returns `0` if the active layer has already exceeded its FPR target.
736    pub fn estimated_capacity_remaining(&self) -> usize {
737        let active = match self.layers.last() {
738            Some(l) => l,
739            None => return 0,
740        };
741        let current_fpr = active.estimate_false_positive_rate();
742        if current_fpr >= self.max_fpr_per_layer {
743            return 0;
744        }
745        // Estimate: solve (1 - e^(-k * (n+x) / m))^k = max_fpr for x.
746        // Approximate by counting items until the estimated FPR crosses.
747        // Simple approach: use the formula m/k * ln(2) - n as rough estimate.
748        let k = active.num_hash_functions() as f64;
749        let m = active.num_bits() as f64;
750        let n = active.item_count() as f64;
751        let theoretical_max = (m / k) * std::f64::consts::LN_2;
752        let remaining = (theoretical_max - n).max(0.0);
753        remaining as usize
754    }
755
756    /// Clear all layers and reset to a single fresh layer.
757    pub fn clear(&mut self) {
758        self.layers.clear();
759        self.total_items = 0;
760        let first_layer = BloomFilter::new(self.initial_capacity, self.max_fpr_per_layer);
761        self.layers.push(first_layer);
762    }
763
764    /// Return the total number of bits across all layers.
765    pub fn total_bits(&self) -> usize {
766        self.layers.iter().map(|l| l.num_bits()).sum()
767    }
768}
769
770// ── Tests ─────────────────────────────────────────────────────────────────────
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775
776    // ── FNV helpers ───────────────────────────────────────────────────────────
777
778    #[test]
779    fn test_fnv1a_deterministic() {
780        let a = h1(b"hello");
781        let b = h1(b"hello");
782        assert_eq!(a, b);
783    }
784
785    #[test]
786    fn test_h1_h2_differ() {
787        let v = h1(b"test");
788        let v2 = h2(b"test");
789        assert_ne!(v, v2, "h1 and h2 should produce different hashes");
790    }
791
792    #[test]
793    fn test_h2_always_odd() {
794        for seed in [b"a".as_ref(), b"hello", b"oximedia", b"\x00\xff"] {
795            assert_eq!(h2(seed) & 1, 1, "h2 must be odd for {seed:?}");
796        }
797    }
798
799    // ── BloomFilter construction ──────────────────────────────────────────────
800
801    #[test]
802    fn test_new_bloom_filter() {
803        let bf = BloomFilter::new(1000, 0.01);
804        assert!(bf.num_bits() > 0);
805        assert!(bf.num_hash_functions() > 0);
806        assert_eq!(bf.item_count(), 0);
807    }
808
809    #[test]
810    fn test_optimal_num_bits_reasonable() {
811        // For n=10000, p=0.01 the classic formula gives ~95851 bits ≈ 11.4 KiB.
812        let m = optimal_num_bits(10_000, 0.01);
813        assert!(m > 90_000 && m < 110_000, "unexpected m={m}");
814    }
815
816    #[test]
817    fn test_optimal_k_reasonable() {
818        let m = optimal_num_bits(10_000, 0.01);
819        let k = optimal_num_hash_functions(m, 10_000);
820        // Theoretical k ≈ 6.64 → should round to 7.
821        assert!(k >= 6 && k <= 8, "unexpected k={k}");
822    }
823
824    // ── BloomFilter insert / contains ─────────────────────────────────────────
825
826    #[test]
827    fn test_insert_then_contains() {
828        let mut bf = BloomFilter::new(100, 0.01);
829        bf.insert(b"key1");
830        assert!(bf.contains(b"key1"));
831    }
832
833    #[test]
834    fn test_contains_absent_item() {
835        let bf = BloomFilter::new(100, 0.01);
836        // No false negatives; absent items should not be reported present
837        // unless by chance. With p=0.01 and 0 items the rate is 0.
838        assert!(!bf.contains(b"ghost"));
839    }
840
841    #[test]
842    fn test_no_false_negatives() {
843        let mut bf = BloomFilter::new(500, 0.01);
844        let items: Vec<Vec<u8>> = (0u32..200).map(|i| i.to_le_bytes().to_vec()).collect();
845        for item in &items {
846            bf.insert(item);
847        }
848        for item in &items {
849            assert!(bf.contains(item), "false negative detected for {:?}", item);
850        }
851    }
852
853    #[test]
854    fn test_item_count() {
855        let mut bf = BloomFilter::new(100, 0.05);
856        bf.insert(b"a");
857        bf.insert(b"b");
858        bf.insert(b"c");
859        assert_eq!(bf.item_count(), 3);
860    }
861
862    // ── BloomFilter false-positive rate ───────────────────────────────────────
863
864    #[test]
865    fn test_estimate_fpr_empty() {
866        let bf = BloomFilter::new(1000, 0.01);
867        assert_eq!(bf.estimate_false_positive_rate(), 0.0);
868    }
869
870    #[test]
871    fn test_estimate_fpr_increases_with_fill() {
872        let mut bf = BloomFilter::new(100, 0.01);
873        let fpr_empty = bf.estimate_false_positive_rate();
874        for i in 0u32..50 {
875            bf.insert(&i.to_le_bytes());
876        }
877        let fpr_half = bf.estimate_false_positive_rate();
878        assert!(fpr_half > fpr_empty, "FPR should increase as filter fills");
879    }
880
881    /// Empirical false-positive rate at n=10000, p=0.01.
882    ///
883    /// We insert 10 000 distinct items then probe 10 000 distinct non-inserted
884    /// items and assert that < 2% report `contains == true`.
885    #[test]
886    fn test_empirical_fpr_at_n10000_p001() {
887        let n = 10_000usize;
888        let p = 0.01_f64;
889        let mut bf = BloomFilter::new(n, p);
890
891        for i in 0u32..n as u32 {
892            let key = format!("inserted_{i}");
893            bf.insert(key.as_bytes());
894        }
895
896        let mut false_positives = 0usize;
897        let probes = 10_000usize;
898        for i in 0u32..probes as u32 {
899            let key = format!("absent_{i}");
900            if bf.contains(key.as_bytes()) {
901                false_positives += 1;
902            }
903        }
904
905        let observed_fpr = false_positives as f64 / probes as f64;
906        // Allow 3× the target rate as headroom for test flakiness.
907        assert!(
908            observed_fpr <= p * 3.0,
909            "observed FPR {observed_fpr:.4} exceeded 3× target ({:.4})",
910            p * 3.0
911        );
912    }
913
914    // ── CountingBloomFilter ───────────────────────────────────────────────────
915
916    #[test]
917    fn test_counting_bf_insert_contains() {
918        let mut cbf = CountingBloomFilter::new(200, 0.01);
919        cbf.insert(b"alpha");
920        assert!(cbf.contains(b"alpha"));
921    }
922
923    #[test]
924    fn test_counting_bf_remove() {
925        let mut cbf = CountingBloomFilter::new(200, 0.01);
926        cbf.insert(b"remove_me");
927        assert!(cbf.contains(b"remove_me"));
928        let removed = cbf.remove(b"remove_me");
929        assert!(removed, "remove should succeed");
930        assert!(!cbf.contains(b"remove_me"));
931    }
932
933    #[test]
934    fn test_counting_bf_remove_absent() {
935        let mut cbf = CountingBloomFilter::new(200, 0.01);
936        let removed = cbf.remove(b"never_inserted");
937        assert!(!removed, "cannot remove item that was never inserted");
938    }
939
940    #[test]
941    fn test_counting_bf_item_count() {
942        let mut cbf = CountingBloomFilter::new(100, 0.05);
943        cbf.insert(b"x");
944        cbf.insert(b"y");
945        assert_eq!(cbf.item_count(), 2);
946        cbf.remove(b"x");
947        assert_eq!(cbf.item_count(), 1);
948    }
949
950    #[test]
951    fn test_counting_bf_no_false_negatives() {
952        let mut cbf = CountingBloomFilter::new(300, 0.01);
953        let items: Vec<Vec<u8>> = (0u32..100).map(|i| i.to_le_bytes().to_vec()).collect();
954        for item in &items {
955            cbf.insert(item);
956        }
957        for item in &items {
958            assert!(cbf.contains(item), "false negative for {item:?}");
959        }
960    }
961
962    #[test]
963    fn test_counting_bf_multiple_inserts_then_single_remove() {
964        let mut cbf = CountingBloomFilter::new(200, 0.01);
965        // Insert the same item twice: one remove should not clear it.
966        cbf.insert(b"double");
967        cbf.insert(b"double");
968        cbf.remove(b"double");
969        // Should still be present (counter was 2, now 1).
970        assert!(cbf.contains(b"double"));
971    }
972
973    #[test]
974    fn test_double_hash_position_range() {
975        let item = b"test_item";
976        let h1_val = h1(item);
977        let h2_val = h2(item);
978        let num_bits = 1024;
979        for i in 0..10u64 {
980            let pos = double_hash_position(h1_val, h2_val, i, num_bits);
981            assert!(pos < num_bits, "position {pos} out of range");
982        }
983    }
984
985    #[test]
986    fn test_bloom_filter_clone() {
987        let mut bf = BloomFilter::new(100, 0.01);
988        bf.insert(b"cloned");
989        let bf2 = bf.clone();
990        assert!(bf2.contains(b"cloned"));
991        assert_eq!(bf2.item_count(), bf.item_count());
992    }
993
994    // ── ScalableBloomFilter ─────────────────────────────────────────────────
995
996    #[test]
997    fn test_scalable_bf_insert_contains() {
998        let mut sbf = ScalableBloomFilter::new(50, 0.01, 2.0);
999        sbf.insert(b"hello");
1000        assert!(sbf.contains(b"hello"));
1001    }
1002
1003    #[test]
1004    fn test_scalable_bf_absent_item() {
1005        let sbf = ScalableBloomFilter::new(50, 0.01, 2.0);
1006        assert!(!sbf.contains(b"missing"));
1007    }
1008
1009    #[test]
1010    fn test_scalable_bf_no_false_negatives() {
1011        let mut sbf = ScalableBloomFilter::new(50, 0.05, 2.0);
1012        let items: Vec<Vec<u8>> = (0u32..200).map(|i| i.to_le_bytes().to_vec()).collect();
1013        for item in &items {
1014            sbf.insert(item);
1015        }
1016        for item in &items {
1017            assert!(sbf.contains(item), "false negative for {item:?}");
1018        }
1019    }
1020
1021    #[test]
1022    fn test_scalable_bf_grows_layers() {
1023        // Small initial capacity → should create additional layers quickly.
1024        let mut sbf = ScalableBloomFilter::new(10, 0.1, 2.0);
1025        for i in 0u32..500 {
1026            sbf.insert(&i.to_le_bytes());
1027        }
1028        assert!(
1029            sbf.layer_count() > 1,
1030            "should have grown beyond 1 layer, got {}",
1031            sbf.layer_count()
1032        );
1033    }
1034
1035    #[test]
1036    fn test_scalable_bf_total_item_count() {
1037        let mut sbf = ScalableBloomFilter::new(100, 0.01, 2.0);
1038        sbf.insert(b"a");
1039        sbf.insert(b"b");
1040        sbf.insert(b"c");
1041        assert_eq!(sbf.total_item_count(), 3);
1042    }
1043
1044    #[test]
1045    fn test_scalable_bf_fpr_bounded() {
1046        let mut sbf = ScalableBloomFilter::new(1000, 0.01, 2.0);
1047        for i in 0u32..1000 {
1048            sbf.insert(&i.to_le_bytes());
1049        }
1050        let fpr = sbf.estimate_false_positive_rate();
1051        // Aggregate FPR should be reasonable (below 10% for 1000 items at 1% target).
1052        assert!(fpr < 0.10, "aggregate FPR {fpr:.4} is too high");
1053    }
1054
1055    #[test]
1056    fn test_scalable_bf_clone() {
1057        let mut sbf = ScalableBloomFilter::new(100, 0.01, 2.0);
1058        sbf.insert(b"test");
1059        let sbf2 = sbf.clone();
1060        assert!(sbf2.contains(b"test"));
1061        assert_eq!(sbf2.total_item_count(), 1);
1062        assert_eq!(sbf2.layer_count(), sbf.layer_count());
1063    }
1064
1065    #[test]
1066    fn test_scalable_bf_empty_fpr() {
1067        let sbf = ScalableBloomFilter::new(100, 0.01, 2.0);
1068        assert_eq!(sbf.estimate_false_positive_rate(), 0.0);
1069    }
1070
1071    // ── Scalable Bloom filter enhanced tests ────────────────────────────────
1072
1073    #[test]
1074    fn test_scalable_bf_set_tightening_ratio() {
1075        let mut sbf = ScalableBloomFilter::new(100, 0.01, 2.0);
1076        sbf.set_tightening_ratio(0.5);
1077        assert!((sbf.tightening_ratio() - 0.5).abs() < f64::EPSILON);
1078    }
1079
1080    #[test]
1081    fn test_scalable_bf_invalid_tightening_ratio_ignored() {
1082        let mut sbf = ScalableBloomFilter::new(100, 0.01, 2.0);
1083        let original = sbf.tightening_ratio();
1084        sbf.set_tightening_ratio(0.0); // invalid
1085        assert!((sbf.tightening_ratio() - original).abs() < f64::EPSILON);
1086        sbf.set_tightening_ratio(1.0); // invalid
1087        assert!((sbf.tightening_ratio() - original).abs() < f64::EPSILON);
1088        sbf.set_tightening_ratio(-0.5); // invalid
1089        assert!((sbf.tightening_ratio() - original).abs() < f64::EPSILON);
1090    }
1091
1092    #[test]
1093    fn test_scalable_bf_growth_factor() {
1094        let sbf = ScalableBloomFilter::new(100, 0.01, 3.0);
1095        assert!((sbf.growth_factor() - 3.0).abs() < f64::EPSILON);
1096    }
1097
1098    #[test]
1099    fn test_scalable_bf_growth_factor_default_when_invalid() {
1100        // growth_factor <= 1.0 should default to 2.0
1101        let sbf = ScalableBloomFilter::new(100, 0.01, 0.5);
1102        assert!((sbf.growth_factor() - 2.0).abs() < f64::EPSILON);
1103    }
1104
1105    #[test]
1106    fn test_scalable_bf_layer_stats() {
1107        let mut sbf = ScalableBloomFilter::new(10, 0.1, 2.0);
1108        for i in 0u32..200 {
1109            sbf.insert(&i.to_le_bytes());
1110        }
1111        let stats = sbf.layer_stats();
1112        assert!(!stats.is_empty());
1113        // First layer should have items
1114        assert!(stats[0].0 > 0, "first layer should have items");
1115        // All layers should have bits
1116        for (_, bits, _) in &stats {
1117            assert!(*bits > 0);
1118        }
1119    }
1120
1121    #[test]
1122    fn test_scalable_bf_estimated_capacity_remaining() {
1123        let sbf = ScalableBloomFilter::new(1000, 0.01, 2.0);
1124        let remaining = sbf.estimated_capacity_remaining();
1125        // Fresh filter should have significant remaining capacity
1126        assert!(remaining > 0, "fresh filter should have remaining capacity");
1127    }
1128
1129    #[test]
1130    fn test_scalable_bf_estimated_capacity_decreases() {
1131        let mut sbf = ScalableBloomFilter::new(100, 0.01, 2.0);
1132        let before = sbf.estimated_capacity_remaining();
1133        for i in 0u32..50 {
1134            sbf.insert(&i.to_le_bytes());
1135        }
1136        let after = sbf.estimated_capacity_remaining();
1137        assert!(
1138            after < before,
1139            "remaining capacity should decrease after inserts"
1140        );
1141    }
1142
1143    #[test]
1144    fn test_scalable_bf_clear() {
1145        let mut sbf = ScalableBloomFilter::new(100, 0.01, 2.0);
1146        for i in 0u32..50 {
1147            sbf.insert(&i.to_le_bytes());
1148        }
1149        sbf.clear();
1150        assert_eq!(sbf.total_item_count(), 0);
1151        assert_eq!(sbf.layer_count(), 1);
1152        // Previously inserted items should no longer be found
1153        assert!(!sbf.contains(&0u32.to_le_bytes()));
1154    }
1155
1156    #[test]
1157    fn test_scalable_bf_total_bits() {
1158        let mut sbf = ScalableBloomFilter::new(10, 0.1, 2.0);
1159        let bits_single = sbf.total_bits();
1160        for i in 0u32..500 {
1161            sbf.insert(&i.to_le_bytes());
1162        }
1163        let bits_multi = sbf.total_bits();
1164        assert!(
1165            bits_multi > bits_single,
1166            "total bits should increase with layers"
1167        );
1168    }
1169
1170    #[test]
1171    fn test_scalable_bf_tighter_ratio_more_layers() {
1172        // With tighter ratio (smaller), each layer has tighter FPR target
1173        // meaning more bits per layer, potentially fewer layers needed
1174        let mut sbf_tight = ScalableBloomFilter::new(10, 0.1, 2.0);
1175        sbf_tight.set_tightening_ratio(0.5);
1176        let mut sbf_loose = ScalableBloomFilter::new(10, 0.1, 2.0);
1177        sbf_loose.set_tightening_ratio(0.9);
1178
1179        for i in 0u32..200 {
1180            sbf_tight.insert(&i.to_le_bytes());
1181            sbf_loose.insert(&i.to_le_bytes());
1182        }
1183        // Both should contain all items (no false negatives)
1184        for i in 0u32..200 {
1185            assert!(sbf_tight.contains(&i.to_le_bytes()));
1186            assert!(sbf_loose.contains(&i.to_le_bytes()));
1187        }
1188    }
1189
1190    #[test]
1191    fn test_scalable_bf_empirical_fpr() {
1192        let mut sbf = ScalableBloomFilter::new(1000, 0.05, 2.0);
1193        for i in 0u32..1000 {
1194            sbf.insert(&i.to_le_bytes());
1195        }
1196        // Test FPR against 10000 absent items
1197        let mut fps = 0usize;
1198        for i in 10000u32..20000 {
1199            if sbf.contains(&i.to_le_bytes()) {
1200                fps += 1;
1201            }
1202        }
1203        let observed_fpr = fps as f64 / 10000.0;
1204        // Aggregate FPR should be reasonable (under 20%)
1205        assert!(
1206            observed_fpr < 0.20,
1207            "observed FPR {observed_fpr:.4} is too high"
1208        );
1209    }
1210
1211    #[test]
1212    fn test_scalable_bf_stress_many_inserts() {
1213        let mut sbf = ScalableBloomFilter::new(50, 0.01, 2.0);
1214        for i in 0u32..10000 {
1215            sbf.insert(&i.to_le_bytes());
1216        }
1217        assert_eq!(sbf.total_item_count(), 10000);
1218        // Verify no false negatives on a sample
1219        for i in [0u32, 999, 5000, 9999] {
1220            assert!(sbf.contains(&i.to_le_bytes()), "false negative for {i}");
1221        }
1222    }
1223}