use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::hash::Hash;
use crate::utils::single_hash;
use crate::traits::ProbabilisticSet;
const FINGERPRINT_BITS: u32 = 16;
const FINGERPRINT_MASK: usize = (1 << FINGERPRINT_BITS) - 1;
const SLOTS_PER_BUCKET: u32 = usize::BITS / FINGERPRINT_BITS;
const MAX_KICKS: usize = 128;
pub struct CuckooFilter {
buckets: Vec<AtomicUsize>,
size: usize,
kick_count: AtomicUsize,
}
impl CuckooFilter {
pub fn new(size: usize) -> Self {
let mut buckets = Vec::with_capacity(size);
buckets.resize_with(size, || AtomicUsize::new(0));
Self { buckets, size, kick_count: AtomicUsize::new(0) }
}
fn fingerprint<T: Hash>(&self, item: &T) -> u16 {
let hash = single_hash(item);
let fp = (hash & FINGERPRINT_MASK) as u16;
if fp == 0 { 1 } else { fp }
}
fn primary_bucket_index<T: Hash>(&self, item: &T) -> usize {
single_hash(item) % self.size
}
fn secondary_bucket_index(&self, fingerprint: u16, primary_index: usize) -> usize {
let hash = single_hash(&fingerprint);
(primary_index ^ (hash % self.size)) % self.size
}
fn bucket_contains(&self, bucket_index: usize, fingerprint: u16) -> bool {
let bucket = self.buckets[bucket_index].load(Ordering::Relaxed);
for i in 0..(usize::BITS / FINGERPRINT_BITS) {
let fp = ((bucket >> (i * FINGERPRINT_BITS)) & FINGERPRINT_MASK as usize) as u16;
if fp == fingerprint {
return true;
}
}
false
}
fn bucket_insert(&self, bucket_index: usize, fingerprint: u16) -> bool {
loop {
let bucket = self.buckets[bucket_index].load(Ordering::Acquire);
let mut found_empty = false;
for i in 0..(usize::BITS / FINGERPRINT_BITS) {
let fp = ((bucket >> (i * FINGERPRINT_BITS)) & FINGERPRINT_MASK as usize) as u16;
if fp == 0 {
found_empty = true;
let new_bucket = bucket | ((fingerprint as usize) << (i * FINGERPRINT_BITS));
match self.buckets[bucket_index].compare_exchange(bucket, new_bucket, Ordering::AcqRel, Ordering::Relaxed) {
Ok(_) => return true,
Err(_) => break, }
}
}
if !found_empty {
return false; }
}
}
fn bucket_delete(&self, bucket_index: usize, fingerprint: u16) -> bool {
loop {
let bucket = self.buckets[bucket_index].load(Ordering::Acquire);
let mut found = false;
for i in 0..(usize::BITS / FINGERPRINT_BITS) {
let fp = ((bucket >> (i * FINGERPRINT_BITS)) & FINGERPRINT_MASK as usize) as u16;
if fp == fingerprint {
found = true;
let new_bucket = bucket & !((FINGERPRINT_MASK as usize) << (i * FINGERPRINT_BITS));
match self.buckets[bucket_index].compare_exchange(bucket, new_bucket, Ordering::AcqRel, Ordering::Relaxed) {
Ok(_) => return true,
Err(_) => break, }
}
}
if !found {
return false; }
}
}
fn bucket_swap(&self, bucket_index: usize, new_fp: u16) -> u16 {
let slot = (self.kick_count.fetch_add(1, Ordering::Relaxed) % SLOTS_PER_BUCKET as usize) as u32;
loop {
let bucket = self.buckets[bucket_index].load(Ordering::Acquire);
let old_fp = ((bucket >> (slot * FINGERPRINT_BITS)) & FINGERPRINT_MASK as usize) as u16;
let cleared = bucket & !((FINGERPRINT_MASK as usize) << (slot * FINGERPRINT_BITS));
let new_bucket = cleared | ((new_fp as usize) << (slot * FINGERPRINT_BITS));
match self.buckets[bucket_index].compare_exchange(bucket, new_bucket, Ordering::AcqRel, Ordering::Relaxed) {
Ok(_) => return old_fp,
Err(_) => continue, }
}
}
pub fn try_insert<T: Hash>(&self, item: &T) -> bool {
let fingerprint = self.fingerprint(item);
let primary = self.primary_bucket_index(item);
if self.bucket_insert(primary, fingerprint) {
return true;
}
let secondary = self.secondary_bucket_index(fingerprint, primary);
if self.bucket_insert(secondary, fingerprint) {
return true;
}
let mut chain: Vec<(usize, u16)> = Vec::with_capacity(MAX_KICKS);
let mut current_fp = fingerprint;
let mut index = secondary;
for _ in 0..MAX_KICKS {
let evicted = self.bucket_swap(index, current_fp);
chain.push((index, evicted));
current_fp = evicted;
index = self.secondary_bucket_index(current_fp, index);
if self.bucket_insert(index, current_fp) {
return true;
}
}
for (idx, original_fp) in chain.into_iter().rev() {
self.bucket_swap(idx, original_fp);
}
false }
}
impl ProbabilisticSet for CuckooFilter {
fn insert<T: Hash>(&self, item: &T) -> bool {
self.try_insert(item)
}
fn contains<T: Hash>(&self, item: &T) -> bool {
let fingerprint = self.fingerprint(item);
let primary_bucket_index = self.primary_bucket_index(item);
if self.bucket_contains(primary_bucket_index, fingerprint) {
return true;
}
let secondary_bucket_index = self.secondary_bucket_index(fingerprint, primary_bucket_index);
if self.bucket_contains(secondary_bucket_index, fingerprint) {
return true;
}
false
}
fn delete<T: Hash>(&self, item: &T) {
let fingerprint = self.fingerprint(item);
let primary_bucket_index = self.primary_bucket_index(item);
if self.bucket_delete(primary_bucket_index, fingerprint) {
return;
}
let secondary_bucket_index = self.secondary_bucket_index(fingerprint, primary_bucket_index);
if self.bucket_delete(secondary_bucket_index, fingerprint) {
return;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
const FILTER_SIZE: usize = 1024;
#[test]
fn construct() {
let cf = CuckooFilter::new(FILTER_SIZE);
assert_eq!(cf.buckets.len(), FILTER_SIZE);
assert!(cf.buckets.iter().all(|x| x.load(Ordering::Relaxed) == 0));
assert_eq!(cf.size, FILTER_SIZE);
}
#[test]
fn fingerprint_non_zero() {
let cf = CuckooFilter::new(FILTER_SIZE);
for i in 0..1000 {
let fp = cf.fingerprint(&i);
assert_ne!(fp, 0, "Fingerprint should never be 0");
}
}
#[test]
fn insert_and_contains() {
let cf = CuckooFilter::new(FILTER_SIZE);
let item = "foo";
assert!(cf.try_insert(&item));
assert!(cf.contains(&item));
let not_inserted = "bar";
assert!(!cf.contains(¬_inserted));
}
#[test]
fn insert_multiple() {
let cf = CuckooFilter::new(FILTER_SIZE);
for i in 0..100 {
assert!(cf.try_insert(&i), "Failed to insert item {}", i);
}
for i in 0..100 {
assert!(cf.contains(&i), "Item {} should be present", i);
}
}
#[test]
fn delete() {
let cf = CuckooFilter::new(FILTER_SIZE);
let item = "foo";
cf.insert(&item);
assert!(cf.contains(&item));
cf.delete(&item);
assert!(!cf.contains(&item));
}
#[test]
fn delete_non_existent() {
let cf = CuckooFilter::new(FILTER_SIZE);
let item = "foo";
cf.delete(&item);
assert!(!cf.contains(&item));
}
#[test]
fn secondary_bucket_is_symmetric() {
let cf = CuckooFilter::new(FILTER_SIZE);
let item = "test";
let fp = cf.fingerprint(&item);
let primary = cf.primary_bucket_index(&item);
let secondary = cf.secondary_bucket_index(fp, primary);
let back_to_primary = cf.secondary_bucket_index(fp, secondary);
assert_eq!(primary, back_to_primary, "XOR property should be symmetric");
}
#[test]
fn cuckoo_eviction_works() {
let cf = CuckooFilter::new(16);
let mut inserted = 0;
for i in 0..50 {
if cf.try_insert(&i) {
inserted += 1;
}
}
assert!(inserted >= 40, "Should insert at least 40 items, got {}", inserted);
for i in 0..50 {
if i < inserted as i32 {
}
}
}
#[test]
fn filter_full_returns_false() {
let cf = CuckooFilter::new(4);
let mut failures = 0;
for i in 0..100 {
if !cf.try_insert(&i) {
failures += 1;
}
}
assert!(failures > 0, "Small filter should reject some inserts");
}
#[test]
fn concurrent_inserts() {
let cf = Arc::new(CuckooFilter::new(FILTER_SIZE));
let num_threads = 4;
let items_per_thread = 100;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let cf = Arc::clone(&cf);
thread::spawn(move || {
for i in 0..items_per_thread {
let item = t * items_per_thread + i;
cf.insert(&item);
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let mut found = 0;
for i in 0..(num_threads * items_per_thread) {
if cf.contains(&i) {
found += 1;
}
}
assert!(
found >= (num_threads * items_per_thread) * 90 / 100,
"At least 90% of items should be found, got {}/{}",
found,
num_threads * items_per_thread
);
}
#[test]
fn concurrent_insert_and_lookup() {
let cf = Arc::new(CuckooFilter::new(FILTER_SIZE));
for i in 0..50 {
cf.insert(&i);
}
let cf_insert = Arc::clone(&cf);
let cf_lookup = Arc::clone(&cf);
let insert_handle = thread::spawn(move || {
for i in 50..150 {
cf_insert.insert(&i);
}
});
let lookup_handle = thread::spawn(move || {
let mut found = 0;
for _ in 0..1000 {
for i in 0..50 {
if cf_lookup.contains(&i) {
found += 1;
}
}
}
found
});
insert_handle.join().unwrap();
let found = lookup_handle.join().unwrap();
let expected = 50 * 1000;
assert!(
found >= expected * 99 / 100,
"Pre-inserted items should almost always be found, got {}/{}",
found,
expected
);
}
}