Skip to main content

alimentar/datasets/
cifar100.rs

1//! CIFAR-100 dataset loader
2//!
3//! Embedded sample (1 per class = 100 total) works offline.
4//! Full dataset (60k) available with `hf-hub` feature.
5
6use std::sync::Arc;
7
8use arrow::{
9    array::{Float32Array, Int32Array, RecordBatch},
10    datatypes::{DataType, Field, Schema},
11};
12
13use super::{CanonicalDataset, DatasetSplit};
14use crate::{
15    transform::{Skip, Take, Transform},
16    ArrowDataset, Dataset, Result,
17};
18
19/// CIFAR-100 fine class names (100 classes)
20pub const CIFAR100_FINE_CLASSES: [&str; 100] = [
21    "apple",
22    "aquarium_fish",
23    "baby",
24    "bear",
25    "beaver",
26    "bed",
27    "bee",
28    "beetle",
29    "bicycle",
30    "bottle",
31    "bowl",
32    "boy",
33    "bridge",
34    "bus",
35    "butterfly",
36    "camel",
37    "can",
38    "castle",
39    "caterpillar",
40    "cattle",
41    "chair",
42    "chimpanzee",
43    "clock",
44    "cloud",
45    "cockroach",
46    "couch",
47    "crab",
48    "crocodile",
49    "cup",
50    "dinosaur",
51    "dolphin",
52    "elephant",
53    "flatfish",
54    "forest",
55    "fox",
56    "girl",
57    "hamster",
58    "house",
59    "kangaroo",
60    "keyboard",
61    "lamp",
62    "lawn_mower",
63    "leopard",
64    "lion",
65    "lizard",
66    "lobster",
67    "man",
68    "maple_tree",
69    "motorcycle",
70    "mountain",
71    "mouse",
72    "mushroom",
73    "oak_tree",
74    "orange",
75    "orchid",
76    "otter",
77    "palm_tree",
78    "pear",
79    "pickup_truck",
80    "pine_tree",
81    "plain",
82    "plate",
83    "poppy",
84    "porcupine",
85    "possum",
86    "rabbit",
87    "raccoon",
88    "ray",
89    "road",
90    "rocket",
91    "rose",
92    "sea",
93    "seal",
94    "shark",
95    "shrew",
96    "skunk",
97    "skyscraper",
98    "snail",
99    "snake",
100    "spider",
101    "squirrel",
102    "streetcar",
103    "sunflower",
104    "sweet_pepper",
105    "table",
106    "tank",
107    "telephone",
108    "television",
109    "tiger",
110    "tractor",
111    "train",
112    "trout",
113    "tulip",
114    "turtle",
115    "wardrobe",
116    "whale",
117    "willow_tree",
118    "wolf",
119    "woman",
120    "worm",
121];
122
123/// CIFAR-100 coarse class names (20 superclasses)
124pub const CIFAR100_COARSE_CLASSES: [&str; 20] = [
125    "aquatic_mammals",
126    "fish",
127    "flowers",
128    "food_containers",
129    "fruit_and_vegetables",
130    "household_electrical_devices",
131    "household_furniture",
132    "insects",
133    "large_carnivores",
134    "large_man-made_outdoor_things",
135    "large_natural_outdoor_scenes",
136    "large_omnivores_and_herbivores",
137    "medium_mammals",
138    "non-insect_invertebrates",
139    "people",
140    "reptiles",
141    "small_mammals",
142    "trees",
143    "vehicles_1",
144    "vehicles_2",
145];
146
147/// Load CIFAR-100 dataset (embedded 100-sample subset)
148///
149/// # Errors
150///
151/// Returns an error if dataset construction fails.
152pub fn cifar100() -> Result<Cifar100Dataset> {
153    Cifar100Dataset::load()
154}
155
156/// CIFAR-100 image classification dataset
157#[derive(Debug, Clone)]
158pub struct Cifar100Dataset {
159    data: ArrowDataset,
160}
161
162impl Cifar100Dataset {
163    /// Load embedded CIFAR-100 sample
164    ///
165    /// # Errors
166    ///
167    /// Returns an error if construction fails.
168    pub fn load() -> Result<Self> {
169        let mut fields: Vec<Field> = (0..3072)
170            .map(|i| Field::new(format!("pixel_{i}"), DataType::Float32, false))
171            .collect();
172        fields.push(Field::new("fine_label", DataType::Int32, false));
173        fields.push(Field::new("coarse_label", DataType::Int32, false));
174        let schema = Arc::new(Schema::new(fields));
175
176        let (pixels, fine_labels, coarse_labels) = embedded_cifar100_sample();
177        let num_samples = fine_labels.len();
178
179        let mut columns: Vec<Arc<dyn arrow::array::Array>> = Vec::with_capacity(3074);
180        for pixel_idx in 0..3072 {
181            let pixel_data: Vec<f32> = (0..num_samples)
182                .map(|s| pixels[s * 3072 + pixel_idx])
183                .collect();
184            columns.push(Arc::new(Float32Array::from(pixel_data)));
185        }
186        columns.push(Arc::new(Int32Array::from(fine_labels)));
187        columns.push(Arc::new(Int32Array::from(coarse_labels)));
188
189        let batch = RecordBatch::try_new(schema, columns).map_err(crate::Error::Arrow)?;
190        let data = ArrowDataset::from_batch(batch)?;
191
192        Ok(Self { data })
193    }
194
195    /// Load full CIFAR-100 from HuggingFace Hub
196    #[cfg(feature = "hf-hub")]
197    pub fn load_full() -> Result<Self> {
198        use crate::hf_hub::HfDataset;
199        let hf = HfDataset::builder("uoft-cs/cifar100")
200            .split("train")
201            .build()?;
202        let data = hf.download()?;
203        Ok(Self { data })
204    }
205
206    /// Get train/test split (80/20)
207    ///
208    /// # Errors
209    ///
210    /// Returns an error if the dataset is empty or split fails.
211    pub fn split(&self) -> Result<DatasetSplit> {
212        let len = self.data.len();
213        let train_size = (len * 8) / 10;
214
215        let batch = self
216            .data
217            .get_batch(0)
218            .ok_or_else(|| crate::Error::empty_dataset("CIFAR-100"))?;
219
220        let train_batch = Take::new(train_size).apply(batch.clone())?;
221        let test_batch = Skip::new(train_size).apply(batch.clone())?;
222
223        Ok(DatasetSplit::new(
224            ArrowDataset::from_batch(train_batch)?,
225            ArrowDataset::from_batch(test_batch)?,
226        ))
227    }
228
229    /// Get fine class name for a label (100 classes)
230    #[must_use]
231    pub fn fine_class_name(label: i32) -> Option<&'static str> {
232        if label < 0 {
233            return None;
234        }
235        CIFAR100_FINE_CLASSES
236            .get(usize::try_from(label).ok()?)
237            .copied()
238    }
239
240    /// Get coarse class name for a label (20 superclasses)
241    #[must_use]
242    pub fn coarse_class_name(label: i32) -> Option<&'static str> {
243        if label < 0 {
244            return None;
245        }
246        CIFAR100_COARSE_CLASSES
247            .get(usize::try_from(label).ok()?)
248            .copied()
249    }
250}
251
252impl CanonicalDataset for Cifar100Dataset {
253    fn data(&self) -> &ArrowDataset {
254        &self.data
255    }
256    fn num_features(&self) -> usize {
257        3072
258    }
259    fn num_classes(&self) -> usize {
260        100
261    }
262    fn feature_names(&self) -> &'static [&'static str] {
263        &[]
264    }
265    fn target_name(&self) -> &'static str {
266        "fine_label"
267    }
268    fn description(&self) -> &'static str {
269        "CIFAR-100 (Krizhevsky 2009). 100 fine classes, 20 coarse. Embedded: 100. Full: 60k."
270    }
271}
272
273/// Fine-to-coarse label mapping
274const FINE_TO_COARSE: [usize; 100] = [
275    4, 1, 14, 8, 0, 6, 7, 7, 18, 3, 3, 14, 9, 18, 7, 11, 3, 9, 7, 11, 6, 11, 5, 10, 7, 6, 13, 15,
276    3, 15, 0, 11, 1, 10, 12, 14, 16, 9, 11, 5, 5, 19, 8, 8, 15, 13, 14, 17, 18, 10, 16, 4, 17, 4,
277    2, 0, 17, 4, 18, 17, 10, 3, 2, 12, 12, 16, 12, 1, 9, 19, 2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
278    16, 19, 2, 4, 6, 19, 5, 5, 8, 19, 18, 1, 2, 15, 6, 0, 17, 8, 14, 13,
279];
280
281/// Embedded CIFAR-100 sample - 1 per fine class
282#[allow(clippy::cast_precision_loss)]
283fn embedded_cifar100_sample() -> (Vec<f32>, Vec<i32>, Vec<i32>) {
284    let mut pixels = Vec::with_capacity(100 * 3072);
285    let mut fine_labels = Vec::with_capacity(100);
286    let mut coarse_labels = Vec::with_capacity(100);
287
288    // Generate unique color for each of 100 classes
289    for (class_idx, &coarse_idx) in FINE_TO_COARSE.iter().enumerate() {
290        // Deterministic color based on class index (values 0-99, safe precision)
291        let r = ((class_idx * 37) % 100) as f32 / 100.0;
292        let g = ((class_idx * 59) % 100) as f32 / 100.0;
293        let b = ((class_idx * 73) % 100) as f32 / 100.0;
294
295        // Fill RGB channels
296        for _ in 0..1024 {
297            pixels.push(r);
298        }
299        for _ in 0..1024 {
300            pixels.push(g);
301        }
302        for _ in 0..1024 {
303            pixels.push(b);
304        }
305
306        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
307        {
308            fine_labels.push(class_idx as i32);
309            coarse_labels.push(coarse_idx as i32);
310        }
311    }
312
313    (pixels, fine_labels, coarse_labels)
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::Dataset;
320
321    #[test]
322    fn test_cifar100_load() {
323        let dataset = cifar100().unwrap();
324        assert_eq!(dataset.len(), 100);
325        assert_eq!(dataset.num_classes(), 100);
326    }
327
328    #[test]
329    fn test_cifar100_split() {
330        let dataset = cifar100().unwrap();
331        let split = dataset.split().unwrap();
332        assert_eq!(split.train.len(), 80);
333        assert_eq!(split.test.len(), 20);
334    }
335
336    #[test]
337    fn test_cifar100_fine_class_names() {
338        assert_eq!(Cifar100Dataset::fine_class_name(0), Some("apple"));
339        assert_eq!(Cifar100Dataset::fine_class_name(99), Some("worm"));
340        assert_eq!(Cifar100Dataset::fine_class_name(100), None);
341        assert_eq!(Cifar100Dataset::fine_class_name(-1), None);
342    }
343
344    #[test]
345    fn test_cifar100_coarse_class_names() {
346        assert_eq!(
347            Cifar100Dataset::coarse_class_name(0),
348            Some("aquatic_mammals")
349        );
350        assert_eq!(Cifar100Dataset::coarse_class_name(19), Some("vehicles_2"));
351        assert_eq!(Cifar100Dataset::coarse_class_name(20), None);
352    }
353
354    #[test]
355    fn test_cifar100_has_both_labels() {
356        let dataset = cifar100().unwrap();
357        let schema = dataset.data().schema();
358        assert!(schema.field_with_name("fine_label").is_ok());
359        assert!(schema.field_with_name("coarse_label").is_ok());
360    }
361
362    #[test]
363    fn test_cifar100_coarse_class_name_negative() {
364        assert_eq!(Cifar100Dataset::coarse_class_name(-1), None);
365        assert_eq!(Cifar100Dataset::coarse_class_name(-100), None);
366    }
367
368    #[test]
369    fn test_cifar100_num_features() {
370        let dataset = cifar100().unwrap();
371        assert_eq!(dataset.num_features(), 3072);
372    }
373
374    #[test]
375    fn test_cifar100_feature_names() {
376        let dataset = cifar100().unwrap();
377        assert!(dataset.feature_names().is_empty());
378    }
379
380    #[test]
381    fn test_cifar100_target_name() {
382        let dataset = cifar100().unwrap();
383        assert_eq!(dataset.target_name(), "fine_label");
384    }
385
386    #[test]
387    fn test_cifar100_description() {
388        let dataset = cifar100().unwrap();
389        let desc = dataset.description();
390        assert!(desc.contains("CIFAR-100"));
391        assert!(desc.contains("100 fine classes"));
392    }
393
394    #[test]
395    fn test_cifar100_data_access() {
396        let dataset = cifar100().unwrap();
397        let data = dataset.data();
398        assert_eq!(data.len(), 100);
399    }
400
401    #[test]
402    fn test_cifar100_schema_columns() {
403        let dataset = cifar100().unwrap();
404        let batch = dataset.data().get_batch(0).unwrap();
405        assert_eq!(batch.num_columns(), 3074); // 3072 pixels + 2 labels
406    }
407
408    #[test]
409    fn test_cifar100_fine_labels_in_range() {
410        let dataset = cifar100().unwrap();
411        let batch = dataset.data().get_batch(0).unwrap();
412        let label_col = batch
413            .column(3072)
414            .as_any()
415            .downcast_ref::<Int32Array>()
416            .unwrap();
417        for i in 0..label_col.len() {
418            let label = label_col.value(i);
419            assert!(
420                (0..100).contains(&label),
421                "Fine label {} out of range",
422                label
423            );
424        }
425    }
426
427    #[test]
428    fn test_cifar100_coarse_labels_in_range() {
429        let dataset = cifar100().unwrap();
430        let batch = dataset.data().get_batch(0).unwrap();
431        let label_col = batch
432            .column(3073)
433            .as_any()
434            .downcast_ref::<Int32Array>()
435            .unwrap();
436        for i in 0..label_col.len() {
437            let label = label_col.value(i);
438            assert!(
439                (0..20).contains(&label),
440                "Coarse label {} out of range",
441                label
442            );
443        }
444    }
445
446    #[test]
447    fn test_cifar100_clone() {
448        let dataset = cifar100().unwrap();
449        let cloned = dataset.clone();
450        assert_eq!(cloned.len(), dataset.len());
451    }
452
453    #[test]
454    fn test_cifar100_debug() {
455        let dataset = cifar100().unwrap();
456        let debug = format!("{:?}", dataset);
457        assert!(debug.contains("Cifar100Dataset"));
458    }
459
460    #[test]
461    fn test_cifar100_fine_classes_constant() {
462        assert_eq!(CIFAR100_FINE_CLASSES.len(), 100);
463        assert_eq!(CIFAR100_FINE_CLASSES[0], "apple");
464        assert_eq!(CIFAR100_FINE_CLASSES[99], "worm");
465    }
466
467    #[test]
468    fn test_cifar100_coarse_classes_constant() {
469        assert_eq!(CIFAR100_COARSE_CLASSES.len(), 20);
470        assert_eq!(CIFAR100_COARSE_CLASSES[0], "aquatic_mammals");
471        assert_eq!(CIFAR100_COARSE_CLASSES[19], "vehicles_2");
472    }
473
474    #[test]
475    fn test_fine_to_coarse_mapping_valid() {
476        for &coarse_idx in &FINE_TO_COARSE {
477            assert!(coarse_idx < 20, "Coarse index {} out of range", coarse_idx);
478        }
479    }
480
481    #[test]
482    fn test_embedded_cifar100_sample() {
483        let (pixels, fine_labels, coarse_labels) = embedded_cifar100_sample();
484        assert_eq!(pixels.len(), 100 * 3072);
485        assert_eq!(fine_labels.len(), 100);
486        assert_eq!(coarse_labels.len(), 100);
487    }
488
489    #[test]
490    fn test_embedded_cifar100_sample_labels_valid() {
491        let (_, fine_labels, coarse_labels) = embedded_cifar100_sample();
492        for (i, &fine) in fine_labels.iter().enumerate() {
493            assert!(
494                (0..100).contains(&fine),
495                "Fine label {} at {} out of range",
496                fine,
497                i
498            );
499        }
500        for (i, &coarse) in coarse_labels.iter().enumerate() {
501            assert!(
502                (0..20).contains(&coarse),
503                "Coarse label {} at {} out of range",
504                coarse,
505                i
506            );
507        }
508    }
509}