1use ferray_core::dtype::DType;
7use ferray_core::error::{FerrayError, FerrayResult};
8
9use super::dtype_parse::{self, Endianness};
10use crate::format;
11
12#[derive(Debug, Clone)]
14pub struct NpyHeader {
15 pub descr: String,
17 pub dtype: DType,
19 pub endianness: Endianness,
21 pub fortran_order: bool,
23 pub shape: Vec<usize>,
25 pub version: (u8, u8),
27}
28
29pub fn read_header<R: std::io::Read>(reader: &mut R) -> FerrayResult<NpyHeader> {
33 let mut magic = [0u8; format::NPY_MAGIC_LEN];
35 reader
36 .read_exact(&mut magic)
37 .map_err(|e| FerrayError::io_error(format!("failed to read .npy magic: {e}")))?;
38
39 if magic != *format::NPY_MAGIC {
40 return Err(FerrayError::io_error(
41 "not a valid .npy file: bad magic number",
42 ));
43 }
44
45 let mut version = [0u8; 2];
47 reader
48 .read_exact(&mut version)
49 .map_err(|e| FerrayError::io_error(format!("failed to read .npy version: {e}")))?;
50
51 let major = version[0];
52 let minor = version[1];
53
54 if !matches!((major, minor), (1, 0) | (2, 0) | (3, 0)) {
55 return Err(FerrayError::io_error(format!(
56 "unsupported .npy format version {major}.{minor}"
57 )));
58 }
59
60 let header_len = if major == 1 {
62 let mut buf = [0u8; 2];
63 reader
64 .read_exact(&mut buf)
65 .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
66 u16::from_le_bytes(buf) as usize
67 } else {
68 let mut buf = [0u8; 4];
69 reader
70 .read_exact(&mut buf)
71 .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
72 let raw_len = u32::from_le_bytes(buf) as usize;
73 const MAX_HEADER_LEN: usize = 1_048_576;
76 if raw_len > MAX_HEADER_LEN {
77 return Err(FerrayError::io_error(format!(
78 "header length {raw_len} exceeds maximum allowed size ({MAX_HEADER_LEN} bytes)"
79 )));
80 }
81 raw_len
82 };
83
84 let mut header_bytes = vec![0u8; header_len];
86 reader
87 .read_exact(&mut header_bytes)
88 .map_err(|e| FerrayError::io_error(format!("failed to read header: {e}")))?;
89
90 let header_str = std::str::from_utf8(&header_bytes)
91 .map_err(|e| FerrayError::io_error(format!("header is not valid UTF-8: {e}")))?;
92
93 let (descr, fortran_order, shape) = parse_header_dict(header_str)?;
95 let (dtype, endianness) = dtype_parse::parse_dtype_str(&descr)?;
96
97 Ok(NpyHeader {
98 descr,
99 dtype,
100 endianness,
101 fortran_order,
102 shape,
103 version: (major, minor),
104 })
105}
106
107pub fn write_header<W: std::io::Write>(
109 writer: &mut W,
110 dtype: DType,
111 shape: &[usize],
112 fortran_order: bool,
113) -> FerrayResult<()> {
114 let descr = dtype_parse::dtype_to_native_descr(dtype)?;
115 let fortran_str = if fortran_order { "True" } else { "False" };
116
117 let shape_str = format_shape(shape);
118
119 let dict =
120 format!("{{'descr': '{descr}', 'fortran_order': {fortran_str}, 'shape': {shape_str}, }}");
121
122 let preamble_v1 = format::NPY_MAGIC_LEN + 2 + 2; let padding_needed_v1 = compute_padding(preamble_v1 + dict.len() + 1); let total_header_v1 = dict.len() + padding_needed_v1 + 1;
128
129 if total_header_v1 <= format::MAX_HEADER_LEN_V1 {
130 writer.write_all(format::NPY_MAGIC)?;
132 writer.write_all(&[1, 0])?;
133 writer.write_all(&(total_header_v1 as u16).to_le_bytes())?;
134 writer.write_all(dict.as_bytes())?;
135 write_padding(writer, padding_needed_v1)?;
136 writer.write_all(b"\n")?;
137 } else {
138 let preamble_v2 = format::NPY_MAGIC_LEN + 2 + 4; let padding_needed_v2 = compute_padding(preamble_v2 + dict.len() + 1);
141 let total_header_v2 = dict.len() + padding_needed_v2 + 1;
142
143 writer.write_all(format::NPY_MAGIC)?;
144 writer.write_all(&[2, 0])?;
145 writer.write_all(&(total_header_v2 as u32).to_le_bytes())?;
146 writer.write_all(dict.as_bytes())?;
147 write_padding(writer, padding_needed_v2)?;
148 writer.write_all(b"\n")?;
149 }
150
151 Ok(())
152}
153
154pub fn compute_data_offset(version: (u8, u8), header_len: usize) -> usize {
157 let preamble = format::NPY_MAGIC_LEN + 2 + if version.0 == 1 { 2 } else { 4 };
158 preamble + header_len
159}
160
161fn compute_padding(current_total: usize) -> usize {
162 let remainder = current_total % format::HEADER_ALIGNMENT;
163 if remainder == 0 {
164 0
165 } else {
166 format::HEADER_ALIGNMENT - remainder
167 }
168}
169
170fn write_padding<W: std::io::Write>(writer: &mut W, count: usize) -> FerrayResult<()> {
173 const MAX_STACK: usize = 128;
176 if count <= MAX_STACK {
177 let buf = [b' '; MAX_STACK];
178 writer.write_all(&buf[..count])?;
179 } else {
180 let buf = vec![b' '; count];
181 writer.write_all(&buf)?;
182 }
183 Ok(())
184}
185
186fn format_shape(shape: &[usize]) -> String {
187 match shape.len() {
188 0 => "()".to_string(),
189 1 => format!("({},)", shape[0]),
190 _ => {
191 let parts: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
192 format!("({})", parts.join(", "))
193 }
194 }
195}
196
197fn parse_header_dict(header: &str) -> FerrayResult<(String, bool, Vec<usize>)> {
202 let header = header.trim();
203
204 let inner = header
206 .strip_prefix('{')
207 .and_then(|s| s.strip_suffix('}'))
208 .ok_or_else(|| FerrayError::io_error("header dict missing braces"))?
209 .trim();
210
211 let descr = extract_string_value(inner, "descr")?;
212 let fortran_order = extract_bool_value(inner, "fortran_order")?;
213 let shape = extract_shape_value(inner, "shape")?;
214
215 Ok((descr, fortran_order, shape))
216}
217
218fn extract_string_value(dict_body: &str, key: &str) -> FerrayResult<String> {
220 let pattern = format!("'{key}':");
222 let pos = dict_body
223 .find(&pattern)
224 .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
225
226 let after_key = &dict_body[pos + pattern.len()..].trim_start();
227
228 let quote_char = after_key
230 .as_bytes()
231 .first()
232 .ok_or_else(|| FerrayError::io_error(format!("missing value for key '{key}'")))?;
233
234 if *quote_char != b'\'' && *quote_char != b'"' {
235 return Err(FerrayError::io_error(format!(
236 "expected string value for key '{key}'"
237 )));
238 }
239
240 let qc = *quote_char as char;
241 let value_start = &after_key[1..];
242 let end = value_start
243 .find(qc)
244 .ok_or_else(|| FerrayError::io_error(format!("unterminated string for key '{key}'")))?;
245
246 Ok(value_start[..end].to_string())
247}
248
249fn extract_bool_value(dict_body: &str, key: &str) -> FerrayResult<bool> {
251 let pattern = format!("'{key}':");
252 let pos = dict_body
253 .find(&pattern)
254 .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
255
256 let after_key = dict_body[pos + pattern.len()..].trim_start();
257
258 if after_key.starts_with("True") {
259 Ok(true)
260 } else if after_key.starts_with("False") {
261 Ok(false)
262 } else {
263 Err(FerrayError::io_error(format!(
264 "expected True/False for key '{key}'"
265 )))
266 }
267}
268
269fn extract_shape_value(dict_body: &str, key: &str) -> FerrayResult<Vec<usize>> {
271 let pattern = format!("'{key}':");
272 let pos = dict_body
273 .find(&pattern)
274 .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
275
276 let after_key = dict_body[pos + pattern.len()..].trim_start();
277
278 if !after_key.starts_with('(') {
280 return Err(FerrayError::io_error(format!(
281 "expected tuple for key '{key}'"
282 )));
283 }
284
285 let close = after_key
286 .find(')')
287 .ok_or_else(|| FerrayError::io_error(format!("unterminated tuple for key '{key}'")))?;
288
289 let tuple_inner = &after_key[1..close];
290 let tuple_inner = tuple_inner.trim();
291
292 if tuple_inner.is_empty() {
293 return Ok(vec![]);
294 }
295
296 let parts: FerrayResult<Vec<usize>> = tuple_inner
297 .split(',')
298 .filter(|s| !s.trim().is_empty())
299 .map(|s| {
300 s.trim()
301 .parse::<usize>()
302 .map_err(|e| FerrayError::io_error(format!("invalid shape dimension '{s}': {e}")))
303 })
304 .collect();
305
306 parts
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn parse_simple_header() {
315 let header = "{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }";
316 let (descr, fortran, shape) = parse_header_dict(header).unwrap();
317 assert_eq!(descr, "<f8");
318 assert!(!fortran);
319 assert_eq!(shape, vec![3, 4]);
320 }
321
322 #[test]
323 fn parse_1d_header() {
324 let header = "{'descr': '<i4', 'fortran_order': False, 'shape': (10,), }";
325 let (descr, fortran, shape) = parse_header_dict(header).unwrap();
326 assert_eq!(descr, "<i4");
327 assert!(!fortran);
328 assert_eq!(shape, vec![10]);
329 }
330
331 #[test]
332 fn parse_scalar_header() {
333 let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (), }";
334 let (descr, fortran, shape) = parse_header_dict(header).unwrap();
335 assert_eq!(descr, "<f4");
336 assert!(!fortran);
337 assert!(shape.is_empty());
338 }
339
340 #[test]
341 fn parse_fortran_order() {
342 let header = "{'descr': '<f8', 'fortran_order': True, 'shape': (2, 3), }";
343 let (_, fortran, _) = parse_header_dict(header).unwrap();
344 assert!(fortran);
345 }
346
347 #[test]
348 fn format_shape_empty() {
349 assert_eq!(format_shape(&[]), "()");
350 }
351
352 #[test]
353 fn format_shape_1d() {
354 assert_eq!(format_shape(&[5]), "(5,)");
355 }
356
357 #[test]
358 fn format_shape_2d() {
359 assert_eq!(format_shape(&[3, 4]), "(3, 4)");
360 }
361
362 #[test]
363 fn write_read_roundtrip() {
364 let mut buf = Vec::new();
365 write_header(&mut buf, DType::F64, &[3, 4], false).unwrap();
366
367 let mut cursor = std::io::Cursor::new(buf);
368 let header = read_header(&mut cursor).unwrap();
369
370 assert_eq!(header.dtype, DType::F64);
371 assert_eq!(header.shape, vec![3, 4]);
372 assert!(!header.fortran_order);
373 }
374
375 #[test]
376 fn header_alignment() {
377 for shape in [&[3, 4][..], &[100, 200, 300], &[1]] {
378 let mut buf = Vec::new();
379 write_header(&mut buf, DType::F64, shape, false).unwrap();
380 assert_eq!(
381 buf.len() % format::HEADER_ALIGNMENT,
382 0,
383 "header not aligned for shape {shape:?}"
384 );
385 }
386 }
387}