use crate::error::{DatasetsError, Result};
use scirs2_core::ndarray::{Array1, Array2};
#[derive(Debug, Clone)]
pub struct DatasetResult {
pub data: Array2<f64>,
pub target: Array1<f64>,
pub feature_names: Vec<String>,
pub target_names: Vec<String>,
pub description: String,
}
impl DatasetResult {
pub fn n_samples(&self) -> usize {
self.data.nrows()
}
pub fn n_features(&self) -> usize {
self.data.ncols()
}
pub fn shape(&self) -> (usize, usize) {
(self.n_samples(), self.n_features())
}
}
mod boston_data;
mod breast_cancer_data;
mod digits_data;
mod iris_data;
mod wine_data;
pub fn load_iris() -> Result<DatasetResult> {
iris_data::load()
}
pub fn load_wine() -> Result<DatasetResult> {
wine_data::load()
}
pub fn load_breast_cancer() -> Result<DatasetResult> {
breast_cancer_data::load()
}
pub fn load_digits() -> Result<DatasetResult> {
digits_data::load()
}
pub fn load_boston() -> Result<DatasetResult> {
boston_data::load()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_iris_shape() {
let ds = load_iris().expect("ok");
assert_eq!(ds.n_samples(), 150);
assert_eq!(ds.n_features(), 4);
assert_eq!(ds.target.len(), 150);
assert_eq!(ds.feature_names.len(), 4);
assert_eq!(ds.target_names.len(), 3);
}
#[test]
fn test_iris_labels() {
let ds = load_iris().expect("ok");
for &v in ds.target.iter() {
assert!(v == 0.0 || v == 1.0 || v == 2.0, "Invalid iris label {v}");
}
let count_0 = ds.target.iter().filter(|&&v| v == 0.0).count();
let count_1 = ds.target.iter().filter(|&&v| v == 1.0).count();
let count_2 = ds.target.iter().filter(|&&v| v == 2.0).count();
assert_eq!(count_0, 50);
assert_eq!(count_1, 50);
assert_eq!(count_2, 50);
}
#[test]
fn test_wine_shape() {
let ds = load_wine().expect("ok");
assert_eq!(ds.n_samples(), 178);
assert_eq!(ds.n_features(), 13);
assert_eq!(ds.target.len(), 178);
assert_eq!(ds.feature_names.len(), 13);
assert_eq!(ds.target_names.len(), 3);
}
#[test]
fn test_wine_labels() {
let ds = load_wine().expect("ok");
for &v in ds.target.iter() {
assert!(v == 0.0 || v == 1.0 || v == 2.0, "Invalid wine label {v}");
}
}
#[test]
fn test_breast_cancer_shape() {
let ds = load_breast_cancer().expect("ok");
assert_eq!(ds.n_samples(), 569);
assert_eq!(ds.n_features(), 30);
assert_eq!(ds.target.len(), 569);
assert_eq!(ds.feature_names.len(), 30);
assert_eq!(ds.target_names.len(), 2);
}
#[test]
fn test_breast_cancer_labels() {
let ds = load_breast_cancer().expect("ok");
for &v in ds.target.iter() {
assert!(v == 0.0 || v == 1.0, "Invalid breast cancer label {v}");
}
}
#[test]
fn test_digits_shape() {
let ds = load_digits().expect("ok");
assert_eq!(ds.n_samples(), 1797);
assert_eq!(ds.n_features(), 64);
assert_eq!(ds.target.len(), 1797);
assert_eq!(ds.feature_names.len(), 64);
assert_eq!(ds.target_names.len(), 10);
}
#[test]
fn test_digits_labels() {
let ds = load_digits().expect("ok");
for &v in ds.target.iter() {
assert!(
(0.0..=9.0).contains(&v) && v == v.floor(),
"Invalid digit label {v}"
);
}
}
#[test]
fn test_digits_pixel_range() {
let ds = load_digits().expect("ok");
for row in ds.data.rows() {
for &v in row.iter() {
assert!(
(0.0..=16.0).contains(&v),
"Pixel value {v} out of range [0, 16]"
);
}
}
}
#[test]
fn test_boston_shape() {
let ds = load_boston().expect("ok");
assert_eq!(ds.n_samples(), 506);
assert_eq!(ds.n_features(), 13);
assert_eq!(ds.target.len(), 506);
assert_eq!(ds.feature_names.len(), 13);
}
#[test]
fn test_boston_target_positive() {
let ds = load_boston().expect("ok");
for &v in ds.target.iter() {
assert!(v > 0.0, "Boston target should be positive, got {v}");
}
}
#[test]
fn test_dataset_result_methods() {
let ds = load_iris().expect("ok");
assert_eq!(ds.shape(), (150, 4));
assert!(!ds.description.is_empty());
}
#[test]
fn test_all_datasets_consistent() {
let datasets: Vec<(&str, DatasetResult)> = vec![
("iris", load_iris().expect("ok")),
("wine", load_wine().expect("ok")),
("breast_cancer", load_breast_cancer().expect("ok")),
("digits", load_digits().expect("ok")),
("boston", load_boston().expect("ok")),
];
for (name, ds) in &datasets {
assert_eq!(
ds.data.nrows(),
ds.target.len(),
"{name}: data rows != target len"
);
assert_eq!(
ds.data.ncols(),
ds.feature_names.len(),
"{name}: data cols != feature_names len"
);
assert!(
!ds.description.is_empty(),
"{name}: description should not be empty"
);
}
}
}