use std::collections::HashMap;
use std::fmt;
use std::io::{Read, Seek};
use std::path::Path;
use crate::error::AnamnesisError;
use crate::limits::Budget;
use crate::parse::utils::{byteswap_inplace, PREALLOC_SOFT_CAP};
use crate::ParseLimits;
const NPY_MAGIC: &[u8; 6] = b"\x93NUMPY";
const NPY_MAX_HEADER_BYTES: usize = 1 << 20;
const NPZ_MAX_ARRAY_BYTES: u64 = 1 << 33;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum NpzDtype {
Bool,
U8,
I8,
U16,
I16,
U32,
I32,
U64,
I64,
F16,
BF16,
F32,
F64,
}
impl NpzDtype {
#[must_use]
pub const fn byte_size(self) -> usize {
match self {
Self::Bool | Self::U8 | Self::I8 => 1,
Self::U16 | Self::I16 | Self::F16 | Self::BF16 => 2,
Self::U32 | Self::I32 | Self::F32 => 4,
Self::U64 | Self::I64 | Self::F64 => 8,
}
}
}
impl fmt::Display for NpzDtype {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::Bool => "BOOL",
Self::U8 => "U8",
Self::I8 => "I8",
Self::U16 => "U16",
Self::I16 => "I16",
Self::U32 => "U32",
Self::I32 => "I32",
Self::U64 => "U64",
Self::I64 => "I64",
Self::F16 => "F16",
Self::BF16 => "BF16",
Self::F32 => "F32",
Self::F64 => "F64",
};
f.write_str(s)
}
}
#[derive(Debug, Clone)]
pub struct NpzTensor {
pub name: String,
pub shape: Vec<usize>,
pub dtype: NpzDtype,
pub data: Vec<u8>,
}
struct NpyHeader {
dtype: NpzDtype,
big_endian: bool,
fortran_order: bool,
shape: Vec<usize>,
}
fn parse_npy_header(reader: &mut impl Read, budget: &mut Budget) -> crate::Result<NpyHeader> {
let mut preamble = [0u8; 8];
reader
.read_exact(&mut preamble)
.map_err(|e| AnamnesisError::Parse {
reason: format!("NPY preamble read failed: {e}"),
})?;
#[allow(clippy::indexing_slicing)]
if &preamble[..6] != NPY_MAGIC {
return Err(AnamnesisError::Parse {
reason: "invalid NPY magic bytes".into(),
});
}
#[allow(clippy::indexing_slicing)]
let major = preamble[6];
let header_len: usize = match major {
1 => {
let mut buf = [0u8; 2];
reader
.read_exact(&mut buf)
.map_err(|e| AnamnesisError::Parse {
reason: format!("NPY v1 header length read failed: {e}"),
})?;
usize::from(u16::from_le_bytes(buf))
}
2 | 3 => {
let mut buf = [0u8; 4];
reader
.read_exact(&mut buf)
.map_err(|e| AnamnesisError::Parse {
reason: format!("NPY v{major} header length read failed: {e}"),
})?;
#[allow(clippy::as_conversions)]
let len = u32::from_le_bytes(buf) as usize;
len
}
_ => {
return Err(AnamnesisError::Unsupported {
format: "NPY".into(),
detail: format!("unsupported NPY version {major}"),
});
}
};
if header_len > NPY_MAX_HEADER_BYTES {
return Err(AnamnesisError::Parse {
reason: format!(
"NPY header length {header_len} bytes exceeds the \
{NPY_MAX_HEADER_BYTES}-byte cap"
),
});
}
#[allow(clippy::as_conversions)]
budget.charge_alloc(header_len as u64, "NPY header")?;
let mut header_buf = vec![0u8; header_len];
reader
.read_exact(&mut header_buf)
.map_err(|e| AnamnesisError::Parse {
reason: format!("NPY header data read failed: {e}"),
})?;
let header_str = std::str::from_utf8(&header_buf).map_err(|e| AnamnesisError::Parse {
reason: format!("NPY header is not valid UTF-8: {e}"),
})?;
let (dtype, big_endian) = extract_descr(header_str)?;
let fortran_order = extract_fortran_order(header_str);
let shape = extract_shape(header_str)?;
Ok(NpyHeader {
dtype,
big_endian,
fortran_order,
shape,
})
}
fn extract_descr(header: &str) -> crate::Result<(NpzDtype, bool)> {
let descr_start = header.find("'descr'").or_else(|| header.find("\"descr\""));
let descr_start = descr_start.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header missing 'descr' field".into(),
})?;
let after_key = header
.get(descr_start..)
.and_then(|s| s.find(':').map(|i| descr_start + i + 1))
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'descr' field has no value".into(),
})?;
let value_str = header
.get(after_key..)
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header truncated after 'descr'".into(),
})?;
let trimmed = value_str.trim_start();
let quote_char = match trimmed.as_bytes().first() {
Some(b'\'') => '\'',
Some(b'"') => '"',
_ => {
return Err(AnamnesisError::Parse {
reason: "NPY header 'descr' value not quoted".into(),
});
}
};
let inner = trimmed.get(1..).ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'descr' value truncated after opening quote".into(),
})?;
let closing = inner
.find(quote_char)
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'descr' value missing closing quote".into(),
})?;
let descr = inner.get(..closing).ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'descr' extraction failed".into(),
})?;
parse_descr(descr)
}
fn parse_descr(descr: &str) -> crate::Result<(NpzDtype, bool)> {
let bytes = descr.as_bytes();
if bytes.len() < 2 {
return Err(AnamnesisError::Unsupported {
format: "NPY".into(),
detail: format!("dtype descriptor too short: '{descr}'"),
});
}
#[allow(clippy::indexing_slicing)]
let endian_char = bytes[0];
let type_str = descr.get(1..).ok_or_else(|| AnamnesisError::Unsupported {
format: "NPY".into(),
detail: format!("dtype descriptor is not ASCII: '{descr}'"),
})?;
let big_endian = endian_char == b'>';
let dtype = match type_str {
"b1" => NpzDtype::Bool,
"u1" => NpzDtype::U8,
"u2" => NpzDtype::U16,
"u4" => NpzDtype::U32,
"u8" => NpzDtype::U64,
"i1" => NpzDtype::I8,
"i2" => NpzDtype::I16,
"i4" => NpzDtype::I32,
"i8" => NpzDtype::I64,
"f2" => NpzDtype::F16,
"f4" => NpzDtype::F32,
"f8" => NpzDtype::F64,
"V2" => NpzDtype::BF16,
_ => {
return Err(AnamnesisError::Unsupported {
format: "NPY".into(),
detail: format!("unsupported dtype descriptor '{descr}'"),
});
}
};
Ok((dtype, big_endian))
}
fn extract_fortran_order(header: &str) -> bool {
header
.find("'fortran_order'")
.or_else(|| header.find("\"fortran_order\""))
.is_some_and(|pos| {
header
.get(pos..)
.and_then(|s| s.find(':').and_then(|i| s.get(i + 1..)))
.is_some_and(|val| val.trim_start().starts_with("True"))
})
}
fn extract_shape(header: &str) -> crate::Result<Vec<usize>> {
let shape_start = header.find("'shape'").or_else(|| header.find("\"shape\""));
let shape_start = shape_start.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header missing 'shape' field".into(),
})?;
let after_key = header
.get(shape_start..)
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header truncated at 'shape'".into(),
})?;
let paren_open = after_key.find('(').ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'shape' value missing opening paren".into(),
})?;
let inner_start = after_key
.get(paren_open + 1..)
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'shape' truncated after paren".into(),
})?;
let paren_close = inner_start.find(')').ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'shape' value missing closing paren".into(),
})?;
let inner = inner_start
.get(..paren_close)
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'shape' extraction failed".into(),
})?;
inner
.split(',')
.filter(|s| !s.trim().is_empty())
.map(|s| {
s.trim()
.parse::<usize>()
.map_err(|e| AnamnesisError::Parse {
reason: format!("NPY shape dimension parse error: {e}"),
})
})
.collect()
}
fn read_array_data(
reader: &mut impl Read,
header: &NpyHeader,
entry_size: u64,
budget: &mut Budget,
) -> crate::Result<Vec<u8>> {
let n_elements: usize = header
.shape
.iter()
.try_fold(1usize, |acc, &d| acc.checked_mul(d))
.ok_or_else(|| AnamnesisError::Parse {
reason: "element count overflow".into(),
})?;
let data_bytes = n_elements
.checked_mul(header.dtype.byte_size())
.ok_or_else(|| AnamnesisError::Parse {
reason: "data byte count overflow".into(),
})?;
#[allow(clippy::as_conversions)]
let data_bytes_u64 = data_bytes as u64;
if data_bytes_u64 > entry_size {
return Err(AnamnesisError::Parse {
reason: format!(
"NPY array size {data_bytes} bytes exceeds the declared ZIP \
entry size {entry_size} bytes"
),
});
}
if data_bytes_u64 > NPZ_MAX_ARRAY_BYTES {
return Err(AnamnesisError::Parse {
reason: format!(
"NPY array size {data_bytes} bytes exceeds the \
{NPZ_MAX_ARRAY_BYTES}-byte cap"
),
});
}
budget.charge_alloc(data_bytes_u64, "NPZ array data")?;
let mut buf = vec![0u8; data_bytes];
reader
.read_exact(&mut buf)
.map_err(|e| AnamnesisError::Parse {
reason: format!("array data read failed ({data_bytes} bytes): {e}"),
})?;
if header.big_endian && header.dtype.byte_size() > 1 {
byteswap_inplace(&mut buf, header.dtype.byte_size());
}
Ok(buf)
}
fn open_npz_entry_reader<'a, R: Read + Seek>(
src: &'a mut crate::parse::zip::ReaderSource<R>,
entry: &crate::parse::zip::ZipEntry,
name: &str,
) -> crate::Result<Box<dyn Read + 'a>> {
let raw = src.entry_data_reader(entry)?;
match entry.method {
crate::parse::zip::Compression::Stored => Ok(Box::new(raw)),
crate::parse::zip::Compression::Deflate => {
Ok(Box::new(flate2::read::DeflateDecoder::new(raw)))
}
crate::parse::zip::Compression::Unsupported(method) => Err(AnamnesisError::Unsupported {
format: "NPZ".into(),
detail: format!("array '{name}' uses unsupported ZIP compression method {method}"),
}),
}
}
#[derive(Debug, Clone)]
pub struct NpzTensorInfo {
pub name: String,
pub shape: Vec<usize>,
pub dtype: NpzDtype,
pub byte_len: usize,
}
#[derive(Debug, Clone)]
#[must_use]
pub struct NpzInspectInfo {
pub tensors: Vec<NpzTensorInfo>,
pub total_bytes: u64,
pub dtypes: Vec<NpzDtype>,
}
impl fmt::Display for NpzInspectInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Format: NPZ archive")?;
write!(f, "\nTensors: {}", self.tensors.len())?;
write!(
f,
"\nTotal size: {}",
crate::inspect::format_bytes(self.total_bytes)
)?;
let dtype_list: String = self
.dtypes
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ");
write!(f, "\nDtypes: {dtype_list}")?;
Ok(())
}
}
pub fn inspect_npz(path: impl AsRef<Path>) -> crate::Result<NpzInspectInfo> {
let file = std::fs::File::open(path.as_ref())?;
inspect_npz_from_reader(file)
}
pub fn inspect_npz_from_reader<R: Read + Seek>(reader: R) -> crate::Result<NpzInspectInfo> {
let mut src = crate::parse::zip::ReaderSource::new(reader)?;
let entries = crate::parse::zip::read_central_directory(&mut src, &ParseLimits::unbounded())?;
let mut tensors = Vec::with_capacity(entries.len().min(PREALLOC_SOFT_CAP));
let mut total_bytes: u64 = 0;
let mut dtypes: Vec<NpzDtype> = Vec::new();
for entry in &entries {
let name = match entry.name.strip_suffix(".npy") {
Some(n) => n.to_owned(),
None => continue,
};
let mut entry_reader = open_npz_entry_reader(&mut src, entry, &name)?;
let header = parse_npy_header(&mut entry_reader, &mut Budget::unbounded())?;
if header.fortran_order {
return Err(AnamnesisError::Unsupported {
format: "NPZ".into(),
detail: format!(
"fortran-order arrays not supported (array '{name}'). \
ML frameworks save C-order by default"
),
});
}
let n_elements: usize = header
.shape
.iter()
.try_fold(1usize, |acc, &d| acc.checked_mul(d))
.unwrap_or(usize::MAX);
let byte_len = n_elements.saturating_mul(header.dtype.byte_size());
#[allow(clippy::as_conversions)]
{
total_bytes = total_bytes.saturating_add(byte_len as u64);
}
if !dtypes.contains(&header.dtype) {
dtypes.push(header.dtype);
}
tensors.push(NpzTensorInfo {
name,
shape: header.shape,
dtype: header.dtype,
byte_len,
});
}
Ok(NpzInspectInfo {
tensors,
total_bytes,
dtypes,
})
}
pub fn parse_npz(path: impl AsRef<Path>) -> crate::Result<HashMap<String, NpzTensor>> {
parse_npz_with_limits(path, &ParseLimits::default())
}
pub fn parse_npz_with_limits(
path: impl AsRef<Path>,
limits: &ParseLimits,
) -> crate::Result<HashMap<String, NpzTensor>> {
let file = std::fs::File::open(path.as_ref())?;
let mut src = crate::parse::zip::ReaderSource::new(file)?;
let entries = crate::parse::zip::read_central_directory(&mut src, limits)?;
let mut budget = Budget::new(limits);
let mut result = HashMap::with_capacity(entries.len().min(PREALLOC_SOFT_CAP));
for entry in &entries {
let name = match entry.name.strip_suffix(".npy") {
Some(n) => n.to_owned(),
None => continue,
};
limits.check_decompression_ratio(entry.uncompressed_size, entry.compressed_size, &name)?;
let entry_size = entry.uncompressed_size;
let mut entry_reader = open_npz_entry_reader(&mut src, entry, &name)?;
let header = parse_npy_header(&mut entry_reader, &mut budget)?;
if header.fortran_order {
return Err(AnamnesisError::Unsupported {
format: "NPZ".into(),
detail: format!(
"fortran-order arrays not supported (array '{name}'). \
ML frameworks save C-order by default"
),
});
}
let data = read_array_data(&mut entry_reader, &header, entry_size, &mut budget)?;
result.insert(
name.clone(),
NpzTensor {
name,
shape: header.shape,
dtype: header.dtype,
data,
},
);
}
Ok(result)
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_used,
clippy::expect_used,
clippy::as_conversions,
clippy::cast_possible_truncation,
clippy::float_cmp
)]
mod tests {
use std::io::Write;
use super::*;
#[test]
fn byte_size_1() {
assert_eq!(NpzDtype::Bool.byte_size(), 1);
assert_eq!(NpzDtype::U8.byte_size(), 1);
assert_eq!(NpzDtype::I8.byte_size(), 1);
}
#[test]
fn byte_size_2() {
assert_eq!(NpzDtype::U16.byte_size(), 2);
assert_eq!(NpzDtype::I16.byte_size(), 2);
assert_eq!(NpzDtype::F16.byte_size(), 2);
assert_eq!(NpzDtype::BF16.byte_size(), 2);
}
#[test]
fn byte_size_4() {
assert_eq!(NpzDtype::U32.byte_size(), 4);
assert_eq!(NpzDtype::I32.byte_size(), 4);
assert_eq!(NpzDtype::F32.byte_size(), 4);
}
#[test]
fn byte_size_8() {
assert_eq!(NpzDtype::U64.byte_size(), 8);
assert_eq!(NpzDtype::I64.byte_size(), 8);
assert_eq!(NpzDtype::F64.byte_size(), 8);
}
#[test]
fn display() {
assert_eq!(NpzDtype::F32.to_string(), "F32");
assert_eq!(NpzDtype::BF16.to_string(), "BF16");
assert_eq!(NpzDtype::I64.to_string(), "I64");
assert_eq!(NpzDtype::Bool.to_string(), "BOOL");
}
#[test]
fn parse_descr_float_types() {
assert_eq!(parse_descr("<f2").unwrap(), (NpzDtype::F16, false));
assert_eq!(parse_descr("<f4").unwrap(), (NpzDtype::F32, false));
assert_eq!(parse_descr("<f8").unwrap(), (NpzDtype::F64, false));
assert_eq!(parse_descr(">f4").unwrap(), (NpzDtype::F32, true));
}
#[test]
fn parse_descr_int_types() {
assert_eq!(parse_descr("|i1").unwrap(), (NpzDtype::I8, false));
assert_eq!(parse_descr("<i2").unwrap(), (NpzDtype::I16, false));
assert_eq!(parse_descr("<i4").unwrap(), (NpzDtype::I32, false));
assert_eq!(parse_descr("<i8").unwrap(), (NpzDtype::I64, false));
assert_eq!(parse_descr(">i4").unwrap(), (NpzDtype::I32, true));
}
#[test]
fn parse_descr_uint_types() {
assert_eq!(parse_descr("|u1").unwrap(), (NpzDtype::U8, false));
assert_eq!(parse_descr("<u2").unwrap(), (NpzDtype::U16, false));
assert_eq!(parse_descr("<u4").unwrap(), (NpzDtype::U32, false));
assert_eq!(parse_descr("<u8").unwrap(), (NpzDtype::U64, false));
}
#[test]
fn parse_descr_bool() {
assert_eq!(parse_descr("|b1").unwrap(), (NpzDtype::Bool, false));
}
#[test]
fn parse_descr_bf16_void() {
assert_eq!(parse_descr("|V2").unwrap(), (NpzDtype::BF16, false));
}
#[test]
fn parse_descr_unsupported() {
assert!(parse_descr("<c8").is_err()); assert!(parse_descr("<U4").is_err()); assert!(parse_descr("x").is_err()); }
#[test]
fn parse_descr_multibyte_first_char_is_clean_error() {
assert!(parse_descr("\u{359}f4").is_err()); assert!(parse_descr("é4").is_err()); assert!(parse_descr("\u{1F600}").is_err()); }
#[test]
fn extract_descr_from_header() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (2, 3), }";
let (dtype, be) = extract_descr(header).unwrap();
assert_eq!(dtype, NpzDtype::F32);
assert!(!be);
}
#[test]
fn extract_descr_double_quotes() {
let header = "{\"descr\": \"<i4\", \"fortran_order\": False, \"shape\": (10,), }";
let (dtype, _) = extract_descr(header).unwrap();
assert_eq!(dtype, NpzDtype::I32);
}
#[test]
fn extract_descr_mixed_quotes() {
let header = "{'descr': \"<f4\", 'fortran_order': False, 'shape': (2, 3), }";
let (dtype, be) = extract_descr(header).unwrap();
assert_eq!(dtype, NpzDtype::F32);
assert!(!be);
}
#[test]
fn fortran_order_false() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (2, 3), }";
assert!(!extract_fortran_order(header));
}
#[test]
fn fortran_order_true() {
let header = "{'descr': '<f4', 'fortran_order': True, 'shape': (2, 3), }";
assert!(extract_fortran_order(header));
}
#[test]
fn fortran_order_missing() {
let header = "{'descr': '<f4', 'shape': (2, 3), }";
assert!(!extract_fortran_order(header));
}
#[test]
fn shape_scalar() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (), }";
let shape = extract_shape(header).unwrap();
assert!(shape.is_empty());
}
#[test]
fn shape_1d() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (16384,), }";
let shape = extract_shape(header).unwrap();
assert_eq!(shape, vec![16384]);
}
#[test]
fn shape_2d() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (2304, 16384), }";
let shape = extract_shape(header).unwrap();
assert_eq!(shape, vec![2304, 16384]);
}
#[test]
fn shape_3d() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (2, 3, 4), }";
let shape = extract_shape(header).unwrap();
assert_eq!(shape, vec![2, 3, 4]);
}
fn make_npy_v1(header_str: &str, data: &[u8]) -> Vec<u8> {
let header_bytes = header_str.as_bytes();
let total_before_pad = 10 + header_bytes.len();
let padding = (64 - (total_before_pad % 64)) % 64;
let padded_len = header_bytes.len() + padding;
let mut npy = Vec::new();
npy.extend_from_slice(NPY_MAGIC);
npy.push(1); npy.push(0); npy.extend_from_slice(&(padded_len as u16).to_le_bytes());
npy.extend_from_slice(header_bytes);
if padding > 0 {
npy.extend(std::iter::repeat_n(b' ', padding - 1));
npy.push(b'\n');
}
npy.extend_from_slice(data);
npy
}
#[test]
fn roundtrip_f32_npy_v1() {
let values: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
let mut data = Vec::new();
for v in &values {
data.extend_from_slice(&v.to_le_bytes());
}
let npy = make_npy_v1(
"{'descr': '<f4', 'fortran_order': False, 'shape': (2, 2), }",
&data,
);
let mut reader = std::io::Cursor::new(&npy);
let header = parse_npy_header(&mut reader, &mut Budget::unbounded()).unwrap();
assert_eq!(header.dtype, NpzDtype::F32);
assert!(!header.big_endian);
assert!(!header.fortran_order);
assert_eq!(header.shape, vec![2, 2]);
let result =
read_array_data(&mut reader, &header, u64::MAX, &mut Budget::unbounded()).unwrap();
assert_eq!(result, data);
}
#[test]
fn roundtrip_f32_big_endian() {
let data_be: Vec<u8> = vec![0x3F, 0x80, 0x00, 0x00];
let npy = make_npy_v1(
"{'descr': '>f4', 'fortran_order': False, 'shape': (1,), }",
&data_be,
);
let mut reader = std::io::Cursor::new(&npy);
let header = parse_npy_header(&mut reader, &mut Budget::unbounded()).unwrap();
assert!(header.big_endian);
let result =
read_array_data(&mut reader, &header, u64::MAX, &mut Budget::unbounded()).unwrap();
assert_eq!(result, vec![0x00, 0x00, 0x80, 0x3F]);
let val = f32::from_le_bytes([result[0], result[1], result[2], result[3]]);
assert_eq!(val, 1.0);
}
#[test]
fn npy_v2_header() {
let header_str = "{'descr': '<f8', 'fortran_order': False, 'shape': (1,), }";
let header_bytes = header_str.as_bytes();
let total_before_pad = 12 + header_bytes.len();
let padding = (64 - (total_before_pad % 64)) % 64;
let padded_len = header_bytes.len() + padding;
let mut npy = Vec::new();
npy.extend_from_slice(NPY_MAGIC);
npy.push(2); npy.push(0); npy.extend_from_slice(&(padded_len as u32).to_le_bytes());
npy.extend_from_slice(header_bytes);
if padding > 0 {
npy.extend(std::iter::repeat_n(b' ', padding - 1));
npy.push(b'\n');
}
npy.extend_from_slice(&42.5_f64.to_le_bytes());
let mut reader = std::io::Cursor::new(&npy);
let header = parse_npy_header(&mut reader, &mut Budget::unbounded()).unwrap();
assert_eq!(header.dtype, NpzDtype::F64);
assert_eq!(header.shape, vec![1]);
let result =
read_array_data(&mut reader, &header, u64::MAX, &mut Budget::unbounded()).unwrap();
assert_eq!(result, 42.5_f64.to_le_bytes());
}
#[test]
fn invalid_magic_rejected() {
let data = b"NOT_NUMPY_DATA_AT_ALL";
let mut reader = std::io::Cursor::new(data);
assert!(parse_npy_header(&mut reader, &mut Budget::unbounded()).is_err());
}
#[test]
fn fortran_order_rejected_in_parse_npz() {
let header = "{'descr': '<f4', 'fortran_order': True, 'shape': (2, 3), }";
assert!(extract_fortran_order(header));
}
#[test]
fn fortran_order_rejected_end_to_end() {
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("arr.npy", options).unwrap();
let header_str = "{'descr': '<f4', 'fortran_order': True, 'shape': (2, 2), }";
let npy = make_npy_v1(header_str, &[0u8; 16]); zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
let err = parse_npz(tmp.path()).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("Fortran-order") || msg.contains("fortran"),
"expected Fortran-order error, got: {msg}"
);
}
#[test]
fn empty_npz_archive() {
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let zip = zip::ZipWriter::new(file);
zip.finish().unwrap();
}
let result = parse_npz(tmp.path()).unwrap();
assert!(result.is_empty(), "empty NPZ should return empty map");
let info = inspect_npz(tmp.path()).unwrap();
assert!(info.tensors.is_empty());
assert_eq!(info.total_bytes, 0);
}
#[test]
fn parse_descr_native_endian() {
let (dtype, be) = parse_descr("=f4").unwrap();
assert_eq!(dtype, NpzDtype::F32);
assert!(!be, "'=' should not be treated as big-endian");
}
#[test]
fn big_endian_through_parse_npz() {
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("val.npy", options).unwrap();
let npy = make_npy_v1(
"{'descr': '>f4', 'fortran_order': False, 'shape': (1,), }",
&[0x3F, 0x80, 0x00, 0x00],
);
zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
let tensors = parse_npz(tmp.path()).unwrap();
let t = tensors.get("val").expect("val not found");
assert_eq!(t.dtype, NpzDtype::F32);
assert_eq!(t.data, vec![0x00, 0x00, 0x80, 0x3F]);
}
#[test]
fn parse_npz_respects_parse_limits() {
let data = vec![0u8; 4000];
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("val.npy", options).unwrap();
let npy = make_npy_v1(
"{'descr': '<f4', 'fortran_order': False, 'shape': (1000,), }",
&data,
);
zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
let baseline = parse_npz(tmp.path()).unwrap();
let with_default = parse_npz_with_limits(tmp.path(), &ParseLimits::default()).unwrap();
assert_eq!(with_default.get("val").unwrap().data.len(), 4000);
assert_eq!(
baseline.get("val").unwrap().data,
with_default.get("val").unwrap().data
);
let err = parse_npz_with_limits(
tmp.path(),
&ParseLimits::default().with_max_single_alloc(3999),
)
.unwrap_err();
assert!(
matches!(err, AnamnesisError::Parse { ref reason } if reason.contains("NPZ array data")),
"expected array single-alloc limit error, got: {err}"
);
assert!(parse_npz_with_limits(
tmp.path(),
&ParseLimits::default().with_max_single_alloc(4000)
)
.is_ok());
let err =
parse_npz_with_limits(tmp.path(), &ParseLimits::default().with_max_single_alloc(1))
.unwrap_err();
assert!(
matches!(err, AnamnesisError::Parse { ref reason }
if reason.contains("central directory") || reason.contains("NPY header")),
"expected container/header single-alloc limit error, got: {err}"
);
let err = parse_npz_with_limits(tmp.path(), &ParseLimits::default().with_max_item_count(0))
.unwrap_err();
assert!(
matches!(err, AnamnesisError::Parse { ref reason } if reason.contains("max_item_count")),
"expected item-count limit error, got: {err}"
);
assert!(
parse_npz_with_limits(tmp.path(), &ParseLimits::default().with_max_item_count(1))
.is_ok()
);
}
#[test]
fn parse_npz_aggregate_budget() {
let data = vec![0u8; 4000];
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
for name in ["a.npy", "b.npy", "c.npy"] {
zip.start_file(name, options).unwrap();
let npy = make_npy_v1(
"{'descr': '<f4', 'fortran_order': False, 'shape': (1000,), }",
&data,
);
zip.write_all(&npy).unwrap();
}
zip.finish().unwrap();
}
let per_item_only = ParseLimits::default().with_max_single_alloc(4096);
assert!(parse_npz_with_limits(tmp.path(), &per_item_only).is_ok());
let with_aggregate = per_item_only.with_max_total_bytes(8000);
let err = parse_npz_with_limits(tmp.path(), &with_aggregate).unwrap_err();
assert!(
matches!(err, AnamnesisError::Parse { ref reason } if reason.contains("max_total_bytes")),
"expected aggregate limit error, got: {err}"
);
assert!(parse_npz_with_limits(
tmp.path(),
&ParseLimits::default().with_max_total_bytes(1 << 20)
)
.is_ok());
assert!(parse_npz_with_limits(tmp.path(), &ParseLimits::default()).is_ok());
}
#[test]
fn parse_npz_decompression_ratio_cap() {
let zeros = vec![0u8; 40_000];
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Deflated);
zip.start_file("z.npy", options).unwrap();
let npy = make_npy_v1(
"{'descr': '<u1', 'fortran_order': False, 'shape': (40000,), }",
&zeros,
);
zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
let err = parse_npz_with_limits(
tmp.path(),
&ParseLimits::default().with_max_decompression_ratio(2),
)
.unwrap_err();
assert!(
matches!(err, AnamnesisError::Parse { ref reason } if reason.contains("max_decompression_ratio")),
"expected ratio limit error, got: {err}"
);
assert!(parse_npz_with_limits(
tmp.path(),
&ParseLimits::default().with_max_decompression_ratio(100_000)
)
.is_ok());
assert!(parse_npz_with_limits(tmp.path(), &ParseLimits::default()).is_ok());
let stored = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(stored.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("s.npy", options).unwrap();
let npy = make_npy_v1(
"{'descr': '<u1', 'fortran_order': False, 'shape': (16,), }",
&[0u8; 16],
);
zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
assert!(parse_npz_with_limits(
stored.path(),
&ParseLimits::default().with_max_decompression_ratio(1)
)
.is_ok());
}
#[test]
fn parse_npz_deflate_roundtrip_values() {
let values: [f32; 6] = [1.5, -2.25, 3.0, 0.0, 42.5, -100.0];
let mut raw = Vec::new();
for v in &values {
raw.extend_from_slice(&v.to_le_bytes());
}
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Deflated);
zip.start_file("w.npy", options).unwrap();
let npy = make_npy_v1(
"{'descr': '<f4', 'fortran_order': False, 'shape': (6,), }",
&raw,
);
zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
let tensors = parse_npz(tmp.path()).unwrap();
let t = tensors.get("w").expect("w not found");
assert_eq!(t.dtype, NpzDtype::F32);
assert_eq!(t.shape, vec![6]);
assert_eq!(
t.data, raw,
"inflated DEFLATE bytes must round-trip exactly"
);
let info = inspect_npz(tmp.path()).unwrap();
assert_eq!(info.tensors.len(), 1);
assert_eq!(info.tensors[0].shape, vec![6]);
assert_eq!(info.tensors[0].dtype, NpzDtype::F32);
}
#[test]
fn parse_npz_more_entries_than_prealloc_cap() {
const N: usize = 300; let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
for i in 0..N {
zip.start_file(format!("a{i}.npy"), options).unwrap();
let npy = make_npy_v1(
"{'descr': '<u1', 'fortran_order': False, 'shape': (1,), }",
&[0u8],
);
zip.write_all(&npy).unwrap();
}
zip.finish().unwrap();
}
let tensors = parse_npz(tmp.path()).unwrap();
assert_eq!(
tensors.len(),
N,
"every entry must round-trip past the clamp"
);
}
#[test]
fn inspect_npz_total_predicts_parse_limits_gate() {
let data = vec![0u8; 4000];
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("w.npy", options).unwrap();
let npy = make_npy_v1(
"{'descr': '<f4', 'fortran_order': False, 'shape': (1000,), }",
&data,
);
zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
let info = inspect_npz(tmp.path()).unwrap();
assert_eq!(info.total_bytes, 4000);
assert_eq!(info.tensors.len(), 1);
let policy_budget = info.total_bytes - 1;
assert!(info.total_bytes > policy_budget); assert!(parse_npz_with_limits(
tmp.path(),
&ParseLimits::default().with_max_total_bytes(policy_budget)
)
.is_err());
assert!(parse_npz_with_limits(
tmp.path(),
&ParseLimits::default().with_max_total_bytes(info.total_bytes * 2)
)
.is_ok());
}
#[test]
fn inspect_npz_overflow_saturates() {
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("huge.npy", options).unwrap();
let shape_str = format!(
"{{'descr': '<f4', 'fortran_order': False, 'shape': ({}, 2), }}",
usize::MAX / 2 + 1
);
let npy = make_npy_v1(&shape_str, &[]); zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
let info = inspect_npz(tmp.path()).unwrap();
assert_eq!(info.tensors.len(), 1);
assert_eq!(info.tensors[0].byte_len, usize::MAX);
}
fn make_in_memory_npz(arr_name: &str, header_str: &str, data: &[u8]) -> Vec<u8> {
let mut buf: Vec<u8> = Vec::new();
{
let cursor = std::io::Cursor::new(&mut buf);
let mut zip = zip::ZipWriter::new(cursor);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
let entry_name = format!("{arr_name}.npy");
zip.start_file(&entry_name, options).unwrap();
let npy = make_npy_v1(header_str, data);
zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
buf
}
#[test]
fn inspect_from_reader_matches_path() {
let mut buf: Vec<u8> = Vec::new();
{
let cursor = std::io::Cursor::new(&mut buf);
let mut zip = zip::ZipWriter::new(cursor);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("weights.npy", options).unwrap();
let npy1 = make_npy_v1(
"{'descr': '<f4', 'fortran_order': False, 'shape': (2, 3), }",
&[0u8; 24],
);
zip.write_all(&npy1).unwrap();
zip.start_file("indices.npy", options).unwrap();
let npy2 = make_npy_v1(
"{'descr': '<i8', 'fortran_order': False, 'shape': (4,), }",
&[0u8; 32],
);
zip.write_all(&npy2).unwrap();
zip.finish().unwrap();
}
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), &buf).unwrap();
let path_info = inspect_npz(tmp.path()).unwrap();
let reader_info = inspect_npz_from_reader(std::io::Cursor::new(&buf)).unwrap();
assert_eq!(path_info.tensors.len(), reader_info.tensors.len());
assert_eq!(path_info.total_bytes, reader_info.total_bytes);
assert_eq!(path_info.dtypes, reader_info.dtypes);
for (a, b) in path_info.tensors.iter().zip(reader_info.tensors.iter()) {
assert_eq!(a.name, b.name);
assert_eq!(a.shape, b.shape);
assert_eq!(a.dtype, b.dtype);
assert_eq!(a.byte_len, b.byte_len);
}
assert_eq!(reader_info.tensors.len(), 2);
assert_eq!(reader_info.total_bytes, 24 + 32);
}
#[test]
fn inspect_from_reader_empty_archive() {
let mut buf: Vec<u8> = Vec::new();
{
let cursor = std::io::Cursor::new(&mut buf);
let zip = zip::ZipWriter::new(cursor);
zip.finish().unwrap();
}
let info = inspect_npz_from_reader(std::io::Cursor::new(&buf)).unwrap();
assert!(info.tensors.is_empty());
assert_eq!(info.total_bytes, 0);
assert!(info.dtypes.is_empty());
}
#[test]
fn inspect_from_reader_rejects_fortran_order() {
let buf = make_in_memory_npz(
"arr",
"{'descr': '<f4', 'fortran_order': True, 'shape': (2, 2), }",
&[0u8; 16],
);
let err = inspect_npz_from_reader(std::io::Cursor::new(&buf)).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("Fortran-order") || msg.contains("fortran"),
"expected Fortran-order error, got: {msg}"
);
}
#[test]
fn header_len_cap_rejects_oversized_v2_header() {
let mut npy = Vec::new();
npy.extend_from_slice(NPY_MAGIC);
npy.push(2); npy.push(0); npy.extend_from_slice(&u32::MAX.to_le_bytes());
let mut reader = std::io::Cursor::new(&npy);
let Err(err) = parse_npy_header(&mut reader, &mut Budget::unbounded()) else {
panic!("expected error for oversized header length");
};
let msg = err.to_string();
assert!(
msg.contains("header length") && msg.contains("cap"),
"expected header-length cap error, got: {msg}"
);
}
#[test]
fn header_len_cap_accepts_boundary() {
let mut npy = Vec::new();
npy.extend_from_slice(NPY_MAGIC);
npy.push(2);
npy.push(0);
npy.extend_from_slice(&(NPY_MAX_HEADER_BYTES as u32).to_le_bytes());
let mut reader = std::io::Cursor::new(&npy);
let Err(err) = parse_npy_header(&mut reader, &mut Budget::unbounded()) else {
panic!("expected truncated-body error past the cap gate");
};
let msg = err.to_string();
assert!(
!msg.contains("cap"),
"boundary value must pass the cap gate, got cap error: {msg}"
);
}
#[test]
fn array_bytes_cap_rejects_oversized_shape() {
let header = NpyHeader {
dtype: NpzDtype::F32,
big_endian: false,
fortran_order: false,
shape: vec![3_000_000_000usize],
};
let mut empty = std::io::Cursor::new(Vec::new());
let result = read_array_data(&mut empty, &header, u64::MAX, &mut Budget::unbounded());
assert!(
result.is_err(),
"oversized declared array must be rejected, got Ok"
);
}
#[test]
fn array_bytes_rejected_above_entry_size() {
let header = NpyHeader {
dtype: NpzDtype::F32,
big_endian: false,
fortran_order: false,
shape: vec![1000],
};
let mut empty = std::io::Cursor::new(Vec::new());
let Err(err) = read_array_data(&mut empty, &header, 16, &mut Budget::unbounded()) else {
panic!("over-declared shape vs entry size must be rejected");
};
let msg = err.to_string();
assert!(
msg.contains("declared ZIP entry size"),
"expected entry-size rejection, got: {msg}"
);
}
}