use crate::data::{Dataset, Transform};
use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
use std::fmt::Debug;
use std::marker::PhantomData;
use std::path::Path;
#[derive(Debug)]
pub struct CSVDataset<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
features: Array<F, IxDyn>,
labels: Array<F, IxDyn>,
feature_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
label_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
}
impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> Clone
for CSVDataset<F>
{
fn clone(&self) -> Self {
Self {
features: self.features.clone(),
labels: self.labels.clone(),
feature_transform: match &self.feature_transform {
Some(t) => Some(t.box_clone()),
None => None,
},
label_transform: match &self.label_transform {
Some(t) => Some(t.box_clone()),
None => None,
},
}
}
}
impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> CSVDataset<F> {
pub fn from_csv<P: AsRef<Path>>(
_path: P,
_has_header: bool,
_feature_cols: &[usize],
_label_cols: &[usize],
_delimiter: char,
) -> Result<Self> {
Err(NeuralError::InferenceError(
"CSV loading not yet implemented".to_string(),
))
}
pub fn with_feature_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
self.feature_transform = Some(Box::new(transform));
self
}
pub fn with_label_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
self.label_transform = Some(Box::new(transform));
self
}
}
impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> Dataset<F>
for CSVDataset<F>
{
fn len(&self) -> usize {
self.features.shape()[0]
}
fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
if index >= self.len() {
return Err(NeuralError::InferenceError(format!(
"Index {} out of bounds for dataset with length {}",
index,
self.len()
)));
}
let x_slice = self.features.slice(scirs2_core::ndarray::s![index, ..]);
let y_slice = self.labels.slice(scirs2_core::ndarray::s![index, ..]);
let xshape = x_slice.shape().to_vec();
let yshape = y_slice.shape().to_vec();
let mut x = x_slice
.to_owned()
.into_shape_with_order(IxDyn(&xshape))
.expect("Operation failed");
let mut y = y_slice
.to_owned()
.into_shape_with_order(IxDyn(&yshape))
.expect("Operation failed");
if let Some(ref transform) = self.feature_transform {
x = transform.apply(&x)?;
}
if let Some(ref transform) = self.label_transform {
y = transform.apply(&y)?;
}
Ok((x, y))
}
fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
Box::new(self.clone())
}
}
#[derive(Debug)]
pub struct TransformedDataset<
F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dataset<F> + Clone,
> {
dataset: D,
feature_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
label_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
_phantom: PhantomData<F>,
}
impl<
F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dataset<F> + Clone,
> Clone for TransformedDataset<F, D>
{
fn clone(&self) -> Self {
Self {
dataset: self.dataset.clone(),
feature_transform: match &self.feature_transform {
Some(t) => Some(t.box_clone()),
None => None,
},
label_transform: match &self.label_transform {
Some(t) => Some(t.box_clone()),
None => None,
},
_phantom: PhantomData,
}
}
}
impl<
F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dataset<F> + Clone,
> TransformedDataset<F, D>
{
pub fn new(dataset: D) -> Self {
Self {
dataset,
feature_transform: None,
label_transform: None,
_phantom: PhantomData,
}
}
pub fn with_feature_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
self.feature_transform = Some(Box::new(transform));
self
}
pub fn with_label_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
self.label_transform = Some(Box::new(transform));
self
}
}
impl<
F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dataset<F> + Clone + 'static,
> Dataset<F> for TransformedDataset<F, D>
{
fn len(&self) -> usize {
self.dataset.len()
}
fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
let (mut x, mut y) = self.dataset.get(index)?;
if let Some(ref transform) = self.feature_transform {
x = transform.apply(&x)?;
}
if let Some(ref transform) = self.label_transform {
y = transform.apply(&y)?;
}
Ok((x, y))
}
fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
pub struct SubsetDataset<
F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dataset<F> + Clone,
> {
dataset: D,
indices: Vec<usize>,
_phantom: PhantomData<F>,
}
impl<
F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dataset<F> + Clone,
> SubsetDataset<F, D>
{
pub fn new(dataset: D, indices: Vec<usize>) -> Result<Self> {
for &idx in &indices {
if idx >= dataset.len() {
return Err(NeuralError::InferenceError(format!(
"Index {} out of bounds for dataset with length {}",
idx,
dataset.len()
)));
}
}
Ok(Self {
dataset,
indices,
_phantom: PhantomData,
})
}
}
impl<
F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dataset<F> + Clone + 'static,
> Dataset<F> for SubsetDataset<F, D>
{
fn len(&self) -> usize {
self.indices.len()
}
fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
if index >= self.len() {
return Err(NeuralError::InferenceError(format!(
"Index {} out of bounds for subset dataset with length {}",
index,
self.len()
)));
}
let dataset_index = self.indices[index];
self.dataset.get(dataset_index)
}
fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
Box::new(self.clone())
}
}