use std::collections::HashMap;
use std::fmt;
use std::io::{Read, Seek};
use std::path::Path;
use crate::error::AnamnesisError;
use crate::parse::utils::byteswap_inplace;
const NPY_MAGIC: &[u8; 6] = b"\x93NUMPY";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum NpzDtype {
Bool,
U8,
I8,
U16,
I16,
U32,
I32,
U64,
I64,
F16,
BF16,
F32,
F64,
}
impl NpzDtype {
#[must_use]
pub const fn byte_size(self) -> usize {
match self {
Self::Bool | Self::U8 | Self::I8 => 1,
Self::U16 | Self::I16 | Self::F16 | Self::BF16 => 2,
Self::U32 | Self::I32 | Self::F32 => 4,
Self::U64 | Self::I64 | Self::F64 => 8,
}
}
}
impl fmt::Display for NpzDtype {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::Bool => "BOOL",
Self::U8 => "U8",
Self::I8 => "I8",
Self::U16 => "U16",
Self::I16 => "I16",
Self::U32 => "U32",
Self::I32 => "I32",
Self::U64 => "U64",
Self::I64 => "I64",
Self::F16 => "F16",
Self::BF16 => "BF16",
Self::F32 => "F32",
Self::F64 => "F64",
};
f.write_str(s)
}
}
#[derive(Debug, Clone)]
pub struct NpzTensor {
pub name: String,
pub shape: Vec<usize>,
pub dtype: NpzDtype,
pub data: Vec<u8>,
}
struct NpyHeader {
dtype: NpzDtype,
big_endian: bool,
fortran_order: bool,
shape: Vec<usize>,
}
fn parse_npy_header(reader: &mut impl Read) -> crate::Result<NpyHeader> {
let mut preamble = [0u8; 8];
reader
.read_exact(&mut preamble)
.map_err(|e| AnamnesisError::Parse {
reason: format!("NPY preamble read failed: {e}"),
})?;
#[allow(clippy::indexing_slicing)]
if &preamble[..6] != NPY_MAGIC {
return Err(AnamnesisError::Parse {
reason: "invalid NPY magic bytes".into(),
});
}
#[allow(clippy::indexing_slicing)]
let major = preamble[6];
let header_len: usize = match major {
1 => {
let mut buf = [0u8; 2];
reader
.read_exact(&mut buf)
.map_err(|e| AnamnesisError::Parse {
reason: format!("NPY v1 header length read failed: {e}"),
})?;
usize::from(u16::from_le_bytes(buf))
}
2 | 3 => {
let mut buf = [0u8; 4];
reader
.read_exact(&mut buf)
.map_err(|e| AnamnesisError::Parse {
reason: format!("NPY v{major} header length read failed: {e}"),
})?;
#[allow(clippy::as_conversions)]
let len = u32::from_le_bytes(buf) as usize;
len
}
_ => {
return Err(AnamnesisError::Unsupported {
format: "NPY".into(),
detail: format!("unsupported NPY version {major}"),
});
}
};
let mut header_buf = vec![0u8; header_len];
reader
.read_exact(&mut header_buf)
.map_err(|e| AnamnesisError::Parse {
reason: format!("NPY header data read failed: {e}"),
})?;
let header_str = std::str::from_utf8(&header_buf).map_err(|e| AnamnesisError::Parse {
reason: format!("NPY header is not valid UTF-8: {e}"),
})?;
let (dtype, big_endian) = extract_descr(header_str)?;
let fortran_order = extract_fortran_order(header_str);
let shape = extract_shape(header_str)?;
Ok(NpyHeader {
dtype,
big_endian,
fortran_order,
shape,
})
}
fn extract_descr(header: &str) -> crate::Result<(NpzDtype, bool)> {
let descr_start = header.find("'descr'").or_else(|| header.find("\"descr\""));
let descr_start = descr_start.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header missing 'descr' field".into(),
})?;
let after_key = header
.get(descr_start..)
.and_then(|s| s.find(':').map(|i| descr_start + i + 1))
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'descr' field has no value".into(),
})?;
let value_str = header
.get(after_key..)
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header truncated after 'descr'".into(),
})?;
let trimmed = value_str.trim_start();
let quote_char = match trimmed.as_bytes().first() {
Some(b'\'') => '\'',
Some(b'"') => '"',
_ => {
return Err(AnamnesisError::Parse {
reason: "NPY header 'descr' value not quoted".into(),
});
}
};
let inner = trimmed.get(1..).ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'descr' value truncated after opening quote".into(),
})?;
let closing = inner
.find(quote_char)
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'descr' value missing closing quote".into(),
})?;
let descr = inner.get(..closing).ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'descr' extraction failed".into(),
})?;
parse_descr(descr)
}
fn parse_descr(descr: &str) -> crate::Result<(NpzDtype, bool)> {
let bytes = descr.as_bytes();
if bytes.len() < 2 {
return Err(AnamnesisError::Unsupported {
format: "NPY".into(),
detail: format!("dtype descriptor too short: '{descr}'"),
});
}
#[allow(clippy::indexing_slicing)]
let endian_char = bytes[0];
#[allow(clippy::indexing_slicing)]
let type_str = &descr[1..];
let big_endian = endian_char == b'>';
let dtype = match type_str {
"b1" => NpzDtype::Bool,
"u1" => NpzDtype::U8,
"u2" => NpzDtype::U16,
"u4" => NpzDtype::U32,
"u8" => NpzDtype::U64,
"i1" => NpzDtype::I8,
"i2" => NpzDtype::I16,
"i4" => NpzDtype::I32,
"i8" => NpzDtype::I64,
"f2" => NpzDtype::F16,
"f4" => NpzDtype::F32,
"f8" => NpzDtype::F64,
"V2" => NpzDtype::BF16,
_ => {
return Err(AnamnesisError::Unsupported {
format: "NPY".into(),
detail: format!("unsupported dtype descriptor '{descr}'"),
});
}
};
Ok((dtype, big_endian))
}
fn extract_fortran_order(header: &str) -> bool {
header
.find("'fortran_order'")
.or_else(|| header.find("\"fortran_order\""))
.is_some_and(|pos| {
header
.get(pos..)
.and_then(|s| s.find(':').and_then(|i| s.get(i + 1..)))
.is_some_and(|val| val.trim_start().starts_with("True"))
})
}
fn extract_shape(header: &str) -> crate::Result<Vec<usize>> {
let shape_start = header.find("'shape'").or_else(|| header.find("\"shape\""));
let shape_start = shape_start.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header missing 'shape' field".into(),
})?;
let after_key = header
.get(shape_start..)
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header truncated at 'shape'".into(),
})?;
let paren_open = after_key.find('(').ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'shape' value missing opening paren".into(),
})?;
let inner_start = after_key
.get(paren_open + 1..)
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'shape' truncated after paren".into(),
})?;
let paren_close = inner_start.find(')').ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'shape' value missing closing paren".into(),
})?;
let inner = inner_start
.get(..paren_close)
.ok_or_else(|| AnamnesisError::Parse {
reason: "NPY header 'shape' extraction failed".into(),
})?;
inner
.split(',')
.filter(|s| !s.trim().is_empty())
.map(|s| {
s.trim()
.parse::<usize>()
.map_err(|e| AnamnesisError::Parse {
reason: format!("NPY shape dimension parse error: {e}"),
})
})
.collect()
}
fn read_array_data(reader: &mut impl Read, header: &NpyHeader) -> crate::Result<Vec<u8>> {
let n_elements: usize = header
.shape
.iter()
.try_fold(1usize, |acc, &d| acc.checked_mul(d))
.ok_or_else(|| AnamnesisError::Parse {
reason: "element count overflow".into(),
})?;
let data_bytes = n_elements
.checked_mul(header.dtype.byte_size())
.ok_or_else(|| AnamnesisError::Parse {
reason: "data byte count overflow".into(),
})?;
let mut buf = vec![0u8; data_bytes];
reader
.read_exact(&mut buf)
.map_err(|e| AnamnesisError::Parse {
reason: format!("array data read failed ({data_bytes} bytes): {e}"),
})?;
if header.big_endian && header.dtype.byte_size() > 1 {
byteswap_inplace(&mut buf, header.dtype.byte_size());
}
Ok(buf)
}
#[derive(Debug, Clone)]
pub struct NpzTensorInfo {
pub name: String,
pub shape: Vec<usize>,
pub dtype: NpzDtype,
pub byte_len: usize,
}
#[derive(Debug, Clone)]
#[must_use]
pub struct NpzInspectInfo {
pub tensors: Vec<NpzTensorInfo>,
pub total_bytes: u64,
pub dtypes: Vec<NpzDtype>,
}
impl fmt::Display for NpzInspectInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Format: NPZ archive")?;
write!(f, "\nTensors: {}", self.tensors.len())?;
write!(
f,
"\nTotal size: {}",
crate::inspect::format_bytes(self.total_bytes)
)?;
let dtype_list: String = self
.dtypes
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ");
write!(f, "\nDtypes: {dtype_list}")?;
Ok(())
}
}
pub fn inspect_npz(path: impl AsRef<Path>) -> crate::Result<NpzInspectInfo> {
let file = std::fs::File::open(path.as_ref())?;
inspect_npz_from_reader(file)
}
pub fn inspect_npz_from_reader<R: Read + Seek>(reader: R) -> crate::Result<NpzInspectInfo> {
let mut archive = zip::ZipArchive::new(reader)?;
let mut tensors = Vec::with_capacity(archive.len());
let mut total_bytes: u64 = 0;
let mut dtypes: Vec<NpzDtype> = Vec::new();
for i in 0..archive.len() {
let mut entry = archive.by_index(i).map_err(|e| AnamnesisError::Parse {
reason: format!("failed to read ZIP entry {i}: {e}"),
})?;
let full_name = entry.name().to_owned();
let name = match full_name.strip_suffix(".npy") {
Some(n) => n.to_owned(),
None => continue,
};
let header = parse_npy_header(&mut entry)?;
if header.fortran_order {
return Err(AnamnesisError::Unsupported {
format: "NPZ".into(),
detail: format!(
"fortran-order arrays not supported (array '{name}'). \
ML frameworks save C-order by default"
),
});
}
let n_elements: usize = header
.shape
.iter()
.try_fold(1usize, |acc, &d| acc.checked_mul(d))
.unwrap_or(usize::MAX);
let byte_len = n_elements.saturating_mul(header.dtype.byte_size());
#[allow(clippy::as_conversions)]
{
total_bytes = total_bytes.saturating_add(byte_len as u64);
}
if !dtypes.contains(&header.dtype) {
dtypes.push(header.dtype);
}
tensors.push(NpzTensorInfo {
name,
shape: header.shape,
dtype: header.dtype,
byte_len,
});
}
Ok(NpzInspectInfo {
tensors,
total_bytes,
dtypes,
})
}
pub fn parse_npz(path: impl AsRef<Path>) -> crate::Result<HashMap<String, NpzTensor>> {
let file = std::fs::File::open(path.as_ref())?;
let mut archive = zip::ZipArchive::new(file)?;
let mut result = HashMap::with_capacity(archive.len());
for i in 0..archive.len() {
let mut entry = archive.by_index(i).map_err(|e| AnamnesisError::Parse {
reason: format!("failed to read ZIP entry {i}: {e}"),
})?;
let full_name = entry.name().to_owned();
let name = match full_name.strip_suffix(".npy") {
Some(n) => n.to_owned(),
None => continue,
};
let header = parse_npy_header(&mut entry)?;
if header.fortran_order {
return Err(AnamnesisError::Unsupported {
format: "NPZ".into(),
detail: format!(
"fortran-order arrays not supported (array '{name}'). \
ML frameworks save C-order by default"
),
});
}
let data = read_array_data(&mut entry, &header)?;
result.insert(
name.clone(),
NpzTensor {
name,
shape: header.shape,
dtype: header.dtype,
data,
},
);
}
Ok(result)
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_used,
clippy::expect_used,
clippy::as_conversions,
clippy::cast_possible_truncation,
clippy::float_cmp
)]
mod tests {
use std::io::Write;
use super::*;
#[test]
fn byte_size_1() {
assert_eq!(NpzDtype::Bool.byte_size(), 1);
assert_eq!(NpzDtype::U8.byte_size(), 1);
assert_eq!(NpzDtype::I8.byte_size(), 1);
}
#[test]
fn byte_size_2() {
assert_eq!(NpzDtype::U16.byte_size(), 2);
assert_eq!(NpzDtype::I16.byte_size(), 2);
assert_eq!(NpzDtype::F16.byte_size(), 2);
assert_eq!(NpzDtype::BF16.byte_size(), 2);
}
#[test]
fn byte_size_4() {
assert_eq!(NpzDtype::U32.byte_size(), 4);
assert_eq!(NpzDtype::I32.byte_size(), 4);
assert_eq!(NpzDtype::F32.byte_size(), 4);
}
#[test]
fn byte_size_8() {
assert_eq!(NpzDtype::U64.byte_size(), 8);
assert_eq!(NpzDtype::I64.byte_size(), 8);
assert_eq!(NpzDtype::F64.byte_size(), 8);
}
#[test]
fn display() {
assert_eq!(NpzDtype::F32.to_string(), "F32");
assert_eq!(NpzDtype::BF16.to_string(), "BF16");
assert_eq!(NpzDtype::I64.to_string(), "I64");
assert_eq!(NpzDtype::Bool.to_string(), "BOOL");
}
#[test]
fn parse_descr_float_types() {
assert_eq!(parse_descr("<f2").unwrap(), (NpzDtype::F16, false));
assert_eq!(parse_descr("<f4").unwrap(), (NpzDtype::F32, false));
assert_eq!(parse_descr("<f8").unwrap(), (NpzDtype::F64, false));
assert_eq!(parse_descr(">f4").unwrap(), (NpzDtype::F32, true));
}
#[test]
fn parse_descr_int_types() {
assert_eq!(parse_descr("|i1").unwrap(), (NpzDtype::I8, false));
assert_eq!(parse_descr("<i2").unwrap(), (NpzDtype::I16, false));
assert_eq!(parse_descr("<i4").unwrap(), (NpzDtype::I32, false));
assert_eq!(parse_descr("<i8").unwrap(), (NpzDtype::I64, false));
assert_eq!(parse_descr(">i4").unwrap(), (NpzDtype::I32, true));
}
#[test]
fn parse_descr_uint_types() {
assert_eq!(parse_descr("|u1").unwrap(), (NpzDtype::U8, false));
assert_eq!(parse_descr("<u2").unwrap(), (NpzDtype::U16, false));
assert_eq!(parse_descr("<u4").unwrap(), (NpzDtype::U32, false));
assert_eq!(parse_descr("<u8").unwrap(), (NpzDtype::U64, false));
}
#[test]
fn parse_descr_bool() {
assert_eq!(parse_descr("|b1").unwrap(), (NpzDtype::Bool, false));
}
#[test]
fn parse_descr_bf16_void() {
assert_eq!(parse_descr("|V2").unwrap(), (NpzDtype::BF16, false));
}
#[test]
fn parse_descr_unsupported() {
assert!(parse_descr("<c8").is_err()); assert!(parse_descr("<U4").is_err()); assert!(parse_descr("x").is_err()); }
#[test]
fn extract_descr_from_header() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (2, 3), }";
let (dtype, be) = extract_descr(header).unwrap();
assert_eq!(dtype, NpzDtype::F32);
assert!(!be);
}
#[test]
fn extract_descr_double_quotes() {
let header = "{\"descr\": \"<i4\", \"fortran_order\": False, \"shape\": (10,), }";
let (dtype, _) = extract_descr(header).unwrap();
assert_eq!(dtype, NpzDtype::I32);
}
#[test]
fn extract_descr_mixed_quotes() {
let header = "{'descr': \"<f4\", 'fortran_order': False, 'shape': (2, 3), }";
let (dtype, be) = extract_descr(header).unwrap();
assert_eq!(dtype, NpzDtype::F32);
assert!(!be);
}
#[test]
fn fortran_order_false() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (2, 3), }";
assert!(!extract_fortran_order(header));
}
#[test]
fn fortran_order_true() {
let header = "{'descr': '<f4', 'fortran_order': True, 'shape': (2, 3), }";
assert!(extract_fortran_order(header));
}
#[test]
fn fortran_order_missing() {
let header = "{'descr': '<f4', 'shape': (2, 3), }";
assert!(!extract_fortran_order(header));
}
#[test]
fn shape_scalar() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (), }";
let shape = extract_shape(header).unwrap();
assert!(shape.is_empty());
}
#[test]
fn shape_1d() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (16384,), }";
let shape = extract_shape(header).unwrap();
assert_eq!(shape, vec![16384]);
}
#[test]
fn shape_2d() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (2304, 16384), }";
let shape = extract_shape(header).unwrap();
assert_eq!(shape, vec![2304, 16384]);
}
#[test]
fn shape_3d() {
let header = "{'descr': '<f4', 'fortran_order': False, 'shape': (2, 3, 4), }";
let shape = extract_shape(header).unwrap();
assert_eq!(shape, vec![2, 3, 4]);
}
fn make_npy_v1(header_str: &str, data: &[u8]) -> Vec<u8> {
let header_bytes = header_str.as_bytes();
let total_before_pad = 10 + header_bytes.len();
let padding = (64 - (total_before_pad % 64)) % 64;
let padded_len = header_bytes.len() + padding;
let mut npy = Vec::new();
npy.extend_from_slice(NPY_MAGIC);
npy.push(1); npy.push(0); npy.extend_from_slice(&(padded_len as u16).to_le_bytes());
npy.extend_from_slice(header_bytes);
if padding > 0 {
npy.extend(std::iter::repeat_n(b' ', padding - 1));
npy.push(b'\n');
}
npy.extend_from_slice(data);
npy
}
#[test]
fn roundtrip_f32_npy_v1() {
let values: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
let mut data = Vec::new();
for v in &values {
data.extend_from_slice(&v.to_le_bytes());
}
let npy = make_npy_v1(
"{'descr': '<f4', 'fortran_order': False, 'shape': (2, 2), }",
&data,
);
let mut reader = std::io::Cursor::new(&npy);
let header = parse_npy_header(&mut reader).unwrap();
assert_eq!(header.dtype, NpzDtype::F32);
assert!(!header.big_endian);
assert!(!header.fortran_order);
assert_eq!(header.shape, vec![2, 2]);
let result = read_array_data(&mut reader, &header).unwrap();
assert_eq!(result, data);
}
#[test]
fn roundtrip_f32_big_endian() {
let data_be: Vec<u8> = vec![0x3F, 0x80, 0x00, 0x00];
let npy = make_npy_v1(
"{'descr': '>f4', 'fortran_order': False, 'shape': (1,), }",
&data_be,
);
let mut reader = std::io::Cursor::new(&npy);
let header = parse_npy_header(&mut reader).unwrap();
assert!(header.big_endian);
let result = read_array_data(&mut reader, &header).unwrap();
assert_eq!(result, vec![0x00, 0x00, 0x80, 0x3F]);
let val = f32::from_le_bytes([result[0], result[1], result[2], result[3]]);
assert_eq!(val, 1.0);
}
#[test]
fn npy_v2_header() {
let header_str = "{'descr': '<f8', 'fortran_order': False, 'shape': (1,), }";
let header_bytes = header_str.as_bytes();
let total_before_pad = 12 + header_bytes.len();
let padding = (64 - (total_before_pad % 64)) % 64;
let padded_len = header_bytes.len() + padding;
let mut npy = Vec::new();
npy.extend_from_slice(NPY_MAGIC);
npy.push(2); npy.push(0); npy.extend_from_slice(&(padded_len as u32).to_le_bytes());
npy.extend_from_slice(header_bytes);
if padding > 0 {
npy.extend(std::iter::repeat_n(b' ', padding - 1));
npy.push(b'\n');
}
npy.extend_from_slice(&42.5_f64.to_le_bytes());
let mut reader = std::io::Cursor::new(&npy);
let header = parse_npy_header(&mut reader).unwrap();
assert_eq!(header.dtype, NpzDtype::F64);
assert_eq!(header.shape, vec![1]);
let result = read_array_data(&mut reader, &header).unwrap();
assert_eq!(result, 42.5_f64.to_le_bytes());
}
#[test]
fn invalid_magic_rejected() {
let data = b"NOT_NUMPY_DATA_AT_ALL";
let mut reader = std::io::Cursor::new(data);
assert!(parse_npy_header(&mut reader).is_err());
}
#[test]
fn fortran_order_rejected_in_parse_npz() {
let header = "{'descr': '<f4', 'fortran_order': True, 'shape': (2, 3), }";
assert!(extract_fortran_order(header));
}
#[test]
fn fortran_order_rejected_end_to_end() {
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("arr.npy", options).unwrap();
let header_str = "{'descr': '<f4', 'fortran_order': True, 'shape': (2, 2), }";
let npy = make_npy_v1(header_str, &[0u8; 16]); zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
let err = parse_npz(tmp.path()).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("Fortran-order") || msg.contains("fortran"),
"expected Fortran-order error, got: {msg}"
);
}
#[test]
fn empty_npz_archive() {
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let zip = zip::ZipWriter::new(file);
zip.finish().unwrap();
}
let result = parse_npz(tmp.path()).unwrap();
assert!(result.is_empty(), "empty NPZ should return empty map");
let info = inspect_npz(tmp.path()).unwrap();
assert!(info.tensors.is_empty());
assert_eq!(info.total_bytes, 0);
}
#[test]
fn parse_descr_native_endian() {
let (dtype, be) = parse_descr("=f4").unwrap();
assert_eq!(dtype, NpzDtype::F32);
assert!(!be, "'=' should not be treated as big-endian");
}
#[test]
fn big_endian_through_parse_npz() {
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("val.npy", options).unwrap();
let npy = make_npy_v1(
"{'descr': '>f4', 'fortran_order': False, 'shape': (1,), }",
&[0x3F, 0x80, 0x00, 0x00],
);
zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
let tensors = parse_npz(tmp.path()).unwrap();
let t = tensors.get("val").expect("val not found");
assert_eq!(t.dtype, NpzDtype::F32);
assert_eq!(t.data, vec![0x00, 0x00, 0x80, 0x3F]);
}
#[test]
fn inspect_npz_overflow_saturates() {
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("huge.npy", options).unwrap();
let shape_str = format!(
"{{'descr': '<f4', 'fortran_order': False, 'shape': ({}, 2), }}",
usize::MAX / 2 + 1
);
let npy = make_npy_v1(&shape_str, &[]); zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
let info = inspect_npz(tmp.path()).unwrap();
assert_eq!(info.tensors.len(), 1);
assert_eq!(info.tensors[0].byte_len, usize::MAX);
}
fn make_in_memory_npz(arr_name: &str, header_str: &str, data: &[u8]) -> Vec<u8> {
let mut buf: Vec<u8> = Vec::new();
{
let cursor = std::io::Cursor::new(&mut buf);
let mut zip = zip::ZipWriter::new(cursor);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
let entry_name = format!("{arr_name}.npy");
zip.start_file(&entry_name, options).unwrap();
let npy = make_npy_v1(header_str, data);
zip.write_all(&npy).unwrap();
zip.finish().unwrap();
}
buf
}
#[test]
fn inspect_from_reader_matches_path() {
let mut buf: Vec<u8> = Vec::new();
{
let cursor = std::io::Cursor::new(&mut buf);
let mut zip = zip::ZipWriter::new(cursor);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("weights.npy", options).unwrap();
let npy1 = make_npy_v1(
"{'descr': '<f4', 'fortran_order': False, 'shape': (2, 3), }",
&[0u8; 24],
);
zip.write_all(&npy1).unwrap();
zip.start_file("indices.npy", options).unwrap();
let npy2 = make_npy_v1(
"{'descr': '<i8', 'fortran_order': False, 'shape': (4,), }",
&[0u8; 32],
);
zip.write_all(&npy2).unwrap();
zip.finish().unwrap();
}
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), &buf).unwrap();
let path_info = inspect_npz(tmp.path()).unwrap();
let reader_info = inspect_npz_from_reader(std::io::Cursor::new(&buf)).unwrap();
assert_eq!(path_info.tensors.len(), reader_info.tensors.len());
assert_eq!(path_info.total_bytes, reader_info.total_bytes);
assert_eq!(path_info.dtypes, reader_info.dtypes);
for (a, b) in path_info.tensors.iter().zip(reader_info.tensors.iter()) {
assert_eq!(a.name, b.name);
assert_eq!(a.shape, b.shape);
assert_eq!(a.dtype, b.dtype);
assert_eq!(a.byte_len, b.byte_len);
}
assert_eq!(reader_info.tensors.len(), 2);
assert_eq!(reader_info.total_bytes, 24 + 32);
}
#[test]
fn inspect_from_reader_empty_archive() {
let mut buf: Vec<u8> = Vec::new();
{
let cursor = std::io::Cursor::new(&mut buf);
let zip = zip::ZipWriter::new(cursor);
zip.finish().unwrap();
}
let info = inspect_npz_from_reader(std::io::Cursor::new(&buf)).unwrap();
assert!(info.tensors.is_empty());
assert_eq!(info.total_bytes, 0);
assert!(info.dtypes.is_empty());
}
#[test]
fn inspect_from_reader_rejects_fortran_order() {
let buf = make_in_memory_npz(
"arr",
"{'descr': '<f4', 'fortran_order': True, 'shape': (2, 2), }",
&[0u8; 16],
);
let err = inspect_npz_from_reader(std::io::Cursor::new(&buf)).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("Fortran-order") || msg.contains("fortran"),
"expected Fortran-order error, got: {msg}"
);
}
}