1use rapidhash::{
12 rng::RapidRng,
13 v3::{rapidhash_v3_seeded, DEFAULT_RAPID_SECRETS},
14};
15use std::{collections::HashMap, fmt::Debug, marker::PhantomData};
16
17#[inline]
18fn rapidhash(data: &[u8]) -> u64 {
19 rapidhash_v3_seeded(data, &DEFAULT_RAPID_SECRETS)
20}
21
22pub trait Symbol<const N: usize> {
36 fn to_bytes(&self) -> [u8; N];
38
39 fn from_bytes(bytes: &[u8; N]) -> Self;
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct CodedSymbol<const N: usize> {
46 data: [u8; N],
47 hash: u64,
48}
49
50pub struct Encoder<const N: usize> {
52 next_index: u32,
53 state: HashMap<[u8; N], (rapidhash::rng::RapidRng, u32)>,
55}
56
57impl<const N: usize> Encoder<N> {
58 pub fn new<T: Symbol<N>>(iter: impl Iterator<Item = T>) -> Self {
60 let state = iter
61 .map(|value| {
62 let bytes = value.to_bytes();
63 let hash = rapidhash(&bytes);
64 let rng = RapidRng::new(hash);
65 (bytes, (rng, 0))
66 })
67 .collect();
68 Self {
69 next_index: 0,
70 state,
71 }
72 }
73
74 pub fn next_symbol(&mut self) -> CodedSymbol<N> {
76 let mut data = [0u8; N];
77 let mut hash = 0u64;
78 for (bytes, (r, i)) in self
79 .state
80 .iter_mut()
81 .filter(|(_, (_, i))| *i == self.next_index)
82 {
83 data.iter_mut().zip(bytes).for_each(|(a, b)| *a ^= *b);
84 hash ^= rapidhash(bytes);
85 Self::update_index(r, i);
86 }
87 self.next_index += 1;
88 CodedSymbol { data, hash }
89 }
90
91 fn add_symbol(&mut self, data: [u8; N]) -> Vec<usize> {
92 let hash = rapidhash(&data);
93 let mut rng = RapidRng::new(hash);
94 let mut index = 0;
95 let mut indices = Vec::new();
96 while index < self.next_index {
97 indices.push(index as usize);
98 Self::update_index(&mut rng, &mut index);
99 }
100 self.state.insert(data, (rng, index));
101 indices
102 }
103
104 fn remove_symbol(&mut self, data: &[u8; N]) -> Vec<usize> {
105 let hash = rapidhash(data);
106 let mut rng = RapidRng::new(hash);
107 let mut index = 0;
108 let mut indices = Vec::new();
109 while index < self.next_index {
110 indices.push(index as usize);
111 Self::update_index(&mut rng, &mut index);
112 }
113 self.state.remove(data);
114 indices
115 }
116
117 fn update_index(r: &mut RapidRng, i: &mut u32) {
118 const TP32: f64 = (1u64 << 32) as f64;
120 let diff = (*i as f64 + 1.5) * (TP32 / (r.next() as f64 + 1.0).sqrt() - 1.0);
121 *i += diff.ceil() as u32;
122 }
123
124 fn contains(&self, data: &[u8; N]) -> bool {
125 self.state.contains_key(data)
126 }
127}
128
129impl<const N: usize> Iterator for Encoder<N> {
130 type Item = CodedSymbol<N>;
131
132 fn next(&mut self) -> Option<Self::Item> {
134 Some(self.next_symbol())
135 }
136}
137
138pub struct CachedEncoder<const N: usize> {
141 encoder: Encoder<N>,
142 cache: Vec<Option<Box<CodedSymbol<N>>>>,
143}
144
145impl<const N: usize> CachedEncoder<N> {
146 const EMPTY_SYMBOL: CodedSymbol<N> = CodedSymbol {
147 data: [0u8; N],
148 hash: 0,
149 };
150
151 pub fn new<T: Symbol<N>>(iter: impl Iterator<Item = T>) -> Self {
153 Self {
154 encoder: Encoder::new(iter),
155 cache: Vec::new(),
156 }
157 }
158
159 pub fn get(&mut self, index: usize) -> CodedSymbol<N> {
161 loop {
162 match self.cache.get(index).cloned() {
163 None => self.cache.push(match self.encoder.next_symbol() {
164 s if s != Self::EMPTY_SYMBOL => Some(Box::new(s)),
165 _ => None,
166 }),
167 Some(val) => break val.as_deref().cloned().unwrap_or(Self::EMPTY_SYMBOL),
168 }
169 }
170 }
171}
172
173#[derive(Debug)]
175pub enum Peeled<T> {
176 MissingLocal(T),
177 MissingRemote(T),
178}
179
180pub struct Decoder<const N: usize, T: Symbol<N>> {
182 encoder: Encoder<N>,
183 symbols: Vec<CodedSymbol<N>>,
184 done: bool,
185 _marker: PhantomData<T>,
186}
187
188impl<'a, const N: usize, T: Symbol<N>> Decoder<N, T> {
189 pub fn new(local: impl Iterator<Item = T>) -> Self {
191 Self {
192 encoder: Encoder::new(local),
193 symbols: Vec::new(),
194 done: false,
195 _marker: PhantomData,
196 }
197 }
198
199 pub fn next_symbol(&mut self, symbol: CodedSymbol<N>) -> (bool, Vec<Peeled<T>>) {
202 if self.done {
203 return (true, vec![]);
204 }
205 let mut local = self.encoder.next().unwrap();
206 local
207 .data
208 .iter_mut()
209 .zip(symbol.data)
210 .for_each(|(a, b)| *a ^= b);
211 local.hash ^= symbol.hash;
212 self.symbols.push(local);
213 (self.done, self.peel())
214 }
215
216 fn peel(&mut self) -> Vec<Peeled<T>> {
217 let mut peeled = Vec::new();
218 while let Some((i, pure_symbol)) = self
219 .symbols
220 .iter()
221 .enumerate()
222 .find(|(_, v)| rapidhash(&v.data) == v.hash)
223 .map(|(i, s)| (i, s.clone()))
224 {
225 let missing_remote = self.encoder.contains(&pure_symbol.data);
226 for i in self.encoder.add_symbol(pure_symbol.data) {
227 if let Some(symbol) = self.symbols.get_mut(i) {
228 symbol
229 .data
230 .iter_mut()
231 .zip(pure_symbol.data)
232 .for_each(|(a, b)| *a ^= b);
233 symbol.hash ^= pure_symbol.hash;
234 }
235 }
236 let t = T::from_bytes(&pure_symbol.data);
237 peeled.push(if missing_remote {
238 self.encoder.remove_symbol(&pure_symbol.data);
239 Peeled::MissingRemote(t)
240 } else {
241 Peeled::MissingLocal(t)
242 });
243 if i == 0 {
244 self.done = true;
245 break;
246 }
247 }
248 peeled
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use std::collections::HashSet;
256
257 impl Symbol<8> for u64 {
258 fn to_bytes(&self) -> [u8; 8] {
259 self.to_ne_bytes()
260 }
261
262 fn from_bytes(bytes: &[u8; 8]) -> Self {
263 Self::from_ne_bytes(*bytes)
264 }
265 }
266
267 #[test]
268 fn test_riblt() {
269 const SIZE: usize = 1000;
270 let mut rng = RapidRng::default();
271 let remote: HashSet<_> = HashSet::from_iter((0..SIZE).map(|_| rng.next() % SIZE as u64));
272 let local = HashSet::from_iter((0..SIZE).map(|_| rng.next() % SIZE as u64));
273 let diff = remote.symmetric_difference(&local).count();
274 let elements = remote.union(&local).count();
275
276 let mut encoder = Encoder::new(remote.clone().into_iter());
277
278 let mut decoder = Decoder::new(local.clone().into_iter());
279 let mut symbols = 0;
280 let mut peeled = Vec::new();
281
282 loop {
283 let symbol = encoder.next().unwrap();
284 symbols += 1;
285 let (done, peeled_) = decoder.next_symbol(symbol);
286 peeled.extend(peeled_);
287 if done {
288 break;
289 }
290 }
291
292 let efficiency = symbols as f64 / diff as f64;
293 dbg!(&peeled, elements, diff, symbols, efficiency);
294
295 assert_eq!(
296 remote.difference(&local).collect::<HashSet<_>>(),
297 peeled
298 .iter()
299 .filter_map(|v| match v {
300 Peeled::MissingLocal(t) => Some(t),
301 _ => None,
302 })
303 .collect()
304 );
305 assert_eq!(
306 local.difference(&remote).collect::<HashSet<_>>(),
307 peeled
308 .iter()
309 .filter_map(|v| match v {
310 Peeled::MissingRemote(t) => Some(t),
311 _ => None,
312 })
313 .collect()
314 );
315 }
316}