use crate::types::{Array1, Array2, Float};
#[derive(Debug, Clone)]
pub struct Dataset<X = Array2<Float>, Y = Array1<Float>> {
pub data: X,
pub target: Y,
pub feature_names: Vec<String>,
pub target_names: Option<Vec<String>>,
pub description: String,
}
impl<X, Y> Dataset<X, Y> {
pub fn new(data: X, target: Y) -> Self {
Self {
data,
target,
feature_names: Vec::new(),
target_names: None,
description: String::new(),
}
}
pub fn builder() -> crate::dataset::builder::DatasetBuilder<
X,
Y,
crate::dataset::builder::NoData,
crate::dataset::builder::NoTarget,
> {
crate::dataset::builder::DatasetBuilder::new()
}
pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
self.feature_names = names;
self
}
pub fn with_target_names(mut self, names: Vec<String>) -> Self {
self.target_names = Some(names);
self
}
pub fn with_description(mut self, description: String) -> Self {
self.description = description;
self
}
pub fn n_samples(&self) -> Option<usize>
where
X: HasShape,
{
self.data.shape().map(|(n_samples, _)| n_samples)
}
pub fn n_features(&self) -> Option<usize>
where
X: HasShape,
{
self.data.shape().map(|(_, n_features)| n_features)
}
}
pub trait HasShape {
fn shape(&self) -> Option<(usize, usize)>;
}
impl HasShape for Array2<Float> {
fn shape(&self) -> Option<(usize, usize)> {
let dim = self.dim();
Some((dim.0, dim.1))
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_dataset_creation() {
let data = Array2::<f64>::zeros((10, 3));
let target = Array1::<f64>::zeros(10);
let dataset = Dataset::new(data, target)
.with_description("Test dataset".to_string())
.with_feature_names(vec!["f1".to_string(), "f2".to_string(), "f3".to_string()]);
assert_eq!(dataset.description, "Test dataset");
assert_eq!(dataset.feature_names.len(), 3);
assert_eq!(dataset.n_samples(), Some(10));
assert_eq!(dataset.n_features(), Some(3));
}
#[test]
fn test_dataset_with_target_names() {
let data = Array2::<f64>::zeros((5, 2));
let target = Array1::<f64>::zeros(5);
let dataset = Dataset::new(data, target)
.with_target_names(vec!["class_a".to_string(), "class_b".to_string()]);
assert!(dataset.target_names.is_some());
assert_eq!(
dataset
.target_names
.as_ref()
.expect("value should be present")
.len(),
2
);
}
}