use anyhow::{bail, Context, Result};
use std::{
collections::HashMap,
io::Read,
path::Path,
};
use zip::ZipArchive;
pub fn parse_npy(data: &[u8]) -> Result<(Vec<usize>, Vec<f32>)> {
if data.len() < 10 || &data[..6] != b"\x93NUMPY" {
bail!("Not a valid NPY file (bad magic)");
}
let major = data[6];
let minor = data[7];
let (header_len, header_start) = match (major, minor) {
(1, _) => {
let len = u16::from_le_bytes([data[8], data[9]]) as usize;
(len, 10)
}
(2, _) => {
if data.len() < 12 {
bail!("NPY v2 file too short");
}
let len = u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize;
(len, 12)
}
_ => bail!("Unsupported NPY version {}.{}", major, minor),
};
let header_end = header_start + header_len;
if data.len() < header_end {
bail!("NPY file truncated in header");
}
let header = std::str::from_utf8(&data[header_start..header_end])
.context("NPY header is not valid UTF-8")?;
let dtype = extract_header_field(header, "descr")
.context("NPY header missing 'descr'")?;
let dtype = dtype.trim().trim_matches('\'').trim_matches('"');
let is_f32 = matches!(dtype, "<f4" | "=f4" | "|f4" | ">f4");
if !is_f32 {
bail!("Unsupported dtype '{}' — only float32 is supported", dtype);
}
let big_endian = dtype.starts_with('>');
let fortran = extract_header_field(header, "fortran_order")
.unwrap_or("False")
.trim()
.to_ascii_lowercase();
if fortran == "true" {
bail!("Fortran-order arrays are not supported");
}
let shape_str = extract_header_field(header, "shape")
.context("NPY header missing 'shape'")?;
let shape = parse_shape(shape_str.trim())?;
let n_elements: usize = shape.iter().product();
let data_bytes = &data[header_end..];
if data_bytes.len() < n_elements * 4 {
bail!(
"NPY data section too short: expected {} bytes, got {}",
n_elements * 4,
data_bytes.len()
);
}
let values: Vec<f32> = data_bytes[..n_elements * 4]
.chunks_exact(4)
.map(|b| {
let arr = [b[0], b[1], b[2], b[3]];
if big_endian {
f32::from_be_bytes(arr)
} else {
f32::from_le_bytes(arr)
}
})
.collect();
Ok((shape, values))
}
fn extract_header_field<'a>(header: &'a str, field: &str) -> Option<&'a str> {
let key_sq = format!("'{}':", field);
let key_dq = format!("\"{}\":", field);
let start = header
.find(key_sq.as_str())
.map(|p| p + key_sq.len())
.or_else(|| header.find(key_dq.as_str()).map(|p| p + key_dq.len()))?;
let rest = header[start..].trim_start();
if rest.starts_with('(') {
let end = rest.find(')')?;
Some(&rest[..end + 1])
} else if rest.starts_with('\'') || rest.starts_with('"') {
let quote = rest.chars().next()?;
let inner = &rest[1..];
let end = inner.find(quote)?;
Some(&inner[..end])
} else {
let end = rest.find([',', '}']).unwrap_or(rest.len());
Some(rest[..end].trim())
}
}
fn parse_shape(s: &str) -> Result<Vec<usize>> {
let inner = s.trim_start_matches('(').trim_end_matches(')');
if inner.trim().is_empty() {
return Ok(vec![]);
}
inner
.split(',')
.map(|t| t.trim())
.filter(|t| !t.is_empty())
.map(|t| t.parse::<usize>().with_context(|| format!("Bad shape dim: '{}'", t)))
.collect()
}
pub struct NpyArray {
pub shape: Vec<usize>,
pub data: Vec<f32>,
}
impl NpyArray {
pub fn nrows(&self) -> usize {
self.shape.first().copied().unwrap_or(0)
}
pub fn ncols(&self) -> usize {
self.shape.get(1).copied().unwrap_or(1)
}
pub fn row(&self, i: usize) -> &[f32] {
let ncols = self.ncols();
&self.data[i * ncols..(i + 1) * ncols]
}
}
pub fn load_npz(path: &Path) -> Result<HashMap<String, NpyArray>> {
let file = std::fs::File::open(path)
.with_context(|| format!("Cannot open NPZ file: {}", path.display()))?;
let mut archive = ZipArchive::new(file)
.with_context(|| format!("Cannot open ZIP archive: {}", path.display()))?;
let mut arrays = HashMap::new();
for i in 0..archive.len() {
let mut entry = archive.by_index(i).context("Failed to read ZIP entry")?;
let name = entry
.name()
.trim_end_matches(".npy")
.to_string();
let mut buf = Vec::with_capacity(entry.size() as usize);
entry.read_to_end(&mut buf).context("Failed to read NPY entry")?;
let (shape, data) = parse_npy(&buf)
.with_context(|| format!("Failed to parse NPY entry '{}'", name))?;
arrays.insert(name, NpyArray { shape, data });
}
Ok(arrays)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_npy(shape: &[usize], values: &[f32]) -> Vec<u8> {
let header_str = format!(
"{{'descr': '<f4', 'fortran_order': False, 'shape': ({},), }}",
shape.iter().map(|d| d.to_string()).collect::<Vec<_>>().join(", ")
);
let raw_len = header_str.len() + 1; let padded_len = ((raw_len + 63) / 64) * 64;
let _header_len = padded_len - 1; let pad_needed = padded_len - raw_len;
let mut header = header_str;
for _ in 0..pad_needed {
header.push(' ');
}
header.push('\n');
let header_len_u16 = header.len() as u16;
let mut buf = Vec::new();
buf.extend_from_slice(b"\x93NUMPY");
buf.push(1); buf.push(0); buf.extend_from_slice(&header_len_u16.to_le_bytes());
buf.extend_from_slice(header.as_bytes());
for &v in values {
buf.extend_from_slice(&v.to_le_bytes());
}
buf
}
#[test]
fn test_parse_npy_1d() {
let values = vec![1.0f32, 2.0, 3.0];
let buf = make_npy(&[3], &values);
let (shape, data) = parse_npy(&buf).unwrap();
assert_eq!(shape, vec![3]);
assert_eq!(data, values);
}
#[test]
fn test_parse_npy_2d() {
let values: Vec<f32> = (0..6).map(|x| x as f32).collect();
let buf = make_npy(&[2, 3], &values);
let (shape, data) = parse_npy(&buf).unwrap();
assert_eq!(shape, vec![2, 3]);
assert_eq!(data, values);
}
#[test]
fn test_npy_array_row() {
let values: Vec<f32> = (0..6).map(|x| x as f32).collect();
let buf = make_npy(&[2, 3], &values);
let (shape, data) = parse_npy(&buf).unwrap();
let arr = NpyArray { shape, data };
assert_eq!(arr.row(0), &[0.0, 1.0, 2.0]);
assert_eq!(arr.row(1), &[3.0, 4.0, 5.0]);
}
#[test]
fn test_bad_magic() {
let result = parse_npy(b"NOTANPY");
assert!(result.is_err());
}
}