mod cifar10;
mod cifar100;
mod fashion_mnist;
mod iris;
mod mnist;
pub use cifar10::{cifar10, Cifar10Dataset, CIFAR10_CLASSES};
pub use cifar100::{cifar100, Cifar100Dataset, CIFAR100_COARSE_CLASSES, CIFAR100_FINE_CLASSES};
pub use fashion_mnist::{fashion_mnist, FashionMnistDataset, FASHION_MNIST_CLASSES};
pub use iris::{iris, IrisDataset};
pub use mnist::{mnist, MnistDataset};
use crate::{ArrowDataset, Dataset};
pub trait CanonicalDataset {
fn data(&self) -> &ArrowDataset;
fn len(&self) -> usize {
self.data().len()
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn num_features(&self) -> usize;
fn num_classes(&self) -> usize;
fn feature_names(&self) -> &'static [&'static str];
fn target_name(&self) -> &'static str;
fn description(&self) -> &'static str;
}
#[derive(Debug, Clone)]
pub struct DatasetSplit {
pub train: ArrowDataset,
pub test: ArrowDataset,
}
impl DatasetSplit {
pub fn new(train: ArrowDataset, test: ArrowDataset) -> Self {
Self { train, test }
}
pub fn train(&self) -> &ArrowDataset {
&self.train
}
pub fn test(&self) -> &ArrowDataset {
&self.test
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dataset_split_new() {
let train = iris::iris()
.ok()
.unwrap_or_else(|| panic!("Should load iris"))
.data()
.clone();
let test = train.clone();
let split = DatasetSplit::new(train.clone(), test.clone());
assert_eq!(split.train().len(), train.len());
assert_eq!(split.test().len(), test.len());
}
#[test]
fn test_dataset_split_debug() {
let train = iris::iris()
.ok()
.unwrap_or_else(|| panic!("Should load iris"))
.data()
.clone();
let test = train.clone();
let split = DatasetSplit::new(train, test);
let debug = format!("{:?}", split);
assert!(debug.contains("DatasetSplit"));
}
#[test]
fn test_dataset_split_clone() {
let train = iris::iris()
.ok()
.unwrap_or_else(|| panic!("Should load iris"))
.data()
.clone();
let test = train.clone();
let split = DatasetSplit::new(train, test);
let cloned = split.clone();
assert_eq!(cloned.train().len(), split.train().len());
}
#[test]
fn test_canonical_dataset_is_empty() {
let iris = iris::iris()
.ok()
.unwrap_or_else(|| panic!("Should load iris"));
assert!(!iris.is_empty());
}
}