bloom_filters/
buckets.rs

1use std::f64::consts::LN_2;
2use std::mem::size_of;
3use std::ptr::copy_nonoverlapping;
4
5type Word = u64;
6const BYTES_PER_WORD: usize = size_of::<Word>();
7const BITS_PER_WORD: usize = BYTES_PER_WORD * 8;
8
9pub struct Buckets {
10    data: Vec<Word>,
11    count: usize,
12    bucket_size: u8,
13    max: u8,
14}
15
16impl Buckets {
17    pub fn with_fp_rate(items_count: usize, fp_rate: f64, bucket_size: u8) -> Self {
18        Self::new(compute_m_num(items_count, fp_rate), bucket_size)
19    }
20
21    /// Creates a new Buckets with the provided number of buckets where
22    /// each bucket is the specified number of bits.
23    pub fn new(count: usize, bucket_size: u8) -> Self {
24        debug_assert!(bucket_size < 8);
25        Self {
26            data: vec![0; (count * bucket_size as usize + BITS_PER_WORD - 1) / BITS_PER_WORD],
27            count,
28            bucket_size,
29            max: (1u8 << bucket_size) - 1,
30        }
31    }
32
33    pub fn with_raw_data(count: usize, bucket_size: u8, raw_data: &[u8]) -> Self {
34        debug_assert!(bucket_size < 8);
35        debug_assert!((count * bucket_size as usize + BITS_PER_WORD - 1) / BITS_PER_WORD * 8 == raw_data.len());
36        let data = raw_data
37            .chunks(BYTES_PER_WORD)
38            .map(|buf| {
39                let mut d = [0u8; BYTES_PER_WORD];
40                let d_slice = d.as_mut_ptr();
41                unsafe {
42                    copy_nonoverlapping(buf.as_ptr(), d_slice, BYTES_PER_WORD);
43                    (*(d_slice as *const Word)).to_le()
44                }
45            })
46            .collect::<Vec<_>>();
47
48        Self {
49            data,
50            count,
51            bucket_size,
52            max: (1u8 << bucket_size) - 1,
53        }
54    }
55
56    pub fn raw_data(&self) -> Vec<u8> {
57        let mut result = vec![0; self.data.len() * BYTES_PER_WORD];
58        for (d, chunk) in self.data.iter().zip(result.chunks_mut(BYTES_PER_WORD)) {
59            unsafe {
60                let bytes = *(&d.to_le() as *const _ as *const [u8; BYTES_PER_WORD]);
61                copy_nonoverlapping((&bytes).as_ptr(), chunk.as_mut_ptr(), BYTES_PER_WORD);
62            }
63        }
64        result
65    }
66
67    pub fn update(&mut self, raw_data: &[u8]) {
68        let new_data = self
69            .data
70            .iter()
71            .zip(raw_data.chunks(BYTES_PER_WORD))
72            .map(|(word, bytes)| {
73                bytes.iter().enumerate().fold(*word, |acc, (offset, byte)| {
74                    acc | (*byte as Word) << (offset * BYTES_PER_WORD)
75                })
76            })
77            .collect::<Vec<_>>();
78
79        self.data = new_data;
80    }
81
82    #[inline(always)]
83    pub fn len(&self) -> usize {
84        self.count
85    }
86
87    #[inline(always)]
88    pub fn max_value(&self) -> u8 {
89        self.max
90    }
91
92    pub fn reset(&mut self) {
93        self.data.iter_mut().for_each(|x| *x = 0)
94    }
95
96    pub fn increment(&mut self, bucket: usize, delta: i8) {
97        let v = (self.get(bucket) as i8).saturating_add(delta);
98        let value = if v < 0 {
99            0u8
100        } else if v > self.max as i8 {
101            self.max
102        } else {
103            v as u8
104        };
105        self.set(bucket, value);
106    }
107
108    pub fn set(&mut self, bucket: usize, byte: u8) {
109        let offset = bucket * self.bucket_size as usize;
110        let length = self.bucket_size as usize;
111        let word = if byte > self.max as u8 { self.max } else { byte } as Word;
112        self.set_word(offset, length, word);
113    }
114
115    pub fn get(&self, bucket: usize) -> u8 {
116        self.get_word(bucket * self.bucket_size as usize, self.bucket_size as usize) as u8
117    }
118
119    fn set_word(&mut self, offset: usize, length: usize, word: Word) {
120        let word_index = offset / BITS_PER_WORD;
121        let word_offset = offset % BITS_PER_WORD;
122
123        if word_offset + length > BITS_PER_WORD {
124            let remain = BITS_PER_WORD - word_offset;
125            self.set_word(offset, remain, word);
126            self.set_word(offset + remain, length - remain, word >> remain);
127        } else {
128            let bit_mask = (1 << length) - 1;
129            self.data[word_index] &= !(bit_mask << word_offset);
130            self.data[word_index] |= (word & bit_mask) << word_offset;
131        }
132    }
133
134    fn get_word(&self, offset: usize, length: usize) -> Word {
135        let word_index = offset / BITS_PER_WORD;
136        let word_offset = offset % BITS_PER_WORD;
137        if word_offset + length > BITS_PER_WORD {
138            let remain = BITS_PER_WORD - word_offset;
139            self.get_word(offset, remain) | (self.get_word(offset + remain, length - remain) << remain)
140        } else {
141            let bit_mask = (1 << length) - 1;
142            (self.data[word_index] & (bit_mask << word_offset)) >> word_offset
143        }
144    }
145}
146
147const LN_2_2: f64 = LN_2 * LN_2;
148
149// Calculates the optimal buckets count, m, based on the number of
150// items and the desired rate of false positives.
151fn compute_m_num(items_count: usize, fp_rate: f64) -> usize {
152    debug_assert!(items_count > 0);
153    debug_assert!(fp_rate > 0.0 && fp_rate < 1.0);
154    ((items_count as f64) * fp_rate.ln().abs() / LN_2_2).ceil() as usize
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn one_bit() {
163        let mut buckets = Buckets::new(100, 1);
164        buckets.set(0, 1);
165        buckets.set(1, 0);
166        buckets.set(2, 1);
167        buckets.set(3, 0);
168        assert_eq!(1, buckets.get(0));
169        assert_eq!(0, buckets.get(1));
170        assert_eq!(1, buckets.get(2));
171        assert_eq!(0, buckets.get(3));
172    }
173
174    #[test]
175    fn three_bits() {
176        let mut buckets = Buckets::new(100, 3);
177        buckets.set(0, 1);
178        buckets.set(1, 2);
179        buckets.set(10, 3);
180        buckets.set(11, 4);
181        buckets.set(20, 5);
182        buckets.set(21, 6);
183        assert_eq!(1, buckets.get(0));
184        assert_eq!(2, buckets.get(1));
185        assert_eq!(3, buckets.get(10));
186        assert_eq!(4, buckets.get(11));
187        assert_eq!(5, buckets.get(20));
188        assert_eq!(6, buckets.get(21));
189    }
190
191    #[test]
192    fn reset() {
193        let mut buckets = Buckets::new(100, 1);
194        buckets.set(1, 1);
195        assert_eq!(1, buckets.get(1));
196        buckets.reset();
197        assert_eq!(0, buckets.get(1));
198    }
199
200    #[test]
201    fn increment() {
202        let mut buckets = Buckets::new(100, 3);
203        buckets.increment(10, 2);
204        assert_eq!(2, buckets.get(10));
205        buckets.increment(10, 1);
206        assert_eq!(3, buckets.get(10));
207        buckets.increment(10, 100);
208        assert_eq!(7, buckets.get(10));
209        buckets.increment(10, -1);
210        assert_eq!(6, buckets.get(10));
211        buckets.increment(10, -10);
212        assert_eq!(0, buckets.get(10));
213
214        // test overflow
215        let mut buckets = Buckets::new(3, 7);
216        buckets.increment(0, 127);
217        assert_eq!(127, buckets.get(0));
218        buckets.increment(0, 1);
219        assert_eq!(127, buckets.get(0));
220    }
221
222    #[test]
223    fn with_raw_data() {
224        let mut buckets = Buckets::new(100, 1);
225        buckets.set(0, 1);
226        buckets.set(1, 0);
227        buckets.set(2, 1);
228        buckets.set(3, 0);
229        let raw_data = buckets.raw_data();
230        let buckets = Buckets::with_raw_data(100, 1, &raw_data);
231        assert_eq!(1, buckets.get(0));
232        assert_eq!(0, buckets.get(1));
233        assert_eq!(1, buckets.get(2));
234        assert_eq!(0, buckets.get(3));
235
236        let mut buckets = Buckets::new(100, 3);
237        buckets.set(0, 1);
238        buckets.set(1, 2);
239        buckets.set(10, 3);
240        buckets.set(11, 4);
241        buckets.set(20, 5);
242        buckets.set(21, 6);
243        let raw_data = buckets.raw_data();
244        let buckets = Buckets::with_raw_data(100, 3, &raw_data);
245        assert_eq!(1, buckets.get(0));
246        assert_eq!(2, buckets.get(1));
247        assert_eq!(3, buckets.get(10));
248        assert_eq!(4, buckets.get(11));
249        assert_eq!(5, buckets.get(20));
250        assert_eq!(6, buckets.get(21));
251    }
252
253    #[test]
254    fn update() {
255        let mut b1 = Buckets::new(100, 1);
256        b1.set(0, 1);
257        b1.set(20, 1);
258        b1.set(63, 1);
259
260        let mut b2 = Buckets::new(50, 1);
261        b2.set(7, 1);
262        b2.set(20, 1);
263        b2.set(21, 1);
264        b2.set(49, 1);
265
266        b1.update(&b2.raw_data());
267        assert_eq!(1, b1.get(0));
268        assert_eq!(0, b1.get(1));
269        assert_eq!(1, b1.get(20));
270        assert_eq!(1, b1.get(21));
271        assert_eq!(1, b1.get(49));
272        assert_eq!(1, b1.get(63));
273    }
274}