use ghostflow_core::Tensor;
pub trait Dataset: Send + Sync {
fn get(&self, index: usize) -> (Tensor, Tensor);
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct TensorDataset {
data: Tensor,
targets: Tensor,
}
impl TensorDataset {
pub fn new(data: Tensor, targets: Tensor) -> Self {
assert_eq!(data.dims()[0], targets.dims()[0],
"Data and targets must have same number of samples");
TensorDataset { data, targets }
}
}
impl Dataset for TensorDataset {
fn get(&self, index: usize) -> (Tensor, Tensor) {
let data_dims = self.data.dims();
let target_dims = self.targets.dims();
let data_slice_size: usize = data_dims[1..].iter().product();
let target_slice_size: usize = if target_dims.len() > 1 {
target_dims[1..].iter().product()
} else {
1
};
let data_vec = self.data.data_f32();
let target_vec = self.targets.data_f32();
let data_start = index * data_slice_size;
let data_end = data_start + data_slice_size;
let sample_data = &data_vec[data_start..data_end];
let target_start = index * target_slice_size;
let target_end = target_start + target_slice_size;
let sample_target = &target_vec[target_start..target_end];
let data_shape: Vec<usize> = data_dims[1..].to_vec();
let target_shape: Vec<usize> = if target_dims.len() > 1 {
target_dims[1..].to_vec()
} else {
vec![1]
};
(
Tensor::from_slice(sample_data, &data_shape).unwrap(),
Tensor::from_slice(sample_target, &target_shape).unwrap(),
)
}
fn len(&self) -> usize {
self.data.dims()[0]
}
}
pub struct Subset<D: Dataset> {
dataset: D,
indices: Vec<usize>,
}
impl<D: Dataset> Subset<D> {
pub fn new(dataset: D, indices: Vec<usize>) -> Self {
Subset { dataset, indices }
}
}
impl<D: Dataset> Dataset for Subset<D> {
fn get(&self, index: usize) -> (Tensor, Tensor) {
self.dataset.get(self.indices[index])
}
fn len(&self) -> usize {
self.indices.len()
}
}
pub struct ConcatDataset<D: Dataset> {
datasets: Vec<D>,
cumulative_sizes: Vec<usize>,
}
impl<D: Dataset> ConcatDataset<D> {
pub fn new(datasets: Vec<D>) -> Self {
let mut cumulative_sizes = Vec::with_capacity(datasets.len());
let mut total = 0;
for ds in &datasets {
total += ds.len();
cumulative_sizes.push(total);
}
ConcatDataset {
datasets,
cumulative_sizes,
}
}
}
impl<D: Dataset> Dataset for ConcatDataset<D> {
fn get(&self, index: usize) -> (Tensor, Tensor) {
let mut dataset_idx = 0;
let mut sample_idx = index;
for (i, &size) in self.cumulative_sizes.iter().enumerate() {
if index < size {
dataset_idx = i;
if i > 0 {
sample_idx = index - self.cumulative_sizes[i - 1];
}
break;
}
}
self.datasets[dataset_idx].get(sample_idx)
}
fn len(&self) -> usize {
*self.cumulative_sizes.last().unwrap_or(&0)
}
}