#include <stdlib.h>
#include "lib/malloc/malloc.h"
#include "lib/container/bloomfilt.h"
#include "lib/intmath/bits.h"
#include "lib/log/util_bug.h"
#include "ext/siphash.h"
#define N_BITS_PER_ITEM (BLOOMFILT_N_HASHES * 2)
struct bloomfilt_t {
struct sipkey key[BLOOMFILT_N_HASHES];
bloomfilt_hash_fn hashfn;
uint32_t mask;
bitarray_t *ba;
};
#define BIT(set, n) ((n) & (set)->mask)
void
bloomfilt_add(bloomfilt_t *set,
const void *item)
{
int i;
for (i = 0; i < BLOOMFILT_N_HASHES; ++i) {
uint64_t h = set->hashfn(&set->key[i], item);
uint32_t high_bits = (uint32_t)(h >> 32);
uint32_t low_bits = (uint32_t)(h);
bitarray_set(set->ba, BIT(set, high_bits));
bitarray_set(set->ba, BIT(set, low_bits));
}
}
int
bloomfilt_probably_contains(const bloomfilt_t *set,
const void *item)
{
int i, matches = 0;
for (i = 0; i < BLOOMFILT_N_HASHES; ++i) {
uint64_t h = set->hashfn(&set->key[i], item);
uint32_t high_bits = (uint32_t)(h >> 32);
uint32_t low_bits = (uint32_t)(h);
matches += !! bitarray_is_set(set->ba, BIT(set, high_bits));
matches += !! bitarray_is_set(set->ba, BIT(set, low_bits));
}
return matches == N_BITS_PER_ITEM;
}
bloomfilt_t *
bloomfilt_new(int max_elements,
bloomfilt_hash_fn hashfn,
const uint8_t *random_key)
{
int n_bits = 1u << (tor_log2(max_elements)+5);
bloomfilt_t *r = tor_malloc(sizeof(bloomfilt_t));
r->mask = n_bits - 1;
r->ba = bitarray_init_zero(n_bits);
tor_assert(sizeof(r->key) == BLOOMFILT_KEY_LEN);
memcpy(r->key, random_key, sizeof(r->key));
r->hashfn = hashfn;
return r;
}
void
bloomfilt_free_(bloomfilt_t *set)
{
if (!set)
return;
bitarray_free(set->ba);
tor_free(set);
}