#![allow(
clippy::cast_lossless,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_precision_loss,
clippy::cast_sign_loss,
clippy::doc_markdown,
clippy::items_after_statements,
clippy::similar_names,
clippy::unreadable_literal
)]
use alloc::format;
use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
use core::fmt;
use spg_crypto::crc32::crc32;
const BLOOM_MAGIC: u32 = 0xB100_F11E;
const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
const FNV_PRIME: u64 = 0x0000_0001_0000_01b3;
const NUM_HASHES_MAX: u32 = 32;
#[derive(Debug, PartialEq, Eq)]
pub enum BloomError {
TooShort { got: usize, need: usize },
BadMagic { got: u32 },
BadShape(String),
BadCrc { expected: u32, got: u32 },
BadNumHashes { got: u32 },
}
impl fmt::Display for BloomError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooShort { got, need } => {
write!(f, "bloom: too short, got {got} bytes, need at least {need}")
}
Self::BadMagic { got } => {
write!(
f,
"bloom: bad magic 0x{got:08x}, expected 0x{BLOOM_MAGIC:08x}"
)
}
Self::BadShape(s) => write!(f, "bloom: bad shape: {s}"),
Self::BadCrc { expected, got } => write!(
f,
"bloom: crc mismatch, expected 0x{expected:08x}, got 0x{got:08x}"
),
Self::BadNumHashes { got } => write!(
f,
"bloom: bad num_hashes {got}, must be 1..={NUM_HASHES_MAX}"
),
}
}
}
#[derive(Debug, Clone)]
pub struct BloomFilter {
bits: Vec<u64>,
num_bits: u64,
num_hashes: u32,
}
impl BloomFilter {
#[must_use]
pub fn with_target_fp_rate(num_items: usize, fp_rate: f64) -> Self {
assert!(num_items > 0, "BloomFilter: num_items must be > 0");
assert!(
fp_rate > 0.0 && fp_rate < 1.0,
"BloomFilter: fp_rate must be in (0, 1), got {fp_rate}"
);
let n = num_items as f64;
let ln_2 = libm_ln(2.0);
let m_raw = -(n * libm_ln(fp_rate)) / (ln_2 * ln_2);
let m_ceil_bits = f64_ceil_to_u64(m_raw).max(64);
let num_words = m_ceil_bits.div_ceil(64);
let num_bits = num_words * 64;
let k_raw = (num_bits as f64 / n) * ln_2;
let num_hashes = (f64_ceil_to_u64(k_raw) as u32).clamp(1, NUM_HASHES_MAX);
Self {
bits: vec![0u64; num_words as usize],
num_bits,
num_hashes,
}
}
fn from_params(num_bits: u64, num_hashes: u32, bits: Vec<u64>) -> Result<Self, BloomError> {
if num_bits == 0 || !num_bits.is_multiple_of(64) {
return Err(BloomError::BadShape(format!(
"num_bits {num_bits} must be a positive multiple of 64"
)));
}
if num_hashes == 0 || num_hashes > NUM_HASHES_MAX {
return Err(BloomError::BadNumHashes { got: num_hashes });
}
let expected_words = num_bits / 64;
if bits.len() as u64 != expected_words {
return Err(BloomError::BadShape(format!(
"bits.len() = {} doesn't match num_bits/64 = {expected_words}",
bits.len()
)));
}
Ok(Self {
bits,
num_bits,
num_hashes,
})
}
pub fn insert(&mut self, key: &[u8]) {
let (h1, h2) = derive_hash_pair(key);
for i in 0..self.num_hashes {
let bit_idx = mix(h1, h2, i, self.num_bits);
let word_idx = (bit_idx / 64) as usize;
let bit_in_word = bit_idx % 64;
self.bits[word_idx] |= 1u64 << bit_in_word;
}
}
#[must_use]
pub fn contains(&self, key: &[u8]) -> bool {
let (h1, h2) = derive_hash_pair(key);
for i in 0..self.num_hashes {
let bit_idx = mix(h1, h2, i, self.num_bits);
let word_idx = (bit_idx / 64) as usize;
let bit_in_word = bit_idx % 64;
if self.bits[word_idx] & (1u64 << bit_in_word) == 0 {
return false;
}
}
true
}
#[must_use]
pub const fn num_bits(&self) -> u64 {
self.num_bits
}
#[must_use]
pub const fn num_hashes(&self) -> u32 {
self.num_hashes
}
#[must_use]
pub fn encoded_len(&self) -> usize {
20 + self.bits.len() * 8
}
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(self.encoded_len());
out.extend_from_slice(&BLOOM_MAGIC.to_le_bytes());
let body_start = out.len();
out.extend_from_slice(&self.num_bits.to_le_bytes());
out.extend_from_slice(&self.num_hashes.to_le_bytes());
let crc_offset = out.len();
out.extend_from_slice(&0u32.to_le_bytes());
for word in &self.bits {
out.extend_from_slice(&word.to_le_bytes());
}
let body_crc = {
let mut to_hash = Vec::with_capacity(out.len() - crc_offset - 4 + 12);
to_hash.extend_from_slice(&out[body_start..crc_offset]);
to_hash.extend_from_slice(&out[crc_offset + 4..]);
crc32(&to_hash)
};
out[crc_offset..crc_offset + 4].copy_from_slice(&body_crc.to_le_bytes());
out
}
pub fn from_bytes(input: &[u8]) -> Result<Self, BloomError> {
const HEADER_LEN: usize = 20;
if input.len() < HEADER_LEN {
return Err(BloomError::TooShort {
got: input.len(),
need: HEADER_LEN,
});
}
let magic = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
if magic != BLOOM_MAGIC {
return Err(BloomError::BadMagic { got: magic });
}
let num_bits = u64::from_le_bytes([
input[4], input[5], input[6], input[7], input[8], input[9], input[10], input[11],
]);
let num_hashes = u32::from_le_bytes([input[12], input[13], input[14], input[15]]);
let crc_stored = u32::from_le_bytes([input[16], input[17], input[18], input[19]]);
if num_bits == 0 || !num_bits.is_multiple_of(64) {
return Err(BloomError::BadShape(format!(
"num_bits {num_bits} must be a positive multiple of 64"
)));
}
let expected_words = (num_bits / 64) as usize;
let expected_body_bytes = expected_words * 8;
if input.len() != HEADER_LEN + expected_body_bytes {
return Err(BloomError::BadShape(format!(
"input is {} bytes, expected {}",
input.len(),
HEADER_LEN + expected_body_bytes
)));
}
let crc_computed = {
let mut to_hash = Vec::with_capacity(12 + expected_body_bytes);
to_hash.extend_from_slice(&input[4..16]); to_hash.extend_from_slice(&input[HEADER_LEN..]);
crc32(&to_hash)
};
if crc_computed != crc_stored {
return Err(BloomError::BadCrc {
expected: crc_stored,
got: crc_computed,
});
}
let mut bits = Vec::with_capacity(expected_words);
for w in 0..expected_words {
let off = HEADER_LEN + w * 8;
bits.push(u64::from_le_bytes([
input[off],
input[off + 1],
input[off + 2],
input[off + 3],
input[off + 4],
input[off + 5],
input[off + 6],
input[off + 7],
]));
}
Self::from_params(num_bits, num_hashes, bits)
}
}
fn fnv1a_64(bytes: &[u8]) -> u64 {
let mut h = FNV_OFFSET_BASIS;
for &b in bytes {
h ^= u64::from(b);
h = h.wrapping_mul(FNV_PRIME);
}
h
}
const fn splitmix64(mut x: u64) -> u64 {
x = x.wrapping_add(0x9e37_79b9_7f4a_7c15);
x = (x ^ (x >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
x = (x ^ (x >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
x ^ (x >> 31)
}
fn derive_hash_pair(key: &[u8]) -> (u64, u64) {
let h1 = fnv1a_64(key);
let h2 = splitmix64(h1);
let h2 = if h2 == 0 { 0xdead_beef_dead_beef } else { h2 };
(h1, h2)
}
#[inline]
fn mix(h1: u64, h2: u64, i: u32, num_bits: u64) -> u64 {
let combined = h1.wrapping_add((u64::from(i)).wrapping_mul(h2));
combined % num_bits
}
fn f64_ceil_to_u64(x: f64) -> u64 {
debug_assert!(x >= 0.0, "f64_ceil_to_u64: x must be >= 0");
let truncated = x as u64;
if (truncated as f64) < x {
truncated + 1
} else {
truncated
}
}
fn libm_ln(x: f64) -> f64 {
debug_assert!(x > 0.0, "libm_ln: x must be > 0");
let bits = x.to_bits();
let exponent_raw = ((bits >> 52) & 0x7ff) as i64;
let exponent = exponent_raw - 1023;
let mantissa_bits = (bits & 0x000f_ffff_ffff_ffff) | 0x3ff0_0000_0000_0000;
let mantissa = f64::from_bits(mantissa_bits);
use core::f64::consts::LN_2;
let y = mantissa - 1.0;
let t = y / (mantissa + 1.0);
let t2 = t * t;
let ln_mantissa = 2.0 * (t + t2 * t / 3.0 + t2 * t2 * t / 5.0 + t2 * t2 * t2 * t / 7.0);
(exponent as f64) * LN_2 + ln_mantissa
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec::Vec;
fn rng_stream(seed: u64, count: usize) -> Vec<u64> {
let mut s = seed;
let mut out = Vec::with_capacity(count);
for _ in 0..count {
s = splitmix64(s.wrapping_add(1));
out.push(s);
}
out
}
#[test]
fn libm_ln_matches_known_values() {
use core::f64::consts::{LN_2, LN_10};
let cases = [
(1.0_f64, 0.0_f64),
(2.0, LN_2),
(10.0, LN_10),
(0.5, -LN_2),
(0.01, -2.0 * LN_10),
];
for &(x, expected) in &cases {
let got = libm_ln(x);
let err = (got - expected).abs();
assert!(
err < 1e-5,
"ln({x}) expected {expected}, got {got}, err {err}"
);
}
}
#[test]
fn with_target_fp_rate_sizes_match_spec() {
let bf = BloomFilter::with_target_fp_rate(100_000, 0.01);
assert_eq!(bf.num_bits() % 64, 0);
assert!(bf.num_bits() >= 958_506);
assert!(bf.num_bits() <= 958_506 + 64);
assert_eq!(bf.num_hashes(), 7);
}
#[test]
fn insert_then_contains_returns_true_for_inserted_keys() {
let mut bf = BloomFilter::with_target_fp_rate(10_000, 0.01);
let keys = rng_stream(0xc0ffee, 10_000);
for k in &keys {
bf.insert(&k.to_le_bytes());
}
for k in &keys {
assert!(
bf.contains(&k.to_le_bytes()),
"expected contains(inserted key {k}) == true"
);
}
}
#[test]
fn fuzz_oracle_fp_rate_under_target_x_1_2() {
const TARGET_FP: f64 = 0.01;
const N: usize = 100_000;
let mut bf = BloomFilter::with_target_fp_rate(N, TARGET_FP);
let inserted = rng_stream(0xfeed_beef, N);
for k in &inserted {
bf.insert(&k.to_le_bytes());
}
let probes = rng_stream(0xbeef_feed, N);
let inserted_set: alloc::collections::BTreeSet<u64> = inserted.iter().copied().collect();
let mut fp = 0u64;
let mut tested = 0u64;
for k in &probes {
if inserted_set.contains(k) {
continue;
}
tested += 1;
if bf.contains(&k.to_le_bytes()) {
fp += 1;
}
}
let observed = fp as f64 / tested as f64;
let ceiling = TARGET_FP * 1.2;
assert!(
observed <= ceiling,
"observed FP {observed:.4} exceeded ceiling {ceiling:.4} (target {TARGET_FP})"
);
}
#[test]
fn to_bytes_then_from_bytes_roundtrip() {
let mut bf = BloomFilter::with_target_fp_rate(1_000, 0.005);
let keys = rng_stream(42, 500);
for k in &keys {
bf.insert(&k.to_le_bytes());
}
let bytes = bf.to_bytes();
assert_eq!(bytes.len(), bf.encoded_len());
let parsed = BloomFilter::from_bytes(&bytes).expect("roundtrip parses");
assert_eq!(parsed.num_bits(), bf.num_bits());
assert_eq!(parsed.num_hashes(), bf.num_hashes());
for k in &keys {
assert!(parsed.contains(&k.to_le_bytes()));
}
assert_eq!(parsed.bits, bf.bits);
}
#[test]
fn from_bytes_rejects_truncated_input() {
let bf = BloomFilter::with_target_fp_rate(100, 0.01);
let bytes = bf.to_bytes();
let truncated = &bytes[..10];
match BloomFilter::from_bytes(truncated) {
Err(BloomError::TooShort { .. }) => {}
other => panic!("expected TooShort, got {other:?}"),
}
}
#[test]
fn from_bytes_rejects_bad_magic() {
let bf = BloomFilter::with_target_fp_rate(100, 0.01);
let mut bytes = bf.to_bytes();
bytes[0] ^= 0xff;
match BloomFilter::from_bytes(&bytes) {
Err(BloomError::BadMagic { .. }) => {}
other => panic!("expected BadMagic, got {other:?}"),
}
}
#[test]
fn from_bytes_rejects_bad_crc() {
let bf = BloomFilter::with_target_fp_rate(100, 0.01);
let mut bytes = bf.to_bytes();
bytes[25] ^= 0x01;
match BloomFilter::from_bytes(&bytes) {
Err(BloomError::BadCrc { .. }) => {}
other => panic!("expected BadCrc, got {other:?}"),
}
}
#[test]
fn from_bytes_rejects_zero_num_hashes() {
let num_bits: u64 = 128;
let num_hashes: u32 = 0;
let mut buf = Vec::new();
buf.extend_from_slice(&BLOOM_MAGIC.to_le_bytes());
buf.extend_from_slice(&num_bits.to_le_bytes());
buf.extend_from_slice(&num_hashes.to_le_bytes());
let crc_off = buf.len();
buf.extend_from_slice(&0u32.to_le_bytes());
for _ in 0..2 {
buf.extend_from_slice(&0u64.to_le_bytes());
}
let body_crc = {
let mut to_hash = Vec::new();
to_hash.extend_from_slice(&buf[4..16]);
to_hash.extend_from_slice(&buf[20..]);
crc32(&to_hash)
};
buf[crc_off..crc_off + 4].copy_from_slice(&body_crc.to_le_bytes());
match BloomFilter::from_bytes(&buf) {
Err(BloomError::BadNumHashes { got: 0 }) => {}
other => panic!("expected BadNumHashes, got {other:?}"),
}
}
#[test]
fn num_bits_is_always_64_aligned() {
for &(n, p) in &[
(1_usize, 0.5_f64),
(10, 0.1),
(1_000, 0.01),
(1_000_000, 0.001),
] {
let bf = BloomFilter::with_target_fp_rate(n, p);
assert_eq!(bf.num_bits() % 64, 0, "n={n} p={p}");
assert!(bf.num_bits() >= 64);
}
}
}