1use crate::nab_array::NDArray;
2use crate::nab_io::save_nab;
3
4pub struct NabMnist;
5
6
7impl NabMnist {
8 #[allow(dead_code)]
17 pub fn mnist_csv_to_nab(
18 csv_path: &str,
19 images_path: &str,
20 labels_path: &str,
21 image_shape: Vec<usize>
22 ) -> std::io::Result<()> {
23 let mut rdr = csv::Reader::from_path(csv_path)?;
24 let mut images = Vec::new();
25 let mut labels = Vec::new();
26 let mut sample_count = 0;
27
28 for result in rdr.records() {
29 let record = result?;
30 sample_count += 1;
31
32 if let Some(label) = record.get(0) {
33 labels.push(label.parse::<f64>().map_err(|e| {
34 std::io::Error::new(std::io::ErrorKind::InvalidData, e)
35 })?);
36 }
37
38 for value in record.iter().skip(1) {
39 let pixel: f64 = value.parse().map_err(|e| {
40 std::io::Error::new(std::io::ErrorKind::InvalidData, e)
41 })?;
42 images.push(pixel);
43 }
44 }
45
46 let mut full_image_shape = vec![sample_count];
47 full_image_shape.extend(image_shape);
48 let images_array = NDArray::new(images, full_image_shape);
49 save_nab(images_path, &images_array)?;
50
51 let labels_array = NDArray::new(labels, vec![sample_count]);
52 save_nab(labels_path, &labels_array)?;
53
54 Ok(())
55 }
56}
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61 use std::io;
62 use crate::nab_io::load_nab;
63 use crate::nab_utils::NabUtils;
64
65 #[test]
66 fn test_mnist_load_and_split_dataset() -> std::io::Result<()> {
67 std::fs::create_dir_all("datasets")?;
68
69 NabMnist::mnist_csv_to_nab(
70 "csv/mnist_test.csv",
71 "datasets/mnist_test_images.nab",
72 "datasets/mnist_test_labels.nab",
73 vec![28, 28]
74 )?;
75
76 let ((train_images, train_labels), (test_images, test_labels)) =
77 NabUtils::load_and_split_dataset("datasets/mnist_test", 80.0)?;
78
79 assert_eq!(train_images.shape()[0] + test_images.shape()[0], 999);
80 assert_eq!(train_labels.shape()[0] + test_labels.shape()[0], 999);
81
82 Ok(())
86 }
87
88 #[test]
89 fn test_mnist_csv_to_nab_conversion() -> io::Result<()> {
90 let csv_path = "csv/mnist_test.csv";
92 let nab_path = "datasets/mnist_test";
93 let expected_shape = vec![999, 28, 28];
94
95 println!("Starting test with CSV: {}", csv_path);
96
97 NabUtils::csv_to_nab(csv_path, nab_path, expected_shape.clone(), true)?;
99
100 let images = load_nab(nab_path)?;
102 println!("Loaded NAB file with shape: {:?}", images.shape());
103
104 assert_eq!(images.shape(), &expected_shape,
106 "Shape mismatch: expected {:?}, got {:?}", expected_shape, images.shape());
107
108 Ok(())
113 }
114
115
116 #[test]
117 fn test_extract_and_print_sample() -> io::Result<()> {
118 std::fs::create_dir_all("datasets")?;
120
121 NabMnist::mnist_csv_to_nab(
123 "csv/mnist_test.csv",
124 "datasets/mnist_test_images.nab",
125 "datasets/mnist_test_labels.nab",
126 vec![28, 28]
127 )?;
128
129 let ((train_images, train_labels), _) =
131 NabUtils::load_and_split_dataset("datasets/mnist_test", 80.0)?;
132
133 println!("Label of 42nd entry: {}", train_labels.get(42));
135 println!("Image of 42nd entry:");
136 let image_42: NDArray = train_images.extract_sample(42);
137 image_42.pretty_print(0);
138
139 Ok(())
144 }
145
146 #[test]
147 fn test_mnist_normalize() -> std::io::Result<()> {
148 let ((mut train_images, _), _) =
149 NabUtils::load_and_split_dataset("datasets/mnist_test", 80.0)?;
150
151 NabUtils::normalize_with_range(&mut train_images, 0.0, 255.0);
152
153 let gray_image_42 = train_images.extract_sample(42);
155 println!("First few raw values: {:?}", &gray_image_42.data()[..5]);
156 gray_image_42.pretty_print(4); Ok(())
159 }
160}
161