use crate::{ReadNpyError, ReadNpyExt, ReadableElement, WriteNpyError, WriteNpyExt};
use ndarray::prelude::*;
use ndarray::DataOwned;
use std::error::Error;
use std::fmt;
use std::io::{BufWriter, Read, Seek, Write};
use zip::result::ZipError;
use zip::write::{FileOptionExtension, FileOptions, SimpleFileOptions};
use zip::{CompressionMethod, ZipArchive, ZipWriter};
#[derive(Debug)]
pub enum WriteNpzError {
Zip(ZipError),
Npy(WriteNpyError),
}
impl Error for WriteNpzError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
WriteNpzError::Zip(err) => Some(err),
WriteNpzError::Npy(err) => Some(err),
}
}
}
impl fmt::Display for WriteNpzError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WriteNpzError::Zip(err) => write!(f, "zip file error: {}", err),
WriteNpzError::Npy(err) => write!(f, "error writing npy file to npz archive: {}", err),
}
}
}
impl From<ZipError> for WriteNpzError {
fn from(err: ZipError) -> WriteNpzError {
WriteNpzError::Zip(err)
}
}
impl From<WriteNpyError> for WriteNpzError {
fn from(err: WriteNpyError) -> WriteNpzError {
WriteNpzError::Npy(err)
}
}
pub struct NpzWriter<W: Write + Seek> {
zip: ZipWriter<W>,
options: SimpleFileOptions,
}
impl<W: Write + Seek> NpzWriter<W> {
pub fn new(writer: W) -> NpzWriter<W> {
NpzWriter {
zip: ZipWriter::new(writer),
options: SimpleFileOptions::default().compression_method(CompressionMethod::Stored),
}
}
#[cfg(feature = "compressed_npz")]
pub fn new_compressed(writer: W) -> NpzWriter<W> {
NpzWriter {
zip: ZipWriter::new(writer),
options: SimpleFileOptions::default().compression_method(CompressionMethod::Deflated),
}
}
pub fn new_with_options(writer: W, options: SimpleFileOptions) -> NpzWriter<W> {
NpzWriter {
zip: ZipWriter::new(writer),
options,
}
}
pub fn add_array<N, T>(&mut self, name: N, array: &T) -> Result<(), WriteNpzError>
where
N: Into<String>,
T: WriteNpyExt + ?Sized,
{
self.add_array_with_options(name, array, self.options)
}
pub fn add_array_with_options<N, T, U>(
&mut self,
name: N,
array: &T,
options: FileOptions<'_, U>,
) -> Result<(), WriteNpzError>
where
N: Into<String>,
T: WriteNpyExt + ?Sized,
U: FileOptionExtension,
{
fn inner<W, T, U>(
npz_zip: &mut ZipWriter<W>,
name: String,
array: &T,
options: FileOptions<'_, U>,
) -> Result<(), WriteNpzError>
where
W: Write + Seek,
T: WriteNpyExt + ?Sized,
U: FileOptionExtension,
{
npz_zip.start_file(name + ".npy", options)?;
array.write_npy(BufWriter::new(npz_zip))?;
Ok(())
}
inner(&mut self.zip, name.into(), array, options)
}
pub fn finish(self) -> Result<W, WriteNpzError> {
let mut writer = self.zip.finish()?;
writer.flush().map_err(ZipError::from)?;
Ok(writer)
}
}
#[derive(Debug)]
pub enum ReadNpzError {
Zip(ZipError),
Npy(ReadNpyError),
}
impl Error for ReadNpzError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
ReadNpzError::Zip(err) => Some(err),
ReadNpzError::Npy(err) => Some(err),
}
}
}
impl fmt::Display for ReadNpzError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ReadNpzError::Zip(err) => write!(f, "zip file error: {}", err),
ReadNpzError::Npy(err) => write!(f, "error reading npy file in npz archive: {}", err),
}
}
}
impl From<ZipError> for ReadNpzError {
fn from(err: ZipError) -> ReadNpzError {
ReadNpzError::Zip(err)
}
}
impl From<ReadNpyError> for ReadNpzError {
fn from(err: ReadNpyError) -> ReadNpzError {
ReadNpzError::Npy(err)
}
}
pub struct NpzReader<R: Read + Seek> {
zip: ZipArchive<R>,
}
impl<R: Read + Seek> NpzReader<R> {
pub fn new(reader: R) -> Result<NpzReader<R>, ReadNpzError> {
Ok(NpzReader {
zip: ZipArchive::new(reader)?,
})
}
pub fn is_empty(&self) -> bool {
self.zip.len() == 0
}
pub fn len(&self) -> usize {
self.zip.len()
}
pub fn names(&mut self) -> Result<Vec<String>, ReadNpzError> {
Ok((0..self.zip.len())
.map(|i| {
let file = self.zip.by_index(i)?;
let name = file.name();
let stripped = name.strip_suffix(".npy").unwrap_or(name);
Ok(stripped.to_owned())
})
.collect::<Result<_, ZipError>>()?)
}
pub fn by_name<S, D>(&mut self, name: &str) -> Result<ArrayBase<S, D>, ReadNpzError>
where
S::Elem: ReadableElement,
S: DataOwned,
D: Dimension,
{
match self.zip.by_name(name) {
Ok(file) => return Ok(ArrayBase::<S, D>::read_npy(file)?),
Err(ZipError::FileNotFound) => {}
Err(err) => return Err(err.into()),
};
Ok(ArrayBase::<S, D>::read_npy(
self.zip.by_name(&format!("{name}.npy"))?,
)?)
}
pub fn by_index<S, D>(&mut self, index: usize) -> Result<ArrayBase<S, D>, ReadNpzError>
where
S::Elem: ReadableElement,
S: DataOwned,
D: Dimension,
{
Ok(ArrayBase::<S, D>::read_npy(self.zip.by_index(index)?)?)
}
}