1use std::sync::Arc;
7
8use arrow::{
9 array::{Float32Array, Int32Array, RecordBatch},
10 datatypes::{DataType, Field, Schema},
11};
12
13use super::{CanonicalDataset, DatasetSplit};
14use crate::{
15 transform::{Skip, Take, Transform},
16 ArrowDataset, Dataset, Result,
17};
18
19pub const CIFAR10_CLASSES: [&str; 10] = [
21 "airplane",
22 "automobile",
23 "bird",
24 "cat",
25 "deer",
26 "dog",
27 "frog",
28 "horse",
29 "ship",
30 "truck",
31];
32
33pub fn cifar10() -> Result<Cifar10Dataset> {
39 Cifar10Dataset::load()
40}
41
42#[derive(Debug, Clone)]
44pub struct Cifar10Dataset {
45 data: ArrowDataset,
46}
47
48impl Cifar10Dataset {
49 pub fn load() -> Result<Self> {
55 let mut fields: Vec<Field> = (0..3072)
57 .map(|i| Field::new(format!("pixel_{i}"), DataType::Float32, false))
58 .collect();
59 fields.push(Field::new("label", DataType::Int32, false));
60 let schema = Arc::new(Schema::new(fields));
61
62 let (pixels, labels) = embedded_cifar10_sample();
63 let num_samples = labels.len();
64
65 let mut columns: Vec<Arc<dyn arrow::array::Array>> = Vec::with_capacity(3073);
66 for pixel_idx in 0..3072 {
67 let pixel_data: Vec<f32> = (0..num_samples)
68 .map(|s| pixels[s * 3072 + pixel_idx])
69 .collect();
70 columns.push(Arc::new(Float32Array::from(pixel_data)));
71 }
72 columns.push(Arc::new(Int32Array::from(labels)));
73
74 let batch = RecordBatch::try_new(schema, columns).map_err(crate::Error::Arrow)?;
75 let data = ArrowDataset::from_batch(batch)?;
76
77 Ok(Self { data })
78 }
79
80 #[cfg(feature = "hf-hub")]
82 pub fn load_full() -> Result<Self> {
83 use crate::hf_hub::HfDataset;
84 let hf = HfDataset::builder("uoft-cs/cifar10")
85 .split("train")
86 .build()?;
87 let data = hf.download()?;
88 Ok(Self { data })
89 }
90
91 pub fn split(&self) -> Result<DatasetSplit> {
97 let len = self.data.len();
98 let train_size = (len * 8) / 10;
99
100 let batch = self
101 .data
102 .get_batch(0)
103 .ok_or_else(|| crate::Error::empty_dataset("CIFAR-10"))?;
104
105 let train_batch = Take::new(train_size).apply(batch.clone())?;
106 let test_batch = Skip::new(train_size).apply(batch.clone())?;
107
108 Ok(DatasetSplit::new(
109 ArrowDataset::from_batch(train_batch)?,
110 ArrowDataset::from_batch(test_batch)?,
111 ))
112 }
113
114 #[must_use]
116 pub fn class_name(label: i32) -> Option<&'static str> {
117 if label < 0 {
118 return None;
119 }
120 CIFAR10_CLASSES.get(usize::try_from(label).ok()?).copied()
121 }
122}
123
124impl CanonicalDataset for Cifar10Dataset {
125 fn data(&self) -> &ArrowDataset {
126 &self.data
127 }
128 fn num_features(&self) -> usize {
129 3072
130 }
131 fn num_classes(&self) -> usize {
132 10
133 }
134 fn feature_names(&self) -> &'static [&'static str] {
135 &[]
136 }
137 fn target_name(&self) -> &'static str {
138 "label"
139 }
140 fn description(&self) -> &'static str {
141 "CIFAR-10 (Krizhevsky 2009). Embedded: 100 samples. Full: 60k (requires hf-hub)."
142 }
143}
144
145#[allow(clippy::cast_precision_loss)]
147fn embedded_cifar10_sample() -> (Vec<f32>, Vec<i32>) {
148 let mut pixels = Vec::with_capacity(100 * 3072);
149 let mut labels = Vec::with_capacity(100);
150
151 let class_colors: [(f32, f32, f32); 10] = [
153 (0.5, 0.7, 0.9), (0.3, 0.3, 0.3), (0.6, 0.4, 0.2), (0.8, 0.6, 0.4), (0.4, 0.3, 0.2), (0.7, 0.5, 0.3), (0.2, 0.8, 0.2), (0.5, 0.3, 0.2), (0.2, 0.3, 0.5), (0.6, 0.2, 0.2), ];
164
165 for (class_idx, &(r, g, b)) in class_colors.iter().enumerate() {
166 for sample in 0..10i16 {
167 let var = f32::from(sample) * 0.02;
169 for _ in 0..1024 {
170 pixels.push((r + var).min(1.0));
171 } for _ in 0..1024 {
173 pixels.push((g + var).min(1.0));
174 } for _ in 0..1024 {
176 pixels.push((b + var).min(1.0));
177 } #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
180 labels.push(class_idx as i32);
181 }
182 }
183
184 (pixels, labels)
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::Dataset;
191
192 #[test]
193 fn test_cifar10_load() {
194 let dataset = cifar10().unwrap();
195 assert_eq!(dataset.len(), 100);
196 assert_eq!(dataset.num_classes(), 10);
197 }
198
199 #[test]
200 fn test_cifar10_split() {
201 let dataset = cifar10().unwrap();
202 let split = dataset.split().unwrap();
203 assert_eq!(split.train.len(), 80);
204 assert_eq!(split.test.len(), 20);
205 }
206
207 #[test]
208 fn test_cifar10_class_names() {
209 assert_eq!(Cifar10Dataset::class_name(0), Some("airplane"));
210 assert_eq!(Cifar10Dataset::class_name(9), Some("truck"));
211 assert_eq!(Cifar10Dataset::class_name(10), None);
212 }
213
214 #[test]
215 fn test_cifar10_class_name_negative() {
216 assert_eq!(Cifar10Dataset::class_name(-1), None);
217 assert_eq!(Cifar10Dataset::class_name(-100), None);
218 }
219
220 #[test]
221 fn test_cifar10_all_class_names() {
222 for (idx, &expected) in CIFAR10_CLASSES.iter().enumerate() {
223 assert_eq!(Cifar10Dataset::class_name(idx as i32), Some(expected));
224 }
225 }
226
227 #[test]
228 fn test_cifar10_num_features() {
229 let dataset = cifar10().unwrap();
230 assert_eq!(dataset.num_features(), 3072);
231 }
232
233 #[test]
234 fn test_cifar10_feature_names() {
235 let dataset = cifar10().unwrap();
236 assert!(dataset.feature_names().is_empty());
237 }
238
239 #[test]
240 fn test_cifar10_target_name() {
241 let dataset = cifar10().unwrap();
242 assert_eq!(dataset.target_name(), "label");
243 }
244
245 #[test]
246 fn test_cifar10_description() {
247 let dataset = cifar10().unwrap();
248 let desc = dataset.description();
249 assert!(desc.contains("CIFAR-10"));
250 assert!(desc.contains("100 samples"));
251 }
252
253 #[test]
254 fn test_cifar10_data_access() {
255 let dataset = cifar10().unwrap();
256 let data = dataset.data();
257 assert_eq!(data.len(), 100);
258 }
259
260 #[test]
261 fn test_cifar10_schema_columns() {
262 let dataset = cifar10().unwrap();
263 let batch = dataset.data().get_batch(0).unwrap();
264 assert_eq!(batch.num_columns(), 3073); }
266
267 #[test]
268 fn test_cifar10_pixel_values_normalized() {
269 let dataset = cifar10().unwrap();
270 let batch = dataset.data().get_batch(0).unwrap();
271 let pixel_col = batch
272 .column(0)
273 .as_any()
274 .downcast_ref::<Float32Array>()
275 .unwrap();
276 for i in 0..pixel_col.len() {
277 let val = pixel_col.value(i);
278 assert!(
279 (0.0..=1.0).contains(&val),
280 "Pixel value {} out of range",
281 val
282 );
283 }
284 }
285
286 #[test]
287 fn test_cifar10_labels_in_range() {
288 let dataset = cifar10().unwrap();
289 let batch = dataset.data().get_batch(0).unwrap();
290 let label_col = batch
291 .column(3072)
292 .as_any()
293 .downcast_ref::<Int32Array>()
294 .unwrap();
295 for i in 0..label_col.len() {
296 let label = label_col.value(i);
297 assert!((0..10).contains(&label), "Label {} out of range", label);
298 }
299 }
300
301 #[test]
302 fn test_cifar10_clone() {
303 let dataset = cifar10().unwrap();
304 let cloned = dataset.clone();
305 assert_eq!(cloned.len(), dataset.len());
306 }
307
308 #[test]
309 fn test_cifar10_debug() {
310 let dataset = cifar10().unwrap();
311 let debug = format!("{:?}", dataset);
312 assert!(debug.contains("Cifar10Dataset"));
313 }
314
315 #[test]
316 fn test_embedded_cifar10_sample() {
317 let (pixels, labels) = embedded_cifar10_sample();
318 assert_eq!(pixels.len(), 100 * 3072);
319 assert_eq!(labels.len(), 100);
320 }
321
322 #[test]
323 fn test_embedded_cifar10_sample_labels_balanced() {
324 let (_, labels) = embedded_cifar10_sample();
325 let mut counts = [0i32; 10];
326 for label in labels {
327 counts[usize::try_from(label).unwrap()] += 1;
328 }
329 for (i, &count) in counts.iter().enumerate() {
330 assert_eq!(count, 10, "Class {} should have 10 samples", i);
331 }
332 }
333
334 #[test]
335 fn test_cifar10_classes_constant() {
336 assert_eq!(CIFAR10_CLASSES.len(), 10);
337 assert_eq!(CIFAR10_CLASSES[0], "airplane");
338 assert_eq!(CIFAR10_CLASSES[9], "truck");
339 }
340}