1use std::sync::Arc;
7
8use arrow::{
9 array::{Float32Array, Int32Array, RecordBatch},
10 datatypes::{DataType, Field, Schema},
11};
12
13use super::{CanonicalDataset, DatasetSplit};
14use crate::{
15 transform::{Skip, Take, Transform},
16 ArrowDataset, Dataset, Result,
17};
18
19pub const CIFAR100_FINE_CLASSES: [&str; 100] = [
21 "apple",
22 "aquarium_fish",
23 "baby",
24 "bear",
25 "beaver",
26 "bed",
27 "bee",
28 "beetle",
29 "bicycle",
30 "bottle",
31 "bowl",
32 "boy",
33 "bridge",
34 "bus",
35 "butterfly",
36 "camel",
37 "can",
38 "castle",
39 "caterpillar",
40 "cattle",
41 "chair",
42 "chimpanzee",
43 "clock",
44 "cloud",
45 "cockroach",
46 "couch",
47 "crab",
48 "crocodile",
49 "cup",
50 "dinosaur",
51 "dolphin",
52 "elephant",
53 "flatfish",
54 "forest",
55 "fox",
56 "girl",
57 "hamster",
58 "house",
59 "kangaroo",
60 "keyboard",
61 "lamp",
62 "lawn_mower",
63 "leopard",
64 "lion",
65 "lizard",
66 "lobster",
67 "man",
68 "maple_tree",
69 "motorcycle",
70 "mountain",
71 "mouse",
72 "mushroom",
73 "oak_tree",
74 "orange",
75 "orchid",
76 "otter",
77 "palm_tree",
78 "pear",
79 "pickup_truck",
80 "pine_tree",
81 "plain",
82 "plate",
83 "poppy",
84 "porcupine",
85 "possum",
86 "rabbit",
87 "raccoon",
88 "ray",
89 "road",
90 "rocket",
91 "rose",
92 "sea",
93 "seal",
94 "shark",
95 "shrew",
96 "skunk",
97 "skyscraper",
98 "snail",
99 "snake",
100 "spider",
101 "squirrel",
102 "streetcar",
103 "sunflower",
104 "sweet_pepper",
105 "table",
106 "tank",
107 "telephone",
108 "television",
109 "tiger",
110 "tractor",
111 "train",
112 "trout",
113 "tulip",
114 "turtle",
115 "wardrobe",
116 "whale",
117 "willow_tree",
118 "wolf",
119 "woman",
120 "worm",
121];
122
123pub const CIFAR100_COARSE_CLASSES: [&str; 20] = [
125 "aquatic_mammals",
126 "fish",
127 "flowers",
128 "food_containers",
129 "fruit_and_vegetables",
130 "household_electrical_devices",
131 "household_furniture",
132 "insects",
133 "large_carnivores",
134 "large_man-made_outdoor_things",
135 "large_natural_outdoor_scenes",
136 "large_omnivores_and_herbivores",
137 "medium_mammals",
138 "non-insect_invertebrates",
139 "people",
140 "reptiles",
141 "small_mammals",
142 "trees",
143 "vehicles_1",
144 "vehicles_2",
145];
146
147pub fn cifar100() -> Result<Cifar100Dataset> {
153 Cifar100Dataset::load()
154}
155
156#[derive(Debug, Clone)]
158pub struct Cifar100Dataset {
159 data: ArrowDataset,
160}
161
162impl Cifar100Dataset {
163 pub fn load() -> Result<Self> {
169 let mut fields: Vec<Field> = (0..3072)
170 .map(|i| Field::new(format!("pixel_{i}"), DataType::Float32, false))
171 .collect();
172 fields.push(Field::new("fine_label", DataType::Int32, false));
173 fields.push(Field::new("coarse_label", DataType::Int32, false));
174 let schema = Arc::new(Schema::new(fields));
175
176 let (pixels, fine_labels, coarse_labels) = embedded_cifar100_sample();
177 let num_samples = fine_labels.len();
178
179 let mut columns: Vec<Arc<dyn arrow::array::Array>> = Vec::with_capacity(3074);
180 for pixel_idx in 0..3072 {
181 let pixel_data: Vec<f32> = (0..num_samples)
182 .map(|s| pixels[s * 3072 + pixel_idx])
183 .collect();
184 columns.push(Arc::new(Float32Array::from(pixel_data)));
185 }
186 columns.push(Arc::new(Int32Array::from(fine_labels)));
187 columns.push(Arc::new(Int32Array::from(coarse_labels)));
188
189 let batch = RecordBatch::try_new(schema, columns).map_err(crate::Error::Arrow)?;
190 let data = ArrowDataset::from_batch(batch)?;
191
192 Ok(Self { data })
193 }
194
195 #[cfg(feature = "hf-hub")]
197 pub fn load_full() -> Result<Self> {
198 use crate::hf_hub::HfDataset;
199 let hf = HfDataset::builder("uoft-cs/cifar100")
200 .split("train")
201 .build()?;
202 let data = hf.download()?;
203 Ok(Self { data })
204 }
205
206 pub fn split(&self) -> Result<DatasetSplit> {
212 let len = self.data.len();
213 let train_size = (len * 8) / 10;
214
215 let batch = self
216 .data
217 .get_batch(0)
218 .ok_or_else(|| crate::Error::empty_dataset("CIFAR-100"))?;
219
220 let train_batch = Take::new(train_size).apply(batch.clone())?;
221 let test_batch = Skip::new(train_size).apply(batch.clone())?;
222
223 Ok(DatasetSplit::new(
224 ArrowDataset::from_batch(train_batch)?,
225 ArrowDataset::from_batch(test_batch)?,
226 ))
227 }
228
229 #[must_use]
231 pub fn fine_class_name(label: i32) -> Option<&'static str> {
232 if label < 0 {
233 return None;
234 }
235 CIFAR100_FINE_CLASSES
236 .get(usize::try_from(label).ok()?)
237 .copied()
238 }
239
240 #[must_use]
242 pub fn coarse_class_name(label: i32) -> Option<&'static str> {
243 if label < 0 {
244 return None;
245 }
246 CIFAR100_COARSE_CLASSES
247 .get(usize::try_from(label).ok()?)
248 .copied()
249 }
250}
251
252impl CanonicalDataset for Cifar100Dataset {
253 fn data(&self) -> &ArrowDataset {
254 &self.data
255 }
256 fn num_features(&self) -> usize {
257 3072
258 }
259 fn num_classes(&self) -> usize {
260 100
261 }
262 fn feature_names(&self) -> &'static [&'static str] {
263 &[]
264 }
265 fn target_name(&self) -> &'static str {
266 "fine_label"
267 }
268 fn description(&self) -> &'static str {
269 "CIFAR-100 (Krizhevsky 2009). 100 fine classes, 20 coarse. Embedded: 100. Full: 60k."
270 }
271}
272
273const FINE_TO_COARSE: [usize; 100] = [
275 4, 1, 14, 8, 0, 6, 7, 7, 18, 3, 3, 14, 9, 18, 7, 11, 3, 9, 7, 11, 6, 11, 5, 10, 7, 6, 13, 15,
276 3, 15, 0, 11, 1, 10, 12, 14, 16, 9, 11, 5, 5, 19, 8, 8, 15, 13, 14, 17, 18, 10, 16, 4, 17, 4,
277 2, 0, 17, 4, 18, 17, 10, 3, 2, 12, 12, 16, 12, 1, 9, 19, 2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
278 16, 19, 2, 4, 6, 19, 5, 5, 8, 19, 18, 1, 2, 15, 6, 0, 17, 8, 14, 13,
279];
280
281#[allow(clippy::cast_precision_loss)]
283fn embedded_cifar100_sample() -> (Vec<f32>, Vec<i32>, Vec<i32>) {
284 let mut pixels = Vec::with_capacity(100 * 3072);
285 let mut fine_labels = Vec::with_capacity(100);
286 let mut coarse_labels = Vec::with_capacity(100);
287
288 for (class_idx, &coarse_idx) in FINE_TO_COARSE.iter().enumerate() {
290 let r = ((class_idx * 37) % 100) as f32 / 100.0;
292 let g = ((class_idx * 59) % 100) as f32 / 100.0;
293 let b = ((class_idx * 73) % 100) as f32 / 100.0;
294
295 for _ in 0..1024 {
297 pixels.push(r);
298 }
299 for _ in 0..1024 {
300 pixels.push(g);
301 }
302 for _ in 0..1024 {
303 pixels.push(b);
304 }
305
306 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
307 {
308 fine_labels.push(class_idx as i32);
309 coarse_labels.push(coarse_idx as i32);
310 }
311 }
312
313 (pixels, fine_labels, coarse_labels)
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::Dataset;
320
321 #[test]
322 fn test_cifar100_load() {
323 let dataset = cifar100().unwrap();
324 assert_eq!(dataset.len(), 100);
325 assert_eq!(dataset.num_classes(), 100);
326 }
327
328 #[test]
329 fn test_cifar100_split() {
330 let dataset = cifar100().unwrap();
331 let split = dataset.split().unwrap();
332 assert_eq!(split.train.len(), 80);
333 assert_eq!(split.test.len(), 20);
334 }
335
336 #[test]
337 fn test_cifar100_fine_class_names() {
338 assert_eq!(Cifar100Dataset::fine_class_name(0), Some("apple"));
339 assert_eq!(Cifar100Dataset::fine_class_name(99), Some("worm"));
340 assert_eq!(Cifar100Dataset::fine_class_name(100), None);
341 assert_eq!(Cifar100Dataset::fine_class_name(-1), None);
342 }
343
344 #[test]
345 fn test_cifar100_coarse_class_names() {
346 assert_eq!(
347 Cifar100Dataset::coarse_class_name(0),
348 Some("aquatic_mammals")
349 );
350 assert_eq!(Cifar100Dataset::coarse_class_name(19), Some("vehicles_2"));
351 assert_eq!(Cifar100Dataset::coarse_class_name(20), None);
352 }
353
354 #[test]
355 fn test_cifar100_has_both_labels() {
356 let dataset = cifar100().unwrap();
357 let schema = dataset.data().schema();
358 assert!(schema.field_with_name("fine_label").is_ok());
359 assert!(schema.field_with_name("coarse_label").is_ok());
360 }
361
362 #[test]
363 fn test_cifar100_coarse_class_name_negative() {
364 assert_eq!(Cifar100Dataset::coarse_class_name(-1), None);
365 assert_eq!(Cifar100Dataset::coarse_class_name(-100), None);
366 }
367
368 #[test]
369 fn test_cifar100_num_features() {
370 let dataset = cifar100().unwrap();
371 assert_eq!(dataset.num_features(), 3072);
372 }
373
374 #[test]
375 fn test_cifar100_feature_names() {
376 let dataset = cifar100().unwrap();
377 assert!(dataset.feature_names().is_empty());
378 }
379
380 #[test]
381 fn test_cifar100_target_name() {
382 let dataset = cifar100().unwrap();
383 assert_eq!(dataset.target_name(), "fine_label");
384 }
385
386 #[test]
387 fn test_cifar100_description() {
388 let dataset = cifar100().unwrap();
389 let desc = dataset.description();
390 assert!(desc.contains("CIFAR-100"));
391 assert!(desc.contains("100 fine classes"));
392 }
393
394 #[test]
395 fn test_cifar100_data_access() {
396 let dataset = cifar100().unwrap();
397 let data = dataset.data();
398 assert_eq!(data.len(), 100);
399 }
400
401 #[test]
402 fn test_cifar100_schema_columns() {
403 let dataset = cifar100().unwrap();
404 let batch = dataset.data().get_batch(0).unwrap();
405 assert_eq!(batch.num_columns(), 3074); }
407
408 #[test]
409 fn test_cifar100_fine_labels_in_range() {
410 let dataset = cifar100().unwrap();
411 let batch = dataset.data().get_batch(0).unwrap();
412 let label_col = batch
413 .column(3072)
414 .as_any()
415 .downcast_ref::<Int32Array>()
416 .unwrap();
417 for i in 0..label_col.len() {
418 let label = label_col.value(i);
419 assert!(
420 (0..100).contains(&label),
421 "Fine label {} out of range",
422 label
423 );
424 }
425 }
426
427 #[test]
428 fn test_cifar100_coarse_labels_in_range() {
429 let dataset = cifar100().unwrap();
430 let batch = dataset.data().get_batch(0).unwrap();
431 let label_col = batch
432 .column(3073)
433 .as_any()
434 .downcast_ref::<Int32Array>()
435 .unwrap();
436 for i in 0..label_col.len() {
437 let label = label_col.value(i);
438 assert!(
439 (0..20).contains(&label),
440 "Coarse label {} out of range",
441 label
442 );
443 }
444 }
445
446 #[test]
447 fn test_cifar100_clone() {
448 let dataset = cifar100().unwrap();
449 let cloned = dataset.clone();
450 assert_eq!(cloned.len(), dataset.len());
451 }
452
453 #[test]
454 fn test_cifar100_debug() {
455 let dataset = cifar100().unwrap();
456 let debug = format!("{:?}", dataset);
457 assert!(debug.contains("Cifar100Dataset"));
458 }
459
460 #[test]
461 fn test_cifar100_fine_classes_constant() {
462 assert_eq!(CIFAR100_FINE_CLASSES.len(), 100);
463 assert_eq!(CIFAR100_FINE_CLASSES[0], "apple");
464 assert_eq!(CIFAR100_FINE_CLASSES[99], "worm");
465 }
466
467 #[test]
468 fn test_cifar100_coarse_classes_constant() {
469 assert_eq!(CIFAR100_COARSE_CLASSES.len(), 20);
470 assert_eq!(CIFAR100_COARSE_CLASSES[0], "aquatic_mammals");
471 assert_eq!(CIFAR100_COARSE_CLASSES[19], "vehicles_2");
472 }
473
474 #[test]
475 fn test_fine_to_coarse_mapping_valid() {
476 for &coarse_idx in &FINE_TO_COARSE {
477 assert!(coarse_idx < 20, "Coarse index {} out of range", coarse_idx);
478 }
479 }
480
481 #[test]
482 fn test_embedded_cifar100_sample() {
483 let (pixels, fine_labels, coarse_labels) = embedded_cifar100_sample();
484 assert_eq!(pixels.len(), 100 * 3072);
485 assert_eq!(fine_labels.len(), 100);
486 assert_eq!(coarse_labels.len(), 100);
487 }
488
489 #[test]
490 fn test_embedded_cifar100_sample_labels_valid() {
491 let (_, fine_labels, coarse_labels) = embedded_cifar100_sample();
492 for (i, &fine) in fine_labels.iter().enumerate() {
493 assert!(
494 (0..100).contains(&fine),
495 "Fine label {} at {} out of range",
496 fine,
497 i
498 );
499 }
500 for (i, &coarse) in coarse_labels.iter().enumerate() {
501 assert!(
502 (0..20).contains(&coarse),
503 "Coarse label {} at {} out of range",
504 coarse,
505 i
506 );
507 }
508 }
509}