idx_parser 0.3.0

Parse IDX files such as the ones used in MNIST database files.
Documentation
#![feature(array_chunks)]

pub mod matrix;
use matrix::*;
pub mod raw_data;
use raw_data::*;

/// holds parsed idx file data
#[derive(Debug, Clone, PartialEq)]
pub struct IDXFile {
    /// length of each of the dimensions. i.e [10, 20, 100] means the first dimension has a 10 elements, the second has 20, and third has 100
    pub dimensions: Vec<u32>,
    /// multi dimensional matrix data
    pub matrix_data: matrix::Row,
}
impl IDXFile {
    /// parses bytes from an IDX file into a IDXFile
    /// # Example
    /// ```
    /// use std::fs::File;
    /// use std::io::prelude::*;
    /// use crate::idx_parser::IDXFile;
    ///
    /// // byte vec
    /// let mut buf: Vec<u8> = vec![];
    /// // your IDX file
    /// let mut file = File::open("tests/u8_dim3_test").unwrap();
    /// // get bytes from IDX file
    /// file.read_to_end(&mut buf).unwrap();
    /// let my_idx_file = IDXFile::from_bytes(buf).unwrap();
    /// assert_eq!(my_idx_file.dimensions, [1, 2,3])
    ///
    /// ```
    pub fn from_bytes(bytes: Vec<u8>) -> Result<IDXFile, ParseError> {
        let mut iter = bytes.iter();

        if (*iter.next().unwrap(), *iter.next().unwrap()) != (0, 0) {
            return Err(ParseError::FirstTwoBytesNotZero(bytes[0], bytes[1]));
        };

        let data_type = match iter.next().unwrap() {
            0x08 => DataType::UnsignedByte,
            0x09 => DataType::SignedByte,
            0x0b => DataType::Short,
            0x0c => DataType::Int,
            0x0d => DataType::Float,
            0x0e => DataType::Double,
            _ => return Err(ParseError::InvalidDataType(bytes[2])),
        };

        let dim_len = *iter.next().unwrap() as usize;
        if dim_len <= 0 {
            return Err(ParseError::ZeroDim);
        }

        let dimentions: Vec<u32> = (0..dim_len)
            .into_iter()
            .map(|_| {
                u32::from_be_bytes([
                    *iter.next().unwrap(),
                    *iter.next().unwrap(),
                    *iter.next().unwrap(),
                    *iter.next().unwrap(),
                ])
            })
            .collect();

        let raw_data_bytes: Vec<u8> = iter.map(|a| *a).collect();
        let raw_data: Vec<Matrix> = match data_type {
            DataType::UnsignedByte => raw_data_bytes
                .iter()
                .map(|b| Matrix::Data(RawData::UnsignedByte(*b)))
                .collect(),
            DataType::SignedByte => raw_data_bytes
                .into_iter()
                .map(|b| Matrix::Data(RawData::SignedByte(b as i8)))
                .collect(),
            DataType::Short => raw_data_bytes
                .array_chunks::<2>()
                .into_iter()
                .map(|c| Matrix::Data(RawData::Short(i16::from_le_bytes(*c))))
                .collect(),
            DataType::Int => raw_data_bytes
                .array_chunks::<4>()
                .into_iter()
                .map(|c| Matrix::Data(RawData::Int(i32::from_le_bytes(*c))))
                .collect(),
            DataType::Float => raw_data_bytes
                .array_chunks::<4>()
                .into_iter()
                .map(|c| Matrix::Data(RawData::Float(f32::from_le_bytes(*c))))
                .collect(),
            DataType::Double => raw_data_bytes
                .array_chunks::<8>()
                .into_iter()
                .map(|c| Matrix::Data(RawData::Double(f64::from_le_bytes(*c))))
                .collect(),
        };

        let mut dim_iter = dimentions.iter().rev().peekable();
        let mut matrix_data: Matrix =
            Matrix::Row(raw_data.iter().map(|d| Box::new(d.clone())).collect());

        while let Some(&dim) = dim_iter.next() {
            if let Matrix::Row(r) = matrix_data {
                matrix_data = Matrix::Row(
                    r.chunks_exact(dim as usize)
                        .map(|c| Box::new(Matrix::Row(c.to_vec())))
                        .collect(),
                );
            }
        }

        if let Matrix::Row(r) = matrix_data {
            let m: Matrix = *r[0].clone();
            if let Matrix::Row(matrix_data) = m {
                return Ok(IDXFile {
                    matrix_data,
                    dimensions: dimentions,
                });
            } else {
                return Err(ParseError::ZeroDim);
            }
        } else {
            return Err(ParseError::ZeroDim);
        }
    }
}

#[derive(Debug, Copy, Clone, PartialEq)]
enum DataType {
    UnsignedByte,
    SignedByte,
    Short,
    Int,
    Float,
    Double,
}

#[derive(Debug, Copy, Clone, PartialEq)]
pub enum ParseError {
    FirstTwoBytesNotZero(u8, u8),
    ZeroDimByte(u8),
    InvalidDataType(u8),
    ZeroDim,
}

#[cfg(test)]
mod test {
    use super::*;
    use std::fs::File;
    use std::io::prelude::*;
    use Matrix::*;
    use RawData::*;

    use std::convert::TryInto;

    #[test]
    pub fn test_bad_magic_number() {
        let mut buf = Vec::new();
        let mut file = File::open("./tests/bad_magic_number_test").unwrap();
        file.read_to_end(&mut buf).unwrap();
        let result = IDXFile::from_bytes(buf).err().unwrap();
        assert_eq!(result, ParseError::FirstTwoBytesNotZero(0x01, 0x00))
    }
    #[test]
    pub fn test_bad_dim() {
        let mut buf = Vec::new();
        let mut file = File::open("./tests/bad_dim_test").unwrap();
        file.read_to_end(&mut buf).unwrap();
        let result = IDXFile::from_bytes(buf).err().unwrap();
        assert_eq!(result, ParseError::ZeroDim)
    }
    #[test]
    pub fn test_u8_dim3() {
        let bytes: Vec<u8> = vec![
            0x00, 0x00, 0x08, 0x3, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00,
            0x00, 0x03, 00, 01, 02, 03, 04, 05, 06, 07, 08, 09, 10, 11, 12, 13, 14, 15, 16, 17, 18,
            19, 20, 21, 22, 23, 24,
        ];
        let idx = IDXFile::from_bytes(bytes).unwrap();
        assert_eq!(idx.dimensions, vec![2, 2, 3]);

        let expected: matrix::Row = vec![
            Row(vec![
                Row(vec![
                    Data(UnsignedByte(0)).into(),
                    Data(UnsignedByte(1)).into(),
                    Data(UnsignedByte(2)).into(),
                ])
                .into(),
                Row(vec![
                    Data(UnsignedByte(3)).into(),
                    Data(UnsignedByte(4)).into(),
                    Data(UnsignedByte(5)).into(),
                ])
                .into(),
            ])
            .into(),
            Row(vec![
                Row(vec![
                    Data(UnsignedByte(6)).into(),
                    Data(UnsignedByte(7)).into(),
                    Data(UnsignedByte(8)).into(),
                ])
                .into(),
                Row(vec![
                    Data(UnsignedByte(9)).into(),
                    Data(UnsignedByte(10)).into(),
                    Data(UnsignedByte(11)).into(),
                ])
                .into(),
            ])
            .into(),
        ];
        assert_eq!(idx.matrix_data, expected);
        let u: u8 = idx.matrix_data[1][1][1].clone().try_into().unwrap();
        assert_eq!(u, 10u8);
    }
}