Skip to main content

scirs2_neural/data/
dataset.rs

1//! Dataset implementations for different data sources
2
3use crate::data::{Dataset, Transform};
4use crate::error::{NeuralError, Result};
5use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
6use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
7use std::fmt::Debug;
8use std::marker::PhantomData;
9use std::path::Path;
10
11/// CSV dataset implementation
12#[derive(Debug)]
13pub struct CSVDataset<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
14    /// Features (inputs)
15    features: Array<F, IxDyn>,
16    /// Labels (targets)
17    labels: Array<F, IxDyn>,
18    /// Transform to apply to features
19    feature_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
20    /// Transform to apply to labels
21    label_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
22}
23
24impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> Clone
25    for CSVDataset<F>
26{
27    fn clone(&self) -> Self {
28        // Manual clone implementation that uses box_clone for dyn Transform<F> + Send + Sync
29        Self {
30            features: self.features.clone(),
31            labels: self.labels.clone(),
32            feature_transform: match &self.feature_transform {
33                Some(t) => Some(t.box_clone()),
34                None => None,
35            },
36            label_transform: match &self.label_transform {
37                Some(t) => Some(t.box_clone()),
38                None => None,
39            },
40        }
41    }
42}
43
44impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> CSVDataset<F> {
45    /// Create a new dataset from CSV file
46    pub fn from_csv<P: AsRef<Path>>(
47        _path: P,
48        _has_header: bool,
49        _feature_cols: &[usize],
50        _label_cols: &[usize],
51        _delimiter: char,
52    ) -> Result<Self> {
53        // In a real implementation, we'd use a CSV reader here
54        // For now, just return an error
55        Err(NeuralError::InferenceError(
56            "CSV loading not yet implemented".to_string(),
57        ))
58    }
59
60    /// Set feature transform
61    pub fn with_feature_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
62        self.feature_transform = Some(Box::new(transform));
63        self
64    }
65
66    /// Set label transform
67    pub fn with_label_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
68        self.label_transform = Some(Box::new(transform));
69        self
70    }
71}
72
73impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> Dataset<F>
74    for CSVDataset<F>
75{
76    fn len(&self) -> usize {
77        self.features.shape()[0]
78    }
79
80    fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
81        if index >= self.len() {
82            return Err(NeuralError::InferenceError(format!(
83                "Index {} out of bounds for dataset with length {}",
84                index,
85                self.len()
86            )));
87        }
88
89        // Get slices of the data and convert to owned arrays
90        let x_slice = self.features.slice(scirs2_core::ndarray::s![index, ..]);
91        let y_slice = self.labels.slice(scirs2_core::ndarray::s![index, ..]);
92
93        // Convert to dynamic dimension arrays
94        let xshape = x_slice.shape().to_vec();
95        let yshape = y_slice.shape().to_vec();
96
97        let mut x = x_slice
98            .to_owned()
99            .into_shape_with_order(IxDyn(&xshape))
100            .expect("Operation failed");
101        let mut y = y_slice
102            .to_owned()
103            .into_shape_with_order(IxDyn(&yshape))
104            .expect("Operation failed");
105
106        // Apply transforms if available
107        if let Some(ref transform) = self.feature_transform {
108            x = transform.apply(&x)?;
109        }
110
111        if let Some(ref transform) = self.label_transform {
112            y = transform.apply(&y)?;
113        }
114
115        Ok((x, y))
116    }
117
118    fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
119        Box::new(self.clone())
120    }
121}
122
123/// Transformed dataset wrapper
124#[derive(Debug)]
125pub struct TransformedDataset<
126    F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
127    D: Dataset<F> + Clone,
128> {
129    /// Base dataset
130    dataset: D,
131    /// Transform to apply to features
132    feature_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
133    /// Transform to apply to labels
134    label_transform: Option<Box<dyn Transform<F> + Send + Sync>>,
135    /// Phantom data for float type
136    _phantom: PhantomData<F>,
137}
138
139impl<
140        F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
141        D: Dataset<F> + Clone,
142    > Clone for TransformedDataset<F, D>
143{
144    fn clone(&self) -> Self {
145        // Manual clone implementation that uses box_clone for dyn Transform<F> + Send + Sync
146        Self {
147            dataset: self.dataset.clone(),
148            feature_transform: match &self.feature_transform {
149                Some(t) => Some(t.box_clone()),
150                None => None,
151            },
152            label_transform: match &self.label_transform {
153                Some(t) => Some(t.box_clone()),
154                None => None,
155            },
156            _phantom: PhantomData,
157        }
158    }
159}
160
161impl<
162        F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
163        D: Dataset<F> + Clone,
164    > TransformedDataset<F, D>
165{
166    /// Create a new transformed dataset
167    pub fn new(dataset: D) -> Self {
168        Self {
169            dataset,
170            feature_transform: None,
171            label_transform: None,
172            _phantom: PhantomData,
173        }
174    }
175
176    /// Set feature transform
177    pub fn with_feature_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
178        self.feature_transform = Some(Box::new(transform));
179        self
180    }
181
182    /// Set label transform
183    pub fn with_label_transform<T: Transform<F> + 'static>(mut self, transform: T) -> Self {
184        self.label_transform = Some(Box::new(transform));
185        self
186    }
187}
188
189impl<
190        F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
191        D: Dataset<F> + Clone + 'static,
192    > Dataset<F> for TransformedDataset<F, D>
193{
194    fn len(&self) -> usize {
195        self.dataset.len()
196    }
197
198    fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
199        // Get the data from the underlying dataset
200        let (mut x, mut y) = self.dataset.get(index)?;
201
202        // Apply transforms if available
203        if let Some(ref transform) = self.feature_transform {
204            x = transform.apply(&x)?;
205        }
206
207        if let Some(ref transform) = self.label_transform {
208            y = transform.apply(&y)?;
209        }
210
211        Ok((x, y))
212    }
213
214    fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
215        Box::new(self.clone())
216    }
217}
218
219/// Subset dataset wrapper
220#[derive(Debug, Clone)]
221pub struct SubsetDataset<
222    F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
223    D: Dataset<F> + Clone,
224> {
225    /// Base dataset
226    dataset: D,
227    /// Indices to include in the subset
228    indices: Vec<usize>,
229    /// Phantom data for float type
230    _phantom: PhantomData<F>,
231}
232
233impl<
234        F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
235        D: Dataset<F> + Clone,
236    > SubsetDataset<F, D>
237{
238    /// Create a new subset dataset
239    pub fn new(dataset: D, indices: Vec<usize>) -> Result<Self> {
240        // Validate indices
241        for &idx in &indices {
242            if idx >= dataset.len() {
243                return Err(NeuralError::InferenceError(format!(
244                    "Index {} out of bounds for dataset with length {}",
245                    idx,
246                    dataset.len()
247                )));
248            }
249        }
250
251        Ok(Self {
252            dataset,
253            indices,
254            _phantom: PhantomData,
255        })
256    }
257}
258
259impl<
260        F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
261        D: Dataset<F> + Clone + 'static,
262    > Dataset<F> for SubsetDataset<F, D>
263{
264    fn len(&self) -> usize {
265        self.indices.len()
266    }
267
268    fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
269        if index >= self.len() {
270            return Err(NeuralError::InferenceError(format!(
271                "Index {} out of bounds for subset dataset with length {}",
272                index,
273                self.len()
274            )));
275        }
276
277        let dataset_index = self.indices[index];
278        self.dataset.get(dataset_index)
279    }
280
281    fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
282        Box::new(self.clone())
283    }
284}