1use ghostflow_core::tensor::Tensor;
6use std::path::Path;
7use std::fs::File;
8use std::io::{Read, BufReader};
9
10pub struct MNIST {
14 train_images: Vec<Vec<f32>>,
15 train_labels: Vec<u8>,
16 test_images: Vec<Vec<f32>>,
17 test_labels: Vec<u8>,
18 image_size: (usize, usize),
19}
20
21impl MNIST {
22 pub fn load<P: AsRef<Path>>(data_dir: P) -> std::io::Result<Self> {
24 let data_dir = data_dir.as_ref();
25
26 let train_images = Self::load_images(
27 data_dir.join("train-images-idx3-ubyte"),
28 )?;
29 let train_labels = Self::load_labels(
30 data_dir.join("train-labels-idx1-ubyte"),
31 )?;
32 let test_images = Self::load_images(
33 data_dir.join("t10k-images-idx3-ubyte"),
34 )?;
35 let test_labels = Self::load_labels(
36 data_dir.join("t10k-labels-idx1-ubyte"),
37 )?;
38
39 Ok(Self {
40 train_images,
41 train_labels,
42 test_images,
43 test_labels,
44 image_size: (28, 28),
45 })
46 }
47
48 fn load_images<P: AsRef<Path>>(path: P) -> std::io::Result<Vec<Vec<f32>>> {
49 let file = File::open(path)?;
50 let mut reader = BufReader::new(file);
51
52 let mut magic = [0u8; 4];
54 reader.read_exact(&mut magic)?;
55 let magic_num = u32::from_be_bytes(magic);
56
57 if magic_num != 2051 {
58 return Err(std::io::Error::new(
59 std::io::ErrorKind::InvalidData,
60 "Invalid MNIST image file",
61 ));
62 }
63
64 let mut dims = [0u8; 12];
66 reader.read_exact(&mut dims)?;
67 let num_images = u32::from_be_bytes([dims[0], dims[1], dims[2], dims[3]]) as usize;
68 let rows = u32::from_be_bytes([dims[4], dims[5], dims[6], dims[7]]) as usize;
69 let cols = u32::from_be_bytes([dims[8], dims[9], dims[10], dims[11]]) as usize;
70
71 let mut images = Vec::with_capacity(num_images);
73 let image_size = rows * cols;
74
75 for _ in 0..num_images {
76 let mut image_bytes = vec![0u8; image_size];
77 reader.read_exact(&mut image_bytes)?;
78
79 let image: Vec<f32> = image_bytes
81 .iter()
82 .map(|&b| b as f32 / 255.0)
83 .collect();
84
85 images.push(image);
86 }
87
88 Ok(images)
89 }
90
91 fn load_labels<P: AsRef<Path>>(path: P) -> std::io::Result<Vec<u8>> {
92 let file = File::open(path)?;
93 let mut reader = BufReader::new(file);
94
95 let mut magic = [0u8; 4];
97 reader.read_exact(&mut magic)?;
98 let magic_num = u32::from_be_bytes(magic);
99
100 if magic_num != 2049 {
101 return Err(std::io::Error::new(
102 std::io::ErrorKind::InvalidData,
103 "Invalid MNIST label file",
104 ));
105 }
106
107 let mut num_bytes = [0u8; 4];
109 reader.read_exact(&mut num_bytes)?;
110 let num_labels = u32::from_be_bytes(num_bytes) as usize;
111
112 let mut labels = vec![0u8; num_labels];
114 reader.read_exact(&mut labels)?;
115
116 Ok(labels)
117 }
118
119 pub fn train_data(&self) -> (&[Vec<f32>], &[u8]) {
121 (&self.train_images, &self.train_labels)
122 }
123
124 pub fn test_data(&self) -> (&[Vec<f32>], &[u8]) {
126 (&self.test_images, &self.test_labels)
127 }
128
129 pub fn train_batch(&self, start: usize, batch_size: usize) -> (Tensor, Tensor) {
131 let end = (start + batch_size).min(self.train_images.len());
132 let batch_images: Vec<f32> = self.train_images[start..end]
133 .iter()
134 .flat_map(|img| img.iter().copied())
135 .collect();
136
137 let batch_labels: Vec<f32> = self.train_labels[start..end]
138 .iter()
139 .map(|&label| label as f32)
140 .collect();
141
142 let images_tensor = Tensor::from_slice(
143 &batch_images,
144 &[end - start, 1, 28, 28],
145 ).unwrap();
146
147 let labels_tensor = Tensor::from_slice(
148 &batch_labels,
149 &[end - start],
150 ).unwrap();
151
152 (images_tensor, labels_tensor)
153 }
154
155 pub fn train_size(&self) -> usize {
157 self.train_images.len()
158 }
159
160 pub fn test_size(&self) -> usize {
162 self.test_images.len()
163 }
164}
165
166pub struct CIFAR10 {
168 train_images: Vec<Vec<f32>>,
169 train_labels: Vec<u8>,
170 test_images: Vec<Vec<f32>>,
171 test_labels: Vec<u8>,
172}
173
174impl CIFAR10 {
175 pub fn load<P: AsRef<Path>>(data_dir: P) -> std::io::Result<Self> {
177 let data_dir = data_dir.as_ref();
178
179 let mut train_images = Vec::new();
180 let mut train_labels = Vec::new();
181
182 for i in 1..=5 {
184 let batch_file = data_dir.join(format!("data_batch_{}.bin", i));
185 let (images, labels) = Self::load_batch(&batch_file)?;
186 train_images.extend(images);
187 train_labels.extend(labels);
188 }
189
190 let test_file = data_dir.join("test_batch.bin");
192 let (test_images, test_labels) = Self::load_batch(&test_file)?;
193
194 Ok(Self {
195 train_images,
196 train_labels,
197 test_images,
198 test_labels,
199 })
200 }
201
202 fn load_batch<P: AsRef<Path>>(path: P) -> std::io::Result<(Vec<Vec<f32>>, Vec<u8>)> {
203 let file = File::open(path)?;
204 let mut reader = BufReader::new(file);
205
206 let mut images = Vec::new();
207 let mut labels = Vec::new();
208
209 loop {
211 let mut label = [0u8; 1];
212 match reader.read_exact(&mut label) {
213 Ok(_) => {},
214 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
215 Err(e) => return Err(e),
216 }
217
218 let mut image_bytes = [0u8; 3072];
219 reader.read_exact(&mut image_bytes)?;
220
221 let image: Vec<f32> = image_bytes
223 .iter()
224 .map(|&b| b as f32 / 255.0)
225 .collect();
226
227 images.push(image);
228 labels.push(label[0]);
229 }
230
231 Ok((images, labels))
232 }
233
234 pub fn train_data(&self) -> (&[Vec<f32>], &[u8]) {
236 (&self.train_images, &self.train_labels)
237 }
238
239 pub fn test_data(&self) -> (&[Vec<f32>], &[u8]) {
241 (&self.test_images, &self.test_labels)
242 }
243
244 pub fn train_batch(&self, start: usize, batch_size: usize) -> (Tensor, Tensor) {
246 let end = (start + batch_size).min(self.train_images.len());
247 let batch_images: Vec<f32> = self.train_images[start..end]
248 .iter()
249 .flat_map(|img| img.iter().copied())
250 .collect();
251
252 let batch_labels: Vec<f32> = self.train_labels[start..end]
253 .iter()
254 .map(|&label| label as f32)
255 .collect();
256
257 let images_tensor = Tensor::from_slice(
258 &batch_images,
259 &[end - start, 3, 32, 32],
260 ).unwrap();
261
262 let labels_tensor = Tensor::from_slice(
263 &batch_labels,
264 &[end - start],
265 ).unwrap();
266
267 (images_tensor, labels_tensor)
268 }
269
270 pub fn train_size(&self) -> usize {
272 self.train_images.len()
273 }
274
275 pub fn test_size(&self) -> usize {
277 self.test_images.len()
278 }
279}
280
281pub trait Dataset {
283 fn get(&self, index: usize) -> (Tensor, Tensor);
285
286 fn len(&self) -> usize;
288
289 fn is_empty(&self) -> bool {
291 self.len() == 0
292 }
293}
294
295pub struct InMemoryDataset {
297 data: Vec<(Tensor, Tensor)>,
298}
299
300impl InMemoryDataset {
301 pub fn new(data: Vec<(Tensor, Tensor)>) -> Self {
302 Self { data }
303 }
304}
305
306impl Dataset for InMemoryDataset {
307 fn get(&self, index: usize) -> (Tensor, Tensor) {
308 self.data[index].clone()
309 }
310
311 fn len(&self) -> usize {
312 self.data.len()
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_in_memory_dataset() {
322 let data = vec![
323 (
324 Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap(),
325 Tensor::from_slice(&[0.0f32], &[1]).unwrap(),
326 ),
327 (
328 Tensor::from_slice(&[3.0f32, 4.0], &[2]).unwrap(),
329 Tensor::from_slice(&[1.0f32], &[1]).unwrap(),
330 ),
331 ];
332
333 let dataset = InMemoryDataset::new(data);
334
335 assert_eq!(dataset.len(), 2);
336 assert!(!dataset.is_empty());
337
338 let (x, y) = dataset.get(0);
339 assert_eq!(x.shape().dims(), &[2]);
340 assert_eq!(y.shape().dims(), &[1]);
341 }
342}