ndarray_npy/npy/
header.rs

1//! Types and methods for (de)serializing the header of an `.npy` file.
2//!
3//! In most cases, users do not need this module, since they can use the more convenient,
4//! higher-level functionality instead.
5
6use byteorder::{ByteOrder, LittleEndian, ReadBytesExt};
7use num_traits::ToPrimitive;
8use py_literal::{
9    FormatError as PyValueFormatError, ParseError as PyValueParseError, Value as PyValue,
10};
11use std::convert::TryFrom;
12use std::error::Error;
13use std::fmt;
14use std::io;
15
16/// Magic string to indicate npy format.
17const MAGIC_STRING: &[u8] = b"\x93NUMPY";
18
19/// The total header length (including magic string, version number, header
20/// length value, array format description, padding, and final newline) must be
21/// evenly divisible by this value.
22// If this changes, update the docs of `ViewNpyExt` and `ViewMutNpyExt`.
23const HEADER_DIVISOR: usize = 64;
24
25/// Error parsing an `.npy` header.
26#[derive(Debug)]
27pub enum ParseHeaderError {
28    /// The first several bytes are not the expected magic string.
29    MagicString,
30    /// The version number specified in the header is unsupported.
31    Version { major: u8, minor: u8 },
32    /// The `HEADER_LEN` doesn't fit in `usize`.
33    HeaderLengthOverflow(u32),
34    /// The array format string contains non-ASCII characters.
35    ///
36    /// This is an error for .npy format versions 1.0 and 2.0.
37    NonAscii,
38    /// Error parsing the array format string as UTF-8.
39    ///
40    /// This does not apply to .npy format versions 1.0 and 2.0, which require the array format
41    /// string to be ASCII.
42    Utf8Parse(std::str::Utf8Error),
43    /// The Python dictionary in the header contains an unexpected key.
44    UnknownKey(PyValue),
45    /// The Python dictionary in the header is missing an expected key.
46    MissingKey(String),
47    /// The value corresponding to an expected key is illegal (e.g., the wrong type).
48    IllegalValue {
49        /// The key in the header dictionary.
50        key: String,
51        /// The corresponding (illegal) value.
52        value: PyValue,
53    },
54    /// Error parsing the dictionary in the header.
55    DictParse(PyValueParseError),
56    /// The metadata in the header is not a dictionary.
57    MetaNotDict(PyValue),
58    /// There is no newline at the end of the header.
59    MissingNewline,
60}
61
62impl Error for ParseHeaderError {
63    fn source(&self) -> Option<&(dyn Error + 'static)> {
64        use ParseHeaderError::*;
65        match self {
66            MagicString => None,
67            Version { .. } => None,
68            HeaderLengthOverflow(_) => None,
69            NonAscii => None,
70            Utf8Parse(err) => Some(err),
71            UnknownKey(_) => None,
72            MissingKey(_) => None,
73            IllegalValue { .. } => None,
74            DictParse(err) => Some(err),
75            MetaNotDict(_) => None,
76            MissingNewline => None,
77        }
78    }
79}
80
81impl fmt::Display for ParseHeaderError {
82    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83        use ParseHeaderError::*;
84        match self {
85            MagicString => write!(f, "start does not match magic string"),
86            Version { major, minor } => write!(f, "unknown version number: {}.{}", major, minor),
87            HeaderLengthOverflow(header_len) => write!(f, "HEADER_LEN {} does not fit in `usize`", header_len),
88            NonAscii => write!(f, "non-ascii in array format string; this is not supported in .npy format versions 1.0 and 2.0"),
89            Utf8Parse(err) => write!(f, "error parsing array format string as UTF-8: {}", err),
90            UnknownKey(key) => write!(f, "unknown key: {}", key),
91            MissingKey(key) => write!(f, "missing key: {}", key),
92            IllegalValue { key, value } => write!(f, "illegal value for key {}: {}", key, value),
93            DictParse(err) => write!(f, "error parsing metadata dict: {}", err),
94            MetaNotDict(value) => write!(f, "metadata is not a dict: {}", value),
95            MissingNewline => write!(f, "newline missing at end of header"),
96        }
97    }
98}
99
100impl From<std::str::Utf8Error> for ParseHeaderError {
101    fn from(err: std::str::Utf8Error) -> ParseHeaderError {
102        ParseHeaderError::Utf8Parse(err)
103    }
104}
105
106impl From<PyValueParseError> for ParseHeaderError {
107    fn from(err: PyValueParseError) -> ParseHeaderError {
108        ParseHeaderError::DictParse(err)
109    }
110}
111
112/// Error reading an `.npy` header.
113#[derive(Debug)]
114pub enum ReadHeaderError {
115    /// I/O error.
116    Io(io::Error),
117    /// Error parsing the header.
118    Parse(ParseHeaderError),
119}
120
121impl Error for ReadHeaderError {
122    fn source(&self) -> Option<&(dyn Error + 'static)> {
123        match self {
124            ReadHeaderError::Io(err) => Some(err),
125            ReadHeaderError::Parse(err) => Some(err),
126        }
127    }
128}
129
130impl fmt::Display for ReadHeaderError {
131    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
132        match self {
133            ReadHeaderError::Io(err) => write!(f, "I/O error: {}", err),
134            ReadHeaderError::Parse(err) => write!(f, "error parsing header: {}", err),
135        }
136    }
137}
138
139impl From<io::Error> for ReadHeaderError {
140    fn from(err: io::Error) -> ReadHeaderError {
141        ReadHeaderError::Io(err)
142    }
143}
144
145impl From<ParseHeaderError> for ReadHeaderError {
146    fn from(err: ParseHeaderError) -> ReadHeaderError {
147        ReadHeaderError::Parse(err)
148    }
149}
150
151#[derive(Clone, Copy)]
152#[allow(non_camel_case_types)]
153enum Version {
154    V1_0,
155    V2_0,
156    V3_0,
157}
158
159impl Version {
160    /// Number of bytes taken up by version number (1 byte for major version, 1
161    /// byte for minor version).
162    const VERSION_NUM_BYTES: usize = 2;
163
164    fn from_bytes(bytes: &[u8]) -> Result<Self, ParseHeaderError> {
165        debug_assert_eq!(bytes.len(), Self::VERSION_NUM_BYTES);
166        match (bytes[0], bytes[1]) {
167            (0x01, 0x00) => Ok(Version::V1_0),
168            (0x02, 0x00) => Ok(Version::V2_0),
169            (0x03, 0x00) => Ok(Version::V3_0),
170            (major, minor) => Err(ParseHeaderError::Version { major, minor }),
171        }
172    }
173
174    /// Major version number.
175    fn major_version(self) -> u8 {
176        match self {
177            Version::V1_0 => 1,
178            Version::V2_0 => 2,
179            Version::V3_0 => 3,
180        }
181    }
182
183    /// Major version number.
184    fn minor_version(self) -> u8 {
185        match self {
186            Version::V1_0 => 0,
187            Version::V2_0 => 0,
188            Version::V3_0 => 0,
189        }
190    }
191
192    /// Number of bytes in representation of header length.
193    fn header_len_num_bytes(self) -> usize {
194        match self {
195            Version::V1_0 => 2,
196            Version::V2_0 | Version::V3_0 => 4,
197        }
198    }
199
200    /// Read header length.
201    fn read_header_len<R: io::Read>(self, reader: &mut R) -> Result<usize, ReadHeaderError> {
202        match self {
203            Version::V1_0 => Ok(usize::from(reader.read_u16::<LittleEndian>()?)),
204            Version::V2_0 | Version::V3_0 => {
205                let header_len: u32 = reader.read_u32::<LittleEndian>()?;
206                Ok(usize::try_from(header_len)
207                    .map_err(|_| ParseHeaderError::HeaderLengthOverflow(header_len))?)
208            }
209        }
210    }
211
212    /// Format header length as bytes for writing to file.
213    ///
214    /// Returns `None` if the value of `header_len` is too large for this .npy version.
215    fn format_header_len(self, header_len: usize) -> Option<Vec<u8>> {
216        match self {
217            Version::V1_0 => {
218                let header_len: u16 = u16::try_from(header_len).ok()?;
219                let mut out = vec![0; self.header_len_num_bytes()];
220                LittleEndian::write_u16(&mut out, header_len);
221                Some(out)
222            }
223            Version::V2_0 | Version::V3_0 => {
224                let header_len: u32 = u32::try_from(header_len).ok()?;
225                let mut out = vec![0; self.header_len_num_bytes()];
226                LittleEndian::write_u32(&mut out, header_len);
227                Some(out)
228            }
229        }
230    }
231
232    /// Computes the total header length, formatted `HEADER_LEN` value, and
233    /// padding length for this .npy version.
234    ///
235    /// `unpadded_arr_format` is the Python literal describing the array
236    /// format, formatted as an ASCII string without any padding.
237    ///
238    /// Returns `None` if the total header length overflows `usize` or if the
239    /// value of `HEADER_LEN` is too large for this .npy version.
240    fn compute_lengths(self, unpadded_arr_format: &[u8]) -> Option<HeaderLengthInfo> {
241        /// Length of a '\n' char in bytes.
242        const NEWLINE_LEN: usize = 1;
243
244        let prefix_len: usize =
245            MAGIC_STRING.len() + Version::VERSION_NUM_BYTES + self.header_len_num_bytes();
246        let unpadded_total_len: usize = prefix_len
247            .checked_add(unpadded_arr_format.len())?
248            .checked_add(NEWLINE_LEN)?;
249        let padding_len: usize = HEADER_DIVISOR - unpadded_total_len % HEADER_DIVISOR;
250        let total_len: usize = unpadded_total_len.checked_add(padding_len)?;
251        let header_len: usize = total_len - prefix_len;
252        let formatted_header_len = self.format_header_len(header_len)?;
253        Some(HeaderLengthInfo {
254            total_len,
255            formatted_header_len,
256        })
257    }
258}
259
260struct HeaderLengthInfo {
261    /// Total header length (including magic string, version number, header
262    /// length value, array format description, padding, and final newline).
263    total_len: usize,
264    /// Formatted `HEADER_LEN` value. (This is the number of bytes in the array
265    /// format description, padding, and final newline.)
266    formatted_header_len: Vec<u8>,
267}
268
269/// Error formatting an `.npy` header.
270#[derive(Debug)]
271pub enum FormatHeaderError {
272    /// Error formatting the header's metadata dictionary.
273    PyValue(PyValueFormatError),
274    /// The total header length overflows `usize`, or `HEADER_LEN` exceeds the
275    /// maximum encodable value.
276    HeaderTooLong,
277}
278
279impl Error for FormatHeaderError {
280    fn source(&self) -> Option<&(dyn Error + 'static)> {
281        match self {
282            FormatHeaderError::PyValue(err) => Some(err),
283            FormatHeaderError::HeaderTooLong => None,
284        }
285    }
286}
287
288impl fmt::Display for FormatHeaderError {
289    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
290        match self {
291            FormatHeaderError::PyValue(err) => write!(f, "error formatting Python value: {}", err),
292            FormatHeaderError::HeaderTooLong => write!(f, "the header is too long"),
293        }
294    }
295}
296
297impl From<PyValueFormatError> for FormatHeaderError {
298    fn from(err: PyValueFormatError) -> FormatHeaderError {
299        FormatHeaderError::PyValue(err)
300    }
301}
302
303/// Error writing an `.npy` header.
304#[derive(Debug)]
305pub enum WriteHeaderError {
306    /// I/O error.
307    Io(io::Error),
308    /// Error formatting the header.
309    Format(FormatHeaderError),
310}
311
312impl Error for WriteHeaderError {
313    fn source(&self) -> Option<&(dyn Error + 'static)> {
314        match self {
315            WriteHeaderError::Io(err) => Some(err),
316            WriteHeaderError::Format(err) => Some(err),
317        }
318    }
319}
320
321impl fmt::Display for WriteHeaderError {
322    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
323        match self {
324            WriteHeaderError::Io(err) => write!(f, "I/O error: {}", err),
325            WriteHeaderError::Format(err) => write!(f, "error formatting header: {}", err),
326        }
327    }
328}
329
330impl From<io::Error> for WriteHeaderError {
331    fn from(err: io::Error) -> WriteHeaderError {
332        WriteHeaderError::Io(err)
333    }
334}
335
336impl From<FormatHeaderError> for WriteHeaderError {
337    fn from(err: FormatHeaderError) -> WriteHeaderError {
338        WriteHeaderError::Format(err)
339    }
340}
341
342/// Layout of an array stored in an `.npy` file.
343#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
344pub enum Layout {
345    /// Standard layout (C order).
346    Standard,
347    /// Fortran layout.
348    Fortran,
349}
350
351impl Layout {
352    /// Returns `true` if the layout is [`Fortran`](Self::Fortran).
353    #[inline]
354    pub fn is_fortran(&self) -> bool {
355        matches!(*self, Layout::Fortran)
356    }
357}
358
359/// Header of an `.npy` file.
360#[derive(Clone, Debug)]
361pub struct Header {
362    /// A Python literal which can be passed as an argument to the `numpy.dtype` constructor to
363    /// create the array's dtype.
364    pub type_descriptor: PyValue,
365    /// The layout of the array.
366    pub layout: Layout,
367    /// The shape of the array.
368    pub shape: Vec<usize>,
369}
370
371impl fmt::Display for Header {
372    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
373        write!(f, "{}", self.to_py_value())
374    }
375}
376
377impl Header {
378    fn from_py_value(value: PyValue) -> Result<Self, ParseHeaderError> {
379        if let PyValue::Dict(dict) = value {
380            let mut type_descriptor: Option<PyValue> = None;
381            let mut is_fortran: Option<bool> = None;
382            let mut shape: Option<Vec<usize>> = None;
383            for (key, value) in dict {
384                match key {
385                    PyValue::String(ref k) if k == "descr" => {
386                        type_descriptor = Some(value);
387                    }
388                    PyValue::String(ref k) if k == "fortran_order" => {
389                        if let PyValue::Boolean(b) = value {
390                            is_fortran = Some(b);
391                        } else {
392                            return Err(ParseHeaderError::IllegalValue {
393                                key: "fortran_order".to_owned(),
394                                value,
395                            });
396                        }
397                    }
398                    PyValue::String(ref k) if k == "shape" => {
399                        fn parse_shape(value: &PyValue) -> Option<Vec<usize>> {
400                            value
401                                .as_tuple()?
402                                .iter()
403                                .map(|elem| elem.as_integer()?.to_usize())
404                                .collect()
405                        }
406                        if let Some(s) = parse_shape(&value) {
407                            shape = Some(s);
408                        } else {
409                            return Err(ParseHeaderError::IllegalValue {
410                                key: "shape".to_owned(),
411                                value,
412                            });
413                        }
414                    }
415                    k => return Err(ParseHeaderError::UnknownKey(k)),
416                }
417            }
418            match (type_descriptor, is_fortran, shape) {
419                (Some(type_descriptor), Some(is_fortran), Some(shape)) => {
420                    let layout = if is_fortran {
421                        Layout::Fortran
422                    } else {
423                        Layout::Standard
424                    };
425                    Ok(Header {
426                        type_descriptor,
427                        layout,
428                        shape,
429                    })
430                }
431                (None, _, _) => Err(ParseHeaderError::MissingKey("descr".to_owned())),
432                (_, None, _) => Err(ParseHeaderError::MissingKey("fortran_order".to_owned())),
433                (_, _, None) => Err(ParseHeaderError::MissingKey("shaper".to_owned())),
434            }
435        } else {
436            Err(ParseHeaderError::MetaNotDict(value))
437        }
438    }
439
440    /// Deserializes a header from the provided reader.
441    pub fn from_reader<R: io::Read>(reader: &mut R) -> Result<Self, ReadHeaderError> {
442        // Check for magic string.
443        let mut buf = vec![0; MAGIC_STRING.len()];
444        reader.read_exact(&mut buf)?;
445        if buf != MAGIC_STRING {
446            return Err(ParseHeaderError::MagicString.into());
447        }
448
449        // Get version number.
450        let mut buf = [0; Version::VERSION_NUM_BYTES];
451        reader.read_exact(&mut buf)?;
452        let version = Version::from_bytes(&buf)?;
453
454        // Get `HEADER_LEN`.
455        let header_len = version.read_header_len(reader)?;
456
457        // Parse the dictionary describing the array's format.
458        let mut buf = vec![0; header_len];
459        reader.read_exact(&mut buf)?;
460        let without_newline = match buf.split_last() {
461            Some((&b'\n', rest)) => rest,
462            Some(_) | None => return Err(ParseHeaderError::MissingNewline.into()),
463        };
464        let header_str = match version {
465            Version::V1_0 | Version::V2_0 => {
466                if without_newline.is_ascii() {
467                    // ASCII strings are always valid UTF-8.
468                    unsafe { std::str::from_utf8_unchecked(without_newline) }
469                } else {
470                    return Err(ParseHeaderError::NonAscii.into());
471                }
472            }
473            Version::V3_0 => {
474                std::str::from_utf8(without_newline).map_err(ParseHeaderError::from)?
475            }
476        };
477        let arr_format: PyValue = header_str.parse().map_err(ParseHeaderError::from)?;
478        Ok(Header::from_py_value(arr_format)?)
479    }
480
481    fn to_py_value(&self) -> PyValue {
482        PyValue::Dict(vec![
483            (
484                PyValue::String("descr".into()),
485                self.type_descriptor.clone(),
486            ),
487            (
488                PyValue::String("fortran_order".into()),
489                PyValue::Boolean(self.layout.is_fortran()),
490            ),
491            (
492                PyValue::String("shape".into()),
493                PyValue::Tuple(
494                    self.shape
495                        .iter()
496                        .map(|&elem| PyValue::Integer(elem.into()))
497                        .collect(),
498                ),
499            ),
500        ])
501    }
502
503    /// Returns the serialized representation of the header.
504    pub fn to_bytes(&self) -> Result<Vec<u8>, FormatHeaderError> {
505        // Metadata describing array's format as ASCII string.
506        let mut arr_format = Vec::new();
507        self.to_py_value().write_ascii(&mut arr_format)?;
508
509        // Determine appropriate version based on header length, and compute
510        // length information.
511        let (version, length_info) = [Version::V1_0, Version::V2_0]
512            .iter()
513            .find_map(|&version| Some((version, version.compute_lengths(&arr_format)?)))
514            .ok_or(FormatHeaderError::HeaderTooLong)?;
515
516        // Write the header.
517        let mut out = Vec::with_capacity(length_info.total_len);
518        out.extend_from_slice(MAGIC_STRING);
519        out.push(version.major_version());
520        out.push(version.minor_version());
521        out.extend_from_slice(&length_info.formatted_header_len);
522        out.extend_from_slice(&arr_format);
523        out.resize(length_info.total_len - 1, b' ');
524        out.push(b'\n');
525
526        // Verify the length of the header.
527        debug_assert_eq!(out.len(), length_info.total_len);
528        debug_assert_eq!(out.len() % HEADER_DIVISOR, 0);
529
530        Ok(out)
531    }
532
533    /// Writes the serialized representation of the header to the provided writer.
534    pub fn write<W: io::Write>(&self, mut writer: W) -> Result<(), WriteHeaderError> {
535        let bytes = self.to_bytes()?;
536        writer.write_all(&bytes)?;
537        Ok(())
538    }
539}