nasbench/
tfrecord.rs

1use crate::Result;
2use byteorder::{ByteOrder as _, LittleEndian, ReadBytesExt as _};
3use crc;
4use std::io::Read;
5use trackable::error::Failed;
6
7/// See: https://www.tensorflow.org/tutorials/load_data/tf_records.
8#[derive(Debug)]
9pub struct TfRecord {
10    pub len: u64,
11    pub data: Vec<u8>,
12}
13impl TfRecord {
14    pub fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
15        let mut len_buf = [0; 8];
16        track_any_err!(reader.read_exact(&mut len_buf))?;
17        let len = LittleEndian::read_u64(&len_buf);
18        let len_crc = track_any_err!(reader.read_u32::<LittleEndian>())?;
19        track!(check_crc(&len_buf, len_crc))?;
20
21        let mut data = vec![0; len as usize];
22        track_any_err!(reader.read_exact(&mut data))?;
23        let data_crc = track_any_err!(reader.read_u32::<LittleEndian>())?;
24        track!(check_crc(&data, data_crc))?;
25        Ok(Self { len, data })
26    }
27}
28
29#[derive(Debug)]
30pub struct TfRecordStream<R> {
31    reader: R,
32}
33impl<R: Read> TfRecordStream<R> {
34    pub fn new(reader: R) -> Self {
35        Self { reader }
36    }
37}
38impl<R: Read> Iterator for TfRecordStream<R> {
39    type Item = Result<TfRecord>;
40
41    fn next(&mut self) -> Option<Self::Item> {
42        let mut peek = [0; 1];
43        match track_any_err!(self.reader.read(&mut peek)) {
44            Err(e) => Some(Err(e)),
45            Ok(0) => None,
46            Ok(_) => match track!(TfRecord::from_reader(peek.chain(&mut self.reader))) {
47                Err(e) => Some(Err(e)),
48                Ok(r) => Some(Ok(r)),
49            },
50        }
51    }
52}
53
54fn check_crc(bytes: &[u8], actual_crc: u32) -> Result<()> {
55    let expected_crc = crc::crc32::checksum_castagnoli(&bytes);
56    let expected_crc = (expected_crc.overflowing_shr(15).0 | expected_crc.overflowing_shl(17).0)
57        .overflowing_add(0xa282_ead8)
58        .0;
59    track_assert_eq!(actual_crc, expected_crc, Failed);
60    Ok(())
61}