use std::fs::File;
use std::io::{BufReader, Cursor, Read};
use std::path::Path;
use ferray_core::dynarray::DynArray;
use ferray_core::error::{FerrayError, FerrayResult};
use zip::ZipArchive;
use crate::npy;
type NpzReader = Box<dyn ReadSeek + Send>;
trait ReadSeek: Read + std::io::Seek {}
impl<T: Read + std::io::Seek + ?Sized> ReadSeek for T {}
pub fn savez<P: AsRef<Path>>(path: P, arrays: &[(&str, &DynArray)]) -> FerrayResult<()> {
savez_impl(path, arrays, zip::CompressionMethod::Stored)
}
pub fn savez_compressed<P: AsRef<Path>>(path: P, arrays: &[(&str, &DynArray)]) -> FerrayResult<()> {
savez_impl(path, arrays, zip::CompressionMethod::Deflated)
}
fn savez_impl<P: AsRef<Path>>(
path: P,
arrays: &[(&str, &DynArray)],
method: zip::CompressionMethod,
) -> FerrayResult<()> {
let file = File::create(path.as_ref()).map_err(|e| {
FerrayError::io_error(format!(
"failed to create .npz file '{}': {e}",
path.as_ref().display()
))
})?;
let mut zip_writer = zip::ZipWriter::new(file);
for (name, array) in arrays {
let entry_name = if std::path::Path::new(name)
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("npy"))
{
name.to_string()
} else {
format!("{name}.npy")
};
let options = zip::write::SimpleFileOptions::default().compression_method(method);
zip_writer.start_file(&entry_name, options).map_err(|e| {
FerrayError::io_error(format!("failed to create zip entry '{entry_name}': {e}"))
})?;
npy::save_dynamic_to_writer(&mut zip_writer, array)?;
}
zip_writer
.finish()
.map_err(|e| FerrayError::io_error(format!("failed to finalize .npz file: {e}")))?;
Ok(())
}
pub struct NpzFile {
archive: ZipArchive<NpzReader>,
names: Vec<String>,
}
impl NpzFile {
pub fn open<P: AsRef<Path>>(path: P) -> FerrayResult<Self> {
let file = File::open(path.as_ref()).map_err(|e| {
FerrayError::io_error(format!(
"failed to open .npz file '{}': {e}",
path.as_ref().display()
))
})?;
let reader: NpzReader = Box::new(BufReader::new(file));
Self::from_boxed_reader(reader)
}
pub fn from_reader<R: Read + std::io::Seek + Send + 'static>(reader: R) -> FerrayResult<Self> {
Self::from_boxed_reader(Box::new(reader))
}
fn from_boxed_reader(reader: NpzReader) -> FerrayResult<Self> {
let archive = ZipArchive::new(reader)
.map_err(|e| FerrayError::io_error(format!("failed to read .npz archive: {e}")))?;
let names: Vec<String> = archive
.file_names()
.map(|n| n.strip_suffix(".npy").unwrap_or(n).to_string())
.collect();
Ok(Self { archive, names })
}
pub fn names(&self) -> Vec<&str> {
self.names.iter().map(String::as_str).collect()
}
pub fn get(&mut self, name: &str) -> FerrayResult<DynArray> {
let entry_name = if self.names.iter().any(|n| n == name) {
format!("{name}.npy")
} else {
return Err(FerrayError::io_error(format!(
"array '{name}' not found in .npz archive"
)));
};
let mut entry = self.archive.by_name(&entry_name).map_err(|e| {
FerrayError::io_error(format!("failed to read .npz entry '{entry_name}': {e}"))
})?;
let mut data = Vec::new();
entry
.read_to_end(&mut data)
.map_err(|e| FerrayError::io_error(format!("failed to read .npz entry data: {e}")))?;
let mut cursor = Cursor::new(data);
npy::load_dynamic_from_reader(&mut cursor)
}
#[must_use]
pub fn len(&self) -> usize {
self.names.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.names.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::Array;
use ferray_core::dimension::IxDyn;
use ferray_core::dtype::DType;
fn temp_path(name: &str) -> (tempfile::TempDir, std::path::PathBuf) {
let dir = tempfile::TempDir::new().expect("failed to create test TempDir");
let path = dir.path().join(name);
(dir, path)
}
fn make_f64_dyn(data: Vec<f64>, shape: &[usize]) -> DynArray {
let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(shape), data).unwrap();
DynArray::F64(arr)
}
fn make_i32_dyn(data: Vec<i32>, shape: &[usize]) -> DynArray {
let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(shape), data).unwrap();
DynArray::I32(arr)
}
#[test]
fn savez_and_load() {
let a = make_f64_dyn(vec![1.0, 2.0, 3.0], &[3]);
let b = make_i32_dyn(vec![10, 20, 30, 40], &[2, 2]);
let (_dir, path) = temp_path("test.npz");
savez(&path, &[("a", &a), ("b", &b)]).unwrap();
let mut npz = NpzFile::open(&path).unwrap();
assert_eq!(npz.len(), 2);
let mut names = npz.names();
names.sort_unstable();
assert_eq!(names, vec!["a", "b"]);
let loaded_a = npz.get("a").unwrap();
assert_eq!(loaded_a.dtype(), DType::F64);
assert_eq!(loaded_a.shape(), &[3]);
let loaded_b = npz.get("b").unwrap();
assert_eq!(loaded_b.dtype(), DType::I32);
assert_eq!(loaded_b.shape(), &[2, 2]);
}
#[test]
fn savez_compressed_and_load() {
let a = make_f64_dyn(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let (_dir, path) = temp_path("test_compressed.npz");
savez_compressed(&path, &[("data", &a)]).unwrap();
let mut npz = NpzFile::open(&path).unwrap();
assert_eq!(npz.len(), 1);
let loaded = npz.get("data").unwrap();
assert_eq!(loaded.dtype(), DType::F64);
assert_eq!(loaded.shape(), &[2, 3]);
}
#[test]
fn npz_missing_key() {
let a = make_f64_dyn(vec![1.0], &[1]);
let (_dir, path) = temp_path("npz_missing.npz");
savez(&path, &[("a", &a)]).unwrap();
let mut npz = NpzFile::open(&path).unwrap();
assert!(npz.get("nonexistent").is_err());
}
#[test]
fn npz_lazy_get_then_get() {
let a = make_f64_dyn(vec![1.0, 2.0, 3.0], &[3]);
let b = make_i32_dyn(vec![10, 20, 30], &[3]);
let (_dir, path) = temp_path("test_lazy.npz");
savez(&path, &[("a", &a), ("b", &b)]).unwrap();
let mut npz = NpzFile::open(&path).unwrap();
let _b1 = npz.get("b").unwrap();
let _a = npz.get("a").unwrap();
let _b2 = npz.get("b").unwrap();
}
#[test]
fn npz_empty() {
let (_dir, path) = temp_path("npz_empty.npz");
savez(&path, &[]).unwrap();
let npz = NpzFile::open(&path).unwrap();
assert!(npz.is_empty());
assert_eq!(npz.len(), 0);
}
}