pub mod builder;
pub mod core;
pub mod synthetic;
#[cfg(feature = "mmap")]
pub mod mmap;
pub use builder::{DatasetBuilder, HasData, HasTarget, NoData, NoTarget};
pub use core::{Dataset, HasShape};
pub use synthetic::{load_iris, make_blobs, make_classification, make_regression};
#[cfg(feature = "mmap")]
pub use mmap::{
make_large_regression, MmapDataset, MmapDatasetBuilder, MmapDatasetBuilderConfig,
MmapSerializable,
};
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Array1, Array2};
#[test]
fn test_module_integration() {
let synthetic_dataset =
synthetic::make_regression(20, 3, 0.1).expect("expected valid value");
assert_eq!(synthetic_dataset.data.dim(), (20, 3));
assert_eq!(synthetic_dataset.target.len(), 20);
let builder_dataset = Dataset::builder()
.data(Array2::<f64>::zeros((10, 2)))
.target(Array1::<f64>::zeros(10))
.description("Integration test".to_string())
.build();
assert_eq!(builder_dataset.description, "Integration test");
assert_eq!(builder_dataset.data.dim(), (10, 2));
}
#[test]
fn test_iris_dataset() {
let iris = synthetic::load_iris().expect("expected valid value");
assert_eq!(iris.data.dim(), (6, 4));
assert_eq!(iris.target.len(), 6);
assert_eq!(iris.feature_names.len(), 4);
assert!(iris.target_names.is_some());
assert_eq!(
iris.target_names
.as_ref()
.expect("value should be present")
.len(),
3
);
}
#[test]
fn test_blob_generation() {
let blobs = synthetic::make_blobs(30, 2, 3, 1.0).expect("expected valid value");
assert_eq!(blobs.data.dim(), (30, 2));
assert_eq!(blobs.target.len(), 30);
let mut unique_targets: Vec<_> = blobs.target.iter().cloned().collect();
unique_targets.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
unique_targets.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
assert!(unique_targets.len() <= 3);
}
#[test]
fn test_classification_generation() {
let classification =
synthetic::make_classification(40, 3, 2.0).expect("expected valid value");
assert_eq!(classification.data.dim(), (40, 3));
assert_eq!(classification.target.len(), 40);
let mut unique_targets: Vec<_> = classification.target.iter().cloned().collect();
unique_targets.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
unique_targets.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
assert!(unique_targets.len() <= 2);
}
#[test]
fn test_dataset_metadata() {
let dataset = Dataset::new(Array2::<f64>::ones((5, 2)), Array1::<f64>::ones(5))
.with_feature_names(vec!["x".to_string(), "y".to_string()])
.with_target_names(vec!["class_a".to_string(), "class_b".to_string()])
.with_description("Test metadata".to_string());
assert_eq!(dataset.feature_names.len(), 2);
assert_eq!(
dataset
.target_names
.as_ref()
.expect("value should be present")
.len(),
2
);
assert_eq!(dataset.description, "Test metadata");
assert_eq!(dataset.n_samples(), Some(5));
assert_eq!(dataset.n_features(), Some(2));
}
}