Skip to main content

alimentar/datasets/
cifar10.rs

1//! CIFAR-10 dataset loader
2//!
3//! Embedded sample (10 per class = 100 total) works offline.
4//! Full dataset (60k) available with `hf-hub` feature.
5
6use std::sync::Arc;
7
8use arrow::{
9    array::{Float32Array, Int32Array, RecordBatch},
10    datatypes::{DataType, Field, Schema},
11};
12
13use super::{CanonicalDataset, DatasetSplit};
14use crate::{
15    transform::{Skip, Take, Transform},
16    ArrowDataset, Dataset, Result,
17};
18
19/// CIFAR-10 class names
20pub const CIFAR10_CLASSES: [&str; 10] = [
21    "airplane",
22    "automobile",
23    "bird",
24    "cat",
25    "deer",
26    "dog",
27    "frog",
28    "horse",
29    "ship",
30    "truck",
31];
32
33/// Load CIFAR-10 dataset (embedded 100-sample subset)
34///
35/// # Errors
36///
37/// Returns an error if dataset construction fails.
38pub fn cifar10() -> Result<Cifar10Dataset> {
39    Cifar10Dataset::load()
40}
41
42/// CIFAR-10 image classification dataset
43#[derive(Debug, Clone)]
44pub struct Cifar10Dataset {
45    data: ArrowDataset,
46}
47
48impl Cifar10Dataset {
49    /// Load embedded CIFAR-10 sample
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if construction fails.
54    pub fn load() -> Result<Self> {
55        // Schema: 3072 pixel columns (32x32x3 RGB) + label
56        let mut fields: Vec<Field> = (0..3072)
57            .map(|i| Field::new(format!("pixel_{i}"), DataType::Float32, false))
58            .collect();
59        fields.push(Field::new("label", DataType::Int32, false));
60        let schema = Arc::new(Schema::new(fields));
61
62        let (pixels, labels) = embedded_cifar10_sample();
63        let num_samples = labels.len();
64
65        let mut columns: Vec<Arc<dyn arrow::array::Array>> = Vec::with_capacity(3073);
66        for pixel_idx in 0..3072 {
67            let pixel_data: Vec<f32> = (0..num_samples)
68                .map(|s| pixels[s * 3072 + pixel_idx])
69                .collect();
70            columns.push(Arc::new(Float32Array::from(pixel_data)));
71        }
72        columns.push(Arc::new(Int32Array::from(labels)));
73
74        let batch = RecordBatch::try_new(schema, columns).map_err(crate::Error::Arrow)?;
75        let data = ArrowDataset::from_batch(batch)?;
76
77        Ok(Self { data })
78    }
79
80    /// Load full CIFAR-10 from HuggingFace Hub
81    #[cfg(feature = "hf-hub")]
82    pub fn load_full() -> Result<Self> {
83        use crate::hf_hub::HfDataset;
84        let hf = HfDataset::builder("uoft-cs/cifar10")
85            .split("train")
86            .build()?;
87        let data = hf.download()?;
88        Ok(Self { data })
89    }
90
91    /// Get train/test split (80/20)
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if the dataset is empty or split fails.
96    pub fn split(&self) -> Result<DatasetSplit> {
97        let len = self.data.len();
98        let train_size = (len * 8) / 10;
99
100        let batch = self
101            .data
102            .get_batch(0)
103            .ok_or_else(|| crate::Error::empty_dataset("CIFAR-10"))?;
104
105        let train_batch = Take::new(train_size).apply(batch.clone())?;
106        let test_batch = Skip::new(train_size).apply(batch.clone())?;
107
108        Ok(DatasetSplit::new(
109            ArrowDataset::from_batch(train_batch)?,
110            ArrowDataset::from_batch(test_batch)?,
111        ))
112    }
113
114    /// Get class name for a label
115    #[must_use]
116    pub fn class_name(label: i32) -> Option<&'static str> {
117        if label < 0 {
118            return None;
119        }
120        CIFAR10_CLASSES.get(usize::try_from(label).ok()?).copied()
121    }
122}
123
124impl CanonicalDataset for Cifar10Dataset {
125    fn data(&self) -> &ArrowDataset {
126        &self.data
127    }
128    fn num_features(&self) -> usize {
129        3072
130    }
131    fn num_classes(&self) -> usize {
132        10
133    }
134    fn feature_names(&self) -> &'static [&'static str] {
135        &[]
136    }
137    fn target_name(&self) -> &'static str {
138        "label"
139    }
140    fn description(&self) -> &'static str {
141        "CIFAR-10 (Krizhevsky 2009). Embedded: 100 samples. Full: 60k (requires hf-hub)."
142    }
143}
144
145/// Embedded CIFAR-10 sample - 10 per class with simple color patterns
146#[allow(clippy::cast_precision_loss)]
147fn embedded_cifar10_sample() -> (Vec<f32>, Vec<i32>) {
148    let mut pixels = Vec::with_capacity(100 * 3072);
149    let mut labels = Vec::with_capacity(100);
150
151    // Simple color patterns per class
152    let class_colors: [(f32, f32, f32); 10] = [
153        (0.5, 0.7, 0.9), // airplane - sky blue
154        (0.3, 0.3, 0.3), // automobile - gray
155        (0.6, 0.4, 0.2), // bird - brown
156        (0.8, 0.6, 0.4), // cat - orange
157        (0.4, 0.3, 0.2), // deer - brown
158        (0.7, 0.5, 0.3), // dog - tan
159        (0.2, 0.8, 0.2), // frog - green
160        (0.5, 0.3, 0.2), // horse - brown
161        (0.2, 0.3, 0.5), // ship - navy
162        (0.6, 0.2, 0.2), // truck - red
163    ];
164
165    for (class_idx, &(r, g, b)) in class_colors.iter().enumerate() {
166        for sample in 0..10i16 {
167            // Add variation per sample
168            let var = f32::from(sample) * 0.02;
169            for _ in 0..1024 {
170                pixels.push((r + var).min(1.0));
171            } // R channel
172            for _ in 0..1024 {
173                pixels.push((g + var).min(1.0));
174            } // G channel
175            for _ in 0..1024 {
176                pixels.push((b + var).min(1.0));
177            } // B channel
178              // class_idx is always 0-9, safe truncation
179            #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
180            labels.push(class_idx as i32);
181        }
182    }
183
184    (pixels, labels)
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::Dataset;
191
192    #[test]
193    fn test_cifar10_load() {
194        let dataset = cifar10().unwrap();
195        assert_eq!(dataset.len(), 100);
196        assert_eq!(dataset.num_classes(), 10);
197    }
198
199    #[test]
200    fn test_cifar10_split() {
201        let dataset = cifar10().unwrap();
202        let split = dataset.split().unwrap();
203        assert_eq!(split.train.len(), 80);
204        assert_eq!(split.test.len(), 20);
205    }
206
207    #[test]
208    fn test_cifar10_class_names() {
209        assert_eq!(Cifar10Dataset::class_name(0), Some("airplane"));
210        assert_eq!(Cifar10Dataset::class_name(9), Some("truck"));
211        assert_eq!(Cifar10Dataset::class_name(10), None);
212    }
213
214    #[test]
215    fn test_cifar10_class_name_negative() {
216        assert_eq!(Cifar10Dataset::class_name(-1), None);
217        assert_eq!(Cifar10Dataset::class_name(-100), None);
218    }
219
220    #[test]
221    fn test_cifar10_all_class_names() {
222        for (idx, &expected) in CIFAR10_CLASSES.iter().enumerate() {
223            assert_eq!(Cifar10Dataset::class_name(idx as i32), Some(expected));
224        }
225    }
226
227    #[test]
228    fn test_cifar10_num_features() {
229        let dataset = cifar10().unwrap();
230        assert_eq!(dataset.num_features(), 3072);
231    }
232
233    #[test]
234    fn test_cifar10_feature_names() {
235        let dataset = cifar10().unwrap();
236        assert!(dataset.feature_names().is_empty());
237    }
238
239    #[test]
240    fn test_cifar10_target_name() {
241        let dataset = cifar10().unwrap();
242        assert_eq!(dataset.target_name(), "label");
243    }
244
245    #[test]
246    fn test_cifar10_description() {
247        let dataset = cifar10().unwrap();
248        let desc = dataset.description();
249        assert!(desc.contains("CIFAR-10"));
250        assert!(desc.contains("100 samples"));
251    }
252
253    #[test]
254    fn test_cifar10_data_access() {
255        let dataset = cifar10().unwrap();
256        let data = dataset.data();
257        assert_eq!(data.len(), 100);
258    }
259
260    #[test]
261    fn test_cifar10_schema_columns() {
262        let dataset = cifar10().unwrap();
263        let batch = dataset.data().get_batch(0).unwrap();
264        assert_eq!(batch.num_columns(), 3073); // 3072 pixels + 1 label
265    }
266
267    #[test]
268    fn test_cifar10_pixel_values_normalized() {
269        let dataset = cifar10().unwrap();
270        let batch = dataset.data().get_batch(0).unwrap();
271        let pixel_col = batch
272            .column(0)
273            .as_any()
274            .downcast_ref::<Float32Array>()
275            .unwrap();
276        for i in 0..pixel_col.len() {
277            let val = pixel_col.value(i);
278            assert!(
279                (0.0..=1.0).contains(&val),
280                "Pixel value {} out of range",
281                val
282            );
283        }
284    }
285
286    #[test]
287    fn test_cifar10_labels_in_range() {
288        let dataset = cifar10().unwrap();
289        let batch = dataset.data().get_batch(0).unwrap();
290        let label_col = batch
291            .column(3072)
292            .as_any()
293            .downcast_ref::<Int32Array>()
294            .unwrap();
295        for i in 0..label_col.len() {
296            let label = label_col.value(i);
297            assert!((0..10).contains(&label), "Label {} out of range", label);
298        }
299    }
300
301    #[test]
302    fn test_cifar10_clone() {
303        let dataset = cifar10().unwrap();
304        let cloned = dataset.clone();
305        assert_eq!(cloned.len(), dataset.len());
306    }
307
308    #[test]
309    fn test_cifar10_debug() {
310        let dataset = cifar10().unwrap();
311        let debug = format!("{:?}", dataset);
312        assert!(debug.contains("Cifar10Dataset"));
313    }
314
315    #[test]
316    fn test_embedded_cifar10_sample() {
317        let (pixels, labels) = embedded_cifar10_sample();
318        assert_eq!(pixels.len(), 100 * 3072);
319        assert_eq!(labels.len(), 100);
320    }
321
322    #[test]
323    fn test_embedded_cifar10_sample_labels_balanced() {
324        let (_, labels) = embedded_cifar10_sample();
325        let mut counts = [0i32; 10];
326        for label in labels {
327            counts[usize::try_from(label).unwrap()] += 1;
328        }
329        for (i, &count) in counts.iter().enumerate() {
330            assert_eq!(count, 10, "Class {} should have 10 samples", i);
331        }
332    }
333
334    #[test]
335    fn test_cifar10_classes_constant() {
336        assert_eq!(CIFAR10_CLASSES.len(), 10);
337        assert_eq!(CIFAR10_CLASSES[0], "airplane");
338        assert_eq!(CIFAR10_CLASSES[9], "truck");
339    }
340}