ndarray_npy/npz.rs
1use crate::{
2 ReadNpyError, ReadNpyExt, ReadableElement, WritableElement, WriteNpyError, WriteNpyExt,
3};
4use ndarray::prelude::*;
5use ndarray::{Data, DataOwned};
6use std::error::Error;
7use std::fmt;
8use std::io::{BufWriter, Read, Seek, Write};
9use zip::result::ZipError;
10use zip::write::SimpleFileOptions;
11use zip::{CompressionMethod, ZipArchive, ZipWriter};
12
13/// An error writing a `.npz` file.
14#[derive(Debug)]
15pub enum WriteNpzError {
16 /// An error caused by the zip file.
17 Zip(ZipError),
18 /// An error caused by writing an inner `.npy` file.
19 Npy(WriteNpyError),
20}
21
22impl Error for WriteNpzError {
23 fn source(&self) -> Option<&(dyn Error + 'static)> {
24 match self {
25 WriteNpzError::Zip(err) => Some(err),
26 WriteNpzError::Npy(err) => Some(err),
27 }
28 }
29}
30
31impl fmt::Display for WriteNpzError {
32 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
33 match self {
34 WriteNpzError::Zip(err) => write!(f, "zip file error: {}", err),
35 WriteNpzError::Npy(err) => write!(f, "error writing npy file to npz archive: {}", err),
36 }
37 }
38}
39
40impl From<ZipError> for WriteNpzError {
41 fn from(err: ZipError) -> WriteNpzError {
42 WriteNpzError::Zip(err)
43 }
44}
45
46impl From<WriteNpyError> for WriteNpzError {
47 fn from(err: WriteNpyError) -> WriteNpzError {
48 WriteNpzError::Npy(err)
49 }
50}
51
52/// Writer for `.npz` files.
53///
54/// Note that the inner [`ZipWriter`] is wrapped in a [`BufWriter`] when
55/// writing each array with [`.add_array()`](NpzWriter::add_array). If desired,
56/// you could additionally buffer the innermost writer (e.g. the
57/// [`File`](std::fs::File) when writing to a file) by wrapping it in a
58/// [`BufWriter`]. This may be somewhat beneficial if the arrays are large and
59/// have non-standard layouts but may decrease performance if the arrays have
60/// standard or Fortran layout, so it's not recommended without testing to
61/// compare.
62///
63/// # Example
64///
65/// ```no_run
66/// use ndarray::{array, aview0, Array1, Array2};
67/// use ndarray_npy::NpzWriter;
68/// use std::fs::File;
69///
70/// let mut npz = NpzWriter::new(File::create("arrays.npz")?);
71/// let a: Array2<i32> = array![[1, 2, 3], [4, 5, 6]];
72/// let b: Array1<i32> = array![7, 8, 9];
73/// npz.add_array("a", &a)?;
74/// npz.add_array("b", &b)?;
75/// npz.add_array("c", &aview0(&10))?;
76/// npz.finish()?;
77/// # Ok::<_, Box<dyn std::error::Error>>(())
78/// ```
79pub struct NpzWriter<W: Write + Seek> {
80 zip: ZipWriter<W>,
81 options: SimpleFileOptions,
82}
83
84impl<W: Write + Seek> NpzWriter<W> {
85 /// Create a new `.npz` file without compression. See [`numpy.savez`].
86 ///
87 /// [`numpy.savez`]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez.html
88 pub fn new(writer: W) -> NpzWriter<W> {
89 NpzWriter {
90 zip: ZipWriter::new(writer),
91 options: SimpleFileOptions::default().compression_method(CompressionMethod::Stored),
92 }
93 }
94
95 /// Creates a new `.npz` file with compression. See [`numpy.savez_compressed`].
96 ///
97 /// [`numpy.savez_compressed`]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez_compressed.html
98 #[cfg(feature = "compressed_npz")]
99 pub fn new_compressed(writer: W) -> NpzWriter<W> {
100 NpzWriter {
101 zip: ZipWriter::new(writer),
102 options: SimpleFileOptions::default().compression_method(CompressionMethod::Deflated),
103 }
104 }
105
106 /// Creates a new `.npz` file with the specified options.
107 ///
108 /// This allows you to use a custom compression method, such as zstd, or set other options.
109 ///
110 /// Make sure to enable the relevant features of the `zip` crate.
111 pub fn new_with_options(writer: W, options: SimpleFileOptions) -> NpzWriter<W> {
112 NpzWriter {
113 zip: ZipWriter::new(writer),
114 options,
115 }
116 }
117
118 /// Adds an array with the specified `name` to the `.npz` file.
119 ///
120 /// Note that a `.npy` extension will be appended to `name`; this matches NumPy's behavior.
121 ///
122 /// To write a scalar value, create a zero-dimensional array using [`arr0`](ndarray::arr0) or
123 /// [`aview0`](ndarray::aview0).
124 pub fn add_array<N, S, D>(
125 &mut self,
126 name: N,
127 array: &ArrayBase<S, D>,
128 ) -> Result<(), WriteNpzError>
129 where
130 N: Into<String>,
131 S::Elem: WritableElement,
132 S: Data,
133 D: Dimension,
134 {
135 self.zip.start_file(name.into() + ".npy", self.options)?;
136 // Buffering when writing individual arrays is beneficial even when the
137 // underlying writer is `Cursor<Vec<u8>>` instead of a real file. The
138 // only exception I saw in testing was the "compressed, in-memory
139 // writer, standard layout case". See
140 // https://github.com/jturner314/ndarray-npy/issues/50#issuecomment-812802481
141 // for details.
142 array.write_npy(BufWriter::new(&mut self.zip))?;
143 Ok(())
144 }
145
146 /// Calls [`.finish()`](ZipWriter::finish) on the zip file and
147 /// [`.flush()`](Write::flush) on the writer, and then returns the writer.
148 ///
149 /// This finishes writing the remaining zip structures and flushes the
150 /// writer. While dropping will automatically attempt to finish the zip
151 /// file and (for writers that flush on drop, such as
152 /// [`BufWriter`](std::io::BufWriter)) flush the writer, any errors that
153 /// occur during drop will be silently ignored. So, it's necessary to call
154 /// `.finish()` to properly handle errors.
155 pub fn finish(self) -> Result<W, WriteNpzError> {
156 let mut writer = self.zip.finish()?;
157 writer.flush().map_err(ZipError::from)?;
158 Ok(writer)
159 }
160}
161
162/// An error reading a `.npz` file.
163#[derive(Debug)]
164pub enum ReadNpzError {
165 /// An error caused by the zip archive.
166 Zip(ZipError),
167 /// An error caused by reading an inner `.npy` file.
168 Npy(ReadNpyError),
169}
170
171impl Error for ReadNpzError {
172 fn source(&self) -> Option<&(dyn Error + 'static)> {
173 match self {
174 ReadNpzError::Zip(err) => Some(err),
175 ReadNpzError::Npy(err) => Some(err),
176 }
177 }
178}
179
180impl fmt::Display for ReadNpzError {
181 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
182 match self {
183 ReadNpzError::Zip(err) => write!(f, "zip file error: {}", err),
184 ReadNpzError::Npy(err) => write!(f, "error reading npy file in npz archive: {}", err),
185 }
186 }
187}
188
189impl From<ZipError> for ReadNpzError {
190 fn from(err: ZipError) -> ReadNpzError {
191 ReadNpzError::Zip(err)
192 }
193}
194
195impl From<ReadNpyError> for ReadNpzError {
196 fn from(err: ReadNpyError) -> ReadNpzError {
197 ReadNpzError::Npy(err)
198 }
199}
200
201/// Reader for `.npz` files.
202///
203/// # Example
204///
205/// ```no_run
206/// use ndarray::{Array1, Array2};
207/// use ndarray_npy::NpzReader;
208/// use std::fs::File;
209///
210/// let mut npz = NpzReader::new(File::open("arrays.npz")?)?;
211/// let a: Array2<i32> = npz.by_name("a")?;
212/// let b: Array1<i32> = npz.by_name("b")?;
213/// # Ok::<_, Box<dyn std::error::Error>>(())
214/// ```
215pub struct NpzReader<R: Read + Seek> {
216 zip: ZipArchive<R>,
217}
218
219impl<R: Read + Seek> NpzReader<R> {
220 /// Creates a new `.npz` file reader.
221 pub fn new(reader: R) -> Result<NpzReader<R>, ReadNpzError> {
222 Ok(NpzReader {
223 zip: ZipArchive::new(reader)?,
224 })
225 }
226
227 /// Returns `true` iff the `.npz` file doesn't contain any arrays.
228 pub fn is_empty(&self) -> bool {
229 self.zip.len() == 0
230 }
231
232 /// Returns the number of arrays in the `.npz` file.
233 pub fn len(&self) -> usize {
234 self.zip.len()
235 }
236
237 /// Returns the names of all of the arrays in the file.
238 ///
239 /// Note that a single ".npy" suffix (if present) will be stripped from each name; this matches
240 /// NumPy's behavior.
241 pub fn names(&mut self) -> Result<Vec<String>, ReadNpzError> {
242 Ok((0..self.zip.len())
243 .map(|i| {
244 let file = self.zip.by_index(i)?;
245 let name = file.name();
246 let stripped = name.strip_suffix(".npy").unwrap_or(name);
247 Ok(stripped.to_owned())
248 })
249 .collect::<Result<_, ZipError>>()?)
250 }
251
252 /// Reads an array by name.
253 ///
254 /// Note that this first checks for `name` in the `.npz` file, and if that is not present,
255 /// checks for `format!("{name}.npy")`. This matches NumPy's behavior.
256 pub fn by_name<S, D>(&mut self, name: &str) -> Result<ArrayBase<S, D>, ReadNpzError>
257 where
258 S::Elem: ReadableElement,
259 S: DataOwned,
260 D: Dimension,
261 {
262 // TODO: Combine the two cases into a single `let file = match { ... }` once
263 // https://github.com/rust-lang/rust/issues/47680 is resolved.
264 match self.zip.by_name(name) {
265 Ok(file) => return Ok(ArrayBase::<S, D>::read_npy(file)?),
266 Err(ZipError::FileNotFound) => {}
267 Err(err) => return Err(err.into()),
268 };
269 Ok(ArrayBase::<S, D>::read_npy(
270 self.zip.by_name(&format!("{name}.npy"))?,
271 )?)
272 }
273
274 /// Reads an array by index in the `.npz` file.
275 pub fn by_index<S, D>(&mut self, index: usize) -> Result<ArrayBase<S, D>, ReadNpzError>
276 where
277 S::Elem: ReadableElement,
278 S: DataOwned,
279 D: Dimension,
280 {
281 Ok(ArrayBase::<S, D>::read_npy(self.zip.by_index(index)?)?)
282 }
283}