use crate::cache::RegistryEntry;
use crate::error::{DatasetsError, Result};
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct DatasetMetadata {
pub name: String,
pub description: String,
pub n_samples: usize,
pub n_features: usize,
pub task_type: String,
pub targetnames: Option<Vec<String>>,
pub featurenames: Option<Vec<String>>,
pub url: Option<String>,
pub checksum: Option<String>,
}
pub struct DatasetRegistry {
entries: HashMap<String, RegistryEntry>,
}
impl Default for DatasetRegistry {
fn default() -> Self {
let mut registry = Self::new();
registry.populate_default_datasets();
registry
}
}
impl DatasetRegistry {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub fn register(&mut self, name: String, entry: RegistryEntry) {
self.entries.insert(name, entry);
}
pub fn get(&self, name: &str) -> Option<&RegistryEntry> {
self.entries.get(name)
}
pub fn list_datasets(&self) -> Vec<String> {
self.entries.keys().cloned().collect()
}
pub fn contains(&self, name: &str) -> bool {
self.entries.contains_key(name)
}
pub fn get_metadata(&self, name: &str) -> Result<DatasetMetadata> {
match name {
"iris" => Ok(DatasetMetadata {
name: "Iris".to_string(),
description: "Classic iris flower dataset for classification".to_string(),
n_samples: 150,
n_features: 4,
task_type: "classification".to_string(),
targetnames: Some(vec![
"setosa".to_string(),
"versicolor".to_string(),
"virginica".to_string(),
]),
featurenames: Some(vec![
"sepal_length".to_string(),
"sepal_width".to_string(),
"petal_length".to_string(),
"petal_width".to_string(),
]),
url: None,
checksum: None,
}),
"boston" => Ok(DatasetMetadata {
name: "Boston Housing".to_string(),
description: "Boston housing prices dataset for regression".to_string(),
n_samples: 506,
n_features: 13,
task_type: "regression".to_string(),
targetnames: None,
featurenames: None,
url: None,
checksum: None,
}),
"digits" => Ok(DatasetMetadata {
name: "Digits".to_string(),
description: "Hand-written digits dataset for image classification".to_string(),
n_samples: 1797,
n_features: 64,
task_type: "classification".to_string(),
targetnames: Some(vec![
"0".to_string(),
"1".to_string(),
"2".to_string(),
"3".to_string(),
"4".to_string(),
"5".to_string(),
"6".to_string(),
"7".to_string(),
"8".to_string(),
"9".to_string(),
]),
featurenames: None,
url: None,
checksum: None,
}),
"wine" => Ok(DatasetMetadata {
name: "Wine".to_string(),
description: "Wine recognition dataset for classification".to_string(),
n_samples: 178,
n_features: 13,
task_type: "classification".to_string(),
targetnames: Some(vec![
"class_0".to_string(),
"class_1".to_string(),
"class_2".to_string(),
]),
featurenames: None,
url: None,
checksum: None,
}),
"breast_cancer" => Ok(DatasetMetadata {
name: "Breast Cancer".to_string(),
description: "Breast cancer wisconsin dataset for classification".to_string(),
n_samples: 569,
n_features: 30,
task_type: "classification".to_string(),
targetnames: Some(vec!["malignant".to_string(), "benign".to_string()]),
featurenames: None,
url: None,
checksum: None,
}),
"diabetes" => Ok(DatasetMetadata {
name: "Diabetes".to_string(),
description: "Diabetes dataset for regression".to_string(),
n_samples: 442,
n_features: 10,
task_type: "regression".to_string(),
targetnames: None,
featurenames: None,
url: None,
checksum: None,
}),
_ => Err(DatasetsError::Other(format!("Unknown dataset: {name}"))),
}
}
fn populate_default_datasets(&mut self) {
self.register(
"example".to_string(),
RegistryEntry {
url: "file://data/example.csv",
sha256: "c51c3ff2e8a5db28b1baed809a2ba29f29643e5a26ad476448eb3889996173d6",
},
);
self.register(
"sample_data".to_string(),
RegistryEntry {
url: "file://examples/sample_data.csv",
sha256: "59cceb2c80692ee2c1c3b607335d1feb983ceed24214d1ffc2eace9f3ce5ab47",
},
);
self.register_toy_dataset("iris", "Classic iris flower dataset for classification");
self.register_toy_dataset("boston", "Boston housing prices dataset for regression");
self.register_toy_dataset(
"digits",
"Hand-written digits dataset for image classification",
);
self.register_toy_dataset("wine", "Wine recognition dataset for classification");
self.register_toy_dataset(
"breast_cancer",
"Breast cancer wisconsin dataset for classification",
);
self.register_toy_dataset("diabetes", "Diabetes dataset for regression");
}
fn register_toy_dataset(&mut self, name: &str, _description: &str) {
let url = match name {
"iris" => "builtin://iris",
"boston" => "builtin://boston",
"digits" => "builtin://digits",
"wine" => "builtin://wine",
"breast_cancer" => "builtin://breast_cancer",
"diabetes" => "builtin://diabetes",
_ => "builtin://unknown",
};
self.register(
name.to_string(),
RegistryEntry {
url,
sha256: "builtin", },
);
}
}
#[allow(dead_code)]
pub fn get_registry() -> DatasetRegistry {
DatasetRegistry::default()
}
struct BuiltinEntry {
name: &'static str,
description: &'static str,
loader: fn() -> Result<crate::utils::Dataset>,
}
fn wine_loader() -> Result<crate::utils::Dataset> {
let dr = crate::standard::load_wine()?;
let mut ds = crate::utils::Dataset::new(dr.data, Some(dr.target));
ds = ds
.with_featurenames(dr.feature_names)
.with_targetnames(dr.target_names)
.with_description(dr.description);
Ok(ds)
}
static BUILTIN_TABLE: &[BuiltinEntry] = &[
BuiltinEntry {
name: "iris",
description: "Classic iris flower dataset for classification (150 samples, 4 features)",
loader: crate::toy::load_iris,
},
BuiltinEntry {
name: "wine",
description: "Wine recognition dataset for classification (178 samples, 13 features)",
loader: wine_loader,
},
BuiltinEntry {
name: "breast_cancer",
description: "Breast cancer Wisconsin dataset for classification (30 samples, 5 features)",
loader: crate::toy::load_breast_cancer,
},
BuiltinEntry {
name: "boston",
description: "Boston housing prices dataset for regression (506 samples, 13 features)",
loader: crate::toy::load_boston,
},
BuiltinEntry {
name: "diabetes",
description: "Diabetes dataset for regression (442 samples, 10 features)",
loader: crate::toy::load_diabetes,
},
BuiltinEntry {
name: "digits",
description: "Handwritten digits dataset for classification (50 samples, 16 features)",
loader: crate::toy::load_digits,
},
];
pub fn list_datasets() -> Vec<&'static str> {
BUILTIN_TABLE.iter().map(|e| e.name).collect()
}
pub fn load_dataset_by_name(name: &str) -> Result<crate::utils::Dataset> {
for entry in BUILTIN_TABLE {
if entry.name == name {
return (entry.loader)();
}
}
Err(DatasetsError::NotFound(format!(
"Unknown dataset '{}'. Available: {:?}",
name,
list_datasets()
)))
}
#[cfg(feature = "download")]
#[allow(dead_code)]
pub fn load_dataset_byname(name: &str, forcedownload: bool) -> Result<crate::utils::Dataset> {
let registry = get_registry();
if let Some(entry) = registry.get(name) {
if entry.url.starts_with("builtin://") {
match name {
"iris" => crate::toy::load_iris(),
"boston" => crate::toy::load_boston(),
"digits" => crate::toy::load_digits(),
"wine" => crate::sample::load_wine(false),
"breast_cancer" => crate::toy::load_breast_cancer(),
"diabetes" => crate::toy::load_diabetes(),
_ => Err(DatasetsError::Other(format!(
"Built-in dataset '{}' not implemented",
name
))),
}
} else if entry.url.starts_with("file://") {
load_local_dataset(name, &entry.url[7..], entry.sha256) } else if entry.url.starts_with("http") {
match name {
"california_housing" => crate::sample::load_california_housing(forcedownload),
"electrocardiogram" => crate::time_series::electrocardiogram(),
"stock_market" => crate::time_series::stock_market(false),
"weather" => crate::time_series::weather(None),
_ => Err(DatasetsError::Other(format!(
"Remote dataset '{}' not yet implemented for loading",
name
))),
}
} else {
Err(DatasetsError::Other(format!(
"Unsupported URL scheme for dataset '{}': {}",
name, entry.url
)))
}
} else {
Err(DatasetsError::Other(format!(
"Unknown dataset: '{}'. Available datasets: {:?}",
name,
registry.list_datasets()
)))
}
}
#[cfg(feature = "download")]
#[allow(dead_code)]
fn load_local_dataset(
name: &str,
relativepath: &str,
expected_sha256: &str,
) -> Result<crate::utils::Dataset> {
use crate::loaders::{load_csv, CsvConfig};
use std::path::Path;
let workspace_root = env!("CARGO_MANIFEST_DIR");
let filepath = Path::new(workspace_root).join(relativepath);
if !filepath.exists() {
return Err(DatasetsError::Other(format!(
"Local dataset file not found: {}",
filepath.display()
)));
}
if expected_sha256 != "builtin" {
if let Ok(actual_hash) = crate::cache::sha256_hash_file(&filepath) {
if actual_hash != expected_sha256 {
return Err(DatasetsError::Other(format!(
"Hash verification failed for dataset '{}'. Expected: {}, Got: {}",
name, expected_sha256, actual_hash
)));
}
}
}
let config = CsvConfig::default().with_header(true);
let mut dataset = load_csv(&filepath, config)?;
dataset = dataset.with_description(format!("Local dataset: {}", name));
Ok(dataset)
}
#[cfg(not(feature = "download"))]
#[allow(dead_code)]
pub fn load_dataset_byname(_name: &str, _forcedownload: bool) -> Result<crate::utils::Dataset> {
Err(DatasetsError::Other(
"Download feature is not enabled. Recompile with --features _download".to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_creation() {
let registry = DatasetRegistry::new();
assert!(registry.entries.is_empty());
}
#[test]
fn test_registry_default() {
let registry = DatasetRegistry::default();
assert!(!registry.entries.is_empty());
assert!(registry.contains("example"));
assert!(registry.contains("sample_data"));
assert!(registry.contains("iris"));
assert!(registry.contains("boston"));
assert!(registry.contains("wine"));
assert!(registry.contains("digits"));
assert!(registry.contains("breast_cancer"));
assert!(registry.contains("diabetes"));
}
#[test]
fn test_registry_operations() {
let mut registry = DatasetRegistry::new();
let entry = RegistryEntry {
url: "https://example.com/test.csv",
sha256: "abcd1234",
};
registry.register("test_dataset".to_string(), entry);
assert!(registry.contains("test_dataset"));
assert!(!registry.contains("nonexistent"));
let retrieved = registry.get("test_dataset").expect("Operation failed");
assert_eq!(retrieved.url, "https://example.com/test.csv");
assert_eq!(retrieved.sha256, "abcd1234");
let datasets = registry.list_datasets();
assert_eq!(datasets.len(), 1);
assert!(datasets.contains(&"test_dataset".to_string()));
}
#[test]
fn test_get_registry() {
let registry = get_registry();
assert!(!registry.list_datasets().is_empty());
}
#[test]
fn test_registry_url_schemes() {
let registry = DatasetRegistry::default();
if let Some(iris_entry) = registry.get("iris") {
assert_eq!(iris_entry.url, "builtin://iris");
assert_eq!(iris_entry.sha256, "builtin");
}
if let Some(example_entry) = registry.get("example") {
assert_eq!(example_entry.url, "file://data/example.csv");
assert_eq!(
example_entry.sha256,
"c51c3ff2e8a5db28b1baed809a2ba29f29643e5a26ad476448eb3889996173d6"
);
}
}
#[test]
fn test_dataset_count() {
let registry = DatasetRegistry::default();
let datasets = registry.list_datasets();
assert_eq!(datasets.len(), 8);
let expected_datasets = vec![
"example",
"sample_data", "iris",
"boston",
"digits",
"wine",
"breast_cancer",
"diabetes", ];
for expected in expected_datasets {
assert!(
datasets.contains(&expected.to_string()),
"Dataset '{expected}' not found in registry"
);
}
}
#[test]
fn test_list_datasets_contains_iris_and_wine() {
let names = super::list_datasets();
assert!(
names.contains(&"iris"),
"list_datasets() should include 'iris'"
);
assert!(
names.contains(&"wine"),
"list_datasets() should include 'wine'"
);
}
#[test]
fn test_list_datasets_minimum_count() {
let names = super::list_datasets();
assert!(
names.len() >= 5,
"Expected at least 5 built-in datasets, got {}",
names.len()
);
}
#[test]
fn test_load_dataset_by_name_iris_succeeds() {
let result = super::load_dataset_by_name("iris");
assert!(
result.is_ok(),
"load_dataset_by_name('iris') should succeed"
);
let ds = result.expect("iris loaded");
assert!(ds.n_samples() > 0, "iris should have at least one sample");
}
#[test]
fn test_load_dataset_by_name_unknown_returns_err() {
let result = super::load_dataset_by_name("unknown_xyz");
assert!(
result.is_err(),
"load_dataset_by_name('unknown_xyz') should return Err"
);
}
#[test]
fn test_load_dataset_by_name_iris_feature_count() {
let ds = super::load_dataset_by_name("iris").expect("iris loaded");
assert_eq!(
ds.n_features(),
4,
"iris should have 4 features, got {}",
ds.n_features()
);
}
#[test]
fn test_load_dataset_by_name_wine_roundtrip() {
let ds = super::load_dataset_by_name("wine").expect("wine loaded");
assert!(ds.n_samples() > 0, "wine should have samples");
assert_eq!(
ds.n_features(),
13,
"wine should have 13 features, got {}",
ds.n_features()
);
}
#[test]
fn test_load_dataset_by_name_all_known_succeed() {
for name in super::list_datasets() {
let result = super::load_dataset_by_name(name);
assert!(
result.is_ok(),
"load_dataset_by_name('{name}') failed: {:?}",
result.err()
);
}
}
}