ndarray_npy/npy/
mod.rs

1//! Functionality for `.npy` files.
2//!
3//! Most of this functionality is reexported at the top level of the crate.
4
5mod elements;
6pub mod header;
7
8use self::header::{
9    FormatHeaderError, Header, Layout, ParseHeaderError, ReadHeaderError, WriteHeaderError,
10};
11use ndarray::prelude::*;
12use ndarray::{Data, DataOwned, IntoDimension};
13use py_literal::Value as PyValue;
14use std::convert::TryInto;
15use std::error::Error;
16use std::fmt;
17use std::fs::File;
18use std::io::{self, BufWriter, Seek};
19use std::mem;
20
21/// Read an `.npy` file located at the specified path.
22///
23/// This is a convience function for using `File::open` followed by
24/// [`ReadNpyExt::read_npy`](trait.ReadNpyExt.html#tymethod.read_npy).
25///
26/// # Example
27///
28/// ```
29/// use ndarray::Array2;
30/// use ndarray_npy::read_npy;
31/// # use ndarray_npy::ReadNpyError;
32///
33/// let arr: Array2<i32> = read_npy("resources/array.npy")?;
34/// # println!("arr = {}", arr);
35/// # Ok::<_, ReadNpyError>(())
36/// ```
37pub fn read_npy<P, T>(path: P) -> Result<T, ReadNpyError>
38where
39    P: AsRef<std::path::Path>,
40    T: ReadNpyExt,
41{
42    T::read_npy(File::open(path)?)
43}
44
45/// Writes an array to an `.npy` file at the specified path.
46///
47/// This function will create the file if it does not exist, or overwrite it if
48/// it does.
49///
50/// This is a convenience function for `BufWriter::new(File::create(path)?)`
51/// followed by [`WriteNpyExt::write_npy`].
52///
53/// # Example
54///
55/// ```no_run
56/// use ndarray::array;
57/// use ndarray_npy::write_npy;
58/// # use ndarray_npy::WriteNpyError;
59///
60/// let arr = array![[1, 2, 3], [4, 5, 6]];
61/// write_npy("array.npy", &arr)?;
62/// # Ok::<_, WriteNpyError>(())
63/// ```
64pub fn write_npy<P, T>(path: P, array: &T) -> Result<(), WriteNpyError>
65where
66    P: AsRef<std::path::Path>,
67    T: WriteNpyExt + ?Sized,
68{
69    array.write_npy(BufWriter::new(File::create(path)?))
70}
71
72/// Writes an array to a new `.npy` file at the specified path; error if the file exists.
73///
74/// This is a convenience function for `BufWriter::new(File::create_new(path)?)` followed by
75/// [`WriteNpyExt::write_npy`].
76///
77/// # Example
78///
79/// ```no_run
80/// use ndarray::array;
81/// use ndarray_npy::create_new_npy;
82/// # use ndarray_npy::WriteNpyError;
83///
84/// let arr = array![[1, 2, 3], [4, 5, 6]];
85/// create_new_npy("new_array.npy", &arr)?;
86/// assert!(create_new_npy("new_array.npy", &arr).is_err());
87/// # Ok::<_, WriteNpyError>(())
88/// ```
89pub fn create_new_npy<P, T>(path: P, array: &T) -> Result<(), WriteNpyError>
90where
91    P: AsRef<std::path::Path>,
92    T: WriteNpyExt + ?Sized,
93{
94    array.write_npy(BufWriter::new(File::create_new(path)?))
95}
96
97/// Writes an `.npy` file (sparse if possible) with bitwise-zero-filled data.
98///
99/// The `.npy` file represents an array with element type `A` and shape
100/// specified by `shape`, with all elements of the array represented by an
101/// all-zero byte-pattern. The file is written starting at the current cursor
102/// location and truncated such that there are no additional bytes after the
103/// `.npy` data.
104///
105/// This function is primarily useful for creating an `.npy` file for an array
106/// larger than available memory. The file can then be memory-mapped and
107/// modified using [`ViewMutNpyExt`].
108///
109/// # Panics
110///
111/// May panic if any of the following overflow `isize` or `u64`:
112///
113/// - the number of elements in the array
114/// - the size of the array in bytes
115/// - the size of the resulting file in bytes
116///
117/// # Considerations
118///
119/// ## Data is zeroed bytes
120///
121/// The data consists of all zeroed bytes, so this function is useful only for
122/// element types for which an all-zero byte-pattern is a valid representation.
123///
124/// ## Sparse file
125///
126/// On filesystems which support [sparse files], most of the data should be
127/// handled by empty blocks, i.e. not allocated on disk. If you plan to
128/// memory-map the file to modify it and know that most blocks of the file will
129/// ultimately contain some nonzero data, then it may be beneficial to allocate
130/// space for the file on disk before modifying it in order to avoid
131/// fragmentation. For example, on POSIX-like systems, you can do this by
132/// calling `fallocate` on the file.
133///
134/// [sparse files]: https://en.wikipedia.org/wiki/Sparse_file
135///
136/// ## Alternatives
137///
138/// If all you want to do is create an array larger than the available memory
139/// and don't care about actually writing the data to disk, it may be worth
140/// considering alternative options:
141///
142/// - Add more swap space to your system, using swap file(s) if necessary, so
143///   that you can allocate the array as normal.
144///
145/// - If you know the data will be sparse:
146///
147///   - Use a sparse data structure instead of `ndarray`'s array types. For
148///     example, the [`sprs`](https://crates.io/crates/sprs) crate provides
149///     sparse matrices.
150///
151///   - Rely on memory overcommitment. In other words, configure the operating
152///     system to allocate more memory than actually exists. However, this
153///     risks the system running out of memory if the data is not as sparse as
154///     you expect.
155///
156/// # Example
157///
158/// In this example, a file containing 64 GiB of zeroed `f64` elements is
159/// created. Then, an `ArrayViewMut` is created by memory-mapping the file.
160/// Modifications to the data in the `ArrayViewMut` will be applied to the
161/// backing file. This works even on systems with less than 64 GiB of physical
162/// memory. On filesystems which support [sparse files], the disk space that's
163/// actually used depends on how much data is modified.
164///
165/// ```no_run
166/// use memmap2::MmapMut;
167/// use ndarray::ArrayViewMut3;
168/// use ndarray_npy::{write_zeroed_npy, ViewMutNpyExt};
169/// use std::fs::{File, OpenOptions};
170///
171/// let path = "array.npy";
172///
173/// // Create a (sparse if supported) file containing 64 GiB of zeroed data.
174/// let file = File::create(path)?;
175/// write_zeroed_npy::<f64, _>(&file, (1024, 2048, 4096))?;
176///
177/// // Memory-map the file and create the mutable view.
178/// let file = OpenOptions::new().read(true).write(true).open(path)?;
179/// let mut mmap = unsafe { MmapMut::map_mut(&file)? };
180/// let mut view_mut = ArrayViewMut3::<f64>::view_mut_npy(&mut mmap)?;
181///
182/// // Modify an element in the view.
183/// view_mut[[500, 1000, 2000]] = 3.14;
184/// #
185/// # Ok::<_, Box<dyn std::error::Error>>(())
186/// ```
187pub fn write_zeroed_npy<A, D>(mut file: &File, shape: D) -> Result<(), WriteNpyError>
188where
189    A: WritableElement,
190    D: IntoDimension,
191{
192    let dim = shape.into_dimension();
193    let data_bytes_len: u64 = dim
194        .size_checked()
195        .expect("overflow computing number of elements")
196        .checked_mul(mem::size_of::<A>())
197        .expect("overflow computing length of data")
198        .try_into()
199        .expect("overflow converting length of data to u64");
200    Header {
201        type_descriptor: A::type_descriptor(),
202        layout: Layout::Standard,
203        shape: dim.as_array_view().to_vec(),
204    }
205    .write(file)?;
206    let current_offset = file.stream_position()?;
207    // First, truncate the file to the current offset.
208    file.set_len(current_offset)?;
209    // Then, zero-extend the length to represent the data (sparse if possible).
210    file.set_len(
211        current_offset
212            .checked_add(data_bytes_len)
213            .expect("overflow computing file length"),
214    )?;
215    Ok(())
216}
217
218/// An error writing array data.
219#[derive(Debug)]
220pub enum WriteDataError {
221    /// An error caused by I/O.
222    Io(io::Error),
223    /// An error formatting the data.
224    FormatData(Box<dyn Error + Send + Sync + 'static>),
225}
226
227impl Error for WriteDataError {
228    fn source(&self) -> Option<&(dyn Error + 'static)> {
229        match self {
230            WriteDataError::Io(err) => Some(err),
231            WriteDataError::FormatData(err) => Some(&**err),
232        }
233    }
234}
235
236impl fmt::Display for WriteDataError {
237    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
238        match self {
239            WriteDataError::Io(err) => write!(f, "I/O error: {}", err),
240            WriteDataError::FormatData(err) => write!(f, "error formatting data: {}", err),
241        }
242    }
243}
244
245impl From<io::Error> for WriteDataError {
246    fn from(err: io::Error) -> WriteDataError {
247        WriteDataError::Io(err)
248    }
249}
250
251/// An array element type that can be written to an `.npy` or `.npz` file.
252pub trait WritableElement: Sized {
253    /// Returns a descriptor of the type that can be used in the header.
254    fn type_descriptor() -> PyValue;
255
256    /// Writes a single instance of `Self` to the writer.
257    fn write<W: io::Write>(&self, writer: W) -> Result<(), WriteDataError>;
258
259    /// Writes a slice of `Self` to the writer.
260    fn write_slice<W: io::Write>(slice: &[Self], writer: W) -> Result<(), WriteDataError>;
261}
262
263/// An error writing a `.npy` file.
264#[derive(Debug)]
265pub enum WriteNpyError {
266    /// An error caused by I/O.
267    Io(io::Error),
268    /// An error formatting the header.
269    FormatHeader(FormatHeaderError),
270    /// An error formatting the data.
271    FormatData(Box<dyn Error + Send + Sync + 'static>),
272}
273
274impl Error for WriteNpyError {
275    fn source(&self) -> Option<&(dyn Error + 'static)> {
276        match self {
277            WriteNpyError::Io(err) => Some(err),
278            WriteNpyError::FormatHeader(err) => Some(err),
279            WriteNpyError::FormatData(err) => Some(&**err),
280        }
281    }
282}
283
284impl fmt::Display for WriteNpyError {
285    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
286        match self {
287            WriteNpyError::Io(err) => write!(f, "I/O error: {}", err),
288            WriteNpyError::FormatHeader(err) => write!(f, "error formatting header: {}", err),
289            WriteNpyError::FormatData(err) => write!(f, "error formatting data: {}", err),
290        }
291    }
292}
293
294impl From<io::Error> for WriteNpyError {
295    fn from(err: io::Error) -> WriteNpyError {
296        WriteNpyError::Io(err)
297    }
298}
299
300impl From<WriteHeaderError> for WriteNpyError {
301    fn from(err: WriteHeaderError) -> WriteNpyError {
302        match err {
303            WriteHeaderError::Io(err) => WriteNpyError::Io(err),
304            WriteHeaderError::Format(err) => WriteNpyError::FormatHeader(err),
305        }
306    }
307}
308
309impl From<FormatHeaderError> for WriteNpyError {
310    fn from(err: FormatHeaderError) -> WriteNpyError {
311        WriteNpyError::FormatHeader(err)
312    }
313}
314
315impl From<WriteDataError> for WriteNpyError {
316    fn from(err: WriteDataError) -> WriteNpyError {
317        match err {
318            WriteDataError::Io(err) => WriteNpyError::Io(err),
319            WriteDataError::FormatData(err) => WriteNpyError::FormatData(err),
320        }
321    }
322}
323
324/// Extension trait for writing [`ArrayBase`] to `.npy` files.
325///
326/// If writes are expensive (e.g. for a file or network socket) and the layout
327/// of the array is not known to be in standard or Fortran layout, it is
328/// strongly recommended to wrap the writer in a [`BufWriter`]. For the sake of
329/// convenience, this method calls [`.flush()`](io::Write::flush) on the writer
330/// before returning.
331///
332/// # Example
333///
334/// ```no_run
335/// use ndarray::{array, Array2};
336/// use ndarray_npy::WriteNpyExt;
337/// use std::fs::File;
338/// use std::io::BufWriter;
339/// # use ndarray_npy::WriteNpyError;
340///
341/// let arr: Array2<i32> = array![[1, 2, 3], [4, 5, 6]];
342/// let writer = BufWriter::new(File::create("array.npy")?);
343/// arr.write_npy(writer)?;
344/// # Ok::<_, WriteNpyError>(())
345/// ```
346pub trait WriteNpyExt {
347    /// Writes the array to `writer` in [`.npy`
348    /// format](https://docs.scipy.org/doc/numpy/reference/generated/numpy.lib.format.html).
349    ///
350    /// This function is the Rust equivalent of
351    /// [`numpy.save`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.save.html).
352    fn write_npy<W: io::Write>(&self, writer: W) -> Result<(), WriteNpyError>;
353}
354
355impl<A, D> WriteNpyExt for ArrayRef<A, D>
356where
357    A: WritableElement,
358    D: Dimension,
359{
360    fn write_npy<W: io::Write>(&self, mut writer: W) -> Result<(), WriteNpyError> {
361        let write_contiguous = |mut writer: W, layout: Layout| {
362            Header {
363                type_descriptor: A::type_descriptor(),
364                layout,
365                shape: self.shape().to_owned(),
366            }
367            .write(&mut writer)?;
368            A::write_slice(self.as_slice_memory_order().unwrap(), &mut writer)?;
369            writer.flush()?;
370            Ok(())
371        };
372        if self.is_standard_layout() {
373            write_contiguous(writer, Layout::Standard)
374        } else if self.view().reversed_axes().is_standard_layout() {
375            write_contiguous(writer, Layout::Fortran)
376        } else {
377            Header {
378                type_descriptor: A::type_descriptor(),
379                layout: Layout::Standard,
380                shape: self.shape().to_owned(),
381            }
382            .write(&mut writer)?;
383            for elem in self.iter() {
384                elem.write(&mut writer)?;
385            }
386            writer.flush()?;
387            Ok(())
388        }
389    }
390}
391
392impl<A, S, D> WriteNpyExt for ArrayBase<S, D>
393where
394    A: WritableElement,
395    S: Data<Elem = A>,
396    D: Dimension,
397{
398    fn write_npy<W: io::Write>(&self, writer: W) -> Result<(), WriteNpyError> {
399        let arr: &ArrayRef<A, D> = self;
400        arr.write_npy(writer)
401    }
402}
403
404/// An error reading array data.
405#[derive(Debug)]
406pub enum ReadDataError {
407    /// An error caused by I/O.
408    Io(io::Error),
409    /// The type descriptor does not match the element type.
410    WrongDescriptor(PyValue),
411    /// The file does not contain all the data described in the header.
412    MissingData,
413    /// Extra bytes are present between the end of the data and the end of the
414    /// file.
415    ExtraBytes(usize),
416    /// An error parsing the data.
417    ParseData(Box<dyn Error + Send + Sync + 'static>),
418}
419
420impl Error for ReadDataError {
421    fn source(&self) -> Option<&(dyn Error + 'static)> {
422        match self {
423            ReadDataError::Io(err) => Some(err),
424            ReadDataError::WrongDescriptor(_) => None,
425            ReadDataError::MissingData => None,
426            ReadDataError::ExtraBytes(_) => None,
427            ReadDataError::ParseData(err) => Some(&**err),
428        }
429    }
430}
431
432impl fmt::Display for ReadDataError {
433    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
434        match self {
435            ReadDataError::Io(err) => write!(f, "I/O error: {}", err),
436            ReadDataError::WrongDescriptor(desc) => {
437                write!(f, "incorrect descriptor ({}) for this type", desc)
438            }
439            ReadDataError::MissingData => write!(f, "reached EOF before reading all data"),
440            ReadDataError::ExtraBytes(num_extra_bytes) => {
441                write!(f, "file had {} extra bytes before EOF", num_extra_bytes)
442            }
443            ReadDataError::ParseData(err) => write!(f, "error parsing data: {}", err),
444        }
445    }
446}
447
448impl From<io::Error> for ReadDataError {
449    /// Performs the conversion.
450    ///
451    /// If the error kind is `UnexpectedEof`, the `MissingData` variant is
452    /// returned. Otherwise, the `Io` variant is returned.
453    fn from(err: io::Error) -> ReadDataError {
454        if err.kind() == io::ErrorKind::UnexpectedEof {
455            ReadDataError::MissingData
456        } else {
457            ReadDataError::Io(err)
458        }
459    }
460}
461
462/// An array element type that can be read from an `.npy` or `.npz` file.
463pub trait ReadableElement: Sized {
464    /// Reads to the end of the `reader`, creating a `Vec` of length `len`.
465    ///
466    /// This method should return `Err(_)` in at least the following cases:
467    ///
468    /// * if the `type_desc` does not match `Self`
469    /// * if the `reader` has fewer elements than `len`
470    /// * if the `reader` has extra bytes after reading `len` elements
471    fn read_to_end_exact_vec<R: io::Read>(
472        reader: R,
473        type_desc: &PyValue,
474        len: usize,
475    ) -> Result<Vec<Self>, ReadDataError>;
476}
477
478/// An error reading a `.npy` file.
479#[derive(Debug)]
480pub enum ReadNpyError {
481    /// An error caused by I/O.
482    Io(io::Error),
483    /// An error parsing the file header.
484    ParseHeader(ParseHeaderError),
485    /// An error parsing the data.
486    ParseData(Box<dyn Error + Send + Sync + 'static>),
487    /// Overflow while computing the length of the array (in units of bytes or
488    /// the number of elements) from the shape described in the file header.
489    LengthOverflow,
490    /// An error caused by incorrect `Dimension` type.
491    WrongNdim(Option<usize>, usize),
492    /// The type descriptor does not match the element type.
493    WrongDescriptor(PyValue),
494    /// The file does not contain all the data described in the header.
495    MissingData,
496    /// Extra bytes are present between the end of the data and the end of the
497    /// file.
498    ExtraBytes(usize),
499}
500
501impl Error for ReadNpyError {
502    fn source(&self) -> Option<&(dyn Error + 'static)> {
503        match self {
504            ReadNpyError::Io(err) => Some(err),
505            ReadNpyError::ParseHeader(err) => Some(err),
506            ReadNpyError::ParseData(err) => Some(&**err),
507            ReadNpyError::LengthOverflow => None,
508            ReadNpyError::WrongNdim(_, _) => None,
509            ReadNpyError::WrongDescriptor(_) => None,
510            ReadNpyError::MissingData => None,
511            ReadNpyError::ExtraBytes(_) => None,
512        }
513    }
514}
515
516impl fmt::Display for ReadNpyError {
517    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
518        match self {
519            ReadNpyError::Io(err) => write!(f, "I/O error: {}", err),
520            ReadNpyError::ParseHeader(err) => write!(f, "error parsing header: {}", err),
521            ReadNpyError::ParseData(err) => write!(f, "error parsing data: {}", err),
522            ReadNpyError::LengthOverflow => write!(f, "overflow computing length from shape"),
523            ReadNpyError::WrongNdim(expected, actual) => write!(
524                f,
525                "ndim {} of array did not match Dimension type with NDIM = {:?}",
526                actual, expected
527            ),
528            ReadNpyError::WrongDescriptor(desc) => {
529                write!(f, "incorrect descriptor ({}) for this type", desc)
530            }
531            ReadNpyError::MissingData => write!(f, "reached EOF before reading all data"),
532            ReadNpyError::ExtraBytes(num_extra_bytes) => {
533                write!(f, "file had {} extra bytes before EOF", num_extra_bytes)
534            }
535        }
536    }
537}
538
539impl From<io::Error> for ReadNpyError {
540    fn from(err: io::Error) -> ReadNpyError {
541        ReadNpyError::Io(err)
542    }
543}
544
545impl From<ReadHeaderError> for ReadNpyError {
546    fn from(err: ReadHeaderError) -> ReadNpyError {
547        match err {
548            ReadHeaderError::Io(err) => ReadNpyError::Io(err),
549            ReadHeaderError::Parse(err) => ReadNpyError::ParseHeader(err),
550        }
551    }
552}
553
554impl From<ParseHeaderError> for ReadNpyError {
555    fn from(err: ParseHeaderError) -> ReadNpyError {
556        ReadNpyError::ParseHeader(err)
557    }
558}
559
560impl From<ReadDataError> for ReadNpyError {
561    fn from(err: ReadDataError) -> ReadNpyError {
562        match err {
563            ReadDataError::Io(err) => ReadNpyError::Io(err),
564            ReadDataError::WrongDescriptor(desc) => ReadNpyError::WrongDescriptor(desc),
565            ReadDataError::MissingData => ReadNpyError::MissingData,
566            ReadDataError::ExtraBytes(nbytes) => ReadNpyError::ExtraBytes(nbytes),
567            ReadDataError::ParseData(err) => ReadNpyError::ParseData(err),
568        }
569    }
570}
571
572/// Extension trait for reading `Array` from `.npy` files.
573///
574/// # Example
575///
576/// ```
577/// use ndarray::Array2;
578/// use ndarray_npy::ReadNpyExt;
579/// use std::fs::File;
580/// # use ndarray_npy::ReadNpyError;
581///
582/// let reader = File::open("resources/array.npy")?;
583/// let arr = Array2::<i32>::read_npy(reader)?;
584/// # println!("arr = {}", arr);
585/// # Ok::<_, ReadNpyError>(())
586/// ```
587pub trait ReadNpyExt: Sized {
588    /// Reads the array from `reader` in [`.npy`
589    /// format](https://docs.scipy.org/doc/numpy/reference/generated/numpy.lib.format.html).
590    ///
591    /// This function is the Rust equivalent of
592    /// [`numpy.load`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.load.html)
593    /// for `.npy` files.
594    fn read_npy<R: io::Read>(reader: R) -> Result<Self, ReadNpyError>;
595}
596
597impl<A, S, D> ReadNpyExt for ArrayBase<S, D>
598where
599    A: ReadableElement,
600    S: DataOwned<Elem = A>,
601    D: Dimension,
602{
603    fn read_npy<R: io::Read>(mut reader: R) -> Result<Self, ReadNpyError> {
604        let header = Header::from_reader(&mut reader)?;
605        let shape = header.shape.into_dimension();
606        let ndim = shape.ndim();
607        let len = shape_length_checked::<A>(&shape).ok_or(ReadNpyError::LengthOverflow)?;
608        let data = A::read_to_end_exact_vec(&mut reader, &header.type_descriptor, len)?;
609        ArrayBase::from_shape_vec(shape.set_f(header.layout.is_fortran()), data)
610            .unwrap()
611            .into_dimensionality()
612            .map_err(|_| ReadNpyError::WrongNdim(D::NDIM, ndim))
613    }
614}
615
616/// An error viewing a `.npy` file.
617#[derive(Debug)]
618#[non_exhaustive]
619pub enum ViewNpyError {
620    /// An error caused by I/O.
621    Io(io::Error),
622    /// An error parsing the file header.
623    ParseHeader(ParseHeaderError),
624    /// Some of the data is invalid for the element type.
625    InvalidData(Box<dyn Error + Send + Sync + 'static>),
626    /// Overflow while computing the length of the array (in units of bytes or
627    /// the number of elements) from the shape described in the file header.
628    LengthOverflow,
629    /// An error caused by incorrect `Dimension` type.
630    WrongNdim(Option<usize>, usize),
631    /// The type descriptor does not match the element type.
632    WrongDescriptor(PyValue),
633    /// The type descriptor does not match the native endianness.
634    NonNativeEndian,
635    /// The start of the data is not properly aligned for the element type.
636    MisalignedData,
637    /// The file does not contain all the data described in the header.
638    MissingBytes(usize),
639    /// Extra bytes are present between the end of the data and the end of the
640    /// file.
641    ExtraBytes(usize),
642}
643
644impl Error for ViewNpyError {
645    fn source(&self) -> Option<&(dyn Error + 'static)> {
646        match self {
647            ViewNpyError::Io(err) => Some(err),
648            ViewNpyError::ParseHeader(err) => Some(err),
649            ViewNpyError::InvalidData(err) => Some(&**err),
650            ViewNpyError::LengthOverflow => None,
651            ViewNpyError::WrongNdim(_, _) => None,
652            ViewNpyError::WrongDescriptor(_) => None,
653            ViewNpyError::NonNativeEndian => None,
654            ViewNpyError::MisalignedData => None,
655            ViewNpyError::MissingBytes(_) => None,
656            ViewNpyError::ExtraBytes(_) => None,
657        }
658    }
659}
660
661impl fmt::Display for ViewNpyError {
662    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
663        match self {
664            ViewNpyError::Io(err) => write!(f, "I/O error: {}", err),
665            ViewNpyError::ParseHeader(err) => write!(f, "error parsing header: {}", err),
666            ViewNpyError::InvalidData(err) => write!(f, "invalid data for element type: {}", err),
667            ViewNpyError::LengthOverflow => write!(f, "overflow computing length from shape"),
668            ViewNpyError::WrongNdim(expected, actual) => write!(
669                f,
670                "ndim {} of array did not match Dimension type with NDIM = {:?}",
671                actual, expected
672            ),
673            ViewNpyError::WrongDescriptor(desc) => {
674                write!(f, "incorrect descriptor ({}) for this type", desc)
675            }
676            ViewNpyError::NonNativeEndian => {
677                write!(f, "descriptor does not match native endianness")
678            }
679            ViewNpyError::MisalignedData => write!(
680                f,
681                "start of data is not properly aligned for the element type"
682            ),
683            ViewNpyError::MissingBytes(num_missing_bytes) => write!(
684                f,
685                "missing {} bytes of data specified in header",
686                num_missing_bytes
687            ),
688            ViewNpyError::ExtraBytes(num_extra_bytes) => {
689                write!(f, "file had {} extra bytes before EOF", num_extra_bytes)
690            }
691        }
692    }
693}
694
695impl From<ReadHeaderError> for ViewNpyError {
696    fn from(err: ReadHeaderError) -> ViewNpyError {
697        match err {
698            ReadHeaderError::Io(err) => ViewNpyError::Io(err),
699            ReadHeaderError::Parse(err) => ViewNpyError::ParseHeader(err),
700        }
701    }
702}
703
704impl From<ParseHeaderError> for ViewNpyError {
705    fn from(err: ParseHeaderError) -> ViewNpyError {
706        ViewNpyError::ParseHeader(err)
707    }
708}
709
710impl From<ViewDataError> for ViewNpyError {
711    fn from(err: ViewDataError) -> ViewNpyError {
712        match err {
713            ViewDataError::WrongDescriptor(desc) => ViewNpyError::WrongDescriptor(desc),
714            ViewDataError::NonNativeEndian => ViewNpyError::NonNativeEndian,
715            ViewDataError::Misaligned => ViewNpyError::MisalignedData,
716            ViewDataError::MissingBytes(nbytes) => ViewNpyError::MissingBytes(nbytes),
717            ViewDataError::ExtraBytes(nbytes) => ViewNpyError::ExtraBytes(nbytes),
718            ViewDataError::InvalidData(err) => ViewNpyError::InvalidData(err),
719        }
720    }
721}
722
723/// Extension trait for creating an [`ArrayView`] from a buffer containing an
724/// `.npy` file.
725///
726/// The primary use-case for this is viewing a memory-mapped `.npy` file.
727///
728/// # Notes
729///
730/// - For types for which not all bit patterns are valid, such as `bool`, the
731///   implementation iterates over all of the elements when creating the view
732///   to ensure they have a valid bit pattern.
733///
734/// - Viewing an `.npy` file has more restrictions than reading it, due to
735///   memory layout. Specifically:
736///
737///   - An error is returned if the data in the buffer is not properly aligned
738///     for the element type. Typically, this should not be a concern for
739///     memory-mapped files (unless an option like `MAP_FIXED` is used), since
740///     memory mappings are usually aligned to a page boundary, and the `.npy`
741///     format has padding such that the header size is a multiple of 64 bytes.
742///
743///   - An error is returned if the endianness of the data does not match the
744///     endianness of the target element type. For example, multi-byte
745///     primitive types such as `f32` require that the data in the file match
746///     the native endianness of the machine.
747///
748/// # Example
749///
750/// This is an example of opening a readonly memory-mapped file as an
751/// [`ArrayView`].
752///
753/// This example uses the [`memmap2`](https://crates.io/crates/memmap2) crate
754/// because that appears to be the best-maintained memory-mapping crate at the
755/// moment, but `view_npy` takes a `&[u8]` instead of a file so that you can
756/// use the memory-mapping crate you're most comfortable with.
757///
758/// ```
759/// # // Miri doesn't support mmap, and the file is in little endian format.
760/// # if !cfg!(miri) && cfg!(target_endian = "little") {
761/// use memmap2::Mmap;
762/// use ndarray::ArrayView2;
763/// use ndarray_npy::ViewNpyExt;
764/// use std::fs::File;
765///
766/// let file = File::open("resources/array.npy")?;
767/// let mmap = unsafe { Mmap::map(&file)? };
768/// let view = ArrayView2::<i32>::view_npy(&mmap)?;
769/// # println!("view = {}", view);
770/// # }
771/// # Ok::<_, Box<dyn std::error::Error>>(())
772/// ```
773pub trait ViewNpyExt<'a>: Sized {
774    /// Creates an `ArrayView` from a buffer containing an `.npy` file.
775    fn view_npy(buf: &'a [u8]) -> Result<Self, ViewNpyError>;
776}
777
778/// Extension trait for creating an [`ArrayViewMut`] from a mutable buffer
779/// containing an `.npy` file.
780///
781/// The primary use-case for this is modifying a memory-mapped `.npy` file.
782/// Modifying the elements in the view will modify the file. Modifying the
783/// shape/strides of the view will *not* modify the shape/strides of the array
784/// in the file.
785///
786/// Notes:
787///
788/// - For types for which not all bit patterns are valid, such as `bool`, the
789///   implementation iterates over all of the elements when creating the view
790///   to ensure they have a valid bit pattern.
791///
792/// - Viewing an `.npy` file has more restrictions than reading it, due to
793///   memory layout. Specifically:
794///
795///   - An error is returned if the data in the buffer is not properly aligned
796///     for the element type. Typically, this should not be a concern for
797///     memory-mapped files (unless an option like `MAP_FIXED` is used), since
798///     memory mappings are usually aligned to a page boundary, and the `.npy`
799///     format has padding such that the header size is a multiple of 64 bytes.
800///
801///   - An error is returned if the endianness of the data does not match the
802///     endianness of the target element type. For example, multi-byte
803///     primitive types such as `f32` require that the data in the file match
804///     the native endianness of the machine.
805///
806/// # Example
807///
808/// This is an example of opening a writable memory-mapped file as an
809/// [`ArrayViewMut`]. Changes to the data in the view will modify the
810/// underlying file.
811///
812/// This example uses the [`memmap2`](https://crates.io/crates/memmap2) crate
813/// because that appears to be the best-maintained memory-mapping crate at the
814/// moment, but `view_mut_npy` takes a `&mut [u8]` instead of a file so that
815/// you can use the memory-mapping crate you're most comfortable with.
816///
817/// ```
818/// # // Miri doesn't support mmap, and the file is in little endian format.
819/// # if !cfg!(miri) && cfg!(target_endian = "little") {
820/// use memmap2::MmapMut;
821/// use ndarray::ArrayViewMut2;
822/// use ndarray_npy::ViewMutNpyExt;
823/// use std::fs;
824///
825/// let file = fs::OpenOptions::new()
826///     .read(true)
827///     .write(true)
828///     .open("resources/array.npy")?;
829/// let mut mmap = unsafe { MmapMut::map_mut(&file)? };
830/// let view_mut = ArrayViewMut2::<i32>::view_mut_npy(&mut mmap)?;
831/// # println!("view_mut = {}", view_mut);
832/// # }
833/// # Ok::<_, Box<dyn std::error::Error>>(())
834/// ```
835pub trait ViewMutNpyExt<'a>: Sized {
836    /// Creates an `ArrayViewMut` from a mutable buffer containing an `.npy`
837    /// file.
838    fn view_mut_npy(buf: &'a mut [u8]) -> Result<Self, ViewNpyError>;
839}
840
841impl<'a, A, D> ViewNpyExt<'a> for ArrayView<'a, A, D>
842where
843    A: ViewElement,
844    D: Dimension,
845{
846    fn view_npy(buf: &'a [u8]) -> Result<Self, ViewNpyError> {
847        let mut reader = buf;
848        let header = Header::from_reader(&mut reader)?;
849        let shape = header.shape.into_dimension();
850        let ndim = shape.ndim();
851        let len = shape_length_checked::<A>(&shape).ok_or(ViewNpyError::LengthOverflow)?;
852        let data = A::bytes_as_slice(reader, &header.type_descriptor, len)?;
853        ArrayView::from_shape(shape.set_f(header.layout.is_fortran()), data)
854            .unwrap()
855            .into_dimensionality()
856            .map_err(|_| ViewNpyError::WrongNdim(D::NDIM, ndim))
857    }
858}
859
860impl<'a, A, D> ViewMutNpyExt<'a> for ArrayViewMut<'a, A, D>
861where
862    A: ViewMutElement,
863    D: Dimension,
864{
865    fn view_mut_npy(buf: &'a mut [u8]) -> Result<Self, ViewNpyError> {
866        let mut reader = &*buf;
867        let header = Header::from_reader(&mut reader)?;
868        let shape = header.shape.into_dimension();
869        let ndim = shape.ndim();
870        let len = shape_length_checked::<A>(&shape).ok_or(ViewNpyError::LengthOverflow)?;
871        let mid = buf.len() - reader.len();
872        let data = A::bytes_as_mut_slice(&mut buf[mid..], &header.type_descriptor, len)?;
873        ArrayViewMut::from_shape(shape.set_f(header.layout.is_fortran()), data)
874            .unwrap()
875            .into_dimensionality()
876            .map_err(|_| ViewNpyError::WrongNdim(D::NDIM, ndim))
877    }
878}
879
880/// An error viewing array data.
881#[derive(Debug)]
882#[non_exhaustive]
883pub enum ViewDataError {
884    /// The type descriptor does not match the element type.
885    WrongDescriptor(PyValue),
886    /// The type descriptor does not match the native endianness.
887    NonNativeEndian,
888    /// The start of the data is not properly aligned for the element type.
889    Misaligned,
890    /// The file does not contain all the data described in the header.
891    MissingBytes(usize),
892    /// Extra bytes are present between the end of the data and the end of the
893    /// file.
894    ExtraBytes(usize),
895    /// Some of the data is invalid for the element type.
896    InvalidData(Box<dyn Error + Send + Sync + 'static>),
897}
898
899impl Error for ViewDataError {
900    fn source(&self) -> Option<&(dyn Error + 'static)> {
901        match self {
902            ViewDataError::WrongDescriptor(_) => None,
903            ViewDataError::NonNativeEndian => None,
904            ViewDataError::Misaligned => None,
905            ViewDataError::MissingBytes(_) => None,
906            ViewDataError::ExtraBytes(_) => None,
907            ViewDataError::InvalidData(err) => Some(&**err),
908        }
909    }
910}
911
912impl fmt::Display for ViewDataError {
913    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
914        match self {
915            ViewDataError::WrongDescriptor(desc) => {
916                write!(f, "incorrect descriptor ({}) for this type", desc)
917            }
918            ViewDataError::NonNativeEndian => {
919                write!(f, "descriptor does not match native endianness")
920            }
921            ViewDataError::Misaligned => write!(
922                f,
923                "start of data is not properly aligned for the element type"
924            ),
925            ViewDataError::MissingBytes(num_missing_bytes) => write!(
926                f,
927                "missing {} bytes of data specified in header",
928                num_missing_bytes
929            ),
930            ViewDataError::ExtraBytes(num_extra_bytes) => {
931                write!(f, "file had {} extra bytes before EOF", num_extra_bytes)
932            }
933            ViewDataError::InvalidData(err) => write!(f, "invalid data for element type: {}", err),
934        }
935    }
936}
937
938/// An array element type that can be viewed (without copying) in an `.npy`
939/// file.
940pub trait ViewElement: Sized {
941    /// Casts `bytes` into a slice of elements of length `len`.
942    ///
943    /// Returns `Err(_)` in at least the following cases:
944    ///
945    ///   * if the `type_desc` does not match `Self` with native endianness
946    ///   * if the `bytes` slice is misaligned for elements of type `Self`
947    ///   * if the `bytes` slice is too short for `len` elements
948    ///   * if the `bytes` slice has extra bytes after `len` elements
949    ///
950    /// May panic if `len * size_of::<Self>()` overflows.
951    fn bytes_as_slice<'a>(
952        bytes: &'a [u8],
953        type_desc: &PyValue,
954        len: usize,
955    ) -> Result<&'a [Self], ViewDataError>;
956}
957
958/// An array element type that can be mutably viewed (without copying) in an
959/// `.npy` file.
960pub trait ViewMutElement: Sized {
961    /// Casts `bytes` into a mutable slice of elements of length `len`.
962    ///
963    /// Returns `Err(_)` in at least the following cases:
964    ///
965    ///   * if the `type_desc` does not match `Self` with native endianness
966    ///   * if the `bytes` slice is misaligned for elements of type `Self`
967    ///   * if the `bytes` slice is too short for `len` elements
968    ///   * if the `bytes` slice has extra bytes after `len` elements
969    ///
970    /// May panic if `len * size_of::<Self>()` overflows.
971    fn bytes_as_mut_slice<'a>(
972        bytes: &'a mut [u8],
973        type_desc: &PyValue,
974        len: usize,
975    ) -> Result<&'a mut [Self], ViewDataError>;
976}
977
978/// Computes the length associated with the shape (i.e. the product of the axis
979/// lengths), where the element type is `T`.
980///
981/// Returns `None` if the number of elements or the length in bytes would
982/// overflow `isize`.
983fn shape_length_checked<T>(shape: &IxDyn) -> Option<usize> {
984    let len = shape.size_checked()?;
985    if len > isize::MAX as usize {
986        return None;
987    }
988    let bytes = len.checked_mul(mem::size_of::<T>())?;
989    if bytes > isize::MAX as usize {
990        return None;
991    }
992    Some(len)
993}