ferray-io 0.2.8

NumPy-compatible file I/O (.npy, .npz, memory-mapped) for ferray
Documentation
// ferray-io: .npz file I/O
//
// REQ-4: savez(path, &[("name", DynArray)]) writes multiple arrays to .npz (zip archive)
// REQ-5: savez_compressed(path, ...) writes gzip-compressed .npz

use std::fs::File;
use std::io::{BufReader, Cursor, Read, Write};
use std::path::Path;

use ferray_core::dynarray::DynArray;
use ferray_core::error::{FerrayError, FerrayResult};

use crate::npy;

/// Save multiple arrays to an uncompressed `.npz` file (zip archive of `.npy` files).
///
/// Each entry in `arrays` is a `(name, array)` pair. The name is used as the
/// filename inside the archive (`.npy` extension is appended automatically).
///
/// # Errors
/// Returns `FerrayError::IoError` on file creation or write failures.
pub fn savez<P: AsRef<Path>>(path: P, arrays: &[(&str, &DynArray)]) -> FerrayResult<()> {
    let file = File::create(path.as_ref()).map_err(|e| {
        FerrayError::io_error(format!(
            "failed to create .npz file '{}': {e}",
            path.as_ref().display()
        ))
    })?;

    let mut zip_writer = zip::ZipWriter::new(file);

    for (name, array) in arrays {
        let entry_name = if name.ends_with(".npy") {
            name.to_string()
        } else {
            format!("{name}.npy")
        };

        let options = zip::write::SimpleFileOptions::default()
            .compression_method(zip::CompressionMethod::Stored);

        zip_writer.start_file(&entry_name, options).map_err(|e| {
            FerrayError::io_error(format!("failed to create zip entry '{entry_name}': {e}"))
        })?;

        // Write the .npy data into a buffer first
        let mut npy_buf = Vec::new();
        npy::save_dynamic_to_writer(&mut npy_buf, array)?;

        zip_writer.write_all(&npy_buf).map_err(|e| {
            FerrayError::io_error(format!("failed to write zip entry '{entry_name}': {e}"))
        })?;
    }

    zip_writer
        .finish()
        .map_err(|e| FerrayError::io_error(format!("failed to finalize .npz file: {e}")))?;

    Ok(())
}

/// Save multiple arrays to a gzip-compressed `.npz` file.
///
/// Same as [`savez`] but each entry is individually compressed with DEFLATE.
///
/// # Errors
/// Returns `FerrayError::IoError` on file creation or write failures.
pub fn savez_compressed<P: AsRef<Path>>(path: P, arrays: &[(&str, &DynArray)]) -> FerrayResult<()> {
    let file = File::create(path.as_ref()).map_err(|e| {
        FerrayError::io_error(format!(
            "failed to create .npz file '{}': {e}",
            path.as_ref().display()
        ))
    })?;

    let mut zip_writer = zip::ZipWriter::new(file);

    for (name, array) in arrays {
        let entry_name = if name.ends_with(".npy") {
            name.to_string()
        } else {
            format!("{name}.npy")
        };

        let options = zip::write::SimpleFileOptions::default()
            .compression_method(zip::CompressionMethod::Deflated);

        zip_writer.start_file(&entry_name, options).map_err(|e| {
            FerrayError::io_error(format!("failed to create zip entry '{entry_name}': {e}"))
        })?;

        let mut npy_buf = Vec::new();
        npy::save_dynamic_to_writer(&mut npy_buf, array)?;

        zip_writer.write_all(&npy_buf).map_err(|e| {
            FerrayError::io_error(format!("failed to write zip entry '{entry_name}': {e}"))
        })?;
    }

    zip_writer
        .finish()
        .map_err(|e| FerrayError::io_error(format!("failed to finalize .npz file: {e}")))?;

    Ok(())
}

/// A loaded `.npz` archive, providing access to named arrays.
pub struct NpzFile {
    /// The entries in the archive, keyed by name (without .npy extension).
    entries: Vec<(String, Vec<u8>)>,
}

impl NpzFile {
    /// Open and read a `.npz` file.
    ///
    /// All entries are read into memory. Use the [`get`](Self::get) method to
    /// retrieve individual arrays.
    pub fn open<P: AsRef<Path>>(path: P) -> FerrayResult<Self> {
        let file = File::open(path.as_ref()).map_err(|e| {
            FerrayError::io_error(format!(
                "failed to open .npz file '{}': {e}",
                path.as_ref().display()
            ))
        })?;
        let reader = BufReader::new(file);
        Self::from_reader(reader)
    }

    /// Read a `.npz` from a reader.
    pub fn from_reader<R: Read + std::io::Seek>(reader: R) -> FerrayResult<Self> {
        let mut archive = zip::ZipArchive::new(reader)
            .map_err(|e| FerrayError::io_error(format!("failed to read .npz archive: {e}")))?;

        let mut entries = Vec::new();
        for i in 0..archive.len() {
            let mut entry = archive.by_index(i).map_err(|e| {
                FerrayError::io_error(format!("failed to read .npz entry {i}: {e}"))
            })?;

            let name = entry
                .name()
                .strip_suffix(".npy")
                .unwrap_or(entry.name())
                .to_string();

            let mut data = Vec::new();
            entry.read_to_end(&mut data).map_err(|e| {
                FerrayError::io_error(format!("failed to read .npz entry data: {e}"))
            })?;

            entries.push((name, data));
        }

        Ok(Self { entries })
    }

    /// List the names of arrays stored in the archive.
    pub fn names(&self) -> Vec<&str> {
        self.entries.iter().map(|(name, _)| name.as_str()).collect()
    }

    /// Retrieve a named array as a `DynArray`.
    ///
    /// # Errors
    /// Returns `FerrayError::IoError` if the name is not found or the data is invalid.
    pub fn get(&self, name: &str) -> FerrayResult<DynArray> {
        let data = self
            .entries
            .iter()
            .find(|(n, _)| n == name)
            .map(|(_, d)| d)
            .ok_or_else(|| {
                FerrayError::io_error(format!("array '{name}' not found in .npz archive"))
            })?;

        let mut cursor = Cursor::new(data);
        npy::load_dynamic_from_reader(&mut cursor)
    }

    /// Number of arrays in the archive.
    pub fn len(&self) -> usize {
        self.entries.len()
    }

    /// Whether the archive is empty.
    pub fn is_empty(&self) -> bool {
        self.entries.is_empty()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use ferray_core::Array;
    use ferray_core::dimension::IxDyn;
    use ferray_core::dtype::DType;

    fn test_dir() -> std::path::PathBuf {
        let dir = std::env::temp_dir().join(format!("ferray_io_npz_{}", std::process::id()));
        let _ = std::fs::create_dir_all(&dir);
        dir
    }

    fn test_file(name: &str) -> std::path::PathBuf {
        test_dir().join(name)
    }

    fn make_f64_dyn(data: Vec<f64>, shape: &[usize]) -> DynArray {
        let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(shape), data).unwrap();
        DynArray::F64(arr)
    }

    fn make_i32_dyn(data: Vec<i32>, shape: &[usize]) -> DynArray {
        let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(shape), data).unwrap();
        DynArray::I32(arr)
    }

    #[test]
    fn savez_and_load() {
        let a = make_f64_dyn(vec![1.0, 2.0, 3.0], &[3]);
        let b = make_i32_dyn(vec![10, 20, 30, 40], &[2, 2]);

        let path = test_file("test.npz");
        savez(&path, &[("a", &a), ("b", &b)]).unwrap();

        let npz = NpzFile::open(&path).unwrap();
        assert_eq!(npz.len(), 2);

        let mut names = npz.names();
        names.sort();
        assert_eq!(names, vec!["a", "b"]);

        let loaded_a = npz.get("a").unwrap();
        assert_eq!(loaded_a.dtype(), DType::F64);
        assert_eq!(loaded_a.shape(), &[3]);

        let loaded_b = npz.get("b").unwrap();
        assert_eq!(loaded_b.dtype(), DType::I32);
        assert_eq!(loaded_b.shape(), &[2, 2]);
        let _ = std::fs::remove_file(&path);
    }

    #[test]
    fn savez_compressed_and_load() {
        let a = make_f64_dyn(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);

        let path = test_file("test_compressed.npz");
        savez_compressed(&path, &[("data", &a)]).unwrap();

        let npz = NpzFile::open(&path).unwrap();
        assert_eq!(npz.len(), 1);

        let loaded = npz.get("data").unwrap();
        assert_eq!(loaded.dtype(), DType::F64);
        assert_eq!(loaded.shape(), &[2, 3]);
        let _ = std::fs::remove_file(&path);
    }

    #[test]
    fn npz_missing_key() {
        let a = make_f64_dyn(vec![1.0], &[1]);

        let path = test_file("npz_missing.npz");
        savez(&path, &[("a", &a)]).unwrap();

        let npz = NpzFile::open(&path).unwrap();
        assert!(npz.get("nonexistent").is_err());
        let _ = std::fs::remove_file(&path);
    }

    #[test]
    fn npz_empty() {
        let path = test_file("npz_empty.npz");
        savez(&path, &[]).unwrap();

        let npz = NpzFile::open(&path).unwrap();
        assert!(npz.is_empty());
        assert_eq!(npz.len(), 0);
        let _ = std::fs::remove_file(&path);
    }
}