Skip to main content

do_riblt/
lib.rs

1//! An efficient implementation of Rateless Invertable Bloom Lookup Tables (riblt)
2//! The paper can be found here: https://arxiv.org/abs/2402.02668
3//!
4//! This crate can be used to efficiently synchronize sets between two devices
5//! Why use riblts?
6//! 1. they are rateless (there are infinite symbols)
7//! 2. they are universal (you don't need to know anything from the other side to begin with)
8//! 3. they provide low transportation cost (only around 1.3 to 1.7 the amount of symbols need to be send)
9//! 4. they provide low computation cost (by using mostly XOR operations)
10
11use rapidhash::{
12    rng::RapidRng,
13    v3::{rapidhash_v3_seeded, DEFAULT_RAPID_SECRETS},
14};
15use std::{collections::HashMap, fmt::Debug, marker::PhantomData};
16
17#[inline]
18fn rapidhash(data: &[u8]) -> u64 {
19    rapidhash_v3_seeded(data, &DEFAULT_RAPID_SECRETS)
20}
21
22/// A trait to implement on everything that can be an item for a `Encoder` or a `Decoder`.
23///
24/// Example implementation for `u64`:
25/// ```
26/// impl Symbol<8> for u64 {
27///   fn to_bytes(&self) -> [u8; 8] {
28///     self.to_be_bytes()
29///   }
30///   fn from_bytes(bytes: &[u8; 8]) -> Self {
31///     Self::from_be_bytes(*bytes)
32///   }
33/// }
34/// ```
35pub trait Symbol<const N: usize> {
36    /// Turn the value to a sequence of bytes
37    fn to_bytes(&self) -> [u8; N];
38
39    /// Get a value from a sequence of bytes
40    fn from_bytes(bytes: &[u8; N]) -> Self;
41}
42
43/// A symbol to send from an encoder to a decoder
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct CodedSymbol<const N: usize> {
46    data: [u8; N],
47    hash: u64,
48}
49
50/// The encoder to calculate `CodedSymbol`s from a set of items
51pub struct Encoder<const N: usize> {
52    next_index: u32,
53    /// bytes -> (rng, next_index)
54    state: HashMap<[u8; N], (rapidhash::rng::RapidRng, u32)>,
55}
56
57impl<const N: usize> Encoder<N> {
58    /// Create a new decoder from an iterator of items
59    pub fn new<T: Symbol<N>>(iter: impl Iterator<Item = T>) -> Self {
60        let state = iter
61            .map(|value| {
62                let bytes = value.to_bytes();
63                let hash = rapidhash(&bytes);
64                let rng = RapidRng::new(hash);
65                (bytes, (rng, 0))
66            })
67            .collect();
68        Self {
69            next_index: 0,
70            state,
71        }
72    }
73
74    /// Calculate the next `CodedSymbol`
75    pub fn next_symbol(&mut self) -> CodedSymbol<N> {
76        let mut data = [0u8; N];
77        let mut hash = 0u64;
78        for (bytes, (r, i)) in self
79            .state
80            .iter_mut()
81            .filter(|(_, (_, i))| *i == self.next_index)
82        {
83            data.iter_mut().zip(bytes).for_each(|(a, b)| *a ^= *b);
84            hash ^= rapidhash(bytes);
85            Self::update_index(r, i);
86        }
87        self.next_index += 1;
88        CodedSymbol { data, hash }
89    }
90
91    fn add_symbol(&mut self, data: [u8; N]) -> Vec<usize> {
92        let hash = rapidhash(&data);
93        let mut rng = RapidRng::new(hash);
94        let mut index = 0;
95        let mut indices = Vec::new();
96        while index < self.next_index {
97            indices.push(index as usize);
98            Self::update_index(&mut rng, &mut index);
99        }
100        self.state.insert(data, (rng, index));
101        indices
102    }
103
104    fn remove_symbol(&mut self, data: &[u8; N]) -> Vec<usize> {
105        let hash = rapidhash(data);
106        let mut rng = RapidRng::new(hash);
107        let mut index = 0;
108        let mut indices = Vec::new();
109        while index < self.next_index {
110            indices.push(index as usize);
111            Self::update_index(&mut rng, &mut index);
112        }
113        self.state.remove(data);
114        indices
115    }
116
117    fn update_index(r: &mut RapidRng, i: &mut u32) {
118        //  Stolen from https://github.com/samWighton/rateless_iblt/blob/main/src/mapping.rs
119        const TP32: f64 = (1u64 << 32) as f64;
120        let diff = (*i as f64 + 1.5) * (TP32 / (r.next() as f64 + 1.0).sqrt() - 1.0);
121        *i += diff.ceil() as u32;
122    }
123
124    fn contains(&self, data: &[u8; N]) -> bool {
125        self.state.contains_key(data)
126    }
127}
128
129impl<const N: usize> Iterator for Encoder<N> {
130    type Item = CodedSymbol<N>;
131
132    /// Returns the next symbol
133    fn next(&mut self) -> Option<Self::Item> {
134        Some(self.next_symbol())
135    }
136}
137
138/// A cached version of the encoder
139/// Useful to synchronize one set with multiple remotes at once (without calculating the symbols multiple times)
140pub struct CachedEncoder<const N: usize> {
141    encoder: Encoder<N>,
142    cache: Vec<Option<Box<CodedSymbol<N>>>>,
143}
144
145impl<const N: usize> CachedEncoder<N> {
146    const EMPTY_SYMBOL: CodedSymbol<N> = CodedSymbol {
147        data: [0u8; N],
148        hash: 0,
149    };
150
151    /// Create a new decoder from an iterator of items
152    pub fn new<T: Symbol<N>>(iter: impl Iterator<Item = T>) -> Self {
153        Self {
154            encoder: Encoder::new(iter),
155            cache: Vec::new(),
156        }
157    }
158
159    /// Calculate all symbols up to `index` and cache them
160    pub fn get(&mut self, index: usize) -> CodedSymbol<N> {
161        loop {
162            match self.cache.get(index).cloned() {
163                None => self.cache.push(match self.encoder.next_symbol() {
164                    s if s != Self::EMPTY_SYMBOL => Some(Box::new(s)),
165                    _ => None,
166                }),
167                Some(val) => break val.as_deref().cloned().unwrap_or(Self::EMPTY_SYMBOL),
168            }
169        }
170    }
171}
172
173/// A peeled symbol
174#[derive(Debug)]
175pub enum Peeled<T> {
176    MissingLocal(T),
177    MissingRemote(T),
178}
179
180/// The decoder to calculate set differences
181pub struct Decoder<const N: usize, T: Symbol<N>> {
182    encoder: Encoder<N>,
183    symbols: Vec<CodedSymbol<N>>,
184    done: bool,
185    _marker: PhantomData<T>,
186}
187
188impl<'a, const N: usize, T: Symbol<N>> Decoder<N, T> {
189    /// Create a new decoder with a local iterator of items (our set)
190    pub fn new(local: impl Iterator<Item = T>) -> Self {
191        Self {
192            encoder: Encoder::new(local),
193            symbols: Vec::new(),
194            done: false,
195            _marker: PhantomData,
196        }
197    }
198
199    /// Consumes a received symbol and tries to peel as many items as possible.\
200    /// Returns `true` when done and `false` if there are more items to peel
201    pub fn next_symbol(&mut self, symbol: CodedSymbol<N>) -> (bool, Vec<Peeled<T>>) {
202        if self.done {
203            return (true, vec![]);
204        }
205        let mut local = self.encoder.next().unwrap();
206        local
207            .data
208            .iter_mut()
209            .zip(symbol.data)
210            .for_each(|(a, b)| *a ^= b);
211        local.hash ^= symbol.hash;
212        self.symbols.push(local);
213        (self.done, self.peel())
214    }
215
216    fn peel(&mut self) -> Vec<Peeled<T>> {
217        let mut peeled = Vec::new();
218        while let Some((i, pure_symbol)) = self
219            .symbols
220            .iter()
221            .enumerate()
222            .find(|(_, v)| rapidhash(&v.data) == v.hash)
223            .map(|(i, s)| (i, s.clone()))
224        {
225            let missing_remote = self.encoder.contains(&pure_symbol.data);
226            for i in self.encoder.add_symbol(pure_symbol.data) {
227                if let Some(symbol) = self.symbols.get_mut(i) {
228                    symbol
229                        .data
230                        .iter_mut()
231                        .zip(pure_symbol.data)
232                        .for_each(|(a, b)| *a ^= b);
233                    symbol.hash ^= pure_symbol.hash;
234                }
235            }
236            let t = T::from_bytes(&pure_symbol.data);
237            peeled.push(if missing_remote {
238                self.encoder.remove_symbol(&pure_symbol.data);
239                Peeled::MissingRemote(t)
240            } else {
241                Peeled::MissingLocal(t)
242            });
243            if i == 0 {
244                self.done = true;
245                break;
246            }
247        }
248        peeled
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use std::collections::HashSet;
256
257    impl Symbol<8> for u64 {
258        fn to_bytes(&self) -> [u8; 8] {
259            self.to_ne_bytes()
260        }
261
262        fn from_bytes(bytes: &[u8; 8]) -> Self {
263            Self::from_ne_bytes(*bytes)
264        }
265    }
266
267    #[test]
268    fn test_riblt() {
269        const SIZE: usize = 1000;
270        let mut rng = RapidRng::default();
271        let remote: HashSet<_> = HashSet::from_iter((0..SIZE).map(|_| rng.next() % SIZE as u64));
272        let local = HashSet::from_iter((0..SIZE).map(|_| rng.next() % SIZE as u64));
273        let diff = remote.symmetric_difference(&local).count();
274        let elements = remote.union(&local).count();
275
276        let mut encoder = Encoder::new(remote.clone().into_iter());
277
278        let mut decoder = Decoder::new(local.clone().into_iter());
279        let mut symbols = 0;
280        let mut peeled = Vec::new();
281
282        loop {
283            let symbol = encoder.next().unwrap();
284            symbols += 1;
285            let (done, peeled_) = decoder.next_symbol(symbol);
286            peeled.extend(peeled_);
287            if done {
288                break;
289            }
290        }
291
292        let efficiency = symbols as f64 / diff as f64;
293        dbg!(&peeled, elements, diff, symbols, efficiency);
294
295        assert_eq!(
296            remote.difference(&local).collect::<HashSet<_>>(),
297            peeled
298                .iter()
299                .filter_map(|v| match v {
300                    Peeled::MissingLocal(t) => Some(t),
301                    _ => None,
302                })
303                .collect()
304        );
305        assert_eq!(
306            local.difference(&remote).collect::<HashSet<_>>(),
307            peeled
308                .iter()
309                .filter_map(|v| match v {
310                    Peeled::MissingRemote(t) => Some(t),
311                    _ => None,
312                })
313                .collect()
314        );
315    }
316}