pub mod header;
use self::header::{
FormatHeaderError, Header, ParseHeaderError, ReadHeaderError, WriteHeaderError,
};
use byteorder::{BigEndian, LittleEndian, ReadBytesExt};
use ndarray::prelude::*;
use ndarray::{Data, DataOwned, IntoDimension};
use py_literal::Value as PyValue;
use std::error::Error;
use std::fmt;
use std::io;
use std::mem;
pub fn read_npy<P, T>(path: P) -> Result<T, ReadNpyError>
where
P: AsRef<std::path::Path>,
T: ReadNpyExt,
{
T::read_npy(std::fs::File::open(path)?)
}
pub fn write_npy<P, T>(path: P, array: &T) -> Result<(), WriteNpyError>
where
P: AsRef<std::path::Path>,
T: WriteNpyExt,
{
array.write_npy(std::fs::File::create(path)?)
}
#[derive(Debug)]
pub enum WriteDataError {
Io(io::Error),
FormatData(Box<dyn Error + Send + Sync + 'static>),
}
impl Error for WriteDataError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
WriteDataError::Io(err) => Some(err),
WriteDataError::FormatData(err) => Some(&**err),
}
}
}
impl fmt::Display for WriteDataError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WriteDataError::Io(err) => write!(f, "I/O error: {}", err),
WriteDataError::FormatData(err) => write!(f, "error formatting data: {}", err),
}
}
}
impl From<io::Error> for WriteDataError {
fn from(err: io::Error) -> WriteDataError {
WriteDataError::Io(err)
}
}
pub unsafe trait WritableElement: Sized {
fn type_descriptor() -> PyValue;
fn write<W: io::Write>(&self, writer: W) -> Result<(), WriteDataError>;
fn write_slice<W: io::Write>(slice: &[Self], writer: W) -> Result<(), WriteDataError>;
}
#[derive(Debug)]
pub enum WriteNpyError {
Io(io::Error),
FormatHeader(FormatHeaderError),
FormatData(Box<dyn Error + Send + Sync + 'static>),
}
impl Error for WriteNpyError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
WriteNpyError::Io(err) => Some(err),
WriteNpyError::FormatHeader(err) => Some(err),
WriteNpyError::FormatData(err) => Some(&**err),
}
}
}
impl fmt::Display for WriteNpyError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WriteNpyError::Io(err) => write!(f, "I/O error: {}", err),
WriteNpyError::FormatHeader(err) => write!(f, "error formatting header: {}", err),
WriteNpyError::FormatData(err) => write!(f, "error formatting data: {}", err),
}
}
}
impl From<io::Error> for WriteNpyError {
fn from(err: io::Error) -> WriteNpyError {
WriteNpyError::Io(err)
}
}
impl From<WriteHeaderError> for WriteNpyError {
fn from(err: WriteHeaderError) -> WriteNpyError {
match err {
WriteHeaderError::Io(err) => WriteNpyError::Io(err),
WriteHeaderError::Format(err) => WriteNpyError::FormatHeader(err),
}
}
}
impl From<FormatHeaderError> for WriteNpyError {
fn from(err: FormatHeaderError) -> WriteNpyError {
WriteNpyError::FormatHeader(err)
}
}
impl From<WriteDataError> for WriteNpyError {
fn from(err: WriteDataError) -> WriteNpyError {
match err {
WriteDataError::Io(err) => WriteNpyError::Io(err),
WriteDataError::FormatData(err) => WriteNpyError::FormatData(err),
}
}
}
pub trait WriteNpyExt {
fn write_npy<W: io::Write>(&self, writer: W) -> Result<(), WriteNpyError>;
}
impl<A, S, D> WriteNpyExt for ArrayBase<S, D>
where
A: WritableElement,
S: Data<Elem = A>,
D: Dimension,
{
fn write_npy<W: io::Write>(&self, mut writer: W) -> Result<(), WriteNpyError> {
let write_contiguous = |mut writer: W, fortran_order: bool| {
Header {
type_descriptor: A::type_descriptor(),
fortran_order,
shape: self.shape().to_owned(),
}
.write(&mut writer)?;
A::write_slice(self.as_slice_memory_order().unwrap(), &mut writer)?;
Ok(())
};
if self.is_standard_layout() {
write_contiguous(writer, false)
} else if self.view().reversed_axes().is_standard_layout() {
write_contiguous(writer, true)
} else {
Header {
type_descriptor: A::type_descriptor(),
fortran_order: false,
shape: self.shape().to_owned(),
}
.write(&mut writer)?;
for elem in self.iter() {
elem.write(&mut writer)?;
}
Ok(())
}
}
}
#[derive(Debug)]
pub enum ReadDataError {
Io(io::Error),
WrongDescriptor(PyValue),
MissingData,
ExtraBytes(usize),
ParseData(Box<dyn Error + Send + Sync + 'static>),
}
impl Error for ReadDataError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
ReadDataError::Io(err) => Some(err),
ReadDataError::WrongDescriptor(_) => None,
ReadDataError::MissingData => None,
ReadDataError::ExtraBytes(_) => None,
ReadDataError::ParseData(err) => Some(&**err),
}
}
}
impl fmt::Display for ReadDataError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ReadDataError::Io(err) => write!(f, "I/O error: {}", err),
ReadDataError::WrongDescriptor(desc) => {
write!(f, "incorrect descriptor ({}) for this type", desc)
}
ReadDataError::MissingData => write!(f, "reached EOF before reading all data"),
ReadDataError::ExtraBytes(num_extra_bytes) => {
write!(f, "file had {} extra bytes before EOF", num_extra_bytes)
}
ReadDataError::ParseData(err) => write!(f, "error parsing data: {}", err),
}
}
}
impl From<io::Error> for ReadDataError {
fn from(err: io::Error) -> ReadDataError {
if err.kind() == io::ErrorKind::UnexpectedEof {
ReadDataError::MissingData
} else {
ReadDataError::Io(err)
}
}
}
pub trait ReadableElement: Sized {
fn read_to_end_exact_vec<R: io::Read>(
reader: R,
type_desc: &PyValue,
len: usize,
) -> Result<Vec<Self>, ReadDataError>;
}
#[derive(Debug)]
pub enum ReadNpyError {
Io(io::Error),
ParseHeader(ParseHeaderError),
ParseData(Box<dyn Error + Send + Sync + 'static>),
LengthOverflow,
WrongNdim(Option<usize>, usize),
WrongDescriptor(PyValue),
MissingData,
ExtraBytes(usize),
}
impl Error for ReadNpyError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
ReadNpyError::Io(err) => Some(err),
ReadNpyError::ParseHeader(err) => Some(err),
ReadNpyError::ParseData(err) => Some(&**err),
ReadNpyError::LengthOverflow => None,
ReadNpyError::WrongNdim(_, _) => None,
ReadNpyError::WrongDescriptor(_) => None,
ReadNpyError::MissingData => None,
ReadNpyError::ExtraBytes(_) => None,
}
}
}
impl fmt::Display for ReadNpyError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ReadNpyError::Io(err) => write!(f, "I/O error: {}", err),
ReadNpyError::ParseHeader(err) => write!(f, "error parsing header: {}", err),
ReadNpyError::ParseData(err) => write!(f, "error parsing data: {}", err),
ReadNpyError::LengthOverflow => write!(f, "overflow computing length from shape"),
ReadNpyError::WrongNdim(expected, actual) => write!(
f,
"ndim {} of array did not match Dimension type with NDIM = {:?}",
actual, expected
),
ReadNpyError::WrongDescriptor(desc) => {
write!(f, "incorrect descriptor ({}) for this type", desc)
}
ReadNpyError::MissingData => write!(f, "reached EOF before reading all data"),
ReadNpyError::ExtraBytes(num_extra_bytes) => {
write!(f, "file had {} extra bytes before EOF", num_extra_bytes)
}
}
}
}
impl From<io::Error> for ReadNpyError {
fn from(err: io::Error) -> ReadNpyError {
ReadNpyError::Io(err)
}
}
impl From<ReadHeaderError> for ReadNpyError {
fn from(err: ReadHeaderError) -> ReadNpyError {
match err {
ReadHeaderError::Io(err) => ReadNpyError::Io(err),
ReadHeaderError::Parse(err) => ReadNpyError::ParseHeader(err),
}
}
}
impl From<ParseHeaderError> for ReadNpyError {
fn from(err: ParseHeaderError) -> ReadNpyError {
ReadNpyError::ParseHeader(err)
}
}
impl From<ReadDataError> for ReadNpyError {
fn from(err: ReadDataError) -> ReadNpyError {
match err {
ReadDataError::Io(err) => ReadNpyError::Io(err),
ReadDataError::WrongDescriptor(desc) => ReadNpyError::WrongDescriptor(desc),
ReadDataError::MissingData => ReadNpyError::MissingData,
ReadDataError::ExtraBytes(nbytes) => ReadNpyError::ExtraBytes(nbytes),
ReadDataError::ParseData(err) => ReadNpyError::ParseData(err),
}
}
}
pub trait ReadNpyExt: Sized {
fn read_npy<R: io::Read>(reader: R) -> Result<Self, ReadNpyError>;
}
impl<A, S, D> ReadNpyExt for ArrayBase<S, D>
where
A: ReadableElement,
S: DataOwned<Elem = A>,
D: Dimension,
{
fn read_npy<R: io::Read>(mut reader: R) -> Result<Self, ReadNpyError> {
let header = Header::from_reader(&mut reader)?;
let shape = header.shape.into_dimension();
let ndim = shape.ndim();
let len = match shape.size_checked() {
Some(len) if len <= std::isize::MAX as usize => len,
_ => return Err(ReadNpyError::LengthOverflow),
};
let data = A::read_to_end_exact_vec(&mut reader, &header.type_descriptor, len)?;
ArrayBase::from_shape_vec(shape.set_f(header.fortran_order), data)
.unwrap()
.into_dimensionality()
.map_err(|_| ReadNpyError::WrongNdim(D::NDIM, ndim))
}
}
macro_rules! impl_writable_primitive {
($elem:ty, $little_desc:expr, $big_desc:expr) => {
unsafe impl WritableElement for $elem {
fn type_descriptor() -> PyValue {
if cfg!(target_endian = "little") {
PyValue::String($little_desc.into())
} else if cfg!(target_endian = "big") {
PyValue::String($big_desc.into())
} else {
unreachable!()
}
}
fn write<W: io::Write>(&self, mut writer: W) -> Result<(), WriteDataError> {
fn cast(self_: &$elem) -> &[u8] {
unsafe {
let ptr: *const $elem = self_;
std::slice::from_raw_parts(ptr.cast::<u8>(), mem::size_of::<$elem>())
}
}
writer.write_all(cast(self))?;
Ok(())
}
fn write_slice<W: io::Write>(
slice: &[Self],
mut writer: W,
) -> Result<(), WriteDataError> {
fn cast(slice: &[$elem]) -> &[u8] {
unsafe {
std::slice::from_raw_parts(
slice.as_ptr().cast::<u8>(),
slice.len() * mem::size_of::<$elem>(),
)
}
}
writer.write_all(cast(slice))?;
Ok(())
}
}
};
}
pub fn check_for_extra_bytes<R: io::Read>(reader: &mut R) -> Result<(), ReadDataError> {
let num_extra_bytes = reader.read_to_end(&mut Vec::new())?;
if num_extra_bytes == 0 {
Ok(())
} else {
Err(ReadDataError::ExtraBytes(num_extra_bytes))
}
}
macro_rules! impl_readable_primitive_one_byte {
($elem:ty, [$($desc:expr),*], $zero:expr, $read_into:ident) => {
impl ReadableElement for $elem {
fn read_to_end_exact_vec<R: io::Read>(
mut reader: R,
type_desc: &PyValue,
len: usize,
) -> Result<Vec<Self>, ReadDataError> {
match *type_desc {
PyValue::String(ref s) if $(s == $desc)||* => {
let mut out = vec![$zero; len];
reader.$read_into(&mut out)?;
check_for_extra_bytes(&mut reader)?;
Ok(out)
}
ref other => Err(ReadDataError::WrongDescriptor(other.clone())),
}
}
}
};
}
macro_rules! impl_primitive_one_byte {
($elem:ty, $write_desc:expr, [$($read_desc:expr),*], $zero:expr, $read_into:ident) => {
impl_writable_primitive!($elem, $write_desc, $write_desc);
impl_readable_primitive_one_byte!($elem, [$($read_desc),*], $zero, $read_into);
};
}
impl_primitive_one_byte!(i8, "|i1", ["|i1", "i1", "b"], 0, read_i8_into);
impl_primitive_one_byte!(u8, "|u1", ["|u1", "u1", "B"], 0, read_exact);
macro_rules! impl_readable_primitive_multi_byte {
($elem:ty, $little_desc:expr, $big_desc:expr, $zero:expr, $read_into:ident) => {
impl ReadableElement for $elem {
fn read_to_end_exact_vec<R: io::Read>(
mut reader: R,
type_desc: &PyValue,
len: usize,
) -> Result<Vec<Self>, ReadDataError> {
let mut out = vec![$zero; len];
match *type_desc {
PyValue::String(ref s) if s == $little_desc => {
reader.$read_into::<LittleEndian>(&mut out)?;
}
PyValue::String(ref s) if s == $big_desc => {
reader.$read_into::<BigEndian>(&mut out)?;
}
ref other => {
return Err(ReadDataError::WrongDescriptor(other.clone()));
}
}
check_for_extra_bytes(&mut reader)?;
Ok(out)
}
}
};
}
macro_rules! impl_primitive_multi_byte {
($elem:ty, $little_desc:expr, $big_desc:expr, $zero:expr, $read_into:ident) => {
impl_writable_primitive!($elem, $little_desc, $big_desc);
impl_readable_primitive_multi_byte!($elem, $little_desc, $big_desc, $zero, $read_into);
};
}
impl_primitive_multi_byte!(i16, "<i2", ">i2", 0, read_i16_into);
impl_primitive_multi_byte!(i32, "<i4", ">i4", 0, read_i32_into);
impl_primitive_multi_byte!(i64, "<i8", ">i8", 0, read_i64_into);
impl_primitive_multi_byte!(u16, "<u2", ">u2", 0, read_u16_into);
impl_primitive_multi_byte!(u32, "<u4", ">u4", 0, read_u32_into);
impl_primitive_multi_byte!(u64, "<u8", ">u8", 0, read_u64_into);
impl_primitive_multi_byte!(f32, "<f4", ">f4", 0., read_f32_into);
impl_primitive_multi_byte!(f64, "<f8", ">f8", 0., read_f64_into);
#[derive(Debug)]
struct ParseBoolError {
bad_value: u8,
}
impl Error for ParseBoolError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
None
}
}
impl fmt::Display for ParseBoolError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "error parsing value {:#04x} as a bool", self.bad_value)
}
}
impl From<ParseBoolError> for ReadDataError {
fn from(err: ParseBoolError) -> ReadDataError {
ReadDataError::ParseData(Box::new(err))
}
}
impl ReadableElement for bool {
fn read_to_end_exact_vec<R: io::Read>(
mut reader: R,
type_desc: &PyValue,
len: usize,
) -> Result<Vec<Self>, ReadDataError> {
match *type_desc {
PyValue::String(ref s) if s == "|b1" => {
let mut bytes: Vec<u8> = vec![0; len];
reader.read_exact(&mut bytes)?;
check_for_extra_bytes(&mut reader)?;
for &byte in &bytes {
if byte > 1 {
return Err(ReadDataError::from(ParseBoolError { bad_value: byte }));
}
}
{
let ptr: *mut u8 = bytes.as_mut_ptr();
let len: usize = bytes.len();
let cap: usize = bytes.capacity();
mem::forget(bytes);
Ok(unsafe { Vec::from_raw_parts(ptr.cast::<bool>(), len, cap) })
}
}
ref other => Err(ReadDataError::WrongDescriptor(other.clone())),
}
}
}
impl_writable_primitive!(bool, "|b1", "|b1");
#[cfg(test)]
mod test {
use super::{ReadDataError, ReadableElement};
use py_literal::Value as PyValue;
use std::io::Cursor;
#[test]
fn read_bool() {
let data = &[0x00, 0x01, 0x00, 0x00, 0x01];
let type_desc = PyValue::String(String::from("|b1"));
let out = <bool>::read_to_end_exact_vec(Cursor::new(data), &type_desc, data.len()).unwrap();
assert_eq!(out, vec![false, true, false, false, true]);
}
#[test]
fn read_bool_bad_value() {
let data = &[0x00, 0x01, 0x05, 0x00, 0x01];
let type_desc = PyValue::String(String::from("|b1"));
match <bool>::read_to_end_exact_vec(Cursor::new(data), &type_desc, data.len()) {
Err(ReadDataError::ParseData(err)) => {
assert_eq!(format!("{}", err), "error parsing value 0x05 as a bool");
}
_ => panic!(),
}
}
}