Skip to main content

alimentar/datasets/
fashion_mnist.rs

1//! Fashion-MNIST dataset loader
2//!
3//! Embedded sample (10 per class = 100 total) works offline.
4//! Full dataset (70k) available with `hf-hub` feature.
5
6use std::sync::Arc;
7
8use arrow::{
9    array::{Float32Array, Int32Array, RecordBatch},
10    datatypes::{DataType, Field, Schema},
11};
12
13use super::{CanonicalDataset, DatasetSplit};
14use crate::{
15    transform::{Skip, Take, Transform},
16    ArrowDataset, Dataset, Result,
17};
18
19/// Fashion-MNIST class names
20pub const FASHION_MNIST_CLASSES: [&str; 10] = [
21    "t-shirt/top",
22    "trouser",
23    "pullover",
24    "dress",
25    "coat",
26    "sandal",
27    "shirt",
28    "sneaker",
29    "bag",
30    "ankle boot",
31];
32
33/// Load Fashion-MNIST dataset (embedded 100-sample subset)
34///
35/// # Errors
36///
37/// Returns an error if dataset construction fails.
38pub fn fashion_mnist() -> Result<FashionMnistDataset> {
39    FashionMnistDataset::load()
40}
41
42/// Fashion-MNIST clothing classification dataset
43#[derive(Debug, Clone)]
44pub struct FashionMnistDataset {
45    data: ArrowDataset,
46}
47
48impl FashionMnistDataset {
49    /// Load embedded Fashion-MNIST sample
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if construction fails.
54    pub fn load() -> Result<Self> {
55        let mut fields: Vec<Field> = (0..784)
56            .map(|i| Field::new(format!("pixel_{i}"), DataType::Float32, false))
57            .collect();
58        fields.push(Field::new("label", DataType::Int32, false));
59        let schema = Arc::new(Schema::new(fields));
60
61        let (pixels, labels) = embedded_fashion_mnist_sample();
62        let num_samples = labels.len();
63
64        let mut columns: Vec<Arc<dyn arrow::array::Array>> = Vec::with_capacity(785);
65        for pixel_idx in 0..784 {
66            let pixel_data: Vec<f32> = (0..num_samples)
67                .map(|s| pixels[s * 784 + pixel_idx])
68                .collect();
69            columns.push(Arc::new(Float32Array::from(pixel_data)));
70        }
71        columns.push(Arc::new(Int32Array::from(labels)));
72
73        let batch = RecordBatch::try_new(schema, columns).map_err(crate::Error::Arrow)?;
74        let data = ArrowDataset::from_batch(batch)?;
75
76        Ok(Self { data })
77    }
78
79    /// Load full Fashion-MNIST from HuggingFace Hub
80    #[cfg(feature = "hf-hub")]
81    pub fn load_full() -> Result<Self> {
82        use crate::hf_hub::HfDataset;
83        let hf = HfDataset::builder("zalando-datasets/fashion_mnist")
84            .split("train")
85            .build()?;
86        let data = hf.download()?;
87        Ok(Self { data })
88    }
89
90    /// Get train/test split (80/20)
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if the dataset is empty or split fails.
95    pub fn split(&self) -> Result<DatasetSplit> {
96        let len = self.data.len();
97        let train_size = (len * 8) / 10;
98
99        let batch = self
100            .data
101            .get_batch(0)
102            .ok_or_else(|| crate::Error::empty_dataset("Fashion-MNIST"))?;
103
104        let train_batch = Take::new(train_size).apply(batch.clone())?;
105        let test_batch = Skip::new(train_size).apply(batch.clone())?;
106
107        Ok(DatasetSplit::new(
108            ArrowDataset::from_batch(train_batch)?,
109            ArrowDataset::from_batch(test_batch)?,
110        ))
111    }
112
113    /// Get class name for a label
114    #[must_use]
115    pub fn class_name(label: i32) -> Option<&'static str> {
116        if label < 0 {
117            return None;
118        }
119        FASHION_MNIST_CLASSES
120            .get(usize::try_from(label).ok()?)
121            .copied()
122    }
123}
124
125impl CanonicalDataset for FashionMnistDataset {
126    fn data(&self) -> &ArrowDataset {
127        &self.data
128    }
129    fn num_features(&self) -> usize {
130        784
131    }
132    fn num_classes(&self) -> usize {
133        10
134    }
135    fn feature_names(&self) -> &'static [&'static str] {
136        &[]
137    }
138    fn target_name(&self) -> &'static str {
139        "label"
140    }
141    fn description(&self) -> &'static str {
142        "Fashion-MNIST (Xiao et al. 2017). Embedded: 100 samples. Full: 70k (requires hf-hub)."
143    }
144}
145
146/// Embedded Fashion-MNIST sample - 10 per class with simple patterns
147fn embedded_fashion_mnist_sample() -> (Vec<f32>, Vec<i32>) {
148    let mut pixels = Vec::with_capacity(100 * 784);
149    let mut labels = Vec::with_capacity(100);
150
151    for class_idx in 0..10 {
152        for sample in 0..10i16 {
153            let pattern = generate_fashion_pattern(class_idx, sample);
154            pixels.extend(pattern);
155            #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
156            labels.push(class_idx as i32);
157        }
158    }
159
160    (pixels, labels)
161}
162
163/// Generate simple fashion item patterns
164fn generate_fashion_pattern(class: usize, variation: i16) -> Vec<f32> {
165    let mut img = vec![0.0f32; 784];
166    let var = f32::from(variation) * 0.02;
167
168    match class {
169        0 => draw_tshirt(&mut img, var),     // t-shirt/top
170        1 => draw_trouser(&mut img, var),    // trouser
171        2 => draw_pullover(&mut img, var),   // pullover
172        3 => draw_dress(&mut img, var),      // dress
173        4 => draw_coat(&mut img, var),       // coat
174        5 => draw_sandal(&mut img, var),     // sandal
175        6 => draw_shirt(&mut img, var),      // shirt
176        7 => draw_sneaker(&mut img, var),    // sneaker
177        8 => draw_bag(&mut img, var),        // bag
178        9 => draw_ankle_boot(&mut img, var), // ankle boot
179        _ => {}
180    }
181
182    img
183}
184
185fn set_pixel(img: &mut [f32], x: usize, y: usize, val: f32) {
186    if x < 28 && y < 28 {
187        img[y * 28 + x] = val;
188    }
189}
190
191fn draw_tshirt(img: &mut [f32], var: f32) {
192    // Body
193    for y in 8..22 {
194        for x in 8..20 {
195            set_pixel(img, x, y, (0.8 + var).min(1.0));
196        }
197    }
198    // Sleeves
199    for y in 8..12 {
200        for x in 4..8 {
201            set_pixel(img, x, y, (0.7 + var).min(1.0));
202        }
203        for x in 20..24 {
204            set_pixel(img, x, y, (0.7 + var).min(1.0));
205        }
206    }
207}
208
209fn draw_trouser(img: &mut [f32], var: f32) {
210    // Left leg
211    for y in 4..24 {
212        for x in 8..13 {
213            set_pixel(img, x, y, (0.6 + var).min(1.0));
214        }
215    }
216    // Right leg
217    for y in 4..24 {
218        for x in 15..20 {
219            set_pixel(img, x, y, (0.6 + var).min(1.0));
220        }
221    }
222    // Waist
223    for x in 8..20 {
224        for y in 4..7 {
225            set_pixel(img, x, y, (0.7 + var).min(1.0));
226        }
227    }
228}
229
230fn draw_pullover(img: &mut [f32], var: f32) {
231    draw_tshirt(img, var);
232    // Longer sleeves
233    for y in 12..16 {
234        for x in 4..8 {
235            set_pixel(img, x, y, (0.7 + var).min(1.0));
236        }
237        for x in 20..24 {
238            set_pixel(img, x, y, (0.7 + var).min(1.0));
239        }
240    }
241}
242
243fn draw_dress(img: &mut [f32], var: f32) {
244    // Top
245    for y in 6..12 {
246        for x in 10..18 {
247            set_pixel(img, x, y, (0.8 + var).min(1.0));
248        }
249    }
250    // Flared skirt
251    for y in 12..24 {
252        let width = 4 + (y - 12) / 2;
253        for x in (14 - width)..(14 + width) {
254            set_pixel(img, x, y, (0.8 + var).min(1.0));
255        }
256    }
257}
258
259fn draw_coat(img: &mut [f32], var: f32) {
260    draw_tshirt(img, var);
261    // Extend body
262    for y in 22..26 {
263        for x in 8..20 {
264            set_pixel(img, x, y, (0.8 + var).min(1.0));
265        }
266    }
267}
268
269fn draw_sandal(img: &mut [f32], var: f32) {
270    // Sole
271    for x in 6..22 {
272        for y in 20..24 {
273            set_pixel(img, x, y, (0.5 + var).min(1.0));
274        }
275    }
276    // Straps
277    for x in 8..20 {
278        set_pixel(img, x, 16, (0.7 + var).min(1.0));
279        set_pixel(img, x, 12, (0.7 + var).min(1.0));
280    }
281}
282
283fn draw_shirt(img: &mut [f32], var: f32) {
284    draw_tshirt(img, var);
285    // Collar
286    for x in 12..16 {
287        set_pixel(img, x, 7, (0.9 + var).min(1.0));
288    }
289}
290
291fn draw_sneaker(img: &mut [f32], var: f32) {
292    // Sole
293    for x in 4..24 {
294        for y in 18..22 {
295            set_pixel(img, x, y, (0.4 + var).min(1.0));
296        }
297    }
298    // Upper
299    for x in 6..22 {
300        for y in 12..18 {
301            set_pixel(img, x, y, (0.8 + var).min(1.0));
302        }
303    }
304}
305
306fn draw_bag(img: &mut [f32], var: f32) {
307    // Body
308    for y in 10..24 {
309        for x in 8..20 {
310            set_pixel(img, x, y, (0.7 + var).min(1.0));
311        }
312    }
313    // Handle
314    for x in 10..18 {
315        set_pixel(img, x, 6, (0.6 + var).min(1.0));
316        set_pixel(img, x, 8, (0.6 + var).min(1.0));
317    }
318    set_pixel(img, 10, 7, (0.6 + var).min(1.0));
319    set_pixel(img, 17, 7, (0.6 + var).min(1.0));
320}
321
322fn draw_ankle_boot(img: &mut [f32], var: f32) {
323    // Sole
324    for x in 6..22 {
325        for y in 20..24 {
326            set_pixel(img, x, y, (0.3 + var).min(1.0));
327        }
328    }
329    // Boot upper
330    for x in 8..20 {
331        for y in 8..20 {
332            set_pixel(img, x, y, (0.6 + var).min(1.0));
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use arrow::array::Float32Array;
340
341    use super::*;
342    use crate::Dataset;
343
344    #[test]
345    fn test_fashion_mnist_load() {
346        let dataset = fashion_mnist().unwrap();
347        assert_eq!(dataset.len(), 100);
348        assert_eq!(dataset.num_classes(), 10);
349    }
350
351    #[test]
352    fn test_fashion_mnist_split() {
353        let dataset = fashion_mnist().unwrap();
354        let split = dataset.split().unwrap();
355        assert_eq!(split.train.len(), 80);
356        assert_eq!(split.test.len(), 20);
357    }
358
359    #[test]
360    fn test_fashion_mnist_class_names() {
361        assert_eq!(FashionMnistDataset::class_name(0), Some("t-shirt/top"));
362        assert_eq!(FashionMnistDataset::class_name(9), Some("ankle boot"));
363        assert_eq!(FashionMnistDataset::class_name(10), None);
364        assert_eq!(FashionMnistDataset::class_name(-1), None);
365    }
366
367    #[test]
368    fn test_fashion_mnist_all_class_names() {
369        for (idx, &expected) in FASHION_MNIST_CLASSES.iter().enumerate() {
370            assert_eq!(FashionMnistDataset::class_name(idx as i32), Some(expected));
371        }
372    }
373
374    #[test]
375    fn test_fashion_mnist_num_features() {
376        let dataset = fashion_mnist().unwrap();
377        assert_eq!(dataset.num_features(), 784);
378    }
379
380    #[test]
381    fn test_fashion_mnist_feature_names() {
382        let dataset = fashion_mnist().unwrap();
383        assert!(dataset.feature_names().is_empty());
384    }
385
386    #[test]
387    fn test_fashion_mnist_target_name() {
388        let dataset = fashion_mnist().unwrap();
389        assert_eq!(dataset.target_name(), "label");
390    }
391
392    #[test]
393    fn test_fashion_mnist_description() {
394        let dataset = fashion_mnist().unwrap();
395        let desc = dataset.description();
396        assert!(desc.contains("Fashion-MNIST"));
397        assert!(desc.contains("Xiao"));
398    }
399
400    #[test]
401    fn test_fashion_mnist_data_access() {
402        let dataset = fashion_mnist().unwrap();
403        let data = dataset.data();
404        assert_eq!(data.len(), 100);
405    }
406
407    #[test]
408    fn test_fashion_mnist_schema_columns() {
409        let dataset = fashion_mnist().unwrap();
410        let batch = dataset.data().get_batch(0).unwrap();
411        assert_eq!(batch.num_columns(), 785); // 784 pixels + 1 label
412    }
413
414    #[test]
415    fn test_fashion_mnist_labels_in_range() {
416        let dataset = fashion_mnist().unwrap();
417        let batch = dataset.data().get_batch(0).unwrap();
418        let label_col = batch
419            .column(784)
420            .as_any()
421            .downcast_ref::<Int32Array>()
422            .unwrap();
423        for i in 0..label_col.len() {
424            let label = label_col.value(i);
425            assert!((0..10).contains(&label), "Label {} out of range", label);
426        }
427    }
428
429    #[test]
430    fn test_fashion_mnist_pixel_values_normalized() {
431        let dataset = fashion_mnist().unwrap();
432        let batch = dataset.data().get_batch(0).unwrap();
433        let pixel_col = batch
434            .column(0)
435            .as_any()
436            .downcast_ref::<Float32Array>()
437            .unwrap();
438        for i in 0..pixel_col.len() {
439            let val = pixel_col.value(i);
440            assert!(
441                (0.0..=1.0).contains(&val),
442                "Pixel value {} out of range",
443                val
444            );
445        }
446    }
447
448    #[test]
449    fn test_fashion_mnist_clone() {
450        let dataset = fashion_mnist().unwrap();
451        let cloned = dataset.clone();
452        assert_eq!(cloned.len(), dataset.len());
453    }
454
455    #[test]
456    fn test_fashion_mnist_debug() {
457        let dataset = fashion_mnist().unwrap();
458        let debug = format!("{:?}", dataset);
459        assert!(debug.contains("FashionMnistDataset"));
460    }
461
462    #[test]
463    fn test_embedded_fashion_mnist_sample() {
464        let (pixels, labels) = embedded_fashion_mnist_sample();
465        assert_eq!(pixels.len(), 100 * 784);
466        assert_eq!(labels.len(), 100);
467    }
468
469    #[test]
470    fn test_embedded_fashion_mnist_sample_labels_balanced() {
471        let (_, labels) = embedded_fashion_mnist_sample();
472        let mut counts = [0i32; 10];
473        for label in labels {
474            counts[usize::try_from(label).unwrap()] += 1;
475        }
476        for (class, &count) in counts.iter().enumerate() {
477            assert_eq!(count, 10, "Class {} should have 10 samples", class);
478        }
479    }
480
481    #[test]
482    fn test_generate_fashion_pattern_all_classes() {
483        for class in 0..10 {
484            let pattern = generate_fashion_pattern(class, 0);
485            assert_eq!(pattern.len(), 784, "Class {} pattern wrong size", class);
486            let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
487            assert!(
488                non_zero > 0,
489                "Class {} pattern should have non-zero pixels",
490                class
491            );
492        }
493    }
494
495    #[test]
496    fn test_generate_fashion_pattern_with_variation() {
497        let pattern1 = generate_fashion_pattern(0, 0);
498        let pattern2 = generate_fashion_pattern(0, 5);
499        // Patterns should differ due to variation
500        let different = pattern1
501            .iter()
502            .zip(pattern2.iter())
503            .any(|(a, b)| (a - b).abs() > 0.001);
504        assert!(
505            different,
506            "Patterns with different variations should differ"
507        );
508    }
509
510    #[test]
511    fn test_generate_fashion_pattern_unknown() {
512        let pattern = generate_fashion_pattern(99, 0);
513        assert_eq!(pattern.len(), 784);
514        // Unknown class should be all zeros
515        let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
516        assert_eq!(non_zero, 0, "Unknown class should have all zeros");
517    }
518
519    #[test]
520    fn test_set_pixel_in_bounds() {
521        let mut img = vec![0.0f32; 784];
522        set_pixel(&mut img, 14, 14, 1.0);
523        assert_eq!(img[14 * 28 + 14], 1.0);
524    }
525
526    #[test]
527    fn test_set_pixel_out_of_bounds() {
528        let mut img = vec![0.0f32; 784];
529        set_pixel(&mut img, 30, 14, 1.0); // x out of bounds
530        set_pixel(&mut img, 14, 30, 1.0); // y out of bounds
531                                          // Should not panic, and image should be unchanged
532        let non_zero: usize = img.iter().filter(|&&p| p > 0.0).count();
533        assert_eq!(non_zero, 0);
534    }
535
536    #[test]
537    fn test_fashion_mnist_classes_constant() {
538        assert_eq!(FASHION_MNIST_CLASSES.len(), 10);
539        assert_eq!(FASHION_MNIST_CLASSES[0], "t-shirt/top");
540        assert_eq!(FASHION_MNIST_CLASSES[9], "ankle boot");
541    }
542}