1use byteorder::{ByteOrder, LittleEndian, ReadBytesExt};
7use num_traits::ToPrimitive;
8use py_literal::{
9 FormatError as PyValueFormatError, ParseError as PyValueParseError, Value as PyValue,
10};
11use std::convert::TryFrom;
12use std::error::Error;
13use std::fmt;
14use std::io;
15
16const MAGIC_STRING: &[u8] = b"\x93NUMPY";
18
19const HEADER_DIVISOR: usize = 64;
24
25#[derive(Debug)]
27pub enum ParseHeaderError {
28 MagicString,
30 Version { major: u8, minor: u8 },
32 HeaderLengthOverflow(u32),
34 NonAscii,
38 Utf8Parse(std::str::Utf8Error),
43 UnknownKey(PyValue),
45 MissingKey(String),
47 IllegalValue {
49 key: String,
51 value: PyValue,
53 },
54 DictParse(PyValueParseError),
56 MetaNotDict(PyValue),
58 MissingNewline,
60}
61
62impl Error for ParseHeaderError {
63 fn source(&self) -> Option<&(dyn Error + 'static)> {
64 use ParseHeaderError::*;
65 match self {
66 MagicString => None,
67 Version { .. } => None,
68 HeaderLengthOverflow(_) => None,
69 NonAscii => None,
70 Utf8Parse(err) => Some(err),
71 UnknownKey(_) => None,
72 MissingKey(_) => None,
73 IllegalValue { .. } => None,
74 DictParse(err) => Some(err),
75 MetaNotDict(_) => None,
76 MissingNewline => None,
77 }
78 }
79}
80
81impl fmt::Display for ParseHeaderError {
82 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83 use ParseHeaderError::*;
84 match self {
85 MagicString => write!(f, "start does not match magic string"),
86 Version { major, minor } => write!(f, "unknown version number: {}.{}", major, minor),
87 HeaderLengthOverflow(header_len) => write!(f, "HEADER_LEN {} does not fit in `usize`", header_len),
88 NonAscii => write!(f, "non-ascii in array format string; this is not supported in .npy format versions 1.0 and 2.0"),
89 Utf8Parse(err) => write!(f, "error parsing array format string as UTF-8: {}", err),
90 UnknownKey(key) => write!(f, "unknown key: {}", key),
91 MissingKey(key) => write!(f, "missing key: {}", key),
92 IllegalValue { key, value } => write!(f, "illegal value for key {}: {}", key, value),
93 DictParse(err) => write!(f, "error parsing metadata dict: {}", err),
94 MetaNotDict(value) => write!(f, "metadata is not a dict: {}", value),
95 MissingNewline => write!(f, "newline missing at end of header"),
96 }
97 }
98}
99
100impl From<std::str::Utf8Error> for ParseHeaderError {
101 fn from(err: std::str::Utf8Error) -> ParseHeaderError {
102 ParseHeaderError::Utf8Parse(err)
103 }
104}
105
106impl From<PyValueParseError> for ParseHeaderError {
107 fn from(err: PyValueParseError) -> ParseHeaderError {
108 ParseHeaderError::DictParse(err)
109 }
110}
111
112#[derive(Debug)]
114pub enum ReadHeaderError {
115 Io(io::Error),
117 Parse(ParseHeaderError),
119}
120
121impl Error for ReadHeaderError {
122 fn source(&self) -> Option<&(dyn Error + 'static)> {
123 match self {
124 ReadHeaderError::Io(err) => Some(err),
125 ReadHeaderError::Parse(err) => Some(err),
126 }
127 }
128}
129
130impl fmt::Display for ReadHeaderError {
131 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
132 match self {
133 ReadHeaderError::Io(err) => write!(f, "I/O error: {}", err),
134 ReadHeaderError::Parse(err) => write!(f, "error parsing header: {}", err),
135 }
136 }
137}
138
139impl From<io::Error> for ReadHeaderError {
140 fn from(err: io::Error) -> ReadHeaderError {
141 ReadHeaderError::Io(err)
142 }
143}
144
145impl From<ParseHeaderError> for ReadHeaderError {
146 fn from(err: ParseHeaderError) -> ReadHeaderError {
147 ReadHeaderError::Parse(err)
148 }
149}
150
151#[derive(Clone, Copy)]
152#[allow(non_camel_case_types)]
153enum Version {
154 V1_0,
155 V2_0,
156 V3_0,
157}
158
159impl Version {
160 const VERSION_NUM_BYTES: usize = 2;
163
164 fn from_bytes(bytes: &[u8]) -> Result<Self, ParseHeaderError> {
165 debug_assert_eq!(bytes.len(), Self::VERSION_NUM_BYTES);
166 match (bytes[0], bytes[1]) {
167 (0x01, 0x00) => Ok(Version::V1_0),
168 (0x02, 0x00) => Ok(Version::V2_0),
169 (0x03, 0x00) => Ok(Version::V3_0),
170 (major, minor) => Err(ParseHeaderError::Version { major, minor }),
171 }
172 }
173
174 fn major_version(self) -> u8 {
176 match self {
177 Version::V1_0 => 1,
178 Version::V2_0 => 2,
179 Version::V3_0 => 3,
180 }
181 }
182
183 fn minor_version(self) -> u8 {
185 match self {
186 Version::V1_0 => 0,
187 Version::V2_0 => 0,
188 Version::V3_0 => 0,
189 }
190 }
191
192 fn header_len_num_bytes(self) -> usize {
194 match self {
195 Version::V1_0 => 2,
196 Version::V2_0 | Version::V3_0 => 4,
197 }
198 }
199
200 fn read_header_len<R: io::Read>(self, reader: &mut R) -> Result<usize, ReadHeaderError> {
202 match self {
203 Version::V1_0 => Ok(usize::from(reader.read_u16::<LittleEndian>()?)),
204 Version::V2_0 | Version::V3_0 => {
205 let header_len: u32 = reader.read_u32::<LittleEndian>()?;
206 Ok(usize::try_from(header_len)
207 .map_err(|_| ParseHeaderError::HeaderLengthOverflow(header_len))?)
208 }
209 }
210 }
211
212 fn format_header_len(self, header_len: usize) -> Option<Vec<u8>> {
216 match self {
217 Version::V1_0 => {
218 let header_len: u16 = u16::try_from(header_len).ok()?;
219 let mut out = vec![0; self.header_len_num_bytes()];
220 LittleEndian::write_u16(&mut out, header_len);
221 Some(out)
222 }
223 Version::V2_0 | Version::V3_0 => {
224 let header_len: u32 = u32::try_from(header_len).ok()?;
225 let mut out = vec![0; self.header_len_num_bytes()];
226 LittleEndian::write_u32(&mut out, header_len);
227 Some(out)
228 }
229 }
230 }
231
232 fn compute_lengths(self, unpadded_arr_format: &[u8]) -> Option<HeaderLengthInfo> {
241 const NEWLINE_LEN: usize = 1;
243
244 let prefix_len: usize =
245 MAGIC_STRING.len() + Version::VERSION_NUM_BYTES + self.header_len_num_bytes();
246 let unpadded_total_len: usize = prefix_len
247 .checked_add(unpadded_arr_format.len())?
248 .checked_add(NEWLINE_LEN)?;
249 let padding_len: usize = HEADER_DIVISOR - unpadded_total_len % HEADER_DIVISOR;
250 let total_len: usize = unpadded_total_len.checked_add(padding_len)?;
251 let header_len: usize = total_len - prefix_len;
252 let formatted_header_len = self.format_header_len(header_len)?;
253 Some(HeaderLengthInfo {
254 total_len,
255 formatted_header_len,
256 })
257 }
258}
259
260struct HeaderLengthInfo {
261 total_len: usize,
264 formatted_header_len: Vec<u8>,
267}
268
269#[derive(Debug)]
271pub enum FormatHeaderError {
272 PyValue(PyValueFormatError),
274 HeaderTooLong,
277}
278
279impl Error for FormatHeaderError {
280 fn source(&self) -> Option<&(dyn Error + 'static)> {
281 match self {
282 FormatHeaderError::PyValue(err) => Some(err),
283 FormatHeaderError::HeaderTooLong => None,
284 }
285 }
286}
287
288impl fmt::Display for FormatHeaderError {
289 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
290 match self {
291 FormatHeaderError::PyValue(err) => write!(f, "error formatting Python value: {}", err),
292 FormatHeaderError::HeaderTooLong => write!(f, "the header is too long"),
293 }
294 }
295}
296
297impl From<PyValueFormatError> for FormatHeaderError {
298 fn from(err: PyValueFormatError) -> FormatHeaderError {
299 FormatHeaderError::PyValue(err)
300 }
301}
302
303#[derive(Debug)]
305pub enum WriteHeaderError {
306 Io(io::Error),
308 Format(FormatHeaderError),
310}
311
312impl Error for WriteHeaderError {
313 fn source(&self) -> Option<&(dyn Error + 'static)> {
314 match self {
315 WriteHeaderError::Io(err) => Some(err),
316 WriteHeaderError::Format(err) => Some(err),
317 }
318 }
319}
320
321impl fmt::Display for WriteHeaderError {
322 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
323 match self {
324 WriteHeaderError::Io(err) => write!(f, "I/O error: {}", err),
325 WriteHeaderError::Format(err) => write!(f, "error formatting header: {}", err),
326 }
327 }
328}
329
330impl From<io::Error> for WriteHeaderError {
331 fn from(err: io::Error) -> WriteHeaderError {
332 WriteHeaderError::Io(err)
333 }
334}
335
336impl From<FormatHeaderError> for WriteHeaderError {
337 fn from(err: FormatHeaderError) -> WriteHeaderError {
338 WriteHeaderError::Format(err)
339 }
340}
341
342#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
344pub enum Layout {
345 Standard,
347 Fortran,
349}
350
351impl Layout {
352 #[inline]
354 pub fn is_fortran(&self) -> bool {
355 matches!(*self, Layout::Fortran)
356 }
357}
358
359#[derive(Clone, Debug)]
361pub struct Header {
362 pub type_descriptor: PyValue,
365 pub layout: Layout,
367 pub shape: Vec<usize>,
369}
370
371impl fmt::Display for Header {
372 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
373 write!(f, "{}", self.to_py_value())
374 }
375}
376
377impl Header {
378 fn from_py_value(value: PyValue) -> Result<Self, ParseHeaderError> {
379 if let PyValue::Dict(dict) = value {
380 let mut type_descriptor: Option<PyValue> = None;
381 let mut is_fortran: Option<bool> = None;
382 let mut shape: Option<Vec<usize>> = None;
383 for (key, value) in dict {
384 match key {
385 PyValue::String(ref k) if k == "descr" => {
386 type_descriptor = Some(value);
387 }
388 PyValue::String(ref k) if k == "fortran_order" => {
389 if let PyValue::Boolean(b) = value {
390 is_fortran = Some(b);
391 } else {
392 return Err(ParseHeaderError::IllegalValue {
393 key: "fortran_order".to_owned(),
394 value,
395 });
396 }
397 }
398 PyValue::String(ref k) if k == "shape" => {
399 fn parse_shape(value: &PyValue) -> Option<Vec<usize>> {
400 value
401 .as_tuple()?
402 .iter()
403 .map(|elem| elem.as_integer()?.to_usize())
404 .collect()
405 }
406 if let Some(s) = parse_shape(&value) {
407 shape = Some(s);
408 } else {
409 return Err(ParseHeaderError::IllegalValue {
410 key: "shape".to_owned(),
411 value,
412 });
413 }
414 }
415 k => return Err(ParseHeaderError::UnknownKey(k)),
416 }
417 }
418 match (type_descriptor, is_fortran, shape) {
419 (Some(type_descriptor), Some(is_fortran), Some(shape)) => {
420 let layout = if is_fortran {
421 Layout::Fortran
422 } else {
423 Layout::Standard
424 };
425 Ok(Header {
426 type_descriptor,
427 layout,
428 shape,
429 })
430 }
431 (None, _, _) => Err(ParseHeaderError::MissingKey("descr".to_owned())),
432 (_, None, _) => Err(ParseHeaderError::MissingKey("fortran_order".to_owned())),
433 (_, _, None) => Err(ParseHeaderError::MissingKey("shaper".to_owned())),
434 }
435 } else {
436 Err(ParseHeaderError::MetaNotDict(value))
437 }
438 }
439
440 pub fn from_reader<R: io::Read>(reader: &mut R) -> Result<Self, ReadHeaderError> {
442 let mut buf = vec![0; MAGIC_STRING.len()];
444 reader.read_exact(&mut buf)?;
445 if buf != MAGIC_STRING {
446 return Err(ParseHeaderError::MagicString.into());
447 }
448
449 let mut buf = [0; Version::VERSION_NUM_BYTES];
451 reader.read_exact(&mut buf)?;
452 let version = Version::from_bytes(&buf)?;
453
454 let header_len = version.read_header_len(reader)?;
456
457 let mut buf = vec![0; header_len];
459 reader.read_exact(&mut buf)?;
460 let without_newline = match buf.split_last() {
461 Some((&b'\n', rest)) => rest,
462 Some(_) | None => return Err(ParseHeaderError::MissingNewline.into()),
463 };
464 let header_str = match version {
465 Version::V1_0 | Version::V2_0 => {
466 if without_newline.is_ascii() {
467 unsafe { std::str::from_utf8_unchecked(without_newline) }
469 } else {
470 return Err(ParseHeaderError::NonAscii.into());
471 }
472 }
473 Version::V3_0 => {
474 std::str::from_utf8(without_newline).map_err(ParseHeaderError::from)?
475 }
476 };
477 let arr_format: PyValue = header_str.parse().map_err(ParseHeaderError::from)?;
478 Ok(Header::from_py_value(arr_format)?)
479 }
480
481 fn to_py_value(&self) -> PyValue {
482 PyValue::Dict(vec![
483 (
484 PyValue::String("descr".into()),
485 self.type_descriptor.clone(),
486 ),
487 (
488 PyValue::String("fortran_order".into()),
489 PyValue::Boolean(self.layout.is_fortran()),
490 ),
491 (
492 PyValue::String("shape".into()),
493 PyValue::Tuple(
494 self.shape
495 .iter()
496 .map(|&elem| PyValue::Integer(elem.into()))
497 .collect(),
498 ),
499 ),
500 ])
501 }
502
503 pub fn to_bytes(&self) -> Result<Vec<u8>, FormatHeaderError> {
505 let mut arr_format = Vec::new();
507 self.to_py_value().write_ascii(&mut arr_format)?;
508
509 let (version, length_info) = [Version::V1_0, Version::V2_0]
512 .iter()
513 .find_map(|&version| Some((version, version.compute_lengths(&arr_format)?)))
514 .ok_or(FormatHeaderError::HeaderTooLong)?;
515
516 let mut out = Vec::with_capacity(length_info.total_len);
518 out.extend_from_slice(MAGIC_STRING);
519 out.push(version.major_version());
520 out.push(version.minor_version());
521 out.extend_from_slice(&length_info.formatted_header_len);
522 out.extend_from_slice(&arr_format);
523 out.resize(length_info.total_len - 1, b' ');
524 out.push(b'\n');
525
526 debug_assert_eq!(out.len(), length_info.total_len);
528 debug_assert_eq!(out.len() % HEADER_DIVISOR, 0);
529
530 Ok(out)
531 }
532
533 pub fn write<W: io::Write>(&self, mut writer: W) -> Result<(), WriteHeaderError> {
535 let bytes = self.to_bytes()?;
536 writer.write_all(&bytes)?;
537 Ok(())
538 }
539}