1use ferray_core::dtype::DType;
7use ferray_core::error::{FerrayError, FerrayResult};
8
9use super::dtype_parse::{self, Endianness};
10use crate::format;
11
12#[derive(Debug, Clone)]
14pub struct NpyHeader {
15 pub descr: String,
17 pub dtype: DType,
19 pub endianness: Endianness,
21 pub fortran_order: bool,
23 pub shape: Vec<usize>,
25 pub version: (u8, u8),
27}
28
29pub fn read_header<R: std::io::Read>(reader: &mut R) -> FerrayResult<NpyHeader> {
33 let mut magic = [0u8; format::NPY_MAGIC_LEN];
35 reader
36 .read_exact(&mut magic)
37 .map_err(|e| FerrayError::io_error(format!("failed to read .npy magic: {e}")))?;
38
39 if magic != *format::NPY_MAGIC {
40 return Err(FerrayError::io_error(
41 "not a valid .npy file: bad magic number",
42 ));
43 }
44
45 let mut version = [0u8; 2];
47 reader
48 .read_exact(&mut version)
49 .map_err(|e| FerrayError::io_error(format!("failed to read .npy version: {e}")))?;
50
51 let major = version[0];
52 let minor = version[1];
53
54 if !matches!((major, minor), (1, 0) | (2, 0) | (3, 0)) {
55 return Err(FerrayError::io_error(format!(
56 "unsupported .npy format version {major}.{minor}"
57 )));
58 }
59
60 let header_len = if major == 1 {
62 let mut buf = [0u8; 2];
63 reader
64 .read_exact(&mut buf)
65 .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
66 u16::from_le_bytes(buf) as usize
67 } else {
68 let mut buf = [0u8; 4];
69 reader
70 .read_exact(&mut buf)
71 .map_err(|e| FerrayError::io_error(format!("failed to read header length: {e}")))?;
72 let raw_len = u32::from_le_bytes(buf) as usize;
73 const MAX_HEADER_LEN: usize = 1_048_576;
76 if raw_len > MAX_HEADER_LEN {
77 return Err(FerrayError::io_error(format!(
78 "header length {raw_len} exceeds maximum allowed size ({MAX_HEADER_LEN} bytes)"
79 )));
80 }
81 raw_len
82 };
83
84 let mut header_bytes = vec![0u8; header_len];
86 reader
87 .read_exact(&mut header_bytes)
88 .map_err(|e| FerrayError::io_error(format!("failed to read header: {e}")))?;
89
90 let header_str = std::str::from_utf8(&header_bytes)
91 .map_err(|e| FerrayError::io_error(format!("header is not valid UTF-8: {e}")))?;
92
93 let (descr, fortran_order, shape) = parse_header_dict(header_str)?;
95 let (dtype, endianness) = dtype_parse::parse_dtype_str(&descr)?;
96
97 Ok(NpyHeader {
98 descr,
99 dtype,
100 endianness,
101 fortran_order,
102 shape,
103 version: (major, minor),
104 })
105}
106
107pub fn write_header<W: std::io::Write>(
109 writer: &mut W,
110 dtype: DType,
111 shape: &[usize],
112 fortran_order: bool,
113) -> FerrayResult<()> {
114 let descr = dtype_parse::dtype_to_native_descr(dtype)?;
115 let fortran_str = if fortran_order { "True" } else { "False" };
116
117 let shape_str = format_shape(shape);
118
119 let dict =
120 format!("{{'descr': '{descr}', 'fortran_order': {fortran_str}, 'shape': {shape_str}, }}");
121
122 let preamble_v1 = format::NPY_MAGIC_LEN + 2 + 2; let padding_needed_v1 = compute_padding(preamble_v1 + dict.len() + 1); let total_header_v1 = dict.len() + padding_needed_v1 + 1;
128
129 if total_header_v1 <= format::MAX_HEADER_LEN_V1 {
130 writer.write_all(format::NPY_MAGIC)?;
132 writer.write_all(&[1, 0])?;
133 writer.write_all(&(total_header_v1 as u16).to_le_bytes())?;
134 writer.write_all(dict.as_bytes())?;
135 write_padding(writer, padding_needed_v1)?;
136 writer.write_all(b"\n")?;
137 } else {
138 let preamble_v2 = format::NPY_MAGIC_LEN + 2 + 4; let padding_needed_v2 = compute_padding(preamble_v2 + dict.len() + 1);
141 let total_header_v2 = dict.len() + padding_needed_v2 + 1;
142
143 writer.write_all(format::NPY_MAGIC)?;
144 writer.write_all(&[2, 0])?;
145 writer.write_all(&(total_header_v2 as u32).to_le_bytes())?;
146 writer.write_all(dict.as_bytes())?;
147 write_padding(writer, padding_needed_v2)?;
148 writer.write_all(b"\n")?;
149 }
150
151 Ok(())
152}
153
154pub fn compute_data_offset(version: (u8, u8), header_len: usize) -> usize {
157 let preamble = format::NPY_MAGIC_LEN + 2 + if version.0 == 1 { 2 } else { 4 };
158 preamble + header_len
159}
160
161fn compute_padding(current_total: usize) -> usize {
162 let remainder = current_total % format::HEADER_ALIGNMENT;
163 if remainder == 0 {
164 0
165 } else {
166 format::HEADER_ALIGNMENT - remainder
167 }
168}
169
170fn write_padding<W: std::io::Write>(writer: &mut W, count: usize) -> FerrayResult<()> {
171 for _ in 0..count {
172 writer.write_all(b" ")?;
173 }
174 Ok(())
175}
176
177fn format_shape(shape: &[usize]) -> String {
178 match shape.len() {
179 0 => "()".to_string(),
180 1 => format!("({},)", shape[0]),
181 _ => {
182 let parts: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
183 format!("({})", parts.join(", "))
184 }
185 }
186}
187
188fn parse_header_dict(header: &str) -> FerrayResult<(String, bool, Vec<usize>)> {
193 let header = header.trim();
194
195 let inner = header
197 .strip_prefix('{')
198 .and_then(|s| s.strip_suffix('}'))
199 .ok_or_else(|| FerrayError::io_error("header dict missing braces"))?
200 .trim();
201
202 let descr = extract_string_value(inner, "descr")?;
203 let fortran_order = extract_bool_value(inner, "fortran_order")?;
204 let shape = extract_shape_value(inner, "shape")?;
205
206 Ok((descr, fortran_order, shape))
207}
208
209fn extract_string_value(dict_body: &str, key: &str) -> FerrayResult<String> {
211 let pattern = format!("'{key}':");
213 let pos = dict_body
214 .find(&pattern)
215 .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
216
217 let after_key = &dict_body[pos + pattern.len()..].trim_start();
218
219 let quote_char = after_key
221 .as_bytes()
222 .first()
223 .ok_or_else(|| FerrayError::io_error(format!("missing value for key '{key}'")))?;
224
225 if *quote_char != b'\'' && *quote_char != b'"' {
226 return Err(FerrayError::io_error(format!(
227 "expected string value for key '{key}'"
228 )));
229 }
230
231 let qc = *quote_char as char;
232 let value_start = &after_key[1..];
233 let end = value_start
234 .find(qc)
235 .ok_or_else(|| FerrayError::io_error(format!("unterminated string for key '{key}'")))?;
236
237 Ok(value_start[..end].to_string())
238}
239
240fn extract_bool_value(dict_body: &str, key: &str) -> FerrayResult<bool> {
242 let pattern = format!("'{key}':");
243 let pos = dict_body
244 .find(&pattern)
245 .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
246
247 let after_key = dict_body[pos + pattern.len()..].trim_start();
248
249 if after_key.starts_with("True") {
250 Ok(true)
251 } else if after_key.starts_with("False") {
252 Ok(false)
253 } else {
254 Err(FerrayError::io_error(format!(
255 "expected True/False for key '{key}'"
256 )))
257 }
258}
259
260fn extract_shape_value(dict_body: &str, key: &str) -> FerrayResult<Vec<usize>> {
262 let pattern = format!("'{key}':");
263 let pos = dict_body
264 .find(&pattern)
265 .ok_or_else(|| FerrayError::io_error(format!("header missing key '{key}'")))?;
266
267 let after_key = dict_body[pos + pattern.len()..].trim_start();
268
269 if !after_key.starts_with('(') {
271 return Err(FerrayError::io_error(format!(
272 "expected tuple for key '{key}'"
273 )));
274 }
275
276 let close = after_key
277 .find(')')
278 .ok_or_else(|| FerrayError::io_error(format!("unterminated tuple for key '{key}'")))?;
279
280 let tuple_inner = &after_key[1..close];
281 let tuple_inner = tuple_inner.trim();
282
283 if tuple_inner.is_empty() {
284 return Ok(vec![]);
285 }
286
287 let parts: FerrayResult<Vec<usize>> = tuple_inner
288 .split(',')
289 .filter(|s| !s.trim().is_empty())
290 .map(|s| {
291 s.trim()
292 .parse::<usize>()
293 .map_err(|e| FerrayError::io_error(format!("invalid shape dimension '{s}': {e}")))
294 })
295 .collect();
296
297 parts
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn parse_simple_header() {
306 let header = "{'descr': '<f8', 'fortran_order': False, 'shape': (3, 4), }";
307 let (descr, fortran, shape) = parse_header_dict(header).unwrap();
308 assert_eq!(descr, "<f8");
309 assert!(!fortran);
310 assert_eq!(shape, vec![3, 4]);
311 }
312
313 #[test]
314 fn parse_1d_header() {
315 let header = "{'descr': '<i4', 'fortran_order': False, 'shape': (10,), }";
316 let (descr, fortran, shape) = parse_header_dict(header).unwrap();
317 assert_eq!(descr, "<i4");
318 assert!(!fortran);
319 assert_eq!(shape, vec![10]);
320 }
321
322 #[test]
323 fn parse_scalar_header() {
324 let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (), }";
325 let (descr, fortran, shape) = parse_header_dict(header).unwrap();
326 assert_eq!(descr, "<f4");
327 assert!(!fortran);
328 assert!(shape.is_empty());
329 }
330
331 #[test]
332 fn parse_fortran_order() {
333 let header = "{'descr': '<f8', 'fortran_order': True, 'shape': (2, 3), }";
334 let (_, fortran, _) = parse_header_dict(header).unwrap();
335 assert!(fortran);
336 }
337
338 #[test]
339 fn format_shape_empty() {
340 assert_eq!(format_shape(&[]), "()");
341 }
342
343 #[test]
344 fn format_shape_1d() {
345 assert_eq!(format_shape(&[5]), "(5,)");
346 }
347
348 #[test]
349 fn format_shape_2d() {
350 assert_eq!(format_shape(&[3, 4]), "(3, 4)");
351 }
352
353 #[test]
354 fn write_read_roundtrip() {
355 let mut buf = Vec::new();
356 write_header(&mut buf, DType::F64, &[3, 4], false).unwrap();
357
358 let mut cursor = std::io::Cursor::new(buf);
359 let header = read_header(&mut cursor).unwrap();
360
361 assert_eq!(header.dtype, DType::F64);
362 assert_eq!(header.shape, vec![3, 4]);
363 assert!(!header.fortran_order);
364 }
365
366 #[test]
367 fn header_alignment() {
368 for shape in [&[3, 4][..], &[100, 200, 300], &[1]] {
369 let mut buf = Vec::new();
370 write_header(&mut buf, DType::F64, shape, false).unwrap();
371 assert_eq!(
372 buf.len() % format::HEADER_ALIGNMENT,
373 0,
374 "header not aligned for shape {shape:?}"
375 );
376 }
377 }
378}