nabla_ml/
nab_io.rs

1use std::fs::File;
2use std::io::{self, BufRead, Read, Write};
3use flate2::{Compression, write::GzEncoder, read::GzDecoder};
4use serde::{Serialize, Deserialize};
5use std::collections::HashMap;
6
7use crate::nab_array::NDArray;
8
9#[derive(Serialize, Deserialize)]
10struct SerializableNDArray {
11    data: Vec<f64>,
12    shape: Vec<usize>,
13}
14
15/// Saves an NDArray to a .nab file with compression
16pub fn save_nab(filename: &str, array: &NDArray) -> io::Result<()> {
17    let file = File::create(filename)?;
18    let mut encoder = GzEncoder::new(file, Compression::default());
19    let serializable_array = SerializableNDArray {
20        data: array.data().to_vec(),
21        shape: array.shape().to_vec(),
22    };
23    let serialized_data = bincode::serialize(&serializable_array).unwrap();
24    encoder.write_all(&serialized_data)?;
25    encoder.finish()?;
26    Ok(())
27}
28
29/// Loads an NDArray from a compressed .nab file
30pub fn load_nab(filename: &str) -> io::Result<NDArray> {
31    let file = File::open(filename)?;
32    let mut decoder = GzDecoder::new(file);
33    let mut serialized_data = Vec::new();
34    decoder.read_to_end(&mut serialized_data)?;
35    let serializable_array: SerializableNDArray = bincode::deserialize(&serialized_data).unwrap();
36    Ok(NDArray::new(serializable_array.data, serializable_array.shape))
37}
38
39/// Saves multiple NDArrays to a .nab file
40///
41/// # Arguments
42///
43/// * `filename` - The name of the file to save the arrays to.
44/// * `arrays` - A vector of tuples containing the name and NDArray to save.
45#[allow(dead_code)]
46pub fn savez_nab(filename: &str, arrays: Vec<(&str, &NDArray)>) -> io::Result<()> {
47    let mut file = File::create(filename)?;
48    for (name, array) in arrays {
49        let shape_str = array.shape().iter().map(|s| s.to_string()).collect::<Vec<_>>().join(",");
50        let data_str = array.data().iter().map(|d| d.to_string()).collect::<Vec<_>>().join(",");
51        writeln!(file, "{}:{};{}", name, shape_str, data_str)?;
52    }
53    Ok(())
54}
55
56#[allow(dead_code)]
57pub fn loadz_nab(filename: &str) -> io::Result<HashMap<String, NDArray>> {
58    let file = File::open(filename)?;
59    let mut arrays = HashMap::new();
60    
61    // Read the file line by line
62    for line in io::BufReader::new(file).lines() {
63        let line = line?;
64        // Split the line into name, shape, and data parts
65        let parts: Vec<&str> = line.split(':').collect();
66        if parts.len() != 2 {
67            continue;
68        }
69        
70        let name = parts[0].to_string();
71        let shape_and_data: Vec<&str> = parts[1].split(';').collect();
72        if shape_and_data.len() != 2 {
73            continue;
74        }
75        
76        // Parse shape
77        let shape: Vec<usize> = shape_and_data[0]
78            .split(',')
79            .filter_map(|s| s.parse().ok())
80            .collect();
81            
82        // Parse data
83        let data: Vec<f64> = shape_and_data[1]
84            .split(',')
85            .filter_map(|s| s.parse().ok())
86            .collect();
87            
88        arrays.insert(name, NDArray::new(data, shape));
89    }
90    
91    Ok(arrays)
92}