use crate::Dataset;
use std::marker::PhantomData;
use tenflowers_core::{Result, Tensor};
pub mod augmentation;
pub mod feature_engineering;
pub mod noise;
pub mod normalization;
pub mod pipeline;
pub mod profiling;
pub mod vision;
pub use noise::*;
pub use normalization::*;
pub trait Transform<T> {
fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)>;
}
pub struct TransformedDataset<T, D: Dataset<T>, Tr: Transform<T>> {
dataset: D,
transform: Tr,
_phantom: PhantomData<T>,
}
impl<T, D: Dataset<T>, Tr: Transform<T>> TransformedDataset<T, D, Tr> {
pub fn new(dataset: D, transform: Tr) -> Self {
Self {
dataset,
transform,
_phantom: PhantomData,
}
}
}
impl<T, D: Dataset<T>, Tr: Transform<T>> Dataset<T> for TransformedDataset<T, D, Tr> {
fn len(&self) -> usize {
self.dataset.len()
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
let sample = self.dataset.get(index)?;
self.transform.apply(sample)
}
}
pub trait DatasetExt<T>: Dataset<T> + Sized {
fn transform<Tr: Transform<T>>(self, transform: Tr) -> TransformedDataset<T, Self, Tr> {
TransformedDataset::new(self, transform)
}
}
impl<T, D: Dataset<T>> DatasetExt<T> for D {}