ndarray_npy/npz.rs
1use crate::{ReadNpyError, ReadNpyExt, ReadableElement, WriteNpyError, WriteNpyExt};
2use ndarray::prelude::*;
3use ndarray::DataOwned;
4use std::error::Error;
5use std::fmt;
6use std::io::{BufWriter, Read, Seek, Write};
7use zip::result::ZipError;
8use zip::write::{FileOptionExtension, FileOptions, SimpleFileOptions};
9use zip::{CompressionMethod, ZipArchive, ZipWriter};
10
11/// An error writing a `.npz` file.
12#[derive(Debug)]
13pub enum WriteNpzError {
14 /// An error caused by the zip file.
15 Zip(ZipError),
16 /// An error caused by writing an inner `.npy` file.
17 Npy(WriteNpyError),
18}
19
20impl Error for WriteNpzError {
21 fn source(&self) -> Option<&(dyn Error + 'static)> {
22 match self {
23 WriteNpzError::Zip(err) => Some(err),
24 WriteNpzError::Npy(err) => Some(err),
25 }
26 }
27}
28
29impl fmt::Display for WriteNpzError {
30 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
31 match self {
32 WriteNpzError::Zip(err) => write!(f, "zip file error: {}", err),
33 WriteNpzError::Npy(err) => write!(f, "error writing npy file to npz archive: {}", err),
34 }
35 }
36}
37
38impl From<ZipError> for WriteNpzError {
39 fn from(err: ZipError) -> WriteNpzError {
40 WriteNpzError::Zip(err)
41 }
42}
43
44impl From<WriteNpyError> for WriteNpzError {
45 fn from(err: WriteNpyError) -> WriteNpzError {
46 WriteNpzError::Npy(err)
47 }
48}
49
50/// Writer for `.npz` files.
51///
52/// Note that the inner [`ZipWriter`] is wrapped in a [`BufWriter`] when
53/// writing each array with [`.add_array()`](NpzWriter::add_array). If desired,
54/// you could additionally buffer the innermost writer (e.g. the
55/// [`File`](std::fs::File) when writing to a file) by wrapping it in a
56/// [`BufWriter`]. This may be somewhat beneficial if the arrays are large and
57/// have non-standard layouts but may decrease performance if the arrays have
58/// standard or Fortran layout, so it's not recommended without testing to
59/// compare.
60///
61/// # Example
62///
63/// ```no_run
64/// use ndarray::{array, aview0, Array1, Array2};
65/// use ndarray_npy::NpzWriter;
66/// use std::fs::File;
67///
68/// let mut npz = NpzWriter::new(File::create("arrays.npz")?);
69/// let a: Array2<i32> = array![[1, 2, 3], [4, 5, 6]];
70/// let b: Array1<i32> = array![7, 8, 9];
71/// npz.add_array("a", &a)?;
72/// npz.add_array("b", &b)?;
73/// npz.add_array("c", &aview0(&10))?;
74/// npz.finish()?;
75/// # Ok::<_, Box<dyn std::error::Error>>(())
76/// ```
77pub struct NpzWriter<W: Write + Seek> {
78 zip: ZipWriter<W>,
79 options: SimpleFileOptions,
80}
81
82impl<W: Write + Seek> NpzWriter<W> {
83 /// Create a new `.npz` file without compression. See [`numpy.savez`].
84 ///
85 /// [`numpy.savez`]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez.html
86 pub fn new(writer: W) -> NpzWriter<W> {
87 NpzWriter {
88 zip: ZipWriter::new(writer),
89 options: SimpleFileOptions::default().compression_method(CompressionMethod::Stored),
90 }
91 }
92
93 /// Creates a new `.npz` file with [`Deflated`](CompressionMethod::Deflated) compression. See
94 /// [`numpy.savez_compressed`].
95 ///
96 /// For other compression algorithms, use [`NpzWriter::new_with_options`] or
97 /// [`NpzWriter::add_array_with_options`].
98 ///
99 /// [`numpy.savez_compressed`]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez_compressed.html
100 #[cfg(feature = "compressed_npz")]
101 pub fn new_compressed(writer: W) -> NpzWriter<W> {
102 NpzWriter {
103 zip: ZipWriter::new(writer),
104 options: SimpleFileOptions::default().compression_method(CompressionMethod::Deflated),
105 }
106 }
107
108 /// Creates a new `.npz` file with the specified options to be used for each array.
109 ///
110 /// This allows you to use a custom compression method, such as zstd, or set other options.
111 ///
112 /// Make sure to enable the relevant features of the `zip` crate.
113 pub fn new_with_options(writer: W, options: SimpleFileOptions) -> NpzWriter<W> {
114 NpzWriter {
115 zip: ZipWriter::new(writer),
116 options,
117 }
118 }
119
120 /// Adds an array with the specified `name` to the `.npz` file.
121 ///
122 /// This uses the file options passed to the `NpzWriter` constructor.
123 ///
124 /// Note that a `.npy` extension will be appended to `name`; this matches NumPy's behavior.
125 ///
126 /// To write a scalar value, create a zero-dimensional array using [`arr0`](ndarray::arr0) or
127 /// [`aview0`](ndarray::aview0).
128 pub fn add_array<N, T>(&mut self, name: N, array: &T) -> Result<(), WriteNpzError>
129 where
130 N: Into<String>,
131 T: WriteNpyExt + ?Sized,
132 {
133 self.add_array_with_options(name, array, self.options)
134 }
135
136 /// Adds an array with the specified `name` and options to the `.npz` file.
137 ///
138 /// The specified options override those passed to the [`NpzWriter`] constructor (if any).
139 ///
140 /// Note that a `.npy` extension will be appended to `name`; this matches NumPy's behavior.
141 ///
142 /// To write a scalar value, create a zero-dimensional array using [`arr0`](ndarray::arr0) or
143 /// [`aview0`](ndarray::aview0).
144 pub fn add_array_with_options<N, T, U>(
145 &mut self,
146 name: N,
147 array: &T,
148 options: FileOptions<'_, U>,
149 ) -> Result<(), WriteNpzError>
150 where
151 N: Into<String>,
152 T: WriteNpyExt + ?Sized,
153 U: FileOptionExtension,
154 {
155 fn inner<W, T, U>(
156 npz_zip: &mut ZipWriter<W>,
157 name: String,
158 array: &T,
159 options: FileOptions<'_, U>,
160 ) -> Result<(), WriteNpzError>
161 where
162 W: Write + Seek,
163 T: WriteNpyExt + ?Sized,
164 U: FileOptionExtension,
165 {
166 npz_zip.start_file(name + ".npy", options)?;
167 // Buffering when writing individual arrays is beneficial even when the
168 // underlying writer is `Cursor<Vec<u8>>` instead of a real file. The
169 // only exception I saw in testing was the "compressed, in-memory
170 // writer, standard layout case". See
171 // https://github.com/jturner314/ndarray-npy/issues/50#issuecomment-812802481
172 // for details.
173 array.write_npy(BufWriter::new(npz_zip))?;
174 Ok(())
175 }
176
177 inner(&mut self.zip, name.into(), array, options)
178 }
179
180 /// Calls [`.finish()`](ZipWriter::finish) on the zip file and
181 /// [`.flush()`](Write::flush) on the writer, and then returns the writer.
182 ///
183 /// This finishes writing the remaining zip structures and flushes the
184 /// writer. While dropping will automatically attempt to finish the zip
185 /// file and (for writers that flush on drop, such as
186 /// [`BufWriter`](std::io::BufWriter)) flush the writer, any errors that
187 /// occur during drop will be silently ignored. So, it's necessary to call
188 /// `.finish()` to properly handle errors.
189 pub fn finish(self) -> Result<W, WriteNpzError> {
190 let mut writer = self.zip.finish()?;
191 writer.flush().map_err(ZipError::from)?;
192 Ok(writer)
193 }
194}
195
196/// An error reading a `.npz` file.
197#[derive(Debug)]
198pub enum ReadNpzError {
199 /// An error caused by the zip archive.
200 Zip(ZipError),
201 /// An error caused by reading an inner `.npy` file.
202 Npy(ReadNpyError),
203}
204
205impl Error for ReadNpzError {
206 fn source(&self) -> Option<&(dyn Error + 'static)> {
207 match self {
208 ReadNpzError::Zip(err) => Some(err),
209 ReadNpzError::Npy(err) => Some(err),
210 }
211 }
212}
213
214impl fmt::Display for ReadNpzError {
215 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216 match self {
217 ReadNpzError::Zip(err) => write!(f, "zip file error: {}", err),
218 ReadNpzError::Npy(err) => write!(f, "error reading npy file in npz archive: {}", err),
219 }
220 }
221}
222
223impl From<ZipError> for ReadNpzError {
224 fn from(err: ZipError) -> ReadNpzError {
225 ReadNpzError::Zip(err)
226 }
227}
228
229impl From<ReadNpyError> for ReadNpzError {
230 fn from(err: ReadNpyError) -> ReadNpzError {
231 ReadNpzError::Npy(err)
232 }
233}
234
235/// Reader for `.npz` files.
236///
237/// # Example
238///
239/// ```no_run
240/// use ndarray::{Array1, Array2};
241/// use ndarray_npy::NpzReader;
242/// use std::fs::File;
243///
244/// let mut npz = NpzReader::new(File::open("arrays.npz")?)?;
245/// let a: Array2<i32> = npz.by_name("a")?;
246/// let b: Array1<i32> = npz.by_name("b")?;
247/// # Ok::<_, Box<dyn std::error::Error>>(())
248/// ```
249pub struct NpzReader<R: Read + Seek> {
250 zip: ZipArchive<R>,
251}
252
253impl<R: Read + Seek> NpzReader<R> {
254 /// Creates a new `.npz` file reader.
255 pub fn new(reader: R) -> Result<NpzReader<R>, ReadNpzError> {
256 Ok(NpzReader {
257 zip: ZipArchive::new(reader)?,
258 })
259 }
260
261 /// Returns `true` iff the `.npz` file doesn't contain any arrays.
262 pub fn is_empty(&self) -> bool {
263 self.zip.len() == 0
264 }
265
266 /// Returns the number of arrays in the `.npz` file.
267 pub fn len(&self) -> usize {
268 self.zip.len()
269 }
270
271 /// Returns the names of all of the arrays in the file.
272 ///
273 /// Note that a single ".npy" suffix (if present) will be stripped from each name; this matches
274 /// NumPy's behavior.
275 pub fn names(&mut self) -> Result<Vec<String>, ReadNpzError> {
276 Ok((0..self.zip.len())
277 .map(|i| {
278 let file = self.zip.by_index(i)?;
279 let name = file.name();
280 let stripped = name.strip_suffix(".npy").unwrap_or(name);
281 Ok(stripped.to_owned())
282 })
283 .collect::<Result<_, ZipError>>()?)
284 }
285
286 /// Reads an array by name.
287 ///
288 /// Note that this first checks for `name` in the `.npz` file, and if that is not present,
289 /// checks for `format!("{name}.npy")`. This matches NumPy's behavior.
290 pub fn by_name<S, D>(&mut self, name: &str) -> Result<ArrayBase<S, D>, ReadNpzError>
291 where
292 S::Elem: ReadableElement,
293 S: DataOwned,
294 D: Dimension,
295 {
296 // TODO: Combine the two cases into a single `let file = match { ... }` once
297 // https://github.com/rust-lang/rust/issues/47680 is resolved.
298 match self.zip.by_name(name) {
299 Ok(file) => return Ok(ArrayBase::<S, D>::read_npy(file)?),
300 Err(ZipError::FileNotFound) => {}
301 Err(err) => return Err(err.into()),
302 };
303 Ok(ArrayBase::<S, D>::read_npy(
304 self.zip.by_name(&format!("{name}.npy"))?,
305 )?)
306 }
307
308 /// Reads an array by index in the `.npz` file.
309 pub fn by_index<S, D>(&mut self, index: usize) -> Result<ArrayBase<S, D>, ReadNpzError>
310 where
311 S::Elem: ReadableElement,
312 S: DataOwned,
313 D: Dimension,
314 {
315 Ok(ArrayBase::<S, D>::read_npy(self.zip.by_index(index)?)?)
316 }
317}