1use std::sync::Arc;
7
8use arrow::{
9 array::{Float32Array, Int32Array, RecordBatch},
10 datatypes::{DataType, Field, Schema},
11};
12
13use super::{CanonicalDataset, DatasetSplit};
14use crate::{
15 transform::{Skip, Take, Transform},
16 ArrowDataset, Dataset, Result,
17};
18
19pub const FASHION_MNIST_CLASSES: [&str; 10] = [
21 "t-shirt/top",
22 "trouser",
23 "pullover",
24 "dress",
25 "coat",
26 "sandal",
27 "shirt",
28 "sneaker",
29 "bag",
30 "ankle boot",
31];
32
33pub fn fashion_mnist() -> Result<FashionMnistDataset> {
39 FashionMnistDataset::load()
40}
41
42#[derive(Debug, Clone)]
44pub struct FashionMnistDataset {
45 data: ArrowDataset,
46}
47
48impl FashionMnistDataset {
49 pub fn load() -> Result<Self> {
55 let mut fields: Vec<Field> = (0..784)
56 .map(|i| Field::new(format!("pixel_{i}"), DataType::Float32, false))
57 .collect();
58 fields.push(Field::new("label", DataType::Int32, false));
59 let schema = Arc::new(Schema::new(fields));
60
61 let (pixels, labels) = embedded_fashion_mnist_sample();
62 let num_samples = labels.len();
63
64 let mut columns: Vec<Arc<dyn arrow::array::Array>> = Vec::with_capacity(785);
65 for pixel_idx in 0..784 {
66 let pixel_data: Vec<f32> = (0..num_samples)
67 .map(|s| pixels[s * 784 + pixel_idx])
68 .collect();
69 columns.push(Arc::new(Float32Array::from(pixel_data)));
70 }
71 columns.push(Arc::new(Int32Array::from(labels)));
72
73 let batch = RecordBatch::try_new(schema, columns).map_err(crate::Error::Arrow)?;
74 let data = ArrowDataset::from_batch(batch)?;
75
76 Ok(Self { data })
77 }
78
79 #[cfg(feature = "hf-hub")]
81 pub fn load_full() -> Result<Self> {
82 use crate::hf_hub::HfDataset;
83 let hf = HfDataset::builder("zalando-datasets/fashion_mnist")
84 .split("train")
85 .build()?;
86 let data = hf.download()?;
87 Ok(Self { data })
88 }
89
90 pub fn split(&self) -> Result<DatasetSplit> {
96 let len = self.data.len();
97 let train_size = (len * 8) / 10;
98
99 let batch = self
100 .data
101 .get_batch(0)
102 .ok_or_else(|| crate::Error::empty_dataset("Fashion-MNIST"))?;
103
104 let train_batch = Take::new(train_size).apply(batch.clone())?;
105 let test_batch = Skip::new(train_size).apply(batch.clone())?;
106
107 Ok(DatasetSplit::new(
108 ArrowDataset::from_batch(train_batch)?,
109 ArrowDataset::from_batch(test_batch)?,
110 ))
111 }
112
113 #[must_use]
115 pub fn class_name(label: i32) -> Option<&'static str> {
116 if label < 0 {
117 return None;
118 }
119 FASHION_MNIST_CLASSES
120 .get(usize::try_from(label).ok()?)
121 .copied()
122 }
123}
124
125impl CanonicalDataset for FashionMnistDataset {
126 fn data(&self) -> &ArrowDataset {
127 &self.data
128 }
129 fn num_features(&self) -> usize {
130 784
131 }
132 fn num_classes(&self) -> usize {
133 10
134 }
135 fn feature_names(&self) -> &'static [&'static str] {
136 &[]
137 }
138 fn target_name(&self) -> &'static str {
139 "label"
140 }
141 fn description(&self) -> &'static str {
142 "Fashion-MNIST (Xiao et al. 2017). Embedded: 100 samples. Full: 70k (requires hf-hub)."
143 }
144}
145
146fn embedded_fashion_mnist_sample() -> (Vec<f32>, Vec<i32>) {
148 let mut pixels = Vec::with_capacity(100 * 784);
149 let mut labels = Vec::with_capacity(100);
150
151 for class_idx in 0..10 {
152 for sample in 0..10i16 {
153 let pattern = generate_fashion_pattern(class_idx, sample);
154 pixels.extend(pattern);
155 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
156 labels.push(class_idx as i32);
157 }
158 }
159
160 (pixels, labels)
161}
162
163fn generate_fashion_pattern(class: usize, variation: i16) -> Vec<f32> {
165 let mut img = vec![0.0f32; 784];
166 let var = f32::from(variation) * 0.02;
167
168 match class {
169 0 => draw_tshirt(&mut img, var), 1 => draw_trouser(&mut img, var), 2 => draw_pullover(&mut img, var), 3 => draw_dress(&mut img, var), 4 => draw_coat(&mut img, var), 5 => draw_sandal(&mut img, var), 6 => draw_shirt(&mut img, var), 7 => draw_sneaker(&mut img, var), 8 => draw_bag(&mut img, var), 9 => draw_ankle_boot(&mut img, var), _ => {}
180 }
181
182 img
183}
184
185fn set_pixel(img: &mut [f32], x: usize, y: usize, val: f32) {
186 if x < 28 && y < 28 {
187 img[y * 28 + x] = val;
188 }
189}
190
191fn draw_tshirt(img: &mut [f32], var: f32) {
192 for y in 8..22 {
194 for x in 8..20 {
195 set_pixel(img, x, y, (0.8 + var).min(1.0));
196 }
197 }
198 for y in 8..12 {
200 for x in 4..8 {
201 set_pixel(img, x, y, (0.7 + var).min(1.0));
202 }
203 for x in 20..24 {
204 set_pixel(img, x, y, (0.7 + var).min(1.0));
205 }
206 }
207}
208
209fn draw_trouser(img: &mut [f32], var: f32) {
210 for y in 4..24 {
212 for x in 8..13 {
213 set_pixel(img, x, y, (0.6 + var).min(1.0));
214 }
215 }
216 for y in 4..24 {
218 for x in 15..20 {
219 set_pixel(img, x, y, (0.6 + var).min(1.0));
220 }
221 }
222 for x in 8..20 {
224 for y in 4..7 {
225 set_pixel(img, x, y, (0.7 + var).min(1.0));
226 }
227 }
228}
229
230fn draw_pullover(img: &mut [f32], var: f32) {
231 draw_tshirt(img, var);
232 for y in 12..16 {
234 for x in 4..8 {
235 set_pixel(img, x, y, (0.7 + var).min(1.0));
236 }
237 for x in 20..24 {
238 set_pixel(img, x, y, (0.7 + var).min(1.0));
239 }
240 }
241}
242
243fn draw_dress(img: &mut [f32], var: f32) {
244 for y in 6..12 {
246 for x in 10..18 {
247 set_pixel(img, x, y, (0.8 + var).min(1.0));
248 }
249 }
250 for y in 12..24 {
252 let width = 4 + (y - 12) / 2;
253 for x in (14 - width)..(14 + width) {
254 set_pixel(img, x, y, (0.8 + var).min(1.0));
255 }
256 }
257}
258
259fn draw_coat(img: &mut [f32], var: f32) {
260 draw_tshirt(img, var);
261 for y in 22..26 {
263 for x in 8..20 {
264 set_pixel(img, x, y, (0.8 + var).min(1.0));
265 }
266 }
267}
268
269fn draw_sandal(img: &mut [f32], var: f32) {
270 for x in 6..22 {
272 for y in 20..24 {
273 set_pixel(img, x, y, (0.5 + var).min(1.0));
274 }
275 }
276 for x in 8..20 {
278 set_pixel(img, x, 16, (0.7 + var).min(1.0));
279 set_pixel(img, x, 12, (0.7 + var).min(1.0));
280 }
281}
282
283fn draw_shirt(img: &mut [f32], var: f32) {
284 draw_tshirt(img, var);
285 for x in 12..16 {
287 set_pixel(img, x, 7, (0.9 + var).min(1.0));
288 }
289}
290
291fn draw_sneaker(img: &mut [f32], var: f32) {
292 for x in 4..24 {
294 for y in 18..22 {
295 set_pixel(img, x, y, (0.4 + var).min(1.0));
296 }
297 }
298 for x in 6..22 {
300 for y in 12..18 {
301 set_pixel(img, x, y, (0.8 + var).min(1.0));
302 }
303 }
304}
305
306fn draw_bag(img: &mut [f32], var: f32) {
307 for y in 10..24 {
309 for x in 8..20 {
310 set_pixel(img, x, y, (0.7 + var).min(1.0));
311 }
312 }
313 for x in 10..18 {
315 set_pixel(img, x, 6, (0.6 + var).min(1.0));
316 set_pixel(img, x, 8, (0.6 + var).min(1.0));
317 }
318 set_pixel(img, 10, 7, (0.6 + var).min(1.0));
319 set_pixel(img, 17, 7, (0.6 + var).min(1.0));
320}
321
322fn draw_ankle_boot(img: &mut [f32], var: f32) {
323 for x in 6..22 {
325 for y in 20..24 {
326 set_pixel(img, x, y, (0.3 + var).min(1.0));
327 }
328 }
329 for x in 8..20 {
331 for y in 8..20 {
332 set_pixel(img, x, y, (0.6 + var).min(1.0));
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use arrow::array::Float32Array;
340
341 use super::*;
342 use crate::Dataset;
343
344 #[test]
345 fn test_fashion_mnist_load() {
346 let dataset = fashion_mnist().unwrap();
347 assert_eq!(dataset.len(), 100);
348 assert_eq!(dataset.num_classes(), 10);
349 }
350
351 #[test]
352 fn test_fashion_mnist_split() {
353 let dataset = fashion_mnist().unwrap();
354 let split = dataset.split().unwrap();
355 assert_eq!(split.train.len(), 80);
356 assert_eq!(split.test.len(), 20);
357 }
358
359 #[test]
360 fn test_fashion_mnist_class_names() {
361 assert_eq!(FashionMnistDataset::class_name(0), Some("t-shirt/top"));
362 assert_eq!(FashionMnistDataset::class_name(9), Some("ankle boot"));
363 assert_eq!(FashionMnistDataset::class_name(10), None);
364 assert_eq!(FashionMnistDataset::class_name(-1), None);
365 }
366
367 #[test]
368 fn test_fashion_mnist_all_class_names() {
369 for (idx, &expected) in FASHION_MNIST_CLASSES.iter().enumerate() {
370 assert_eq!(FashionMnistDataset::class_name(idx as i32), Some(expected));
371 }
372 }
373
374 #[test]
375 fn test_fashion_mnist_num_features() {
376 let dataset = fashion_mnist().unwrap();
377 assert_eq!(dataset.num_features(), 784);
378 }
379
380 #[test]
381 fn test_fashion_mnist_feature_names() {
382 let dataset = fashion_mnist().unwrap();
383 assert!(dataset.feature_names().is_empty());
384 }
385
386 #[test]
387 fn test_fashion_mnist_target_name() {
388 let dataset = fashion_mnist().unwrap();
389 assert_eq!(dataset.target_name(), "label");
390 }
391
392 #[test]
393 fn test_fashion_mnist_description() {
394 let dataset = fashion_mnist().unwrap();
395 let desc = dataset.description();
396 assert!(desc.contains("Fashion-MNIST"));
397 assert!(desc.contains("Xiao"));
398 }
399
400 #[test]
401 fn test_fashion_mnist_data_access() {
402 let dataset = fashion_mnist().unwrap();
403 let data = dataset.data();
404 assert_eq!(data.len(), 100);
405 }
406
407 #[test]
408 fn test_fashion_mnist_schema_columns() {
409 let dataset = fashion_mnist().unwrap();
410 let batch = dataset.data().get_batch(0).unwrap();
411 assert_eq!(batch.num_columns(), 785); }
413
414 #[test]
415 fn test_fashion_mnist_labels_in_range() {
416 let dataset = fashion_mnist().unwrap();
417 let batch = dataset.data().get_batch(0).unwrap();
418 let label_col = batch
419 .column(784)
420 .as_any()
421 .downcast_ref::<Int32Array>()
422 .unwrap();
423 for i in 0..label_col.len() {
424 let label = label_col.value(i);
425 assert!((0..10).contains(&label), "Label {} out of range", label);
426 }
427 }
428
429 #[test]
430 fn test_fashion_mnist_pixel_values_normalized() {
431 let dataset = fashion_mnist().unwrap();
432 let batch = dataset.data().get_batch(0).unwrap();
433 let pixel_col = batch
434 .column(0)
435 .as_any()
436 .downcast_ref::<Float32Array>()
437 .unwrap();
438 for i in 0..pixel_col.len() {
439 let val = pixel_col.value(i);
440 assert!(
441 (0.0..=1.0).contains(&val),
442 "Pixel value {} out of range",
443 val
444 );
445 }
446 }
447
448 #[test]
449 fn test_fashion_mnist_clone() {
450 let dataset = fashion_mnist().unwrap();
451 let cloned = dataset.clone();
452 assert_eq!(cloned.len(), dataset.len());
453 }
454
455 #[test]
456 fn test_fashion_mnist_debug() {
457 let dataset = fashion_mnist().unwrap();
458 let debug = format!("{:?}", dataset);
459 assert!(debug.contains("FashionMnistDataset"));
460 }
461
462 #[test]
463 fn test_embedded_fashion_mnist_sample() {
464 let (pixels, labels) = embedded_fashion_mnist_sample();
465 assert_eq!(pixels.len(), 100 * 784);
466 assert_eq!(labels.len(), 100);
467 }
468
469 #[test]
470 fn test_embedded_fashion_mnist_sample_labels_balanced() {
471 let (_, labels) = embedded_fashion_mnist_sample();
472 let mut counts = [0i32; 10];
473 for label in labels {
474 counts[usize::try_from(label).unwrap()] += 1;
475 }
476 for (class, &count) in counts.iter().enumerate() {
477 assert_eq!(count, 10, "Class {} should have 10 samples", class);
478 }
479 }
480
481 #[test]
482 fn test_generate_fashion_pattern_all_classes() {
483 for class in 0..10 {
484 let pattern = generate_fashion_pattern(class, 0);
485 assert_eq!(pattern.len(), 784, "Class {} pattern wrong size", class);
486 let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
487 assert!(
488 non_zero > 0,
489 "Class {} pattern should have non-zero pixels",
490 class
491 );
492 }
493 }
494
495 #[test]
496 fn test_generate_fashion_pattern_with_variation() {
497 let pattern1 = generate_fashion_pattern(0, 0);
498 let pattern2 = generate_fashion_pattern(0, 5);
499 let different = pattern1
501 .iter()
502 .zip(pattern2.iter())
503 .any(|(a, b)| (a - b).abs() > 0.001);
504 assert!(
505 different,
506 "Patterns with different variations should differ"
507 );
508 }
509
510 #[test]
511 fn test_generate_fashion_pattern_unknown() {
512 let pattern = generate_fashion_pattern(99, 0);
513 assert_eq!(pattern.len(), 784);
514 let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
516 assert_eq!(non_zero, 0, "Unknown class should have all zeros");
517 }
518
519 #[test]
520 fn test_set_pixel_in_bounds() {
521 let mut img = vec![0.0f32; 784];
522 set_pixel(&mut img, 14, 14, 1.0);
523 assert_eq!(img[14 * 28 + 14], 1.0);
524 }
525
526 #[test]
527 fn test_set_pixel_out_of_bounds() {
528 let mut img = vec![0.0f32; 784];
529 set_pixel(&mut img, 30, 14, 1.0); set_pixel(&mut img, 14, 30, 1.0); let non_zero: usize = img.iter().filter(|&&p| p > 0.0).count();
533 assert_eq!(non_zero, 0);
534 }
535
536 #[test]
537 fn test_fashion_mnist_classes_constant() {
538 assert_eq!(FASHION_MNIST_CLASSES.len(), 10);
539 assert_eq!(FASHION_MNIST_CLASSES[0], "t-shirt/top");
540 assert_eq!(FASHION_MNIST_CLASSES[9], "ankle boot");
541 }
542}