radiate_gp/regression/
data.rs1use radiate_core::random_provider;
2
3#[derive(Debug, Clone, Default)]
4pub struct Row<T> {
5 input: Vec<T>,
6 output: Vec<T>,
7}
8
9impl<T> Row<T> {
10 pub fn new(input: Vec<T>, output: Vec<T>) -> Self {
11 Row { input, output }
12 }
13
14 pub fn input(&self) -> &[T] {
15 &self.input
16 }
17
18 pub fn output(&self) -> &[T] {
19 &self.output
20 }
21}
22
23impl<T> From<(Vec<T>, Vec<T>)> for Row<T> {
24 fn from(data: (Vec<T>, Vec<T>)) -> Self {
25 Row::new(data.0, data.1)
26 }
27}
28
29#[derive(Default, Clone)]
30pub struct DataSet<T> {
31 rows: Vec<Row<T>>,
32}
33
34impl<T> DataSet<T> {
35 pub fn new(inputs: Vec<Vec<T>>, outputs: Vec<Vec<T>>) -> Self {
36 let mut samples = Vec::new();
37 for (input, output) in inputs.into_iter().zip(outputs.into_iter()) {
38 samples.push(Row { input, output });
39 }
40
41 DataSet { rows: samples }
42 }
43
44 pub fn row(mut self, row: impl Into<Row<T>>) -> Self {
45 self.rows.push(row.into());
46 self
47 }
48
49 pub fn iter(&self) -> std::slice::Iter<'_, Row<T>> {
50 self.rows.iter()
51 }
52
53 pub fn len(&self) -> usize {
54 self.rows.len()
55 }
56
57 pub fn shuffle(mut self) -> Self {
58 random_provider::shuffle(&mut self.rows);
59 self
60 }
61
62 pub fn shape(&self) -> (usize, usize, usize) {
63 let num_samples = self.rows.len();
64 let input_dim = if num_samples > 0 {
65 self.rows[0].input.len()
66 } else {
67 0
68 };
69 let output_dim = if num_samples > 0 {
70 self.rows[0].output.len()
71 } else {
72 0
73 };
74
75 (num_samples, input_dim, output_dim)
76 }
77
78 #[inline]
79 pub fn features(&self) -> Vec<Vec<T>>
80 where
81 T: Clone,
82 {
83 self.rows.iter().map(|row| row.input.clone()).collect()
84 }
85
86 #[inline]
87 pub fn labels(&self) -> Vec<Vec<T>>
88 where
89 T: Clone,
90 {
91 self.rows.iter().map(|row| row.output.clone()).collect()
92 }
93
94 #[inline]
95 pub fn split(self, ratio: f32) -> (Self, Self)
96 where
97 T: Clone,
98 {
99 let ratio = ratio.clamp(0.0, 1.0);
100 let split = (self.len() as f32 * ratio).round() as usize;
101 let (left, right) = self.rows.split_at(split);
102
103 (
104 DataSet {
105 rows: left.to_vec(),
106 },
107 DataSet {
108 rows: right.to_vec(),
109 },
110 )
111 }
112}
113
114impl DataSet<f32> {
115 pub fn standardize(mut self) -> Self {
116 let mut means = vec![0.0; self.rows[0].input.len()];
117 let mut stds = vec![0.0; self.rows[0].input.len()];
118
119 for sample in self.rows.iter() {
120 for (i, &val) in sample.input.iter().enumerate() {
121 means[i] += val;
122 }
123 }
124
125 let n = self.len() as f32;
126 for mean in means.iter_mut() {
127 *mean /= n;
128 }
129
130 for sample in self.rows.iter() {
131 for (i, &val) in sample.input.iter().enumerate() {
132 stds[i] += (val - means[i]).powi(2);
133 }
134 }
135
136 for std in stds.iter_mut() {
137 *std = (*std / n).sqrt();
138 }
139
140 for sample in self.rows.iter_mut() {
141 for (i, val) in sample.input.iter_mut().enumerate() {
142 *val = (*val - means[i]) / stds[i];
143 }
144 }
145
146 self
147 }
148
149 pub fn normalize(mut self) -> Self {
150 let mut mins = vec![f32::MAX; self.rows[0].input.len()];
151 let mut maxs = vec![f32::MIN; self.rows[0].input.len()];
152
153 for sample in self.rows.iter() {
154 for (i, &val) in sample.input.iter().enumerate() {
155 if val < mins[i] {
156 mins[i] = val;
157 }
158
159 if val > maxs[i] {
160 maxs[i] = val;
161 }
162 }
163 }
164
165 for sample in self.rows.iter_mut() {
166 for (i, val) in sample.input.iter_mut().enumerate() {
167 *val = (*val - mins[i]) / (maxs[i] - mins[i]);
168 }
169 }
170
171 self
172 }
173}
174
175impl<T> From<Vec<Vec<Option<T>>>> for DataSet<T>
176where
177 T: Clone,
178{
179 fn from(data: Vec<Vec<Option<T>>>) -> Self {
180 let mut rows = Vec::new();
181 for row in data.into_iter() {
182 let input = row
183 .iter()
184 .filter_map(|v| v.as_ref())
185 .cloned()
186 .collect::<Vec<T>>();
187
188 rows.push(Row {
189 input,
190 output: Vec::new(),
191 });
192 }
193
194 DataSet { rows }
195 }
196}
197
198impl<T> From<(Vec<Vec<T>>, Vec<Vec<T>>)> for DataSet<T> {
199 fn from(data: (Vec<Vec<T>>, Vec<Vec<T>>)) -> Self {
200 DataSet::new(data.0, data.1)
201 }
202}