use std::collections::HashMap;
use std::io;
use crate::header::{Value, DType, read_header, convert_value_to_shape};
use crate::serialize::{Deserialize, TypeRead, DTypeError, ErrorKind as DTypeErrorKind};
#[derive(Clone)]
pub struct NpyFile<R: io::Read> {
header: NpyHeader,
reader: R,
}
#[derive(Clone)]
pub struct NpyHeader {
dtype: DType,
shape: Vec<u64>,
strides: Vec<u64>,
order: Order,
n_records: u64,
item_size: Option<usize>,
uses_pickled_array: bool,
}
impl NpyHeader {
pub fn from_reader(r: impl io::Read) -> io::Result<NpyHeader> {
NpyHeader::read_and_interpret(r)
}
}
pub struct NpyReader<T: Deserialize, R: io::Read> {
header: NpyHeader,
type_reader: <T as Deserialize>::TypeReader,
reader_and_current_index: (R, u64),
}
#[deprecated(since = "0.5.0", note = "use NpyReader instead")]
pub struct NpyData<'a, T: Deserialize> {
inner: NpyReader<T, &'a [u8]>,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Order {
C,
Fortran,
}
impl Order {
pub(crate) fn from_fortran_order(fortran_order: bool) -> Order {
if fortran_order { Order::Fortran } else { Order::C }
}
}
impl<R: io::Read> NpyFile<R> {
pub fn new(mut reader: R) -> io::Result<Self> {
let header = NpyHeader::read_and_interpret(&mut reader)?;
Ok(NpyFile { header, reader })
}
pub fn with_header(header: NpyHeader, data_reader: R) -> Self {
NpyFile { header, reader: data_reader }
}
pub fn header(&self) -> &NpyHeader {
&self.header
}
pub fn into_inner(self) -> R {
self.reader
}
}
impl<R: io::Read> std::ops::Deref for NpyFile<R> {
type Target = NpyHeader;
fn deref(&self) -> &NpyHeader {
&self.header
}
}
impl NpyHeader {
pub fn dtype(&self) -> DType {
self.dtype.clone()
}
pub fn shape(&self) -> &[u64] {
&self.shape
}
pub fn strides(&self) -> &[u64] {
&self.strides
}
pub fn order(&self) -> Order {
self.order
}
pub fn len(&self) -> u64 {
self.n_records
}
pub fn uses_pickled_array(&self) -> bool {
self.uses_pickled_array
}
fn forbid_pickle(&self) -> Result<(), DTypeError> {
if self.uses_pickled_array {
Err(DTypeError(DTypeErrorKind::RequiresPickle))
} else {
Ok(())
}
}
}
impl<R: io::Read> NpyFile<R> {
pub fn into_vec<T: Deserialize>(self) -> io::Result<Vec<T>> {
match self.data() {
Ok(r) => r.collect(),
Err(e) => Err(invalid_data(e)),
}
}
pub fn data<T: Deserialize>(self) -> Result<NpyReader<T, R>, DTypeError> {
self.forbid_pickle()?;
let NpyFile { reader, header } = self;
let type_reader = T::reader(&header.dtype)?;
Ok(NpyReader { type_reader, header, reader_and_current_index: (reader, 0) })
}
pub fn try_data<T: Deserialize>(self) -> Result<NpyReader<T, R>, Self> {
if self.uses_pickled_array {
return Err(self);
}
let type_reader = match T::reader(&self.header.dtype) {
Ok(r) => r,
Err(_) => return Err(self),
};
let NpyFile { reader, header } = self;
Ok(NpyReader { type_reader, header, reader_and_current_index: (reader, 0) })
}
}
impl NpyHeader {
fn read_and_interpret(mut r: impl io::Read) -> io::Result<NpyHeader> {
let header = read_header(&mut r)?;
let dict = match header {
Value::Dict(dict) => dict
.into_iter()
.map(|(k, v)| Ok((k.as_string().ok_or(invalid_data("key is not string"))?.to_owned(), v)))
.collect::<io::Result<HashMap<String, Value>>>()?,
_ => return Err(invalid_data("expected a python dict literal")),
};
let expect_key = |key: &str| {
dict.get(key).ok_or_else(|| invalid_data(format_args!("dict is missing key '{}'", key)))
};
let order = match expect_key("fortran_order")? {
&Value::Boolean(b) => Order::from_fortran_order(b),
_ => return Err(invalid_data(format_args!("'fortran_order' value is not a bool"))),
};
let shape = convert_value_to_shape(expect_key("shape")?)?;
let descr: &Value = expect_key("descr")?;
let dtype = DType::from_descr(descr)?;
Self::from_parts(dtype, shape, order)
}
fn from_parts(dtype: DType, shape: Vec<u64>, order: Order) -> io::Result<NpyHeader> {
let n_records = shape.iter().product();
let uses_pickled_array = dtype.uses_pickled_array();
let item_size = match uses_pickled_array {
true => None,
false => match dtype.num_bytes() {
Some(num) => Some(num),
None => Err(invalid_data(format_args!("dtype is larger than usize!")))?,
},
};
let strides = strides(order, &shape);
Ok(NpyHeader { dtype, shape, strides, order, n_records, item_size, uses_pickled_array })
}
}
impl<T: Deserialize, R: io::Read> NpyReader<T, R> {
#[inline(always)]
fn reader(&self) -> &R {
&self.reader_and_current_index.0
}
pub fn dtype(&self) -> DType {
self.header.dtype.clone()
}
pub fn shape(&self) -> &[u64] {
&self.header.shape
}
pub fn total_len(&self) -> u64 {
self.header.n_records
}
pub fn len(&self) -> u64 {
self.header.n_records - self.reader_and_current_index.1
}
}
impl<R: io::Read, T: Deserialize> NpyReader<T, R> where R: io::Seek {
pub fn seek_to(&mut self, index: u64) -> io::Result<()> {
const NO_SEEK_MSG: &str = "array with variable size elements does not support seeking";
let len = self.total_len();
assert!(index <= len, "index out of bounds for seeking (the index is {} but the len is {})", index, len);
let (reader, current_index) = &mut self.reader_and_current_index;
let delta = index as i64 - *current_index as i64;
if delta != 0 {
let item_size = self
.header
.item_size
.ok_or(io::Error::new(io::ErrorKind::Unsupported, NO_SEEK_MSG))?;
reader.seek(io::SeekFrom::Current(delta * item_size as i64))?;
*current_index = index;
}
Ok(())
}
pub fn read_at(&mut self, index: u64) -> io::Result<T> {
let len = self.total_len();
assert!(index < len, "index out of bounds for reading (the index is {} but the len is {})", index, len);
self.seek_to(index)?;
self.next().unwrap()
}
}
#[allow(deprecated)]
impl<'a, T: Deserialize> NpyData<'a, T> {
pub fn from_bytes(bytes: &'a [u8]) -> io::Result<NpyData<'a, T>> {
let inner = NpyFile::new(bytes)?.data().map_err(invalid_data)?;
if let Some(item_size) = inner.header.item_size {
assert_eq!(
item_size as u64 * inner.header.n_records,
inner.reader().len() as u64,
);
}
Ok(NpyData { inner })
}
#[inline(always)] fn get_data_slice(&self) -> &'a [u8] {
self.inner.reader()
}
pub fn len(&self) -> usize {
self.inner.total_len() as usize
}
pub fn is_empty(&self) -> bool { self.len() == 0 }
pub fn get(&self, i: usize) -> Option<T> {
if i < self.len() {
Some(self.get_unchecked(i))
} else {
None
}
}
pub fn get_unchecked(&self, i: usize) -> T {
let item_size = self.inner.header.item_size.unwrap();
let item_bytes = &self.get_data_slice()[i * item_size..];
self.inner.type_reader.read_one(item_bytes).unwrap()
}
pub fn to_vec(&self) -> Vec<T> {
let &(mut reader) = self.inner.reader();
(0..self.len()).map(|_| self.inner.type_reader.read_one(&mut reader).unwrap()).collect()
}
}
fn strides(order: Order, shape: &[u64]) -> Vec<u64> {
match order {
Order::C => {
let mut strides = prefix_products(shape.iter().rev().copied()).collect::<Vec<_>>();
strides.reverse();
strides
},
Order::Fortran => prefix_products(shape.iter().copied()).collect(),
}
}
fn prefix_products<I: IntoIterator<Item=u64>>(iter: I) -> impl Iterator<Item=u64> {
iter.into_iter().scan(1, |acc, x| { let old = *acc; *acc *= x; Some(old) })
}
fn invalid_data<S: ToString>(s: S) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, s.to_string())
}
impl<R, T> Iterator for NpyReader<T, R> where T: Deserialize, R: io::Read {
type Item = io::Result<T>;
fn next(&mut self) -> Option<Self::Item> {
let (reader, current_index) = &mut self.reader_and_current_index;
if *current_index < self.header.n_records {
*current_index += 1;
return Some(self.type_reader.read_one(reader));
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
let u64_len = self.len();
if u64_len > usize::MAX as u64 {
(usize::MAX, None)
} else {
(u64_len as usize, Some(u64_len as usize))
}
}
}
#[deprecated(since = "0.5.0", note = "NpyData is being replaced with NpyReader.")]
pub struct IntoIter<'a, T: 'a + Deserialize> {
#[allow(deprecated)]
data: NpyData<'a, T>,
i: usize,
}
#[allow(deprecated)]
impl<'a, T> IntoIter<'a, T> where T: Deserialize {
fn new(data: NpyData<'a, T>) -> Self {
IntoIter { data, i: 0 }
}
}
#[allow(deprecated)]
impl<'a, T: 'a> IntoIterator for NpyData<'a, T> where T: Deserialize {
type Item = T;
type IntoIter = IntoIter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
IntoIter::new(self)
}
}
#[allow(deprecated)]
impl<'a, T> Iterator for IntoIter<'a, T> where T: Deserialize {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.i += 1;
self.data.get(self.i - 1)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.data.len() - self.i, Some(self.data.len() - self.i))
}
}
#[allow(deprecated)]
impl<'a, T> ExactSizeIterator for IntoIter<'a, T> where T: Deserialize {}
#[cfg(test)]
mod tests {
use super::*;
use crate::write::to_bytes_1d;
#[test]
fn test_strides() {
assert_eq!(strides(Order::C, &[2, 3, 4]), vec![12, 4, 1]);
assert_eq!(strides(Order::C, &[]), vec![]);
assert_eq!(strides(Order::Fortran, &[2, 3, 4]), vec![1, 2, 6]);
assert_eq!(strides(Order::Fortran, &[]), vec![]);
}
#[test]
fn test_methods_after_partial_iteration() {
let bytes = to_bytes_1d(&[100, 101, 102, 103, 104, 105, 106]).unwrap();
let mut reader = NpyFile::new(&bytes[..]).unwrap().data().unwrap();
assert_eq!(reader.total_len(), 7);
assert_eq!(reader.len(), 7);
assert!(matches!(reader.next(), Some(Ok(100))));
assert!(matches!(reader.next(), Some(Ok(101))));
assert_eq!(reader.total_len(), 7);
assert_eq!(reader.len(), 5);
}
#[test]
fn test_next_after_finished_iteration() {
let bytes = to_bytes_1d(&[100, 101, 102, 103, 104, 105, 106]).unwrap();
let mut reader = NpyFile::new(&bytes[..]).unwrap().data::<i32>().unwrap();
assert_eq!(reader.total_len(), 7);
assert_eq!(reader.len(), 7);
assert_eq!(reader.by_ref().count(), 7);
assert!(reader.next().is_none());
assert!(reader.next().is_none());
assert_eq!(reader.total_len(), 7);
assert_eq!(reader.len(), 0); }
#[test]
fn test_methods_after_seek() {
let bytes = to_bytes_1d(&[100, 101, 102, 103, 104, 105, 106]).unwrap();
let mut reader = NpyFile::new(io::Cursor::new(&bytes[..])).unwrap().data().unwrap();
assert_eq!(reader.total_len(), 7);
assert_eq!(reader.len(), 7);
assert!(matches!(reader.next(), Some(Ok(100))));
assert!(matches!(reader.next(), Some(Ok(101))));
reader.seek_to(4).unwrap();
assert_eq!(reader.total_len(), 7);
assert_eq!(reader.len(), 3);
assert!(matches!(reader.next(), Some(Ok(104))));
assert_eq!(reader.read_at(2).unwrap(), 102);
assert_eq!(reader.len(), 4);
}
fn check_seek_panic_boundary(items: &[i32], index: u64) {
let bytes = to_bytes_1d(items).unwrap();
let mut reader = NpyFile::new(io::Cursor::new(&bytes[..])).unwrap().data::<i32>().unwrap();
let _ = reader.seek_to(index);
}
fn check_read_panic_boundary(items: &[i32], index: u64) {
let bytes = to_bytes_1d(items).unwrap();
let mut reader = NpyFile::new(io::Cursor::new(&bytes[..])).unwrap().data::<i32>().unwrap();
let _ = reader.read_at(index);
}
#[test]
fn test_seek_boundary_ok() { check_seek_panic_boundary(&[1, 2, 3], 3) }
#[test]
#[should_panic]
fn test_seek_boundary_ng() { check_seek_panic_boundary(&[1, 2, 3], 4) }
#[test]
fn test_read_boundary_ok() { check_read_panic_boundary(&[1, 2, 3], 2) }
#[test]
#[should_panic]
fn test_read_boundary_ng() { check_read_panic_boundary(&[1, 2, 3], 3) }
#[test]
fn test_reusing_header() {
let bytes = to_bytes_1d(&[100, 101, 102, 103, 104, 105, 106]).unwrap();
let mut reader = io::Cursor::new(&bytes[..]);
let header = NpyHeader::from_reader(&mut reader).unwrap();
let npy_1 = NpyFile::with_header(header.clone(), reader.clone());
let npy_2 = NpyFile::with_header(header.clone(), reader.clone());
assert_eq!(
npy_1.into_vec::<i32>().unwrap(),
npy_2.into_vec::<i32>().unwrap(),
);
}
}