Skip to main content

ferray_io/npy/
header.rs

1// ferray-io: .npy header parsing and writing
2//
3// The .npy header is a Python dict literal with keys 'descr', 'fortran_order', 'shape'.
4// Example: "{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }"
5
6use ferray_core::dtype::DType;
7use ferray_core::error::{FerrayError, FerrayResult};
8
9use super::dtype_parse::{self, Endianness};
10use crate::format;
11
12/// Parsed .npy file header.
13#[derive(Debug, Clone)]
14pub struct NpyHeader {
15    /// The dtype descriptor string (e.g., "<f8").
16    pub descr: String,
17    /// Parsed dtype.
18    pub dtype: DType,
19    /// Parsed endianness.
20    pub endianness: Endianness,
21    /// Whether the data is stored in Fortran (column-major) order.
22    pub fortran_order: bool,
23    /// Shape of the array.
24    pub shape: Vec<usize>,
25    /// Format version (major, minor).
26    pub version: (u8, u8),
27}
28
29/// Read and parse a .npy header from a reader.
30///
31/// After this function returns, the reader is positioned at the start of the data.
32pub fn read_header<R: std::io::Read>(reader: &mut R) -> FerrayResult<NpyHeader> {
33    // Read magic
34    let mut magic = [0u8; format::NPY_MAGIC_LEN];
35    reader
36        .read_exact(&mut magic)
37        .map_err(|e| FerrayError::io_error(format!("failed to read .npy magic: {e}")))?;
38
39    if magic != *format::NPY_MAGIC {
40        return Err(FerrayError::io_error(
41            "not a valid .npy file: bad magic number",
42        ));
43    }
44
45    // Read version
46    let mut version = [0u8; 2];
47    reader
48        .read_exact(&mut version)
49        .map_err(|e| FerrayError::io_error(format!("failed to read .npy version: {e}")))?;
50
51    let major = version[0];
52    let minor = version[1];
53
54    if !matches!((major, minor), (1, 0) | (2, 0) | (3, 0)) {
55        return Err(FerrayError::io_error(format!(
56            "unsupported .npy format version {major}.{minor}"
57        )));
58    }
59
60    // Read header length
61    let header_len = if major == 1 {
62        let mut buf = [0u8; 2];
63        reader
64            .read_exact(&mut buf)
65            .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
66        u16::from_le_bytes(buf) as usize
67    } else {
68        let mut buf = [0u8; 4];
69        reader
70            .read_exact(&mut buf)
71            .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
72        let raw_len = u32::from_le_bytes(buf) as usize;
73        // Cap header length at 1 MB to prevent unbounded allocation from
74        // untrusted files. Legitimate .npy headers are typically < 1 KB.
75        const MAX_HEADER_LEN: usize = 1_048_576;
76        if raw_len > MAX_HEADER_LEN {
77            return Err(FerrayError::io_error(format!(
78                "header length {raw_len} exceeds maximum allowed size ({MAX_HEADER_LEN} bytes)"
79            )));
80        }
81        raw_len
82    };
83
84    // Read header string
85    let mut header_bytes = vec![0u8; header_len];
86    reader
87        .read_exact(&mut header_bytes)
88        .map_err(|e| FerrayError::io_error(format!("failed to read header: {e}")))?;
89
90    let header_str = std::str::from_utf8(&header_bytes)
91        .map_err(|e| FerrayError::io_error(format!("header is not valid UTF-8: {e}")))?;
92
93    // Parse the header dict
94    let (descr, fortran_order, shape) = parse_header_dict(header_str)?;
95    let (dtype, endianness) = dtype_parse::parse_dtype_str(&descr)?;
96
97    Ok(NpyHeader {
98        descr,
99        dtype,
100        endianness,
101        fortran_order,
102        shape,
103        version: (major, minor),
104    })
105}
106
107/// Write a .npy header to a writer. Returns the total preamble+header size.
108pub fn write_header<W: std::io::Write>(
109    writer: &mut W,
110    dtype: DType,
111    shape: &[usize],
112    fortran_order: bool,
113) -> FerrayResult<()> {
114    let descr = dtype_parse::dtype_to_native_descr(dtype)?;
115    let fortran_str = if fortran_order { "True" } else { "False" };
116
117    let shape_str = format_shape(shape);
118
119    let dict =
120        format!("{{'descr': '{descr}', 'fortran_order': {fortran_str}, 'shape': {shape_str}, }}");
121
122    // Try version 1.0 first (header length fits in u16)
123    // Preamble: magic(6) + version(2) + header_len(2) = 10 for v1
124    // Preamble: magic(6) + version(2) + header_len(4) = 12 for v2
125    let preamble_v1 = format::NPY_MAGIC_LEN + 2 + 2; // 10
126    let padding_needed_v1 = compute_padding(preamble_v1 + dict.len() + 1); // +1 for newline
127    let total_header_v1 = dict.len() + padding_needed_v1 + 1;
128
129    if total_header_v1 <= format::MAX_HEADER_LEN_V1 {
130        // Version 1.0
131        writer.write_all(format::NPY_MAGIC)?;
132        writer.write_all(&[1, 0])?;
133        writer.write_all(&(total_header_v1 as u16).to_le_bytes())?;
134        writer.write_all(dict.as_bytes())?;
135        write_padding(writer, padding_needed_v1)?;
136        writer.write_all(b"\n")?;
137    } else {
138        // Version 2.0
139        let preamble_v2 = format::NPY_MAGIC_LEN + 2 + 4; // 12
140        let padding_needed_v2 = compute_padding(preamble_v2 + dict.len() + 1);
141        let total_header_v2 = dict.len() + padding_needed_v2 + 1;
142
143        writer.write_all(format::NPY_MAGIC)?;
144        writer.write_all(&[2, 0])?;
145        writer.write_all(&(total_header_v2 as u32).to_le_bytes())?;
146        writer.write_all(dict.as_bytes())?;
147        write_padding(writer, padding_needed_v2)?;
148        writer.write_all(b"\n")?;
149    }
150
151    Ok(())
152}
153
154/// Compute the header size (preamble + header dict + padding + newline) for reading purposes.
155/// Returns the byte offset where data begins.
156pub fn compute_data_offset(version: (u8, u8), header_len: usize) -> usize {
157    let preamble = format::NPY_MAGIC_LEN + 2 + if version.0 == 1 { 2 } else { 4 };
158    preamble + header_len
159}
160
161fn compute_padding(current_total: usize) -> usize {
162    let remainder = current_total % format::HEADER_ALIGNMENT;
163    if remainder == 0 {
164        0
165    } else {
166        format::HEADER_ALIGNMENT - remainder
167    }
168}
169
170fn write_padding<W: std::io::Write>(writer: &mut W, count: usize) -> FerrayResult<()> {
171    for _ in 0..count {
172        writer.write_all(b" ")?;
173    }
174    Ok(())
175}
176
177fn format_shape(shape: &[usize]) -> String {
178    match shape.len() {
179        0 => "()".to_string(),
180        1 => format!("({},)", shape[0]),
181        _ => {
182            let parts: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
183            format!("({})", parts.join(", "))
184        }
185    }
186}
187
188/// Parse the Python dict-like header string.
189///
190/// We do simple string parsing here rather than pulling in a full Python parser.
191/// The format is well-defined: `{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }`
192fn parse_header_dict(header: &str) -> FerrayResult<(String, bool, Vec<usize>)> {
193    let header = header.trim();
194
195    // Strip outer braces
196    let inner = header
197        .strip_prefix('{')
198        .and_then(|s| s.strip_suffix('}'))
199        .ok_or_else(|| FerrayError::io_error("header dict missing braces"))?
200        .trim();
201
202    let descr = extract_string_value(inner, "descr")?;
203    let fortran_order = extract_bool_value(inner, "fortran_order")?;
204    let shape = extract_shape_value(inner, "shape")?;
205
206    Ok((descr, fortran_order, shape))
207}
208
209/// Extract a string value for a given key from the dict body.
210fn extract_string_value(dict_body: &str, key: &str) -> FerrayResult<String> {
211    // Look for 'key': 'value'
212    let pattern = format!("'{key}':");
213    let pos = dict_body
214        .find(&pattern)
215        .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
216
217    let after_key = &dict_body[pos + pattern.len()..].trim_start();
218
219    // Find the opening quote
220    let quote_char = after_key
221        .as_bytes()
222        .first()
223        .ok_or_else(|| FerrayError::io_error(format!("missing value for key '{key}'")))?;
224
225    if *quote_char != b'\'' && *quote_char != b'"' {
226        return Err(FerrayError::io_error(format!(
227            "expected string value for key '{key}'"
228        )));
229    }
230
231    let qc = *quote_char as char;
232    let value_start = &after_key[1..];
233    let end = value_start
234        .find(qc)
235        .ok_or_else(|| FerrayError::io_error(format!("unterminated string for key '{key}'")))?;
236
237    Ok(value_start[..end].to_string())
238}
239
240/// Extract a boolean value for a given key from the dict body.
241fn extract_bool_value(dict_body: &str, key: &str) -> FerrayResult<bool> {
242    let pattern = format!("'{key}':");
243    let pos = dict_body
244        .find(&pattern)
245        .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
246
247    let after_key = dict_body[pos + pattern.len()..].trim_start();
248
249    if after_key.starts_with("True") {
250        Ok(true)
251    } else if after_key.starts_with("False") {
252        Ok(false)
253    } else {
254        Err(FerrayError::io_error(format!(
255            "expected True/False for key '{key}'"
256        )))
257    }
258}
259
260/// Extract a tuple shape value for a given key from the dict body.
261fn extract_shape_value(dict_body: &str, key: &str) -> FerrayResult<Vec<usize>> {
262    let pattern = format!("'{key}':");
263    let pos = dict_body
264        .find(&pattern)
265        .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
266
267    let after_key = dict_body[pos + pattern.len()..].trim_start();
268
269    // Find the opening paren
270    if !after_key.starts_with('(') {
271        return Err(FerrayError::io_error(format!(
272            "expected tuple for key '{key}'"
273        )));
274    }
275
276    let close = after_key
277        .find(')')
278        .ok_or_else(|| FerrayError::io_error(format!("unterminated tuple for key '{key}'")))?;
279
280    let tuple_inner = &after_key[1..close];
281    let tuple_inner = tuple_inner.trim();
282
283    if tuple_inner.is_empty() {
284        return Ok(vec![]);
285    }
286
287    let parts: FerrayResult<Vec<usize>> = tuple_inner
288        .split(',')
289        .filter(|s| !s.trim().is_empty())
290        .map(|s| {
291            s.trim()
292                .parse::<usize>()
293                .map_err(|e| FerrayError::io_error(format!("invalid shape dimension '{s}': {e}")))
294        })
295        .collect();
296
297    parts
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn parse_simple_header() {
306        let header = "{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }";
307        let (descr, fortran, shape) = parse_header_dict(header).unwrap();
308        assert_eq!(descr, "<f8");
309        assert!(!fortran);
310        assert_eq!(shape, vec![3, 4]);
311    }
312
313    #[test]
314    fn parse_1d_header() {
315        let header = "{'descr': '<i4', 'fortran_order': False, 'shape': (10,), }";
316        let (descr, fortran, shape) = parse_header_dict(header).unwrap();
317        assert_eq!(descr, "<i4");
318        assert!(!fortran);
319        assert_eq!(shape, vec![10]);
320    }
321
322    #[test]
323    fn parse_scalar_header() {
324        let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (), }";
325        let (descr, fortran, shape) = parse_header_dict(header).unwrap();
326        assert_eq!(descr, "<f4");
327        assert!(!fortran);
328        assert!(shape.is_empty());
329    }
330
331    #[test]
332    fn parse_fortran_order() {
333        let header = "{'descr': '<f8', 'fortran_order': True, 'shape': (2, 3), }";
334        let (_, fortran, _) = parse_header_dict(header).unwrap();
335        assert!(fortran);
336    }
337
338    #[test]
339    fn format_shape_empty() {
340        assert_eq!(format_shape(&[]), "()");
341    }
342
343    #[test]
344    fn format_shape_1d() {
345        assert_eq!(format_shape(&[5]), "(5,)");
346    }
347
348    #[test]
349    fn format_shape_2d() {
350        assert_eq!(format_shape(&[3, 4]), "(3, 4)");
351    }
352
353    #[test]
354    fn write_read_roundtrip() {
355        let mut buf = Vec::new();
356        write_header(&mut buf, DType::F64, &[3, 4], false).unwrap();
357
358        let mut cursor = std::io::Cursor::new(buf);
359        let header = read_header(&mut cursor).unwrap();
360
361        assert_eq!(header.dtype, DType::F64);
362        assert_eq!(header.shape, vec![3, 4]);
363        assert!(!header.fortran_order);
364    }
365
366    #[test]
367    fn header_alignment() {
368        for shape in [&[3, 4][..], &[100, 200, 300], &[1]] {
369            let mut buf = Vec::new();
370            write_header(&mut buf, DType::F64, shape, false).unwrap();
371            assert_eq!(
372                buf.len() % format::HEADER_ALIGNMENT,
373                0,
374                "header not aligned for shape {shape:?}"
375            );
376        }
377    }
378}