1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
extern crate bitvec;
extern crate byteorder;
extern crate digest;
extern crate murmurhash3;
extern crate rand;

use bitvec::{bitvec, BitVec, LittleEndian};
use byteorder::ReadBytesExt;
use murmurhash3::murmurhash3_x86_32;

use std::io::{Error, ErrorKind, Read};

#[derive(Debug)]
pub struct Bloom {
    level: u32,
    n_hash_funcs: u32,
    size: usize,
    bitvec: BitVec<bitvec::LittleEndian>,
}

pub fn calculate_n_hash_funcs(error_rate: f32) -> u32 {
    ((1.0 / error_rate).ln() / (2.0_f32).ln()).ceil() as u32
}

pub fn calculate_size(elements: usize, error_rate: f32) -> usize {
    let n_hash_funcs = calculate_n_hash_funcs(error_rate);
    let hashes = n_hash_funcs as f32;
    return (1.0_f32
        - (hashes * (elements as f32 + 0.5) / (1.0_f32 - error_rate.powf(1.0 / hashes)).ln()))
    .ceil() as usize;
}

impl Bloom {
    pub fn new(size: usize, n_hash_funcs: u32, level: u32) -> Bloom {
        let bitvec: BitVec<LittleEndian> = bitvec![LittleEndian; 0; size];

        Bloom {
            level: level,
            n_hash_funcs: n_hash_funcs,
            size: size,
            bitvec: bitvec,
        }
    }

    pub fn from_bytes(cursor: &mut &[u8]) -> Result<Bloom, Error> {
        // Load the layer metadata. bloomer.py writes size, nHashFuncs and level as little-endian
        // unsigned ints.
        let size = cursor.read_u32::<byteorder::LittleEndian>()? as usize;
        let n_hash_funcs = cursor.read_u32::<byteorder::LittleEndian>()?;
        let level = cursor.read_u32::<byteorder::LittleEndian>()?;

        let shifted_size = size.wrapping_shr(3);
        let byte_count = if size % 8 != 0 {
            shifted_size + 1
        } else {
            shifted_size
        };

        let mut bitvec_buf = vec![0u8; byte_count];
        cursor.read_exact(&mut bitvec_buf)?;

        Ok(Bloom {
            level,
            n_hash_funcs,
            size,
            bitvec: bitvec_buf.into(),
        })
    }

    fn hash(&self, n_fn: u32, key: &[u8]) -> usize {
        let hash_seed = (n_fn << 16) + self.level;
        let h = murmurhash3_x86_32(key, hash_seed) as usize % self.size;
        h
    }

    pub fn put(&mut self, item: &[u8]) {
        for i in 0..self.n_hash_funcs {
            let index = self.hash(i, item);
            self.bitvec.set(index, true);
        }
    }

    pub fn has(&self, item: &[u8]) -> bool {
        for i in 0..self.n_hash_funcs {
            match self.bitvec.get(self.hash(i, item)) {
                Some(false) => return false,
                Some(true) => (),
                None => panic!(
                    "access outside the bloom filter bit vector (this is almost certainly a bug)"
                ),
            }
        }

        true
    }

    pub fn clear(&mut self) {
        self.bitvec.clear()
    }
}

#[derive(Debug)]
pub struct Cascade {
    filter: Bloom,
    child_layer: Option<Box<Cascade>>,
}

impl Cascade {
    pub fn new(size: usize, n_hash_funcs: u32) -> Cascade {
        return Cascade::new_layer(size, n_hash_funcs, 1);
    }

    pub fn from_bytes(bytes: &[u8]) -> Result<Option<Box<Cascade>>, Error> {
        if bytes.len() == 0 {
            return Ok(None);
        }
        let mut cursor = bytes;
        let version = cursor.read_u16::<byteorder::LittleEndian>()?;
        println!("version is {:x} - {:x?}", version, bytes);
        if version != 1 {
            return Err(Error::new(ErrorKind::InvalidInput, "Invalid version"));
        }
        Ok(Some(Box::new(Cascade {
            filter: Bloom::from_bytes(&mut cursor)?,
            child_layer: Cascade::from_bytes(cursor)?,
        })))
    }

    fn new_layer(size: usize, n_hash_funcs: u32, layer: u32) -> Cascade {
        Cascade {
            filter: Bloom::new(size, n_hash_funcs, layer),
            child_layer: Option::None,
        }
    }

    pub fn initialize(&mut self, entries: Vec<Vec<u8>>, exclusions: Vec<Vec<u8>>) {
        let mut false_positives = Vec::new();
        for entry in &entries {
            self.filter.put(entry);
        }

        for entry in exclusions {
            if self.filter.has(&entry) {
                false_positives.push(entry);
            }
        }

        if false_positives.len() > 0 {
            let n_hash_funcs = calculate_n_hash_funcs(0.5);
            let size = calculate_size(false_positives.len(), 0.5);
            let mut child = Box::new(Cascade::new_layer(
                size,
                n_hash_funcs,
                self.filter.level + 1,
            ));
            child.initialize(false_positives, entries);
            self.child_layer = Some(child);
        }
    }

    pub fn has(&self, entry: &[u8]) -> bool {
        if self.filter.has(&entry) {
            match self.child_layer {
                Some(ref child) => {
                    let child_value = !child.has(entry);
                    return child_value;
                }
                None => {
                    return true;
                }
            }
        }
        return false;
    }

    pub fn check(&self, entries: Vec<Vec<u8>>, exclusions: Vec<Vec<u8>>) -> bool {
        for entry in entries {
            if !self.has(&entry) {
                return false;
            }
        }

        for entry in exclusions {
            if self.has(&entry) {
                return false;
            }
        }

        true
    }
}

#[cfg(test)]
mod tests {
    use calculate_n_hash_funcs;
    use calculate_size;
    use rand::Rng;
    use Bloom;
    use Cascade;

    #[test]
    fn bloom_test_bloom_size() {
        let error_rate = 0.01;
        let elements = 1024;
        let n_hash_funcs = calculate_n_hash_funcs(error_rate);
        let size = calculate_size(elements, error_rate);

        let bloom = Bloom::new(size, n_hash_funcs, 0);
        assert!(bloom.bitvec.len() == 9829);
    }

    #[test]
    fn bloom_test_put() {
        let error_rate = 0.01;
        let elements = 1024;
        let n_hash_funcs = calculate_n_hash_funcs(error_rate);
        let size = calculate_size(elements, error_rate);

        let mut bloom = Bloom::new(size, n_hash_funcs, 0);
        let key: &[u8] = b"foo";

        bloom.put(key);
    }

    #[test]
    fn bloom_test_has() {
        let error_rate = 0.01;
        let elements = 1024;
        let n_hash_funcs = calculate_n_hash_funcs(error_rate);
        let size = calculate_size(elements, error_rate);

        let mut bloom = Bloom::new(size, n_hash_funcs, 0);
        let key: &[u8] = b"foo";

        bloom.put(key);
        assert!(bloom.has(key) == true);
        assert!(bloom.has(b"bar") == false);
    }

    #[test]
    fn bloom_test_from_bytes() {
        let src: Vec<u8> = vec![
            0x09, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x41, 0x00,
        ];

        match Bloom::from_bytes(&mut &src[..]) {
            Ok(mut bloom) => {
                assert!(bloom.has(b"this") == true);
                assert!(bloom.has(b"that") == true);
                assert!(bloom.has(b"other") == false);

                bloom.put(b"other");
                assert!(bloom.has(b"other") == true);
            }
            Err(_) => {
                panic!("Parsing failed");
            }
        };

        let short: Vec<u8> = vec![
            0x09, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x41,
        ];
        match Bloom::from_bytes(&mut &short[..]) {
            Ok(_) => {
                panic!("Parsing should fail; data is truncated");
            }
            Err(_) => {}
        };
    }

    #[test]
    fn bloom_test_from_file() {
        let v = include_bytes!("../test_data/test_bf");
        let bloom = Bloom::from_bytes(&mut &v[..]).expect("parsing Bloom should succeed");
        assert!(bloom.has(b"this") == true);
        assert!(bloom.has(b"that") == true);
        assert!(bloom.has(b"yet another test") == false);
    }

    #[test]
    fn cascade_test() {
        // thread_rng is often the most convenient source of randomness:
        let mut rng = rand::thread_rng();

        // create some entries and exclusions
        let mut foo: Vec<Vec<u8>> = Vec::new();
        let mut bar: Vec<Vec<u8>> = Vec::new();

        for i in 0..500 {
            let s = format!("{}", i);
            let bytes = s.into_bytes();
            foo.push(bytes);
        }

        for _ in 0..100 {
            let idx = rng.gen_range(0, foo.len());
            bar.push(foo.swap_remove(idx));
        }

        let error_rate = 0.5;
        let elements = 500;
        let n_hash_funcs = calculate_n_hash_funcs(error_rate);
        let size = calculate_size(elements, error_rate);

        let mut cascade = Cascade::new(size, n_hash_funcs);
        cascade.initialize(foo.clone(), bar.clone());

        assert!(cascade.check(foo.clone(), bar.clone()) == true);
    }

    #[test]
    fn cascade_from_file_bytes_test() {
        let v = include_bytes!("../test_data/test_mlbf");
        let cascade = Cascade::from_bytes(v)
            .expect("parsing Cascade should succeed")
            .expect("Cascade should be Some");
        assert!(cascade.has(b"test") == true);
        assert!(cascade.has(b"another test") == true);
        assert!(cascade.has(b"yet another test") == true);
        assert!(cascade.has(b"blah") == false);
        assert!(cascade.has(b"blah blah") == false);
        assert!(cascade.has(b"blah blah blah") == false);

        let v = include_bytes!("../test_data/test_short_mlbf");
        Cascade::from_bytes(v).expect_err("parsing truncated Cascade should fail");
    }
}