Skip to main content

ferray_io/npy/
header.rs

1// ferray-io: .npy header parsing and writing
2//
3// The .npy header is a Python dict literal with keys 'descr', 'fortran_order', 'shape'.
4// Example: "{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }"
5
6use ferray_core::dtype::DType;
7use ferray_core::error::{FerrayError, FerrayResult};
8
9use super::dtype_parse::{self, Endianness};
10use crate::format;
11
12/// Parsed .npy file header.
13#[derive(Debug, Clone)]
14pub struct NpyHeader {
15    /// The dtype descriptor string (e.g., "<f8").
16    pub descr: String,
17    /// Parsed dtype.
18    pub dtype: DType,
19    /// Parsed endianness.
20    pub endianness: Endianness,
21    /// Whether the data is stored in Fortran (column-major) order.
22    pub fortran_order: bool,
23    /// Shape of the array.
24    pub shape: Vec<usize>,
25    /// Format version (major, minor).
26    pub version: (u8, u8),
27}
28
29/// Read and parse a .npy header from a reader.
30///
31/// After this function returns, the reader is positioned at the start of the data.
32pub fn read_header<R: std::io::Read>(reader: &mut R) -> FerrayResult<NpyHeader> {
33    // Read magic
34    let mut magic = [0u8; format::NPY_MAGIC_LEN];
35    reader
36        .read_exact(&mut magic)
37        .map_err(|e| FerrayError::io_error(format!("failed to read .npy magic: {e}")))?;
38
39    if magic != *format::NPY_MAGIC {
40        return Err(FerrayError::io_error(
41            "not a valid .npy file: bad magic number",
42        ));
43    }
44
45    // Read version
46    let mut version = [0u8; 2];
47    reader
48        .read_exact(&mut version)
49        .map_err(|e| FerrayError::io_error(format!("failed to read .npy version: {e}")))?;
50
51    let major = version[0];
52    let minor = version[1];
53
54    if !matches!((major, minor), (1, 0) | (2, 0) | (3, 0)) {
55        return Err(FerrayError::io_error(format!(
56            "unsupported .npy format version {major}.{minor}"
57        )));
58    }
59
60    // Read header length
61    let header_len = if major == 1 {
62        let mut buf = [0u8; 2];
63        reader
64            .read_exact(&mut buf)
65            .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
66        u16::from_le_bytes(buf) as usize
67    } else {
68        let mut buf = [0u8; 4];
69        reader
70            .read_exact(&mut buf)
71            .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
72        let raw_len = u32::from_le_bytes(buf) as usize;
73        // Cap header length at 1 MB to prevent unbounded allocation from
74        // untrusted files. Legitimate .npy headers are typically < 1 KB.
75        const MAX_HEADER_LEN: usize = 1_048_576;
76        if raw_len > MAX_HEADER_LEN {
77            return Err(FerrayError::io_error(format!(
78                "header length {raw_len} exceeds maximum allowed size ({MAX_HEADER_LEN} bytes)"
79            )));
80        }
81        raw_len
82    };
83
84    // Read header string
85    let mut header_bytes = vec![0u8; header_len];
86    reader
87        .read_exact(&mut header_bytes)
88        .map_err(|e| FerrayError::io_error(format!("failed to read header: {e}")))?;
89
90    let header_str = std::str::from_utf8(&header_bytes)
91        .map_err(|e| FerrayError::io_error(format!("header is not valid UTF-8: {e}")))?;
92
93    // Parse the header dict
94    let (descr, fortran_order, shape) = parse_header_dict(header_str)?;
95    let (dtype, endianness) = dtype_parse::parse_dtype_str(&descr)?;
96
97    Ok(NpyHeader {
98        descr,
99        dtype,
100        endianness,
101        fortran_order,
102        shape,
103        version: (major, minor),
104    })
105}
106
107/// Write a .npy header to a writer. Returns the total preamble+header size.
108pub fn write_header<W: std::io::Write>(
109    writer: &mut W,
110    dtype: DType,
111    shape: &[usize],
112    fortran_order: bool,
113) -> FerrayResult<()> {
114    let descr = dtype_parse::dtype_to_native_descr(dtype)?;
115    let fortran_str = if fortran_order { "True" } else { "False" };
116
117    let shape_str = format_shape(shape);
118
119    let dict =
120        format!("{{'descr': '{descr}', 'fortran_order': {fortran_str}, 'shape': {shape_str}, }}");
121
122    // Try version 1.0 first (header length fits in u16)
123    // Preamble: magic(6) + version(2) + header_len(2) = 10 for v1
124    // Preamble: magic(6) + version(2) + header_len(4) = 12 for v2
125    let preamble_v1 = format::NPY_MAGIC_LEN + 2 + 2; // 10
126    let padding_needed_v1 = compute_padding(preamble_v1 + dict.len() + 1); // +1 for newline
127    let total_header_v1 = dict.len() + padding_needed_v1 + 1;
128
129    if total_header_v1 <= format::MAX_HEADER_LEN_V1 {
130        // Version 1.0
131        writer.write_all(format::NPY_MAGIC)?;
132        writer.write_all(&[1, 0])?;
133        writer.write_all(&(total_header_v1 as u16).to_le_bytes())?;
134        writer.write_all(dict.as_bytes())?;
135        write_padding(writer, padding_needed_v1)?;
136        writer.write_all(b"\n")?;
137    } else {
138        // Version 2.0
139        let preamble_v2 = format::NPY_MAGIC_LEN + 2 + 4; // 12
140        let padding_needed_v2 = compute_padding(preamble_v2 + dict.len() + 1);
141        let total_header_v2 = dict.len() + padding_needed_v2 + 1;
142
143        writer.write_all(format::NPY_MAGIC)?;
144        writer.write_all(&[2, 0])?;
145        writer.write_all(&(total_header_v2 as u32).to_le_bytes())?;
146        writer.write_all(dict.as_bytes())?;
147        write_padding(writer, padding_needed_v2)?;
148        writer.write_all(b"\n")?;
149    }
150
151    Ok(())
152}
153
154/// Compute the header size (preamble + header dict + padding + newline) for reading purposes.
155/// Returns the byte offset where data begins.
156pub fn compute_data_offset(version: (u8, u8), header_len: usize) -> usize {
157    let preamble = format::NPY_MAGIC_LEN + 2 + if version.0 == 1 { 2 } else { 4 };
158    preamble + header_len
159}
160
161fn compute_padding(current_total: usize) -> usize {
162    let remainder = current_total % format::HEADER_ALIGNMENT;
163    if remainder == 0 {
164        0
165    } else {
166        format::HEADER_ALIGNMENT - remainder
167    }
168}
169
170/// Write `count` space bytes as header padding. Uses a single
171/// `write_all` call instead of one per byte (#237).
172fn write_padding<W: std::io::Write>(writer: &mut W, count: usize) -> FerrayResult<()> {
173    // Stack buffer covers all realistic header padding; the .npy
174    // spec aligns to 64 bytes, so count is always < 64.
175    const MAX_STACK: usize = 128;
176    if count <= MAX_STACK {
177        let buf = [b' '; MAX_STACK];
178        writer.write_all(&buf[..count])?;
179    } else {
180        let buf = vec![b' '; count];
181        writer.write_all(&buf)?;
182    }
183    Ok(())
184}
185
186fn format_shape(shape: &[usize]) -> String {
187    match shape.len() {
188        0 => "()".to_string(),
189        1 => format!("({},)", shape[0]),
190        _ => {
191            let parts: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
192            format!("({})", parts.join(", "))
193        }
194    }
195}
196
197/// Parse the Python dict-like header string.
198///
199/// We do simple string parsing here rather than pulling in a full Python parser.
200/// The format is well-defined: `{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }`
201fn parse_header_dict(header: &str) -> FerrayResult<(String, bool, Vec<usize>)> {
202    let header = header.trim();
203
204    // Strip outer braces
205    let inner = header
206        .strip_prefix('{')
207        .and_then(|s| s.strip_suffix('}'))
208        .ok_or_else(|| FerrayError::io_error("header dict missing braces"))?
209        .trim();
210
211    let descr = extract_string_value(inner, "descr")?;
212    let fortran_order = extract_bool_value(inner, "fortran_order")?;
213    let shape = extract_shape_value(inner, "shape")?;
214
215    Ok((descr, fortran_order, shape))
216}
217
218/// Extract a string value for a given key from the dict body.
219fn extract_string_value(dict_body: &str, key: &str) -> FerrayResult<String> {
220    // Look for 'key': 'value'
221    let pattern = format!("'{key}':");
222    let pos = dict_body
223        .find(&pattern)
224        .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
225
226    let after_key = &dict_body[pos + pattern.len()..].trim_start();
227
228    // Find the opening quote
229    let quote_char = after_key
230        .as_bytes()
231        .first()
232        .ok_or_else(|| FerrayError::io_error(format!("missing value for key '{key}'")))?;
233
234    if *quote_char != b'\'' && *quote_char != b'"' {
235        return Err(FerrayError::io_error(format!(
236            "expected string value for key '{key}'"
237        )));
238    }
239
240    let qc = *quote_char as char;
241    let value_start = &after_key[1..];
242    let end = value_start
243        .find(qc)
244        .ok_or_else(|| FerrayError::io_error(format!("unterminated string for key '{key}'")))?;
245
246    Ok(value_start[..end].to_string())
247}
248
249/// Extract a boolean value for a given key from the dict body.
250fn extract_bool_value(dict_body: &str, key: &str) -> FerrayResult<bool> {
251    let pattern = format!("'{key}':");
252    let pos = dict_body
253        .find(&pattern)
254        .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
255
256    let after_key = dict_body[pos + pattern.len()..].trim_start();
257
258    if after_key.starts_with("True") {
259        Ok(true)
260    } else if after_key.starts_with("False") {
261        Ok(false)
262    } else {
263        Err(FerrayError::io_error(format!(
264            "expected True/False for key '{key}'"
265        )))
266    }
267}
268
269/// Extract a tuple shape value for a given key from the dict body.
270fn extract_shape_value(dict_body: &str, key: &str) -> FerrayResult<Vec<usize>> {
271    let pattern = format!("'{key}':");
272    let pos = dict_body
273        .find(&pattern)
274        .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
275
276    let after_key = dict_body[pos + pattern.len()..].trim_start();
277
278    // Find the opening paren
279    if !after_key.starts_with('(') {
280        return Err(FerrayError::io_error(format!(
281            "expected tuple for key '{key}'"
282        )));
283    }
284
285    let close = after_key
286        .find(')')
287        .ok_or_else(|| FerrayError::io_error(format!("unterminated tuple for key '{key}'")))?;
288
289    let tuple_inner = &after_key[1..close];
290    let tuple_inner = tuple_inner.trim();
291
292    if tuple_inner.is_empty() {
293        return Ok(vec![]);
294    }
295
296    let parts: FerrayResult<Vec<usize>> = tuple_inner
297        .split(',')
298        .filter(|s| !s.trim().is_empty())
299        .map(|s| {
300            s.trim()
301                .parse::<usize>()
302                .map_err(|e| FerrayError::io_error(format!("invalid shape dimension '{s}': {e}")))
303        })
304        .collect();
305
306    parts
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn parse_simple_header() {
315        let header = "{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }";
316        let (descr, fortran, shape) = parse_header_dict(header).unwrap();
317        assert_eq!(descr, "<f8");
318        assert!(!fortran);
319        assert_eq!(shape, vec![3, 4]);
320    }
321
322    #[test]
323    fn parse_1d_header() {
324        let header = "{'descr': '<i4', 'fortran_order': False, 'shape': (10,), }";
325        let (descr, fortran, shape) = parse_header_dict(header).unwrap();
326        assert_eq!(descr, "<i4");
327        assert!(!fortran);
328        assert_eq!(shape, vec![10]);
329    }
330
331    #[test]
332    fn parse_scalar_header() {
333        let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (), }";
334        let (descr, fortran, shape) = parse_header_dict(header).unwrap();
335        assert_eq!(descr, "<f4");
336        assert!(!fortran);
337        assert!(shape.is_empty());
338    }
339
340    #[test]
341    fn parse_fortran_order() {
342        let header = "{'descr': '<f8', 'fortran_order': True, 'shape': (2, 3), }";
343        let (_, fortran, _) = parse_header_dict(header).unwrap();
344        assert!(fortran);
345    }
346
347    #[test]
348    fn format_shape_empty() {
349        assert_eq!(format_shape(&[]), "()");
350    }
351
352    #[test]
353    fn format_shape_1d() {
354        assert_eq!(format_shape(&[5]), "(5,)");
355    }
356
357    #[test]
358    fn format_shape_2d() {
359        assert_eq!(format_shape(&[3, 4]), "(3, 4)");
360    }
361
362    #[test]
363    fn write_read_roundtrip() {
364        let mut buf = Vec::new();
365        write_header(&mut buf, DType::F64, &[3, 4], false).unwrap();
366
367        let mut cursor = std::io::Cursor::new(buf);
368        let header = read_header(&mut cursor).unwrap();
369
370        assert_eq!(header.dtype, DType::F64);
371        assert_eq!(header.shape, vec![3, 4]);
372        assert!(!header.fortran_order);
373    }
374
375    #[test]
376    fn header_alignment() {
377        for shape in [&[3, 4][..], &[100, 200, 300], &[1]] {
378            let mut buf = Vec::new();
379            write_header(&mut buf, DType::F64, shape, false).unwrap();
380            assert_eq!(
381                buf.len() % format::HEADER_ALIGNMENT,
382                0,
383                "header not aligned for shape {shape:?}"
384            );
385        }
386    }
387}