scirs2_neural/data/
mod.rs1use crate::error::{NeuralError, Result};
7use scirs2_core::ndarray::{Array, Axis, IxDyn, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
9use scirs2_core::random::rngs::SmallRng;
10use scirs2_core::random::seq::SliceRandom;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use std::fmt::Debug;
13
14mod augmentation;
15mod dataloader;
16mod dataset;
17mod memory_pool;
18mod transforms;
19mod utils;
20
21pub use augmentation::*;
22pub use dataloader::*;
23pub use dataset::*;
24pub use memory_pool::*;
25pub use transforms::*;
26pub use utils::*;
27
28pub trait Dataset<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync>:
30 Send + Sync
31{
32 fn len(&self) -> usize;
34
35 fn is_empty(&self) -> bool {
37 self.len() == 0
38 }
39
40 fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)>;
42
43 fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync>;
45}
46
47impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> Dataset<F>
49 for Box<dyn Dataset<F> + Send + Sync>
50{
51 fn len(&self) -> usize {
52 (**self).len()
53 }
54
55 fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
56 (**self).get(index)
57 }
58
59 fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
60 (**self).box_clone()
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct InMemoryDataset<
67 F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
68> {
69 pub features: Array<F, IxDyn>,
71 pub labels: Array<F, IxDyn>,
73}
74
75impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync>
76 InMemoryDataset<F>
77{
78 pub fn new(features: Array<F, IxDyn>, labels: Array<F, IxDyn>) -> Result<Self> {
80 if features.shape()[0] != labels.shape()[0] {
81 return Err(NeuralError::InferenceError(format!(
82 "Features and labels have different number of samples: {} vs {}",
83 features.shape()[0],
84 labels.shape()[0]
85 )));
86 }
87
88 Ok(Self { features, labels })
89 }
90
91 pub fn train_test_split(&self, test_size: f64) -> Result<(Self, Self)> {
93 if test_size <= 0.0 || test_size >= 1.0 {
94 return Err(NeuralError::InferenceError(
95 "test_size must be between 0 and 1".to_string(),
96 ));
97 }
98
99 let n_samples = self.len();
100 let n_test = (n_samples as f64 * test_size).round() as usize;
101 let n_train = n_samples - n_test;
102
103 if n_train == 0 || n_test == 0 {
104 return Err(NeuralError::InferenceError(
105 "Split would result in empty training or test set".to_string(),
106 ));
107 }
108
109 let mut indices: Vec<usize> = (0..n_samples).collect();
111 let mut rng = SmallRng::from_rng(&mut thread_rng());
112 indices.shuffle(&mut rng);
113
114 let train_indices = &indices[0..n_train];
116 let test_indices = &indices[n_train..];
117
118 let train_features = self.features.select(Axis(0), train_indices);
120 let train_labels = self.labels.select(Axis(0), train_indices);
121
122 let test_features = self.features.select(Axis(0), test_indices);
124 let test_labels = self.labels.select(Axis(0), test_indices);
125
126 Ok((
127 Self::new(train_features, train_labels)?,
128 Self::new(test_features, test_labels)?,
129 ))
130 }
131}
132
133impl<F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync> Dataset<F>
134 for InMemoryDataset<F>
135{
136 fn len(&self) -> usize {
137 self.features.shape()[0]
138 }
139
140 fn get(&self, index: usize) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
141 if index >= self.len() {
142 return Err(NeuralError::InferenceError(format!(
143 "Index {} out of bounds for dataset with length {}",
144 index,
145 self.len()
146 )));
147 }
148
149 let x_slice = self.features.slice(scirs2_core::ndarray::s![index, ..]);
151 let y_slice = self.labels.slice(scirs2_core::ndarray::s![index, ..]);
152
153 let x_shape = x_slice.shape().to_vec();
154 let y_shape = y_slice.shape().to_vec();
155
156 let x = x_slice
157 .to_owned()
158 .into_shape_with_order(IxDyn(&x_shape))
159 .expect("Operation failed");
160 let y = y_slice
161 .to_owned()
162 .into_shape_with_order(IxDyn(&y_shape))
163 .expect("Operation failed");
164
165 Ok((x, y))
166 }
167
168 fn box_clone(&self) -> Box<dyn Dataset<F> + Send + Sync> {
169 Box::new(InMemoryDataset {
170 features: self.features.clone(),
171 labels: self.labels.clone(),
172 })
173 }
174}