1use std::fs::File;
2use std::io::{self, BufRead, Read, Write};
3use flate2::{Compression, write::GzEncoder, read::GzDecoder};
4use serde::{Serialize, Deserialize};
5use std::collections::HashMap;
6
7use crate::nab_array::NDArray;
8
9#[derive(Serialize, Deserialize)]
10struct SerializableNDArray {
11 data: Vec<f64>,
12 shape: Vec<usize>,
13}
14
15pub fn save_nab(filename: &str, array: &NDArray) -> io::Result<()> {
17 let file = File::create(filename)?;
18 let mut encoder = GzEncoder::new(file, Compression::default());
19 let serializable_array = SerializableNDArray {
20 data: array.data().to_vec(),
21 shape: array.shape().to_vec(),
22 };
23 let serialized_data = bincode::serialize(&serializable_array).unwrap();
24 encoder.write_all(&serialized_data)?;
25 encoder.finish()?;
26 Ok(())
27}
28
29pub fn load_nab(filename: &str) -> io::Result<NDArray> {
31 let file = File::open(filename)?;
32 let mut decoder = GzDecoder::new(file);
33 let mut serialized_data = Vec::new();
34 decoder.read_to_end(&mut serialized_data)?;
35 let serializable_array: SerializableNDArray = bincode::deserialize(&serialized_data).unwrap();
36 Ok(NDArray::new(serializable_array.data, serializable_array.shape))
37}
38
39#[allow(dead_code)]
46pub fn savez_nab(filename: &str, arrays: Vec<(&str, &NDArray)>) -> io::Result<()> {
47 let mut file = File::create(filename)?;
48 for (name, array) in arrays {
49 let shape_str = array.shape().iter().map(|s| s.to_string()).collect::<Vec<_>>().join(",");
50 let data_str = array.data().iter().map(|d| d.to_string()).collect::<Vec<_>>().join(",");
51 writeln!(file, "{}:{};{}", name, shape_str, data_str)?;
52 }
53 Ok(())
54}
55
56#[allow(dead_code)]
57pub fn loadz_nab(filename: &str) -> io::Result<HashMap<String, NDArray>> {
58 let file = File::open(filename)?;
59 let mut arrays = HashMap::new();
60
61 for line in io::BufReader::new(file).lines() {
63 let line = line?;
64 let parts: Vec<&str> = line.split(':').collect();
66 if parts.len() != 2 {
67 continue;
68 }
69
70 let name = parts[0].to_string();
71 let shape_and_data: Vec<&str> = parts[1].split(';').collect();
72 if shape_and_data.len() != 2 {
73 continue;
74 }
75
76 let shape: Vec<usize> = shape_and_data[0]
78 .split(',')
79 .filter_map(|s| s.parse().ok())
80 .collect();
81
82 let data: Vec<f64> = shape_and_data[1]
84 .split(',')
85 .filter_map(|s| s.parse().ok())
86 .collect();
87
88 arrays.insert(name, NDArray::new(data, shape));
89 }
90
91 Ok(arrays)
92}