ndarray_npy/
npz.rs

1use crate::{
2    ReadNpyError, ReadNpyExt, ReadableElement, WritableElement, WriteNpyError, WriteNpyExt,
3};
4use ndarray::prelude::*;
5use ndarray::{Data, DataOwned};
6use std::error::Error;
7use std::fmt;
8use std::io::{BufWriter, Read, Seek, Write};
9use zip::result::ZipError;
10use zip::write::SimpleFileOptions;
11use zip::{CompressionMethod, ZipArchive, ZipWriter};
12
13/// An error writing a `.npz` file.
14#[derive(Debug)]
15pub enum WriteNpzError {
16    /// An error caused by the zip file.
17    Zip(ZipError),
18    /// An error caused by writing an inner `.npy` file.
19    Npy(WriteNpyError),
20}
21
22impl Error for WriteNpzError {
23    fn source(&self) -> Option<&(dyn Error + 'static)> {
24        match self {
25            WriteNpzError::Zip(err) => Some(err),
26            WriteNpzError::Npy(err) => Some(err),
27        }
28    }
29}
30
31impl fmt::Display for WriteNpzError {
32    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
33        match self {
34            WriteNpzError::Zip(err) => write!(f, "zip file error: {}", err),
35            WriteNpzError::Npy(err) => write!(f, "error writing npy file to npz archive: {}", err),
36        }
37    }
38}
39
40impl From<ZipError> for WriteNpzError {
41    fn from(err: ZipError) -> WriteNpzError {
42        WriteNpzError::Zip(err)
43    }
44}
45
46impl From<WriteNpyError> for WriteNpzError {
47    fn from(err: WriteNpyError) -> WriteNpzError {
48        WriteNpzError::Npy(err)
49    }
50}
51
52/// Writer for `.npz` files.
53///
54/// Note that the inner [`ZipWriter`] is wrapped in a [`BufWriter`] when
55/// writing each array with [`.add_array()`](NpzWriter::add_array). If desired,
56/// you could additionally buffer the innermost writer (e.g. the
57/// [`File`](std::fs::File) when writing to a file) by wrapping it in a
58/// [`BufWriter`]. This may be somewhat beneficial if the arrays are large and
59/// have non-standard layouts but may decrease performance if the arrays have
60/// standard or Fortran layout, so it's not recommended without testing to
61/// compare.
62///
63/// # Example
64///
65/// ```no_run
66/// use ndarray::{array, aview0, Array1, Array2};
67/// use ndarray_npy::NpzWriter;
68/// use std::fs::File;
69///
70/// let mut npz = NpzWriter::new(File::create("arrays.npz")?);
71/// let a: Array2<i32> = array![[1, 2, 3], [4, 5, 6]];
72/// let b: Array1<i32> = array![7, 8, 9];
73/// npz.add_array("a", &a)?;
74/// npz.add_array("b", &b)?;
75/// npz.add_array("c", &aview0(&10))?;
76/// npz.finish()?;
77/// # Ok::<_, Box<dyn std::error::Error>>(())
78/// ```
79pub struct NpzWriter<W: Write + Seek> {
80    zip: ZipWriter<W>,
81    options: SimpleFileOptions,
82}
83
84impl<W: Write + Seek> NpzWriter<W> {
85    /// Create a new `.npz` file without compression. See [`numpy.savez`].
86    ///
87    /// [`numpy.savez`]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez.html
88    pub fn new(writer: W) -> NpzWriter<W> {
89        NpzWriter {
90            zip: ZipWriter::new(writer),
91            options: SimpleFileOptions::default().compression_method(CompressionMethod::Stored),
92        }
93    }
94
95    /// Creates a new `.npz` file with compression. See [`numpy.savez_compressed`].
96    ///
97    /// [`numpy.savez_compressed`]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez_compressed.html
98    #[cfg(feature = "compressed_npz")]
99    pub fn new_compressed(writer: W) -> NpzWriter<W> {
100        NpzWriter {
101            zip: ZipWriter::new(writer),
102            options: SimpleFileOptions::default().compression_method(CompressionMethod::Deflated),
103        }
104    }
105
106    /// Creates a new `.npz` file with the specified options.
107    ///
108    /// This allows you to use a custom compression method, such as zstd, or set other options.
109    ///
110    /// Make sure to enable the relevant features of the `zip` crate.
111    pub fn new_with_options(writer: W, options: SimpleFileOptions) -> NpzWriter<W> {
112        NpzWriter {
113            zip: ZipWriter::new(writer),
114            options,
115        }
116    }
117
118    /// Adds an array with the specified `name` to the `.npz` file.
119    ///
120    /// Note that a `.npy` extension will be appended to `name`; this matches NumPy's behavior.
121    ///
122    /// To write a scalar value, create a zero-dimensional array using [`arr0`](ndarray::arr0) or
123    /// [`aview0`](ndarray::aview0).
124    pub fn add_array<N, S, D>(
125        &mut self,
126        name: N,
127        array: &ArrayBase<S, D>,
128    ) -> Result<(), WriteNpzError>
129    where
130        N: Into<String>,
131        S::Elem: WritableElement,
132        S: Data,
133        D: Dimension,
134    {
135        self.zip.start_file(name.into() + ".npy", self.options)?;
136        // Buffering when writing individual arrays is beneficial even when the
137        // underlying writer is `Cursor<Vec<u8>>` instead of a real file. The
138        // only exception I saw in testing was the "compressed, in-memory
139        // writer, standard layout case". See
140        // https://github.com/jturner314/ndarray-npy/issues/50#issuecomment-812802481
141        // for details.
142        array.write_npy(BufWriter::new(&mut self.zip))?;
143        Ok(())
144    }
145
146    /// Calls [`.finish()`](ZipWriter::finish) on the zip file and
147    /// [`.flush()`](Write::flush) on the writer, and then returns the writer.
148    ///
149    /// This finishes writing the remaining zip structures and flushes the
150    /// writer. While dropping will automatically attempt to finish the zip
151    /// file and (for writers that flush on drop, such as
152    /// [`BufWriter`](std::io::BufWriter)) flush the writer, any errors that
153    /// occur during drop will be silently ignored. So, it's necessary to call
154    /// `.finish()` to properly handle errors.
155    pub fn finish(self) -> Result<W, WriteNpzError> {
156        let mut writer = self.zip.finish()?;
157        writer.flush().map_err(ZipError::from)?;
158        Ok(writer)
159    }
160}
161
162/// An error reading a `.npz` file.
163#[derive(Debug)]
164pub enum ReadNpzError {
165    /// An error caused by the zip archive.
166    Zip(ZipError),
167    /// An error caused by reading an inner `.npy` file.
168    Npy(ReadNpyError),
169}
170
171impl Error for ReadNpzError {
172    fn source(&self) -> Option<&(dyn Error + 'static)> {
173        match self {
174            ReadNpzError::Zip(err) => Some(err),
175            ReadNpzError::Npy(err) => Some(err),
176        }
177    }
178}
179
180impl fmt::Display for ReadNpzError {
181    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
182        match self {
183            ReadNpzError::Zip(err) => write!(f, "zip file error: {}", err),
184            ReadNpzError::Npy(err) => write!(f, "error reading npy file in npz archive: {}", err),
185        }
186    }
187}
188
189impl From<ZipError> for ReadNpzError {
190    fn from(err: ZipError) -> ReadNpzError {
191        ReadNpzError::Zip(err)
192    }
193}
194
195impl From<ReadNpyError> for ReadNpzError {
196    fn from(err: ReadNpyError) -> ReadNpzError {
197        ReadNpzError::Npy(err)
198    }
199}
200
201/// Reader for `.npz` files.
202///
203/// # Example
204///
205/// ```no_run
206/// use ndarray::{Array1, Array2};
207/// use ndarray_npy::NpzReader;
208/// use std::fs::File;
209///
210/// let mut npz = NpzReader::new(File::open("arrays.npz")?)?;
211/// let a: Array2<i32> = npz.by_name("a")?;
212/// let b: Array1<i32> = npz.by_name("b")?;
213/// # Ok::<_, Box<dyn std::error::Error>>(())
214/// ```
215pub struct NpzReader<R: Read + Seek> {
216    zip: ZipArchive<R>,
217}
218
219impl<R: Read + Seek> NpzReader<R> {
220    /// Creates a new `.npz` file reader.
221    pub fn new(reader: R) -> Result<NpzReader<R>, ReadNpzError> {
222        Ok(NpzReader {
223            zip: ZipArchive::new(reader)?,
224        })
225    }
226
227    /// Returns `true` iff the `.npz` file doesn't contain any arrays.
228    pub fn is_empty(&self) -> bool {
229        self.zip.len() == 0
230    }
231
232    /// Returns the number of arrays in the `.npz` file.
233    pub fn len(&self) -> usize {
234        self.zip.len()
235    }
236
237    /// Returns the names of all of the arrays in the file.
238    ///
239    /// Note that a single ".npy" suffix (if present) will be stripped from each name; this matches
240    /// NumPy's behavior.
241    pub fn names(&mut self) -> Result<Vec<String>, ReadNpzError> {
242        Ok((0..self.zip.len())
243            .map(|i| {
244                let file = self.zip.by_index(i)?;
245                let name = file.name();
246                let stripped = name.strip_suffix(".npy").unwrap_or(name);
247                Ok(stripped.to_owned())
248            })
249            .collect::<Result<_, ZipError>>()?)
250    }
251
252    /// Reads an array by name.
253    ///
254    /// Note that this first checks for `name` in the `.npz` file, and if that is not present,
255    /// checks for `format!("{name}.npy")`. This matches NumPy's behavior.
256    pub fn by_name<S, D>(&mut self, name: &str) -> Result<ArrayBase<S, D>, ReadNpzError>
257    where
258        S::Elem: ReadableElement,
259        S: DataOwned,
260        D: Dimension,
261    {
262        // TODO: Combine the two cases into a single `let file = match { ... }` once
263        // https://github.com/rust-lang/rust/issues/47680 is resolved.
264        match self.zip.by_name(name) {
265            Ok(file) => return Ok(ArrayBase::<S, D>::read_npy(file)?),
266            Err(ZipError::FileNotFound) => {}
267            Err(err) => return Err(err.into()),
268        };
269        Ok(ArrayBase::<S, D>::read_npy(
270            self.zip.by_name(&format!("{name}.npy"))?,
271        )?)
272    }
273
274    /// Reads an array by index in the `.npz` file.
275    pub fn by_index<S, D>(&mut self, index: usize) -> Result<ArrayBase<S, D>, ReadNpzError>
276    where
277        S::Elem: ReadableElement,
278        S: DataOwned,
279        D: Dimension,
280    {
281        Ok(ArrayBase::<S, D>::read_npy(self.zip.by_index(index)?)?)
282    }
283}