use crate::{
ReadNpyError, ReadNpyExt, ReadableElement, WritableElement, WriteNpyError, WriteNpyExt,
};
use ndarray::prelude::*;
use ndarray::{Data, DataOwned};
use std::error::Error;
use std::fmt;
use std::io::{BufWriter, Read, Seek, Write};
use zip::result::ZipError;
use zip::write::FileOptions;
use zip::{CompressionMethod, ZipArchive, ZipWriter};
#[derive(Debug)]
pub enum WriteNpzError {
Zip(ZipError),
Npy(WriteNpyError),
}
impl Error for WriteNpzError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
WriteNpzError::Zip(err) => Some(err),
WriteNpzError::Npy(err) => Some(err),
}
}
}
impl fmt::Display for WriteNpzError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WriteNpzError::Zip(err) => write!(f, "zip file error: {}", err),
WriteNpzError::Npy(err) => write!(f, "error writing npy file to npz archive: {}", err),
}
}
}
impl From<ZipError> for WriteNpzError {
fn from(err: ZipError) -> WriteNpzError {
WriteNpzError::Zip(err)
}
}
impl From<WriteNpyError> for WriteNpzError {
fn from(err: WriteNpyError) -> WriteNpzError {
WriteNpzError::Npy(err)
}
}
pub struct NpzWriter<W: Write + Seek> {
zip: ZipWriter<W>,
options: FileOptions,
}
impl<W: Write + Seek> NpzWriter<W> {
pub fn new(writer: W) -> NpzWriter<W> {
NpzWriter {
zip: ZipWriter::new(writer),
options: FileOptions::default().compression_method(CompressionMethod::Stored),
}
}
#[cfg(feature = "compressed_npz")]
pub fn new_compressed(writer: W) -> NpzWriter<W> {
NpzWriter {
zip: ZipWriter::new(writer),
options: FileOptions::default().compression_method(CompressionMethod::Deflated),
}
}
pub fn add_array<N, S, D>(
&mut self,
name: N,
array: &ArrayBase<S, D>,
) -> Result<(), WriteNpzError>
where
N: Into<String>,
S::Elem: WritableElement,
S: Data,
D: Dimension,
{
self.zip.start_file(name, self.options)?;
array.write_npy(BufWriter::new(&mut self.zip))?;
Ok(())
}
pub fn finish(mut self) -> Result<W, WriteNpzError> {
let mut writer = self.zip.finish()?;
writer.flush().map_err(ZipError::from)?;
Ok(writer)
}
}
#[derive(Debug)]
pub enum ReadNpzError {
Zip(ZipError),
Npy(ReadNpyError),
}
impl Error for ReadNpzError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
ReadNpzError::Zip(err) => Some(err),
ReadNpzError::Npy(err) => Some(err),
}
}
}
impl fmt::Display for ReadNpzError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ReadNpzError::Zip(err) => write!(f, "zip file error: {}", err),
ReadNpzError::Npy(err) => write!(f, "error reading npy file in npz archive: {}", err),
}
}
}
impl From<ZipError> for ReadNpzError {
fn from(err: ZipError) -> ReadNpzError {
ReadNpzError::Zip(err)
}
}
impl From<ReadNpyError> for ReadNpzError {
fn from(err: ReadNpyError) -> ReadNpzError {
ReadNpzError::Npy(err)
}
}
pub struct NpzReader<R: Read + Seek> {
zip: ZipArchive<R>,
}
impl<R: Read + Seek> NpzReader<R> {
pub fn new(reader: R) -> Result<NpzReader<R>, ReadNpzError> {
Ok(NpzReader {
zip: ZipArchive::new(reader)?,
})
}
pub fn is_empty(&self) -> bool {
self.zip.len() == 0
}
pub fn len(&self) -> usize {
self.zip.len()
}
pub fn names(&mut self) -> Result<Vec<String>, ReadNpzError> {
Ok((0..self.zip.len())
.map(|i| Ok(self.zip.by_index(i)?.name().to_owned()))
.collect::<Result<_, ZipError>>()?)
}
pub fn by_name<S, D>(&mut self, name: &str) -> Result<ArrayBase<S, D>, ReadNpzError>
where
S::Elem: ReadableElement,
S: DataOwned,
D: Dimension,
{
Ok(ArrayBase::<S, D>::read_npy(self.zip.by_name(name)?)?)
}
pub fn by_index<S, D>(&mut self, index: usize) -> Result<ArrayBase<S, D>, ReadNpzError>
where
S::Elem: ReadableElement,
S: DataOwned,
D: Dimension,
{
Ok(ArrayBase::<S, D>::read_npy(self.zip.by_index(index)?)?)
}
}