ghostflow_data/
dataset.rs

1//! Dataset trait and implementations
2
3use ghostflow_core::Tensor;
4
5/// Base trait for datasets
6pub trait Dataset: Send + Sync {
7    /// Get a single item by index
8    fn get(&self, index: usize) -> (Tensor, Tensor);
9    
10    /// Get the total number of items
11    fn len(&self) -> usize;
12    
13    /// Check if dataset is empty
14    fn is_empty(&self) -> bool {
15        self.len() == 0
16    }
17}
18
19/// In-memory tensor dataset
20pub struct TensorDataset {
21    data: Tensor,
22    targets: Tensor,
23}
24
25impl TensorDataset {
26    pub fn new(data: Tensor, targets: Tensor) -> Self {
27        assert_eq!(data.dims()[0], targets.dims()[0], 
28            "Data and targets must have same number of samples");
29        TensorDataset { data, targets }
30    }
31}
32
33impl Dataset for TensorDataset {
34    fn get(&self, index: usize) -> (Tensor, Tensor) {
35        // Get single sample
36        let data_dims = self.data.dims();
37        let target_dims = self.targets.dims();
38        
39        let data_slice_size: usize = data_dims[1..].iter().product();
40        let target_slice_size: usize = if target_dims.len() > 1 {
41            target_dims[1..].iter().product()
42        } else {
43            1
44        };
45        
46        let data_vec = self.data.data_f32();
47        let target_vec = self.targets.data_f32();
48        
49        let data_start = index * data_slice_size;
50        let data_end = data_start + data_slice_size;
51        let sample_data = &data_vec[data_start..data_end];
52        
53        let target_start = index * target_slice_size;
54        let target_end = target_start + target_slice_size;
55        let sample_target = &target_vec[target_start..target_end];
56        
57        let data_shape: Vec<usize> = data_dims[1..].to_vec();
58        let target_shape: Vec<usize> = if target_dims.len() > 1 {
59            target_dims[1..].to_vec()
60        } else {
61            vec![1]
62        };
63        
64        (
65            Tensor::from_slice(sample_data, &data_shape).unwrap(),
66            Tensor::from_slice(sample_target, &target_shape).unwrap(),
67        )
68    }
69
70    fn len(&self) -> usize {
71        self.data.dims()[0]
72    }
73}
74
75/// Subset of a dataset
76pub struct Subset<D: Dataset> {
77    dataset: D,
78    indices: Vec<usize>,
79}
80
81impl<D: Dataset> Subset<D> {
82    pub fn new(dataset: D, indices: Vec<usize>) -> Self {
83        Subset { dataset, indices }
84    }
85}
86
87impl<D: Dataset> Dataset for Subset<D> {
88    fn get(&self, index: usize) -> (Tensor, Tensor) {
89        self.dataset.get(self.indices[index])
90    }
91
92    fn len(&self) -> usize {
93        self.indices.len()
94    }
95}
96
97/// Concatenation of multiple datasets
98pub struct ConcatDataset<D: Dataset> {
99    datasets: Vec<D>,
100    cumulative_sizes: Vec<usize>,
101}
102
103impl<D: Dataset> ConcatDataset<D> {
104    pub fn new(datasets: Vec<D>) -> Self {
105        let mut cumulative_sizes = Vec::with_capacity(datasets.len());
106        let mut total = 0;
107        
108        for ds in &datasets {
109            total += ds.len();
110            cumulative_sizes.push(total);
111        }
112        
113        ConcatDataset {
114            datasets,
115            cumulative_sizes,
116        }
117    }
118}
119
120impl<D: Dataset> Dataset for ConcatDataset<D> {
121    fn get(&self, index: usize) -> (Tensor, Tensor) {
122        let mut dataset_idx = 0;
123        let mut sample_idx = index;
124        
125        for (i, &size) in self.cumulative_sizes.iter().enumerate() {
126            if index < size {
127                dataset_idx = i;
128                if i > 0 {
129                    sample_idx = index - self.cumulative_sizes[i - 1];
130                }
131                break;
132            }
133        }
134        
135        self.datasets[dataset_idx].get(sample_idx)
136    }
137
138    fn len(&self) -> usize {
139        *self.cumulative_sizes.last().unwrap_or(&0)
140    }
141}