1use ghostflow_core::Tensor;
4use rand::prelude::*;
5
6pub struct KFold {
8 pub n_splits: usize,
9 pub shuffle: bool,
10 pub random_state: Option<u64>,
11}
12
13impl KFold {
14 pub fn new(n_splits: usize) -> Self {
15 KFold {
16 n_splits,
17 shuffle: false,
18 random_state: None,
19 }
20 }
21
22 pub fn shuffle(mut self, shuffle: bool) -> Self {
23 self.shuffle = shuffle;
24 self
25 }
26
27 pub fn random_state(mut self, seed: u64) -> Self {
28 self.random_state = Some(seed);
29 self
30 }
31
32 pub fn split(&self, n_samples: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
33 let mut indices: Vec<usize> = (0..n_samples).collect();
34
35 if self.shuffle {
36 let mut rng = match self.random_state {
37 Some(seed) => StdRng::seed_from_u64(seed),
38 None => StdRng::from_entropy(),
39 };
40 indices.shuffle(&mut rng);
41 }
42
43 let fold_size = n_samples / self.n_splits;
44 let remainder = n_samples % self.n_splits;
45
46 let mut folds = Vec::with_capacity(self.n_splits);
47 let mut start = 0;
48
49 for i in 0..self.n_splits {
50 let extra = if i < remainder { 1 } else { 0 };
51 let end = start + fold_size + extra;
52
53 let test_indices: Vec<usize> = indices[start..end].to_vec();
54 let train_indices: Vec<usize> = indices[..start].iter()
55 .chain(indices[end..].iter())
56 .cloned()
57 .collect();
58
59 folds.push((train_indices, test_indices));
60 start = end;
61 }
62
63 folds
64 }
65}
66
67pub struct StratifiedKFold {
69 pub n_splits: usize,
70 pub shuffle: bool,
71 pub random_state: Option<u64>,
72}
73
74impl StratifiedKFold {
75 pub fn new(n_splits: usize) -> Self {
76 StratifiedKFold {
77 n_splits,
78 shuffle: false,
79 random_state: None,
80 }
81 }
82
83 pub fn shuffle(mut self, shuffle: bool) -> Self {
84 self.shuffle = shuffle;
85 self
86 }
87
88 pub fn split(&self, y: &Tensor) -> Vec<(Vec<usize>, Vec<usize>)> {
89 let y_data = y.data_f32();
90 let _n_samples = y_data.len();
91
92 let n_classes = y_data.iter().map(|&v| v as usize).max().unwrap_or(0) + 1;
94 let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
95
96 for (i, &label) in y_data.iter().enumerate() {
97 class_indices[label as usize].push(i);
98 }
99
100 if self.shuffle {
102 let mut rng = match self.random_state {
103 Some(seed) => StdRng::seed_from_u64(seed),
104 None => StdRng::from_entropy(),
105 };
106 for indices in &mut class_indices {
107 indices.shuffle(&mut rng);
108 }
109 }
110
111 let mut folds: Vec<(Vec<usize>, Vec<usize>)> = (0..self.n_splits)
113 .map(|_| (Vec::new(), Vec::new()))
114 .collect();
115
116 for class_idx in &class_indices {
117 let n_class = class_idx.len();
118 let fold_size = n_class / self.n_splits;
119 let remainder = n_class % self.n_splits;
120
121 let mut start = 0;
122 for i in 0..self.n_splits {
123 let extra = if i < remainder { 1 } else { 0 };
124 let end = start + fold_size + extra;
125
126 folds[i].1.extend(&class_idx[start..end]);
128
129 for j in 0..self.n_splits {
131 if j != i {
132 folds[j].0.extend(&class_idx[start..end]);
133 }
134 }
135
136 start = end;
137 }
138 }
139
140 folds
141 }
142}
143
144pub struct LeaveOneOut;
146
147impl LeaveOneOut {
148 pub fn new() -> Self {
149 LeaveOneOut
150 }
151
152 pub fn split(&self, n_samples: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
153 (0..n_samples)
154 .map(|i| {
155 let train: Vec<usize> = (0..n_samples).filter(|&j| j != i).collect();
156 let test = vec![i];
157 (train, test)
158 })
159 .collect()
160 }
161}
162
163impl Default for LeaveOneOut {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169pub struct TimeSeriesSplit {
171 pub n_splits: usize,
172 pub max_train_size: Option<usize>,
173 pub test_size: Option<usize>,
174 pub gap: usize,
175}
176
177impl TimeSeriesSplit {
178 pub fn new(n_splits: usize) -> Self {
179 TimeSeriesSplit {
180 n_splits,
181 max_train_size: None,
182 test_size: None,
183 gap: 0,
184 }
185 }
186
187 pub fn gap(mut self, gap: usize) -> Self {
188 self.gap = gap;
189 self
190 }
191
192 pub fn split(&self, n_samples: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
193 let test_size = self.test_size.unwrap_or(n_samples / (self.n_splits + 1));
194
195 let mut folds = Vec::with_capacity(self.n_splits);
196
197 for i in 0..self.n_splits {
198 let test_start = n_samples - (self.n_splits - i) * test_size;
199 let test_end = test_start + test_size;
200
201 let train_end = test_start - self.gap;
202 let train_start = match self.max_train_size {
203 Some(max) => train_end.saturating_sub(max),
204 None => 0,
205 };
206
207 if train_start < train_end && test_start < test_end {
208 let train: Vec<usize> = (train_start..train_end).collect();
209 let test: Vec<usize> = (test_start..test_end.min(n_samples)).collect();
210 folds.push((train, test));
211 }
212 }
213
214 folds
215 }
216}
217
218pub fn cross_val_score<F>(
220 x: &Tensor,
221 y: &Tensor,
222 cv: &[(Vec<usize>, Vec<usize>)],
223 fit_predict_score: F,
224) -> Vec<f32>
225where
226 F: Fn(&Tensor, &Tensor, &Tensor, &Tensor) -> f32,
227{
228 let x_data = x.data_f32();
229 let y_data = y.data_f32();
230 let n_features = x.dims()[1];
231
232 cv.iter()
233 .map(|(train_idx, test_idx)| {
234 let x_train: Vec<f32> = train_idx.iter()
236 .flat_map(|&i| x_data[i * n_features..(i + 1) * n_features].to_vec())
237 .collect();
238 let y_train: Vec<f32> = train_idx.iter().map(|&i| y_data[i]).collect();
239
240 let x_test: Vec<f32> = test_idx.iter()
242 .flat_map(|&i| x_data[i * n_features..(i + 1) * n_features].to_vec())
243 .collect();
244 let y_test: Vec<f32> = test_idx.iter().map(|&i| y_data[i]).collect();
245
246 let x_train_tensor = Tensor::from_slice(&x_train, &[train_idx.len(), n_features]).unwrap();
247 let y_train_tensor = Tensor::from_slice(&y_train, &[train_idx.len()]).unwrap();
248 let x_test_tensor = Tensor::from_slice(&x_test, &[test_idx.len(), n_features]).unwrap();
249 let y_test_tensor = Tensor::from_slice(&y_test, &[test_idx.len()]).unwrap();
250
251 fit_predict_score(&x_train_tensor, &y_train_tensor, &x_test_tensor, &y_test_tensor)
252 })
253 .collect()
254}
255
256pub struct GridSearchResult {
258 pub best_params: Vec<(String, f32)>,
259 pub best_score: f32,
260 pub cv_results: Vec<(Vec<(String, f32)>, f32)>,
261}
262
263pub fn parameter_grid(params: &[(String, Vec<f32>)]) -> Vec<Vec<(String, f32)>> {
265 if params.is_empty() {
266 return vec![vec![]];
267 }
268
269 let (name, values) = ¶ms[0];
270 let rest = parameter_grid(¶ms[1..]);
271
272 let mut result = Vec::new();
273 for &value in values {
274 for r in &rest {
275 let mut combo = vec![(name.clone(), value)];
276 combo.extend(r.clone());
277 result.push(combo);
278 }
279 }
280
281 result
282}
283
284pub fn shuffle_split(
286 x: &Tensor,
287 y: &Tensor,
288 test_size: f32,
289 random_state: Option<u64>,
290) -> (Tensor, Tensor, Tensor, Tensor) {
291 let x_data = x.data_f32();
292 let y_data = y.data_f32();
293 let n_samples = x.dims()[0];
294 let n_features = x.dims()[1];
295
296 let mut indices: Vec<usize> = (0..n_samples).collect();
297
298 let mut rng = match random_state {
299 Some(seed) => StdRng::seed_from_u64(seed),
300 None => StdRng::from_entropy(),
301 };
302 indices.shuffle(&mut rng);
303
304 let n_test = (n_samples as f32 * test_size).round() as usize;
305 let n_train = n_samples - n_test;
306
307 let train_indices = &indices[..n_train];
308 let test_indices = &indices[n_train..];
309
310 let x_train: Vec<f32> = train_indices.iter()
311 .flat_map(|&i| x_data[i * n_features..(i + 1) * n_features].to_vec())
312 .collect();
313 let y_train: Vec<f32> = train_indices.iter().map(|&i| y_data[i]).collect();
314
315 let x_test: Vec<f32> = test_indices.iter()
316 .flat_map(|&i| x_data[i * n_features..(i + 1) * n_features].to_vec())
317 .collect();
318 let y_test: Vec<f32> = test_indices.iter().map(|&i| y_data[i]).collect();
319
320 (
321 Tensor::from_slice(&x_train, &[n_train, n_features]).unwrap(),
322 Tensor::from_slice(&x_test, &[n_test, n_features]).unwrap(),
323 Tensor::from_slice(&y_train, &[n_train]).unwrap(),
324 Tensor::from_slice(&y_test, &[n_test]).unwrap(),
325 )
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_kfold() {
334 let kf = KFold::new(5);
335 let folds = kf.split(100);
336
337 assert_eq!(folds.len(), 5);
338
339 for (train, test) in &folds {
340 assert_eq!(train.len() + test.len(), 100);
341 assert_eq!(test.len(), 20);
342 }
343 }
344
345 #[test]
346 fn test_stratified_kfold() {
347 let y = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], &[8]).unwrap();
348
349 let skf = StratifiedKFold::new(2);
350 let folds = skf.split(&y);
351
352 assert_eq!(folds.len(), 2);
353 }
354
355 #[test]
356 fn test_time_series_split() {
357 let tss = TimeSeriesSplit::new(3);
358 let folds = tss.split(100);
359
360 assert_eq!(folds.len(), 3);
361
362 for (train, test) in &folds {
364 let max_train = train.iter().max().unwrap_or(&0);
365 let min_test = test.iter().min().unwrap_or(&100);
366 assert!(max_train < min_test);
367 }
368 }
369
370 #[test]
371 fn test_parameter_grid() {
372 let params = vec![
373 ("alpha".to_string(), vec![0.1, 1.0]),
374 ("beta".to_string(), vec![0.01, 0.1]),
375 ];
376
377 let grid = parameter_grid(¶ms);
378 assert_eq!(grid.len(), 4);
379 }
380}
381
382