use crate::utils::serialization;
use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Dataset {
#[serde(
serialize_with = "serialization::serialize_array2",
deserialize_with = "serialization::deserialize_array2"
)]
pub data: Array2<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub target: Option<Array1<f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub targetnames: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub featurenames: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub feature_descriptions: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub metadata: HashMap<String, String>,
}
impl Dataset {
pub fn new(data: Array2<f64>, target: Option<Array1<f64>>) -> Self {
Dataset {
data,
target,
targetnames: None,
featurenames: None,
feature_descriptions: None,
description: None,
metadata: HashMap::new(),
}
}
pub fn from_metadata(
data: Array2<f64>,
target: Option<Array1<f64>>,
metadata: crate::registry::DatasetMetadata,
) -> Self {
let mut dataset_metadata = HashMap::new();
dataset_metadata.insert("name".to_string(), metadata.name);
dataset_metadata.insert("task_type".to_string(), metadata.task_type);
dataset_metadata.insert("n_samples".to_string(), metadata.n_samples.to_string());
dataset_metadata.insert("n_features".to_string(), metadata.n_features.to_string());
Dataset {
data,
target,
targetnames: metadata.targetnames,
featurenames: None,
feature_descriptions: None,
description: Some(metadata.description),
metadata: dataset_metadata,
}
}
pub fn with_targetnames(mut self, targetnames: Vec<String>) -> Self {
self.targetnames = Some(targetnames);
self
}
pub fn with_featurenames(mut self, featurenames: Vec<String>) -> Self {
self.featurenames = Some(featurenames);
self
}
pub fn with_feature_descriptions(mut self, featuredescriptions: Vec<String>) -> Self {
self.feature_descriptions = Some(featuredescriptions);
self
}
pub fn with_description(mut self, description: String) -> Self {
self.description = Some(description);
self
}
pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
self.metadata.insert(key.to_string(), value.to_string());
self
}
pub fn n_samples(&self) -> usize {
self.data.nrows()
}
pub fn n_features(&self) -> usize {
self.data.ncols()
}
pub fn shape(&self) -> (usize, usize) {
(self.n_samples(), self.n_features())
}
pub fn has_target(&self) -> bool {
self.target.is_some()
}
pub fn featurenames(&self) -> Option<&Vec<String>> {
self.featurenames.as_ref()
}
pub fn targetnames(&self) -> Option<&Vec<String>> {
self.targetnames.as_ref()
}
pub fn description(&self) -> Option<&String> {
self.description.as_ref()
}
pub fn metadata(&self) -> &HashMap<String, String> {
&self.metadata
}
pub fn set_metadata(&mut self, key: &str, value: &str) {
self.metadata.insert(key.to_string(), value.to_string());
}
pub fn get_metadata(&self, key: &str) -> Option<&String> {
self.metadata.get(key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_dataset_creation() {
let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let target = Some(array![0.0, 1.0, 0.0]);
let dataset = Dataset::new(data.clone(), target.clone());
assert_eq!(dataset.n_samples(), 3);
assert_eq!(dataset.n_features(), 2);
assert_eq!(dataset.shape(), (3, 2));
assert!(dataset.has_target());
assert_eq!(dataset.data, data);
assert_eq!(dataset.target, target);
}
#[test]
fn test_dataset_builder_pattern() {
let data = array![[1.0, 2.0], [3.0, 4.0]];
let dataset = Dataset::new(data, None)
.with_featurenames(vec!["feat1".to_string(), "feat2".to_string()])
.with_description("Test dataset".to_string())
.with_metadata("version", "1.0")
.with_metadata("author", "test");
assert_eq!(dataset.featurenames().expect("Operation failed").len(), 2);
assert_eq!(
dataset.description().expect("Operation failed"),
"Test dataset"
);
assert_eq!(
dataset.get_metadata("version").expect("Operation failed"),
"1.0"
);
assert_eq!(
dataset.get_metadata("author").expect("Operation failed"),
"test"
);
}
#[test]
fn test_dataset_without_target() {
let data = array![[1.0, 2.0], [3.0, 4.0]];
let dataset = Dataset::new(data, None);
assert!(!dataset.has_target());
assert!(dataset.target.is_none());
}
#[test]
fn test_metadata_operations() {
let data = array![[1.0, 2.0]];
let mut dataset = Dataset::new(data, None);
dataset.set_metadata("key1", "value1");
dataset.set_metadata("key2", "value2");
assert_eq!(
dataset.get_metadata("key1").expect("Operation failed"),
"value1"
);
assert_eq!(
dataset.get_metadata("key2").expect("Operation failed"),
"value2"
);
assert!(dataset.get_metadata("nonexistent").is_none());
dataset.set_metadata("key1", "updated_value");
assert_eq!(
dataset.get_metadata("key1").expect("Operation failed"),
"updated_value"
);
}
}