nabla_ml/
nab_mnist.rs

1use crate::nab_array::NDArray;
2use crate::nab_io::save_nab;
3
4pub struct NabMnist;
5
6
7impl NabMnist {
8        /// Converts MNIST CSV data to image and label .nab files
9    /// 
10    /// # Arguments
11    /// 
12    /// * `csv_path` - Path to the CSV file
13    /// * `images_path` - Path where to save the images .nab file
14    /// * `labels_path` - Path where to save the labels .nab file
15    /// * `image_shape` - Shape of a single image (e.g., [28, 28])
16    #[allow(dead_code)]
17    pub fn mnist_csv_to_nab(
18        csv_path: &str,
19        images_path: &str,
20        labels_path: &str,
21        image_shape: Vec<usize>
22    ) -> std::io::Result<()> {
23        let mut rdr = csv::Reader::from_path(csv_path)?;
24        let mut images = Vec::new();
25        let mut labels = Vec::new();
26        let mut sample_count = 0;
27
28        for result in rdr.records() {
29            let record = result?;
30            sample_count += 1;
31
32            if let Some(label) = record.get(0) {
33                labels.push(label.parse::<f64>().map_err(|e| {
34                    std::io::Error::new(std::io::ErrorKind::InvalidData, e)
35                })?);
36            }
37
38            for value in record.iter().skip(1) {
39                let pixel: f64 = value.parse().map_err(|e| {
40                    std::io::Error::new(std::io::ErrorKind::InvalidData, e)
41                })?;
42                images.push(pixel);
43            }
44        }
45
46        let mut full_image_shape = vec![sample_count];
47        full_image_shape.extend(image_shape);
48        let images_array = NDArray::new(images, full_image_shape);
49        save_nab(images_path, &images_array)?;
50
51        let labels_array = NDArray::new(labels, vec![sample_count]);
52        save_nab(labels_path, &labels_array)?;
53
54        Ok(())
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61    use std::io;
62    use crate::nab_io::load_nab;
63    use crate::nab_utils::NabUtils;
64
65    #[test]
66    fn test_mnist_load_and_split_dataset() -> std::io::Result<()> {
67        std::fs::create_dir_all("datasets")?;
68
69        NabMnist::mnist_csv_to_nab(
70            "csv/mnist_test.csv",
71            "datasets/mnist_test_images.nab",
72            "datasets/mnist_test_labels.nab",
73            vec![28, 28]
74        )?;
75
76        let ((train_images, train_labels), (test_images, test_labels)) = 
77            NabUtils::load_and_split_dataset("datasets/mnist_test", 80.0)?;
78
79        assert_eq!(train_images.shape()[0] + test_images.shape()[0], 999);
80        assert_eq!(train_labels.shape()[0] + test_labels.shape()[0], 999);
81
82        // std::fs::remove_file("datasets/mnist_test_images.nab")?;
83        // std::fs::remove_file("datasets/mnist_test_labels.nab")?;
84
85        Ok(())
86    }
87
88    #[test]
89    fn test_mnist_csv_to_nab_conversion() -> io::Result<()> {
90        // Define paths for the CSV and .nab files
91        let csv_path = "csv/mnist_test.csv";
92        let nab_path = "datasets/mnist_test";
93        let expected_shape = vec![999, 28, 28];
94        
95        println!("Starting test with CSV: {}", csv_path);
96
97        // Convert CSV to .nab, skipping the first column
98        NabUtils::csv_to_nab(csv_path, nab_path, expected_shape.clone(), true)?;
99
100        // Load the .nab file
101        let images = load_nab(nab_path)?;
102        println!("Loaded NAB file with shape: {:?}", images.shape());
103
104        // Verify the shape of the data
105        assert_eq!(images.shape(), &expected_shape, 
106            "Shape mismatch: expected {:?}, got {:?}", expected_shape, images.shape());
107
108        // Clean up the .nab file
109        // std::fs::remove_file(nab_path)?;
110        // println!("Test cleanup complete");
111
112        Ok(())
113    }
114    
115
116    #[test]
117    fn test_extract_and_print_sample() -> io::Result<()> {
118        // Ensure the datasets directory exists
119        std::fs::create_dir_all("datasets")?;
120
121        // Convert CSV to .nab files if not already done
122        NabMnist::mnist_csv_to_nab(
123            "csv/mnist_test.csv",
124            "datasets/mnist_test_images.nab",
125            "datasets/mnist_test_labels.nab",
126            vec![28, 28]
127        )?;
128
129        // Load the dataset
130        let ((train_images, train_labels), _) = 
131            NabUtils::load_and_split_dataset("datasets/mnist_test", 80.0)?;
132
133        // Extract and print the 42nd entry
134        println!("Label of 42nd entry: {}", train_labels.get(42));
135        println!("Image of 42nd entry:");
136        let image_42: NDArray = train_images.extract_sample(42);
137        image_42.pretty_print(0);
138
139        // Clean up
140        // std::fs::remove_file("datasets/mnist_test_images.nab")?;
141        // std::fs::remove_file("datasets/mnist_test_labels.nab")?;
142
143        Ok(())
144    }
145
146    #[test]
147    fn test_mnist_normalize() -> std::io::Result<()> {
148        let ((mut train_images, _), _) = 
149            NabUtils::load_and_split_dataset("datasets/mnist_test", 80.0)?;
150        
151        NabUtils::normalize_with_range(&mut train_images, 0.0, 255.0);
152        
153        // Add this to check actual values
154        let gray_image_42 = train_images.extract_sample(42);
155        println!("First few raw values: {:?}", &gray_image_42.data()[..5]);
156        gray_image_42.pretty_print(4); // Add precision parameter to show 3 decimal places
157
158        Ok(())
159    }
160} 
161