Skip to main content

alimentar/datasets/
mod.rs

1//! Canonical ML dataset loaders
2//!
3//! Provides convenient one-liner access to well-known ML datasets
4//! for tutorials, examples, and benchmarking.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use alimentar::datasets::{iris, mnist, cifar10};
10//!
11//! // Load Iris dataset (embedded, no download)
12//! let iris = iris()?;
13//! println!("Iris: {} samples", iris.len());
14//!
15//! // Load MNIST (downloads from HuggingFace Hub on first use)
16//! let mnist = mnist()?;
17//! let (train, test) = mnist.split()?;
18//! ```
19
20mod cifar10;
21mod cifar100;
22mod fashion_mnist;
23mod iris;
24mod mnist;
25
26pub use cifar10::{cifar10, Cifar10Dataset, CIFAR10_CLASSES};
27pub use cifar100::{cifar100, Cifar100Dataset, CIFAR100_COARSE_CLASSES, CIFAR100_FINE_CLASSES};
28pub use fashion_mnist::{fashion_mnist, FashionMnistDataset, FASHION_MNIST_CLASSES};
29pub use iris::{iris, IrisDataset};
30pub use mnist::{mnist, MnistDataset};
31
32use crate::{ArrowDataset, Dataset};
33
34/// A canonical ML dataset with train/test split support
35pub trait CanonicalDataset {
36    /// Returns the full dataset
37    fn data(&self) -> &ArrowDataset;
38
39    /// Returns the number of samples
40    fn len(&self) -> usize {
41        self.data().len()
42    }
43
44    /// Returns true if the dataset is empty
45    fn is_empty(&self) -> bool {
46        self.len() == 0
47    }
48
49    /// Returns the number of features (excluding label)
50    fn num_features(&self) -> usize;
51
52    /// Returns the number of classes (for classification datasets)
53    fn num_classes(&self) -> usize;
54
55    /// Returns the feature column names
56    fn feature_names(&self) -> &'static [&'static str];
57
58    /// Returns the label/target column name
59    fn target_name(&self) -> &'static str;
60
61    /// Returns a description of the dataset
62    fn description(&self) -> &'static str;
63}
64
65/// Split information for train/test datasets
66#[derive(Debug, Clone)]
67pub struct DatasetSplit {
68    /// Training dataset
69    pub train: ArrowDataset,
70    /// Test dataset
71    pub test: ArrowDataset,
72}
73
74impl DatasetSplit {
75    /// Create a new dataset split
76    pub fn new(train: ArrowDataset, test: ArrowDataset) -> Self {
77        Self { train, test }
78    }
79
80    /// Get training data
81    pub fn train(&self) -> &ArrowDataset {
82        &self.train
83    }
84
85    /// Get test data
86    pub fn test(&self) -> &ArrowDataset {
87        &self.test
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    #[test]
96    fn test_dataset_split_new() {
97        let train = iris::iris()
98            .ok()
99            .unwrap_or_else(|| panic!("Should load iris"))
100            .data()
101            .clone();
102        let test = train.clone();
103
104        let split = DatasetSplit::new(train.clone(), test.clone());
105        assert_eq!(split.train().len(), train.len());
106        assert_eq!(split.test().len(), test.len());
107    }
108
109    #[test]
110    fn test_dataset_split_debug() {
111        let train = iris::iris()
112            .ok()
113            .unwrap_or_else(|| panic!("Should load iris"))
114            .data()
115            .clone();
116        let test = train.clone();
117
118        let split = DatasetSplit::new(train, test);
119        let debug = format!("{:?}", split);
120        assert!(debug.contains("DatasetSplit"));
121    }
122
123    #[test]
124    fn test_dataset_split_clone() {
125        let train = iris::iris()
126            .ok()
127            .unwrap_or_else(|| panic!("Should load iris"))
128            .data()
129            .clone();
130        let test = train.clone();
131
132        let split = DatasetSplit::new(train, test);
133        let cloned = split.clone();
134        assert_eq!(cloned.train().len(), split.train().len());
135    }
136
137    #[test]
138    fn test_canonical_dataset_is_empty() {
139        let iris = iris::iris()
140            .ok()
141            .unwrap_or_else(|| panic!("Should load iris"));
142        assert!(!iris.is_empty());
143    }
144}