use ndarray::Array2;
use num_traits::Float;
pub trait Dataset {
fn n_samples(&self) -> usize;
fn n_features(&self) -> usize;
fn is_sparse(&self) -> bool;
}
impl<F> Dataset for Array2<F>
where
F: Float + Send + Sync + 'static,
{
fn n_samples(&self) -> usize {
self.nrows()
}
fn n_features(&self) -> usize {
self.ncols()
}
fn is_sparse(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_array2_f64_dataset() {
let data = Array2::<f64>::zeros((50, 12));
assert_eq!(data.n_samples(), 50);
assert_eq!(data.n_features(), 12);
assert!(!data.is_sparse());
}
#[test]
fn test_array2_f32_dataset() {
let data = Array2::<f32>::zeros((200, 5));
assert_eq!(data.n_samples(), 200);
assert_eq!(data.n_features(), 5);
assert!(!data.is_sparse());
}
#[test]
fn test_empty_array_dataset() {
let data = Array2::<f64>::zeros((0, 0));
assert_eq!(data.n_samples(), 0);
assert_eq!(data.n_features(), 0);
}
#[test]
fn test_single_sample_dataset() {
let data = Array2::<f64>::zeros((1, 100));
assert_eq!(data.n_samples(), 1);
assert_eq!(data.n_features(), 100);
}
#[test]
fn test_dataset_trait_is_object_safe() {
let data = Array2::<f64>::zeros((10, 3));
let _: &dyn Dataset = &data;
}
}