alimentar/datasets/
mod.rs1mod cifar10;
21mod cifar100;
22mod fashion_mnist;
23mod iris;
24mod mnist;
25
26pub use cifar10::{cifar10, Cifar10Dataset, CIFAR10_CLASSES};
27pub use cifar100::{cifar100, Cifar100Dataset, CIFAR100_COARSE_CLASSES, CIFAR100_FINE_CLASSES};
28pub use fashion_mnist::{fashion_mnist, FashionMnistDataset, FASHION_MNIST_CLASSES};
29pub use iris::{iris, IrisDataset};
30pub use mnist::{mnist, MnistDataset};
31
32use crate::{ArrowDataset, Dataset};
33
34pub trait CanonicalDataset {
36 fn data(&self) -> &ArrowDataset;
38
39 fn len(&self) -> usize {
41 self.data().len()
42 }
43
44 fn is_empty(&self) -> bool {
46 self.len() == 0
47 }
48
49 fn num_features(&self) -> usize;
51
52 fn num_classes(&self) -> usize;
54
55 fn feature_names(&self) -> &'static [&'static str];
57
58 fn target_name(&self) -> &'static str;
60
61 fn description(&self) -> &'static str;
63}
64
65#[derive(Debug, Clone)]
67pub struct DatasetSplit {
68 pub train: ArrowDataset,
70 pub test: ArrowDataset,
72}
73
74impl DatasetSplit {
75 pub fn new(train: ArrowDataset, test: ArrowDataset) -> Self {
77 Self { train, test }
78 }
79
80 pub fn train(&self) -> &ArrowDataset {
82 &self.train
83 }
84
85 pub fn test(&self) -> &ArrowDataset {
87 &self.test
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94
95 #[test]
96 fn test_dataset_split_new() {
97 let train = iris::iris()
98 .ok()
99 .unwrap_or_else(|| panic!("Should load iris"))
100 .data()
101 .clone();
102 let test = train.clone();
103
104 let split = DatasetSplit::new(train.clone(), test.clone());
105 assert_eq!(split.train().len(), train.len());
106 assert_eq!(split.test().len(), test.len());
107 }
108
109 #[test]
110 fn test_dataset_split_debug() {
111 let train = iris::iris()
112 .ok()
113 .unwrap_or_else(|| panic!("Should load iris"))
114 .data()
115 .clone();
116 let test = train.clone();
117
118 let split = DatasetSplit::new(train, test);
119 let debug = format!("{:?}", split);
120 assert!(debug.contains("DatasetSplit"));
121 }
122
123 #[test]
124 fn test_dataset_split_clone() {
125 let train = iris::iris()
126 .ok()
127 .unwrap_or_else(|| panic!("Should load iris"))
128 .data()
129 .clone();
130 let test = train.clone();
131
132 let split = DatasetSplit::new(train, test);
133 let cloned = split.clone();
134 assert_eq!(cloned.train().len(), split.train().len());
135 }
136
137 #[test]
138 fn test_canonical_dataset_is_empty() {
139 let iris = iris::iris()
140 .ok()
141 .unwrap_or_else(|| panic!("Should load iris"));
142 assert!(!iris.is_empty());
143 }
144}