do-riblt 1.0.2

An implementation of rateless invertable bloom lookup tables
Documentation
//! An efficient implementation of Rateless Invertable Bloom Lookup Tables (riblt)
//! The paper can be found here: https://arxiv.org/abs/2402.02668
//!
//! This crate can be used to efficiently synchronize sets between two devices
//! Why use riblts?
//! 1. they are rateless (there are infinite symbols)
//! 2. they are universal (you don't need to know anything from the other side to begin with)
//! 3. they provide low transportation cost (only around 1.3 to 1.7 the amount of symbols need to be send)
//! 4. they provide low computation cost (by using mostly XOR operations)

use rapidhash::{
    rng::RapidRng,
    v3::{rapidhash_v3_seeded, DEFAULT_RAPID_SECRETS},
};
use std::{collections::HashMap, fmt::Debug, marker::PhantomData};

#[inline]
fn rapidhash(data: &[u8]) -> u64 {
    rapidhash_v3_seeded(data, &DEFAULT_RAPID_SECRETS)
}

/// A trait to implement on everything that can be an item for a `Encoder` or a `Decoder`.
///
/// Example implementation for `u64`:
/// ```
/// impl Symbol<8> for u64 {
///   fn to_bytes(&self) -> [u8; 8] {
///     self.to_be_bytes()
///   }
///   fn from_bytes(bytes: &[u8; 8]) -> Self {
///     Self::from_be_bytes(*bytes)
///   }
/// }
/// ```
pub trait Symbol<const N: usize> {
    /// Turn the value to a sequence of bytes
    fn to_bytes(&self) -> [u8; N];

    /// Get a value from a sequence of bytes
    fn from_bytes(bytes: &[u8; N]) -> Self;
}

/// A symbol to send from an encoder to a decoder
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CodedSymbol<const N: usize> {
    data: [u8; N],
    hash: u64,
}

/// The encoder to calculate `CodedSymbol`s from a set of items
pub struct Encoder<const N: usize> {
    next_index: u32,
    /// bytes -> (rng, next_index)
    state: HashMap<[u8; N], (rapidhash::rng::RapidRng, u32)>,
}

impl<const N: usize> Encoder<N> {
    /// Create a new decoder from an iterator of items
    pub fn new<T: Symbol<N>>(iter: impl Iterator<Item = T>) -> Self {
        let state = iter
            .map(|value| {
                let bytes = value.to_bytes();
                let hash = rapidhash(&bytes);
                let rng = RapidRng::new(hash);
                (bytes, (rng, 0))
            })
            .collect();
        Self {
            next_index: 0,
            state,
        }
    }

    /// Calculate the next `CodedSymbol`
    pub fn next_symbol(&mut self) -> CodedSymbol<N> {
        let mut data = [0u8; N];
        let mut hash = 0u64;
        for (bytes, (r, i)) in self
            .state
            .iter_mut()
            .filter(|(_, (_, i))| *i == self.next_index)
        {
            data.iter_mut().zip(bytes).for_each(|(a, b)| *a ^= *b);
            hash ^= rapidhash(bytes);
            Self::update_index(r, i);
        }
        self.next_index += 1;
        CodedSymbol { data, hash }
    }

    fn add_symbol(&mut self, data: [u8; N]) -> Vec<usize> {
        let hash = rapidhash(&data);
        let mut rng = RapidRng::new(hash);
        let mut index = 0;
        let mut indices = Vec::new();
        while index < self.next_index {
            indices.push(index as usize);
            Self::update_index(&mut rng, &mut index);
        }
        self.state.insert(data, (rng, index));
        indices
    }

    fn remove_symbol(&mut self, data: &[u8; N]) -> Vec<usize> {
        let hash = rapidhash(data);
        let mut rng = RapidRng::new(hash);
        let mut index = 0;
        let mut indices = Vec::new();
        while index < self.next_index {
            indices.push(index as usize);
            Self::update_index(&mut rng, &mut index);
        }
        self.state.remove(data);
        indices
    }

    fn update_index(r: &mut RapidRng, i: &mut u32) {
        //  Stolen from https://github.com/samWighton/rateless_iblt/blob/main/src/mapping.rs
        const TP32: f64 = (1u64 << 32) as f64;
        let diff = (*i as f64 + 1.5) * (TP32 / (r.next() as f64 + 1.0).sqrt() - 1.0);
        *i += diff.ceil() as u32;
    }

    fn contains(&self, data: &[u8; N]) -> bool {
        self.state.contains_key(data)
    }
}

impl<const N: usize> Iterator for Encoder<N> {
    type Item = CodedSymbol<N>;

    /// Returns the next symbol
    fn next(&mut self) -> Option<Self::Item> {
        Some(self.next_symbol())
    }
}

/// A cached version of the encoder
/// Useful to synchronize one set with multiple remotes at once (without calculating the symbols multiple times)
pub struct CachedEncoder<const N: usize> {
    encoder: Encoder<N>,
    cache: Vec<Option<Box<CodedSymbol<N>>>>,
}

impl<const N: usize> CachedEncoder<N> {
    const EMPTY_SYMBOL: CodedSymbol<N> = CodedSymbol {
        data: [0u8; N],
        hash: 0,
    };

    /// Create a new decoder from an iterator of items
    pub fn new<T: Symbol<N>>(iter: impl Iterator<Item = T>) -> Self {
        Self {
            encoder: Encoder::new(iter),
            cache: Vec::new(),
        }
    }

    /// Calculate all symbols up to `index` and cache them
    pub fn get(&mut self, index: usize) -> CodedSymbol<N> {
        loop {
            match self.cache.get(index).cloned() {
                None => self.cache.push(match self.encoder.next_symbol() {
                    s if s != Self::EMPTY_SYMBOL => Some(Box::new(s)),
                    _ => None,
                }),
                Some(val) => break val.as_deref().cloned().unwrap_or(Self::EMPTY_SYMBOL),
            }
        }
    }
}

/// A peeled symbol
#[derive(Debug)]
pub enum Peeled<T> {
    MissingLocal(T),
    MissingRemote(T),
}

/// The decoder to calculate set differences
pub struct Decoder<const N: usize, T: Symbol<N>> {
    encoder: Encoder<N>,
    symbols: Vec<CodedSymbol<N>>,
    done: bool,
    _marker: PhantomData<T>,
}

impl<'a, const N: usize, T: Symbol<N>> Decoder<N, T> {
    /// Create a new decoder with a local iterator of items (our set)
    pub fn new(local: impl Iterator<Item = T>) -> Self {
        Self {
            encoder: Encoder::new(local),
            symbols: Vec::new(),
            done: false,
            _marker: PhantomData,
        }
    }

    /// Consumes a received symbol and tries to peel as many items as possible.\
    /// Returns `true` when done and `false` if there are more items to peel
    pub fn next_symbol(&mut self, symbol: CodedSymbol<N>) -> (bool, Vec<Peeled<T>>) {
        if self.done {
            return (true, vec![]);
        }
        let mut local = self.encoder.next().unwrap();
        local
            .data
            .iter_mut()
            .zip(symbol.data)
            .for_each(|(a, b)| *a ^= b);
        local.hash ^= symbol.hash;
        self.symbols.push(local);
        (self.done, self.peel())
    }

    fn peel(&mut self) -> Vec<Peeled<T>> {
        let mut peeled = Vec::new();
        while let Some((i, pure_symbol)) = self
            .symbols
            .iter()
            .enumerate()
            .find(|(_, v)| rapidhash(&v.data) == v.hash)
            .map(|(i, s)| (i, s.clone()))
        {
            let missing_remote = self.encoder.contains(&pure_symbol.data);
            for i in self.encoder.add_symbol(pure_symbol.data) {
                if let Some(symbol) = self.symbols.get_mut(i) {
                    symbol
                        .data
                        .iter_mut()
                        .zip(pure_symbol.data)
                        .for_each(|(a, b)| *a ^= b);
                    symbol.hash ^= pure_symbol.hash;
                }
            }
            let t = T::from_bytes(&pure_symbol.data);
            peeled.push(if missing_remote {
                self.encoder.remove_symbol(&pure_symbol.data);
                Peeled::MissingRemote(t)
            } else {
                Peeled::MissingLocal(t)
            });
            if i == 0 {
                self.done = true;
                break;
            }
        }
        peeled
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashSet;

    impl Symbol<8> for u64 {
        fn to_bytes(&self) -> [u8; 8] {
            self.to_ne_bytes()
        }

        fn from_bytes(bytes: &[u8; 8]) -> Self {
            Self::from_ne_bytes(*bytes)
        }
    }

    #[test]
    fn test_riblt() {
        const SIZE: usize = 1000;
        let mut rng = RapidRng::default();
        let remote: HashSet<_> = HashSet::from_iter((0..SIZE).map(|_| rng.next() % SIZE as u64));
        let local = HashSet::from_iter((0..SIZE).map(|_| rng.next() % SIZE as u64));
        let diff = remote.symmetric_difference(&local).count();
        let elements = remote.union(&local).count();

        let mut encoder = Encoder::new(remote.clone().into_iter());

        let mut decoder = Decoder::new(local.clone().into_iter());
        let mut symbols = 0;
        let mut peeled = Vec::new();

        loop {
            let symbol = encoder.next().unwrap();
            symbols += 1;
            let (done, peeled_) = decoder.next_symbol(symbol);
            peeled.extend(peeled_);
            if done {
                break;
            }
        }

        let efficiency = symbols as f64 / diff as f64;
        dbg!(&peeled, elements, diff, symbols, efficiency);

        assert_eq!(
            remote.difference(&local).collect::<HashSet<_>>(),
            peeled
                .iter()
                .filter_map(|v| match v {
                    Peeled::MissingLocal(t) => Some(t),
                    _ => None,
                })
                .collect()
        );
        assert_eq!(
            local.difference(&remote).collect::<HashSet<_>>(),
            peeled
                .iter()
                .filter_map(|v| match v {
                    Peeled::MissingRemote(t) => Some(t),
                    _ => None,
                })
                .collect()
        );
    }
}