use crate::error::DataError;
use scirs2_core::Distribution; use torsh_tensor::Tensor;
use scirs2_datasets::toy::{
load_boston as scirs2_load_boston, load_breast_cancer as scirs2_load_breast_cancer,
load_diabetes as scirs2_load_diabetes, load_digits as scirs2_load_digits,
load_iris as scirs2_load_iris,
};
#[derive(Debug, Clone)]
pub enum BuiltinDataset {
Iris,
Boston,
Diabetes,
Wine,
BreastCancer,
Digits,
}
#[derive(Debug, Clone)]
pub struct SyntheticDataConfig {
pub n_samples: usize,
pub n_features: usize,
pub n_classes: Option<usize>,
pub seed: Option<u64>,
pub noise: Option<f64>,
pub scale: Option<ScalingMethod>,
}
#[derive(Debug, Clone)]
pub enum ScalingMethod {
StandardScaler,
MinMaxScaler,
RobustScaler,
Normalizer,
}
#[derive(Debug, Clone)]
pub struct RegressionConfig {
pub n_samples: usize,
pub n_features: usize,
pub n_informative: Option<usize>,
pub noise: Option<f64>,
pub bias: Option<f64>,
pub random_state: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct ClassificationConfig {
pub n_samples: usize,
pub n_features: usize,
pub n_classes: usize,
pub n_informative: Option<usize>,
pub n_redundant: Option<usize>,
pub n_clusters_per_class: Option<usize>,
pub class_sep: Option<f64>,
pub random_state: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct ClusteringConfig {
pub n_samples: usize,
pub centers: usize,
pub n_features: Option<usize>,
pub cluster_std: Option<f64>,
pub center_box: Option<(f64, f64)>,
pub random_state: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct DatasetResult {
pub features: Tensor,
pub targets: Tensor,
pub feature_names: Option<Vec<String>>,
pub target_names: Option<Vec<String>>,
pub description: String,
}
impl Default for SyntheticDataConfig {
fn default() -> Self {
Self {
n_samples: 100,
n_features: 2,
n_classes: Some(2),
seed: None,
noise: Some(0.1),
scale: Some(ScalingMethod::StandardScaler),
}
}
}
pub fn load_builtin_dataset(dataset: BuiltinDataset) -> Result<DatasetResult, DataError> {
match dataset {
BuiltinDataset::Iris => load_iris_dataset(),
BuiltinDataset::Boston => load_boston_dataset(),
BuiltinDataset::Diabetes => load_diabetes_dataset(),
BuiltinDataset::Wine => load_wine_dataset(),
BuiltinDataset::BreastCancer => load_breast_cancer_dataset(),
BuiltinDataset::Digits => load_digits_dataset(),
}
}
pub fn make_regression(config: RegressionConfig) -> Result<DatasetResult, DataError> {
use scirs2_core::random::{Normal, SeedableRng, StdRng};
let n_informative = config.n_informative.unwrap_or(config.n_features);
let noise_std = config.noise.unwrap_or(0.0);
let bias = config.bias.unwrap_or(0.0);
if n_informative > config.n_features {
return Err(DataError::dataset(
crate::error::DatasetErrorKind::CorruptedData,
format!(
"n_informative ({}) cannot exceed n_features ({})",
n_informative, config.n_features
),
));
}
let mut rng = if let Some(seed) = config.random_state {
StdRng::seed_from_u64(seed)
} else {
let mut thread_rng = scirs2_core::random::thread_rng();
StdRng::from_rng(&mut thread_rng)
};
let normal = Normal::new(0.0, 1.0).expect("valid Normal parameters");
let features_data: Vec<f32> = (0..config.n_samples * config.n_features)
.map(|_| normal.sample(&mut rng) as f32)
.collect();
let features = Tensor::from_vec(
features_data.clone(),
&[config.n_samples, config.n_features],
)?;
let coefficients: Vec<f32> = (0..n_informative)
.map(|_| rng.gen_range(-100.0..100.0))
.collect();
let noise_dist = Normal::new(0.0, noise_std).expect("valid Normal parameters");
let targets_data: Vec<f32> = (0..config.n_samples)
.map(|i| {
let mut target = bias as f32;
for j in 0..n_informative {
let idx = i * config.n_features + j;
target += coefficients[j] * features_data[idx];
}
if noise_std > 0.0 {
target += noise_dist.sample(&mut rng) as f32;
}
target
})
.collect();
let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
Ok(DatasetResult {
features,
targets,
feature_names: Some(
(0..config.n_features)
.map(|i| {
if i < n_informative {
format!("informative_{}", i)
} else {
format!("noise_{}", i - n_informative)
}
})
.collect(),
),
target_names: Some(vec!["target".to_string()]),
description: format!(
"Synthetic regression dataset: {} samples, {} features ({} informative), noise_std={:.2}, bias={:.2}",
config.n_samples, config.n_features, n_informative, noise_std, bias
),
})
}
pub fn make_classification(config: ClassificationConfig) -> Result<DatasetResult, DataError> {
use scirs2_core::random::{Normal, SeedableRng, StdRng};
let n_informative = config.n_informative.unwrap_or(config.n_features.min(2));
let n_redundant = config.n_redundant.unwrap_or(0);
let n_clusters_per_class = config.n_clusters_per_class.unwrap_or(1);
let class_sep = config.class_sep.unwrap_or(1.0);
if n_informative + n_redundant > config.n_features {
return Err(DataError::dataset(
crate::error::DatasetErrorKind::CorruptedData,
format!(
"n_informative ({}) + n_redundant ({}) cannot exceed n_features ({})",
n_informative, n_redundant, config.n_features
),
));
}
let mut rng = if let Some(seed) = config.random_state {
StdRng::seed_from_u64(seed)
} else {
let mut thread_rng = scirs2_core::random::thread_rng();
StdRng::from_rng(&mut thread_rng)
};
let total_clusters = config.n_classes * n_clusters_per_class;
let mut cluster_centers: Vec<Vec<f32>> = Vec::new();
let mut cluster_labels: Vec<usize> = Vec::new();
for class_id in 0..config.n_classes {
for _ in 0..n_clusters_per_class {
let center: Vec<f32> = (0..n_informative)
.map(|_| rng.gen_range(-class_sep as f32..class_sep as f32) * 10.0)
.collect();
cluster_centers.push(center);
cluster_labels.push(class_id);
}
}
let samples_per_cluster = config.n_samples / total_clusters;
let remainder = config.n_samples % total_clusters;
let mut features_data = Vec::new();
let mut targets_data = Vec::new();
let normal = Normal::new(0.0, 1.0).expect("valid Normal parameters");
for (cluster_idx, (center, &class_label)) in cluster_centers
.iter()
.zip(cluster_labels.iter())
.enumerate()
{
let n_samples_this_cluster =
samples_per_cluster + if cluster_idx < remainder { 1 } else { 0 };
for _ in 0..n_samples_this_cluster {
for ¢er_val in center.iter() {
let noise = normal.sample(&mut rng) as f32;
features_data.push(center_val + noise);
}
let start_idx = features_data.len() - n_informative;
for _ in 0..n_redundant {
let mut redundant = 0.0f32;
for j in 0..n_informative {
let weight = rng.gen_range(-1.0..1.0);
redundant += weight * features_data[start_idx + j];
}
features_data.push(redundant);
}
let n_noise = config.n_features - n_informative - n_redundant;
for _ in 0..n_noise {
features_data.push(rng.gen_range(-10.0..10.0));
}
targets_data.push(class_label as f32);
}
}
let features = Tensor::from_vec(features_data, &[config.n_samples, config.n_features])?;
let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
Ok(DatasetResult {
features,
targets,
feature_names: Some(
(0..config.n_features)
.map(|i| {
if i < n_informative {
format!("informative_{}", i)
} else if i < n_informative + n_redundant {
format!("redundant_{}", i - n_informative)
} else {
format!("noise_{}", i - n_informative - n_redundant)
}
})
.collect(),
),
target_names: Some(
(0..config.n_classes)
.map(|i| format!("class_{}", i))
.collect(),
),
description: format!(
"Synthetic classification dataset: {} samples, {} features ({} informative, {} redundant), {} classes, class_sep={:.2}",
config.n_samples, config.n_features, n_informative, n_redundant, config.n_classes, class_sep
),
})
}
pub fn make_blobs(config: ClusteringConfig) -> Result<DatasetResult, DataError> {
use scirs2_core::random::{Normal, SeedableRng, StdRng};
let mut rng = if let Some(seed) = config.random_state {
StdRng::seed_from_u64(seed)
} else {
let mut thread_rng = scirs2_core::random::thread_rng();
StdRng::from_rng(&mut thread_rng)
};
let n_features = config.n_features.unwrap_or(2);
let cluster_std = config.cluster_std.unwrap_or(1.0);
let (box_min, box_max) = config.center_box.unwrap_or((-10.0, 10.0));
if box_min >= box_max {
return Err(DataError::dataset(
crate::error::DatasetErrorKind::CorruptedData,
format!(
"center_box min ({}) must be less than max ({})",
box_min, box_max
),
));
}
let centers: Vec<Vec<f32>> = (0..config.centers)
.map(|_| {
(0..n_features)
.map(|_| rng.gen_range(box_min as f32..box_max as f32))
.collect()
})
.collect();
let samples_per_cluster = config.n_samples / config.centers;
let remainder = config.n_samples % config.centers;
let mut features_data = Vec::new();
let mut targets_data = Vec::new();
let normal = Normal::new(0.0, cluster_std).expect("valid Normal parameters");
for (cluster_id, center) in centers.iter().enumerate() {
let n_samples_this_cluster =
samples_per_cluster + if cluster_id < remainder { 1 } else { 0 };
for _ in 0..n_samples_this_cluster {
for ¢er_coord in center {
let noise = normal.sample(&mut rng) as f32;
features_data.push(center_coord + noise);
}
targets_data.push(cluster_id as f32);
}
}
let features = Tensor::from_vec(features_data, &[config.n_samples, n_features])?;
let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
Ok(DatasetResult {
features,
targets,
feature_names: Some((0..n_features).map(|i| format!("feature_{}", i)).collect()),
target_names: Some(
(0..config.centers)
.map(|i| format!("cluster_{}", i))
.collect(),
),
description: format!(
"Synthetic clustering dataset (blobs): {} samples, {} features, {} clusters, cluster_std={:.2}",
config.n_samples, n_features, config.centers, cluster_std
),
})
}
fn convert_scirs2_dataset(
scirs2_dataset: scirs2_datasets::utils::Dataset,
) -> Result<DatasetResult, DataError> {
let shape = scirs2_dataset.data.shape();
let features_data: Vec<f32> = scirs2_dataset.data.iter().map(|&x| x as f32).collect();
let features = Tensor::from_vec(features_data, &[shape[0], shape[1]])?;
let targets = if let Some(target_array) = scirs2_dataset.target {
let target_data: Vec<f32> = target_array.iter().map(|&x| x as f32).collect();
Tensor::from_vec(target_data, &[target_array.len()])?
} else {
Tensor::from_vec(vec![], &[0])?
};
Ok(DatasetResult {
features,
targets,
feature_names: scirs2_dataset.featurenames,
target_names: scirs2_dataset.targetnames,
description: scirs2_dataset
.description
.unwrap_or_else(|| "Dataset loaded from scirs2".to_string()),
})
}
fn load_iris_dataset() -> Result<DatasetResult, DataError> {
let scirs2_dataset = scirs2_load_iris().map_err(|e| {
DataError::dataset(
crate::error::DatasetErrorKind::CorruptedData,
format!("Failed to load Iris dataset from scirs2_datasets: {}", e),
)
})?;
convert_scirs2_dataset(scirs2_dataset)
}
fn load_boston_dataset() -> Result<DatasetResult, DataError> {
let scirs2_dataset = scirs2_load_boston().map_err(|e| {
DataError::dataset(
crate::error::DatasetErrorKind::CorruptedData,
format!("Failed to load Boston dataset from scirs2_datasets: {}", e),
)
})?;
convert_scirs2_dataset(scirs2_dataset)
}
fn load_diabetes_dataset() -> Result<DatasetResult, DataError> {
let scirs2_dataset = scirs2_load_diabetes().map_err(|e| {
DataError::dataset(
crate::error::DatasetErrorKind::CorruptedData,
format!(
"Failed to load Diabetes dataset from scirs2_datasets: {}",
e
),
)
})?;
convert_scirs2_dataset(scirs2_dataset)
}
fn load_wine_dataset() -> Result<DatasetResult, DataError> {
make_classification(ClassificationConfig {
n_samples: 178,
n_features: 13,
n_classes: 3,
n_informative: Some(13),
random_state: Some(42),
..Default::default()
})
}
fn load_breast_cancer_dataset() -> Result<DatasetResult, DataError> {
let scirs2_dataset = scirs2_load_breast_cancer().map_err(|e| {
DataError::dataset(
crate::error::DatasetErrorKind::CorruptedData,
format!(
"Failed to load Breast Cancer dataset from scirs2_datasets: {}",
e
),
)
})?;
convert_scirs2_dataset(scirs2_dataset)
}
fn load_digits_dataset() -> Result<DatasetResult, DataError> {
let scirs2_dataset = scirs2_load_digits().map_err(|e| {
DataError::dataset(
crate::error::DatasetErrorKind::CorruptedData,
format!("Failed to load Digits dataset from scirs2_datasets: {}", e),
)
})?;
convert_scirs2_dataset(scirs2_dataset)
}
impl Default for RegressionConfig {
fn default() -> Self {
Self {
n_samples: 100,
n_features: 1,
n_informative: None,
noise: Some(0.1),
bias: Some(0.0),
random_state: None,
}
}
}
impl Default for ClassificationConfig {
fn default() -> Self {
Self {
n_samples: 100,
n_features: 2,
n_classes: 2,
n_informative: None,
n_redundant: None,
n_clusters_per_class: None,
class_sep: Some(1.0),
random_state: None,
}
}
}
impl Default for ClusteringConfig {
fn default() -> Self {
Self {
n_samples: 100,
centers: 3,
n_features: Some(2),
cluster_std: Some(1.0),
center_box: Some((-10.0, 10.0)),
random_state: None,
}
}
}
#[derive(Debug, Default)]
pub struct DatasetRegistry {
builtin_datasets: Vec<BuiltinDataset>,
}
impl DatasetRegistry {
pub fn new() -> Self {
Self {
builtin_datasets: vec![
BuiltinDataset::Iris,
BuiltinDataset::Boston,
BuiltinDataset::Diabetes,
BuiltinDataset::Wine,
BuiltinDataset::BreastCancer,
BuiltinDataset::Digits,
],
}
}
pub fn list_builtin(&self) -> &[BuiltinDataset] {
&self.builtin_datasets
}
pub fn load_by_name(&self, name: &str) -> Result<DatasetResult, DataError> {
let dataset = match name.to_lowercase().as_str() {
"iris" => BuiltinDataset::Iris,
"boston" => BuiltinDataset::Boston,
"diabetes" => BuiltinDataset::Diabetes,
"wine" => BuiltinDataset::Wine,
"breast_cancer" | "breastcancer" => BuiltinDataset::BreastCancer,
"digits" => BuiltinDataset::Digits,
_ => {
return Err(DataError::dataset(
crate::error::DatasetErrorKind::UnsupportedFormat,
format!("Unknown dataset: {}", name),
))
}
};
load_builtin_dataset(dataset)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_iris_dataset() {
let result = load_builtin_dataset(BuiltinDataset::Iris);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 150);
assert_eq!(dataset.features.size(1).unwrap(), 4);
assert_eq!(dataset.targets.size(0).unwrap(), 150);
assert!(dataset.feature_names.is_some());
assert!(dataset.target_names.is_some());
assert!(!dataset.description.is_empty());
let feature_names = dataset.feature_names.unwrap();
assert_eq!(feature_names.len(), 4);
assert!(feature_names.contains(&"sepal_length".to_string()));
let target_names = dataset.target_names.unwrap();
assert_eq!(target_names.len(), 3);
}
#[test]
fn test_load_boston_dataset() {
let result = load_builtin_dataset(BuiltinDataset::Boston);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 30);
assert_eq!(dataset.features.size(1).unwrap(), 5);
assert_eq!(dataset.targets.size(0).unwrap(), 30);
assert!(dataset.feature_names.is_some());
assert!(!dataset.description.is_empty());
}
#[test]
fn test_load_diabetes_dataset() {
let result = load_builtin_dataset(BuiltinDataset::Diabetes);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 442);
assert_eq!(dataset.features.size(1).unwrap(), 10);
assert_eq!(dataset.targets.size(0).unwrap(), 442);
assert!(dataset.feature_names.is_some());
assert!(!dataset.description.is_empty());
let feature_names = dataset.feature_names.unwrap();
assert_eq!(feature_names.len(), 10);
assert!(feature_names.contains(&"age".to_string()));
assert!(feature_names.contains(&"bmi".to_string()));
}
#[test]
fn test_load_breast_cancer_dataset() {
let result = load_builtin_dataset(BuiltinDataset::BreastCancer);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 30);
assert_eq!(dataset.features.size(1).unwrap(), 5);
assert_eq!(dataset.targets.size(0).unwrap(), 30);
assert!(dataset.feature_names.is_some());
assert!(dataset.target_names.is_some());
assert!(!dataset.description.is_empty());
let target_names = dataset.target_names.unwrap();
assert_eq!(target_names.len(), 2); assert!(target_names.contains(&"malignant".to_string()));
assert!(target_names.contains(&"benign".to_string()));
}
#[test]
fn test_load_digits_dataset() {
let result = load_builtin_dataset(BuiltinDataset::Digits);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 50);
assert_eq!(dataset.features.size(1).unwrap(), 16);
assert_eq!(dataset.targets.size(0).unwrap(), 50);
assert!(dataset.target_names.is_some());
assert!(!dataset.description.is_empty());
let target_names = dataset.target_names.unwrap();
assert_eq!(target_names.len(), 10); }
#[test]
fn test_load_wine_dataset() {
let result = load_builtin_dataset(BuiltinDataset::Wine);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 178);
assert_eq!(dataset.features.size(1).unwrap(), 13);
assert_eq!(dataset.targets.size(0).unwrap(), 178);
assert!(!dataset.description.is_empty());
}
#[test]
fn test_dataset_registry() {
let registry = DatasetRegistry::new();
let builtin_datasets = registry.list_builtin();
assert_eq!(builtin_datasets.len(), 6);
}
#[test]
fn test_load_by_name() {
let registry = DatasetRegistry::new();
assert!(registry.load_by_name("iris").is_ok());
assert!(registry.load_by_name("boston").is_ok());
assert!(registry.load_by_name("diabetes").is_ok());
assert!(registry.load_by_name("wine").is_ok());
assert!(registry.load_by_name("breast_cancer").is_ok());
assert!(registry.load_by_name("breastcancer").is_ok()); assert!(registry.load_by_name("digits").is_ok());
assert!(registry.load_by_name("IRIS").is_ok());
assert!(registry.load_by_name("Diabetes").is_ok());
assert!(registry.load_by_name("unknown").is_err());
}
#[test]
fn test_make_regression() {
let config = RegressionConfig {
n_samples: 100,
n_features: 5,
n_informative: Some(3),
noise: Some(0.1),
bias: Some(1.0),
random_state: Some(42),
};
let result = make_regression(config);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 100);
assert_eq!(dataset.features.size(1).unwrap(), 5);
assert_eq!(dataset.targets.size(0).unwrap(), 100);
}
#[test]
fn test_make_classification() {
let config = ClassificationConfig {
n_samples: 200,
n_features: 10,
n_classes: 3,
n_informative: Some(5),
random_state: Some(42),
..Default::default()
};
let result = make_classification(config);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 200);
assert_eq!(dataset.features.size(1).unwrap(), 10);
assert_eq!(dataset.targets.size(0).unwrap(), 200);
}
#[test]
fn test_make_blobs() {
let config = ClusteringConfig {
n_samples: 150,
centers: 3,
n_features: Some(2),
cluster_std: Some(0.5),
random_state: Some(42),
..Default::default()
};
let result = make_blobs(config);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 150);
assert_eq!(dataset.features.size(1).unwrap(), 2);
assert_eq!(dataset.targets.size(0).unwrap(), 150);
}
#[test]
fn test_regression_config_validation() {
let config = RegressionConfig {
n_samples: 100,
n_features: 5,
n_informative: Some(10), noise: Some(0.1),
bias: Some(0.0),
random_state: Some(42),
};
let result = make_regression(config);
assert!(result.is_err());
}
#[test]
fn test_scirs2_integration_diabetes() {
let result = load_builtin_dataset(BuiltinDataset::Diabetes);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 442);
assert_eq!(dataset.features.size(1).unwrap(), 10);
assert!(
dataset.description.contains("diabetes") || dataset.description.contains("Diabetes")
);
}
#[test]
fn test_scirs2_integration_breast_cancer() {
let result = load_builtin_dataset(BuiltinDataset::BreastCancer);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 30);
assert_eq!(dataset.features.size(1).unwrap(), 5);
assert!(dataset.feature_names.is_some());
assert!(dataset.target_names.is_some());
}
#[test]
fn test_scirs2_integration_digits() {
let result = load_builtin_dataset(BuiltinDataset::Digits);
assert!(result.is_ok());
let dataset = result.unwrap();
assert_eq!(dataset.features.size(0).unwrap(), 50);
assert_eq!(dataset.features.size(1).unwrap(), 16);
assert!(dataset.target_names.is_some());
let target_names = dataset.target_names.unwrap();
assert_eq!(target_names.len(), 10);
}
}