#![allow(clippy::type_complexity)]
#[allow(unused_imports)]
use super::functions::*;
use super::functions::{NPY_MAGIC, NPY_MAJOR, NPY_MINOR};
#[allow(unused_imports)]
use super::functions_2::*;
#[allow(dead_code)]
pub struct NpySlice<'a> {
pub data: &'a [f64],
pub shape: Vec<usize>,
}
#[allow(dead_code)]
impl<'a> NpySlice<'a> {
pub fn new(data: &'a [f64], shape: Vec<usize>) -> std::result::Result<Self, String> {
let expected: usize = shape.iter().product();
if expected != data.len() {
return Err(format!(
"NpySlice: data length {} != shape product {}",
data.len(),
expected
));
}
Ok(NpySlice { data, shape })
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn row(&self, row_idx: usize) -> std::result::Result<&[f64], String> {
if self.shape.len() != 2 {
return Err(format!(
"row() requires 2-D slice, got {}D",
self.shape.len()
));
}
let ncols = self.shape[1];
if row_idx >= self.shape[0] {
return Err(format!(
"row {} out of bounds (shape[0]={})",
row_idx, self.shape[0]
));
}
Ok(&self.data[row_idx * ncols..(row_idx + 1) * ncols])
}
pub fn get(&self, indices: &[usize]) -> std::result::Result<f64, String> {
let flat = flat_index(indices, &self.shape)?;
Ok(self.data[flat])
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct NpyMaskedArray {
pub data: Vec<f64>,
pub mask: Vec<bool>,
pub fill_value: f64,
pub shape: Vec<usize>,
}
#[allow(dead_code)]
impl NpyMaskedArray {
pub fn new(
data: Vec<f64>,
mask: Vec<bool>,
shape: Vec<usize>,
fill_value: f64,
) -> std::result::Result<Self, String> {
let n: usize = shape.iter().product();
if data.len() != n {
return Err(format!("data length {} != shape product {}", data.len(), n));
}
if mask.len() != n {
return Err(format!("mask length {} != shape product {}", mask.len(), n));
}
Ok(Self {
data,
mask,
fill_value,
shape,
})
}
pub fn from_data(data: Vec<f64>, shape: Vec<usize>) -> std::result::Result<Self, String> {
let n: usize = shape.iter().product();
if data.len() != n {
return Err(format!("data length {} != shape product {}", data.len(), n));
}
let mask = vec![false; n];
Ok(Self {
data,
mask,
fill_value: 1e20,
shape,
})
}
pub fn get_filled(&self, idx: usize) -> f64 {
if self.mask[idx] {
self.fill_value
} else {
self.data[idx]
}
}
pub fn count_valid(&self) -> usize {
self.mask.iter().filter(|&&m| !m).count()
}
pub fn mean_valid(&self) -> Option<f64> {
let (sum, count) = self
.data
.iter()
.zip(self.mask.iter())
.filter(|&(_, &m)| !m)
.fold((0.0_f64, 0_usize), |(s, c), (&v, _)| (s + v, c + 1));
if count == 0 {
None
} else {
Some(sum / count as f64)
}
}
pub fn filled(&self) -> Vec<f64> {
self.data
.iter()
.zip(self.mask.iter())
.map(|(&v, &m)| if m { self.fill_value } else { v })
.collect()
}
pub fn mask_greater_than(&mut self, threshold: f64) {
for (m, &v) in self.mask.iter_mut().zip(self.data.iter()) {
if v.abs() > threshold {
*m = true;
}
}
}
pub fn unmask_all(&mut self) {
self.mask.iter_mut().for_each(|m| *m = false);
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct NpyRecordArray {
pub fields: Vec<NpyField>,
pub columns: Vec<Vec<f64>>,
pub n_records: usize,
}
#[allow(dead_code)]
impl NpyRecordArray {
pub fn new(fields: Vec<NpyField>) -> Self {
let columns = vec![Vec::new(); fields.len()];
Self {
fields,
columns,
n_records: 0,
}
}
pub fn push_record(&mut self, values: &[f64]) -> std::result::Result<(), String> {
let total: usize = self.fields.iter().map(|f| f.count).sum();
if values.len() != total {
return Err(format!(
"push_record: expected {total} values, got {}",
values.len()
));
}
let mut offset = 0;
for (col, field) in self.columns.iter_mut().zip(self.fields.iter()) {
col.extend_from_slice(&values[offset..offset + field.count]);
offset += field.count;
}
self.n_records += 1;
Ok(())
}
pub fn column(&self, name: &str) -> Option<&[f64]> {
self.fields
.iter()
.position(|f| f.name == name)
.map(|i| self.columns[i].as_slice())
}
pub fn get_scalar(&self, record: usize, name: &str) -> std::result::Result<f64, String> {
let fi = self
.fields
.iter()
.position(|f| f.name == name)
.ok_or_else(|| format!("field '{name}' not found"))?;
let field = &self.fields[fi];
if field.count != 1 {
return Err(format!(
"field '{name}' is not scalar (count={})",
field.count
));
}
if record >= self.n_records {
return Err(format!(
"record {record} out of range (n_records={})",
self.n_records
));
}
Ok(self.columns[fi][record])
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum NpyDtype {
Float64,
Float32,
Int32,
Int64,
Bool,
Uint8,
}
impl NpyDtype {
pub fn numpy_str(&self) -> &str {
match self {
NpyDtype::Float64 => "<f8",
NpyDtype::Float32 => "<f4",
NpyDtype::Int32 => "<i4",
NpyDtype::Int64 => "<i8",
NpyDtype::Bool => "?",
NpyDtype::Uint8 => "|u1",
}
}
pub fn element_size(&self) -> usize {
match self {
NpyDtype::Float64 => 8,
NpyDtype::Float32 => 4,
NpyDtype::Int32 => 4,
NpyDtype::Int64 => 8,
NpyDtype::Bool => 1,
NpyDtype::Uint8 => 1,
}
}
pub fn from_numpy_str(s: &str) -> Result<Self, String> {
match s {
"<f8" => Ok(NpyDtype::Float64),
"<f4" => Ok(NpyDtype::Float32),
"<i4" => Ok(NpyDtype::Int32),
"<i8" => Ok(NpyDtype::Int64),
"?" => Ok(NpyDtype::Bool),
"|u1" => Ok(NpyDtype::Uint8),
_ => Err(format!("unsupported dtype: '{s}'")),
}
}
}
#[derive(Debug, Clone)]
pub struct NpyArray {
pub dtype: NpyDtype,
pub shape: Vec<usize>,
pub data_f64: Vec<f64>,
pub data_f32: Vec<f32>,
pub data_i32: Vec<i32>,
}
impl NpyArray {
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn validate(&self) -> Result<(), String> {
let expected = self.numel();
let actual = match self.dtype {
NpyDtype::Float64 => self.data_f64.len(),
NpyDtype::Float32 => self.data_f32.len(),
NpyDtype::Int32 => self.data_i32.len(),
_ => expected,
};
if actual != expected {
Err(format!(
"shape {:?} expects {} elements, but data has {}",
self.shape, expected, actual
))
} else {
Ok(())
}
}
pub fn from_f64(shape: Vec<usize>, data: Vec<f64>) -> Self {
Self {
dtype: NpyDtype::Float64,
shape,
data_f64: data,
data_f32: Vec::new(),
data_i32: Vec::new(),
}
}
pub fn from_f32(shape: Vec<usize>, data: Vec<f32>) -> Self {
Self {
dtype: NpyDtype::Float32,
shape,
data_f64: Vec::new(),
data_f32: data,
data_i32: Vec::new(),
}
}
pub fn from_i32(shape: Vec<usize>, data: Vec<i32>) -> Self {
Self {
dtype: NpyDtype::Int32,
shape,
data_f64: Vec::new(),
data_f32: Vec::new(),
data_i32: data,
}
}
pub fn reshape(&mut self, new_shape: Vec<usize>) -> Result<(), String> {
let old_numel = self.numel();
let new_numel: usize = new_shape.iter().product();
if old_numel != new_numel {
return Err(format!(
"cannot reshape: old numel={old_numel}, new numel={new_numel}"
));
}
self.shape = new_shape;
Ok(())
}
}
impl NpyArray {
#[allow(dead_code)]
pub fn save_structured(
fields: &[(&str, &str)],
n_records: usize,
data_bytes: &[u8],
) -> std::result::Result<Vec<u8>, String> {
if fields.is_empty() {
return Err("save_structured: field list is empty".into());
}
let dtype_parts: Vec<String> = fields
.iter()
.map(|(name, dt)| format!("('{}', '{}')", name, dt))
.collect();
let dtype_str = format!("[{}]", dtype_parts.join(", "));
let header_dict = format!(
"{{'descr': {}, 'fortran_order': False, 'shape': ({},), }}",
dtype_str, n_records
);
let raw_len = header_dict.len() + 1;
let pad_to = raw_len.div_ceil(64) * 64;
let padding = pad_to - raw_len;
let mut header_bytes = header_dict.into_bytes();
header_bytes.extend(std::iter::repeat_n(b' ', padding));
header_bytes.push(b'\n');
let header_len = header_bytes.len() as u16;
let mut out = Vec::new();
out.extend_from_slice(NPY_MAGIC);
out.push(NPY_MAJOR);
out.push(NPY_MINOR);
out.extend_from_slice(&header_len.to_le_bytes());
out.extend_from_slice(&header_bytes);
out.extend_from_slice(data_bytes);
Ok(out)
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Default)]
pub struct NpzArchive {
pub arrays: Vec<(String, NpyArray)>,
}
#[allow(dead_code)]
impl NpzArchive {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, name: &str, array: NpyArray) {
self.arrays.push((name.to_string(), array));
}
pub fn get(&self, name: &str) -> Option<&NpyArray> {
self.arrays.iter().find(|(n, _)| n == name).map(|(_, a)| a)
}
pub fn names(&self) -> Vec<&str> {
self.arrays.iter().map(|(n, _)| n.as_str()).collect()
}
pub fn remove(&mut self, name: &str) -> bool {
let before = self.arrays.len();
self.arrays.retain(|(n, _)| n != name);
self.arrays.len() < before
}
pub fn len(&self) -> usize {
self.arrays.len()
}
pub fn is_empty(&self) -> bool {
self.arrays.is_empty()
}
pub fn to_bytes(&self) -> std::result::Result<Vec<u8>, String> {
let mut writer = NpzWriter::new();
for (name, array) in &self.arrays {
match array.dtype {
NpyDtype::Float64 => {
writer.add_array_f64(name, &array.shape, &array.data_f64);
}
NpyDtype::Float32 => {
writer.add_array_f32(name, &array.shape, &array.data_f32);
}
NpyDtype::Int32 => {
writer.add_array_i32(name, &array.shape, &array.data_i32);
}
_ => {
return Err(format!(
"NpzArchive::to_bytes: unsupported dtype {:?}",
array.dtype
));
}
}
}
Ok(writer.to_bytes())
}
pub fn from_bytes(data: &[u8]) -> std::result::Result<Self, String> {
let writer = NpzWriter::from_bytes(data)?;
let mut archive = NpzArchive::new();
for (name, npy_bytes) in &writer.files {
let dtype = detect_npy_dtype(npy_bytes)?;
let array = match dtype {
NpyDtype::Float64 => {
let (shape, data_f64) = read_npy_f64(npy_bytes)?;
NpyArray::from_f64(shape, data_f64)
}
NpyDtype::Float32 => {
let (shape, data_f32) = read_npy_f32(npy_bytes)?;
NpyArray::from_f32(shape, data_f32)
}
NpyDtype::Int32 => {
let (shape, data_i32) = read_npy_i32(npy_bytes)?;
NpyArray::from_i32(shape, data_i32)
}
other => {
return Err(format!(
"NpzArchive::from_bytes: unsupported dtype {:?} in '{name}'",
other
));
}
};
archive.insert(name, array);
}
Ok(archive)
}
}
impl NpzArchive {
#[allow(dead_code)]
pub fn add_array(&mut self, name: &str, array: NpyArray) {
self.arrays.retain(|(n, _)| n.as_str() != name);
self.arrays.push((name.to_string(), array));
}
#[allow(dead_code)]
pub fn load_all(data: &[u8]) -> std::result::Result<Self, String> {
Self::from_bytes(data)
}
#[allow(dead_code)]
pub fn iter(&self) -> impl Iterator<Item = (&str, &NpyArray)> {
self.arrays.iter().map(|(n, a)| (n.as_str(), a))
}
#[allow(dead_code)]
pub fn merge(&mut self, other: NpzArchive) {
for (name, array) in other.arrays {
self.add_array(&name, array);
}
}
#[allow(dead_code)]
pub fn total_elements(&self) -> usize {
self.arrays.iter().map(|(_, a)| a.numel()).sum()
}
}
#[derive(Debug, Clone)]
pub struct NpzWriter {
pub files: Vec<(String, Vec<u8>)>,
}
impl NpzWriter {
pub fn new() -> Self {
NpzWriter { files: Vec::new() }
}
pub fn add_array_f64(&mut self, name: &str, shape: &[usize], data: &[f64]) {
let npy = write_npy_f64(shape, data);
self.files.push((name.to_string(), npy));
}
pub fn add_array_f32(&mut self, name: &str, shape: &[usize], data: &[f32]) {
let npy = write_npy_f32(shape, data);
self.files.push((name.to_string(), npy));
}
pub fn add_array_i32(&mut self, name: &str, shape: &[usize], data: &[i32]) {
let npy = write_npy_i32(shape, data);
self.files.push((name.to_string(), npy));
}
pub fn add_array_i64(&mut self, name: &str, shape: &[usize], data: &[i64]) {
let npy = write_npy_i64(shape, data);
self.files.push((name.to_string(), npy));
}
pub fn len(&self) -> usize {
self.files.len()
}
pub fn is_empty(&self) -> bool {
self.files.is_empty()
}
pub fn names(&self) -> Vec<&str> {
self.files.iter().map(|(n, _)| n.as_str()).collect()
}
pub fn contains(&self, name: &str) -> bool {
self.files.iter().any(|(n, _)| n == name)
}
pub fn remove(&mut self, name: &str) -> bool {
let before = self.files.len();
self.files.retain(|(n, _)| n != name);
self.files.len() < before
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut out: Vec<u8> = Vec::new();
out.extend_from_slice(&(self.files.len() as u32).to_le_bytes());
for (name, npy) in &self.files {
let name_bytes = name.as_bytes();
out.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
out.extend_from_slice(name_bytes);
out.extend_from_slice(&(npy.len() as u32).to_le_bytes());
out.extend_from_slice(npy);
}
out
}
pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
let mut pos = 0usize;
let count = read_u32(data, &mut pos)? as usize;
let mut files = Vec::with_capacity(count);
for _ in 0..count {
let name_len = read_u32(data, &mut pos)? as usize;
if pos + name_len > data.len() {
return Err("name out of bounds".to_string());
}
let name = std::str::from_utf8(&data[pos..pos + name_len])
.map_err(|e| format!("invalid UTF-8 in name: {e}"))?
.to_string();
pos += name_len;
let npy_len = read_u32(data, &mut pos)? as usize;
if pos + npy_len > data.len() {
return Err("npy payload out of bounds".to_string());
}
let npy = data[pos..pos + npy_len].to_vec();
pos += npy_len;
files.push((name, npy));
}
Ok(NpzWriter { files })
}
pub fn get_f64(&self, name: &str) -> Option<Result<(Vec<usize>, Vec<f64>), String>> {
self.files
.iter()
.find(|(n, _)| n == name)
.map(|(_, npy)| read_npy_f64(npy))
}
pub fn get_f32(&self, name: &str) -> Option<Result<(Vec<usize>, Vec<f32>), String>> {
self.files
.iter()
.find(|(n, _)| n == name)
.map(|(_, npy)| read_npy_f32(npy))
}
pub fn get_i32(&self, name: &str) -> Option<Result<(Vec<usize>, Vec<i32>), String>> {
self.files
.iter()
.find(|(n, _)| n == name)
.map(|(_, npy)| read_npy_i32(npy))
}
pub fn get_i64(&self, name: &str) -> Option<Result<(Vec<usize>, Vec<i64>), String>> {
self.files
.iter()
.find(|(n, _)| n == name)
.map(|(_, npy)| read_npy_i64(npy))
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct NpyField {
pub name: String,
pub dtype: NpyDtype,
pub count: usize,
}
#[allow(dead_code)]
impl NpyField {
pub fn scalar(name: &str, dtype: NpyDtype) -> Self {
Self {
name: name.to_string(),
dtype,
count: 1,
}
}
pub fn vector(name: &str, dtype: NpyDtype, count: usize) -> Self {
Self {
name: name.to_string(),
dtype,
count,
}
}
pub fn byte_size(&self) -> usize {
self.dtype.element_size() * self.count
}
}