Skip to main content

ferray_io/npy/
header.rs

1// ferray-io: .npy header parsing and writing
2//
3// The .npy header is a Python dict literal with keys 'descr', 'fortran_order', 'shape'.
4// Example: "{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }"
5
6use ferray_core::dtype::DType;
7use ferray_core::error::{FerrayError, FerrayResult};
8
9use super::dtype_parse::{self, Endianness};
10use crate::format;
11
12/// Parsed .npy file header.
13#[derive(Debug, Clone)]
14pub struct NpyHeader {
15    /// The dtype descriptor string (e.g., "<f8").
16    pub descr: String,
17    /// Parsed dtype.
18    pub dtype: DType,
19    /// Parsed endianness.
20    pub endianness: Endianness,
21    /// Whether the data is stored in Fortran (column-major) order.
22    pub fortran_order: bool,
23    /// Shape of the array.
24    pub shape: Vec<usize>,
25    /// Format version (major, minor).
26    pub version: (u8, u8),
27}
28
29/// Read and parse a .npy header from a reader.
30///
31/// After this function returns, the reader is positioned at the start of the data.
32pub fn read_header<R: std::io::Read>(reader: &mut R) -> FerrayResult<NpyHeader> {
33    // Read magic
34    let mut magic = [0u8; format::NPY_MAGIC_LEN];
35    reader
36        .read_exact(&mut magic)
37        .map_err(|e| FerrayError::io_error(format!("failed to read .npy magic: {e}")))?;
38
39    if magic != *format::NPY_MAGIC {
40        return Err(FerrayError::io_error(
41            "not a valid .npy file: bad magic number",
42        ));
43    }
44
45    // Read version
46    let mut version = [0u8; 2];
47    reader
48        .read_exact(&mut version)
49        .map_err(|e| FerrayError::io_error(format!("failed to read .npy version: {e}")))?;
50
51    let major = version[0];
52    let minor = version[1];
53
54    if !matches!((major, minor), (1, 0) | (2, 0) | (3, 0)) {
55        return Err(FerrayError::io_error(format!(
56            "unsupported .npy format version {major}.{minor}"
57        )));
58    }
59
60    // Read header length
61    let header_len = if major == 1 {
62        let mut buf = [0u8; 2];
63        reader
64            .read_exact(&mut buf)
65            .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
66        u16::from_le_bytes(buf) as usize
67    } else {
68        let mut buf = [0u8; 4];
69        reader
70            .read_exact(&mut buf)
71            .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
72        u32::from_le_bytes(buf) as usize
73    };
74
75    // Read header string
76    let mut header_bytes = vec![0u8; header_len];
77    reader
78        .read_exact(&mut header_bytes)
79        .map_err(|e| FerrayError::io_error(format!("failed to read header: {e}")))?;
80
81    let header_str = std::str::from_utf8(&header_bytes)
82        .map_err(|e| FerrayError::io_error(format!("header is not valid UTF-8: {e}")))?;
83
84    // Parse the header dict
85    let (descr, fortran_order, shape) = parse_header_dict(header_str)?;
86    let (dtype, endianness) = dtype_parse::parse_dtype_str(&descr)?;
87
88    Ok(NpyHeader {
89        descr,
90        dtype,
91        endianness,
92        fortran_order,
93        shape,
94        version: (major, minor),
95    })
96}
97
98/// Write a .npy header to a writer. Returns the total preamble+header size.
99pub fn write_header<W: std::io::Write>(
100    writer: &mut W,
101    dtype: DType,
102    shape: &[usize],
103    fortran_order: bool,
104) -> FerrayResult<()> {
105    let descr = dtype_parse::dtype_to_native_descr(dtype)?;
106    let fortran_str = if fortran_order { "True" } else { "False" };
107
108    let shape_str = format_shape(shape);
109
110    let dict =
111        format!("{{'descr': '{descr}', 'fortran_order': {fortran_str}, 'shape': {shape_str}, }}");
112
113    // Try version 1.0 first (header length fits in u16)
114    // Preamble: magic(6) + version(2) + header_len(2) = 10 for v1
115    // Preamble: magic(6) + version(2) + header_len(4) = 12 for v2
116    let preamble_v1 = format::NPY_MAGIC_LEN + 2 + 2; // 10
117    let padding_needed_v1 = compute_padding(preamble_v1 + dict.len() + 1); // +1 for newline
118    let total_header_v1 = dict.len() + padding_needed_v1 + 1;
119
120    if total_header_v1 <= format::MAX_HEADER_LEN_V1 {
121        // Version 1.0
122        writer.write_all(format::NPY_MAGIC)?;
123        writer.write_all(&[1, 0])?;
124        writer.write_all(&(total_header_v1 as u16).to_le_bytes())?;
125        writer.write_all(dict.as_bytes())?;
126        write_padding(writer, padding_needed_v1)?;
127        writer.write_all(b"\n")?;
128    } else {
129        // Version 2.0
130        let preamble_v2 = format::NPY_MAGIC_LEN + 2 + 4; // 12
131        let padding_needed_v2 = compute_padding(preamble_v2 + dict.len() + 1);
132        let total_header_v2 = dict.len() + padding_needed_v2 + 1;
133
134        writer.write_all(format::NPY_MAGIC)?;
135        writer.write_all(&[2, 0])?;
136        writer.write_all(&(total_header_v2 as u32).to_le_bytes())?;
137        writer.write_all(dict.as_bytes())?;
138        write_padding(writer, padding_needed_v2)?;
139        writer.write_all(b"\n")?;
140    }
141
142    Ok(())
143}
144
145/// Compute the header size (preamble + header dict + padding + newline) for reading purposes.
146/// Returns the byte offset where data begins.
147pub fn compute_data_offset(version: (u8, u8), header_len: usize) -> usize {
148    let preamble = format::NPY_MAGIC_LEN + 2 + if version.0 == 1 { 2 } else { 4 };
149    preamble + header_len
150}
151
152fn compute_padding(current_total: usize) -> usize {
153    let remainder = current_total % format::HEADER_ALIGNMENT;
154    if remainder == 0 {
155        0
156    } else {
157        format::HEADER_ALIGNMENT - remainder
158    }
159}
160
161fn write_padding<W: std::io::Write>(writer: &mut W, count: usize) -> FerrayResult<()> {
162    for _ in 0..count {
163        writer.write_all(b" ")?;
164    }
165    Ok(())
166}
167
168fn format_shape(shape: &[usize]) -> String {
169    match shape.len() {
170        0 => "()".to_string(),
171        1 => format!("({},)", shape[0]),
172        _ => {
173            let parts: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
174            format!("({})", parts.join(", "))
175        }
176    }
177}
178
179/// Parse the Python dict-like header string.
180///
181/// We do simple string parsing here rather than pulling in a full Python parser.
182/// The format is well-defined: `{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }`
183fn parse_header_dict(header: &str) -> FerrayResult<(String, bool, Vec<usize>)> {
184    let header = header.trim();
185
186    // Strip outer braces
187    let inner = header
188        .strip_prefix('{')
189        .and_then(|s| s.strip_suffix('}'))
190        .ok_or_else(|| FerrayError::io_error("header dict missing braces"))?
191        .trim();
192
193    let descr = extract_string_value(inner, "descr")?;
194    let fortran_order = extract_bool_value(inner, "fortran_order")?;
195    let shape = extract_shape_value(inner, "shape")?;
196
197    Ok((descr, fortran_order, shape))
198}
199
200/// Extract a string value for a given key from the dict body.
201fn extract_string_value(dict_body: &str, key: &str) -> FerrayResult<String> {
202    // Look for 'key': 'value'
203    let pattern = format!("'{key}':");
204    let pos = dict_body
205        .find(&pattern)
206        .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
207
208    let after_key = &dict_body[pos + pattern.len()..].trim_start();
209
210    // Find the opening quote
211    let quote_char = after_key
212        .as_bytes()
213        .first()
214        .ok_or_else(|| FerrayError::io_error(format!("missing value for key '{key}'")))?;
215
216    if *quote_char != b'\'' && *quote_char != b'"' {
217        return Err(FerrayError::io_error(format!(
218            "expected string value for key '{key}'"
219        )));
220    }
221
222    let qc = *quote_char as char;
223    let value_start = &after_key[1..];
224    let end = value_start
225        .find(qc)
226        .ok_or_else(|| FerrayError::io_error(format!("unterminated string for key '{key}'")))?;
227
228    Ok(value_start[..end].to_string())
229}
230
231/// Extract a boolean value for a given key from the dict body.
232fn extract_bool_value(dict_body: &str, key: &str) -> FerrayResult<bool> {
233    let pattern = format!("'{key}':");
234    let pos = dict_body
235        .find(&pattern)
236        .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
237
238    let after_key = dict_body[pos + pattern.len()..].trim_start();
239
240    if after_key.starts_with("True") {
241        Ok(true)
242    } else if after_key.starts_with("False") {
243        Ok(false)
244    } else {
245        Err(FerrayError::io_error(format!(
246            "expected True/False for key '{key}'"
247        )))
248    }
249}
250
251/// Extract a tuple shape value for a given key from the dict body.
252fn extract_shape_value(dict_body: &str, key: &str) -> FerrayResult<Vec<usize>> {
253    let pattern = format!("'{key}':");
254    let pos = dict_body
255        .find(&pattern)
256        .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
257
258    let after_key = dict_body[pos + pattern.len()..].trim_start();
259
260    // Find the opening paren
261    if !after_key.starts_with('(') {
262        return Err(FerrayError::io_error(format!(
263            "expected tuple for key '{key}'"
264        )));
265    }
266
267    let close = after_key
268        .find(')')
269        .ok_or_else(|| FerrayError::io_error(format!("unterminated tuple for key '{key}'")))?;
270
271    let tuple_inner = &after_key[1..close];
272    let tuple_inner = tuple_inner.trim();
273
274    if tuple_inner.is_empty() {
275        return Ok(vec![]);
276    }
277
278    let parts: FerrayResult<Vec<usize>> = tuple_inner
279        .split(',')
280        .filter(|s| !s.trim().is_empty())
281        .map(|s| {
282            s.trim()
283                .parse::<usize>()
284                .map_err(|e| FerrayError::io_error(format!("invalid shape dimension '{s}': {e}")))
285        })
286        .collect();
287
288    parts
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn parse_simple_header() {
297        let header = "{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }";
298        let (descr, fortran, shape) = parse_header_dict(header).unwrap();
299        assert_eq!(descr, "<f8");
300        assert!(!fortran);
301        assert_eq!(shape, vec![3, 4]);
302    }
303
304    #[test]
305    fn parse_1d_header() {
306        let header = "{'descr': '<i4', 'fortran_order': False, 'shape': (10,), }";
307        let (descr, fortran, shape) = parse_header_dict(header).unwrap();
308        assert_eq!(descr, "<i4");
309        assert!(!fortran);
310        assert_eq!(shape, vec![10]);
311    }
312
313    #[test]
314    fn parse_scalar_header() {
315        let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (), }";
316        let (descr, fortran, shape) = parse_header_dict(header).unwrap();
317        assert_eq!(descr, "<f4");
318        assert!(!fortran);
319        assert!(shape.is_empty());
320    }
321
322    #[test]
323    fn parse_fortran_order() {
324        let header = "{'descr': '<f8', 'fortran_order': True, 'shape': (2, 3), }";
325        let (_, fortran, _) = parse_header_dict(header).unwrap();
326        assert!(fortran);
327    }
328
329    #[test]
330    fn format_shape_empty() {
331        assert_eq!(format_shape(&[]), "()");
332    }
333
334    #[test]
335    fn format_shape_1d() {
336        assert_eq!(format_shape(&[5]), "(5,)");
337    }
338
339    #[test]
340    fn format_shape_2d() {
341        assert_eq!(format_shape(&[3, 4]), "(3, 4)");
342    }
343
344    #[test]
345    fn write_read_roundtrip() {
346        let mut buf = Vec::new();
347        write_header(&mut buf, DType::F64, &[3, 4], false).unwrap();
348
349        let mut cursor = std::io::Cursor::new(buf);
350        let header = read_header(&mut cursor).unwrap();
351
352        assert_eq!(header.dtype, DType::F64);
353        assert_eq!(header.shape, vec![3, 4]);
354        assert!(!header.fortran_order);
355    }
356
357    #[test]
358    fn header_alignment() {
359        for shape in [&[3, 4][..], &[100, 200, 300], &[1]] {
360            let mut buf = Vec::new();
361            write_header(&mut buf, DType::F64, shape, false).unwrap();
362            assert_eq!(
363                buf.len() % format::HEADER_ALIGNMENT,
364                0,
365                "header not aligned for shape {shape:?}"
366            );
367        }
368    }
369}