Skip to main content

automerge/sync/
bloom.rs

1use std::borrow::Borrow;
2
3use crate::storage::parse;
4use crate::ChangeHash;
5
6// These constants correspond to a 1% false positive rate. The values can be changed without
7// breaking compatibility of the network protocol, since the parameters used for a particular
8// Bloom filter are encoded in the wire format.
9const BITS_PER_ENTRY: u32 = 10;
10const NUM_PROBES: u32 = 7;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize)]
13pub struct BloomFilter {
14    num_entries: u32,
15    num_bits_per_entry: u32,
16    num_probes: u32,
17    bits: Vec<u8>,
18}
19
20impl Default for BloomFilter {
21    fn default() -> Self {
22        BloomFilter {
23            num_entries: 0,
24            num_bits_per_entry: BITS_PER_ENTRY,
25            num_probes: NUM_PROBES,
26            bits: Vec::new(),
27        }
28    }
29}
30
31#[derive(Debug, thiserror::Error)]
32pub(crate) enum ParseError {
33    #[error(transparent)]
34    Leb128(#[from] parse::leb128::Error),
35}
36
37impl BloomFilter {
38    pub fn to_bytes(&self) -> Vec<u8> {
39        let mut buf = Vec::new();
40        if self.num_entries != 0 {
41            leb128::write::unsigned(&mut buf, self.num_entries as u64).unwrap();
42            leb128::write::unsigned(&mut buf, self.num_bits_per_entry as u64).unwrap();
43            leb128::write::unsigned(&mut buf, self.num_probes as u64).unwrap();
44            buf.extend(&self.bits);
45        }
46        buf
47    }
48
49    pub(crate) fn parse(input: parse::Input<'_>) -> parse::ParseResult<'_, Self, ParseError> {
50        if input.is_empty() {
51            Ok((input, Self::default()))
52        } else {
53            let (i, num_entries) = parse::leb128_u32(input)?;
54            let (i, num_bits_per_entry) = parse::leb128_u32(i)?;
55            let (i, num_probes) = parse::leb128_u32(i)?;
56            let (i, bits) = parse::take_n(bits_capacity(num_entries, num_bits_per_entry), i)?;
57            Ok((
58                i,
59                Self {
60                    num_entries,
61                    num_bits_per_entry,
62                    num_probes,
63                    bits: bits.to_vec(),
64                },
65            ))
66        }
67    }
68
69    fn get_probes(&self, hash: &ChangeHash) -> Vec<u32> {
70        let hash_bytes = hash.0;
71        let modulo = 8 * self.bits.len() as u32;
72
73        let mut x =
74            u32::from_le_bytes([hash_bytes[0], hash_bytes[1], hash_bytes[2], hash_bytes[3]])
75                % modulo;
76        let mut y =
77            u32::from_le_bytes([hash_bytes[4], hash_bytes[5], hash_bytes[6], hash_bytes[7]])
78                % modulo;
79        let z = u32::from_le_bytes([hash_bytes[8], hash_bytes[9], hash_bytes[10], hash_bytes[11]])
80            % modulo;
81
82        let mut probes = Vec::with_capacity(self.num_probes as usize);
83        probes.push(x);
84        for _ in 1..self.num_probes {
85            x = (x + y) % modulo;
86            y = (y + z) % modulo;
87            probes.push(x);
88        }
89        probes
90    }
91
92    fn add_hash(&mut self, hash: &ChangeHash) {
93        for probe in self.get_probes(hash) {
94            self.set_bit(probe as usize);
95        }
96    }
97
98    fn set_bit(&mut self, probe: usize) {
99        if let Some(byte) = self.bits.get_mut(probe >> 3) {
100            *byte |= 1 << (probe & 7);
101        }
102    }
103
104    fn get_bit(&self, probe: usize) -> Option<u8> {
105        self.bits
106            .get(probe >> 3)
107            .map(|byte| byte & (1 << (probe & 7)))
108    }
109
110    #[inline(never)]
111    pub fn contains_hash(&self, hash: &ChangeHash) -> bool {
112        if self.num_entries == 0 {
113            false
114        } else {
115            for probe in self.get_probes(hash) {
116                if let Some(bit) = self.get_bit(probe as usize) {
117                    if bit == 0 {
118                        return false;
119                    }
120                }
121            }
122            true
123        }
124    }
125
126    pub fn from_hashes<H: Borrow<ChangeHash>>(hashes: impl ExactSizeIterator<Item = H>) -> Self {
127        let num_entries = hashes.len() as u32;
128        let num_bits_per_entry = BITS_PER_ENTRY;
129        let num_probes = NUM_PROBES;
130        let bits = vec![0; bits_capacity(num_entries, num_bits_per_entry)];
131        let mut filter = Self {
132            num_entries,
133            num_bits_per_entry,
134            num_probes,
135            bits,
136        };
137        for hash in hashes {
138            filter.add_hash(hash.borrow());
139        }
140        filter
141    }
142}
143
144fn bits_capacity(num_entries: u32, num_bits_per_entry: u32) -> usize {
145    let f = ((f64::from(num_entries) * f64::from(num_bits_per_entry)) / 8_f64).ceil();
146    f as usize
147}
148
149#[derive(thiserror::Error, Debug)]
150#[error("{0}")]
151pub struct DecodeError(String);
152
153impl TryFrom<&[u8]> for BloomFilter {
154    type Error = DecodeError;
155
156    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
157        Self::parse(parse::Input::new(bytes))
158            .map(|(_, b)| b)
159            .map_err(|e| DecodeError(e.to_string()))
160    }
161}