ciff/
binary_collection.rs

1use std::convert::TryFrom;
2use std::convert::TryInto;
3use std::error::Error;
4use std::fmt;
5use std::io::{self, Write};
6
7const ELEMENT_SIZE: usize = std::mem::size_of::<u32>();
8
9/// Error raised when the bytes cannot be properly parsed into the collection format.
10#[derive(Debug, Default, PartialEq, Eq)]
11pub struct InvalidFormat(Option<String>);
12
13impl InvalidFormat {
14    /// Constructs an error with a message.
15    pub fn new<S: Into<String>>(msg: S) -> Self {
16        Self(Some(msg.into()))
17    }
18}
19
20impl Error for InvalidFormat {}
21
22impl fmt::Display for InvalidFormat {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        write!(f, "Invalid binary collection format")?;
25        if let Some(msg) = &self.0 {
26            write!(f, ": {}", msg)?;
27        }
28        Ok(())
29    }
30}
31
32/// Represents a single binary collection.
33///
34/// A binary collection is a series of sequences, each starting with a 4-byte length, followed by
35/// this many 4-byte values.
36///
37/// # Examples
38///
39/// ```
40/// # use ciff::{encode_u32_sequence, BinaryCollection, InvalidFormat};
41/// # use std::convert::TryFrom;
42/// # fn main() -> Result<(), anyhow::Error> {
43/// let mut buffer: Vec<u8> = Vec::new();
44/// encode_u32_sequence(&mut buffer, 3, &[1, 2, 3])?;
45/// encode_u32_sequence(&mut buffer, 1, &[4])?;
46/// encode_u32_sequence(&mut buffer, 3, &[5, 6, 7])?;
47///
48/// // Binary collection is actually an iterator
49/// let mut collection = BinaryCollection::try_from(&buffer[..])?;
50/// assert_eq!(
51///     collection.next().unwrap().map(|seq| seq.iter().collect::<Vec<_>>()).ok(),
52///     Some(vec![1_u32, 2, 3])
53/// );
54/// assert_eq!(
55///     collection.next().unwrap().map(|seq| seq.iter().collect::<Vec<_>>()).ok(),
56///     Some(vec![4_u32])
57/// );
58/// assert_eq!(
59///     collection.next().unwrap().map(|seq| seq.iter().collect::<Vec<_>>()).ok(),
60///     Some(vec![5_u32, 6, 7])
61/// );
62///
63/// // Must create a new collection to iterate again.
64/// let collection = BinaryCollection::try_from(&buffer[..])?;
65/// let elements: Result<Vec<_>, InvalidFormat> = collection
66///     .map(|sequence| Ok(sequence?.iter().collect::<Vec<_>>()))
67///     .collect();
68/// assert_eq!(elements?, vec![vec![1_u32, 2, 3], vec![4], vec![5, 6, 7]]);
69/// # Ok(())
70/// # }
71/// ```
72#[derive(Debug, Clone, Copy)]
73pub struct BinaryCollection<'a> {
74    bytes: &'a [u8],
75}
76
77impl<'a> TryFrom<&'a [u8]> for BinaryCollection<'a> {
78    type Error = InvalidFormat;
79    fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
80        if bytes.len() % std::mem::size_of::<u32>() == 0 {
81            Ok(Self { bytes })
82        } else {
83            Err(InvalidFormat::new(
84                "The byte-length of the collection is not divisible by the element size (4)",
85            ))
86        }
87    }
88}
89
90fn get_from(bytes: &[u8]) -> Result<BinarySequence<'_>, InvalidFormat> {
91    let length_bytes = bytes
92        .get(..ELEMENT_SIZE)
93        .ok_or_else(InvalidFormat::default)?;
94    let length = u32::from_le_bytes(length_bytes.try_into().unwrap()) as usize;
95    let bytes = bytes
96        .get(ELEMENT_SIZE..(ELEMENT_SIZE * (length + 1)))
97        .ok_or_else(InvalidFormat::default)?;
98    Ok(BinarySequence { bytes, length })
99}
100
101fn get_next<'a>(
102    collection: &mut BinaryCollection<'a>,
103) -> Result<BinarySequence<'a>, InvalidFormat> {
104    let sequence = get_from(collection.bytes)?;
105    collection.bytes = &collection.bytes[ELEMENT_SIZE * (sequence.len() + 1)..];
106    Ok(sequence)
107}
108
109impl<'a> Iterator for BinaryCollection<'a> {
110    type Item = Result<BinarySequence<'a>, InvalidFormat>;
111
112    fn next(&mut self) -> Option<Self::Item> {
113        if self.bytes.is_empty() {
114            None
115        } else {
116            Some(get_next(self))
117        }
118    }
119}
120
121/// A version of [`BinaryCollection`] with random access to sequences.
122///
123/// Because the binary format underlying [`BinaryCollection`] does not
124/// support random access, implementing it requires precomputing memory
125/// offsets for the sequences, and storing them in the struct.
126/// This means [`RandomAccessBinaryCollection::try_from`] will have to
127/// perform one full pass through the entire collection to collect the
128/// offsets. Thus, use this class only if you need the random access
129/// functionality.
130///
131/// Note that the because offsets are stored within the struct, it is
132/// not `Copy` as opposed to [`BinaryCollection`], which is simply a view
133/// over a memory buffer.
134///
135/// # Examples
136///
137/// ```
138/// # use ciff::{encode_u32_sequence, RandomAccessBinaryCollection, InvalidFormat};
139/// # use std::convert::TryFrom;
140/// # fn main() -> Result<(), anyhow::Error> {
141/// let mut buffer: Vec<u8> = Vec::new();
142/// encode_u32_sequence(&mut buffer, 3, &[1, 2, 3])?;
143/// encode_u32_sequence(&mut buffer, 1, &[4])?;
144/// encode_u32_sequence(&mut buffer, 3, &[5, 6, 7])?;
145///
146/// let mut collection = RandomAccessBinaryCollection::try_from(&buffer[..])?;
147/// assert_eq!(
148///     collection.get(0).map(|seq| seq.iter().collect::<Vec<_>>()),
149///     Some(vec![1_u32, 2, 3]),
150/// );
151/// assert_eq!(
152///     collection.at(2).iter().collect::<Vec<_>>(),
153///     vec![5_u32, 6, 7],
154/// );
155/// assert_eq!(collection.get(3), None);
156/// # Ok(())
157/// # }
158/// ```
159///
160/// ```should_panic
161/// # use ciff::{encode_u32_sequence, RandomAccessBinaryCollection, InvalidFormat};
162/// # use std::convert::TryFrom;
163/// # fn main() -> Result<(), anyhow::Error> {
164/// # let mut buffer: Vec<u8> = Vec::new();
165/// # encode_u32_sequence(&mut buffer, 3, &[1, 2, 3])?;
166/// # encode_u32_sequence(&mut buffer, 1, &[4])?;
167/// # encode_u32_sequence(&mut buffer, 3, &[5, 6, 7])?;
168/// # let mut collection = RandomAccessBinaryCollection::try_from(&buffer[..])?;
169/// collection.at(3); // out of bounds
170/// # Ok(())
171/// # }
172/// ```
173#[derive(Debug, Clone)]
174pub struct RandomAccessBinaryCollection<'a> {
175    inner: BinaryCollection<'a>,
176    offsets: Vec<usize>,
177}
178
179impl<'a> TryFrom<&'a [u8]> for RandomAccessBinaryCollection<'a> {
180    type Error = InvalidFormat;
181    fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
182        let collection = BinaryCollection::try_from(bytes)?;
183        let offsets = collection
184            .map(|sequence| sequence.map(|s| s.len()))
185            .scan(0, |offset, len| {
186                Some(len.map(|len| {
187                    let result = *offset;
188                    *offset += ELEMENT_SIZE * (len + 1);
189                    result
190                }))
191            })
192            .collect::<Result<Vec<_>, _>>()?;
193
194        Ok(Self {
195            inner: collection,
196            offsets,
197        })
198    }
199}
200
201impl<'a> RandomAccessBinaryCollection<'a> {
202    /// Returns an iterator over sequences.
203    pub fn iter(&self) -> impl Iterator<Item = Result<BinarySequence<'a>, InvalidFormat>> {
204        self.inner
205    }
206
207    /// Returns the sequence at the given index.
208    ///
209    /// # Panics
210    ///
211    /// Panics if the index is out of bounds.
212    #[must_use]
213    pub fn at(&self, index: usize) -> BinarySequence<'a> {
214        if let Some(sequence) = self.get(index) {
215            sequence
216        } else {
217            panic!(
218                "out of bounds: requested {} out of {} elements",
219                index,
220                self.len()
221            );
222        }
223    }
224
225    /// Returns the sequence at the given index or `None` if out of bounds.
226    #[must_use]
227    pub fn get(&self, index: usize) -> Option<BinarySequence<'a>> {
228        let byte_offset = *self.offsets.get(index)?;
229        if let Ok(sequence) = get_from(self.inner.bytes.get(byte_offset..)?) {
230            Some(sequence)
231        } else {
232            // The following case should be unreachable, because when constructing
233            // the collection, we iterate through all sequences. Though there still
234            // can be an error when iterating the sequence elements, the sequence
235            // itself must be Ok.
236            unreachable!()
237        }
238    }
239
240    /// Returns the number of sequences in the collection.
241    #[must_use]
242    pub fn len(&self) -> usize {
243        self.offsets.len()
244    }
245
246    /// Checks if the collection is empty.
247    #[must_use]
248    pub fn is_empty(&self) -> bool {
249        self.offsets.len() == 0
250    }
251}
252
253/// A single binary sequence.
254///
255/// # Examples
256///
257/// ```
258/// # use ciff::BinarySequence;
259/// # use std::convert::TryFrom;
260/// # fn main() -> Result<(), ()> {
261/// let bytes: [u8; 16] = [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0];
262/// let sequence = BinarySequence::try_from(&bytes[..])?;
263/// assert_eq!(sequence.len(), 4);
264/// assert_eq!(sequence.get(0), Some(1));
265/// assert_eq!(sequence.get(1), Some(2));
266/// assert_eq!(sequence.get(2), Some(3));
267/// assert_eq!(sequence.get(3), Some(4));
268/// let elements: Vec<_> = sequence.iter().collect();
269/// assert_eq!(elements, vec![1_u32, 2, 3, 4]);
270/// # Ok(())
271/// # }
272/// ```
273#[derive(Debug, Clone, Copy, PartialEq, Eq)]
274pub struct BinarySequence<'a> {
275    /// All bytes, **excluding** the length bytes.
276    bytes: &'a [u8],
277    /// Length extracted from the first 4 bytes of the sequence.
278    length: usize,
279}
280
281impl<'a> TryFrom<&'a [u8]> for BinarySequence<'a> {
282    type Error = ();
283    /// Tries to construct a binary sequence from a slice of bytes.
284    ///
285    /// # Errors
286    ///
287    /// It will fail if the length of the slice is not divisible by 4.
288    ///
289    /// # Examples
290    ///
291    /// ```
292    /// # use ciff::BinarySequence;
293    /// # use std::convert::TryFrom;
294    /// # fn main() -> Result<(), ()> {
295    /// let bytes: [u8; 8] = [1, 0, 0, 0, 2, 0, 0, 0];
296    /// let sequence = BinarySequence::try_from(&bytes[..])?;
297    /// assert_eq!(sequence.len(), 2);
298    /// # Ok(())
299    /// # }
300    /// ```
301    fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
302        if bytes.len() % std::mem::size_of::<u32>() == 0 {
303            let length = bytes.len() / std::mem::size_of::<u32>();
304            Ok(Self { bytes, length })
305        } else {
306            Err(())
307        }
308    }
309}
310
311/// # Safety
312///
313/// The length of `bytes` must be 4.
314unsafe fn bytes_to_u32(bytes: &[u8]) -> u32 {
315    let mut value: std::mem::MaybeUninit<[u8; 4]> = std::mem::MaybeUninit::uninit();
316    value
317        .as_mut_ptr()
318        .copy_from_nonoverlapping(bytes.as_ptr().cast(), 1);
319    u32::from_le_bytes(value.assume_init())
320}
321
322impl<'a> BinarySequence<'a> {
323    /// Returns the number of elements in the sequence.
324    #[must_use]
325    pub fn len(&self) -> usize {
326        self.length
327    }
328
329    /// Checks if the sequence is empty.
330    #[must_use]
331    pub fn is_empty(&self) -> bool {
332        self.length == 0
333    }
334
335    /// Returns `index`-th element of the sequence or `None` if `index` is out of bounds.
336    #[must_use]
337    pub fn get(&self, index: usize) -> Option<u32> {
338        if index < self.len() {
339            let offset = index * std::mem::size_of::<u32>();
340            self.bytes.get(offset..offset + 4).map(|bytes| {
341                // SAFETY: it is safe because if `get` returns `Some`, the slice must be of length 4.
342                unsafe { bytes_to_u32(bytes) }
343            })
344        } else {
345            None
346        }
347    }
348
349    /// An iterator over all sequence elements.
350    #[must_use]
351    pub fn iter(&'a self) -> BinarySequenceIterator<'a> {
352        BinarySequenceIterator {
353            sequence: self,
354            index: 0,
355        }
356    }
357
358    /// Returns the byte slice of the sequence. This **does not** include the length.
359    #[must_use]
360    pub fn bytes(&'a self) -> &'a [u8] {
361        self.bytes
362    }
363}
364
365pub struct BinarySequenceIterator<'a> {
366    sequence: &'a BinarySequence<'a>,
367    index: usize,
368}
369
370impl<'a> Iterator for BinarySequenceIterator<'a> {
371    type Item = u32;
372
373    fn next(&mut self) -> Option<Self::Item> {
374        let index = self.index;
375        self.index += 1;
376        self.sequence.get(index)
377    }
378}
379
380/// Reorders a collection according to the given order.
381///
382/// The new collection will be written to `output`, such that a sequence at position `i`
383/// in `collection` will be at position `order[i]` in the new collection.
384pub fn reorder<W: Write>(
385    collection: &RandomAccessBinaryCollection<'_>,
386    order: &[usize],
387    output: &mut W,
388) -> io::Result<()> {
389    for &pos in order {
390        let sequence = collection.at(pos);
391        let length = sequence.len() as u32;
392        output.write_all(&length.to_le_bytes())?;
393        output.write_all(sequence.bytes)?;
394    }
395    output.flush()?;
396    Ok(())
397}
398
399#[cfg(test)]
400mod test {
401    use super::*;
402    use quickcheck_macros::quickcheck;
403
404    const COLLECTION_BYTES: [u8; 100] = [
405        1, 0, 0, 0, 3, 0, 0, 0, // Number of documents
406        1, 0, 0, 0, 0, 0, 0, 0, // t0
407        1, 0, 0, 0, 0, 0, 0, 0, // t1
408        1, 0, 0, 0, 0, 0, 0, 0, // t2
409        1, 0, 0, 0, 0, 0, 0, 0, // t3
410        1, 0, 0, 0, 2, 0, 0, 0, // t4
411        3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, // t5
412        2, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, // t6
413        3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, // t7
414        1, 0, 0, 0, 1, 0, 0, 0, // t8
415    ];
416
417    #[test]
418    fn test_binary_sequence() {
419        let bytes: Vec<u8> = (0_u32..10).flat_map(|i| i.to_le_bytes().to_vec()).collect();
420        let sequence = BinarySequence::try_from(bytes.as_ref()).unwrap();
421        assert!(!sequence.is_empty());
422        for n in 0..10 {
423            assert_eq!(sequence.get(n).unwrap(), n as u32);
424        }
425    }
426
427    #[allow(clippy::needless_pass_by_value)]
428    #[quickcheck]
429    fn biniary_sequence_get_never_crashes(bytes: Vec<u8>, indices: Vec<usize>) {
430        let sequence = BinarySequence {
431            bytes: &bytes,
432            length: bytes.len() / 4,
433        };
434        for idx in indices {
435            let _ = sequence.get(idx);
436        }
437    }
438
439    #[test]
440    fn test_binary_collection() {
441        let coll = BinaryCollection::try_from(COLLECTION_BYTES.as_ref()).unwrap();
442        let sequences = coll
443            .map(|sequence| {
444                sequence.map(|sequence| (sequence.len(), sequence.iter().collect::<Vec<_>>()))
445            })
446            .collect::<Result<Vec<_>, _>>()
447            .unwrap();
448        assert_eq!(
449            sequences,
450            vec![
451                (1, vec![3]),
452                (1, vec![0]),
453                (1, vec![0]),
454                (1, vec![0]),
455                (1, vec![0]),
456                (1, vec![2]),
457                (3, vec![0, 1, 2]),
458                (2, vec![1, 2]),
459                (3, vec![0, 1, 2]),
460                (1, vec![1]),
461            ]
462        );
463    }
464
465    #[test]
466    fn test_binary_collection_invalid_format() {
467        let input: Vec<u8> = vec![1, 0, 0, 0, 3, 0, 0, 0, 1];
468        let coll = BinaryCollection::try_from(input.as_ref());
469        assert_eq!(
470            coll.err(),
471            Some(InvalidFormat::new(
472                "The byte-length of the collection is not divisible by the element size (4)"
473            ))
474        );
475    }
476
477    #[test]
478    fn test_random_access_binary_collection() {
479        let coll = RandomAccessBinaryCollection::try_from(COLLECTION_BYTES.as_ref()).unwrap();
480        assert!(!coll.is_empty());
481        let sequences = coll
482            .iter()
483            .map(|sequence| {
484                sequence.map(|sequence| (sequence.len(), sequence.iter().collect::<Vec<_>>()))
485            })
486            .collect::<Result<Vec<_>, _>>()
487            .unwrap();
488        assert_eq!(
489            sequences,
490            vec![
491                (1, vec![3]),
492                (1, vec![0]),
493                (1, vec![0]),
494                (1, vec![0]),
495                (1, vec![0]),
496                (1, vec![2]),
497                (3, vec![0, 1, 2]),
498                (2, vec![1, 2]),
499                (3, vec![0, 1, 2]),
500                (1, vec![1]),
501            ]
502        );
503        assert_eq!(coll.offsets, vec![0, 8, 16, 24, 32, 40, 48, 64, 76, 92]);
504        assert_eq!(coll.len(), 10);
505        assert_eq!(
506            (0..coll.len())
507                .map(|idx| coll.at(idx).iter().collect())
508                .collect::<Vec<Vec<u32>>>(),
509            vec![
510                vec![3],
511                vec![0],
512                vec![0],
513                vec![0],
514                vec![0],
515                vec![2],
516                vec![0, 1, 2],
517                vec![1, 2],
518                vec![0, 1, 2],
519                vec![1],
520            ]
521        );
522    }
523
524    #[test]
525    #[should_panic]
526    fn test_random_access_binary_collection_out_of_bounds() {
527        let coll = RandomAccessBinaryCollection::try_from(COLLECTION_BYTES.as_ref()).unwrap();
528        let _ = coll.at(10);
529    }
530
531    #[test]
532    fn test_reorder_collection() {
533        let coll = RandomAccessBinaryCollection::try_from(COLLECTION_BYTES.as_ref()).unwrap();
534        let order = vec![0, 1, 4, 9, 5, 6, 7, 2, 3, 8];
535        let mut output = Vec::<u8>::new();
536        reorder(&coll, &order, &mut output).unwrap();
537        println!("{:?}", output);
538        let reordered = BinaryCollection::try_from(output.as_ref()).unwrap();
539        let sequences = reordered
540            .map(|sequence| {
541                sequence.map(|sequence| (sequence.len(), sequence.iter().collect::<Vec<_>>()))
542            })
543            .collect::<Result<Vec<_>, _>>()
544            .unwrap();
545        assert_eq!(
546            sequences,
547            vec![
548                (1, vec![3]),       // 0
549                (1, vec![0]),       // 1
550                (1, vec![0]),       // 4
551                (1, vec![1]),       // 9
552                (1, vec![2]),       // 5
553                (3, vec![0, 1, 2]), // 6
554                (2, vec![1, 2]),    // 7
555                (1, vec![0]),       // 2
556                (1, vec![0]),       // 3
557                (3, vec![0, 1, 2]), // 8
558            ]
559        );
560    }
561}