ndarray_npz/
lib.rs

1//! Advanced [`.npz`] file format support for [`ndarray`].
2//!
3//! # Accessing [`.npy`] Files
4//!
5//!   * See [`ndarray_npy`].
6//!
7//! # Accessing [`.npz`] Files
8//!
9//!   * Reading: [`NpzReader`]
10//!   * Writing: [`NpzWriter`]
11//!   * Immutable viewing (primarily for use with memory-mapped files):
12//!       * [`NpzView`] providing an [`NpyView`] for each uncompressed [`.npy`] file within
13//!         the archive
14//!   * Mutable viewing (primarily for use with memory-mapped files):
15//!       * [`NpzViewMut`] providing an [`NpyViewMut`] for each uncompressed [`.npy`] file within
16//!         the archive
17//!
18//! [`.npy`]: https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html
19//! [`.npz`]: https://numpy.org/doc/stable/reference/generated/numpy.savez.html
20//!
21//! # Features
22//!
23//! Both features are enabled by default.
24//!
25//!   * `compressed`: Enables zip archives with *deflate* compression.
26//!   * `num-complex-0_4`: Enables complex element types of crate `num-complex`.
27
28#![forbid(unsafe_code)]
29#![deny(
30	missing_docs,
31	rustdoc::broken_intra_doc_links,
32	rustdoc::missing_crate_level_docs
33)]
34#![allow(clippy::tabs_in_doc_comments)]
35#![cfg_attr(docsrs, feature(doc_auto_cfg))]
36
37// [`NpzReader`] and [`NpzWriter`] are derivative works of [`ndarray_npy`].
38
39pub use ndarray;
40pub use ndarray_npy;
41
42use ndarray::{
43	prelude::*,
44	{Data, DataOwned},
45};
46use ndarray_npy::{
47	ReadNpyError, ReadNpyExt, ReadableElement, ViewElement, ViewMutElement, ViewMutNpyExt,
48	ViewNpyError, ViewNpyExt, WritableElement, WriteNpyError, WriteNpyExt,
49};
50use std::{
51	collections::{BTreeMap, HashMap, HashSet},
52	error::Error,
53	fmt,
54	io::{self, BufWriter, Cursor, Read, Seek, Write},
55	ops::Range,
56};
57use zip::{
58	result::ZipError,
59	write::SimpleFileOptions,
60	{CompressionMethod, ZipArchive, ZipWriter},
61};
62
63/// An error writing a `.npz` file.
64#[derive(Debug)]
65pub enum WriteNpzError {
66	/// An error caused by the zip file.
67	Zip(ZipError),
68	/// An error caused by writing an inner `.npy` file.
69	Npy(WriteNpyError),
70}
71
72impl Error for WriteNpzError {
73	fn source(&self) -> Option<&(dyn Error + 'static)> {
74		match self {
75			WriteNpzError::Zip(err) => Some(err),
76			WriteNpzError::Npy(err) => Some(err),
77		}
78	}
79}
80
81impl fmt::Display for WriteNpzError {
82	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83		match self {
84			WriteNpzError::Zip(err) => write!(f, "zip file error: {err}"),
85			WriteNpzError::Npy(err) => write!(f, "error writing npy file to npz archive: {err}"),
86		}
87	}
88}
89
90impl From<ZipError> for WriteNpzError {
91	fn from(err: ZipError) -> WriteNpzError {
92		WriteNpzError::Zip(err)
93	}
94}
95
96impl From<WriteNpyError> for WriteNpzError {
97	fn from(err: WriteNpyError) -> WriteNpzError {
98		WriteNpzError::Npy(err)
99	}
100}
101
102/// Writer for `.npz` files.
103///
104/// Note that the inner [`ZipWriter`] is wrapped in a [`BufWriter`] when
105/// writing each array with [`.add_array()`](NpzWriter::add_array). If desired,
106/// you could additionally buffer the innermost writer (e.g. the
107/// [`File`](std::fs::File) when writing to a file) by wrapping it in a
108/// [`BufWriter`]. This may be somewhat beneficial if the arrays are large and
109/// have non-standard layouts but may decrease performance if the arrays have
110/// standard or Fortran layout, so it's not recommended without testing to
111/// compare.
112///
113/// # Example
114///
115/// ```no_run
116/// use ndarray_npz::{
117/// 	ndarray::{array, aview0, Array1, Array2},
118/// 	NpzWriter,
119/// };
120/// use std::fs::File;
121///
122/// let mut npz = NpzWriter::new(File::create("arrays.npz")?);
123/// let a: Array2<i32> = array![[1, 2, 3], [4, 5, 6]];
124/// let b: Array1<i32> = array![7, 8, 9];
125/// npz.add_array("a", &a)?;
126/// npz.add_array("b", &b)?;
127/// npz.add_array("c", &aview0(&10))?;
128/// npz.finish()?;
129/// # Ok::<_, Box<dyn std::error::Error>>(())
130/// ```
131pub struct NpzWriter<W: Write + Seek> {
132	zip: ZipWriter<W>,
133	options: SimpleFileOptions,
134}
135
136impl<W: Write + Seek> NpzWriter<W> {
137	/// Creates a new `.npz` file without compression. See [`numpy.savez`].
138	///
139	/// Ensures `.npy` files are 64-byte aligned for memory-mapping via [`NpzView`]/[`NpzViewMut`].
140	///
141	/// [`numpy.savez`]: https://numpy.org/doc/stable/reference/generated/numpy.savez.html
142	#[must_use]
143	pub fn new(writer: W) -> NpzWriter<W> {
144		NpzWriter {
145			zip: ZipWriter::new(writer),
146			options: SimpleFileOptions::default()
147				.with_alignment(64)
148				.compression_method(CompressionMethod::Stored),
149		}
150	}
151
152	/// Creates a new `.npz` file with compression. See [`numpy.savez_compressed`].
153	///
154	/// [`numpy.savez_compressed`]: https://numpy.org/doc/stable/reference/generated/numpy.savez_compressed.html
155	#[cfg(feature = "compressed")]
156	#[must_use]
157	pub fn new_compressed(writer: W) -> NpzWriter<W> {
158		NpzWriter {
159			zip: ZipWriter::new(writer),
160			options: SimpleFileOptions::default().compression_method(CompressionMethod::Deflated),
161		}
162	}
163
164	/// Adds an array with the specified `name` to the `.npz` file.
165	///
166	/// To write a scalar value, create a zero-dimensional array using [`arr0`] or [`aview0`].
167	///
168	/// # Errors
169	///
170	/// Adding an array can fail with [`WriteNpyError`].
171	pub fn add_array<N, S, D>(
172		&mut self,
173		name: N,
174		array: &ArrayBase<S, D>,
175	) -> Result<(), WriteNpzError>
176	where
177		N: Into<String>,
178		S::Elem: WritableElement,
179		S: Data,
180		D: Dimension,
181	{
182		self.zip.start_file(name.into(), self.options)?;
183		array.write_npy(BufWriter::new(&mut self.zip))?;
184		Ok(())
185	}
186
187	/// Calls [`.finish()`](ZipWriter::finish) on the zip file and
188	/// [`.flush()`](Write::flush) on the writer, and then returns the writer.
189	///
190	/// This finishes writing the remaining zip structures and flushes the
191	/// writer. While dropping will automatically attempt to finish the zip
192	/// file and (for writers that flush on drop, such as [`BufWriter`]) flush
193	/// the writer, any errors that occur during drop will be silently ignored.
194	/// So, it's necessary to call `.finish()` to properly handle errors.
195	///
196	/// # Errors
197	///
198	/// Finishing the zip archive can fail with [`ZipError`].
199	pub fn finish(self) -> Result<W, WriteNpzError> {
200		let mut writer = self.zip.finish()?;
201		writer.flush().map_err(ZipError::from)?;
202		Ok(writer)
203	}
204}
205
206/// An error reading a `.npz` file.
207#[derive(Debug)]
208pub enum ReadNpzError {
209	/// An error caused by the zip archive.
210	Zip(ZipError),
211	/// An error caused by reading an inner `.npy` file.
212	Npy(ReadNpyError),
213}
214
215impl Error for ReadNpzError {
216	fn source(&self) -> Option<&(dyn Error + 'static)> {
217		match self {
218			ReadNpzError::Zip(err) => Some(err),
219			ReadNpzError::Npy(err) => Some(err),
220		}
221	}
222}
223
224impl fmt::Display for ReadNpzError {
225	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
226		match self {
227			ReadNpzError::Zip(err) => write!(f, "zip file error: {err}"),
228			ReadNpzError::Npy(err) => write!(f, "error reading npy file in npz archive: {err}"),
229		}
230	}
231}
232
233impl From<ZipError> for ReadNpzError {
234	fn from(err: ZipError) -> ReadNpzError {
235		ReadNpzError::Zip(err)
236	}
237}
238
239impl From<ReadNpyError> for ReadNpzError {
240	fn from(err: ReadNpyError) -> ReadNpzError {
241		ReadNpzError::Npy(err)
242	}
243}
244
245/// Reader for `.npz` files.
246///
247/// # Example
248///
249/// ```no_run
250/// use ndarray_npz::{
251/// 	ndarray::{Array1, Array2},
252/// 	NpzReader,
253/// };
254/// use std::fs::File;
255///
256/// let mut npz = NpzReader::new(File::open("arrays.npz")?)?;
257/// let a: Array2<i32> = npz.by_name("a")?;
258/// let b: Array1<i32> = npz.by_name("b")?;
259/// # Ok::<_, Box<dyn std::error::Error>>(())
260/// ```
261pub struct NpzReader<R: Read + Seek> {
262	zip: ZipArchive<R>,
263}
264
265impl<R: Read + Seek> NpzReader<R> {
266	/// Creates a new `.npz` file reader.
267	///
268	/// # Errors
269	///
270	/// Reading a zip archive can fail with [`ZipError`].
271	pub fn new(reader: R) -> Result<NpzReader<R>, ReadNpzError> {
272		Ok(NpzReader {
273			zip: ZipArchive::new(reader)?,
274		})
275	}
276
277	/// Returns `true` iff the `.npz` file doesn't contain any arrays.
278	#[must_use]
279	pub fn is_empty(&self) -> bool {
280		self.zip.len() == 0
281	}
282
283	/// Returns the number of arrays in the `.npz` file.
284	#[must_use]
285	pub fn len(&self) -> usize {
286		self.zip.len()
287	}
288
289	/// Returns the names of all of the arrays in the file.
290	///
291	/// # Errors
292	///
293	/// Reading a zip archive can fail with [`ZipError`].
294	pub fn names(&mut self) -> Result<Vec<String>, ReadNpzError> {
295		Ok((0..self.zip.len())
296			.map(|i| Ok(self.zip.by_index(i)?.name().to_owned()))
297			.collect::<Result<_, ZipError>>()?)
298	}
299
300	/// Reads an array by name.
301	///
302	/// # Errors
303	///
304	/// Reading an array from an archive can fail with [`ReadNpyError`] or [`ZipError`].
305	pub fn by_name<S, D>(&mut self, name: &str) -> Result<ArrayBase<S, D>, ReadNpzError>
306	where
307		S::Elem: ReadableElement,
308		S: DataOwned,
309		D: Dimension,
310	{
311		Ok(ArrayBase::<S, D>::read_npy(self.zip.by_name(name)?)?)
312	}
313
314	/// Reads an array by index in the `.npz` file.
315	///
316	/// # Errors
317	///
318	/// Reading an array from an archive can fail with [`ReadNpyError`] or [`ZipError`].
319	pub fn by_index<S, D>(&mut self, index: usize) -> Result<ArrayBase<S, D>, ReadNpzError>
320	where
321		S::Elem: ReadableElement,
322		S: DataOwned,
323		D: Dimension,
324	{
325		Ok(ArrayBase::<S, D>::read_npy(self.zip.by_index(index)?)?)
326	}
327}
328
329/// An error viewing a `.npz` file.
330#[derive(Debug)]
331#[non_exhaustive]
332pub enum ViewNpzError {
333	/// An error caused by the zip archive.
334	Zip(ZipError),
335	/// An error caused by viewing an inner `.npy` file.
336	Npy(ViewNpyError),
337	/// A mutable `.npy` file view has already been moved out of its `.npz` file view.
338	MovedNpyViewMut,
339	/// Directories cannot be viewed.
340	Directory,
341	/// Compressed files cannot be viewed.
342	CompressedFile,
343	/// Encrypted files cannot be viewed.
344	EncryptedFile,
345}
346
347impl Error for ViewNpzError {
348	fn source(&self) -> Option<&(dyn Error + 'static)> {
349		match self {
350			ViewNpzError::Zip(err) => Some(err),
351			ViewNpzError::Npy(err) => Some(err),
352			ViewNpzError::MovedNpyViewMut
353			| ViewNpzError::Directory
354			| ViewNpzError::CompressedFile
355			| ViewNpzError::EncryptedFile => None,
356		}
357	}
358}
359
360impl fmt::Display for ViewNpzError {
361	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
362		match self {
363			ViewNpzError::Zip(err) => write!(f, "zip file error: {err}"),
364			ViewNpzError::Npy(err) => write!(f, "error viewing npy file in npz archive: {err}"),
365			ViewNpzError::MovedNpyViewMut => write!(
366				f,
367				"mutable npy file view already moved out of npz file view"
368			),
369			ViewNpzError::Directory => write!(f, "directories cannot be viewed"),
370			ViewNpzError::CompressedFile => write!(f, "compressed files cannot be viewed"),
371			ViewNpzError::EncryptedFile => write!(f, "encrypted files cannot be viewed"),
372		}
373	}
374}
375
376impl From<ZipError> for ViewNpzError {
377	fn from(err: ZipError) -> ViewNpzError {
378		ViewNpzError::Zip(err)
379	}
380}
381
382impl From<ViewNpyError> for ViewNpzError {
383	fn from(err: ViewNpyError) -> ViewNpzError {
384		ViewNpzError::Npy(err)
385	}
386}
387
388/// Immutable view for memory-mapped `.npz` files.
389///
390/// The primary use-case for this is viewing `.npy` files within a memory-mapped
391/// `.npz` archive.
392///
393/// # Notes
394///
395/// - For types for which not all bit patterns are valid, such as `bool`, the
396///   implementation iterates over all of the elements when creating the view
397///   to ensure they have a valid bit pattern.
398/// - The data in the buffer containing an `.npz` archive must be properly
399///   aligned for its `.npy` file with the maximum alignment requirement for its
400///   element type. Typically, this should not be a concern for memory-mapped
401///   files (unless an option like `MAP_FIXED` is used), since memory mappings
402///   are usually aligned to a page boundary.
403/// - The `.npy` files within the `.npz` archive must be properly aligned for
404///   their element type. Archives not created by this crate can be aligned with
405///   the help of the CLI tool [`rezip`] as in `rezip in.npz -o out.npz`.
406///
407/// [`rezip`]: https://crates.io/crates/rezip
408///
409/// # Example
410///
411/// This is an example of opening an immutably memory-mapped `.npz` archive as
412/// an [`NpzView`] providing an [`NpyView`] for each non-compressed and
413/// non-encrypted `.npy` file within the archive which can be accessed via
414/// [`NpyView::view`] as immutable [`ArrayView`].
415///
416/// This example uses the [`memmap2`](https://crates.io/crates/memmap2) crate
417/// because that appears to be the best-maintained memory-mapping crate at the
418/// moment, but [`Self::new`] takes a `&mut [u8]` instead of a file so that you
419/// can use the memory-mapping crate you're most comfortable with.
420///
421/// ```
422/// # if !cfg!(miri) { // Miri doesn't support mmap.
423/// use std::fs::OpenOptions;
424///
425/// use memmap2::MmapOptions;
426/// use ndarray::Ix1;
427/// use ndarray_npz::{NpzView, ViewNpzError};
428///
429/// // Open `.npz` archive of non-compressed and non-encrypted `.npy` files in
430/// // native endian.
431/// #[cfg(target_endian = "little")]
432/// let file = OpenOptions::new()
433/// 	.read(true)
434/// 	.open("tests/examples_little_endian_64_byte_aligned.npz")
435/// 	.unwrap();
436/// #[cfg(target_endian = "big")]
437/// let file = OpenOptions::new()
438/// 	.read(true)
439/// 	.open("tests/examples_big_endian_64_byte_aligned.npz")
440/// 	.unwrap();
441/// // Memory-map `.npz` archive of 64-byte aligned `.npy` files.
442/// let mmap = unsafe { MmapOptions::new().map(&file).unwrap() };
443/// let npz = NpzView::new(&mmap)?;
444/// // List non-compressed and non-encrypted files only.
445/// for npy in npz.names() {
446/// 	println!("{}", npy);
447/// }
448/// // Get immutable `.npy` views.
449/// let mut x_npy_view = npz.by_name("i64.npy")?;
450/// let mut y_npy_view = npz.by_name("f64.npy")?;
451/// // Optionally verify CRC-32 checksums.
452/// x_npy_view.verify()?;
453/// y_npy_view.verify()?;
454/// // Get and print immutable `ArrayView`s.
455/// let x_array_view = x_npy_view.view::<i64, Ix1>()?;
456/// let y_array_view = y_npy_view.view::<f64, Ix1>()?;
457/// println!("{}", x_array_view);
458/// println!("{}", y_array_view);
459/// # }
460/// # Ok::<(), ndarray_npz::ViewNpzError>(())
461/// ```
462#[derive(Debug, Clone)]
463pub struct NpzView<'a> {
464	files: HashMap<usize, NpyView<'a>>,
465	names: HashMap<String, usize>,
466	directory_names: HashSet<String>,
467	compressed_names: HashSet<String>,
468	encrypted_names: HashSet<String>,
469}
470
471impl<'a> NpzView<'a> {
472	/// Creates a new immutable view of a memory-mapped `.npz` file.
473	///
474	/// # Errors
475	///
476	/// Viewing an archive can fail with [`ZipError`].
477	pub fn new(bytes: &'a [u8]) -> Result<Self, ViewNpzError> {
478		let mut zip = ZipArchive::new(Cursor::new(bytes))?;
479		let mut archive = Self {
480			files: HashMap::new(),
481			names: HashMap::new(),
482			directory_names: HashSet::new(),
483			compressed_names: HashSet::new(),
484			encrypted_names: zip.file_names().map(From::from).collect(),
485		};
486		// Initially assume all files to be encrypted.
487		let mut index = 0;
488		for zip_index in 0..zip.len() {
489			// Skip encrypted files.
490			let file = match zip.by_index(zip_index) {
491				Err(ZipError::UnsupportedArchive(ZipError::PASSWORD_REQUIRED)) => continue,
492				Err(err) => return Err(err.into()),
493				Ok(file) => file,
494			};
495			// File name of non-encrypted file.
496			let name = file.name().to_string();
497			// Remove file name from encrypted files.
498			archive.encrypted_names.remove(&name);
499			// Skip directories and compressed files.
500			if file.is_dir() {
501				archive.directory_names.insert(name);
502				continue;
503			}
504			if file.compression() != CompressionMethod::Stored {
505				archive.compressed_names.insert(name);
506				continue;
507			}
508			// Store file index by file names.
509			archive.names.insert(name, index);
510			let file = NpyView {
511				data: slice_at(bytes, file.data_start(), 0..file.size())?,
512				central_crc32: slice_at(bytes, file.central_header_start(), 16..20)
513					.map(as_array_ref)?,
514				status: ChecksumStatus::default(),
515			};
516			// Store file view by file index.
517			archive.files.insert(index, file);
518			// Increment index of non-compressed and non-encrypted files.
519			index += 1;
520		}
521		Ok(archive)
522	}
523
524	/// Returns `true` iff the `.npz` file doesn't contain any viewable arrays.
525	///
526	/// Viewable arrays are neither directories, nor compressed, nor encrypted.
527	#[must_use]
528	pub fn is_empty(&self) -> bool {
529		self.names.is_empty()
530	}
531
532	/// Returns the number of viewable arrays in the `.npz` file.
533	///
534	/// Viewable arrays are neither directories, nor compressed, nor encrypted.
535	#[must_use]
536	pub fn len(&self) -> usize {
537		self.names.len()
538	}
539
540	/// Returns the names of all of the viewable arrays in the `.npz` file.
541	///
542	/// Viewable arrays are neither directories, nor compressed, nor encrypted.
543	pub fn names(&self) -> impl Iterator<Item = &str> {
544		self.names.keys().map(String::as_str)
545	}
546	/// Returns the names of all of the directories in the `.npz` file.
547	pub fn directory_names(&self) -> impl Iterator<Item = &str> {
548		self.directory_names.iter().map(String::as_str)
549	}
550	/// Returns the names of all of the compressed files in the `.npz` file.
551	pub fn compressed_names(&self) -> impl Iterator<Item = &str> {
552		self.compressed_names.iter().map(String::as_str)
553	}
554	/// Returns the names of all of the encrypted files in the `.npz` file.
555	pub fn encrypted_names(&self) -> impl Iterator<Item = &str> {
556		self.encrypted_names.iter().map(String::as_str)
557	}
558
559	/// Returns an immutable `.npy` file view by name.
560	///
561	/// # Errors
562	///
563	/// Viewing an `.npy` file can fail with [`ViewNpyError`]. Trying to view a directory,
564	/// compressed file, or encrypted file, fails with [`ViewNpzError::Directory`],
565	/// [`ViewNpzError::CompressedFile`], or [`ViewNpzError::CompressedFile`]. Fails with
566	/// [`ZipError::FileNotFound`] if the `name` is not found.
567	pub fn by_name(&self, name: &str) -> Result<NpyView<'a>, ViewNpzError> {
568		self.by_index(self.names.get(name).copied().ok_or_else(|| {
569			if self.directory_names.contains(name) {
570				ViewNpzError::Directory
571			} else if self.compressed_names.contains(name) {
572				ViewNpzError::CompressedFile
573			} else if self.encrypted_names.contains(name) {
574				ViewNpzError::EncryptedFile
575			} else {
576				ZipError::FileNotFound.into()
577			}
578		})?)
579	}
580
581	/// Returns an immutable `.npy` file view by index in `0..len()`.
582	///
583	/// The index **does not** necessarily correspond to the index of the zip archive as
584	/// directories, compressed files, and encrypted files are skipped.
585	///
586	/// # Errors
587	///
588	/// Viewing an `.npy` file can fail with [`ViewNpyError`]. Fails with [`ZipError::FileNotFound`]
589	/// if the `index` is not found.
590	pub fn by_index(&self, index: usize) -> Result<NpyView<'a>, ViewNpzError> {
591		self.files
592			.get(&index)
593			.copied()
594			.ok_or_else(|| ZipError::FileNotFound.into())
595	}
596}
597
598/// Immutable view of memory-mapped `.npy` files within an `.npz` file.
599///
600/// Does **not** automatically [verify](`Self::verify`) CRC-32 checksum.
601#[derive(Debug, Clone, Copy)]
602pub struct NpyView<'a> {
603	data: &'a [u8],
604	central_crc32: &'a [u8; 4],
605	status: ChecksumStatus,
606}
607
608impl NpyView<'_> {
609	/// CRC-32 checksum status.
610	#[must_use]
611	pub fn status(&self) -> ChecksumStatus {
612		self.status
613	}
614	/// Verifies and returns CRC-32 checksum by reading the whole array.
615	///
616	/// Changes checksum [`status`](`Self::status()`) to [`Outdated`](`ChecksumStatus::Outdated`)
617	/// if invalid or to [`Correct`](`ChecksumStatus::Correct`) if valid.
618	///
619	/// # Errors
620	///
621	/// Fails with [`ZipError::Io`] if the checksum is invalid.
622	pub fn verify(&mut self) -> Result<u32, ViewNpzError> {
623		self.status = ChecksumStatus::Outdated;
624		// Like the `zip` crate, verify only against central CRC-32.
625		let crc32 = crc32_verify(self.data, *self.central_crc32)?;
626		self.status = ChecksumStatus::Correct;
627		Ok(crc32)
628	}
629
630	/// Returns an immutable view of a memory-mapped `.npy` file.
631	///
632	/// Iterates over `bool` array to ensure `0x00`/`0x01` values.
633	///
634	/// # Errors
635	///
636	/// Viewing an `.npy` file can fail with [`ViewNpyError`].
637	pub fn view<A, D>(&self) -> Result<ArrayView<A, D>, ViewNpzError>
638	where
639		A: ViewElement,
640		D: Dimension,
641	{
642		Ok(ArrayView::view_npy(self.data)?)
643	}
644}
645
646/// Mutable view for memory-mapped `.npz` files.
647///
648/// The primary use-case for this is modifying `.npy` files within a
649/// memory-mapped `.npz` archive. Modifying the elements in the view will modify
650/// the file. Modifying the shape/strides of the view will *not* modify the
651/// shape/strides of the array in the file.
652///
653/// # Notes
654///
655/// - For types for which not all bit patterns are valid, such as `bool`, the
656///   implementation iterates over all of the elements when creating the view
657///   to ensure they have a valid bit pattern.
658/// - The data in the buffer containing an `.npz` archive must be properly
659///   aligned for its `.npy` file with the maximum alignment requirement for its
660///   element type. Typically, this should not be a concern for memory-mapped
661///   files (unless an option like `MAP_FIXED` is used), since memory mappings
662///   are usually aligned to a page boundary.
663/// - The `.npy` files within the `.npz` archive must be properly aligned for
664///   their element type. Archives not created by this crate can be aligned with
665///   the help of the CLI tool [`rezip`] as in `rezip in.npz -o out.npz`.
666///
667/// [`rezip`]: https://crates.io/crates/rezip
668///
669/// # Example
670///
671/// This is an example of opening a mutably memory-mapped `.npz` archive as an
672/// [`NpzViewMut`] providing an [`NpyViewMut`] for each non-compressed and
673/// non-encrypted `.npy` file within the archive which can be accessed via
674/// [`NpyViewMut::view`] as immutable [`ArrayView`] or via
675/// [`NpyViewMut::view_mut`] as mutable [`ArrayViewMut`]. Changes to the data in
676/// the view will modify the underlying file within the archive.
677///
678/// This example uses the [`memmap2`](https://crates.io/crates/memmap2) crate
679/// because that appears to be the best-maintained memory-mapping crate at the
680/// moment, but [`Self::new`] takes a `&mut [u8]` instead of a file so that you
681/// can use the memory-mapping crate you're most comfortable with.
682///
683/// # Example
684///
685/// ```
686/// # if !cfg!(miri) { // Miri doesn't support mmap.
687/// use std::fs::OpenOptions;
688///
689/// use memmap2::MmapOptions;
690/// use ndarray::Ix1;
691/// use ndarray_npz::{NpzViewMut, ViewNpzError};
692///
693/// // Open `.npz` archive of non-compressed and non-encrypted `.npy` files in
694/// // native endian.
695/// #[cfg(target_endian = "little")]
696/// let mut file = OpenOptions::new()
697/// 	.read(true)
698/// 	.write(true)
699/// 	.open("tests/examples_little_endian_64_byte_aligned.npz")
700/// 	.unwrap();
701/// #[cfg(target_endian = "big")]
702/// let mut file = OpenOptions::new()
703/// 	.read(true)
704/// 	.write(true)
705/// 	.open("tests/examples_big_endian_64_byte_aligned.npz")
706/// 	.unwrap();
707/// // Memory-map `.npz` archive of 64-byte aligned `.npy` files.
708/// let mut mmap = unsafe { MmapOptions::new().map_mut(&file).unwrap() };
709/// let mut npz = NpzViewMut::new(&mut mmap)?;
710/// // List non-compressed and non-encrypted files only.
711/// for npy in npz.names() {
712/// 	println!("{}", npy);
713/// }
714/// // Get mutable `.npy` views of both arrays at the same time.
715/// let mut x_npy_view_mut = npz.by_name("i64.npy")?;
716/// let mut y_npy_view_mut = npz.by_name("f64.npy")?;
717/// // Optionally verify CRC-32 checksums.
718/// x_npy_view_mut.verify()?;
719/// y_npy_view_mut.verify()?;
720/// // Get and print mutable `ArrayViewMut`s.
721/// let x_array_view_mut = x_npy_view_mut.view_mut::<i64, Ix1>()?;
722/// let y_array_view_mut = y_npy_view_mut.view_mut::<f64, Ix1>()?;
723/// println!("{}", x_array_view_mut);
724/// println!("{}", y_array_view_mut);
725/// // Update CRC-32 checksums after changes. Automatically updated on `drop()`
726/// // if outdated.
727/// x_npy_view_mut.update();
728/// y_npy_view_mut.update();
729/// # }
730/// # Ok::<(), ndarray_npz::ViewNpzError>(())
731/// ```
732#[derive(Debug)]
733pub struct NpzViewMut<'a> {
734	files: HashMap<usize, NpyViewMut<'a>>,
735	names: HashMap<String, usize>,
736	directory_names: HashSet<String>,
737	compressed_names: HashSet<String>,
738	encrypted_names: HashSet<String>,
739}
740
741impl<'a> NpzViewMut<'a> {
742	/// Creates a new mutable view of a memory-mapped `.npz` file.
743	///
744	/// # Errors
745	///
746	/// Viewing an archive can fail with [`ZipError`].
747	pub fn new(mut bytes: &'a mut [u8]) -> Result<Self, ViewNpzError> {
748		let mut zip = ZipArchive::new(Cursor::new(&bytes))?;
749		let mut archive = Self {
750			files: HashMap::new(),
751			names: HashMap::new(),
752			directory_names: HashSet::new(),
753			compressed_names: HashSet::new(),
754			encrypted_names: zip.file_names().map(From::from).collect(),
755		};
756		// Initially assume all files to be encrypted.
757		let mut ranges = HashMap::new();
758		let mut splits = BTreeMap::new();
759		let mut index = 0;
760		for zip_index in 0..zip.len() {
761			// Skip encrypted files.
762			let file = match zip.by_index(zip_index) {
763				Err(ZipError::UnsupportedArchive(ZipError::PASSWORD_REQUIRED)) => continue,
764				Err(err) => return Err(err.into()),
765				Ok(file) => file,
766			};
767			// File name of non-encrypted file.
768			let name = file.name().to_string();
769			// Remove file name from encrypted files.
770			archive.encrypted_names.remove(&name);
771			// Skip directories and compressed files.
772			if file.is_dir() {
773				archive.directory_names.insert(name);
774				continue;
775			}
776			if file.compression() != CompressionMethod::Stored {
777				archive.compressed_names.insert(name);
778				continue;
779			}
780			// Skip directories and compressed files.
781			if file.is_dir() || file.compression() != CompressionMethod::Stored {
782				continue;
783			}
784			// Store file index by file names.
785			archive.names.insert(name, index);
786			// Get data range.
787			let data_range = range_at(file.data_start(), 0..file.size())?;
788			// Get central general purpose bit flag range.
789			let central_flag_range = range_at(file.central_header_start(), 8..10)?;
790			// Parse central general purpose bit flag range.
791			let central_flag = u16_at(bytes, central_flag_range);
792			// Get central CRC-32 range.
793			let central_crc32_range = range_at(file.central_header_start(), 16..20)?;
794			// Whether local CRC-32 is located in header or data descriptor.
795			let use_data_descriptor = central_flag & (1 << 3) != 0;
796			// Get local CRC-32 range.
797			let crc32_range = if use_data_descriptor {
798				// Get local CRC-32 range in data descriptor.
799				let crc32_range = range_at(data_range.end, 0..4)?;
800				// Parse local CRC-32.
801				let crc32 = u32_at(bytes, crc32_range.clone());
802				// Whether local CRC-32 equals optional data descriptor signature.
803				if crc32 == 0x0807_4b50 {
804					// Parse central CRC-32.
805					let central_crc32 = u32_at(bytes, central_crc32_range.clone());
806					// Whether CRC-32 coincides with data descriptor signature.
807					if crc32 == central_crc32 {
808						return Err(ZipError::InvalidArchive(
809							"Ambiguous CRC-32 location in data descriptor".into(),
810						)
811						.into());
812					}
813					// Skip data descriptor signature and get local CRC-32 range in data descriptor.
814					range_at(data_range.end, 4..8)?
815				} else {
816					crc32_range
817				}
818			} else {
819				// Get local CRC-32 range in header.
820				range_at(file.header_start(), 14..18)?
821			};
822			// Sort ranges by their starts.
823			splits.insert(crc32_range.start, crc32_range.end);
824			splits.insert(data_range.start, data_range.end);
825			splits.insert(central_crc32_range.start, central_crc32_range.end);
826			// Store ranges by file index.
827			ranges.insert(index, (data_range, crc32_range, central_crc32_range));
828			// Increment index of non-compressed and non-encrypted files.
829			index += 1;
830		}
831		// Split and store borrows by their range starts.
832		let mut offset = 0;
833		let mut slices = HashMap::new();
834		for (&start, &end) in &splits {
835			// Split off leading bytes.
836			let mid = start
837				.checked_sub(offset)
838				.ok_or(ZipError::InvalidArchive("Overlapping ranges".into()))?;
839			if mid > bytes.len() {
840				return Err(ZipError::InvalidArchive("Offset exceeds EOF".into()).into());
841			}
842			let (slice, remaining_bytes) = bytes.split_at_mut(mid);
843			offset += slice.len();
844			// Split off leading borrow of interest. Cannot underflow since `start < end`.
845			let mid = end - offset;
846			if mid > remaining_bytes.len() {
847				return Err(ZipError::InvalidArchive("Length exceeds EOF".into()).into());
848			}
849			let (slice, remaining_bytes) = remaining_bytes.split_at_mut(mid);
850			offset += slice.len();
851			// Store borrow by its range start.
852			slices.insert(start, slice);
853			// Store remaining bytes.
854			bytes = remaining_bytes;
855		}
856		// Collect split borrows as file views.
857		for (&index, (data_range, crc32_range, central_crc32_range)) in &ranges {
858			let ambiguous_offset = || ZipError::InvalidArchive("Ambiguous offsets".into());
859			let file = NpyViewMut {
860				data: slices
861					.remove(&data_range.start)
862					.ok_or_else(ambiguous_offset)?,
863				crc32: slices
864					.remove(&crc32_range.start)
865					.map(as_array_mut)
866					.ok_or_else(ambiguous_offset)?,
867				central_crc32: slices
868					.remove(&central_crc32_range.start)
869					.map(as_array_mut)
870					.ok_or_else(ambiguous_offset)?,
871				status: ChecksumStatus::default(),
872			};
873			archive.files.insert(index, file);
874		}
875		Ok(archive)
876	}
877
878	/// Returns `true` iff the `.npz` file doesn't contain any viewable arrays.
879	///
880	/// Viewable arrays are neither directories, nor compressed, nor encrypted.
881	#[must_use]
882	pub fn is_empty(&self) -> bool {
883		self.names.is_empty()
884	}
885
886	/// Returns the number of viewable arrays in the `.npz` file.
887	///
888	/// Viewable arrays are neither directories, nor compressed, nor encrypted.
889	#[must_use]
890	pub fn len(&self) -> usize {
891		self.names.len()
892	}
893
894	/// Returns the names of all of the viewable arrays in the `.npz` file.
895	///
896	/// Viewable arrays are neither directories, nor compressed, nor encrypted.
897	pub fn names(&self) -> impl Iterator<Item = &str> {
898		self.names.keys().map(String::as_str)
899	}
900	/// Returns the names of all of the directories in the `.npz` file.
901	pub fn directory_names(&self) -> impl Iterator<Item = &str> {
902		self.directory_names.iter().map(String::as_str)
903	}
904	/// Returns the names of all of the compressed files in the `.npz` file.
905	pub fn compressed_names(&self) -> impl Iterator<Item = &str> {
906		self.compressed_names.iter().map(String::as_str)
907	}
908	/// Returns the names of all of the encrypted files in the `.npz` file.
909	pub fn encrypted_names(&self) -> impl Iterator<Item = &str> {
910		self.encrypted_names.iter().map(String::as_str)
911	}
912
913	/// Moves a mutable `.npy` file view by name out of the `.npz` file view.
914	///
915	/// # Errors
916	///
917	/// Viewing an `.npy` file can fail with [`ViewNpyError`]. Trying to view a directory,
918	/// compressed file, or encrypted file, fails with [`ViewNpzError::Directory`],
919	/// [`ViewNpzError::CompressedFile`], or [`ViewNpzError::CompressedFile`]. Fails with
920	/// [`ZipError::FileNotFound`] if the `name` is not found.
921	pub fn by_name(&mut self, name: &str) -> Result<NpyViewMut<'a>, ViewNpzError> {
922		self.by_index(self.names.get(name).copied().ok_or_else(|| {
923			if self.directory_names.contains(name) {
924				ViewNpzError::Directory
925			} else if self.compressed_names.contains(name) {
926				ViewNpzError::CompressedFile
927			} else if self.encrypted_names.contains(name) {
928				ViewNpzError::EncryptedFile
929			} else {
930				ZipError::FileNotFound.into()
931			}
932		})?)
933	}
934
935	/// Moves a mutable `.npy` file view by index in `0..len()` out of the `.npz` file view.
936	///
937	/// The index **does not** necessarily correspond to the index of the zip archive as
938	/// directories, compressed files, and encrypted files are skipped.
939	///
940	/// # Errors
941	///
942	/// Viewing an `.npy` file can fail with [`ViewNpyError`]. Fails with [`ZipError::FileNotFound`]
943	/// if the `index` is not found. Fails with [`ViewNpzError::MovedNpyViewMut`] if the mutable
944	/// `.npy` file view has already been moved out of the `.npz` file view.
945	pub fn by_index(&mut self, index: usize) -> Result<NpyViewMut<'a>, ViewNpzError> {
946		if index > self.names.len() {
947			Err(ZipError::FileNotFound.into())
948		} else {
949			self.files
950				.remove(&index)
951				.ok_or(ViewNpzError::MovedNpyViewMut)
952		}
953	}
954}
955
956/// Mutable view of memory-mapped `.npy` files within an `.npz` file.
957///
958/// Does **not** automatically [verify](`Self::verify`) the CRC-32 checksum but **does**
959/// [update](`Self::update`) it on [`Drop::drop`] if [`view_mut`](`Self::view_mut`) has been invoked
960/// and the checksum has not manually been updated by invoking [`update`](`Self::update`).
961#[derive(Debug)]
962pub struct NpyViewMut<'a> {
963	data: &'a mut [u8],
964	crc32: &'a mut [u8; 4],
965	central_crc32: &'a mut [u8; 4],
966	status: ChecksumStatus,
967}
968
969impl NpyViewMut<'_> {
970	/// CRC-32 checksum status.
971	#[must_use]
972	pub fn status(&self) -> ChecksumStatus {
973		self.status
974	}
975	/// Verifies and returns CRC-32 checksum by reading the whole array.
976	///
977	/// Changes checksum [`status`](`Self::status()`) to [`Outdated`](`ChecksumStatus::Outdated`)
978	/// if invalid or to [`Correct`](`ChecksumStatus::Correct`) if valid.
979	///
980	/// # Errors
981	///
982	/// Fails with [`ZipError::Io`] if the checksum is invalid.
983	pub fn verify(&mut self) -> Result<u32, ViewNpzError> {
984		self.status = ChecksumStatus::Outdated;
985		// Like the `zip` crate, verify only against central CRC-32.
986		let crc32 = crc32_verify(self.data, *self.central_crc32)?;
987		self.status = ChecksumStatus::Correct;
988		Ok(crc32)
989	}
990	/// Updates and returns CRC-32 checksum by reading the whole array.
991	///
992	/// Changes checksum [`status`](`Self::status()`) to [`Correct`](`ChecksumStatus::Correct`).
993	///
994	/// Automatically updated on [`Drop::drop`] iff checksum [`status`](`Self::status()`) is
995	/// [`Outdated`](`ChecksumStatus::Outdated`).
996	pub fn update(&mut self) -> u32 {
997		self.status = ChecksumStatus::Correct;
998		let crc32 = crc32_update(self.data);
999		*self.central_crc32 = crc32.to_le_bytes();
1000		*self.crc32 = *self.central_crc32;
1001		crc32
1002	}
1003
1004	/// Returns an immutable view of a memory-mapped `.npy` file.
1005	///
1006	/// Iterates over `bool` array to ensure `0x00`/`0x01` values.
1007	///
1008	/// # Errors
1009	///
1010	/// Viewing an `.npy` file can fail with [`ViewNpyError`].
1011	pub fn view<A, D>(&self) -> Result<ArrayView<A, D>, ViewNpzError>
1012	where
1013		A: ViewElement,
1014		D: Dimension,
1015	{
1016		Ok(ArrayView::<A, D>::view_npy(self.data)?)
1017	}
1018	/// Returns a mutable view of a memory-mapped `.npy` file.
1019	///
1020	/// Iterates over `bool` array to ensure `0x00`/`0x01` values.
1021	///
1022	/// Changes checksum [`status`](`Self::status()`) to [`Outdated`](`ChecksumStatus::Outdated`).
1023	///
1024	/// # Errors
1025	///
1026	/// Viewing an `.npy` file can fail with [`ViewNpyError`].
1027	pub fn view_mut<A, D>(&mut self) -> Result<ArrayViewMut<A, D>, ViewNpzError>
1028	where
1029		A: ViewMutElement,
1030		D: Dimension,
1031	{
1032		self.status = ChecksumStatus::Outdated;
1033		Ok(ArrayViewMut::<A, D>::view_mut_npy(self.data)?)
1034	}
1035}
1036
1037impl Drop for NpyViewMut<'_> {
1038	fn drop(&mut self) {
1039		if self.status == ChecksumStatus::Outdated {
1040			self.update();
1041		}
1042	}
1043}
1044
1045/// Checksum status of an [`NpyView`] or [`NpyViewMut`].
1046#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1047pub enum ChecksumStatus {
1048	/// The checksum has not been computed and the data has not changed.
1049	Unverified,
1050	/// The checksum is correct and the data has not changed.
1051	Correct,
1052	/// The data may have changed.
1053	Outdated,
1054}
1055
1056impl Default for ChecksumStatus {
1057	fn default() -> Self {
1058		Self::Unverified
1059	}
1060}
1061
1062fn crc32_verify(bytes: &[u8], crc32: [u8; 4]) -> Result<u32, ZipError> {
1063	let crc32 = u32::from_le_bytes(crc32);
1064	if crc32_update(bytes) == crc32 {
1065		Ok(crc32)
1066	} else {
1067		Err(ZipError::Io(io::Error::other("Invalid checksum")))
1068	}
1069}
1070
1071#[must_use]
1072fn crc32_update(bytes: &[u8]) -> u32 {
1073	let mut hasher = crc32fast::Hasher::new();
1074	hasher.update(bytes);
1075	hasher.finalize()
1076}
1077
1078fn range_at<T>(index: T, range: Range<T>) -> Result<Range<usize>, ZipError>
1079where
1080	T: TryInto<usize> + Copy,
1081{
1082	index
1083		.try_into()
1084		.ok()
1085		.and_then(|index| {
1086			let start = range.start.try_into().ok()?.checked_add(index)?;
1087			let end = range.end.try_into().ok()?.checked_add(index)?;
1088			Some(start..end)
1089		})
1090		.ok_or(ZipError::InvalidArchive("Range overflow".into()))
1091}
1092
1093fn slice_at<T>(bytes: &[u8], index: T, range: Range<T>) -> Result<&[u8], ZipError>
1094where
1095	T: TryInto<usize> + Copy,
1096{
1097	let range = range_at(index, range)?;
1098	bytes
1099		.get(range)
1100		.ok_or(ZipError::InvalidArchive("Range exceeds EOF".into()))
1101}
1102
1103#[must_use]
1104fn u16_at(bytes: &[u8], range: Range<usize>) -> u16 {
1105	u16::from_le_bytes(bytes.get(range).unwrap().try_into().unwrap())
1106}
1107
1108#[must_use]
1109fn u32_at(bytes: &[u8], range: Range<usize>) -> u32 {
1110	u32::from_le_bytes(bytes.get(range).unwrap().try_into().unwrap())
1111}
1112
1113#[must_use]
1114fn as_array_ref(slice: &[u8]) -> &[u8; 4] {
1115	slice.try_into().unwrap()
1116}
1117
1118#[must_use]
1119fn as_array_mut(slice: &mut [u8]) -> &mut [u8; 4] {
1120	slice.try_into().unwrap()
1121}