sbbf_rs_safe/
lib.rs

1use sbbf_rs::{FilterFn, ALIGNMENT, BUCKET_SIZE};
2use std::alloc::{alloc_zeroed, dealloc, Layout};
3use std::fmt;
4
5/// A split block bloom filter that handles it's own memory
6pub struct Filter {
7    filter_fn: FilterFn,
8    buf: Buf,
9    num_buckets: usize,
10}
11
12impl Filter {
13    /// Create a new filter using the parameters.
14    ///
15    /// Calculated length will be rounded up to the nearest multiple of [BUCKET_SIZE]
16    ///
17    /// `bits_per_key` can be used to adjust the false positive rate.
18    /// Some info can be found [here](https://github.com/apache/parquet-format/blob/master/BloomFilter.md#sizing-an-sbbf).
19    ///
20    /// `num_keys` means the number of unique hashes that are expected to be inserted to this bloom filter.
21    pub fn new(bits_per_key: usize, num_keys: usize) -> Self {
22        let len = bits_per_key * num_keys / 8;
23        let len = ((len + BUCKET_SIZE - 1) / BUCKET_SIZE) * BUCKET_SIZE;
24        let len = if len == 0 { BUCKET_SIZE } else { len };
25        Self {
26            filter_fn: FilterFn::new(),
27            buf: Buf::new(len),
28            num_buckets: len / BUCKET_SIZE,
29        }
30    }
31
32    /// Check if the filter contains the hash.
33    #[inline(always)]
34    pub fn contains_hash(&self, hash: u64) -> bool {
35        unsafe {
36            self.filter_fn
37                .contains(self.buf.ptr, self.num_buckets, hash)
38        }
39    }
40
41    /// Insert the hash into the filter.
42    ///
43    /// Returns true if the hash was already in the filter.
44    #[inline(always)]
45    pub fn insert_hash(&mut self, hash: u64) -> bool {
46        unsafe { self.filter_fn.insert(self.buf.ptr, self.num_buckets, hash) }
47    }
48
49    /// Returns the slice of bytes that represent this filter.
50    ///
51    /// The filter can be restored using these bytes with the `Filter::from_bytes` method.
52    #[inline(always)]
53    pub fn as_bytes(&self) -> &[u8] {
54        unsafe { std::slice::from_raw_parts(self.buf.ptr, self.buf.len) }
55    }
56
57    /// Returns a mutable reference to the slice of bytes that represent this filter.
58    ///
59    /// This can be used to directly read into the filter from a file
60    #[inline(always)]
61    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
62        unsafe { std::slice::from_raw_parts_mut(self.buf.ptr, self.buf.len) }
63    }
64
65    /// Restore a filter from the given bytes.
66    ///
67    /// Returns None if the bytes are invalid.
68    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
69        let len = bytes.len();
70
71        if len == 0 || (len % BUCKET_SIZE != 0) {
72            return None;
73        }
74
75        let buf = Buf::new(len);
76
77        let buf_bytes = unsafe { std::slice::from_raw_parts_mut(buf.ptr, buf.len) };
78        buf_bytes[..len].copy_from_slice(bytes);
79
80        Some(Self {
81            filter_fn: FilterFn::new(),
82            buf,
83            num_buckets: len / BUCKET_SIZE,
84        })
85    }
86
87    /// Resets all bits to zero in the filter.
88    ///
89    /// The filter is empty when all the bits are zero.
90    pub fn reset(&mut self) {
91        self.as_bytes_mut().fill(0);
92    }
93}
94
95struct Buf {
96    ptr: *mut u8,
97    layout: Layout,
98    len: usize,
99}
100
101impl Buf {
102    fn new(len: usize) -> Self {
103        let padded_len = (len + ALIGNMENT - 1) / ALIGNMENT * ALIGNMENT;
104
105        let layout = Layout::from_size_align(padded_len, ALIGNMENT).unwrap();
106        let ptr = unsafe { alloc_zeroed(layout) };
107
108        Self { layout, ptr, len }
109    }
110}
111
112impl Drop for Buf {
113    fn drop(&mut self) {
114        unsafe {
115            dealloc(self.ptr, self.layout);
116        }
117    }
118}
119
120unsafe impl Send for Filter {}
121unsafe impl Sync for Filter {}
122
123impl Clone for Filter {
124    fn clone(&self) -> Self {
125        Self::from_bytes(self.as_bytes()).unwrap()
126    }
127}
128
129impl fmt::Debug for Filter {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        f.debug_struct("Filter")
132            .field("filter_fn", &self.filter_fn.which())
133            .field("num_buckets", &self.num_buckets)
134            .finish()
135    }
136}