ghostflow_data/
datasets.rs

1//! Common Dataset Loaders
2//!
3//! Pre-built loaders for popular ML datasets.
4
5use ghostflow_core::tensor::Tensor;
6use std::path::Path;
7use std::fs::File;
8use std::io::{Read, BufReader};
9
10/// MNIST dataset loader
11/// 
12/// Loads the MNIST handwritten digits dataset.
13pub struct MNIST {
14    train_images: Vec<Vec<f32>>,
15    train_labels: Vec<u8>,
16    test_images: Vec<Vec<f32>>,
17    test_labels: Vec<u8>,
18    image_size: (usize, usize),
19}
20
21impl MNIST {
22    /// Load MNIST from directory containing the 4 files
23    pub fn load<P: AsRef<Path>>(data_dir: P) -> std::io::Result<Self> {
24        let data_dir = data_dir.as_ref();
25
26        let train_images = Self::load_images(
27            data_dir.join("train-images-idx3-ubyte"),
28        )?;
29        let train_labels = Self::load_labels(
30            data_dir.join("train-labels-idx1-ubyte"),
31        )?;
32        let test_images = Self::load_images(
33            data_dir.join("t10k-images-idx3-ubyte"),
34        )?;
35        let test_labels = Self::load_labels(
36            data_dir.join("t10k-labels-idx1-ubyte"),
37        )?;
38
39        Ok(Self {
40            train_images,
41            train_labels,
42            test_images,
43            test_labels,
44            image_size: (28, 28),
45        })
46    }
47
48    fn load_images<P: AsRef<Path>>(path: P) -> std::io::Result<Vec<Vec<f32>>> {
49        let file = File::open(path)?;
50        let mut reader = BufReader::new(file);
51
52        // Read magic number
53        let mut magic = [0u8; 4];
54        reader.read_exact(&mut magic)?;
55        let magic_num = u32::from_be_bytes(magic);
56        
57        if magic_num != 2051 {
58            return Err(std::io::Error::new(
59                std::io::ErrorKind::InvalidData,
60                "Invalid MNIST image file",
61            ));
62        }
63
64        // Read dimensions
65        let mut dims = [0u8; 12];
66        reader.read_exact(&mut dims)?;
67        let num_images = u32::from_be_bytes([dims[0], dims[1], dims[2], dims[3]]) as usize;
68        let rows = u32::from_be_bytes([dims[4], dims[5], dims[6], dims[7]]) as usize;
69        let cols = u32::from_be_bytes([dims[8], dims[9], dims[10], dims[11]]) as usize;
70
71        // Read images
72        let mut images = Vec::with_capacity(num_images);
73        let image_size = rows * cols;
74
75        for _ in 0..num_images {
76            let mut image_bytes = vec![0u8; image_size];
77            reader.read_exact(&mut image_bytes)?;
78            
79            // Normalize to [0, 1]
80            let image: Vec<f32> = image_bytes
81                .iter()
82                .map(|&b| b as f32 / 255.0)
83                .collect();
84            
85            images.push(image);
86        }
87
88        Ok(images)
89    }
90
91    fn load_labels<P: AsRef<Path>>(path: P) -> std::io::Result<Vec<u8>> {
92        let file = File::open(path)?;
93        let mut reader = BufReader::new(file);
94
95        // Read magic number
96        let mut magic = [0u8; 4];
97        reader.read_exact(&mut magic)?;
98        let magic_num = u32::from_be_bytes(magic);
99        
100        if magic_num != 2049 {
101            return Err(std::io::Error::new(
102                std::io::ErrorKind::InvalidData,
103                "Invalid MNIST label file",
104            ));
105        }
106
107        // Read number of labels
108        let mut num_bytes = [0u8; 4];
109        reader.read_exact(&mut num_bytes)?;
110        let num_labels = u32::from_be_bytes(num_bytes) as usize;
111
112        // Read labels
113        let mut labels = vec![0u8; num_labels];
114        reader.read_exact(&mut labels)?;
115
116        Ok(labels)
117    }
118
119    /// Get training data
120    pub fn train_data(&self) -> (&[Vec<f32>], &[u8]) {
121        (&self.train_images, &self.train_labels)
122    }
123
124    /// Get test data
125    pub fn test_data(&self) -> (&[Vec<f32>], &[u8]) {
126        (&self.test_images, &self.test_labels)
127    }
128
129    /// Get a batch of training data as tensors
130    pub fn train_batch(&self, start: usize, batch_size: usize) -> (Tensor, Tensor) {
131        let end = (start + batch_size).min(self.train_images.len());
132        let batch_images: Vec<f32> = self.train_images[start..end]
133            .iter()
134            .flat_map(|img| img.iter().copied())
135            .collect();
136        
137        let batch_labels: Vec<f32> = self.train_labels[start..end]
138            .iter()
139            .map(|&label| label as f32)
140            .collect();
141
142        let images_tensor = Tensor::from_slice(
143            &batch_images,
144            &[end - start, 1, 28, 28],
145        ).unwrap();
146
147        let labels_tensor = Tensor::from_slice(
148            &batch_labels,
149            &[end - start],
150        ).unwrap();
151
152        (images_tensor, labels_tensor)
153    }
154
155    /// Get number of training samples
156    pub fn train_size(&self) -> usize {
157        self.train_images.len()
158    }
159
160    /// Get number of test samples
161    pub fn test_size(&self) -> usize {
162        self.test_images.len()
163    }
164}
165
166/// CIFAR-10 dataset loader
167pub struct CIFAR10 {
168    train_images: Vec<Vec<f32>>,
169    train_labels: Vec<u8>,
170    test_images: Vec<Vec<f32>>,
171    test_labels: Vec<u8>,
172}
173
174impl CIFAR10 {
175    /// Load CIFAR-10 from directory
176    pub fn load<P: AsRef<Path>>(data_dir: P) -> std::io::Result<Self> {
177        let data_dir = data_dir.as_ref();
178
179        let mut train_images = Vec::new();
180        let mut train_labels = Vec::new();
181
182        // Load training batches
183        for i in 1..=5 {
184            let batch_file = data_dir.join(format!("data_batch_{}.bin", i));
185            let (images, labels) = Self::load_batch(&batch_file)?;
186            train_images.extend(images);
187            train_labels.extend(labels);
188        }
189
190        // Load test batch
191        let test_file = data_dir.join("test_batch.bin");
192        let (test_images, test_labels) = Self::load_batch(&test_file)?;
193
194        Ok(Self {
195            train_images,
196            train_labels,
197            test_images,
198            test_labels,
199        })
200    }
201
202    fn load_batch<P: AsRef<Path>>(path: P) -> std::io::Result<(Vec<Vec<f32>>, Vec<u8>)> {
203        let file = File::open(path)?;
204        let mut reader = BufReader::new(file);
205
206        let mut images = Vec::new();
207        let mut labels = Vec::new();
208
209        // Each record is 3073 bytes: 1 label + 3072 pixels (32x32x3)
210        loop {
211            let mut label = [0u8; 1];
212            match reader.read_exact(&mut label) {
213                Ok(_) => {},
214                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
215                Err(e) => return Err(e),
216            }
217
218            let mut image_bytes = [0u8; 3072];
219            reader.read_exact(&mut image_bytes)?;
220
221            // Normalize to [0, 1]
222            let image: Vec<f32> = image_bytes
223                .iter()
224                .map(|&b| b as f32 / 255.0)
225                .collect();
226
227            images.push(image);
228            labels.push(label[0]);
229        }
230
231        Ok((images, labels))
232    }
233
234    /// Get training data
235    pub fn train_data(&self) -> (&[Vec<f32>], &[u8]) {
236        (&self.train_images, &self.train_labels)
237    }
238
239    /// Get test data
240    pub fn test_data(&self) -> (&[Vec<f32>], &[u8]) {
241        (&self.test_images, &self.test_labels)
242    }
243
244    /// Get a batch of training data as tensors
245    pub fn train_batch(&self, start: usize, batch_size: usize) -> (Tensor, Tensor) {
246        let end = (start + batch_size).min(self.train_images.len());
247        let batch_images: Vec<f32> = self.train_images[start..end]
248            .iter()
249            .flat_map(|img| img.iter().copied())
250            .collect();
251        
252        let batch_labels: Vec<f32> = self.train_labels[start..end]
253            .iter()
254            .map(|&label| label as f32)
255            .collect();
256
257        let images_tensor = Tensor::from_slice(
258            &batch_images,
259            &[end - start, 3, 32, 32],
260        ).unwrap();
261
262        let labels_tensor = Tensor::from_slice(
263            &batch_labels,
264            &[end - start],
265        ).unwrap();
266
267        (images_tensor, labels_tensor)
268    }
269
270    /// Get number of training samples
271    pub fn train_size(&self) -> usize {
272        self.train_images.len()
273    }
274
275    /// Get number of test samples
276    pub fn test_size(&self) -> usize {
277        self.test_images.len()
278    }
279}
280
281/// Generic dataset trait
282pub trait Dataset {
283    /// Get a single sample
284    fn get(&self, index: usize) -> (Tensor, Tensor);
285    
286    /// Get dataset size
287    fn len(&self) -> usize;
288    
289    /// Check if dataset is empty
290    fn is_empty(&self) -> bool {
291        self.len() == 0
292    }
293}
294
295/// In-memory dataset
296pub struct InMemoryDataset {
297    data: Vec<(Tensor, Tensor)>,
298}
299
300impl InMemoryDataset {
301    pub fn new(data: Vec<(Tensor, Tensor)>) -> Self {
302        Self { data }
303    }
304}
305
306impl Dataset for InMemoryDataset {
307    fn get(&self, index: usize) -> (Tensor, Tensor) {
308        self.data[index].clone()
309    }
310
311    fn len(&self) -> usize {
312        self.data.len()
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn test_in_memory_dataset() {
322        let data = vec![
323            (
324                Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap(),
325                Tensor::from_slice(&[0.0f32], &[1]).unwrap(),
326            ),
327            (
328                Tensor::from_slice(&[3.0f32, 4.0], &[2]).unwrap(),
329                Tensor::from_slice(&[1.0f32], &[1]).unwrap(),
330            ),
331        ];
332
333        let dataset = InMemoryDataset::new(data);
334        
335        assert_eq!(dataset.len(), 2);
336        assert!(!dataset.is_empty());
337
338        let (x, y) = dataset.get(0);
339        assert_eq!(x.shape().dims(), &[2]);
340        assert_eq!(y.shape().dims(), &[1]);
341    }
342}