1use anyhow::{Context, Result, bail};
28use std::{collections::HashMap, io::Read, path::Path};
29use zip::ZipArchive;
30
31pub fn parse_npy(data: &[u8]) -> Result<(Vec<usize>, Vec<f32>)> {
38 if data.len() < 10 || &data[..6] != b"\x93NUMPY" {
40 bail!("Not a valid NPY file (bad magic)");
41 }
42
43 let major = data[6];
44 let minor = data[7];
45
46 let (header_len, header_start) = match (major, minor) {
48 (1, _) => {
49 let len = u16::from_le_bytes([data[8], data[9]]) as usize;
50 (len, 10)
51 }
52 (2, _) => {
53 if data.len() < 12 {
54 bail!("NPY v2 file too short");
55 }
56 let len = u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize;
57 (len, 12)
58 }
59 _ => bail!("Unsupported NPY version {}.{}", major, minor),
60 };
61
62 let header_end = header_start + header_len;
63 if data.len() < header_end {
64 bail!("NPY file truncated in header");
65 }
66 let header = std::str::from_utf8(&data[header_start..header_end])
67 .context("NPY header is not valid UTF-8")?;
68
69 let dtype = extract_header_field(header, "descr").context("NPY header missing 'descr'")?;
71 let dtype = dtype.trim().trim_matches('\'').trim_matches('"');
72
73 let is_f32 = matches!(dtype, "<f4" | "=f4" | "|f4" | ">f4");
75 if !is_f32 {
76 bail!("Unsupported dtype '{}' — only float32 is supported", dtype);
77 }
78 let big_endian = dtype.starts_with('>');
79
80 let fortran = extract_header_field(header, "fortran_order")
82 .unwrap_or("False")
83 .trim()
84 .to_ascii_lowercase();
85 if fortran == "true" {
86 bail!("Fortran-order arrays are not supported");
87 }
88
89 let shape_str = extract_header_field(header, "shape").context("NPY header missing 'shape'")?;
91 let shape = parse_shape(shape_str.trim())?;
92
93 let n_elements: usize = shape.iter().product();
95
96 let data_bytes = &data[header_end..];
98 if data_bytes.len() < n_elements * 4 {
99 bail!(
100 "NPY data section too short: expected {} bytes, got {}",
101 n_elements * 4,
102 data_bytes.len()
103 );
104 }
105
106 let values: Vec<f32> = data_bytes[..n_elements * 4]
108 .chunks_exact(4)
109 .map(|b| {
110 let arr = [b[0], b[1], b[2], b[3]];
111 if big_endian {
112 f32::from_be_bytes(arr)
113 } else {
114 f32::from_le_bytes(arr)
115 }
116 })
117 .collect();
118
119 Ok((shape, values))
120}
121
122fn extract_header_field<'a>(header: &'a str, field: &str) -> Option<&'a str> {
127 let key_sq = format!("'{}':", field);
129 let key_dq = format!("\"{}\":", field);
130
131 let start = header
132 .find(key_sq.as_str())
133 .map(|p| p + key_sq.len())
134 .or_else(|| header.find(key_dq.as_str()).map(|p| p + key_dq.len()))?;
135
136 let rest = header[start..].trim_start();
137
138 if rest.starts_with('(') {
140 let end = rest.find(')')?;
142 Some(&rest[..end + 1])
143 } else if rest.starts_with('\'') || rest.starts_with('"') {
144 let quote = rest.chars().next()?;
145 let inner = &rest[1..];
146 let end = inner.find(quote)?;
147 Some(&inner[..end])
148 } else {
149 let end = rest.find([',', '}']).unwrap_or(rest.len());
151 Some(rest[..end].trim())
152 }
153}
154
155fn parse_shape(s: &str) -> Result<Vec<usize>> {
157 let inner = s.trim_start_matches('(').trim_end_matches(')');
158 if inner.trim().is_empty() {
159 return Ok(vec![]);
160 }
161 inner
162 .split(',')
163 .map(|t| t.trim())
164 .filter(|t| !t.is_empty())
165 .map(|t| {
166 t.parse::<usize>()
167 .with_context(|| format!("Bad shape dim: '{}'", t))
168 })
169 .collect()
170}
171
172pub struct NpyArray {
178 pub shape: Vec<usize>,
179 pub data: Vec<f32>,
180}
181
182impl NpyArray {
183 pub fn nrows(&self) -> usize {
185 self.shape.first().copied().unwrap_or(0)
186 }
187
188 pub fn ncols(&self) -> usize {
190 self.shape.get(1).copied().unwrap_or(1)
191 }
192
193 pub fn row(&self, i: usize) -> &[f32] {
195 let ncols = self.ncols();
196 &self.data[i * ncols..(i + 1) * ncols]
197 }
198}
199
200pub fn load_npz(path: &Path) -> Result<HashMap<String, NpyArray>> {
202 let file = std::fs::File::open(path)
203 .with_context(|| format!("Cannot open NPZ file: {}", path.display()))?;
204 let mut archive = ZipArchive::new(file)
205 .with_context(|| format!("Cannot open ZIP archive: {}", path.display()))?;
206
207 let mut arrays = HashMap::new();
208
209 for i in 0..archive.len() {
210 let mut entry = archive.by_index(i).context("Failed to read ZIP entry")?;
211 let name = entry.name().trim_end_matches(".npy").to_string();
212
213 let mut buf = Vec::with_capacity(entry.size() as usize);
214 entry
215 .read_to_end(&mut buf)
216 .context("Failed to read NPY entry")?;
217
218 let (shape, data) =
219 parse_npy(&buf).with_context(|| format!("Failed to parse NPY entry '{}'", name))?;
220
221 arrays.insert(name, NpyArray { shape, data });
222 }
223
224 Ok(arrays)
225}
226
227#[cfg(test)]
232mod tests {
233 use super::*;
234
235 fn make_npy(shape: &[usize], values: &[f32]) -> Vec<u8> {
237 let header_str = format!(
238 "{{'descr': '<f4', 'fortran_order': False, 'shape': ({},), }}",
239 shape
240 .iter()
241 .map(|d| d.to_string())
242 .collect::<Vec<_>>()
243 .join(", ")
244 );
245 let raw_len = header_str.len() + 1; let padded_len = raw_len.div_ceil(64) * 64;
248 let _header_len = padded_len - 1; let pad_needed = padded_len - raw_len;
251 let mut header = header_str;
252 for _ in 0..pad_needed {
253 header.push(' ');
254 }
255 header.push('\n');
256
257 let header_len_u16 = header.len() as u16;
258
259 let mut buf = Vec::new();
260 buf.extend_from_slice(b"\x93NUMPY");
261 buf.push(1); buf.push(0); buf.extend_from_slice(&header_len_u16.to_le_bytes());
264 buf.extend_from_slice(header.as_bytes());
265 for &v in values {
266 buf.extend_from_slice(&v.to_le_bytes());
267 }
268 buf
269 }
270
271 #[test]
272 fn test_parse_npy_1d() {
273 let values = vec![1.0f32, 2.0, 3.0];
274 let buf = make_npy(&[3], &values);
275 let (shape, data) = parse_npy(&buf).unwrap();
276 assert_eq!(shape, vec![3]);
277 assert_eq!(data, values);
278 }
279
280 #[test]
281 fn test_parse_npy_2d() {
282 let values: Vec<f32> = (0..6).map(|x| x as f32).collect();
283 let buf = make_npy(&[2, 3], &values);
284 let (shape, data) = parse_npy(&buf).unwrap();
285 assert_eq!(shape, vec![2, 3]);
286 assert_eq!(data, values);
287 }
288
289 #[test]
290 fn test_npy_array_row() {
291 let values: Vec<f32> = (0..6).map(|x| x as f32).collect();
292 let buf = make_npy(&[2, 3], &values);
293 let (shape, data) = parse_npy(&buf).unwrap();
294 let arr = NpyArray { shape, data };
295 assert_eq!(arr.row(0), &[0.0, 1.0, 2.0]);
296 assert_eq!(arr.row(1), &[3.0, 4.0, 5.0]);
297 }
298
299 #[test]
300 fn test_bad_magic() {
301 let result = parse_npy(b"NOTANPY");
302 assert!(result.is_err());
303 }
304}