1use std::sync::Arc;
7
8use arrow::{
9 array::{Float32Array, Int32Array, RecordBatch},
10 datatypes::{DataType, Field, Schema},
11};
12
13use super::CanonicalDataset;
14use crate::{split::DatasetSplit, ArrowDataset, Result};
15
16pub fn mnist() -> Result<MnistDataset> {
22 MnistDataset::load()
23}
24
25#[derive(Debug, Clone)]
27pub struct MnistDataset {
28 data: ArrowDataset,
29}
30
31impl MnistDataset {
32 pub fn load() -> Result<Self> {
38 let mut fields: Vec<Field> = (0..784)
40 .map(|i| Field::new(format!("pixel_{i}"), DataType::Float32, false))
41 .collect();
42 fields.push(Field::new("label", DataType::Int32, false));
43 let schema = Arc::new(Schema::new(fields));
44
45 let (pixels, labels) = embedded_mnist_sample();
48
49 let num_samples = labels.len();
50 let mut columns: Vec<Arc<dyn arrow::array::Array>> = Vec::with_capacity(785);
51
52 for pixel_idx in 0..784 {
53 let pixel_data: Vec<f32> = (0..num_samples)
54 .map(|s| pixels[s * 784 + pixel_idx])
55 .collect();
56 columns.push(Arc::new(Float32Array::from(pixel_data)));
57 }
58 columns.push(Arc::new(Int32Array::from(labels)));
59
60 let batch = RecordBatch::try_new(schema, columns).map_err(crate::Error::Arrow)?;
61 let data = ArrowDataset::from_batch(batch)?;
62
63 Ok(Self { data })
64 }
65
66 #[cfg(feature = "hf-hub")]
68 pub fn load_full() -> Result<Self> {
69 use crate::hf_hub::HfDataset;
70 let hf = HfDataset::builder("ylecun/mnist").split("train").build()?;
71 let data = hf.download()?;
72 Ok(Self { data })
73 }
74
75 pub fn split(&self) -> Result<DatasetSplit> {
85 DatasetSplit::stratified(
88 &self.data,
89 "label", 0.8, 0.2, None, Some(42), )
95 }
96}
97
98impl CanonicalDataset for MnistDataset {
99 fn data(&self) -> &ArrowDataset {
100 &self.data
101 }
102 fn num_features(&self) -> usize {
103 784
104 }
105 fn num_classes(&self) -> usize {
106 10
107 }
108 fn feature_names(&self) -> &'static [&'static str] {
109 &[]
110 }
111 fn target_name(&self) -> &'static str {
112 "label"
113 }
114 fn description(&self) -> &'static str {
115 "MNIST handwritten digits (LeCun 1998). Embedded: 100 samples. Full: 70k (requires hf-hub)."
116 }
117}
118
119fn embedded_mnist_sample() -> (Vec<f32>, Vec<i32>) {
121 let mut pixels = Vec::with_capacity(100 * 784);
124 let mut labels = Vec::with_capacity(100);
125
126 for digit in 0..10 {
127 for _ in 0..10 {
128 let pattern = generate_digit_pattern(digit);
130 pixels.extend(pattern);
131 labels.push(digit);
132 }
133 }
134
135 (pixels, labels)
136}
137
138fn generate_digit_pattern(digit: i32) -> Vec<f32> {
140 let mut img = vec![0.0f32; 784]; match digit {
144 0 => draw_oval(&mut img),
145 1 => draw_vertical_line(&mut img),
146 2 => draw_two(&mut img),
147 3 => draw_three(&mut img),
148 4 => draw_four(&mut img),
149 5 => draw_five(&mut img),
150 6 => draw_six(&mut img),
151 7 => draw_seven(&mut img),
152 8 => draw_eight(&mut img),
153 9 => draw_nine(&mut img),
154 _ => {}
155 }
156
157 img
158}
159
160fn set_pixel(img: &mut [f32], x: usize, y: usize, val: f32) {
161 if x < 28 && y < 28 {
162 img[y * 28 + x] = val;
163 }
164}
165
166fn draw_oval(img: &mut [f32]) {
167 draw_oval_top_bottom(img);
168 draw_oval_sides(img);
169}
170
171fn draw_oval_top_bottom(img: &mut [f32]) {
172 for x in 10..18 {
173 set_pixel(img, x, 6, 1.0);
174 set_pixel(img, x, 21, 1.0);
175 }
176}
177
178fn draw_oval_sides(img: &mut [f32]) {
179 for y in 8..20 {
180 set_pixel(img, 8, y, 1.0);
181 set_pixel(img, 19, y, 1.0);
182 }
183}
184
185fn draw_vertical_line(img: &mut [f32]) {
186 for y in 5..23 {
187 set_pixel(img, 14, y, 1.0);
188 }
189}
190
191fn draw_two(img: &mut [f32]) {
192 for x in 8..20 {
193 set_pixel(img, x, 6, 1.0);
194 set_pixel(img, x, 14, 1.0);
195 set_pixel(img, x, 22, 1.0);
196 }
197 for y in 6..14 {
198 set_pixel(img, 19, y, 1.0);
199 }
200 for y in 14..22 {
201 set_pixel(img, 8, y, 1.0);
202 }
203}
204
205fn draw_three(img: &mut [f32]) {
206 for x in 8..20 {
207 set_pixel(img, x, 6, 1.0);
208 set_pixel(img, x, 14, 1.0);
209 set_pixel(img, x, 22, 1.0);
210 }
211 for y in 6..22 {
212 set_pixel(img, 19, y, 1.0);
213 }
214}
215
216fn draw_four(img: &mut [f32]) {
217 for y in 6..15 {
218 set_pixel(img, 8, y, 1.0);
219 }
220 for x in 8..20 {
221 set_pixel(img, x, 14, 1.0);
222 }
223 for y in 6..22 {
224 set_pixel(img, 18, y, 1.0);
225 }
226}
227
228fn draw_five(img: &mut [f32]) {
229 for x in 8..20 {
230 set_pixel(img, x, 6, 1.0);
231 set_pixel(img, x, 14, 1.0);
232 set_pixel(img, x, 22, 1.0);
233 }
234 for y in 6..14 {
235 set_pixel(img, 8, y, 1.0);
236 }
237 for y in 14..22 {
238 set_pixel(img, 19, y, 1.0);
239 }
240}
241
242fn draw_six(img: &mut [f32]) {
243 for x in 8..20 {
244 set_pixel(img, x, 6, 1.0);
245 set_pixel(img, x, 14, 1.0);
246 set_pixel(img, x, 22, 1.0);
247 }
248 for y in 6..22 {
249 set_pixel(img, 8, y, 1.0);
250 }
251 for y in 14..22 {
252 set_pixel(img, 19, y, 1.0);
253 }
254}
255
256fn draw_seven(img: &mut [f32]) {
257 for x in 8..20 {
258 set_pixel(img, x, 6, 1.0);
259 }
260 for y in 6..22 {
261 set_pixel(img, 19, y, 1.0);
262 }
263}
264
265fn draw_eight(img: &mut [f32]) {
266 draw_oval(img);
267 for x in 8..20 {
268 set_pixel(img, x, 14, 1.0);
269 }
270}
271
272fn draw_nine(img: &mut [f32]) {
273 for x in 8..20 {
274 set_pixel(img, x, 6, 1.0);
275 set_pixel(img, x, 14, 1.0);
276 set_pixel(img, x, 22, 1.0);
277 }
278 for y in 6..14 {
279 set_pixel(img, 8, y, 1.0);
280 }
281 for y in 6..22 {
282 set_pixel(img, 19, y, 1.0);
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use arrow::array::Float32Array;
289
290 use super::*;
291 use crate::Dataset;
292
293 #[test]
294 fn test_mnist_load() {
295 let dataset = mnist().unwrap();
296 assert_eq!(dataset.len(), 100);
297 assert_eq!(dataset.num_classes(), 10);
298 }
299
300 #[test]
301 fn test_mnist_split() {
302 let dataset = mnist().unwrap();
303 let split = dataset.split().unwrap();
304 assert_eq!(split.train.len(), 80);
305 assert_eq!(split.test.len(), 20);
306 }
307
308 #[test]
309 fn test_mnist_num_features() {
310 let dataset = mnist().unwrap();
311 assert_eq!(dataset.num_features(), 784);
312 }
313
314 #[test]
315 fn test_mnist_feature_names() {
316 let dataset = mnist().unwrap();
317 assert!(dataset.feature_names().is_empty());
318 }
319
320 #[test]
321 fn test_mnist_target_name() {
322 let dataset = mnist().unwrap();
323 assert_eq!(dataset.target_name(), "label");
324 }
325
326 #[test]
327 fn test_mnist_description() {
328 let dataset = mnist().unwrap();
329 let desc = dataset.description();
330 assert!(desc.contains("MNIST"));
331 assert!(desc.contains("LeCun"));
332 }
333
334 #[test]
335 fn test_mnist_data_access() {
336 let dataset = mnist().unwrap();
337 let data = dataset.data();
338 assert_eq!(data.len(), 100);
339 }
340
341 #[test]
342 fn test_mnist_schema_columns() {
343 let dataset = mnist().unwrap();
344 let batch = dataset.data().get_batch(0).unwrap();
345 assert_eq!(batch.num_columns(), 785); }
347
348 #[test]
349 fn test_mnist_labels_in_range() {
350 let dataset = mnist().unwrap();
351 let batch = dataset.data().get_batch(0).unwrap();
352 let label_col = batch
353 .column(784)
354 .as_any()
355 .downcast_ref::<Int32Array>()
356 .unwrap();
357 for i in 0..label_col.len() {
358 let label = label_col.value(i);
359 assert!((0..10).contains(&label), "Label {} out of range", label);
360 }
361 }
362
363 #[test]
364 fn test_mnist_pixel_values_normalized() {
365 let dataset = mnist().unwrap();
366 let batch = dataset.data().get_batch(0).unwrap();
367 let pixel_col = batch
368 .column(0)
369 .as_any()
370 .downcast_ref::<Float32Array>()
371 .unwrap();
372 for i in 0..pixel_col.len() {
373 let val = pixel_col.value(i);
374 assert!(
375 (0.0..=1.0).contains(&val),
376 "Pixel value {} out of range",
377 val
378 );
379 }
380 }
381
382 #[test]
383 fn test_mnist_clone() {
384 let dataset = mnist().unwrap();
385 let cloned = dataset.clone();
386 assert_eq!(cloned.len(), dataset.len());
387 }
388
389 #[test]
390 fn test_mnist_debug() {
391 let dataset = mnist().unwrap();
392 let debug = format!("{:?}", dataset);
393 assert!(debug.contains("MnistDataset"));
394 }
395
396 #[test]
397 fn test_embedded_mnist_sample() {
398 let (pixels, labels) = embedded_mnist_sample();
399 assert_eq!(pixels.len(), 100 * 784);
400 assert_eq!(labels.len(), 100);
401 }
402
403 #[test]
404 fn test_embedded_mnist_sample_labels_balanced() {
405 let (_, labels) = embedded_mnist_sample();
406 let mut counts = [0i32; 10];
407 for label in labels {
408 counts[usize::try_from(label).unwrap()] += 1;
409 }
410 for (digit, &count) in counts.iter().enumerate() {
411 assert_eq!(count, 10, "Digit {} should have 10 samples", digit);
412 }
413 }
414
415 #[test]
416 fn test_generate_digit_pattern_0() {
417 let pattern = generate_digit_pattern(0);
418 assert_eq!(pattern.len(), 784);
419 let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
421 assert!(non_zero > 0, "Digit 0 pattern should have non-zero pixels");
422 }
423
424 #[test]
425 fn test_generate_digit_pattern_1() {
426 let pattern = generate_digit_pattern(1);
427 assert_eq!(pattern.len(), 784);
428 let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
429 assert!(non_zero > 0, "Digit 1 pattern should have non-zero pixels");
430 }
431
432 #[test]
433 fn test_generate_digit_patterns_all() {
434 for digit in 0..10 {
435 let pattern = generate_digit_pattern(digit);
436 assert_eq!(pattern.len(), 784, "Digit {} pattern wrong size", digit);
437 let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
438 assert!(
439 non_zero > 0,
440 "Digit {} pattern should have non-zero pixels",
441 digit
442 );
443 }
444 }
445
446 #[test]
447 fn test_generate_digit_pattern_unknown() {
448 let pattern = generate_digit_pattern(99);
449 assert_eq!(pattern.len(), 784);
450 let non_zero: usize = pattern.iter().filter(|&&p| p > 0.0).count();
452 assert_eq!(non_zero, 0, "Unknown digit should have all zeros");
453 }
454
455 #[test]
456 fn test_set_pixel_in_bounds() {
457 let mut img = vec![0.0f32; 784];
458 set_pixel(&mut img, 14, 14, 1.0);
459 assert_eq!(img[14 * 28 + 14], 1.0);
460 }
461
462 #[test]
463 fn test_set_pixel_out_of_bounds() {
464 let mut img = vec![0.0f32; 784];
465 set_pixel(&mut img, 30, 14, 1.0); set_pixel(&mut img, 14, 30, 1.0); let non_zero: usize = img.iter().filter(|&&p| p > 0.0).count();
469 assert_eq!(non_zero, 0);
470 }
471
472 #[test]
476 fn test_mnist_split_is_stratified() {
477 use std::collections::HashSet;
478
479 let dataset = mnist().unwrap();
480 let split = dataset.split().unwrap();
481
482 let train_batch = split.train.get_batch(0).unwrap();
484 let train_labels = train_batch
485 .column(784)
486 .as_any()
487 .downcast_ref::<Int32Array>()
488 .unwrap();
489 let train_label_set: HashSet<i32> = (0..train_labels.len())
490 .map(|i| train_labels.value(i))
491 .collect();
492
493 let test_batch = split.test.get_batch(0).unwrap();
495 let test_labels = test_batch
496 .column(784)
497 .as_any()
498 .downcast_ref::<Int32Array>()
499 .unwrap();
500 let test_label_set: HashSet<i32> = (0..test_labels.len())
501 .map(|i| test_labels.value(i))
502 .collect();
503
504 assert_eq!(
506 train_label_set.len(),
507 10,
508 "Train set must contain all 10 digit classes, got {:?}",
509 train_label_set
510 );
511 assert_eq!(
512 test_label_set.len(),
513 10,
514 "Test set must contain all 10 digit classes, got {:?}",
515 test_label_set
516 );
517
518 for digit in 0..10 {
520 assert!(
521 train_label_set.contains(&digit),
522 "Train set missing digit {}",
523 digit
524 );
525 assert!(
526 test_label_set.contains(&digit),
527 "Test set missing digit {}",
528 digit
529 );
530 }
531 }
532
533 #[test]
535 fn test_mnist_split_maintains_class_balance() {
536 let dataset = mnist().unwrap();
537 let split = dataset.split().unwrap();
538
539 let train_batch = split.train.get_batch(0).unwrap();
541 let train_labels = train_batch
542 .column(784)
543 .as_any()
544 .downcast_ref::<Int32Array>()
545 .unwrap();
546
547 let mut train_counts = [0usize; 10];
549 for i in 0..train_labels.len() {
550 let label = train_labels.value(i);
551 if (0..10).contains(&label) {
552 #[allow(clippy::cast_sign_loss)]
553 let idx = label as usize;
554 train_counts[idx] += 1;
555 }
556 }
557
558 for (digit, &count) in train_counts.iter().enumerate() {
561 assert!(
562 (7..=9).contains(&count),
563 "Digit {} has {} training samples, expected ~8",
564 digit,
565 count
566 );
567 }
568 }
569}