use std::hash::{Hash, Hasher};
const DEFAULT_CAPACITY: usize = 1_000_000;
const TARGET_FP: f64 = 0.01;
const NUM_HASHES: u32 = 7;
fn optimal_bits(n: usize, fp: f64) -> usize {
let m = -(n as f64) * fp.ln() / (core::f64::consts::LN_2.powi(2));
let m = (m.ceil() as usize).max(64);
m.next_power_of_two()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum AllocKind {
Mmap { mapped_bytes: usize },
Heap,
}
pub struct MmapBloom {
ptr: *mut u8,
len_bytes: usize,
mask: u64,
count: usize,
alloc_kind: AllocKind,
}
unsafe impl Send for MmapBloom {}
unsafe impl Sync for MmapBloom {}
impl MmapBloom {
pub fn new(capacity: usize) -> Self {
let cap = if capacity == 0 {
DEFAULT_CAPACITY
} else {
capacity
};
let bits = optimal_bits(cap, TARGET_FP);
let len_bytes = bits / 8;
let (ptr, alloc_kind) = Self::alloc(len_bytes);
Self {
ptr,
len_bytes,
mask: (bits as u64) - 1,
count: 0,
alloc_kind,
}
}
pub fn with_default_capacity() -> Self {
Self::new(DEFAULT_CAPACITY)
}
fn alloc(len: usize) -> (*mut u8, AllocKind) {
#[cfg(unix)]
{
Self::alloc_unix(len)
}
#[cfg(not(unix))]
{
Self::alloc_heap(len)
}
}
#[cfg(unix)]
fn alloc_unix(len: usize) -> (*mut u8, AllocKind) {
use libc::{mmap, MAP_ANONYMOUS, MAP_FAILED, MAP_PRIVATE, PROT_READ, PROT_WRITE};
use std::ptr;
#[cfg(target_os = "linux")]
{
const MAP_HUGETLB: libc::c_int = 0x40000;
const HUGE_PAGE_SIZE: usize = 2 << 20; let aligned = (len + HUGE_PAGE_SIZE - 1) & !(HUGE_PAGE_SIZE - 1);
let p = unsafe {
mmap(
ptr::null_mut(),
aligned,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB,
-1,
0,
)
};
if p != MAP_FAILED {
return (
p as *mut u8,
AllocKind::Mmap {
mapped_bytes: aligned,
},
);
}
}
let p = unsafe {
mmap(
ptr::null_mut(),
len,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS,
-1,
0,
)
};
if p != MAP_FAILED {
return (p as *mut u8, AllocKind::Mmap { mapped_bytes: len });
}
Self::alloc_heap(len)
}
fn alloc_heap(len: usize) -> (*mut u8, AllocKind) {
let mut v: Vec<u8> = vec![0u8; len];
let ptr = v.as_mut_ptr();
std::mem::forget(v);
(ptr, AllocKind::Heap)
}
#[inline(always)]
fn hash_seeds<T: Hash + ?Sized>(item: &T) -> (u64, u64) {
let mut state = ahash::AHasher::default();
item.hash(&mut state);
let h1 = state.finish();
let mut x = h1;
x ^= x >> 33;
x = x.wrapping_mul(0xff51afd7ed558ccd);
x ^= x >> 33;
x = x.wrapping_mul(0xc4ceb9fe1a85ec53);
x ^= x >> 33;
(h1, x | 1)
}
#[inline]
pub fn insert<T: Hash + ?Sized>(&mut self, item: &T) {
let (h1, h2) = Self::hash_seeds(item);
let mask = self.mask;
let mut composite = h1;
for i in 0..NUM_HASHES as u64 {
let pos = composite & mask;
let byte_idx = (pos >> 3) as usize;
let bit_idx = (pos & 7) as u8;
unsafe {
let byte = &mut *self.ptr.add(byte_idx);
*byte |= 1 << bit_idx;
}
composite = composite.wrapping_add(h2).wrapping_add(i);
}
self.count += 1;
}
#[inline]
pub fn contains<T: Hash + ?Sized>(&self, item: &T) -> bool {
let (h1, h2) = Self::hash_seeds(item);
let mask = self.mask;
let mut composite = h1;
for i in 0..NUM_HASHES as u64 {
let pos = composite & mask;
let byte_idx = (pos >> 3) as usize;
let bit_idx = (pos & 7) as u8;
let set = unsafe {
let byte = *self.ptr.add(byte_idx);
byte & (1 << bit_idx) != 0
};
if !set {
return false;
}
composite = composite.wrapping_add(h2).wrapping_add(i);
}
true
}
#[inline]
pub fn len(&self) -> usize {
self.count
}
#[inline]
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn clear(&mut self) {
unsafe {
std::ptr::write_bytes(self.ptr, 0, self.len_bytes);
}
self.count = 0;
}
#[inline]
pub fn size_bytes(&self) -> usize {
self.len_bytes
}
}
impl Drop for MmapBloom {
fn drop(&mut self) {
if self.len_bytes == 0 || self.ptr.is_null() {
return;
}
match self.alloc_kind {
#[cfg(unix)]
AllocKind::Mmap { mapped_bytes } => {
unsafe {
libc::munmap(self.ptr as *mut libc::c_void, mapped_bytes);
}
}
#[cfg(not(unix))]
AllocKind::Mmap { .. } => {
unsafe {
let _ = Vec::from_raw_parts(self.ptr, self.len_bytes, self.len_bytes);
}
}
AllocKind::Heap => {
unsafe {
let _ = Vec::from_raw_parts(self.ptr, self.len_bytes, self.len_bytes);
}
}
}
}
}
impl std::fmt::Debug for MmapBloom {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MmapBloom")
.field("num_bits", &(self.mask + 1))
.field("count", &self.count)
.field("size_bytes", &self.len_bytes)
.field("alloc_kind", &self.alloc_kind)
.finish()
}
}
impl Clone for MmapBloom {
fn clone(&self) -> Self {
let (ptr, alloc_kind) = Self::alloc(self.len_bytes);
unsafe {
std::ptr::copy_nonoverlapping(self.ptr, ptr, self.len_bytes);
}
Self {
ptr,
len_bytes: self.len_bytes,
mask: self.mask,
count: self.count,
alloc_kind,
}
}
}
impl PartialEq for MmapBloom {
fn eq(&self, other: &Self) -> bool {
if self.len_bytes != other.len_bytes || self.count != other.count {
return false;
}
unsafe {
std::slice::from_raw_parts(self.ptr, self.len_bytes)
== std::slice::from_raw_parts(other.ptr, other.len_bytes)
}
}
}
impl Eq for MmapBloom {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_insert_contains() {
let mut bloom = MmapBloom::new(1000);
bloom.insert(&"https://example.com");
assert!(bloom.contains(&"https://example.com"));
assert!(!bloom.contains(&"https://other.com"));
}
#[test]
fn test_empty() {
let bloom = MmapBloom::new(1000);
assert!(bloom.is_empty());
assert_eq!(bloom.len(), 0);
assert!(!bloom.contains(&"anything"));
}
#[test]
fn test_clear() {
let mut bloom = MmapBloom::new(1000);
bloom.insert(&"https://example.com");
assert!(bloom.contains(&"https://example.com"));
bloom.clear();
assert!(bloom.is_empty());
assert!(!bloom.contains(&"https://example.com"));
}
#[test]
fn test_false_positive_rate() {
let n = 10_000;
let mut bloom = MmapBloom::new(n);
for i in 0..n {
bloom.insert(&format!("url-{}", i));
}
assert_eq!(bloom.len(), n);
let test_count = 10_000;
let mut false_positives = 0;
for i in n..(n + test_count) {
if bloom.contains(&format!("url-{}", i)) {
false_positives += 1;
}
}
let fp_rate = false_positives as f64 / test_count as f64;
assert!(
fp_rate < 0.05,
"False positive rate too high: {:.2}%",
fp_rate * 100.0
);
}
#[test]
fn test_no_false_negatives() {
let mut bloom = MmapBloom::new(5000);
let urls: Vec<String> = (0..5000)
.map(|i| format!("https://site.com/{}", i))
.collect();
for url in &urls {
bloom.insert(url);
}
for url in &urls {
assert!(bloom.contains(url), "False negative for {}", url);
}
}
#[test]
fn test_clone() {
let mut bloom = MmapBloom::new(100);
bloom.insert(&"https://a.com");
bloom.insert(&"https://b.com");
let bloom2 = bloom.clone();
assert!(bloom2.contains(&"https://a.com"));
assert!(bloom2.contains(&"https://b.com"));
assert_eq!(bloom2.len(), 2);
}
#[test]
fn test_size_reasonable() {
let bloom = MmapBloom::new(1_000_000);
assert!(bloom.size_bytes() > 1_000_000);
assert!(bloom.size_bytes() <= 2_097_152); }
#[test]
fn test_default_capacity() {
let bloom = MmapBloom::with_default_capacity();
assert!(bloom.size_bytes() > 0);
assert!(bloom.is_empty());
}
#[test]
fn test_optimal_bits() {
let bits = optimal_bits(1_000_000, 0.01);
assert!(bits.is_power_of_two());
assert_eq!(bits, 16_777_216);
}
#[test]
fn test_drop_safety() {
for size in [0, 1, 100, 10_000, 1_000_000] {
let bloom = MmapBloom::new(size);
drop(bloom);
}
}
#[test]
fn test_clone_independence() {
let mut bloom = MmapBloom::new(100);
bloom.insert(&"url-a");
let mut bloom2 = bloom.clone();
bloom2.insert(&"url-b");
assert!(!bloom.contains(&"url-b"));
assert!(bloom2.contains(&"url-b"));
}
}