1use std::convert::TryFrom;
2use std::convert::TryInto;
3use std::error::Error;
4use std::fmt;
5use std::io::{self, Write};
6
7const ELEMENT_SIZE: usize = std::mem::size_of::<u32>();
8
9#[derive(Debug, Default, PartialEq, Eq)]
11pub struct InvalidFormat(Option<String>);
12
13impl InvalidFormat {
14 pub fn new<S: Into<String>>(msg: S) -> Self {
16 Self(Some(msg.into()))
17 }
18}
19
20impl Error for InvalidFormat {}
21
22impl fmt::Display for InvalidFormat {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 write!(f, "Invalid binary collection format")?;
25 if let Some(msg) = &self.0 {
26 write!(f, ": {}", msg)?;
27 }
28 Ok(())
29 }
30}
31
32#[derive(Debug, Clone, Copy)]
73pub struct BinaryCollection<'a> {
74 bytes: &'a [u8],
75}
76
77impl<'a> TryFrom<&'a [u8]> for BinaryCollection<'a> {
78 type Error = InvalidFormat;
79 fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
80 if bytes.len() % std::mem::size_of::<u32>() == 0 {
81 Ok(Self { bytes })
82 } else {
83 Err(InvalidFormat::new(
84 "The byte-length of the collection is not divisible by the element size (4)",
85 ))
86 }
87 }
88}
89
90fn get_from(bytes: &[u8]) -> Result<BinarySequence<'_>, InvalidFormat> {
91 let length_bytes = bytes
92 .get(..ELEMENT_SIZE)
93 .ok_or_else(InvalidFormat::default)?;
94 let length = u32::from_le_bytes(length_bytes.try_into().unwrap()) as usize;
95 let bytes = bytes
96 .get(ELEMENT_SIZE..(ELEMENT_SIZE * (length + 1)))
97 .ok_or_else(InvalidFormat::default)?;
98 Ok(BinarySequence { bytes, length })
99}
100
101fn get_next<'a>(
102 collection: &mut BinaryCollection<'a>,
103) -> Result<BinarySequence<'a>, InvalidFormat> {
104 let sequence = get_from(collection.bytes)?;
105 collection.bytes = &collection.bytes[ELEMENT_SIZE * (sequence.len() + 1)..];
106 Ok(sequence)
107}
108
109impl<'a> Iterator for BinaryCollection<'a> {
110 type Item = Result<BinarySequence<'a>, InvalidFormat>;
111
112 fn next(&mut self) -> Option<Self::Item> {
113 if self.bytes.is_empty() {
114 None
115 } else {
116 Some(get_next(self))
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
174pub struct RandomAccessBinaryCollection<'a> {
175 inner: BinaryCollection<'a>,
176 offsets: Vec<usize>,
177}
178
179impl<'a> TryFrom<&'a [u8]> for RandomAccessBinaryCollection<'a> {
180 type Error = InvalidFormat;
181 fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
182 let collection = BinaryCollection::try_from(bytes)?;
183 let offsets = collection
184 .map(|sequence| sequence.map(|s| s.len()))
185 .scan(0, |offset, len| {
186 Some(len.map(|len| {
187 let result = *offset;
188 *offset += ELEMENT_SIZE * (len + 1);
189 result
190 }))
191 })
192 .collect::<Result<Vec<_>, _>>()?;
193
194 Ok(Self {
195 inner: collection,
196 offsets,
197 })
198 }
199}
200
201impl<'a> RandomAccessBinaryCollection<'a> {
202 pub fn iter(&self) -> impl Iterator<Item = Result<BinarySequence<'a>, InvalidFormat>> {
204 self.inner
205 }
206
207 #[must_use]
213 pub fn at(&self, index: usize) -> BinarySequence<'a> {
214 if let Some(sequence) = self.get(index) {
215 sequence
216 } else {
217 panic!(
218 "out of bounds: requested {} out of {} elements",
219 index,
220 self.len()
221 );
222 }
223 }
224
225 #[must_use]
227 pub fn get(&self, index: usize) -> Option<BinarySequence<'a>> {
228 let byte_offset = *self.offsets.get(index)?;
229 if let Ok(sequence) = get_from(self.inner.bytes.get(byte_offset..)?) {
230 Some(sequence)
231 } else {
232 unreachable!()
237 }
238 }
239
240 #[must_use]
242 pub fn len(&self) -> usize {
243 self.offsets.len()
244 }
245
246 #[must_use]
248 pub fn is_empty(&self) -> bool {
249 self.offsets.len() == 0
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq)]
274pub struct BinarySequence<'a> {
275 bytes: &'a [u8],
277 length: usize,
279}
280
281impl<'a> TryFrom<&'a [u8]> for BinarySequence<'a> {
282 type Error = ();
283 fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
302 if bytes.len() % std::mem::size_of::<u32>() == 0 {
303 let length = bytes.len() / std::mem::size_of::<u32>();
304 Ok(Self { bytes, length })
305 } else {
306 Err(())
307 }
308 }
309}
310
311unsafe fn bytes_to_u32(bytes: &[u8]) -> u32 {
315 let mut value: std::mem::MaybeUninit<[u8; 4]> = std::mem::MaybeUninit::uninit();
316 value
317 .as_mut_ptr()
318 .copy_from_nonoverlapping(bytes.as_ptr().cast(), 1);
319 u32::from_le_bytes(value.assume_init())
320}
321
322impl<'a> BinarySequence<'a> {
323 #[must_use]
325 pub fn len(&self) -> usize {
326 self.length
327 }
328
329 #[must_use]
331 pub fn is_empty(&self) -> bool {
332 self.length == 0
333 }
334
335 #[must_use]
337 pub fn get(&self, index: usize) -> Option<u32> {
338 if index < self.len() {
339 let offset = index * std::mem::size_of::<u32>();
340 self.bytes.get(offset..offset + 4).map(|bytes| {
341 unsafe { bytes_to_u32(bytes) }
343 })
344 } else {
345 None
346 }
347 }
348
349 #[must_use]
351 pub fn iter(&'a self) -> BinarySequenceIterator<'a> {
352 BinarySequenceIterator {
353 sequence: self,
354 index: 0,
355 }
356 }
357
358 #[must_use]
360 pub fn bytes(&'a self) -> &'a [u8] {
361 self.bytes
362 }
363}
364
365pub struct BinarySequenceIterator<'a> {
366 sequence: &'a BinarySequence<'a>,
367 index: usize,
368}
369
370impl<'a> Iterator for BinarySequenceIterator<'a> {
371 type Item = u32;
372
373 fn next(&mut self) -> Option<Self::Item> {
374 let index = self.index;
375 self.index += 1;
376 self.sequence.get(index)
377 }
378}
379
380pub fn reorder<W: Write>(
385 collection: &RandomAccessBinaryCollection<'_>,
386 order: &[usize],
387 output: &mut W,
388) -> io::Result<()> {
389 for &pos in order {
390 let sequence = collection.at(pos);
391 let length = sequence.len() as u32;
392 output.write_all(&length.to_le_bytes())?;
393 output.write_all(sequence.bytes)?;
394 }
395 output.flush()?;
396 Ok(())
397}
398
399#[cfg(test)]
400mod test {
401 use super::*;
402 use quickcheck_macros::quickcheck;
403
404 const COLLECTION_BYTES: [u8; 100] = [
405 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, ];
416
417 #[test]
418 fn test_binary_sequence() {
419 let bytes: Vec<u8> = (0_u32..10).flat_map(|i| i.to_le_bytes().to_vec()).collect();
420 let sequence = BinarySequence::try_from(bytes.as_ref()).unwrap();
421 assert!(!sequence.is_empty());
422 for n in 0..10 {
423 assert_eq!(sequence.get(n).unwrap(), n as u32);
424 }
425 }
426
427 #[allow(clippy::needless_pass_by_value)]
428 #[quickcheck]
429 fn biniary_sequence_get_never_crashes(bytes: Vec<u8>, indices: Vec<usize>) {
430 let sequence = BinarySequence {
431 bytes: &bytes,
432 length: bytes.len() / 4,
433 };
434 for idx in indices {
435 let _ = sequence.get(idx);
436 }
437 }
438
439 #[test]
440 fn test_binary_collection() {
441 let coll = BinaryCollection::try_from(COLLECTION_BYTES.as_ref()).unwrap();
442 let sequences = coll
443 .map(|sequence| {
444 sequence.map(|sequence| (sequence.len(), sequence.iter().collect::<Vec<_>>()))
445 })
446 .collect::<Result<Vec<_>, _>>()
447 .unwrap();
448 assert_eq!(
449 sequences,
450 vec![
451 (1, vec![3]),
452 (1, vec![0]),
453 (1, vec![0]),
454 (1, vec![0]),
455 (1, vec![0]),
456 (1, vec![2]),
457 (3, vec![0, 1, 2]),
458 (2, vec![1, 2]),
459 (3, vec![0, 1, 2]),
460 (1, vec![1]),
461 ]
462 );
463 }
464
465 #[test]
466 fn test_binary_collection_invalid_format() {
467 let input: Vec<u8> = vec![1, 0, 0, 0, 3, 0, 0, 0, 1];
468 let coll = BinaryCollection::try_from(input.as_ref());
469 assert_eq!(
470 coll.err(),
471 Some(InvalidFormat::new(
472 "The byte-length of the collection is not divisible by the element size (4)"
473 ))
474 );
475 }
476
477 #[test]
478 fn test_random_access_binary_collection() {
479 let coll = RandomAccessBinaryCollection::try_from(COLLECTION_BYTES.as_ref()).unwrap();
480 assert!(!coll.is_empty());
481 let sequences = coll
482 .iter()
483 .map(|sequence| {
484 sequence.map(|sequence| (sequence.len(), sequence.iter().collect::<Vec<_>>()))
485 })
486 .collect::<Result<Vec<_>, _>>()
487 .unwrap();
488 assert_eq!(
489 sequences,
490 vec![
491 (1, vec![3]),
492 (1, vec![0]),
493 (1, vec![0]),
494 (1, vec![0]),
495 (1, vec![0]),
496 (1, vec![2]),
497 (3, vec![0, 1, 2]),
498 (2, vec![1, 2]),
499 (3, vec![0, 1, 2]),
500 (1, vec![1]),
501 ]
502 );
503 assert_eq!(coll.offsets, vec![0, 8, 16, 24, 32, 40, 48, 64, 76, 92]);
504 assert_eq!(coll.len(), 10);
505 assert_eq!(
506 (0..coll.len())
507 .map(|idx| coll.at(idx).iter().collect())
508 .collect::<Vec<Vec<u32>>>(),
509 vec![
510 vec![3],
511 vec![0],
512 vec![0],
513 vec![0],
514 vec![0],
515 vec![2],
516 vec![0, 1, 2],
517 vec![1, 2],
518 vec![0, 1, 2],
519 vec![1],
520 ]
521 );
522 }
523
524 #[test]
525 #[should_panic]
526 fn test_random_access_binary_collection_out_of_bounds() {
527 let coll = RandomAccessBinaryCollection::try_from(COLLECTION_BYTES.as_ref()).unwrap();
528 let _ = coll.at(10);
529 }
530
531 #[test]
532 fn test_reorder_collection() {
533 let coll = RandomAccessBinaryCollection::try_from(COLLECTION_BYTES.as_ref()).unwrap();
534 let order = vec![0, 1, 4, 9, 5, 6, 7, 2, 3, 8];
535 let mut output = Vec::<u8>::new();
536 reorder(&coll, &order, &mut output).unwrap();
537 println!("{:?}", output);
538 let reordered = BinaryCollection::try_from(output.as_ref()).unwrap();
539 let sequences = reordered
540 .map(|sequence| {
541 sequence.map(|sequence| (sequence.len(), sequence.iter().collect::<Vec<_>>()))
542 })
543 .collect::<Result<Vec<_>, _>>()
544 .unwrap();
545 assert_eq!(
546 sequences,
547 vec![
548 (1, vec![3]), (1, vec![0]), (1, vec![0]), (1, vec![1]), (1, vec![2]), (3, vec![0, 1, 2]), (2, vec![1, 2]), (1, vec![0]), (1, vec![0]), (3, vec![0, 1, 2]), ]
559 );
560 }
561}