use core::{hash::BuildHasher, marker::PhantomData};
use alloc::{vec, vec::Vec};
use crate::{
hash::{mix64, DefaultHashBuilder},
Error,
};
const BUCKET_SIZE: usize = 4;
const MAX_KICKS: usize = 500;
const EMPTY: u16 = 0;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct Victim {
fingerprint: u16,
index: usize,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CuckooFilter<T: ?Sized, S = DefaultHashBuilder> {
buckets: Vec<u16>,
mask: usize,
len: usize,
victim: Option<Victim>,
rng_state: u64,
#[cfg_attr(feature = "serde", serde(skip))]
hasher: S,
#[cfg_attr(feature = "serde", serde(skip))]
_marker: PhantomData<fn(&T)>,
}
impl<T: ?Sized> CuckooFilter<T, DefaultHashBuilder> {
pub fn new(capacity: usize) -> Result<Self, Error> {
Self::with_hasher(capacity, DefaultHashBuilder)
}
}
impl<T: ?Sized, S: BuildHasher> CuckooFilter<T, S> {
pub fn with_hasher(capacity: usize, hasher: S) -> Result<Self, Error> {
if capacity == 0 {
return Err(Error::InvalidParameter {
param: "capacity",
reason: "must be greater than zero",
});
}
let num_buckets = capacity_to_buckets(capacity);
Ok(Self {
buckets: vec![EMPTY; num_buckets * BUCKET_SIZE],
mask: num_buckets - 1,
len: 0,
victim: None,
rng_state: 0x9E37_79B9_7F4A_7C15 ^ num_buckets as u64,
hasher,
_marker: PhantomData,
})
}
pub fn insert(&mut self, item: &T) -> Result<(), Error>
where
T: core::hash::Hash,
{
if self.victim.is_some() {
return Err(Error::CapacityExceeded);
}
let hash = self.hasher.hash_one(item);
let fingerprint = fingerprint_of(hash);
let i1 = (hash as usize) & self.mask;
let i2 = self.alt_index(i1, fingerprint);
if self.try_put(i1, fingerprint) || self.try_put(i2, fingerprint) {
self.len += 1;
return Ok(());
}
let mut index = if self.next_rng() & 1 == 0 { i1 } else { i2 };
let mut moving = fingerprint;
for _ in 0..MAX_KICKS {
let slot = (self.next_rng() as usize) & (BUCKET_SIZE - 1);
let pos = index * BUCKET_SIZE + slot;
core::mem::swap(&mut moving, &mut self.buckets[pos]);
index = self.alt_index(index, moving);
if self.try_put(index, moving) {
self.len += 1;
return Ok(());
}
}
self.victim = Some(Victim {
fingerprint: moving,
index,
});
self.len += 1;
Ok(())
}
#[must_use]
pub fn contains(&self, item: &T) -> bool
where
T: core::hash::Hash,
{
let hash = self.hasher.hash_one(item);
let fingerprint = fingerprint_of(hash);
let i1 = (hash as usize) & self.mask;
let i2 = self.alt_index(i1, fingerprint);
self.bucket_contains(i1, fingerprint)
|| self.bucket_contains(i2, fingerprint)
|| self.victim_matches(fingerprint, i1, i2)
}
pub fn remove(&mut self, item: &T) -> bool
where
T: core::hash::Hash,
{
let hash = self.hasher.hash_one(item);
let fingerprint = fingerprint_of(hash);
let i1 = (hash as usize) & self.mask;
let i2 = self.alt_index(i1, fingerprint);
if self.try_take(i1, fingerprint) || self.try_take(i2, fingerprint) {
self.len -= 1;
self.reinsert_victim();
return true;
}
if self.victim_matches(fingerprint, i1, i2) {
self.victim = None;
self.len -= 1;
return true;
}
false
}
pub fn clear(&mut self) {
self.buckets.iter_mut().for_each(|slot| *slot = EMPTY);
self.len = 0;
self.victim = None;
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
#[must_use]
pub fn capacity(&self) -> usize {
self.buckets.len()
}
#[must_use]
pub fn load_factor(&self) -> f64 {
self.len as f64 / self.buckets.len() as f64
}
#[inline]
fn alt_index(&self, index: usize, fingerprint: u16) -> usize {
index ^ ((mix64(u64::from(fingerprint)) as usize) & self.mask)
}
#[inline]
fn try_put(&mut self, index: usize, fingerprint: u16) -> bool {
let base = index * BUCKET_SIZE;
for slot in &mut self.buckets[base..base + BUCKET_SIZE] {
if *slot == EMPTY {
*slot = fingerprint;
return true;
}
}
false
}
#[inline]
fn try_take(&mut self, index: usize, fingerprint: u16) -> bool {
let base = index * BUCKET_SIZE;
for slot in &mut self.buckets[base..base + BUCKET_SIZE] {
if *slot == fingerprint {
*slot = EMPTY;
return true;
}
}
false
}
#[inline]
fn bucket_contains(&self, index: usize, fingerprint: u16) -> bool {
let base = index * BUCKET_SIZE;
self.buckets[base..base + BUCKET_SIZE].contains(&fingerprint)
}
#[inline]
fn victim_matches(&self, fingerprint: u16, i1: usize, i2: usize) -> bool {
matches!(
self.victim,
Some(Victim { fingerprint: vf, index: vi })
if vf == fingerprint && (vi == i1 || vi == i2)
)
}
fn reinsert_victim(&mut self) {
let Some(victim) = self.victim else {
return;
};
let alt = self.alt_index(victim.index, victim.fingerprint);
if self.try_put(victim.index, victim.fingerprint) || self.try_put(alt, victim.fingerprint) {
self.victim = None;
}
}
#[inline]
fn next_rng(&mut self) -> u64 {
let mut x = self.rng_state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.rng_state = x;
x
}
}
#[inline]
fn fingerprint_of(hash: u64) -> u16 {
let fingerprint = (hash >> 48) as u16;
if fingerprint == EMPTY {
1
} else {
fingerprint
}
}
fn capacity_to_buckets(capacity: usize) -> usize {
let target_load = 0.95;
let raw = libm::ceil(capacity as f64 / (BUCKET_SIZE as f64 * target_load)) as usize;
raw.max(1).next_power_of_two()
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_new_rejects_zero_capacity() {
assert!(matches!(
CuckooFilter::<&str>::new(0),
Err(Error::InvalidParameter { .. })
));
}
#[test]
fn test_insert_contains_remove() {
let mut filter = CuckooFilter::new(1_000).unwrap();
filter.insert("alice").unwrap();
assert!(filter.contains("alice"));
assert!(filter.remove("alice"));
assert!(!filter.contains("alice"));
assert!(!filter.remove("alice"));
}
#[test]
fn test_no_false_negatives() {
let mut filter = CuckooFilter::new(5_000).unwrap();
for i in 0..2_000u32 {
filter.insert(&i).unwrap();
}
for i in 0..2_000u32 {
assert!(filter.contains(&i), "inserted item {i} reported absent");
}
}
#[test]
fn test_duplicates_need_matching_removes() {
let mut filter = CuckooFilter::new(100).unwrap();
filter.insert("dup").unwrap();
filter.insert("dup").unwrap();
assert_eq!(filter.len(), 2);
assert!(filter.remove("dup"));
assert!(filter.contains("dup")); assert!(filter.remove("dup"));
assert!(!filter.contains("dup"));
}
#[test]
fn test_len_and_clear() {
let mut filter = CuckooFilter::new(100).unwrap();
assert!(filter.is_empty());
for i in 0..10u32 {
filter.insert(&i).unwrap();
}
assert_eq!(filter.len(), 10);
filter.clear();
assert!(filter.is_empty());
assert_eq!(filter.len(), 0);
}
#[test]
fn test_capacity_has_headroom() {
let filter = CuckooFilter::<u32>::new(1_000).unwrap();
assert!(filter.capacity() >= 1_000);
assert_eq!(filter.capacity() % BUCKET_SIZE, 0);
}
#[test]
fn test_fill_to_capacity_eventually_errors() {
let mut filter = CuckooFilter::<u64>::new(64).unwrap();
let mut inserted = 0u64;
let mut hit_limit = false;
for i in 0..10_000u64 {
match filter.insert(&i) {
Ok(()) => inserted += 1,
Err(Error::CapacityExceeded) => {
hit_limit = true;
break;
}
Err(other) => panic!("unexpected error: {other:?}"),
}
}
assert!(hit_limit, "filter never reported CapacityExceeded");
for i in 0..inserted {
assert!(filter.contains(&i), "accepted item {i} reported absent");
}
}
#[test]
fn test_load_factor() {
let mut filter = CuckooFilter::<u32>::new(100).unwrap();
let cap = filter.capacity();
for i in 0..(cap / 2) as u32 {
filter.insert(&i).unwrap();
}
let lf = filter.load_factor();
assert!((0.49..=0.51).contains(&lf), "unexpected load factor {lf}");
}
}