1use crate::data::{Dataset, Transform};
4use crate::error::{NeuralError, Result};
5use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
6use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
7use std::fmt::Debug;
8use std::marker::PhantomData;
9use std::path::Path;
10
11#[derive(Debug)]
13pub struct CSVDataset<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
14 features: Array<F, IxDyn>,
16 labels: Array<F, IxDyn>,
18 feature_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
20 label_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
22}
23
24impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> Clone
25 for CSVDataset<F>
26{
27 fn clone(&self) -> Self {
28 Self {
30 features: self.features.clone(),
31 labels: self.labels.clone(),
32 feature_transform: match &self.feature_transform {
33 Some(t) => Some(t.box_clone()),
34 None => None,
35 },
36 label_transform: match &self.label_transform {
37 Some(t) => Some(t.box_clone()),
38 None => None,
39 },
40 }
41 }
42}
43
44impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> CSVDataset<F> {
45 pub fn from_csv<P: AsRef<Path>>(
47 _path: P,
48 _has_header: bool,
49 _feature_cols: &[usize],
50 _label_cols: &[usize],
51 _delimiter: char,
52 ) -> Result<Self> {
53 Err(NeuralError::InferenceError(
56 "CSV loading not yet implemented".to_string(),
57 ))
58 }
59
60 pub fn with_feature_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
62 self.feature_transform = Some(Box::new(transform));
63 self
64 }
65
66 pub fn with_label_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
68 self.label_transform = Some(Box::new(transform));
69 self
70 }
71}
72
73impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> Dataset<F>
74 for CSVDataset<F>
75{
76 fn len(&self) -> usize {
77 self.features.shape()[0]
78 }
79
80 fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
81 if index >= self.len() {
82 return Err(NeuralError::InferenceError(format!(
83 "Index {} out of bounds for dataset with length {}",
84 index,
85 self.len()
86 )));
87 }
88
89 let x_slice = self.features.slice(scirs2_core::ndarray::s![index, ..]);
91 let y_slice = self.labels.slice(scirs2_core::ndarray::s![index, ..]);
92
93 let xshape = x_slice.shape().to_vec();
95 let yshape = y_slice.shape().to_vec();
96
97 let mut x = x_slice
98 .to_owned()
99 .into_shape_with_order(IxDyn(&xshape))
100 .expect("Operation failed");
101 let mut y = y_slice
102 .to_owned()
103 .into_shape_with_order(IxDyn(&yshape))
104 .expect("Operation failed");
105
106 if let Some(ref transform) = self.feature_transform {
108 x = transform.apply(&x)?;
109 }
110
111 if let Some(ref transform) = self.label_transform {
112 y = transform.apply(&y)?;
113 }
114
115 Ok((x, y))
116 }
117
118 fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
119 Box::new(self.clone())
120 }
121}
122
123#[derive(Debug)]
125pub struct TransformedDataset<
126 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
127 D: Dataset<F> + Clone,
128> {
129 dataset: D,
131 feature_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
133 label_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
135 _phantom: PhantomData<F>,
137}
138
139impl<
140 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
141 D: Dataset<F> + Clone,
142 > Clone for TransformedDataset<F, D>
143{
144 fn clone(&self) -> Self {
145 Self {
147 dataset: self.dataset.clone(),
148 feature_transform: match &self.feature_transform {
149 Some(t) => Some(t.box_clone()),
150 None => None,
151 },
152 label_transform: match &self.label_transform {
153 Some(t) => Some(t.box_clone()),
154 None => None,
155 },
156 _phantom: PhantomData,
157 }
158 }
159}
160
161impl<
162 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
163 D: Dataset<F> + Clone,
164 > TransformedDataset<F, D>
165{
166 pub fn new(dataset: D) -> Self {
168 Self {
169 dataset,
170 feature_transform: None,
171 label_transform: None,
172 _phantom: PhantomData,
173 }
174 }
175
176 pub fn with_feature_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
178 self.feature_transform = Some(Box::new(transform));
179 self
180 }
181
182 pub fn with_label_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
184 self.label_transform = Some(Box::new(transform));
185 self
186 }
187}
188
189impl<
190 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
191 D: Dataset<F> + Clone + 'static,
192 > Dataset<F> for TransformedDataset<F, D>
193{
194 fn len(&self) -> usize {
195 self.dataset.len()
196 }
197
198 fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
199 let (mut x, mut y) = self.dataset.get(index)?;
201
202 if let Some(ref transform) = self.feature_transform {
204 x = transform.apply(&x)?;
205 }
206
207 if let Some(ref transform) = self.label_transform {
208 y = transform.apply(&y)?;
209 }
210
211 Ok((x, y))
212 }
213
214 fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
215 Box::new(self.clone())
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct SubsetDataset<
222 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
223 D: Dataset<F> + Clone,
224> {
225 dataset: D,
227 indices: Vec<usize>,
229 _phantom: PhantomData<F>,
231}
232
233impl<
234 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
235 D: Dataset<F> + Clone,
236 > SubsetDataset<F, D>
237{
238 pub fn new(dataset: D, indices: Vec<usize>) -> Result<Self> {
240 for &idx in &indices {
242 if idx >= dataset.len() {
243 return Err(NeuralError::InferenceError(format!(
244 "Index {} out of bounds for dataset with length {}",
245 idx,
246 dataset.len()
247 )));
248 }
249 }
250
251 Ok(Self {
252 dataset,
253 indices,
254 _phantom: PhantomData,
255 })
256 }
257}
258
259impl<
260 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
261 D: Dataset<F> + Clone + 'static,
262 > Dataset<F> for SubsetDataset<F, D>
263{
264 fn len(&self) -> usize {
265 self.indices.len()
266 }
267
268 fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
269 if index >= self.len() {
270 return Err(NeuralError::InferenceError(format!(
271 "Index {} out of bounds for subset dataset with length {}",
272 index,
273 self.len()
274 )));
275 }
276
277 let dataset_index = self.indices[index];
278 self.dataset.get(dataset_index)
279 }
280
281 fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
282 Box::new(self.clone())
283 }
284}