riblt/
symbol.rs

1use serde::{Deserialize, Serialize};
2use std::fmt::Debug;
3use std::hash::{DefaultHasher, Hash, Hasher};
4use std::marker::PhantomData;
5
6/// A symbol is an item in the set
7pub trait Symbol: Clone + Debug {
8    const BYTE_ARRAY_LENGTH: usize;
9
10    fn encode_to_bytes(&self) -> Vec<u8>;
11    fn decode_from_bytes(bytes: &Vec<u8>) -> Self;
12
13    /// hash_() calculates the hash of the symbol.
14    /// This implementation can be overridden if needed.
15    fn hash_(&self) -> u64 {
16        let encoded = self.encode_to_bytes();
17        let mut hasher = DefaultHasher::new();
18        encoded.hash(&mut hasher);
19        hasher.finish()
20    }
21}
22
23/// A RIBLT is an infinite sequence of CodedSymbols
24///
25/// The 'sum' field is the XOR of the symbols encoded into this CodedSymbol
26///
27/// The 'hash' field is the XOR of the hashes of the symbols encoded into this CodedSymbol
28///
29/// The 'count' field is the number of local symbols minus the number of remote symbols
30///
31/// The '_marker' phantom field is used to allow us to associate the CodedSymbol with a specific Symbol type.
32/// The type T is used by implemented methods to know what type of Symbol is encoded in the CodedSymbol.
33///
34/// A CodedSymbol can be peeled when the count is 1 or -1 and the hash matches
35#[derive(Clone, Debug, Serialize, Deserialize)]
36pub struct CodedSymbol<T: Symbol> {
37    _marker: PhantomData<T>,
38    pub sum: Vec<u8>,
39    pub hash: u64,
40    pub count: i64,
41}
42
43/// If a symbol can be successfully 'peeled' out of a CodedSymbol, it is returned in the PeelableResult.
44///
45/// This enum acts as a wrapper to keep track of if the symbol was local or remote.
46///
47/// It is very common that the symbol is not peelable, so this enum also has the NotPeelable variant.
48#[derive(PartialEq, Eq, Clone, Debug)]
49pub enum PeelableResult<T: Symbol> {
50    Local(T),
51    Remote(T),
52    NotPeelable,
53}
54
55#[derive(Clone)]
56pub enum Direction {
57    Add,
58    Remove,
59}
60
61impl<T: Symbol> CodedSymbol<T> {
62    pub fn new() -> Self {
63        let sum = vec![0u8; T::BYTE_ARRAY_LENGTH];
64        let hash = 0;
65        let count = 0;
66        CodedSymbol {
67            _marker: PhantomData,
68            sum,
69            hash,
70            count,
71        }
72    }
73
74    /// apply() adds or removes a symbol from the CodedSymbol
75    ///
76    /// Adding a local, or removing a remote, symbol increases the count by 1
77    ///
78    /// Removing a local, or adding a remote, symbol decreases the count by 1
79    pub fn apply(&mut self, s: &T, direction: Direction) {
80        //It might be nice to split this into an 'add' and 'remove'
81        assert_eq!(
82            self.sum.len(),
83            T::BYTE_ARRAY_LENGTH,
84            "self.sum must have the length specified by T::BYTE_ARRAY_LENGTH."
85        );
86        let encoded_s = s.encode_to_bytes();
87
88        assert_eq!(
89            encoded_s.len(),
90            T::BYTE_ARRAY_LENGTH,
91            "encoded_s must have the length specified by T::BYTE_ARRAY_LENGTH."
92        );
93
94        // Should be able to update in place here
95        self.sum = self
96            .sum
97            .iter()
98            .zip(encoded_s.iter())
99            .map(|(x, y)| x ^ y)
100            .collect();
101
102        self.hash ^= s.hash_();
103        match direction {
104            Direction::Add => self.count += 1,
105            Direction::Remove => self.count -= 1,
106        };
107    }
108
109    /// Used by the encoder to join two vectors of codedSymbols together produced from two distinct sets.
110    /// The results are only valid if there were no duplicates between the original sets.
111    pub fn combine(&self, b: &CodedSymbol<T>) -> CodedSymbol<T> {
112        assert_eq!(
113            self.sum.len(),
114            T::BYTE_ARRAY_LENGTH,
115            "self.sum must have the length specified by T::BYTE_ARRAY_LENGTH."
116        );
117        assert_eq!(
118            b.sum.len(),
119            T::BYTE_ARRAY_LENGTH,
120            "encoded_s must have the length specified by T::BYTE_ARRAY_LENGTH."
121        );
122
123        let mut new_coded_symbol = self.clone();
124
125        new_coded_symbol.hash ^= b.hash;
126        new_coded_symbol.count += b.count;
127
128        new_coded_symbol.sum = self
129            .sum
130            .iter()
131            .zip(b.sum.iter())
132            .map(|(x, y)| x ^ y)
133            .collect();
134
135        new_coded_symbol
136    }
137
138    /// Used by the encoder to 'subtract' a remote set of codedSymbols from a local set.
139    pub fn collapse(&self, b: &CodedSymbol<T>) -> CodedSymbol<T> {
140        assert_eq!(
141            self.sum.len(),
142            T::BYTE_ARRAY_LENGTH,
143            "self.sum must have the length specified by T::BYTE_ARRAY_LENGTH."
144        );
145        assert_eq!(
146            b.sum.len(),
147            T::BYTE_ARRAY_LENGTH,
148            "encoded_s must have the length specified by T::BYTE_ARRAY_LENGTH."
149        );
150
151        let mut new_coded_symbol = self.clone();
152
153        // new_coded_symbol.symbol = new_coded_symbol.symbol.xor(&b.symbol);
154        new_coded_symbol.hash ^= b.hash;
155        new_coded_symbol.count -= b.count;
156
157        new_coded_symbol.sum = self
158            .sum
159            .iter()
160            .zip(b.sum.iter())
161            .map(|(x, y)| x ^ y)
162            .collect();
163
164        new_coded_symbol
165    }
166
167    /// Checks if the CodedSymbol contains only one symbol and therefore can be peeled
168    ///
169    /// A count of 1 does not necessarily mean that the 'sum' field is the xor of only one encoded
170    /// symbol. It could be the xor of two local and one remote symbols. This is why we also
171    /// check the hash.
172    pub fn is_peelable(&self) -> bool {
173        if self.count == 1 || self.count == -1 {
174            if self.hash == T::decode_from_bytes(&self.sum).hash_() {
175                return true;
176            }
177        }
178        return false;
179    }
180
181    /// Peel extracts a symbol from the CodedSymbol (if possible) and returns it in a PeelableResult
182    /// A PeelableResult is used to keep track of if the symbol was local or remote (or was not
183    /// able to be peeled).
184    pub fn peel(&mut self) -> PeelableResult<T> {
185        if self.is_peelable() {
186            let return_result = if self.count == 1 {
187                PeelableResult::Local(T::decode_from_bytes(&self.sum))
188            } else {
189                PeelableResult::Remote(T::decode_from_bytes(&self.sum))
190            };
191
192            *self = CodedSymbol::new();
193            return return_result;
194        }
195        PeelableResult::NotPeelable
196    }
197    /// same as peel, but does not modify the CodedSymbol
198    pub fn peel_peek(&self) -> PeelableResult<T> {
199        if self.is_peelable() {
200            let return_result = if self.count == 1 {
201                PeelableResult::Local(T::decode_from_bytes(&self.sum))
202            } else {
203                PeelableResult::Remote(T::decode_from_bytes(&self.sum))
204            };
205
206            return return_result;
207        }
208        PeelableResult::NotPeelable
209    }
210
211    /// Checks if the CodedSymbol contains no symbols
212    pub fn is_empty(&self) -> bool {
213        // if self.symbol != T::empty() {
214        //     return false;
215        // }
216        if self.count != 0 {
217            return false;
218        }
219        if self.hash != 0 {
220            return false;
221        }
222        true
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use crate::test_helpers::SimpleSymbol;
230
231    #[test]
232    fn test_symbol() {
233        let symbol1 = SimpleSymbol { value: 42 };
234        let symbol2 = SimpleSymbol { value: 100 };
235
236        let mut coded_symbol = CodedSymbol::new();
237
238        println!("0 is peelable {}", coded_symbol.is_peelable());
239        assert_eq!(coded_symbol.is_peelable(), false);
240
241        coded_symbol.apply(&symbol1, Direction::Add);
242        println!("1 is peelable {}", coded_symbol.is_peelable());
243        assert_eq!(coded_symbol.is_peelable(), true);
244
245        coded_symbol.apply(&symbol2, Direction::Add);
246        println!("2 is peelable {}", coded_symbol.is_peelable());
247        assert_eq!(coded_symbol.is_peelable(), false);
248
249        coded_symbol.apply(&symbol1, Direction::Remove);
250        println!("3 is peelable {}", coded_symbol.is_peelable());
251        assert_eq!(coded_symbol.is_peelable(), true);
252
253        println!("CodedSymbol: {:?}", coded_symbol);
254
255        let peeled_symbol = coded_symbol.peel();
256        match peeled_symbol {
257            PeelableResult::Local(symbol) => {
258                println!("Peeled Local Symbol: {:?}", symbol);
259                assert_eq!(symbol.value, symbol2.value);
260            }
261            PeelableResult::Remote(symbol) => {
262                println!("Peeled Remote Symbol: {:?}", symbol);
263                assert_eq!(symbol.value, symbol2.value);
264            }
265            PeelableResult::NotPeelable => {
266                println!("No symbol to peel");
267                assert!(false);
268            }
269        }
270    }
271}