ndarray_npy/npy/
mod.rs

1mod elements;
2pub mod header;
3
4use self::header::{
5    FormatHeaderError, Header, ParseHeaderError, ReadHeaderError, WriteHeaderError,
6};
7use ndarray::prelude::*;
8use ndarray::{Data, DataOwned, IntoDimension};
9use py_literal::Value as PyValue;
10use std::convert::TryInto;
11use std::error::Error;
12use std::fmt;
13use std::fs::File;
14use std::io::{self, BufWriter, Seek};
15use std::mem;
16
17/// Read an `.npy` file located at the specified path.
18///
19/// This is a convience function for using `File::open` followed by
20/// [`ReadNpyExt::read_npy`](trait.ReadNpyExt.html#tymethod.read_npy).
21///
22/// # Example
23///
24/// ```
25/// use ndarray::Array2;
26/// use ndarray_npy::read_npy;
27/// # use ndarray_npy::ReadNpyError;
28///
29/// let arr: Array2<i32> = read_npy("resources/array.npy")?;
30/// # println!("arr = {}", arr);
31/// # Ok::<_, ReadNpyError>(())
32/// ```
33pub fn read_npy<P, T>(path: P) -> Result<T, ReadNpyError>
34where
35    P: AsRef<std::path::Path>,
36    T: ReadNpyExt,
37{
38    T::read_npy(std::fs::File::open(path)?)
39}
40
41/// Writes an array to an `.npy` file at the specified path.
42///
43/// This function will create the file if it does not exist, or overwrite it if
44/// it does.
45///
46/// This is a convenience function for `BufWriter::new(File::create(path)?)`
47/// followed by [`WriteNpyExt::write_npy`].
48///
49/// # Example
50///
51/// ```no_run
52/// use ndarray::array;
53/// use ndarray_npy::write_npy;
54/// # use ndarray_npy::WriteNpyError;
55///
56/// let arr = array![[1, 2, 3], [4, 5, 6]];
57/// write_npy("array.npy", &arr)?;
58/// # Ok::<_, WriteNpyError>(())
59/// ```
60pub fn write_npy<P, T>(path: P, array: &T) -> Result<(), WriteNpyError>
61where
62    P: AsRef<std::path::Path>,
63    T: WriteNpyExt,
64{
65    array.write_npy(BufWriter::new(File::create(path)?))
66}
67
68/// Writes an `.npy` file (sparse if possible) with bitwise-zero-filled data.
69///
70/// The `.npy` file represents an array with element type `A` and shape
71/// specified by `shape`, with all elements of the array represented by an
72/// all-zero byte-pattern. The file is written starting at the current cursor
73/// location and truncated such that there are no additional bytes after the
74/// `.npy` data.
75///
76/// This function is primarily useful for creating an `.npy` file for an array
77/// larger than available memory. The file can then be memory-mapped and
78/// modified using [`ViewMutNpyExt`].
79///
80/// # Panics
81///
82/// May panic if any of the following overflow `isize` or `u64`:
83///
84/// - the number of elements in the array
85/// - the size of the array in bytes
86/// - the size of the resulting file in bytes
87///
88/// # Considerations
89///
90/// ## Data is zeroed bytes
91///
92/// The data consists of all zeroed bytes, so this function is useful only for
93/// element types for which an all-zero byte-pattern is a valid representation.
94///
95/// ## Sparse file
96///
97/// On filesystems which support [sparse files], most of the data should be
98/// handled by empty blocks, i.e. not allocated on disk. If you plan to
99/// memory-map the file to modify it and know that most blocks of the file will
100/// ultimately contain some nonzero data, then it may be beneficial to allocate
101/// space for the file on disk before modifying it in order to avoid
102/// fragmentation. For example, on POSIX-like systems, you can do this by
103/// calling `fallocate` on the file.
104///
105/// [sparse files]: https://en.wikipedia.org/wiki/Sparse_file
106///
107/// ## Alternatives
108///
109/// If all you want to do is create an array larger than the available memory
110/// and don't care about actually writing the data to disk, it may be worth
111/// considering alternative options:
112///
113/// - Add more swap space to your system, using swap file(s) if necessary, so
114///   that you can allocate the array as normal.
115///
116/// - If you know the data will be sparse:
117///
118///   - Use a sparse data structure instead of `ndarray`'s array types. For
119///     example, the [`sprs`](https://crates.io/crates/sprs) crate provides
120///     sparse matrices.
121///
122///   - Rely on memory overcommitment. In other words, configure the operating
123///     system to allocate more memory than actually exists. However, this
124///     risks the system running out of memory if the data is not as sparse as
125///     you expect.
126///
127/// # Example
128///
129/// In this example, a file containing 64 GiB of zeroed `f64` elements is
130/// created. Then, an `ArrayViewMut` is created by memory-mapping the file.
131/// Modifications to the data in the `ArrayViewMut` will be applied to the
132/// backing file. This works even on systems with less than 64 GiB of physical
133/// memory. On filesystems which support [sparse files], the disk space that's
134/// actually used depends on how much data is modified.
135///
136/// ```no_run
137/// use memmap2::MmapMut;
138/// use ndarray::ArrayViewMut3;
139/// use ndarray_npy::{write_zeroed_npy, ViewMutNpyExt};
140/// use std::fs::{File, OpenOptions};
141///
142/// let path = "array.npy";
143///
144/// // Create a (sparse if supported) file containing 64 GiB of zeroed data.
145/// let file = File::create(path)?;
146/// write_zeroed_npy::<f64, _>(&file, (1024, 2048, 4096))?;
147///
148/// // Memory-map the file and create the mutable view.
149/// let file = OpenOptions::new().read(true).write(true).open(path)?;
150/// let mut mmap = unsafe { MmapMut::map_mut(&file)? };
151/// let mut view_mut = ArrayViewMut3::<f64>::view_mut_npy(&mut mmap)?;
152///
153/// // Modify an element in the view.
154/// view_mut[[500, 1000, 2000]] = 3.14;
155/// #
156/// # Ok::<_, Box<dyn std::error::Error>>(())
157/// ```
158pub fn write_zeroed_npy<A, D>(mut file: &File, shape: D) -> Result<(), WriteNpyError>
159where
160    A: WritableElement,
161    D: IntoDimension,
162{
163    let dim = shape.into_dimension();
164    let data_bytes_len: u64 = dim
165        .size_checked()
166        .expect("overflow computing number of elements")
167        .checked_mul(mem::size_of::<A>())
168        .expect("overflow computing length of data")
169        .try_into()
170        .expect("overflow converting length of data to u64");
171    Header {
172        type_descriptor: A::type_descriptor(),
173        fortran_order: false,
174        shape: dim.as_array_view().to_vec(),
175    }
176    .write(file)?;
177    let current_offset = file.stream_position()?;
178    // First, truncate the file to the current offset.
179    file.set_len(current_offset)?;
180    // Then, zero-extend the length to represent the data (sparse if possible).
181    file.set_len(
182        current_offset
183            .checked_add(data_bytes_len)
184            .expect("overflow computing file length"),
185    )?;
186    Ok(())
187}
188
189/// An error writing array data.
190#[derive(Debug)]
191pub enum WriteDataError {
192    /// An error caused by I/O.
193    Io(io::Error),
194    /// An error formatting the data.
195    FormatData(Box<dyn Error + Send + Sync + 'static>),
196}
197
198impl Error for WriteDataError {
199    fn source(&self) -> Option<&(dyn Error + 'static)> {
200        match self {
201            WriteDataError::Io(err) => Some(err),
202            WriteDataError::FormatData(err) => Some(&**err),
203        }
204    }
205}
206
207impl fmt::Display for WriteDataError {
208    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
209        match self {
210            WriteDataError::Io(err) => write!(f, "I/O error: {}", err),
211            WriteDataError::FormatData(err) => write!(f, "error formatting data: {}", err),
212        }
213    }
214}
215
216impl From<io::Error> for WriteDataError {
217    fn from(err: io::Error) -> WriteDataError {
218        WriteDataError::Io(err)
219    }
220}
221
222/// An array element type that can be written to an `.npy` or `.npz` file.
223pub trait WritableElement: Sized {
224    /// Returns a descriptor of the type that can be used in the header.
225    fn type_descriptor() -> PyValue;
226
227    /// Writes a single instance of `Self` to the writer.
228    fn write<W: io::Write>(&self, writer: W) -> Result<(), WriteDataError>;
229
230    /// Writes a slice of `Self` to the writer.
231    fn write_slice<W: io::Write>(slice: &[Self], writer: W) -> Result<(), WriteDataError>;
232}
233
234/// An error writing a `.npy` file.
235#[derive(Debug)]
236pub enum WriteNpyError {
237    /// An error caused by I/O.
238    Io(io::Error),
239    /// An error formatting the header.
240    FormatHeader(FormatHeaderError),
241    /// An error formatting the data.
242    FormatData(Box<dyn Error + Send + Sync + 'static>),
243}
244
245impl Error for WriteNpyError {
246    fn source(&self) -> Option<&(dyn Error + 'static)> {
247        match self {
248            WriteNpyError::Io(err) => Some(err),
249            WriteNpyError::FormatHeader(err) => Some(err),
250            WriteNpyError::FormatData(err) => Some(&**err),
251        }
252    }
253}
254
255impl fmt::Display for WriteNpyError {
256    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
257        match self {
258            WriteNpyError::Io(err) => write!(f, "I/O error: {}", err),
259            WriteNpyError::FormatHeader(err) => write!(f, "error formatting header: {}", err),
260            WriteNpyError::FormatData(err) => write!(f, "error formatting data: {}", err),
261        }
262    }
263}
264
265impl From<io::Error> for WriteNpyError {
266    fn from(err: io::Error) -> WriteNpyError {
267        WriteNpyError::Io(err)
268    }
269}
270
271impl From<WriteHeaderError> for WriteNpyError {
272    fn from(err: WriteHeaderError) -> WriteNpyError {
273        match err {
274            WriteHeaderError::Io(err) => WriteNpyError::Io(err),
275            WriteHeaderError::Format(err) => WriteNpyError::FormatHeader(err),
276        }
277    }
278}
279
280impl From<FormatHeaderError> for WriteNpyError {
281    fn from(err: FormatHeaderError) -> WriteNpyError {
282        WriteNpyError::FormatHeader(err)
283    }
284}
285
286impl From<WriteDataError> for WriteNpyError {
287    fn from(err: WriteDataError) -> WriteNpyError {
288        match err {
289            WriteDataError::Io(err) => WriteNpyError::Io(err),
290            WriteDataError::FormatData(err) => WriteNpyError::FormatData(err),
291        }
292    }
293}
294
295/// Extension trait for writing [`ArrayBase`] to `.npy` files.
296///
297/// If writes are expensive (e.g. for a file or network socket) and the layout
298/// of the array is not known to be in standard or Fortran layout, it is
299/// strongly recommended to wrap the writer in a [`BufWriter`]. For the sake of
300/// convenience, this method calls [`.flush()`](io::Write::flush) on the writer
301/// before returning.
302///
303/// # Example
304///
305/// ```no_run
306/// use ndarray::{array, Array2};
307/// use ndarray_npy::WriteNpyExt;
308/// use std::fs::File;
309/// use std::io::BufWriter;
310/// # use ndarray_npy::WriteNpyError;
311///
312/// let arr: Array2<i32> = array![[1, 2, 3], [4, 5, 6]];
313/// let writer = BufWriter::new(File::create("array.npy")?);
314/// arr.write_npy(writer)?;
315/// # Ok::<_, WriteNpyError>(())
316/// ```
317pub trait WriteNpyExt {
318    /// Writes the array to `writer` in [`.npy`
319    /// format](https://docs.scipy.org/doc/numpy/reference/generated/numpy.lib.format.html).
320    ///
321    /// This function is the Rust equivalent of
322    /// [`numpy.save`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.save.html).
323    fn write_npy<W: io::Write>(&self, writer: W) -> Result<(), WriteNpyError>;
324}
325
326impl<A, S, D> WriteNpyExt for ArrayBase<S, D>
327where
328    A: WritableElement,
329    S: Data<Elem = A>,
330    D: Dimension,
331{
332    fn write_npy<W: io::Write>(&self, mut writer: W) -> Result<(), WriteNpyError> {
333        let write_contiguous = |mut writer: W, fortran_order: bool| {
334            Header {
335                type_descriptor: A::type_descriptor(),
336                fortran_order,
337                shape: self.shape().to_owned(),
338            }
339            .write(&mut writer)?;
340            A::write_slice(self.as_slice_memory_order().unwrap(), &mut writer)?;
341            writer.flush()?;
342            Ok(())
343        };
344        if self.is_standard_layout() {
345            write_contiguous(writer, false)
346        } else if self.view().reversed_axes().is_standard_layout() {
347            write_contiguous(writer, true)
348        } else {
349            Header {
350                type_descriptor: A::type_descriptor(),
351                fortran_order: false,
352                shape: self.shape().to_owned(),
353            }
354            .write(&mut writer)?;
355            for elem in self.iter() {
356                elem.write(&mut writer)?;
357            }
358            writer.flush()?;
359            Ok(())
360        }
361    }
362}
363
364/// An error reading array data.
365#[derive(Debug)]
366pub enum ReadDataError {
367    /// An error caused by I/O.
368    Io(io::Error),
369    /// The type descriptor does not match the element type.
370    WrongDescriptor(PyValue),
371    /// The file does not contain all the data described in the header.
372    MissingData,
373    /// Extra bytes are present between the end of the data and the end of the
374    /// file.
375    ExtraBytes(usize),
376    /// An error parsing the data.
377    ParseData(Box<dyn Error + Send + Sync + 'static>),
378}
379
380impl Error for ReadDataError {
381    fn source(&self) -> Option<&(dyn Error + 'static)> {
382        match self {
383            ReadDataError::Io(err) => Some(err),
384            ReadDataError::WrongDescriptor(_) => None,
385            ReadDataError::MissingData => None,
386            ReadDataError::ExtraBytes(_) => None,
387            ReadDataError::ParseData(err) => Some(&**err),
388        }
389    }
390}
391
392impl fmt::Display for ReadDataError {
393    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
394        match self {
395            ReadDataError::Io(err) => write!(f, "I/O error: {}", err),
396            ReadDataError::WrongDescriptor(desc) => {
397                write!(f, "incorrect descriptor ({}) for this type", desc)
398            }
399            ReadDataError::MissingData => write!(f, "reached EOF before reading all data"),
400            ReadDataError::ExtraBytes(num_extra_bytes) => {
401                write!(f, "file had {} extra bytes before EOF", num_extra_bytes)
402            }
403            ReadDataError::ParseData(err) => write!(f, "error parsing data: {}", err),
404        }
405    }
406}
407
408impl From<io::Error> for ReadDataError {
409    /// Performs the conversion.
410    ///
411    /// If the error kind is `UnexpectedEof`, the `MissingData` variant is
412    /// returned. Otherwise, the `Io` variant is returned.
413    fn from(err: io::Error) -> ReadDataError {
414        if err.kind() == io::ErrorKind::UnexpectedEof {
415            ReadDataError::MissingData
416        } else {
417            ReadDataError::Io(err)
418        }
419    }
420}
421
422/// An array element type that can be read from an `.npy` or `.npz` file.
423pub trait ReadableElement: Sized {
424    /// Reads to the end of the `reader`, creating a `Vec` of length `len`.
425    ///
426    /// This method should return `Err(_)` in at least the following cases:
427    ///
428    /// * if the `type_desc` does not match `Self`
429    /// * if the `reader` has fewer elements than `len`
430    /// * if the `reader` has extra bytes after reading `len` elements
431    fn read_to_end_exact_vec<R: io::Read>(
432        reader: R,
433        type_desc: &PyValue,
434        len: usize,
435    ) -> Result<Vec<Self>, ReadDataError>;
436}
437
438/// An error reading a `.npy` file.
439#[derive(Debug)]
440pub enum ReadNpyError {
441    /// An error caused by I/O.
442    Io(io::Error),
443    /// An error parsing the file header.
444    ParseHeader(ParseHeaderError),
445    /// An error parsing the data.
446    ParseData(Box<dyn Error + Send + Sync + 'static>),
447    /// Overflow while computing the length of the array (in units of bytes or
448    /// the number of elements) from the shape described in the file header.
449    LengthOverflow,
450    /// An error caused by incorrect `Dimension` type.
451    WrongNdim(Option<usize>, usize),
452    /// The type descriptor does not match the element type.
453    WrongDescriptor(PyValue),
454    /// The file does not contain all the data described in the header.
455    MissingData,
456    /// Extra bytes are present between the end of the data and the end of the
457    /// file.
458    ExtraBytes(usize),
459}
460
461impl Error for ReadNpyError {
462    fn source(&self) -> Option<&(dyn Error + 'static)> {
463        match self {
464            ReadNpyError::Io(err) => Some(err),
465            ReadNpyError::ParseHeader(err) => Some(err),
466            ReadNpyError::ParseData(err) => Some(&**err),
467            ReadNpyError::LengthOverflow => None,
468            ReadNpyError::WrongNdim(_, _) => None,
469            ReadNpyError::WrongDescriptor(_) => None,
470            ReadNpyError::MissingData => None,
471            ReadNpyError::ExtraBytes(_) => None,
472        }
473    }
474}
475
476impl fmt::Display for ReadNpyError {
477    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
478        match self {
479            ReadNpyError::Io(err) => write!(f, "I/O error: {}", err),
480            ReadNpyError::ParseHeader(err) => write!(f, "error parsing header: {}", err),
481            ReadNpyError::ParseData(err) => write!(f, "error parsing data: {}", err),
482            ReadNpyError::LengthOverflow => write!(f, "overflow computing length from shape"),
483            ReadNpyError::WrongNdim(expected, actual) => write!(
484                f,
485                "ndim {} of array did not match Dimension type with NDIM = {:?}",
486                actual, expected
487            ),
488            ReadNpyError::WrongDescriptor(desc) => {
489                write!(f, "incorrect descriptor ({}) for this type", desc)
490            }
491            ReadNpyError::MissingData => write!(f, "reached EOF before reading all data"),
492            ReadNpyError::ExtraBytes(num_extra_bytes) => {
493                write!(f, "file had {} extra bytes before EOF", num_extra_bytes)
494            }
495        }
496    }
497}
498
499impl From<io::Error> for ReadNpyError {
500    fn from(err: io::Error) -> ReadNpyError {
501        ReadNpyError::Io(err)
502    }
503}
504
505impl From<ReadHeaderError> for ReadNpyError {
506    fn from(err: ReadHeaderError) -> ReadNpyError {
507        match err {
508            ReadHeaderError::Io(err) => ReadNpyError::Io(err),
509            ReadHeaderError::Parse(err) => ReadNpyError::ParseHeader(err),
510        }
511    }
512}
513
514impl From<ParseHeaderError> for ReadNpyError {
515    fn from(err: ParseHeaderError) -> ReadNpyError {
516        ReadNpyError::ParseHeader(err)
517    }
518}
519
520impl From<ReadDataError> for ReadNpyError {
521    fn from(err: ReadDataError) -> ReadNpyError {
522        match err {
523            ReadDataError::Io(err) => ReadNpyError::Io(err),
524            ReadDataError::WrongDescriptor(desc) => ReadNpyError::WrongDescriptor(desc),
525            ReadDataError::MissingData => ReadNpyError::MissingData,
526            ReadDataError::ExtraBytes(nbytes) => ReadNpyError::ExtraBytes(nbytes),
527            ReadDataError::ParseData(err) => ReadNpyError::ParseData(err),
528        }
529    }
530}
531
532/// Extension trait for reading `Array` from `.npy` files.
533///
534/// # Example
535///
536/// ```
537/// use ndarray::Array2;
538/// use ndarray_npy::ReadNpyExt;
539/// use std::fs::File;
540/// # use ndarray_npy::ReadNpyError;
541///
542/// let reader = File::open("resources/array.npy")?;
543/// let arr = Array2::<i32>::read_npy(reader)?;
544/// # println!("arr = {}", arr);
545/// # Ok::<_, ReadNpyError>(())
546/// ```
547pub trait ReadNpyExt: Sized {
548    /// Reads the array from `reader` in [`.npy`
549    /// format](https://docs.scipy.org/doc/numpy/reference/generated/numpy.lib.format.html).
550    ///
551    /// This function is the Rust equivalent of
552    /// [`numpy.load`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.load.html)
553    /// for `.npy` files.
554    fn read_npy<R: io::Read>(reader: R) -> Result<Self, ReadNpyError>;
555}
556
557impl<A, S, D> ReadNpyExt for ArrayBase<S, D>
558where
559    A: ReadableElement,
560    S: DataOwned<Elem = A>,
561    D: Dimension,
562{
563    fn read_npy<R: io::Read>(mut reader: R) -> Result<Self, ReadNpyError> {
564        let header = Header::from_reader(&mut reader)?;
565        let shape = header.shape.into_dimension();
566        let ndim = shape.ndim();
567        let len = shape_length_checked::<A>(&shape).ok_or(ReadNpyError::LengthOverflow)?;
568        let data = A::read_to_end_exact_vec(&mut reader, &header.type_descriptor, len)?;
569        ArrayBase::from_shape_vec(shape.set_f(header.fortran_order), data)
570            .unwrap()
571            .into_dimensionality()
572            .map_err(|_| ReadNpyError::WrongNdim(D::NDIM, ndim))
573    }
574}
575
576/// An error viewing a `.npy` file.
577#[derive(Debug)]
578#[non_exhaustive]
579pub enum ViewNpyError {
580    /// An error caused by I/O.
581    Io(io::Error),
582    /// An error parsing the file header.
583    ParseHeader(ParseHeaderError),
584    /// Some of the data is invalid for the element type.
585    InvalidData(Box<dyn Error + Send + Sync + 'static>),
586    /// Overflow while computing the length of the array (in units of bytes or
587    /// the number of elements) from the shape described in the file header.
588    LengthOverflow,
589    /// An error caused by incorrect `Dimension` type.
590    WrongNdim(Option<usize>, usize),
591    /// The type descriptor does not match the element type.
592    WrongDescriptor(PyValue),
593    /// The type descriptor does not match the native endianness.
594    NonNativeEndian,
595    /// The start of the data is not properly aligned for the element type.
596    MisalignedData,
597    /// The file does not contain all the data described in the header.
598    MissingBytes(usize),
599    /// Extra bytes are present between the end of the data and the end of the
600    /// file.
601    ExtraBytes(usize),
602}
603
604impl Error for ViewNpyError {
605    fn source(&self) -> Option<&(dyn Error + 'static)> {
606        match self {
607            ViewNpyError::Io(err) => Some(err),
608            ViewNpyError::ParseHeader(err) => Some(err),
609            ViewNpyError::InvalidData(err) => Some(&**err),
610            ViewNpyError::LengthOverflow => None,
611            ViewNpyError::WrongNdim(_, _) => None,
612            ViewNpyError::WrongDescriptor(_) => None,
613            ViewNpyError::NonNativeEndian => None,
614            ViewNpyError::MisalignedData => None,
615            ViewNpyError::MissingBytes(_) => None,
616            ViewNpyError::ExtraBytes(_) => None,
617        }
618    }
619}
620
621impl fmt::Display for ViewNpyError {
622    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
623        match self {
624            ViewNpyError::Io(err) => write!(f, "I/O error: {}", err),
625            ViewNpyError::ParseHeader(err) => write!(f, "error parsing header: {}", err),
626            ViewNpyError::InvalidData(err) => write!(f, "invalid data for element type: {}", err),
627            ViewNpyError::LengthOverflow => write!(f, "overflow computing length from shape"),
628            ViewNpyError::WrongNdim(expected, actual) => write!(
629                f,
630                "ndim {} of array did not match Dimension type with NDIM = {:?}",
631                actual, expected
632            ),
633            ViewNpyError::WrongDescriptor(desc) => {
634                write!(f, "incorrect descriptor ({}) for this type", desc)
635            }
636            ViewNpyError::NonNativeEndian => {
637                write!(f, "descriptor does not match native endianness")
638            }
639            ViewNpyError::MisalignedData => write!(
640                f,
641                "start of data is not properly aligned for the element type"
642            ),
643            ViewNpyError::MissingBytes(num_missing_bytes) => write!(
644                f,
645                "missing {} bytes of data specified in header",
646                num_missing_bytes
647            ),
648            ViewNpyError::ExtraBytes(num_extra_bytes) => {
649                write!(f, "file had {} extra bytes before EOF", num_extra_bytes)
650            }
651        }
652    }
653}
654
655impl From<ReadHeaderError> for ViewNpyError {
656    fn from(err: ReadHeaderError) -> ViewNpyError {
657        match err {
658            ReadHeaderError::Io(err) => ViewNpyError::Io(err),
659            ReadHeaderError::Parse(err) => ViewNpyError::ParseHeader(err),
660        }
661    }
662}
663
664impl From<ParseHeaderError> for ViewNpyError {
665    fn from(err: ParseHeaderError) -> ViewNpyError {
666        ViewNpyError::ParseHeader(err)
667    }
668}
669
670impl From<ViewDataError> for ViewNpyError {
671    fn from(err: ViewDataError) -> ViewNpyError {
672        match err {
673            ViewDataError::WrongDescriptor(desc) => ViewNpyError::WrongDescriptor(desc),
674            ViewDataError::NonNativeEndian => ViewNpyError::NonNativeEndian,
675            ViewDataError::Misaligned => ViewNpyError::MisalignedData,
676            ViewDataError::MissingBytes(nbytes) => ViewNpyError::MissingBytes(nbytes),
677            ViewDataError::ExtraBytes(nbytes) => ViewNpyError::ExtraBytes(nbytes),
678            ViewDataError::InvalidData(err) => ViewNpyError::InvalidData(err),
679        }
680    }
681}
682
683/// Extension trait for creating an [`ArrayView`] from a buffer containing an
684/// `.npy` file.
685///
686/// The primary use-case for this is viewing a memory-mapped `.npy` file.
687///
688/// # Notes
689///
690/// - For types for which not all bit patterns are valid, such as `bool`, the
691///   implementation iterates over all of the elements when creating the view
692///   to ensure they have a valid bit pattern.
693///
694/// - The data in the buffer must be properly aligned for the element type.
695///   Typically, this should not be a concern for memory-mapped files (unless
696///   an option like `MAP_FIXED` is used), since memory mappings are usually
697///   aligned to a page boundary, and the `.npy` format has padding such that
698///   the header size is a multiple of 64 bytes.
699///
700/// # Example
701///
702/// This is an example of opening a readonly memory-mapped file as an
703/// [`ArrayView`].
704///
705/// This example uses the [`memmap2`](https://crates.io/crates/memmap2) crate
706/// because that appears to be the best-maintained memory-mapping crate at the
707/// moment, but `view_npy` takes a `&[u8]` instead of a file so that you can
708/// use the memory-mapping crate you're most comfortable with.
709///
710/// ```
711/// # if !cfg!(miri) { // Miri doesn't support mmap.
712/// use memmap2::Mmap;
713/// use ndarray::ArrayView2;
714/// use ndarray_npy::ViewNpyExt;
715/// use std::fs::File;
716///
717/// let file = File::open("resources/array.npy")?;
718/// let mmap = unsafe { Mmap::map(&file)? };
719/// let view = ArrayView2::<i32>::view_npy(&mmap)?;
720/// # println!("view = {}", view);
721/// # }
722/// # Ok::<_, Box<dyn std::error::Error>>(())
723/// ```
724pub trait ViewNpyExt<'a>: Sized {
725    /// Creates an `ArrayView` from a buffer containing an `.npy` file.
726    fn view_npy(buf: &'a [u8]) -> Result<Self, ViewNpyError>;
727}
728
729/// Extension trait for creating an [`ArrayViewMut`] from a mutable buffer
730/// containing an `.npy` file.
731///
732/// The primary use-case for this is modifying a memory-mapped `.npy` file.
733/// Modifying the elements in the view will modify the file. Modifying the
734/// shape/strides of the view will *not* modify the shape/strides of the array
735/// in the file.
736///
737/// Notes:
738///
739/// - For types for which not all bit patterns are valid, such as `bool`, the
740///   implementation iterates over all of the elements when creating the view
741///   to ensure they have a valid bit pattern.
742///
743/// - The data in the buffer must be properly aligned for the element type.
744///   Typically, this should not be a concern for memory-mapped files (unless
745///   an option like `MAP_FIXED` is used), since memory mappings are usually
746///   aligned to a page boundary, and the `.npy` format has padding such that
747///   the header size is a multiple of 64 bytes.
748///
749/// # Example
750///
751/// This is an example of opening a writable memory-mapped file as an
752/// [`ArrayViewMut`]. Changes to the data in the view will modify the
753/// underlying file.
754///
755/// This example uses the [`memmap2`](https://crates.io/crates/memmap2) crate
756/// because that appears to be the best-maintained memory-mapping crate at the
757/// moment, but `view_mut_npy` takes a `&mut [u8]` instead of a file so that
758/// you can use the memory-mapping crate you're most comfortable with.
759///
760/// ```
761/// # if !cfg!(miri) { // Miri doesn't support mmap.
762/// use memmap2::MmapMut;
763/// use ndarray::ArrayViewMut2;
764/// use ndarray_npy::ViewMutNpyExt;
765/// use std::fs;
766///
767/// let file = fs::OpenOptions::new()
768///     .read(true)
769///     .write(true)
770///     .open("resources/array.npy")?;
771/// let mut mmap = unsafe { MmapMut::map_mut(&file)? };
772/// let view_mut = ArrayViewMut2::<i32>::view_mut_npy(&mut mmap)?;
773/// # println!("view_mut = {}", view_mut);
774/// # }
775/// # Ok::<_, Box<dyn std::error::Error>>(())
776/// ```
777pub trait ViewMutNpyExt<'a>: Sized {
778    /// Creates an `ArrayViewMut` from a mutable buffer containing an `.npy`
779    /// file.
780    fn view_mut_npy(buf: &'a mut [u8]) -> Result<Self, ViewNpyError>;
781}
782
783impl<'a, A, D> ViewNpyExt<'a> for ArrayView<'a, A, D>
784where
785    A: ViewElement,
786    D: Dimension,
787{
788    fn view_npy(buf: &'a [u8]) -> Result<Self, ViewNpyError> {
789        let mut reader = buf;
790        let header = Header::from_reader(&mut reader)?;
791        let shape = header.shape.into_dimension();
792        let ndim = shape.ndim();
793        let len = shape_length_checked::<A>(&shape).ok_or(ViewNpyError::LengthOverflow)?;
794        let data = A::bytes_as_slice(reader, &header.type_descriptor, len)?;
795        ArrayView::from_shape(shape.set_f(header.fortran_order), data)
796            .unwrap()
797            .into_dimensionality()
798            .map_err(|_| ViewNpyError::WrongNdim(D::NDIM, ndim))
799    }
800}
801
802impl<'a, A, D> ViewMutNpyExt<'a> for ArrayViewMut<'a, A, D>
803where
804    A: ViewMutElement,
805    D: Dimension,
806{
807    fn view_mut_npy(buf: &'a mut [u8]) -> Result<Self, ViewNpyError> {
808        let mut reader = &*buf;
809        let header = Header::from_reader(&mut reader)?;
810        let shape = header.shape.into_dimension();
811        let ndim = shape.ndim();
812        let len = shape_length_checked::<A>(&shape).ok_or(ViewNpyError::LengthOverflow)?;
813        let mid = buf.len() - reader.len();
814        let data = A::bytes_as_mut_slice(&mut buf[mid..], &header.type_descriptor, len)?;
815        ArrayViewMut::from_shape(shape.set_f(header.fortran_order), data)
816            .unwrap()
817            .into_dimensionality()
818            .map_err(|_| ViewNpyError::WrongNdim(D::NDIM, ndim))
819    }
820}
821
822/// An error viewing array data.
823#[derive(Debug)]
824#[non_exhaustive]
825pub enum ViewDataError {
826    /// The type descriptor does not match the element type.
827    WrongDescriptor(PyValue),
828    /// The type descriptor does not match the native endianness.
829    NonNativeEndian,
830    /// The start of the data is not properly aligned for the element type.
831    Misaligned,
832    /// The file does not contain all the data described in the header.
833    MissingBytes(usize),
834    /// Extra bytes are present between the end of the data and the end of the
835    /// file.
836    ExtraBytes(usize),
837    /// Some of the data is invalid for the element type.
838    InvalidData(Box<dyn Error + Send + Sync + 'static>),
839}
840
841impl Error for ViewDataError {
842    fn source(&self) -> Option<&(dyn Error + 'static)> {
843        match self {
844            ViewDataError::WrongDescriptor(_) => None,
845            ViewDataError::NonNativeEndian => None,
846            ViewDataError::Misaligned => None,
847            ViewDataError::MissingBytes(_) => None,
848            ViewDataError::ExtraBytes(_) => None,
849            ViewDataError::InvalidData(err) => Some(&**err),
850        }
851    }
852}
853
854impl fmt::Display for ViewDataError {
855    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
856        match self {
857            ViewDataError::WrongDescriptor(desc) => {
858                write!(f, "incorrect descriptor ({}) for this type", desc)
859            }
860            ViewDataError::NonNativeEndian => {
861                write!(f, "descriptor does not match native endianness")
862            }
863            ViewDataError::Misaligned => write!(
864                f,
865                "start of data is not properly aligned for the element type"
866            ),
867            ViewDataError::MissingBytes(num_missing_bytes) => write!(
868                f,
869                "missing {} bytes of data specified in header",
870                num_missing_bytes
871            ),
872            ViewDataError::ExtraBytes(num_extra_bytes) => {
873                write!(f, "file had {} extra bytes before EOF", num_extra_bytes)
874            }
875            ViewDataError::InvalidData(err) => write!(f, "invalid data for element type: {}", err),
876        }
877    }
878}
879
880/// An array element type that can be viewed (without copying) in an `.npy`
881/// file.
882pub trait ViewElement: Sized {
883    /// Casts `bytes` into a slice of elements of length `len`.
884    ///
885    /// Returns `Err(_)` in at least the following cases:
886    ///
887    ///   * if the `type_desc` does not match `Self` with native endianness
888    ///   * if the `bytes` slice is misaligned for elements of type `Self`
889    ///   * if the `bytes` slice is too short for `len` elements
890    ///   * if the `bytes` slice has extra bytes after `len` elements
891    ///
892    /// May panic if `len * size_of::<Self>()` overflows.
893    fn bytes_as_slice<'a>(
894        bytes: &'a [u8],
895        type_desc: &PyValue,
896        len: usize,
897    ) -> Result<&'a [Self], ViewDataError>;
898}
899
900/// An array element type that can be mutably viewed (without copying) in an
901/// `.npy` file.
902pub trait ViewMutElement: Sized {
903    /// Casts `bytes` into a mutable slice of elements of length `len`.
904    ///
905    /// Returns `Err(_)` in at least the following cases:
906    ///
907    ///   * if the `type_desc` does not match `Self` with native endianness
908    ///   * if the `bytes` slice is misaligned for elements of type `Self`
909    ///   * if the `bytes` slice is too short for `len` elements
910    ///   * if the `bytes` slice has extra bytes after `len` elements
911    ///
912    /// May panic if `len * size_of::<Self>()` overflows.
913    fn bytes_as_mut_slice<'a>(
914        bytes: &'a mut [u8],
915        type_desc: &PyValue,
916        len: usize,
917    ) -> Result<&'a mut [Self], ViewDataError>;
918}
919
920/// Computes the length associated with the shape (i.e. the product of the axis
921/// lengths), where the element type is `T`.
922///
923/// Returns `None` if the number of elements or the length in bytes would
924/// overflow `isize`.
925fn shape_length_checked<T>(shape: &IxDyn) -> Option<usize> {
926    let len = shape.size_checked()?;
927    if len > isize::MAX as usize {
928        return None;
929    }
930    let bytes = len.checked_mul(mem::size_of::<T>())?;
931    if bytes > isize::MAX as usize {
932        return None;
933    }
934    Some(len)
935}