1#![allow(clippy::type_complexity)]
6#[allow(unused_imports)]
7use super::functions::*;
8use super::functions::{NPY_MAGIC, NPY_MAJOR, NPY_MINOR};
9#[allow(unused_imports)]
10use super::functions_2::*;
11
12#[allow(dead_code)]
17pub struct NpySlice<'a> {
18 pub data: &'a [f64],
20 pub shape: Vec<usize>,
22}
23#[allow(dead_code)]
24impl<'a> NpySlice<'a> {
25 pub fn new(data: &'a [f64], shape: Vec<usize>) -> std::result::Result<Self, String> {
27 let expected: usize = shape.iter().product();
28 if expected != data.len() {
29 return Err(format!(
30 "NpySlice: data length {} != shape product {}",
31 data.len(),
32 expected
33 ));
34 }
35 Ok(NpySlice { data, shape })
36 }
37 pub fn ndim(&self) -> usize {
39 self.shape.len()
40 }
41 pub fn numel(&self) -> usize {
43 self.shape.iter().product()
44 }
45 pub fn row(&self, row_idx: usize) -> std::result::Result<&[f64], String> {
47 if self.shape.len() != 2 {
48 return Err(format!(
49 "row() requires 2-D slice, got {}D",
50 self.shape.len()
51 ));
52 }
53 let ncols = self.shape[1];
54 if row_idx >= self.shape[0] {
55 return Err(format!(
56 "row {} out of bounds (shape[0]={})",
57 row_idx, self.shape[0]
58 ));
59 }
60 Ok(&self.data[row_idx * ncols..(row_idx + 1) * ncols])
61 }
62 pub fn get(&self, indices: &[usize]) -> std::result::Result<f64, String> {
64 let flat = flat_index(indices, &self.shape)?;
65 Ok(self.data[flat])
66 }
67}
68#[allow(dead_code)]
71#[derive(Debug, Clone)]
72pub struct NpyMaskedArray {
73 pub data: Vec<f64>,
75 pub mask: Vec<bool>,
77 pub fill_value: f64,
79 pub shape: Vec<usize>,
81}
82#[allow(dead_code)]
83impl NpyMaskedArray {
84 pub fn new(
86 data: Vec<f64>,
87 mask: Vec<bool>,
88 shape: Vec<usize>,
89 fill_value: f64,
90 ) -> std::result::Result<Self, String> {
91 let n: usize = shape.iter().product();
92 if data.len() != n {
93 return Err(format!("data length {} != shape product {}", data.len(), n));
94 }
95 if mask.len() != n {
96 return Err(format!("mask length {} != shape product {}", mask.len(), n));
97 }
98 Ok(Self {
99 data,
100 mask,
101 fill_value,
102 shape,
103 })
104 }
105 pub fn from_data(data: Vec<f64>, shape: Vec<usize>) -> std::result::Result<Self, String> {
107 let n: usize = shape.iter().product();
108 if data.len() != n {
109 return Err(format!("data length {} != shape product {}", data.len(), n));
110 }
111 let mask = vec![false; n];
112 Ok(Self {
113 data,
114 mask,
115 fill_value: 1e20,
116 shape,
117 })
118 }
119 pub fn get_filled(&self, idx: usize) -> f64 {
121 if self.mask[idx] {
122 self.fill_value
123 } else {
124 self.data[idx]
125 }
126 }
127 pub fn count_valid(&self) -> usize {
129 self.mask.iter().filter(|&&m| !m).count()
130 }
131 pub fn mean_valid(&self) -> Option<f64> {
133 let (sum, count) = self
134 .data
135 .iter()
136 .zip(self.mask.iter())
137 .filter(|&(_, &m)| !m)
138 .fold((0.0_f64, 0_usize), |(s, c), (&v, _)| (s + v, c + 1));
139 if count == 0 {
140 None
141 } else {
142 Some(sum / count as f64)
143 }
144 }
145 pub fn filled(&self) -> Vec<f64> {
147 self.data
148 .iter()
149 .zip(self.mask.iter())
150 .map(|(&v, &m)| if m { self.fill_value } else { v })
151 .collect()
152 }
153 pub fn mask_greater_than(&mut self, threshold: f64) {
155 for (m, &v) in self.mask.iter_mut().zip(self.data.iter()) {
156 if v.abs() > threshold {
157 *m = true;
158 }
159 }
160 }
161 pub fn unmask_all(&mut self) {
163 self.mask.iter_mut().for_each(|m| *m = false);
164 }
165}
166#[allow(dead_code)]
169#[derive(Debug, Clone)]
170pub struct NpyRecordArray {
171 pub fields: Vec<NpyField>,
173 pub columns: Vec<Vec<f64>>,
175 pub n_records: usize,
177}
178#[allow(dead_code)]
179impl NpyRecordArray {
180 pub fn new(fields: Vec<NpyField>) -> Self {
182 let columns = vec![Vec::new(); fields.len()];
183 Self {
184 fields,
185 columns,
186 n_records: 0,
187 }
188 }
189 pub fn push_record(&mut self, values: &[f64]) -> std::result::Result<(), String> {
192 let total: usize = self.fields.iter().map(|f| f.count).sum();
193 if values.len() != total {
194 return Err(format!(
195 "push_record: expected {total} values, got {}",
196 values.len()
197 ));
198 }
199 let mut offset = 0;
200 for (col, field) in self.columns.iter_mut().zip(self.fields.iter()) {
201 col.extend_from_slice(&values[offset..offset + field.count]);
202 offset += field.count;
203 }
204 self.n_records += 1;
205 Ok(())
206 }
207 pub fn column(&self, name: &str) -> Option<&[f64]> {
209 self.fields
210 .iter()
211 .position(|f| f.name == name)
212 .map(|i| self.columns[i].as_slice())
213 }
214 pub fn get_scalar(&self, record: usize, name: &str) -> std::result::Result<f64, String> {
216 let fi = self
217 .fields
218 .iter()
219 .position(|f| f.name == name)
220 .ok_or_else(|| format!("field '{name}' not found"))?;
221 let field = &self.fields[fi];
222 if field.count != 1 {
223 return Err(format!(
224 "field '{name}' is not scalar (count={})",
225 field.count
226 ));
227 }
228 if record >= self.n_records {
229 return Err(format!(
230 "record {record} out of range (n_records={})",
231 self.n_records
232 ));
233 }
234 Ok(self.columns[fi][record])
235 }
236}
237#[derive(Debug, Clone, PartialEq)]
239pub enum NpyDtype {
240 Float64,
242 Float32,
244 Int32,
246 Int64,
248 Bool,
250 Uint8,
252}
253impl NpyDtype {
254 pub fn numpy_str(&self) -> &str {
256 match self {
257 NpyDtype::Float64 => "<f8",
258 NpyDtype::Float32 => "<f4",
259 NpyDtype::Int32 => "<i4",
260 NpyDtype::Int64 => "<i8",
261 NpyDtype::Bool => "?",
262 NpyDtype::Uint8 => "|u1",
263 }
264 }
265 pub fn element_size(&self) -> usize {
267 match self {
268 NpyDtype::Float64 => 8,
269 NpyDtype::Float32 => 4,
270 NpyDtype::Int32 => 4,
271 NpyDtype::Int64 => 8,
272 NpyDtype::Bool => 1,
273 NpyDtype::Uint8 => 1,
274 }
275 }
276 pub fn from_numpy_str(s: &str) -> Result<Self, String> {
278 match s {
279 "<f8" => Ok(NpyDtype::Float64),
280 "<f4" => Ok(NpyDtype::Float32),
281 "<i4" => Ok(NpyDtype::Int32),
282 "<i8" => Ok(NpyDtype::Int64),
283 "?" => Ok(NpyDtype::Bool),
284 "|u1" => Ok(NpyDtype::Uint8),
285 _ => Err(format!("unsupported dtype: '{s}'")),
286 }
287 }
288}
289#[derive(Debug, Clone)]
294pub struct NpyArray {
295 pub dtype: NpyDtype,
297 pub shape: Vec<usize>,
299 pub data_f64: Vec<f64>,
301 pub data_f32: Vec<f32>,
303 pub data_i32: Vec<i32>,
305}
306impl NpyArray {
307 pub fn numel(&self) -> usize {
309 self.shape.iter().product()
310 }
311 pub fn ndim(&self) -> usize {
313 self.shape.len()
314 }
315 pub fn validate(&self) -> Result<(), String> {
317 let expected = self.numel();
318 let actual = match self.dtype {
319 NpyDtype::Float64 => self.data_f64.len(),
320 NpyDtype::Float32 => self.data_f32.len(),
321 NpyDtype::Int32 => self.data_i32.len(),
322 _ => expected,
323 };
324 if actual != expected {
325 Err(format!(
326 "shape {:?} expects {} elements, but data has {}",
327 self.shape, expected, actual
328 ))
329 } else {
330 Ok(())
331 }
332 }
333 pub fn from_f64(shape: Vec<usize>, data: Vec<f64>) -> Self {
335 Self {
336 dtype: NpyDtype::Float64,
337 shape,
338 data_f64: data,
339 data_f32: Vec::new(),
340 data_i32: Vec::new(),
341 }
342 }
343 pub fn from_f32(shape: Vec<usize>, data: Vec<f32>) -> Self {
345 Self {
346 dtype: NpyDtype::Float32,
347 shape,
348 data_f64: Vec::new(),
349 data_f32: data,
350 data_i32: Vec::new(),
351 }
352 }
353 pub fn from_i32(shape: Vec<usize>, data: Vec<i32>) -> Self {
355 Self {
356 dtype: NpyDtype::Int32,
357 shape,
358 data_f64: Vec::new(),
359 data_f32: Vec::new(),
360 data_i32: data,
361 }
362 }
363 pub fn reshape(&mut self, new_shape: Vec<usize>) -> Result<(), String> {
365 let old_numel = self.numel();
366 let new_numel: usize = new_shape.iter().product();
367 if old_numel != new_numel {
368 return Err(format!(
369 "cannot reshape: old numel={old_numel}, new numel={new_numel}"
370 ));
371 }
372 self.shape = new_shape;
373 Ok(())
374 }
375}
376impl NpyArray {
377 #[allow(dead_code)]
385 pub fn save_structured(
386 fields: &[(&str, &str)],
387 n_records: usize,
388 data_bytes: &[u8],
389 ) -> std::result::Result<Vec<u8>, String> {
390 if fields.is_empty() {
391 return Err("save_structured: field list is empty".into());
392 }
393 let dtype_parts: Vec<String> = fields
394 .iter()
395 .map(|(name, dt)| format!("('{}', '{}')", name, dt))
396 .collect();
397 let dtype_str = format!("[{}]", dtype_parts.join(", "));
398 let header_dict = format!(
399 "{{'descr': {}, 'fortran_order': False, 'shape': ({},), }}",
400 dtype_str, n_records
401 );
402 let raw_len = header_dict.len() + 1;
403 let pad_to = raw_len.div_ceil(64) * 64;
404 let padding = pad_to - raw_len;
405 let mut header_bytes = header_dict.into_bytes();
406 header_bytes.extend(std::iter::repeat_n(b' ', padding));
407 header_bytes.push(b'\n');
408 let header_len = header_bytes.len() as u16;
409 let mut out = Vec::new();
410 out.extend_from_slice(NPY_MAGIC);
411 out.push(NPY_MAJOR);
412 out.push(NPY_MINOR);
413 out.extend_from_slice(&header_len.to_le_bytes());
414 out.extend_from_slice(&header_bytes);
415 out.extend_from_slice(data_bytes);
416 Ok(out)
417 }
418}
419#[allow(dead_code)]
424#[derive(Debug, Clone, Default)]
425pub struct NpzArchive {
426 pub arrays: Vec<(String, NpyArray)>,
428}
429#[allow(dead_code)]
430impl NpzArchive {
431 pub fn new() -> Self {
433 Self::default()
434 }
435 pub fn insert(&mut self, name: &str, array: NpyArray) {
437 self.arrays.push((name.to_string(), array));
438 }
439 pub fn get(&self, name: &str) -> Option<&NpyArray> {
441 self.arrays.iter().find(|(n, _)| n == name).map(|(_, a)| a)
442 }
443 pub fn names(&self) -> Vec<&str> {
445 self.arrays.iter().map(|(n, _)| n.as_str()).collect()
446 }
447 pub fn remove(&mut self, name: &str) -> bool {
449 let before = self.arrays.len();
450 self.arrays.retain(|(n, _)| n != name);
451 self.arrays.len() < before
452 }
453 pub fn len(&self) -> usize {
455 self.arrays.len()
456 }
457 pub fn is_empty(&self) -> bool {
459 self.arrays.is_empty()
460 }
461 pub fn to_bytes(&self) -> std::result::Result<Vec<u8>, String> {
463 let mut writer = NpzWriter::new();
464 for (name, array) in &self.arrays {
465 match array.dtype {
466 NpyDtype::Float64 => {
467 writer.add_array_f64(name, &array.shape, &array.data_f64);
468 }
469 NpyDtype::Float32 => {
470 writer.add_array_f32(name, &array.shape, &array.data_f32);
471 }
472 NpyDtype::Int32 => {
473 writer.add_array_i32(name, &array.shape, &array.data_i32);
474 }
475 _ => {
476 return Err(format!(
477 "NpzArchive::to_bytes: unsupported dtype {:?}",
478 array.dtype
479 ));
480 }
481 }
482 }
483 Ok(writer.to_bytes())
484 }
485 pub fn from_bytes(data: &[u8]) -> std::result::Result<Self, String> {
487 let writer = NpzWriter::from_bytes(data)?;
488 let mut archive = NpzArchive::new();
489 for (name, npy_bytes) in &writer.files {
490 let dtype = detect_npy_dtype(npy_bytes)?;
491 let array = match dtype {
492 NpyDtype::Float64 => {
493 let (shape, data_f64) = read_npy_f64(npy_bytes)?;
494 NpyArray::from_f64(shape, data_f64)
495 }
496 NpyDtype::Float32 => {
497 let (shape, data_f32) = read_npy_f32(npy_bytes)?;
498 NpyArray::from_f32(shape, data_f32)
499 }
500 NpyDtype::Int32 => {
501 let (shape, data_i32) = read_npy_i32(npy_bytes)?;
502 NpyArray::from_i32(shape, data_i32)
503 }
504 other => {
505 return Err(format!(
506 "NpzArchive::from_bytes: unsupported dtype {:?} in '{name}'",
507 other
508 ));
509 }
510 };
511 archive.insert(name, array);
512 }
513 Ok(archive)
514 }
515}
516impl NpzArchive {
517 #[allow(dead_code)]
520 pub fn add_array(&mut self, name: &str, array: NpyArray) {
521 self.arrays.retain(|(n, _)| n.as_str() != name);
522 self.arrays.push((name.to_string(), array));
523 }
524 #[allow(dead_code)]
529 pub fn load_all(data: &[u8]) -> std::result::Result<Self, String> {
530 Self::from_bytes(data)
531 }
532 #[allow(dead_code)]
534 pub fn iter(&self) -> impl Iterator<Item = (&str, &NpyArray)> {
535 self.arrays.iter().map(|(n, a)| (n.as_str(), a))
536 }
537 #[allow(dead_code)]
539 pub fn merge(&mut self, other: NpzArchive) {
540 for (name, array) in other.arrays {
541 self.add_array(&name, array);
542 }
543 }
544 #[allow(dead_code)]
546 pub fn total_elements(&self) -> usize {
547 self.arrays.iter().map(|(_, a)| a.numel()).sum()
548 }
549}
550#[derive(Debug, Clone)]
560pub struct NpzWriter {
561 pub files: Vec<(String, Vec<u8>)>,
563}
564impl NpzWriter {
565 pub fn new() -> Self {
567 NpzWriter { files: Vec::new() }
568 }
569 pub fn add_array_f64(&mut self, name: &str, shape: &[usize], data: &[f64]) {
571 let npy = write_npy_f64(shape, data);
572 self.files.push((name.to_string(), npy));
573 }
574 pub fn add_array_f32(&mut self, name: &str, shape: &[usize], data: &[f32]) {
576 let npy = write_npy_f32(shape, data);
577 self.files.push((name.to_string(), npy));
578 }
579 pub fn add_array_i32(&mut self, name: &str, shape: &[usize], data: &[i32]) {
581 let npy = write_npy_i32(shape, data);
582 self.files.push((name.to_string(), npy));
583 }
584 pub fn add_array_i64(&mut self, name: &str, shape: &[usize], data: &[i64]) {
586 let npy = write_npy_i64(shape, data);
587 self.files.push((name.to_string(), npy));
588 }
589 pub fn len(&self) -> usize {
591 self.files.len()
592 }
593 pub fn is_empty(&self) -> bool {
595 self.files.is_empty()
596 }
597 pub fn names(&self) -> Vec<&str> {
599 self.files.iter().map(|(n, _)| n.as_str()).collect()
600 }
601 pub fn contains(&self, name: &str) -> bool {
603 self.files.iter().any(|(n, _)| n == name)
604 }
605 pub fn remove(&mut self, name: &str) -> bool {
607 let before = self.files.len();
608 self.files.retain(|(n, _)| n != name);
609 self.files.len() < before
610 }
611 pub fn to_bytes(&self) -> Vec<u8> {
613 let mut out: Vec<u8> = Vec::new();
614 out.extend_from_slice(&(self.files.len() as u32).to_le_bytes());
615 for (name, npy) in &self.files {
616 let name_bytes = name.as_bytes();
617 out.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
618 out.extend_from_slice(name_bytes);
619 out.extend_from_slice(&(npy.len() as u32).to_le_bytes());
620 out.extend_from_slice(npy);
621 }
622 out
623 }
624 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
626 let mut pos = 0usize;
627 let count = read_u32(data, &mut pos)? as usize;
628 let mut files = Vec::with_capacity(count);
629 for _ in 0..count {
630 let name_len = read_u32(data, &mut pos)? as usize;
631 if pos + name_len > data.len() {
632 return Err("name out of bounds".to_string());
633 }
634 let name = std::str::from_utf8(&data[pos..pos + name_len])
635 .map_err(|e| format!("invalid UTF-8 in name: {e}"))?
636 .to_string();
637 pos += name_len;
638 let npy_len = read_u32(data, &mut pos)? as usize;
639 if pos + npy_len > data.len() {
640 return Err("npy payload out of bounds".to_string());
641 }
642 let npy = data[pos..pos + npy_len].to_vec();
643 pos += npy_len;
644 files.push((name, npy));
645 }
646 Ok(NpzWriter { files })
647 }
648 pub fn get_f64(&self, name: &str) -> Option<Result<(Vec<usize>, Vec<f64>), String>> {
650 self.files
651 .iter()
652 .find(|(n, _)| n == name)
653 .map(|(_, npy)| read_npy_f64(npy))
654 }
655 pub fn get_f32(&self, name: &str) -> Option<Result<(Vec<usize>, Vec<f32>), String>> {
657 self.files
658 .iter()
659 .find(|(n, _)| n == name)
660 .map(|(_, npy)| read_npy_f32(npy))
661 }
662 pub fn get_i32(&self, name: &str) -> Option<Result<(Vec<usize>, Vec<i32>), String>> {
664 self.files
665 .iter()
666 .find(|(n, _)| n == name)
667 .map(|(_, npy)| read_npy_i32(npy))
668 }
669 pub fn get_i64(&self, name: &str) -> Option<Result<(Vec<usize>, Vec<i64>), String>> {
671 self.files
672 .iter()
673 .find(|(n, _)| n == name)
674 .map(|(_, npy)| read_npy_i64(npy))
675 }
676}
677#[allow(dead_code)]
679#[derive(Debug, Clone)]
680pub struct NpyField {
681 pub name: String,
683 pub dtype: NpyDtype,
685 pub count: usize,
687}
688#[allow(dead_code)]
689impl NpyField {
690 pub fn scalar(name: &str, dtype: NpyDtype) -> Self {
692 Self {
693 name: name.to_string(),
694 dtype,
695 count: 1,
696 }
697 }
698 pub fn vector(name: &str, dtype: NpyDtype, count: usize) -> Self {
700 Self {
701 name: name.to_string(),
702 dtype,
703 count,
704 }
705 }
706 pub fn byte_size(&self) -> usize {
708 self.dtype.element_size() * self.count
709 }
710}