use rapidhash::{
rng::RapidRng,
v3::{rapidhash_v3_seeded, DEFAULT_RAPID_SECRETS},
};
use std::{collections::HashMap, fmt::Debug, marker::PhantomData};
#[inline]
fn rapidhash(data: &[u8]) -> u64 {
rapidhash_v3_seeded(data, &DEFAULT_RAPID_SECRETS)
}
pub trait Symbol<const N: usize> {
fn to_bytes(&self) -> [u8; N];
fn from_bytes(bytes: &[u8; N]) -> Self;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CodedSymbol<const N: usize> {
data: [u8; N],
hash: u64,
}
pub struct Encoder<const N: usize> {
next_index: u32,
state: HashMap<[u8; N], (rapidhash::rng::RapidRng, u32)>,
}
impl<const N: usize> Encoder<N> {
pub fn new<T: Symbol<N>>(iter: impl Iterator<Item = T>) -> Self {
let state = iter
.map(|value| {
let bytes = value.to_bytes();
let hash = rapidhash(&bytes);
let rng = RapidRng::new(hash);
(bytes, (rng, 0))
})
.collect();
Self {
next_index: 0,
state,
}
}
pub fn next_symbol(&mut self) -> CodedSymbol<N> {
let mut data = [0u8; N];
let mut hash = 0u64;
for (bytes, (r, i)) in self
.state
.iter_mut()
.filter(|(_, (_, i))| *i == self.next_index)
{
data.iter_mut().zip(bytes).for_each(|(a, b)| *a ^= *b);
hash ^= rapidhash(bytes);
Self::update_index(r, i);
}
self.next_index += 1;
CodedSymbol { data, hash }
}
fn add_symbol(&mut self, data: [u8; N]) -> Vec<usize> {
let hash = rapidhash(&data);
let mut rng = RapidRng::new(hash);
let mut index = 0;
let mut indices = Vec::new();
while index < self.next_index {
indices.push(index as usize);
Self::update_index(&mut rng, &mut index);
}
self.state.insert(data, (rng, index));
indices
}
fn remove_symbol(&mut self, data: &[u8; N]) -> Vec<usize> {
let hash = rapidhash(data);
let mut rng = RapidRng::new(hash);
let mut index = 0;
let mut indices = Vec::new();
while index < self.next_index {
indices.push(index as usize);
Self::update_index(&mut rng, &mut index);
}
self.state.remove(data);
indices
}
fn update_index(r: &mut RapidRng, i: &mut u32) {
const TP32: f64 = (1u64 << 32) as f64;
let diff = (*i as f64 + 1.5) * (TP32 / (r.next() as f64 + 1.0).sqrt() - 1.0);
*i += diff.ceil() as u32;
}
fn contains(&self, data: &[u8; N]) -> bool {
self.state.contains_key(data)
}
}
impl<const N: usize> Iterator for Encoder<N> {
type Item = CodedSymbol<N>;
fn next(&mut self) -> Option<Self::Item> {
Some(self.next_symbol())
}
}
pub struct CachedEncoder<const N: usize> {
encoder: Encoder<N>,
cache: Vec<Option<Box<CodedSymbol<N>>>>,
}
impl<const N: usize> CachedEncoder<N> {
const EMPTY_SYMBOL: CodedSymbol<N> = CodedSymbol {
data: [0u8; N],
hash: 0,
};
pub fn new<T: Symbol<N>>(iter: impl Iterator<Item = T>) -> Self {
Self {
encoder: Encoder::new(iter),
cache: Vec::new(),
}
}
pub fn get(&mut self, index: usize) -> CodedSymbol<N> {
loop {
match self.cache.get(index).cloned() {
None => self.cache.push(match self.encoder.next_symbol() {
s if s != Self::EMPTY_SYMBOL => Some(Box::new(s)),
_ => None,
}),
Some(val) => break val.as_deref().cloned().unwrap_or(Self::EMPTY_SYMBOL),
}
}
}
}
#[derive(Debug)]
pub enum Peeled<T> {
MissingLocal(T),
MissingRemote(T),
}
pub struct Decoder<const N: usize, T: Symbol<N>> {
encoder: Encoder<N>,
symbols: Vec<CodedSymbol<N>>,
done: bool,
_marker: PhantomData<T>,
}
impl<'a, const N: usize, T: Symbol<N>> Decoder<N, T> {
pub fn new(local: impl Iterator<Item = T>) -> Self {
Self {
encoder: Encoder::new(local),
symbols: Vec::new(),
done: false,
_marker: PhantomData,
}
}
pub fn next_symbol(&mut self, symbol: CodedSymbol<N>) -> (bool, Vec<Peeled<T>>) {
if self.done {
return (true, vec![]);
}
let mut local = self.encoder.next().unwrap();
local
.data
.iter_mut()
.zip(symbol.data)
.for_each(|(a, b)| *a ^= b);
local.hash ^= symbol.hash;
self.symbols.push(local);
(self.done, self.peel())
}
fn peel(&mut self) -> Vec<Peeled<T>> {
let mut peeled = Vec::new();
while let Some((i, pure_symbol)) = self
.symbols
.iter()
.enumerate()
.find(|(_, v)| rapidhash(&v.data) == v.hash)
.map(|(i, s)| (i, s.clone()))
{
let missing_remote = self.encoder.contains(&pure_symbol.data);
for i in self.encoder.add_symbol(pure_symbol.data) {
if let Some(symbol) = self.symbols.get_mut(i) {
symbol
.data
.iter_mut()
.zip(pure_symbol.data)
.for_each(|(a, b)| *a ^= b);
symbol.hash ^= pure_symbol.hash;
}
}
let t = T::from_bytes(&pure_symbol.data);
peeled.push(if missing_remote {
self.encoder.remove_symbol(&pure_symbol.data);
Peeled::MissingRemote(t)
} else {
Peeled::MissingLocal(t)
});
if i == 0 {
self.done = true;
break;
}
}
peeled
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
impl Symbol<8> for u64 {
fn to_bytes(&self) -> [u8; 8] {
self.to_ne_bytes()
}
fn from_bytes(bytes: &[u8; 8]) -> Self {
Self::from_ne_bytes(*bytes)
}
}
#[test]
fn test_riblt() {
const SIZE: usize = 1000;
let mut rng = RapidRng::default();
let remote: HashSet<_> = HashSet::from_iter((0..SIZE).map(|_| rng.next() % SIZE as u64));
let local = HashSet::from_iter((0..SIZE).map(|_| rng.next() % SIZE as u64));
let diff = remote.symmetric_difference(&local).count();
let elements = remote.union(&local).count();
let mut encoder = Encoder::new(remote.clone().into_iter());
let mut decoder = Decoder::new(local.clone().into_iter());
let mut symbols = 0;
let mut peeled = Vec::new();
loop {
let symbol = encoder.next().unwrap();
symbols += 1;
let (done, peeled_) = decoder.next_symbol(symbol);
peeled.extend(peeled_);
if done {
break;
}
}
let efficiency = symbols as f64 / diff as f64;
dbg!(&peeled, elements, diff, symbols, efficiency);
assert_eq!(
remote.difference(&local).collect::<HashSet<_>>(),
peeled
.iter()
.filter_map(|v| match v {
Peeled::MissingLocal(t) => Some(t),
_ => None,
})
.collect()
);
assert_eq!(
local.difference(&remote).collect::<HashSet<_>>(),
peeled
.iter()
.filter_map(|v| match v {
Peeled::MissingRemote(t) => Some(t),
_ => None,
})
.collect()
);
}
}