ndarray_npy/
npz.rs

1use crate::{ReadNpyError, ReadNpyExt, ReadableElement, WriteNpyError, WriteNpyExt};
2use ndarray::prelude::*;
3use ndarray::DataOwned;
4use std::error::Error;
5use std::fmt;
6use std::io::{BufWriter, Read, Seek, Write};
7use zip::result::ZipError;
8use zip::write::{FileOptionExtension, FileOptions, SimpleFileOptions};
9use zip::{CompressionMethod, ZipArchive, ZipWriter};
10
11/// An error writing a `.npz` file.
12#[derive(Debug)]
13pub enum WriteNpzError {
14    /// An error caused by the zip file.
15    Zip(ZipError),
16    /// An error caused by writing an inner `.npy` file.
17    Npy(WriteNpyError),
18}
19
20impl Error for WriteNpzError {
21    fn source(&self) -> Option<&(dyn Error + 'static)> {
22        match self {
23            WriteNpzError::Zip(err) => Some(err),
24            WriteNpzError::Npy(err) => Some(err),
25        }
26    }
27}
28
29impl fmt::Display for WriteNpzError {
30    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
31        match self {
32            WriteNpzError::Zip(err) => write!(f, "zip file error: {}", err),
33            WriteNpzError::Npy(err) => write!(f, "error writing npy file to npz archive: {}", err),
34        }
35    }
36}
37
38impl From<ZipError> for WriteNpzError {
39    fn from(err: ZipError) -> WriteNpzError {
40        WriteNpzError::Zip(err)
41    }
42}
43
44impl From<WriteNpyError> for WriteNpzError {
45    fn from(err: WriteNpyError) -> WriteNpzError {
46        WriteNpzError::Npy(err)
47    }
48}
49
50/// Writer for `.npz` files.
51///
52/// Note that the inner [`ZipWriter`] is wrapped in a [`BufWriter`] when
53/// writing each array with [`.add_array()`](NpzWriter::add_array). If desired,
54/// you could additionally buffer the innermost writer (e.g. the
55/// [`File`](std::fs::File) when writing to a file) by wrapping it in a
56/// [`BufWriter`]. This may be somewhat beneficial if the arrays are large and
57/// have non-standard layouts but may decrease performance if the arrays have
58/// standard or Fortran layout, so it's not recommended without testing to
59/// compare.
60///
61/// # Example
62///
63/// ```no_run
64/// use ndarray::{array, aview0, Array1, Array2};
65/// use ndarray_npy::NpzWriter;
66/// use std::fs::File;
67///
68/// let mut npz = NpzWriter::new(File::create("arrays.npz")?);
69/// let a: Array2<i32> = array![[1, 2, 3], [4, 5, 6]];
70/// let b: Array1<i32> = array![7, 8, 9];
71/// npz.add_array("a", &a)?;
72/// npz.add_array("b", &b)?;
73/// npz.add_array("c", &aview0(&10))?;
74/// npz.finish()?;
75/// # Ok::<_, Box<dyn std::error::Error>>(())
76/// ```
77pub struct NpzWriter<W: Write + Seek> {
78    zip: ZipWriter<W>,
79    options: SimpleFileOptions,
80}
81
82impl<W: Write + Seek> NpzWriter<W> {
83    /// Create a new `.npz` file without compression. See [`numpy.savez`].
84    ///
85    /// [`numpy.savez`]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez.html
86    pub fn new(writer: W) -> NpzWriter<W> {
87        NpzWriter {
88            zip: ZipWriter::new(writer),
89            options: SimpleFileOptions::default().compression_method(CompressionMethod::Stored),
90        }
91    }
92
93    /// Creates a new `.npz` file with [`Deflated`](CompressionMethod::Deflated) compression. See
94    /// [`numpy.savez_compressed`].
95    ///
96    /// For other compression algorithms, use [`NpzWriter::new_with_options`] or
97    /// [`NpzWriter::add_array_with_options`].
98    ///
99    /// [`numpy.savez_compressed`]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez_compressed.html
100    #[cfg(feature = "compressed_npz")]
101    pub fn new_compressed(writer: W) -> NpzWriter<W> {
102        NpzWriter {
103            zip: ZipWriter::new(writer),
104            options: SimpleFileOptions::default().compression_method(CompressionMethod::Deflated),
105        }
106    }
107
108    /// Creates a new `.npz` file with the specified options to be used for each array.
109    ///
110    /// This allows you to use a custom compression method, such as zstd, or set other options.
111    ///
112    /// Make sure to enable the relevant features of the `zip` crate.
113    pub fn new_with_options(writer: W, options: SimpleFileOptions) -> NpzWriter<W> {
114        NpzWriter {
115            zip: ZipWriter::new(writer),
116            options,
117        }
118    }
119
120    /// Adds an array with the specified `name` to the `.npz` file.
121    ///
122    /// This uses the file options passed to the `NpzWriter` constructor.
123    ///
124    /// Note that a `.npy` extension will be appended to `name`; this matches NumPy's behavior.
125    ///
126    /// To write a scalar value, create a zero-dimensional array using [`arr0`](ndarray::arr0) or
127    /// [`aview0`](ndarray::aview0).
128    pub fn add_array<N, T>(&mut self, name: N, array: &T) -> Result<(), WriteNpzError>
129    where
130        N: Into<String>,
131        T: WriteNpyExt + ?Sized,
132    {
133        self.add_array_with_options(name, array, self.options)
134    }
135
136    /// Adds an array with the specified `name` and options to the `.npz` file.
137    ///
138    /// The specified options override those passed to the [`NpzWriter`] constructor (if any).
139    ///
140    /// Note that a `.npy` extension will be appended to `name`; this matches NumPy's behavior.
141    ///
142    /// To write a scalar value, create a zero-dimensional array using [`arr0`](ndarray::arr0) or
143    /// [`aview0`](ndarray::aview0).
144    pub fn add_array_with_options<N, T, U>(
145        &mut self,
146        name: N,
147        array: &T,
148        options: FileOptions<'_, U>,
149    ) -> Result<(), WriteNpzError>
150    where
151        N: Into<String>,
152        T: WriteNpyExt + ?Sized,
153        U: FileOptionExtension,
154    {
155        fn inner<W, T, U>(
156            npz_zip: &mut ZipWriter<W>,
157            name: String,
158            array: &T,
159            options: FileOptions<'_, U>,
160        ) -> Result<(), WriteNpzError>
161        where
162            W: Write + Seek,
163            T: WriteNpyExt + ?Sized,
164            U: FileOptionExtension,
165        {
166            npz_zip.start_file(name + ".npy", options)?;
167            // Buffering when writing individual arrays is beneficial even when the
168            // underlying writer is `Cursor<Vec<u8>>` instead of a real file. The
169            // only exception I saw in testing was the "compressed, in-memory
170            // writer, standard layout case". See
171            // https://github.com/jturner314/ndarray-npy/issues/50#issuecomment-812802481
172            // for details.
173            array.write_npy(BufWriter::new(npz_zip))?;
174            Ok(())
175        }
176
177        inner(&mut self.zip, name.into(), array, options)
178    }
179
180    /// Calls [`.finish()`](ZipWriter::finish) on the zip file and
181    /// [`.flush()`](Write::flush) on the writer, and then returns the writer.
182    ///
183    /// This finishes writing the remaining zip structures and flushes the
184    /// writer. While dropping will automatically attempt to finish the zip
185    /// file and (for writers that flush on drop, such as
186    /// [`BufWriter`](std::io::BufWriter)) flush the writer, any errors that
187    /// occur during drop will be silently ignored. So, it's necessary to call
188    /// `.finish()` to properly handle errors.
189    pub fn finish(self) -> Result<W, WriteNpzError> {
190        let mut writer = self.zip.finish()?;
191        writer.flush().map_err(ZipError::from)?;
192        Ok(writer)
193    }
194}
195
196/// An error reading a `.npz` file.
197#[derive(Debug)]
198pub enum ReadNpzError {
199    /// An error caused by the zip archive.
200    Zip(ZipError),
201    /// An error caused by reading an inner `.npy` file.
202    Npy(ReadNpyError),
203}
204
205impl Error for ReadNpzError {
206    fn source(&self) -> Option<&(dyn Error + 'static)> {
207        match self {
208            ReadNpzError::Zip(err) => Some(err),
209            ReadNpzError::Npy(err) => Some(err),
210        }
211    }
212}
213
214impl fmt::Display for ReadNpzError {
215    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216        match self {
217            ReadNpzError::Zip(err) => write!(f, "zip file error: {}", err),
218            ReadNpzError::Npy(err) => write!(f, "error reading npy file in npz archive: {}", err),
219        }
220    }
221}
222
223impl From<ZipError> for ReadNpzError {
224    fn from(err: ZipError) -> ReadNpzError {
225        ReadNpzError::Zip(err)
226    }
227}
228
229impl From<ReadNpyError> for ReadNpzError {
230    fn from(err: ReadNpyError) -> ReadNpzError {
231        ReadNpzError::Npy(err)
232    }
233}
234
235/// Reader for `.npz` files.
236///
237/// # Example
238///
239/// ```no_run
240/// use ndarray::{Array1, Array2};
241/// use ndarray_npy::NpzReader;
242/// use std::fs::File;
243///
244/// let mut npz = NpzReader::new(File::open("arrays.npz")?)?;
245/// let a: Array2<i32> = npz.by_name("a")?;
246/// let b: Array1<i32> = npz.by_name("b")?;
247/// # Ok::<_, Box<dyn std::error::Error>>(())
248/// ```
249pub struct NpzReader<R: Read + Seek> {
250    zip: ZipArchive<R>,
251}
252
253impl<R: Read + Seek> NpzReader<R> {
254    /// Creates a new `.npz` file reader.
255    pub fn new(reader: R) -> Result<NpzReader<R>, ReadNpzError> {
256        Ok(NpzReader {
257            zip: ZipArchive::new(reader)?,
258        })
259    }
260
261    /// Returns `true` iff the `.npz` file doesn't contain any arrays.
262    pub fn is_empty(&self) -> bool {
263        self.zip.len() == 0
264    }
265
266    /// Returns the number of arrays in the `.npz` file.
267    pub fn len(&self) -> usize {
268        self.zip.len()
269    }
270
271    /// Returns the names of all of the arrays in the file.
272    ///
273    /// Note that a single ".npy" suffix (if present) will be stripped from each name; this matches
274    /// NumPy's behavior.
275    pub fn names(&mut self) -> Result<Vec<String>, ReadNpzError> {
276        Ok((0..self.zip.len())
277            .map(|i| {
278                let file = self.zip.by_index(i)?;
279                let name = file.name();
280                let stripped = name.strip_suffix(".npy").unwrap_or(name);
281                Ok(stripped.to_owned())
282            })
283            .collect::<Result<_, ZipError>>()?)
284    }
285
286    /// Reads an array by name.
287    ///
288    /// Note that this first checks for `name` in the `.npz` file, and if that is not present,
289    /// checks for `format!("{name}.npy")`. This matches NumPy's behavior.
290    pub fn by_name<S, D>(&mut self, name: &str) -> Result<ArrayBase<S, D>, ReadNpzError>
291    where
292        S::Elem: ReadableElement,
293        S: DataOwned,
294        D: Dimension,
295    {
296        // TODO: Combine the two cases into a single `let file = match { ... }` once
297        // https://github.com/rust-lang/rust/issues/47680 is resolved.
298        match self.zip.by_name(name) {
299            Ok(file) => return Ok(ArrayBase::<S, D>::read_npy(file)?),
300            Err(ZipError::FileNotFound) => {}
301            Err(err) => return Err(err.into()),
302        };
303        Ok(ArrayBase::<S, D>::read_npy(
304            self.zip.by_name(&format!("{name}.npy"))?,
305        )?)
306    }
307
308    /// Reads an array by index in the `.npz` file.
309    pub fn by_index<S, D>(&mut self, index: usize) -> Result<ArrayBase<S, D>, ReadNpzError>
310    where
311        S::Elem: ReadableElement,
312        S: DataOwned,
313        D: Dimension,
314    {
315        Ok(ArrayBase::<S, D>::read_npy(self.zip.by_index(index)?)?)
316    }
317}