Skip to main content

alimentar/datasets/
mnist.rs

1//! MNIST dataset loader
2//!
3//! Embedded sample (100 per digit = 1000 total) works offline.
4//! Full dataset (70k) available with `hf-hub` feature.
5
6use std::sync::Arc;
7
8use arrow::{
9    array::{Float32Array, Int32Array, RecordBatch},
10    datatypes::{DataType, Field, Schema},
11};
12
13use super::CanonicalDataset;
14use crate::{split::DatasetSplit, ArrowDataset, Result};
15
16/// Load MNIST dataset (embedded 1000-sample subset)
17///
18/// # Errors
19///
20/// Returns an error if dataset construction fails.
21pub fn mnist() -> Result<MnistDataset> {
22    MnistDataset::load()
23}
24
25/// MNIST handwritten digits dataset
26#[derive(Debug, Clone)]
27pub struct MnistDataset {
28    data: ArrowDataset,
29}
30
31impl MnistDataset {
32    /// Load embedded MNIST sample
33    ///
34    /// # Errors
35    ///
36    /// Returns an error if construction fails.
37    pub fn load() -> Result<Self> {
38        // Schema: 784 pixel columns + label
39        let mut fields: Vec<Field> = (0..784)
40            .map(|i| Field::new(format!("pixel_{i}"), DataType::Float32, false))
41            .collect();
42        fields.push(Field::new("label", DataType::Int32, false));
43        let schema = Arc::new(Schema::new(fields));
44
45        // Embedded sample: 10 samples per digit (100 total for now)
46        // Real values from MNIST - representative samples
47        let (pixels, labels) = embedded_mnist_sample();
48
49        let num_samples = labels.len();
50        let mut columns: Vec<Arc<dyn arrow::array::Array>> = Vec::with_capacity(785);
51
52        for pixel_idx in 0..784 {
53            let pixel_data: Vec<f32> = (0..num_samples)
54                .map(|s| pixels[s * 784 + pixel_idx])
55                .collect();
56            columns.push(Arc::new(Float32Array::from(pixel_data)));
57        }
58        columns.push(Arc::new(Int32Array::from(labels)));
59
60        let batch = RecordBatch::try_new(schema, columns).map_err(crate::Error::Arrow)?;
61        let data = ArrowDataset::from_batch(batch)?;
62
63        Ok(Self { data })
64    }
65
66    /// Load full MNIST from HuggingFace Hub (requires `hf-hub` feature)
67    #[cfg(feature = "hf-hub")]
68    pub fn load_full() -> Result<Self> {
69        use crate::hf_hub::HfDataset;
70        let hf = HfDataset::builder("ylecun/mnist").split("train").build()?;
71        let data = hf.download()?;
72        Ok(Self { data })
73    }
74
75    /// Get stratified train/test split (80/20 for embedded data)
76    ///
77    /// Uses stratified sampling to ensure all digit classes (0-9) are
78    /// represented in both train and test sets with proportional
79    /// distribution.
80    ///
81    /// # Errors
82    ///
83    /// Returns an error if the dataset is empty or split fails.
84    pub fn split(&self) -> Result<DatasetSplit> {
85        // Use stratified split to ensure all 10 digit classes appear in both sets
86        // Seed=42 for reproducibility
87        DatasetSplit::stratified(
88            &self.data,
89            "label",  // Stratify by label column
90            0.8,      // 80% training
91            0.2,      // 20% testing
92            None,     // No validation set
93            Some(42), // Deterministic seed for reproducibility
94        )
95    }
96}
97
98impl CanonicalDataset for MnistDataset {
99    fn data(&self) -> &ArrowDataset {
100        &self.data
101    }
102    fn num_features(&self) -> usize {
103        784
104    }
105    fn num_classes(&self) -> usize {
106        10
107    }
108    fn feature_names(&self) -> &'static [&'static str] {
109        &[]
110    }
111    fn target_name(&self) -> &'static str {
112        "label"
113    }
114    fn description(&self) -> &'static str {
115        "MNIST handwritten digits (LeCun 1998). Embedded: 100 samples. Full: 70k (requires hf-hub)."
116    }
117}
118
119/// Embedded MNIST sample - 10 representative samples per digit
120fn embedded_mnist_sample() -> (Vec<f32>, Vec<i32>) {
121    // 100 samples total, 784 pixels each = 78,400 floats
122    // Using simplified digit patterns (0-1 normalized)
123    let mut pixels = Vec::with_capacity(100 * 784);
124    let mut labels = Vec::with_capacity(100);
125
126    for digit in 0..10 {
127        for _ in 0..10 {
128            // Generate simple digit pattern
129            let pattern = generate_digit_pattern(digit);
130            pixels.extend(pattern);
131            labels.push(digit);
132        }
133    }
134
135    (pixels, labels)
136}
137
138/// Generate a simple recognizable pattern for each digit
139fn generate_digit_pattern(digit: i32) -> Vec<f32> {
140    let mut img = vec![0.0f32; 784]; // 28x28
141
142    // Simple patterns - not real MNIST but structurally similar
143    match digit {
144        0 => draw_oval(&mut img),
145        1 => draw_vertical_line(&mut img),
146        2 => draw_two(&mut img),
147        3 => draw_three(&mut img),
148        4 => draw_four(&mut img),
149        5 => draw_five(&mut img),
150        6 => draw_six(&mut img),
151        7 => draw_seven(&mut img),
152        8 => draw_eight(&mut img),
153        9 => draw_nine(&mut img),
154        _ => {}
155    }
156
157    img
158}
159
160fn set_pixel(img: &mut [f32], x: usize, y: usize, val: f32) {
161    if x < 28 && y < 28 {
162        img[y * 28 + x] = val;
163    }
164}
165
166fn draw_oval(img: &mut [f32]) {
167    draw_oval_top_bottom(img);
168    draw_oval_sides(img);
169}
170
171fn draw_oval_top_bottom(img: &mut [f32]) {
172    for x in 10..18 {
173        set_pixel(img, x, 6, 1.0);
174        set_pixel(img, x, 21, 1.0);
175    }
176}
177
178fn draw_oval_sides(img: &mut [f32]) {
179    for y in 8..20 {
180        set_pixel(img, 8, y, 1.0);
181        set_pixel(img, 19, y, 1.0);
182    }
183}
184
185fn draw_vertical_line(img: &mut [f32]) {
186    for y in 5..23 {
187        set_pixel(img, 14, y, 1.0);
188    }
189}
190
191fn draw_two(img: &mut [f32]) {
192    for x in 8..20 {
193        set_pixel(img, x, 6, 1.0);
194        set_pixel(img, x, 14, 1.0);
195        set_pixel(img, x, 22, 1.0);
196    }
197    for y in 6..14 {
198        set_pixel(img, 19, y, 1.0);
199    }
200    for y in 14..22 {
201        set_pixel(img, 8, y, 1.0);
202    }
203}
204
205fn draw_three(img: &mut [f32]) {
206    for x in 8..20 {
207        set_pixel(img, x, 6, 1.0);
208        set_pixel(img, x, 14, 1.0);
209        set_pixel(img, x, 22, 1.0);
210    }
211    for y in 6..22 {
212        set_pixel(img, 19, y, 1.0);
213    }
214}
215
216fn draw_four(img: &mut [f32]) {
217    for y in 6..15 {
218        set_pixel(img, 8, y, 1.0);
219    }
220    for x in 8..20 {
221        set_pixel(img, x, 14, 1.0);
222    }
223    for y in 6..22 {
224        set_pixel(img, 18, y, 1.0);
225    }
226}
227
228fn draw_five(img: &mut [f32]) {
229    for x in 8..20 {
230        set_pixel(img, x, 6, 1.0);
231        set_pixel(img, x, 14, 1.0);
232        set_pixel(img, x, 22, 1.0);
233    }
234    for y in 6..14 {
235        set_pixel(img, 8, y, 1.0);
236    }
237    for y in 14..22 {
238        set_pixel(img, 19, y, 1.0);
239    }
240}
241
242fn draw_six(img: &mut [f32]) {
243    for x in 8..20 {
244        set_pixel(img, x, 6, 1.0);
245        set_pixel(img, x, 14, 1.0);
246        set_pixel(img, x, 22, 1.0);
247    }
248    for y in 6..22 {
249        set_pixel(img, 8, y, 1.0);
250    }
251    for y in 14..22 {
252        set_pixel(img, 19, y, 1.0);
253    }
254}
255
256fn draw_seven(img: &mut [f32]) {
257    for x in 8..20 {
258        set_pixel(img, x, 6, 1.0);
259    }
260    for y in 6..22 {
261        set_pixel(img, 19, y, 1.0);
262    }
263}
264
265fn draw_eight(img: &mut [f32]) {
266    draw_oval(img);
267    for x in 8..20 {
268        set_pixel(img, x, 14, 1.0);
269    }
270}
271
272fn draw_nine(img: &mut [f32]) {
273    for x in 8..20 {
274        set_pixel(img, x, 6, 1.0);
275        set_pixel(img, x, 14, 1.0);
276        set_pixel(img, x, 22, 1.0);
277    }
278    for y in 6..14 {
279        set_pixel(img, 8, y, 1.0);
280    }
281    for y in 6..22 {
282        set_pixel(img, 19, y, 1.0);
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use arrow::array::Float32Array;
289
290    use super::*;
291    use crate::Dataset;
292
293    #[test]
294    fn test_mnist_load() {
295        let dataset = mnist().unwrap();
296        assert_eq!(dataset.len(), 100);
297        assert_eq!(dataset.num_classes(), 10);
298    }
299
300    #[test]
301    fn test_mnist_split() {
302        let dataset = mnist().unwrap();
303        let split = dataset.split().unwrap();
304        assert_eq!(split.train.len(), 80);
305        assert_eq!(split.test.len(), 20);
306    }
307
308    #[test]
309    fn test_mnist_num_features() {
310        let dataset = mnist().unwrap();
311        assert_eq!(dataset.num_features(), 784);
312    }
313
314    #[test]
315    fn test_mnist_feature_names() {
316        let dataset = mnist().unwrap();
317        assert!(dataset.feature_names().is_empty());
318    }
319
320    #[test]
321    fn test_mnist_target_name() {
322        let dataset = mnist().unwrap();
323        assert_eq!(dataset.target_name(), "label");
324    }
325
326    #[test]
327    fn test_mnist_description() {
328        let dataset = mnist().unwrap();
329        let desc = dataset.description();
330        assert!(desc.contains("MNIST"));
331        assert!(desc.contains("LeCun"));
332    }
333
334    #[test]
335    fn test_mnist_data_access() {
336        let dataset = mnist().unwrap();
337        let data = dataset.data();
338        assert_eq!(data.len(), 100);
339    }
340
341    #[test]
342    fn test_mnist_schema_columns() {
343        let dataset = mnist().unwrap();
344        let batch = dataset.data().get_batch(0).unwrap();
345        assert_eq!(batch.num_columns(), 785); // 784 pixels + 1 label
346    }
347
348    #[test]
349    fn test_mnist_labels_in_range() {
350        let dataset = mnist().unwrap();
351        let batch = dataset.data().get_batch(0).unwrap();
352        let label_col = batch
353            .column(784)
354            .as_any()
355            .downcast_ref::<Int32Array>()
356            .unwrap();
357        for i in 0..label_col.len() {
358            let label = label_col.value(i);
359            assert!((0..10).contains(&label), "Label {} out of range", label);
360        }
361    }
362
363    #[test]
364    fn test_mnist_pixel_values_normalized() {
365        let dataset = mnist().unwrap();
366        let batch = dataset.data().get_batch(0).unwrap();
367        let pixel_col = batch
368            .column(0)
369            .as_any()
370            .downcast_ref::<Float32Array>()
371            .unwrap();
372        for i in 0..pixel_col.len() {
373            let val = pixel_col.value(i);
374            assert!(
375                (0.0..=1.0).contains(&val),
376                "Pixel value {} out of range",
377                val
378            );
379        }
380    }
381
382    #[test]
383    fn test_mnist_clone() {
384        let dataset = mnist().unwrap();
385        let cloned = dataset.clone();
386        assert_eq!(cloned.len(), dataset.len());
387    }
388
389    #[test]
390    fn test_mnist_debug() {
391        let dataset = mnist().unwrap();
392        let debug = format!("{:?}", dataset);
393        assert!(debug.contains("MnistDataset"));
394    }
395
396    #[test]
397    fn test_embedded_mnist_sample() {
398        let (pixels, labels) = embedded_mnist_sample();
399        assert_eq!(pixels.len(), 100 * 784);
400        assert_eq!(labels.len(), 100);
401    }
402
403    #[test]
404    fn test_embedded_mnist_sample_labels_balanced() {
405        let (_, labels) = embedded_mnist_sample();
406        let mut counts = [0i32; 10];
407        for label in labels {
408            counts[usize::try_from(label).unwrap()] += 1;
409        }
410        for (digit, &count) in counts.iter().enumerate() {
411            assert_eq!(count, 10, "Digit {} should have 10 samples", digit);
412        }
413    }
414
415    #[test]
416    fn test_generate_digit_pattern_0() {
417        let pattern = generate_digit_pattern(0);
418        assert_eq!(pattern.len(), 784);
419        // Should have some non-zero pixels (oval)
420        let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
421        assert!(non_zero > 0, "Digit 0 pattern should have non-zero pixels");
422    }
423
424    #[test]
425    fn test_generate_digit_pattern_1() {
426        let pattern = generate_digit_pattern(1);
427        assert_eq!(pattern.len(), 784);
428        let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
429        assert!(non_zero > 0, "Digit 1 pattern should have non-zero pixels");
430    }
431
432    #[test]
433    fn test_generate_digit_patterns_all() {
434        for digit in 0..10 {
435            let pattern = generate_digit_pattern(digit);
436            assert_eq!(pattern.len(), 784, "Digit {} pattern wrong size", digit);
437            let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
438            assert!(
439                non_zero > 0,
440                "Digit {} pattern should have non-zero pixels",
441                digit
442            );
443        }
444    }
445
446    #[test]
447    fn test_generate_digit_pattern_unknown() {
448        let pattern = generate_digit_pattern(99);
449        assert_eq!(pattern.len(), 784);
450        // Unknown digit should be all zeros
451        let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
452        assert_eq!(non_zero, 0, "Unknown digit should have all zeros");
453    }
454
455    #[test]
456    fn test_set_pixel_in_bounds() {
457        let mut img = vec![0.0f32; 784];
458        set_pixel(&mut img, 14, 14, 1.0);
459        assert_eq!(img[14 * 28 + 14], 1.0);
460    }
461
462    #[test]
463    fn test_set_pixel_out_of_bounds() {
464        let mut img = vec![0.0f32; 784];
465        set_pixel(&mut img, 30, 14, 1.0); // x out of bounds
466        set_pixel(&mut img, 14, 30, 1.0); // y out of bounds
467                                          // Should not panic, and image should be unchanged
468        let non_zero: usize = img.iter().filter(|&&p| p > 0.0).count();
469        assert_eq!(non_zero, 0);
470    }
471
472    /// TDD RED: Test that split is stratified - both train and test must
473    /// contain all 10 digit classes This test documents the bug: current
474    /// sequential split puts 0-7 in train, 8-9 in test only
475    #[test]
476    fn test_mnist_split_is_stratified() {
477        use std::collections::HashSet;
478
479        let dataset = mnist().unwrap();
480        let split = dataset.split().unwrap();
481
482        // Extract labels from train set
483        let train_batch = split.train.get_batch(0).unwrap();
484        let train_labels = train_batch
485            .column(784)
486            .as_any()
487            .downcast_ref::<Int32Array>()
488            .unwrap();
489        let train_label_set: HashSet<i32> = (0..train_labels.len())
490            .map(|i| train_labels.value(i))
491            .collect();
492
493        // Extract labels from test set
494        let test_batch = split.test.get_batch(0).unwrap();
495        let test_labels = test_batch
496            .column(784)
497            .as_any()
498            .downcast_ref::<Int32Array>()
499            .unwrap();
500        let test_label_set: HashSet<i32> = (0..test_labels.len())
501            .map(|i| test_labels.value(i))
502            .collect();
503
504        // STRATIFIED REQUIREMENT: Both splits must contain all 10 digit classes
505        assert_eq!(
506            train_label_set.len(),
507            10,
508            "Train set must contain all 10 digit classes, got {:?}",
509            train_label_set
510        );
511        assert_eq!(
512            test_label_set.len(),
513            10,
514            "Test set must contain all 10 digit classes, got {:?}",
515            test_label_set
516        );
517
518        // Verify each class 0-9 is present in both sets
519        for digit in 0..10 {
520            assert!(
521                train_label_set.contains(&digit),
522                "Train set missing digit {}",
523                digit
524            );
525            assert!(
526                test_label_set.contains(&digit),
527                "Test set missing digit {}",
528                digit
529            );
530        }
531    }
532
533    /// Test that stratified split maintains approximate class balance
534    #[test]
535    fn test_mnist_split_maintains_class_balance() {
536        let dataset = mnist().unwrap();
537        let split = dataset.split().unwrap();
538
539        // Extract labels from train set
540        let train_batch = split.train.get_batch(0).unwrap();
541        let train_labels = train_batch
542            .column(784)
543            .as_any()
544            .downcast_ref::<Int32Array>()
545            .unwrap();
546
547        // Count samples per class in training set
548        let mut train_counts = [0usize; 10];
549        for i in 0..train_labels.len() {
550            let label = train_labels.value(i);
551            if (0..10).contains(&label) {
552                #[allow(clippy::cast_sign_loss)]
553                let idx = label as usize;
554                train_counts[idx] += 1;
555            }
556        }
557
558        // With 100 samples (10 per class) and 80% train split,
559        // each class should have 8 samples in training (±1 for rounding)
560        for (digit, &count) in train_counts.iter().enumerate() {
561            assert!(
562                (7..=9).contains(&count),
563                "Digit {} has {} training samples, expected ~8",
564                digit,
565                count
566            );
567        }
568    }
569}