use std::collections::hash_map::DefaultHasher;
use std::fmt;
use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher};
use std::marker::PhantomData;
use rand::Rng;
use succinct::{IntVec, IntVecMut, IntVector};
use crate::filters::Filter;
use crate::helpers::all_zero_intvector;
const MAX_NUM_KICKS: usize = 500;
#[derive(Debug, Clone, Copy)]
pub struct CuckooFilterFull;
#[derive(Clone)]
pub struct CuckooFilter<T, R, B = BuildHasherDefault<DefaultHasher>>
where
T: Hash + ?Sized,
R: Rng,
B: BuildHasher + Clone + Eq,
{
table: IntVector<u64>,
n_elements: usize,
buildhasher: B,
bucketsize: usize,
n_buckets: usize,
l_fingerprint: usize,
rng: R,
phantom: PhantomData<fn() -> T>,
}
impl<T, R> CuckooFilter<T, R>
where
T: Hash + ?Sized,
R: Rng,
{
pub fn with_params(rng: R, bucketsize: usize, n_buckets: usize, l_fingerprint: usize) -> Self {
let bh = BuildHasherDefault::<DefaultHasher>::default();
Self::with_params_and_hash(rng, bucketsize, n_buckets, l_fingerprint, bh)
}
pub fn with_properties_4(false_positive_rate: f64, expected_elements: usize, rng: R) -> Self {
let bh = BuildHasherDefault::<DefaultHasher>::default();
Self::with_properties_and_hash_4(false_positive_rate, expected_elements, rng, bh)
}
pub fn with_properties_8(false_positive_rate: f64, expected_elements: usize, rng: R) -> Self {
let bh = BuildHasherDefault::<DefaultHasher>::default();
Self::with_properties_and_hash_8(false_positive_rate, expected_elements, rng, bh)
}
}
impl<T, R, B> CuckooFilter<T, R, B>
where
T: Hash + ?Sized,
R: Rng,
B: BuildHasher + Clone + Eq,
{
pub fn with_params_and_hash(
rng: R,
bucketsize: usize,
n_buckets: usize,
l_fingerprint: usize,
bh: B,
) -> Self {
assert!(
bucketsize >= 2,
"bucketsize ({}) must be greater or equal than 2",
bucketsize
);
assert!(
n_buckets.is_power_of_two() & (n_buckets >= 2),
"n_buckets ({}) must be a power of 2 and greater or equal than 2",
n_buckets
);
assert!(
(l_fingerprint > 1) & (l_fingerprint <= 64),
"l_fingerprint ({}) must be greater than 1 and less or equal than 64",
l_fingerprint
);
let table_size = n_buckets
.checked_mul(bucketsize)
.expect("Table size too large");
Self {
table: all_zero_intvector(l_fingerprint, table_size),
n_elements: 0,
buildhasher: bh,
bucketsize,
n_buckets,
l_fingerprint,
rng,
phantom: PhantomData,
}
}
pub fn with_properties_and_hash_4(
false_positive_rate: f64,
expected_elements: usize,
rng: R,
bh: B,
) -> Self {
let bucketsize = 4usize;
let load_factor = 0.95f64;
Self::with_properties_and_hash_n(
bucketsize,
load_factor,
false_positive_rate,
expected_elements,
rng,
bh,
)
}
pub fn with_properties_and_hash_8(
false_positive_rate: f64,
expected_elements: usize,
rng: R,
bh: B,
) -> Self {
let bucketsize = 8usize;
let load_factor = 0.98f64;
Self::with_properties_and_hash_n(
bucketsize,
load_factor,
false_positive_rate,
expected_elements,
rng,
bh,
)
}
fn with_properties_and_hash_n(
bucketsize: usize,
load_factor: f64,
false_positive_rate: f64,
expected_elements: usize,
rng: R,
bh: B,
) -> Self {
assert!(
expected_elements >= 1,
"expected_elements ({}) must be at least 1",
expected_elements
);
assert!(
(false_positive_rate > 0.) && (false_positive_rate < 1.),
"false_positive_rate ({}) must be greater than 0 and smaller than 1",
false_positive_rate
);
let l_fingerprint = (2.0 * (bucketsize as f64) / false_positive_rate)
.log2()
.ceil() as usize;
let costs = (l_fingerprint as f64) / load_factor;
let n_buckets = ((costs * (expected_elements as f64) / (l_fingerprint as f64)).ceil()
as usize)
.next_power_of_two();
Self::with_params_and_hash(rng, bucketsize, n_buckets, l_fingerprint, bh)
}
pub fn bucketsize(&self) -> usize {
self.bucketsize
}
pub fn n_buckets(&self) -> usize {
self.n_buckets
}
pub fn l_fingerprint(&self) -> usize {
self.l_fingerprint
}
pub fn delete(&mut self, t: &T) -> bool {
let (f, i1, i2) = self.start(t);
if self.remove_from_bucket(i1, f) {
self.n_elements -= 1;
return true;
}
if self.remove_from_bucket(i2, f) {
self.n_elements -= 1;
return true;
}
false
}
fn start(&self, t: &T) -> (u64, usize, usize) {
let f = self.fingerprint(t);
let i1 = self.hash(t);
let i2 = i1 ^ self.hash(&f);
(f, i1, i2)
}
fn fingerprint(&self, t: &T) -> u64 {
let mut hasher = self.buildhasher.build_hasher();
hasher.write_usize(0); t.hash(&mut hasher);
let x_mod = if self.l_fingerprint == 64 {
u64::max_value()
} else {
(1u64 << self.l_fingerprint) - 1
};
1 + (hasher.finish() % x_mod)
}
fn hash<U>(&self, obj: &U) -> usize
where
U: Hash + ?Sized,
{
let mut hasher = self.buildhasher.build_hasher();
hasher.write_usize(1); obj.hash(&mut hasher);
(hasher.finish() & (self.n_buckets as u64 - 1)) as usize
}
fn write_to_bucket(&mut self, i: usize, f: u64) -> bool {
let offset = i * self.bucketsize;
for x in offset..(offset + self.bucketsize) {
if self.table.get(x as u64) == 0 {
self.table.set(x as u64, f);
return true;
}
}
false
}
fn has_in_bucket(&self, i: usize, f: u64) -> bool {
let offset = i * self.bucketsize;
for x in offset..(offset + self.bucketsize) {
if self.table.get(x as u64) == f {
return true;
}
}
false
}
fn remove_from_bucket(&mut self, i: usize, f: u64) -> bool {
let offset = i * self.bucketsize;
for x in offset..(offset + self.bucketsize) {
if self.table.get(x as u64) == f {
self.table.set(x as u64, 0);
return true;
}
}
false
}
fn insert_internal(
&mut self,
mut f: u64,
i1: usize,
i2: usize,
log: &mut Vec<(usize, u64)>,
) -> Result<bool, CuckooFilterFull> {
if self.write_to_bucket(i1, f) {
self.n_elements += 1;
return Ok(true);
}
if self.write_to_bucket(i2, f) {
self.n_elements += 1;
return Ok(false);
}
let mut i = if self.rng.gen::<bool>() { i1 } else { i2 };
for _ in 0..MAX_NUM_KICKS {
let e: usize = self.rng.gen_range(0..self.bucketsize);
let offset = i * self.bucketsize;
let x = offset + e;
let tmp = self.table.get(x as u64);
log.push((x, tmp));
self.table.set(x as u64, f);
f = tmp;
i ^= self.hash(&f);
if self.write_to_bucket(i, f) {
self.n_elements += 1;
return Ok(true);
}
}
Err(CuckooFilterFull)
}
fn restore_state(&mut self, log: &[(usize, u64)]) {
for (pos, data) in log.iter().rev().cloned() {
self.table.set(pos as u64, data);
}
}
}
impl<T, R, B> Filter<T> for CuckooFilter<T, R, B>
where
T: Hash + ?Sized,
R: Rng,
B: BuildHasher + Clone + Eq,
{
type InsertErr = CuckooFilterFull;
fn clear(&mut self) {
self.n_elements = 0;
self.table = IntVector::with_fill(self.table.element_bits(), self.table.len(), 0);
}
fn insert(&mut self, obj: &T) -> Result<bool, Self::InsertErr> {
let (f, i1, i2) = self.start(obj);
let mut log: Vec<(usize, u64)> = vec![];
let result = self.insert_internal(f, i1, i2, &mut log);
if result.is_err() {
self.restore_state(&log);
}
result
}
fn union(&mut self, other: &Self) -> Result<(), Self::InsertErr> {
assert_eq!(
self.bucketsize, other.bucketsize,
"bucketsize must be equal (left={}, right={})",
self.bucketsize, other.bucketsize
);
assert_eq!(
self.n_buckets, other.n_buckets,
"n_buckets must be equal (left={}, right={})",
self.n_buckets, other.n_buckets
);
assert_eq!(
self.l_fingerprint, other.l_fingerprint,
"l_fingerprint must be equal (left={}, right={})",
self.l_fingerprint, other.l_fingerprint
);
assert!(
self.buildhasher == other.buildhasher,
"buildhasher must be equal",
);
let mut log: Vec<(usize, u64)> = vec![];
let n_elements_backup = self.n_elements;
let mut i1: usize = 0;
for (counter, f) in other.table.iter().enumerate() {
if (counter > 0) && (counter % other.bucketsize == 0) {
i1 += 1;
}
if f != 0 {
let i2 = i1 ^ other.hash(&f);
if let Err(err) = self.insert_internal(f, i1, i2, &mut log) {
self.restore_state(&log);
self.n_elements = n_elements_backup;
return Err(err);
}
}
}
Ok(())
}
fn is_empty(&self) -> bool {
self.n_elements == 0
}
fn len(&self) -> usize {
self.n_elements
}
fn query(&self, obj: &T) -> bool {
let (f, i1, i2) = self.start(obj);
if self.has_in_bucket(i1, f) {
return true;
}
if self.has_in_bucket(i2, f) {
return true;
}
false
}
}
impl<T, R, B> fmt::Debug for CuckooFilter<T, R, B>
where
T: Hash + ?Sized,
R: Rng,
B: BuildHasher + Clone + Eq,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"CuckooFilter {{ bucketsize: {}, n_buckets: {} }}",
self.bucketsize, self.n_buckets
)
}
}
#[cfg(test)]
mod tests {
use super::CuckooFilter;
use crate::{
filters::Filter,
hash_utils::BuildHasherSeeded,
test_util::{assert_send, NotSend},
};
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
#[test]
#[should_panic(expected = "bucketsize (0) must be greater or equal than 2")]
fn new_panics_bucketsize_0() {
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 0, 16, 8);
}
#[test]
#[should_panic(expected = "bucketsize (1) must be greater or equal than 2")]
fn new_panics_bucketsize_1() {
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 1, 16, 8);
}
#[test]
#[should_panic(expected = "n_buckets (0) must be a power of 2 and greater or equal than 2")]
fn new_panics_n_buckets_0() {
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 0, 8);
}
#[test]
#[should_panic(expected = "n_buckets (1) must be a power of 2 and greater or equal than 2")]
fn new_panics_n_buckets_1() {
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 1, 8);
}
#[test]
#[should_panic(expected = "n_buckets (5) must be a power of 2 and greater or equal than 2")]
fn new_panics_n_buckets_5() {
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 5, 8);
}
#[test]
#[should_panic(expected = "l_fingerprint (0) must be greater than 1 and less or equal than 64")]
fn new_panics_l_fingerprint_0() {
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 0);
}
#[test]
#[should_panic(expected = "l_fingerprint (1) must be greater than 1 and less or equal than 64")]
fn new_panics_l_fingerprint_1() {
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 1);
}
#[test]
#[should_panic(
expected = "l_fingerprint (65) must be greater than 1 and less or equal than 64"
)]
fn new_panics_l_fingerprint_65() {
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 65);
}
#[test]
#[should_panic(expected = "Table size too large")]
fn new_panics_table_size_overflow_1() {
CuckooFilter::<u64, ChaChaRng>::with_params(
ChaChaRng::from_seed([0; 32]),
usize::max_value(),
2,
2,
);
}
#[test]
#[should_panic(expected = "Table size too large")]
fn new_panics_table_size_overflow_2() {
CuckooFilter::<u64, ChaChaRng>::with_params(
ChaChaRng::from_seed([0; 32]),
2,
(((usize::max_value() as u128) + 1) / 2) as usize,
2,
);
}
#[test]
#[should_panic(expected = "Table size too large")]
fn new_panics_table_size_overflow_3() {
CuckooFilter::<u64, ChaChaRng>::with_params(
ChaChaRng::from_seed([0; 32]),
2,
(((usize::max_value() as u128) + 1) / 8) as usize,
64,
);
}
#[test]
fn getter() {
let cf =
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
assert_eq!(cf.bucketsize(), 2);
assert_eq!(cf.n_buckets(), 16);
assert_eq!(cf.l_fingerprint(), 8);
}
#[test]
fn is_empty() {
let cf =
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
assert!(cf.is_empty());
assert_eq!(cf.len(), 0);
}
#[test]
fn insert() {
let mut cf = CuckooFilter::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
assert!(cf.insert(&13).unwrap());
assert!(!cf.is_empty());
assert_eq!(cf.len(), 1);
assert!(cf.query(&13));
assert!(!cf.query(&42));
}
#[test]
fn double_insert() {
let mut cf = CuckooFilter::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
assert!(cf.insert(&13).unwrap());
assert!(cf.insert(&13).unwrap());
assert!(cf.query(&13));
}
#[test]
fn delete() {
let mut cf = CuckooFilter::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
cf.insert(&13).unwrap();
cf.insert(&42).unwrap();
assert!(cf.query(&13));
assert!(cf.query(&42));
assert_eq!(cf.len(), 2);
assert!(cf.delete(&13));
assert!(!cf.query(&13));
assert!(cf.query(&42));
assert_eq!(cf.len(), 1);
}
#[test]
fn clear() {
let mut cf = CuckooFilter::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
cf.insert(&1).unwrap();
cf.clear();
assert!(!cf.query(&1));
assert!(cf.is_empty());
}
#[test]
fn full() {
let mut cf = CuckooFilter::with_params(ChaChaRng::from_seed([0; 32]), 2, 2, 8);
for i in 0..4 {
cf.insert(&i).unwrap();
}
assert_eq!(cf.len(), 4);
for i in 0..4 {
assert!(cf.query(&i));
}
assert!(cf.insert(&5).is_err());
assert_eq!(cf.len(), 4);
assert!(!cf.query(&5)); }
#[test]
fn debug() {
let cf =
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
assert_eq!(
format!("{:?}", cf),
"CuckooFilter { bucketsize: 2, n_buckets: 16 }"
);
}
#[test]
fn clone() {
let mut cf1 = CuckooFilter::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
cf1.insert(&13).unwrap();
assert!(cf1.query(&13));
let cf2 = cf1.clone();
cf1.insert(&42).unwrap();
assert!(cf2.query(&13));
assert!(!cf2.query(&42));
}
#[test]
fn with_properties_4() {
let cf = CuckooFilter::<u64, ChaChaRng>::with_properties_4(
0.02,
1000,
ChaChaRng::from_seed([0; 32]),
);
assert_eq!(cf.bucketsize(), 4);
assert_eq!(cf.n_buckets(), 2048);
assert_eq!(cf.l_fingerprint(), 9);
}
#[test]
fn with_properties_8() {
let cf = CuckooFilter::<u64, ChaChaRng>::with_properties_8(
0.02,
1000,
ChaChaRng::from_seed([0; 32]),
);
assert_eq!(cf.bucketsize(), 8);
assert_eq!(cf.n_buckets(), 1024);
assert_eq!(cf.l_fingerprint(), 10);
}
#[test]
#[should_panic(expected = "expected_elements (0) must be at least 1")]
fn with_properties_4_panics_expected_elements_0() {
CuckooFilter::<u64, ChaChaRng>::with_properties_4(0.02, 0, ChaChaRng::from_seed([0; 32]));
}
#[test]
#[should_panic(expected = "false_positive_rate (0) must be greater than 0 and smaller than 1")]
fn with_properties_4_panics_false_positive_rate_0() {
CuckooFilter::<u64, ChaChaRng>::with_properties_4(0., 1000, ChaChaRng::from_seed([0; 32]));
}
#[test]
#[should_panic(expected = "false_positive_rate (1) must be greater than 0 and smaller than 1")]
fn with_properties_4_panics_false_positive_rate_1() {
CuckooFilter::<u64, ChaChaRng>::with_properties_4(1., 1000, ChaChaRng::from_seed([0; 32]));
}
#[test]
fn union() {
let mut cf1 =
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
let mut cf2 = CuckooFilter::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
cf1.insert(&13).unwrap();
cf1.insert(&42).unwrap();
cf2.insert(&130).unwrap();
cf2.insert(&420).unwrap();
cf1.union(&cf2).unwrap();
assert!(cf1.query(&13));
assert!(cf1.query(&42));
assert!(cf1.query(&130));
assert!(cf1.query(&420));
assert!(!cf2.query(&13));
assert!(!cf2.query(&42));
assert!(cf2.query(&130));
assert!(cf2.query(&420));
}
#[test]
#[should_panic(expected = "bucketsize must be equal (left=2, right=3)")]
fn union_panics_bucketsize() {
let mut cf1 =
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
let cf2 = CuckooFilter::with_params(ChaChaRng::from_seed([0; 32]), 3, 16, 8);
cf1.union(&cf2).unwrap();
}
#[test]
#[should_panic(expected = "n_buckets must be equal (left=16, right=32)")]
fn union_panics_n_buckets() {
let mut cf1 =
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
let cf2 = CuckooFilter::with_params(ChaChaRng::from_seed([0; 32]), 2, 32, 8);
cf1.union(&cf2).unwrap();
}
#[test]
#[should_panic(expected = "l_fingerprint must be equal (left=8, right=16)")]
fn union_panics_l_fingerprint() {
let mut cf1 =
CuckooFilter::<u64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
let cf2 = CuckooFilter::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 16);
cf1.union(&cf2).unwrap();
}
#[test]
#[should_panic(expected = "buildhasher must be equal")]
fn union_panics_buildhasher() {
let mut cf1 = CuckooFilter::<u64, ChaChaRng, BuildHasherSeeded>::with_params_and_hash(
ChaChaRng::from_seed([0; 32]),
2,
16,
8,
BuildHasherSeeded::new(0),
);
let cf2 = CuckooFilter::with_params_and_hash(
ChaChaRng::from_seed([0; 32]),
2,
16,
8,
BuildHasherSeeded::new(1),
);
cf1.union(&cf2).unwrap();
}
#[test]
fn union_full() {
let mut cf1 =
CuckooFilter::<i64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
let mut cf2 =
CuckooFilter::<i64, ChaChaRng>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
let mut obj = 0;
loop {
if cf1.insert(&obj).is_err() {
break;
}
obj += 1;
}
assert!(cf1.query(&0));
let n_cf2 = 10;
for i in 0..n_cf2 {
cf2.insert(&-i).unwrap();
}
assert_eq!(cf2.len(), n_cf2 as usize);
assert!(!cf2.query(&1));
assert!(cf2.union(&cf1).is_err());
assert_eq!(cf2.len(), n_cf2 as usize);
assert!(!cf2.query(&1));
}
#[test]
fn insert_unsized() {
let mut cf = CuckooFilter::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
assert!(cf.insert("test1").unwrap());
assert!(!cf.is_empty());
assert_eq!(cf.len(), 1);
assert!(cf.query("test1"));
assert!(!cf.query("test2"));
}
#[test]
fn send() {
let cf = CuckooFilter::<NotSend, _>::with_params(ChaChaRng::from_seed([0; 32]), 2, 16, 8);
assert_send(&cf);
}
}