1use crate::Result;
2use byteorder::{ByteOrder as _, LittleEndian, ReadBytesExt as _};
3use crc;
4use std::io::Read;
5use trackable::error::Failed;
6
7#[derive(Debug)]
9pub struct TfRecord {
10 pub len: u64,
11 pub data: Vec<u8>,
12}
13impl TfRecord {
14 pub fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
15 let mut len_buf = [0; 8];
16 track_any_err!(reader.read_exact(&mut len_buf))?;
17 let len = LittleEndian::read_u64(&len_buf);
18 let len_crc = track_any_err!(reader.read_u32::<LittleEndian>())?;
19 track!(check_crc(&len_buf, len_crc))?;
20
21 let mut data = vec![0; len as usize];
22 track_any_err!(reader.read_exact(&mut data))?;
23 let data_crc = track_any_err!(reader.read_u32::<LittleEndian>())?;
24 track!(check_crc(&data, data_crc))?;
25 Ok(Self { len, data })
26 }
27}
28
29#[derive(Debug)]
30pub struct TfRecordStream<R> {
31 reader: R,
32}
33impl<R: Read> TfRecordStream<R> {
34 pub fn new(reader: R) -> Self {
35 Self { reader }
36 }
37}
38impl<R: Read> Iterator for TfRecordStream<R> {
39 type Item = Result<TfRecord>;
40
41 fn next(&mut self) -> Option<Self::Item> {
42 let mut peek = [0; 1];
43 match track_any_err!(self.reader.read(&mut peek)) {
44 Err(e) => Some(Err(e)),
45 Ok(0) => None,
46 Ok(_) => match track!(TfRecord::from_reader(peek.chain(&mut self.reader))) {
47 Err(e) => Some(Err(e)),
48 Ok(r) => Some(Ok(r)),
49 },
50 }
51 }
52}
53
54fn check_crc(bytes: &[u8], actual_crc: u32) -> Result<()> {
55 let expected_crc = crc::crc32::checksum_castagnoli(&bytes);
56 let expected_crc = (expected_crc.overflowing_shr(15).0 | expected_crc.overflowing_shl(17).0)
57 .overflowing_add(0xa282_ead8)
58 .0;
59 track_assert_eq!(actual_crc, expected_crc, Failed);
60 Ok(())
61}