1use std::f64::consts::LN_2;
2use std::mem::size_of;
3use std::ptr::copy_nonoverlapping;
4
5type Word = u64;
6const BYTES_PER_WORD: usize = size_of::<Word>();
7const BITS_PER_WORD: usize = BYTES_PER_WORD * 8;
8
9pub struct Buckets {
10 data: Vec<Word>,
11 count: usize,
12 bucket_size: u8,
13 max: u8,
14}
15
16impl Buckets {
17 pub fn with_fp_rate(items_count: usize, fp_rate: f64, bucket_size: u8) -> Self {
18 Self::new(compute_m_num(items_count, fp_rate), bucket_size)
19 }
20
21 pub fn new(count: usize, bucket_size: u8) -> Self {
24 debug_assert!(bucket_size < 8);
25 Self {
26 data: vec![0; (count * bucket_size as usize + BITS_PER_WORD - 1) / BITS_PER_WORD],
27 count,
28 bucket_size,
29 max: (1u8 << bucket_size) - 1,
30 }
31 }
32
33 pub fn with_raw_data(count: usize, bucket_size: u8, raw_data: &[u8]) -> Self {
34 debug_assert!(bucket_size < 8);
35 debug_assert!((count * bucket_size as usize + BITS_PER_WORD - 1) / BITS_PER_WORD * 8 == raw_data.len());
36 let data = raw_data
37 .chunks(BYTES_PER_WORD)
38 .map(|buf| {
39 let mut d = [0u8; BYTES_PER_WORD];
40 let d_slice = d.as_mut_ptr();
41 unsafe {
42 copy_nonoverlapping(buf.as_ptr(), d_slice, BYTES_PER_WORD);
43 (*(d_slice as *const Word)).to_le()
44 }
45 })
46 .collect::<Vec<_>>();
47
48 Self {
49 data,
50 count,
51 bucket_size,
52 max: (1u8 << bucket_size) - 1,
53 }
54 }
55
56 pub fn raw_data(&self) -> Vec<u8> {
57 let mut result = vec![0; self.data.len() * BYTES_PER_WORD];
58 for (d, chunk) in self.data.iter().zip(result.chunks_mut(BYTES_PER_WORD)) {
59 unsafe {
60 let bytes = *(&d.to_le() as *const _ as *const [u8; BYTES_PER_WORD]);
61 copy_nonoverlapping((&bytes).as_ptr(), chunk.as_mut_ptr(), BYTES_PER_WORD);
62 }
63 }
64 result
65 }
66
67 pub fn update(&mut self, raw_data: &[u8]) {
68 let new_data = self
69 .data
70 .iter()
71 .zip(raw_data.chunks(BYTES_PER_WORD))
72 .map(|(word, bytes)| {
73 bytes.iter().enumerate().fold(*word, |acc, (offset, byte)| {
74 acc | (*byte as Word) << (offset * BYTES_PER_WORD)
75 })
76 })
77 .collect::<Vec<_>>();
78
79 self.data = new_data;
80 }
81
82 #[inline(always)]
83 pub fn len(&self) -> usize {
84 self.count
85 }
86
87 #[inline(always)]
88 pub fn max_value(&self) -> u8 {
89 self.max
90 }
91
92 pub fn reset(&mut self) {
93 self.data.iter_mut().for_each(|x| *x = 0)
94 }
95
96 pub fn increment(&mut self, bucket: usize, delta: i8) {
97 let v = (self.get(bucket) as i8).saturating_add(delta);
98 let value = if v < 0 {
99 0u8
100 } else if v > self.max as i8 {
101 self.max
102 } else {
103 v as u8
104 };
105 self.set(bucket, value);
106 }
107
108 pub fn set(&mut self, bucket: usize, byte: u8) {
109 let offset = bucket * self.bucket_size as usize;
110 let length = self.bucket_size as usize;
111 let word = if byte > self.max as u8 { self.max } else { byte } as Word;
112 self.set_word(offset, length, word);
113 }
114
115 pub fn get(&self, bucket: usize) -> u8 {
116 self.get_word(bucket * self.bucket_size as usize, self.bucket_size as usize) as u8
117 }
118
119 fn set_word(&mut self, offset: usize, length: usize, word: Word) {
120 let word_index = offset / BITS_PER_WORD;
121 let word_offset = offset % BITS_PER_WORD;
122
123 if word_offset + length > BITS_PER_WORD {
124 let remain = BITS_PER_WORD - word_offset;
125 self.set_word(offset, remain, word);
126 self.set_word(offset + remain, length - remain, word >> remain);
127 } else {
128 let bit_mask = (1 << length) - 1;
129 self.data[word_index] &= !(bit_mask << word_offset);
130 self.data[word_index] |= (word & bit_mask) << word_offset;
131 }
132 }
133
134 fn get_word(&self, offset: usize, length: usize) -> Word {
135 let word_index = offset / BITS_PER_WORD;
136 let word_offset = offset % BITS_PER_WORD;
137 if word_offset + length > BITS_PER_WORD {
138 let remain = BITS_PER_WORD - word_offset;
139 self.get_word(offset, remain) | (self.get_word(offset + remain, length - remain) << remain)
140 } else {
141 let bit_mask = (1 << length) - 1;
142 (self.data[word_index] & (bit_mask << word_offset)) >> word_offset
143 }
144 }
145}
146
147const LN_2_2: f64 = LN_2 * LN_2;
148
149fn compute_m_num(items_count: usize, fp_rate: f64) -> usize {
152 debug_assert!(items_count > 0);
153 debug_assert!(fp_rate > 0.0 && fp_rate < 1.0);
154 ((items_count as f64) * fp_rate.ln().abs() / LN_2_2).ceil() as usize
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn one_bit() {
163 let mut buckets = Buckets::new(100, 1);
164 buckets.set(0, 1);
165 buckets.set(1, 0);
166 buckets.set(2, 1);
167 buckets.set(3, 0);
168 assert_eq!(1, buckets.get(0));
169 assert_eq!(0, buckets.get(1));
170 assert_eq!(1, buckets.get(2));
171 assert_eq!(0, buckets.get(3));
172 }
173
174 #[test]
175 fn three_bits() {
176 let mut buckets = Buckets::new(100, 3);
177 buckets.set(0, 1);
178 buckets.set(1, 2);
179 buckets.set(10, 3);
180 buckets.set(11, 4);
181 buckets.set(20, 5);
182 buckets.set(21, 6);
183 assert_eq!(1, buckets.get(0));
184 assert_eq!(2, buckets.get(1));
185 assert_eq!(3, buckets.get(10));
186 assert_eq!(4, buckets.get(11));
187 assert_eq!(5, buckets.get(20));
188 assert_eq!(6, buckets.get(21));
189 }
190
191 #[test]
192 fn reset() {
193 let mut buckets = Buckets::new(100, 1);
194 buckets.set(1, 1);
195 assert_eq!(1, buckets.get(1));
196 buckets.reset();
197 assert_eq!(0, buckets.get(1));
198 }
199
200 #[test]
201 fn increment() {
202 let mut buckets = Buckets::new(100, 3);
203 buckets.increment(10, 2);
204 assert_eq!(2, buckets.get(10));
205 buckets.increment(10, 1);
206 assert_eq!(3, buckets.get(10));
207 buckets.increment(10, 100);
208 assert_eq!(7, buckets.get(10));
209 buckets.increment(10, -1);
210 assert_eq!(6, buckets.get(10));
211 buckets.increment(10, -10);
212 assert_eq!(0, buckets.get(10));
213
214 let mut buckets = Buckets::new(3, 7);
216 buckets.increment(0, 127);
217 assert_eq!(127, buckets.get(0));
218 buckets.increment(0, 1);
219 assert_eq!(127, buckets.get(0));
220 }
221
222 #[test]
223 fn with_raw_data() {
224 let mut buckets = Buckets::new(100, 1);
225 buckets.set(0, 1);
226 buckets.set(1, 0);
227 buckets.set(2, 1);
228 buckets.set(3, 0);
229 let raw_data = buckets.raw_data();
230 let buckets = Buckets::with_raw_data(100, 1, &raw_data);
231 assert_eq!(1, buckets.get(0));
232 assert_eq!(0, buckets.get(1));
233 assert_eq!(1, buckets.get(2));
234 assert_eq!(0, buckets.get(3));
235
236 let mut buckets = Buckets::new(100, 3);
237 buckets.set(0, 1);
238 buckets.set(1, 2);
239 buckets.set(10, 3);
240 buckets.set(11, 4);
241 buckets.set(20, 5);
242 buckets.set(21, 6);
243 let raw_data = buckets.raw_data();
244 let buckets = Buckets::with_raw_data(100, 3, &raw_data);
245 assert_eq!(1, buckets.get(0));
246 assert_eq!(2, buckets.get(1));
247 assert_eq!(3, buckets.get(10));
248 assert_eq!(4, buckets.get(11));
249 assert_eq!(5, buckets.get(20));
250 assert_eq!(6, buckets.get(21));
251 }
252
253 #[test]
254 fn update() {
255 let mut b1 = Buckets::new(100, 1);
256 b1.set(0, 1);
257 b1.set(20, 1);
258 b1.set(63, 1);
259
260 let mut b2 = Buckets::new(50, 1);
261 b2.set(7, 1);
262 b2.set(20, 1);
263 b2.set(21, 1);
264 b2.set(49, 1);
265
266 b1.update(&b2.raw_data());
267 assert_eq!(1, b1.get(0));
268 assert_eq!(0, b1.get(1));
269 assert_eq!(1, b1.get(20));
270 assert_eq!(1, b1.get(21));
271 assert_eq!(1, b1.get(49));
272 assert_eq!(1, b1.get(63));
273 }
274}