use crate::error::{IoError, Result};
pub const NPY_MAGIC: &[u8; 6] = b"\x93NUMPY";
pub const NPY_MAJOR_VERSION: u8 = 1;
pub const NPY_MINOR_VERSION: u8 = 0;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NpyDtype {
Float32,
Float64,
Int32,
Int64,
}
impl NpyDtype {
pub fn element_size(&self) -> usize {
match self {
NpyDtype::Float32 => 4,
NpyDtype::Float64 => 8,
NpyDtype::Int32 => 4,
NpyDtype::Int64 => 8,
}
}
pub fn npy_str_le(&self) -> &'static str {
match self {
NpyDtype::Float32 => "<f4",
NpyDtype::Float64 => "<f8",
NpyDtype::Int32 => "<i4",
NpyDtype::Int64 => "<i8",
}
}
pub fn npy_str_be(&self) -> &'static str {
match self {
NpyDtype::Float32 => ">f4",
NpyDtype::Float64 => ">f8",
NpyDtype::Int32 => ">i4",
NpyDtype::Int64 => ">i8",
}
}
pub fn from_descr(descr: &str) -> Result<(Self, ByteOrder)> {
let descr = descr.trim().trim_matches('\'').trim_matches('"');
if descr.len() < 3 {
return Err(IoError::FormatError(format!(
"Invalid dtype descriptor: '{}'",
descr
)));
}
let endian_char = descr.as_bytes()[0];
let type_char = descr.as_bytes()[1];
let size_str = &descr[2..];
let byte_order = match endian_char {
b'<' | b'=' => ByteOrder::LittleEndian,
b'>' => ByteOrder::BigEndian,
b'|' => ByteOrder::NotApplicable,
_ => {
return Err(IoError::FormatError(format!(
"Unknown endian prefix: '{}'",
endian_char as char
)))
}
};
let size: usize = size_str
.parse()
.map_err(|_| IoError::FormatError(format!("Invalid dtype size: '{}'", size_str)))?;
let dtype = match (type_char, size) {
(b'f', 4) => NpyDtype::Float32,
(b'f', 8) => NpyDtype::Float64,
(b'i', 4) => NpyDtype::Int32,
(b'i', 8) => NpyDtype::Int64,
_ => {
return Err(IoError::FormatError(format!(
"Unsupported dtype: type='{}', size={}",
type_char as char, size
)))
}
};
Ok((dtype, byte_order))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ByteOrder {
LittleEndian,
BigEndian,
NotApplicable,
}
#[derive(Debug, Clone)]
pub struct NpyHeader {
pub dtype: NpyDtype,
pub byte_order: ByteOrder,
pub fortran_order: bool,
pub shape: Vec<usize>,
}
impl NpyHeader {
pub fn num_elements(&self) -> usize {
self.shape.iter().product()
}
pub fn to_header_string(&self) -> String {
let descr = if cfg!(target_endian = "little") {
self.dtype.npy_str_le()
} else {
self.dtype.npy_str_be()
};
let fortran_str = if self.fortran_order { "True" } else { "False" };
let shape_str = if self.shape.len() == 1 {
format!("({},)", self.shape[0])
} else {
let parts: Vec<String> = self.shape.iter().map(|s| s.to_string()).collect();
format!("({})", parts.join(", "))
};
format!(
"{{'descr': '{}', 'fortran_order': {}, 'shape': {}, }}",
descr, fortran_str, shape_str
)
}
}
#[derive(Debug, Clone)]
pub enum NpyArray {
Float32 {
data: Vec<f32>,
shape: Vec<usize>,
},
Float64 {
data: Vec<f64>,
shape: Vec<usize>,
},
Int32 {
data: Vec<i32>,
shape: Vec<usize>,
},
Int64 {
data: Vec<i64>,
shape: Vec<usize>,
},
}
impl NpyArray {
pub fn shape(&self) -> &[usize] {
match self {
NpyArray::Float32 { shape, .. } => shape,
NpyArray::Float64 { shape, .. } => shape,
NpyArray::Int32 { shape, .. } => shape,
NpyArray::Int64 { shape, .. } => shape,
}
}
pub fn dtype(&self) -> NpyDtype {
match self {
NpyArray::Float32 { .. } => NpyDtype::Float32,
NpyArray::Float64 { .. } => NpyDtype::Float64,
NpyArray::Int32 { .. } => NpyDtype::Int32,
NpyArray::Int64 { .. } => NpyDtype::Int64,
}
}
pub fn num_elements(&self) -> usize {
self.shape().iter().product()
}
pub fn as_f64(&self) -> Result<&[f64]> {
match self {
NpyArray::Float64 { data, .. } => Ok(data),
_ => Err(IoError::ConversionError(format!(
"Array is {:?}, not Float64",
self.dtype()
))),
}
}
pub fn as_f32(&self) -> Result<&[f32]> {
match self {
NpyArray::Float32 { data, .. } => Ok(data),
_ => Err(IoError::ConversionError(format!(
"Array is {:?}, not Float32",
self.dtype()
))),
}
}
pub fn as_i32(&self) -> Result<&[i32]> {
match self {
NpyArray::Int32 { data, .. } => Ok(data),
_ => Err(IoError::ConversionError(format!(
"Array is {:?}, not Int32",
self.dtype()
))),
}
}
pub fn as_i64(&self) -> Result<&[i64]> {
match self {
NpyArray::Int64 { data, .. } => Ok(data),
_ => Err(IoError::ConversionError(format!(
"Array is {:?}, not Int64",
self.dtype()
))),
}
}
}
pub fn parse_header_dict(header_str: &str) -> Result<NpyHeader> {
let header_str = header_str
.trim()
.trim_end_matches('\n')
.trim_end_matches('\0');
let descr = extract_dict_value(header_str, "descr")?;
let (dtype, byte_order) = NpyDtype::from_descr(&descr)?;
let fortran_str = extract_dict_value(header_str, "fortran_order")?;
let fortran_order = fortran_str.trim() == "True";
let shape_str = extract_dict_value(header_str, "shape")?;
let shape = parse_shape(&shape_str)?;
Ok(NpyHeader {
dtype,
byte_order,
fortran_order,
shape,
})
}
fn extract_dict_value(dict_str: &str, key: &str) -> Result<String> {
let search = format!("'{}': ", key);
let pos = dict_str.find(&search).or_else(|| {
let alt_search = format!("\"{}\":", key);
dict_str.find(&alt_search)
});
let start = match pos {
Some(p) => p + search.len(),
None => {
let alt = format!("'{}':", key);
match dict_str.find(&alt) {
Some(p) => p + alt.len(),
None => {
return Err(IoError::FormatError(format!(
"Key '{}' not found in header: {}",
key, dict_str
)))
}
}
}
};
let remaining = dict_str[start..].trim_start();
if remaining.starts_with('\'') || remaining.starts_with('"') {
let quote = remaining.as_bytes()[0];
let end = remaining[1..]
.find(|c: char| c as u8 == quote)
.ok_or_else(|| {
IoError::FormatError(format!("Unterminated string for key '{}'", key))
})?;
Ok(remaining[1..end + 1].to_string())
} else if remaining.starts_with('(') {
let end = remaining
.find(')')
.ok_or_else(|| IoError::FormatError(format!("Unterminated tuple for key '{}'", key)))?;
Ok(remaining[..end + 1].to_string())
} else {
let end = remaining.find([',', '}']).unwrap_or(remaining.len());
Ok(remaining[..end].trim().to_string())
}
}
fn parse_shape(shape_str: &str) -> Result<Vec<usize>> {
let inner = shape_str
.trim()
.trim_start_matches('(')
.trim_end_matches(')');
if inner.is_empty() {
return Ok(vec![]); }
let mut shape = Vec::new();
for part in inner.split(',') {
let part = part.trim();
if part.is_empty() {
continue;
}
let dim: usize = part
.parse()
.map_err(|_| IoError::FormatError(format!("Invalid shape dimension: '{}'", part)))?;
shape.push(dim);
}
Ok(shape)
}