use crate::{
ReadNpyError, ReadNpyExt, ReadableElement, WritableElement, WriteNpyError, WriteNpyExt
};
use ndarray::{Data, DataOwned};
use ndarray::prelude::*;
use std::error::Error;
use std::io::{self, Read, Seek, Write};
use zip::{CompressionMethod, ZipArchive, ZipWriter};
use zip::result::ZipError;
use zip::write::FileOptions;
quick_error! {
#[derive(Debug)]
pub enum WriteNpzError {
Io(err: io::Error) {
description("I/O error")
display(x) -> ("{}: {}", x.description(), err)
cause(err)
from()
}
Zip(err: ZipError) {
description("zip file error")
display(x) -> ("{}: {}", x.description(), err)
cause(err)
from()
}
Npy(err: WriteNpyError) {
description("error writing npy file to npz archive")
display(x) -> ("{}: {}", x.description(), err)
cause(err)
from()
}
}
}
pub struct NpzWriter<W: Write + Seek> {
zip: ZipWriter<W>,
options: FileOptions,
}
impl<W: Write + Seek> NpzWriter<W> {
pub fn new(writer: W) -> NpzWriter<W> {
NpzWriter {
zip: ZipWriter::new(writer),
options: FileOptions::default().compression_method(CompressionMethod::Stored),
}
}
#[cfg(feature = "compressed_npz")]
pub fn new_compressed(writer: W) -> NpzWriter<W> {
NpzWriter {
zip: ZipWriter::new(writer),
options: FileOptions::default().compression_method(CompressionMethod::Deflated),
}
}
pub fn add_array<N, S, D>(
&mut self,
name: N,
array: &ArrayBase<S, D>,
) -> Result<(), WriteNpzError>
where
N: Into<String>,
S::Elem: WritableElement,
S: Data,
D: Dimension,
{
self.zip.start_file(name, self.options)?;
array.write_npy(&mut self.zip)?;
Ok(())
}
}
quick_error! {
#[derive(Debug)]
pub enum ReadNpzError {
Io(err: io::Error) {
description("I/O error")
display(x) -> ("{}: {}", x.description(), err)
cause(err)
from()
}
Zip(err: ZipError) {
description("zip file error")
display(x) -> ("{}: {}", x.description(), err)
cause(err)
from()
}
Npy(err: ReadNpyError) {
description("error reading npy file in npz archive")
display(x) -> ("{}: {}", x.description(), err)
cause(err)
from()
}
}
}
pub struct NpzReader<R: Read + Seek> {
zip: ZipArchive<R>,
}
impl<R: Read + Seek> NpzReader<R> {
pub fn new(reader: R) -> Result<NpzReader<R>, ReadNpzError> {
Ok(NpzReader {
zip: ZipArchive::new(reader)?,
})
}
pub fn len(&self) -> usize {
self.zip.len()
}
pub fn names(&mut self) -> Result<Vec<String>, ReadNpzError> {
Ok((0..self.zip.len())
.map(|i| Ok(self.zip.by_index(i)?.name().to_owned()))
.collect::<Result<_, ZipError>>()?)
}
pub fn by_name<S, D>(&mut self, name: &str) -> Result<ArrayBase<S, D>, ReadNpzError>
where
S::Elem: ReadableElement,
S: DataOwned,
D: Dimension,
{
Ok(ArrayBase::<S, D>::read_npy(self.zip.by_name(name)?)?)
}
pub fn by_index<S, D>(&mut self, index: usize) -> Result<ArrayBase<S, D>, ReadNpzError>
where
S::Elem: ReadableElement,
S: DataOwned,
D: Dimension,
{
Ok(ArrayBase::<S, D>::read_npy(self.zip.by_index(index)?)?)
}
}