1use ferray_core::dtype::DType;
7use ferray_core::error::{FerrayError, FerrayResult};
8
9use super::dtype_parse::{self, Endianness};
10use crate::format;
11
12#[derive(Debug, Clone)]
14pub struct NpyHeader {
15 pub descr: String,
17 pub dtype: DType,
19 pub endianness: Endianness,
21 pub fortran_order: bool,
23 pub shape: Vec<usize>,
25 pub version: (u8, u8),
27}
28
29pub fn read_header<R: std::io::Read>(reader: &mut R) -> FerrayResult<NpyHeader> {
33 let mut magic = [0u8; format::NPY_MAGIC_LEN];
35 reader
36 .read_exact(&mut magic)
37 .map_err(|e| FerrayError::io_error(format!("failed to read .npy magic: {e}")))?;
38
39 if magic != *format::NPY_MAGIC {
40 return Err(FerrayError::io_error(
41 "not a valid .npy file: bad magic number",
42 ));
43 }
44
45 let mut version = [0u8; 2];
47 reader
48 .read_exact(&mut version)
49 .map_err(|e| FerrayError::io_error(format!("failed to read .npy version: {e}")))?;
50
51 let major = version[0];
52 let minor = version[1];
53
54 if !matches!((major, minor), (1, 0) | (2, 0) | (3, 0)) {
55 return Err(FerrayError::io_error(format!(
56 "unsupported .npy format version {major}.{minor}"
57 )));
58 }
59
60 let header_len = if major == 1 {
62 let mut buf = [0u8; 2];
63 reader
64 .read_exact(&mut buf)
65 .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
66 u16::from_le_bytes(buf) as usize
67 } else {
68 let mut buf = [0u8; 4];
69 reader
70 .read_exact(&mut buf)
71 .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
72 u32::from_le_bytes(buf) as usize
73 };
74
75 let mut header_bytes = vec![0u8; header_len];
77 reader
78 .read_exact(&mut header_bytes)
79 .map_err(|e| FerrayError::io_error(format!("failed to read header: {e}")))?;
80
81 let header_str = std::str::from_utf8(&header_bytes)
82 .map_err(|e| FerrayError::io_error(format!("header is not valid UTF-8: {e}")))?;
83
84 let (descr, fortran_order, shape) = parse_header_dict(header_str)?;
86 let (dtype, endianness) = dtype_parse::parse_dtype_str(&descr)?;
87
88 Ok(NpyHeader {
89 descr,
90 dtype,
91 endianness,
92 fortran_order,
93 shape,
94 version: (major, minor),
95 })
96}
97
98pub fn write_header<W: std::io::Write>(
100 writer: &mut W,
101 dtype: DType,
102 shape: &[usize],
103 fortran_order: bool,
104) -> FerrayResult<()> {
105 let descr = dtype_parse::dtype_to_native_descr(dtype)?;
106 let fortran_str = if fortran_order { "True" } else { "False" };
107
108 let shape_str = format_shape(shape);
109
110 let dict =
111 format!("{{'descr': '{descr}', 'fortran_order': {fortran_str}, 'shape': {shape_str}, }}");
112
113 let preamble_v1 = format::NPY_MAGIC_LEN + 2 + 2; let padding_needed_v1 = compute_padding(preamble_v1 + dict.len() + 1); let total_header_v1 = dict.len() + padding_needed_v1 + 1;
119
120 if total_header_v1 <= format::MAX_HEADER_LEN_V1 {
121 writer.write_all(format::NPY_MAGIC)?;
123 writer.write_all(&[1, 0])?;
124 writer.write_all(&(total_header_v1 as u16).to_le_bytes())?;
125 writer.write_all(dict.as_bytes())?;
126 write_padding(writer, padding_needed_v1)?;
127 writer.write_all(b"\n")?;
128 } else {
129 let preamble_v2 = format::NPY_MAGIC_LEN + 2 + 4; let padding_needed_v2 = compute_padding(preamble_v2 + dict.len() + 1);
132 let total_header_v2 = dict.len() + padding_needed_v2 + 1;
133
134 writer.write_all(format::NPY_MAGIC)?;
135 writer.write_all(&[2, 0])?;
136 writer.write_all(&(total_header_v2 as u32).to_le_bytes())?;
137 writer.write_all(dict.as_bytes())?;
138 write_padding(writer, padding_needed_v2)?;
139 writer.write_all(b"\n")?;
140 }
141
142 Ok(())
143}
144
145pub fn compute_data_offset(version: (u8, u8), header_len: usize) -> usize {
148 let preamble = format::NPY_MAGIC_LEN + 2 + if version.0 == 1 { 2 } else { 4 };
149 preamble + header_len
150}
151
152fn compute_padding(current_total: usize) -> usize {
153 let remainder = current_total % format::HEADER_ALIGNMENT;
154 if remainder == 0 {
155 0
156 } else {
157 format::HEADER_ALIGNMENT - remainder
158 }
159}
160
161fn write_padding<W: std::io::Write>(writer: &mut W, count: usize) -> FerrayResult<()> {
162 for _ in 0..count {
163 writer.write_all(b" ")?;
164 }
165 Ok(())
166}
167
168fn format_shape(shape: &[usize]) -> String {
169 match shape.len() {
170 0 => "()".to_string(),
171 1 => format!("({},)", shape[0]),
172 _ => {
173 let parts: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
174 format!("({})", parts.join(", "))
175 }
176 }
177}
178
179fn parse_header_dict(header: &str) -> FerrayResult<(String, bool, Vec<usize>)> {
184 let header = header.trim();
185
186 let inner = header
188 .strip_prefix('{')
189 .and_then(|s| s.strip_suffix('}'))
190 .ok_or_else(|| FerrayError::io_error("header dict missing braces"))?
191 .trim();
192
193 let descr = extract_string_value(inner, "descr")?;
194 let fortran_order = extract_bool_value(inner, "fortran_order")?;
195 let shape = extract_shape_value(inner, "shape")?;
196
197 Ok((descr, fortran_order, shape))
198}
199
200fn extract_string_value(dict_body: &str, key: &str) -> FerrayResult<String> {
202 let pattern = format!("'{key}':");
204 let pos = dict_body
205 .find(&pattern)
206 .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
207
208 let after_key = &dict_body[pos + pattern.len()..].trim_start();
209
210 let quote_char = after_key
212 .as_bytes()
213 .first()
214 .ok_or_else(|| FerrayError::io_error(format!("missing value for key '{key}'")))?;
215
216 if *quote_char != b'\'' && *quote_char != b'"' {
217 return Err(FerrayError::io_error(format!(
218 "expected string value for key '{key}'"
219 )));
220 }
221
222 let qc = *quote_char as char;
223 let value_start = &after_key[1..];
224 let end = value_start
225 .find(qc)
226 .ok_or_else(|| FerrayError::io_error(format!("unterminated string for key '{key}'")))?;
227
228 Ok(value_start[..end].to_string())
229}
230
231fn extract_bool_value(dict_body: &str, key: &str) -> FerrayResult<bool> {
233 let pattern = format!("'{key}':");
234 let pos = dict_body
235 .find(&pattern)
236 .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
237
238 let after_key = dict_body[pos + pattern.len()..].trim_start();
239
240 if after_key.starts_with("True") {
241 Ok(true)
242 } else if after_key.starts_with("False") {
243 Ok(false)
244 } else {
245 Err(FerrayError::io_error(format!(
246 "expected True/False for key '{key}'"
247 )))
248 }
249}
250
251fn extract_shape_value(dict_body: &str, key: &str) -> FerrayResult<Vec<usize>> {
253 let pattern = format!("'{key}':");
254 let pos = dict_body
255 .find(&pattern)
256 .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
257
258 let after_key = dict_body[pos + pattern.len()..].trim_start();
259
260 if !after_key.starts_with('(') {
262 return Err(FerrayError::io_error(format!(
263 "expected tuple for key '{key}'"
264 )));
265 }
266
267 let close = after_key
268 .find(')')
269 .ok_or_else(|| FerrayError::io_error(format!("unterminated tuple for key '{key}'")))?;
270
271 let tuple_inner = &after_key[1..close];
272 let tuple_inner = tuple_inner.trim();
273
274 if tuple_inner.is_empty() {
275 return Ok(vec![]);
276 }
277
278 let parts: FerrayResult<Vec<usize>> = tuple_inner
279 .split(',')
280 .filter(|s| !s.trim().is_empty())
281 .map(|s| {
282 s.trim()
283 .parse::<usize>()
284 .map_err(|e| FerrayError::io_error(format!("invalid shape dimension '{s}': {e}")))
285 })
286 .collect();
287
288 parts
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn parse_simple_header() {
297 let header = "{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }";
298 let (descr, fortran, shape) = parse_header_dict(header).unwrap();
299 assert_eq!(descr, "<f8");
300 assert!(!fortran);
301 assert_eq!(shape, vec![3, 4]);
302 }
303
304 #[test]
305 fn parse_1d_header() {
306 let header = "{'descr': '<i4', 'fortran_order': False, 'shape': (10,), }";
307 let (descr, fortran, shape) = parse_header_dict(header).unwrap();
308 assert_eq!(descr, "<i4");
309 assert!(!fortran);
310 assert_eq!(shape, vec![10]);
311 }
312
313 #[test]
314 fn parse_scalar_header() {
315 let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (), }";
316 let (descr, fortran, shape) = parse_header_dict(header).unwrap();
317 assert_eq!(descr, "<f4");
318 assert!(!fortran);
319 assert!(shape.is_empty());
320 }
321
322 #[test]
323 fn parse_fortran_order() {
324 let header = "{'descr': '<f8', 'fortran_order': True, 'shape': (2, 3), }";
325 let (_, fortran, _) = parse_header_dict(header).unwrap();
326 assert!(fortran);
327 }
328
329 #[test]
330 fn format_shape_empty() {
331 assert_eq!(format_shape(&[]), "()");
332 }
333
334 #[test]
335 fn format_shape_1d() {
336 assert_eq!(format_shape(&[5]), "(5,)");
337 }
338
339 #[test]
340 fn format_shape_2d() {
341 assert_eq!(format_shape(&[3, 4]), "(3, 4)");
342 }
343
344 #[test]
345 fn write_read_roundtrip() {
346 let mut buf = Vec::new();
347 write_header(&mut buf, DType::F64, &[3, 4], false).unwrap();
348
349 let mut cursor = std::io::Cursor::new(buf);
350 let header = read_header(&mut cursor).unwrap();
351
352 assert_eq!(header.dtype, DType::F64);
353 assert_eq!(header.shape, vec![3, 4]);
354 assert!(!header.fortran_order);
355 }
356
357 #[test]
358 fn header_alignment() {
359 for shape in [&[3, 4][..], &[100, 200, 300], &[1]] {
360 let mut buf = Vec::new();
361 write_header(&mut buf, DType::F64, shape, false).unwrap();
362 assert_eq!(
363 buf.len() % format::HEADER_ALIGNMENT,
364 0,
365 "header not aligned for shape {shape:?}"
366 );
367 }
368 }
369}