1use crate::error::{MLError, Result};
4use scirs2_core::ndarray::{Array1, Array2, Axis};
5use scirs2_core::random::prelude::*;
6
7use super::*;
8pub fn train_test_split(
10 features: &Array2<f64>,
11 labels: &Array1<usize>,
12 test_ratio: f64,
13 shuffle: bool,
14) -> Result<(Array2<f64>, Array1<usize>, Array2<f64>, Array1<usize>)> {
15 if features.nrows() != labels.len() {
16 return Err(MLError::InvalidInput(
17 "Features and labels must have same number of samples".to_string(),
18 ));
19 }
20 if test_ratio <= 0.0 || test_ratio >= 1.0 {
21 return Err(MLError::InvalidInput(
22 "Test ratio must be between 0 and 1".to_string(),
23 ));
24 }
25 let n_samples = features.nrows();
26 let n_test = (n_samples as f64 * test_ratio) as usize;
27 let n_train = n_samples - n_test;
28 let mut indices: Vec<usize> = (0..n_samples).collect();
29 if shuffle {
30 let mut rng = thread_rng();
31 for i in (1..indices.len()).rev() {
32 let j = rng.gen_range(0..=i);
33 indices.swap(i, j);
34 }
35 }
36 let mut train_features = Array2::zeros((n_train, features.ncols()));
37 let mut train_labels = Array1::zeros(n_train);
38 let mut test_features = Array2::zeros((n_test, features.ncols()));
39 let mut test_labels = Array1::zeros(n_test);
40 for (i, &idx) in indices[..n_train].iter().enumerate() {
41 train_features.row_mut(i).assign(&features.row(idx));
42 train_labels[i] = labels[idx];
43 }
44 for (i, &idx) in indices[n_train..].iter().enumerate() {
45 test_features.row_mut(i).assign(&features.row(idx));
46 test_labels[i] = labels[idx];
47 }
48 Ok((train_features, train_labels, test_features, test_labels))
49}
50pub fn train_test_split_regression(
52 features: &Array2<f64>,
53 labels: &Array1<f64>,
54 test_ratio: f64,
55 shuffle: bool,
56) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>, Array1<f64>)> {
57 if features.nrows() != labels.len() {
58 return Err(MLError::InvalidInput(
59 "Features and labels must have same number of samples".to_string(),
60 ));
61 }
62 if test_ratio <= 0.0 || test_ratio >= 1.0 {
63 return Err(MLError::InvalidInput(
64 "Test ratio must be between 0 and 1".to_string(),
65 ));
66 }
67 let n_samples = features.nrows();
68 let n_test = (n_samples as f64 * test_ratio) as usize;
69 let n_train = n_samples - n_test;
70 let mut indices: Vec<usize> = (0..n_samples).collect();
71 if shuffle {
72 let mut rng = thread_rng();
73 for i in (1..indices.len()).rev() {
74 let j = rng.gen_range(0..=i);
75 indices.swap(i, j);
76 }
77 }
78 let mut train_features = Array2::zeros((n_train, features.ncols()));
79 let mut train_labels = Array1::zeros(n_train);
80 let mut test_features = Array2::zeros((n_test, features.ncols()));
81 let mut test_labels = Array1::zeros(n_test);
82 for (i, &idx) in indices[..n_train].iter().enumerate() {
83 train_features.row_mut(i).assign(&features.row(idx));
84 train_labels[i] = labels[idx];
85 }
86 for (i, &idx) in indices[n_train..].iter().enumerate() {
87 test_features.row_mut(i).assign(&features.row(idx));
88 test_labels[i] = labels[idx];
89 }
90 Ok((train_features, train_labels, test_features, test_labels))
91}
92#[derive(Debug, Clone)]
94pub struct KFold {
95 n_splits: usize,
96 shuffle: bool,
97 indices: Vec<usize>,
98}
99impl KFold {
100 pub fn new(n_samples: usize, n_splits: usize, shuffle: bool) -> Result<Self> {
102 if n_splits < 2 {
103 return Err(MLError::InvalidInput(
104 "Number of splits must be at least 2".to_string(),
105 ));
106 }
107 if n_samples < n_splits {
108 return Err(MLError::InvalidInput(format!(
109 "Cannot have {} splits for {} samples",
110 n_splits, n_samples
111 )));
112 }
113 let mut indices: Vec<usize> = (0..n_samples).collect();
114 if shuffle {
115 let mut rng = thread_rng();
116 for i in (1..indices.len()).rev() {
117 let j = rng.gen_range(0..=i);
118 indices.swap(i, j);
119 }
120 }
121 Ok(Self {
122 n_splits,
123 shuffle,
124 indices,
125 })
126 }
127 pub fn n_splits(&self) -> usize {
129 self.n_splits
130 }
131 pub fn shuffle(&self) -> bool {
133 self.shuffle
134 }
135 pub fn get_fold(&self, fold: usize) -> Result<(Vec<usize>, Vec<usize>)> {
137 if fold >= self.n_splits {
138 return Err(MLError::InvalidInput(format!(
139 "Fold {} out of range for {} splits",
140 fold, self.n_splits
141 )));
142 }
143 let n_samples = self.indices.len();
144 let fold_size = n_samples / self.n_splits;
145 let n_larger_folds = n_samples % self.n_splits;
146 let start = if fold < n_larger_folds {
147 fold * (fold_size + 1)
148 } else {
149 n_larger_folds * (fold_size + 1) + (fold - n_larger_folds) * fold_size
150 };
151 let end = if fold < n_larger_folds {
152 start + fold_size + 1
153 } else {
154 start + fold_size
155 };
156 let test_indices: Vec<usize> = self.indices[start..end].to_vec();
157 let train_indices: Vec<usize> = self.indices[..start]
158 .iter()
159 .chain(self.indices[end..].iter())
160 .cloned()
161 .collect();
162 Ok((train_indices, test_indices))
163 }
164 pub fn split(
166 &self,
167 features: &Array2<f64>,
168 labels: &Array1<usize>,
169 fold: usize,
170 ) -> Result<(Array2<f64>, Array1<usize>, Array2<f64>, Array1<usize>)> {
171 let (train_idx, test_idx) = self.get_fold(fold)?;
172 let n_train = train_idx.len();
173 let n_test = test_idx.len();
174 let n_features = features.ncols();
175 let mut train_features = Array2::zeros((n_train, n_features));
176 let mut train_labels = Array1::zeros(n_train);
177 let mut test_features = Array2::zeros((n_test, n_features));
178 let mut test_labels = Array1::zeros(n_test);
179 for (i, &idx) in train_idx.iter().enumerate() {
180 train_features.row_mut(i).assign(&features.row(idx));
181 train_labels[i] = labels[idx];
182 }
183 for (i, &idx) in test_idx.iter().enumerate() {
184 test_features.row_mut(i).assign(&features.row(idx));
185 test_labels[i] = labels[idx];
186 }
187 Ok((train_features, train_labels, test_features, test_labels))
188 }
189}
190#[derive(Debug, Clone)]
193pub struct StratifiedKFold {
194 n_splits: usize,
195 fold_indices: Vec<Vec<usize>>,
196}
197impl StratifiedKFold {
198 pub fn new(labels: &Array1<usize>, n_splits: usize, shuffle: bool) -> Result<Self> {
200 if n_splits < 2 {
201 return Err(MLError::InvalidInput(
202 "Number of splits must be at least 2".to_string(),
203 ));
204 }
205 let n_samples = labels.len();
206 if n_samples < n_splits {
207 return Err(MLError::InvalidInput(format!(
208 "Cannot have {} splits for {} samples",
209 n_splits, n_samples
210 )));
211 }
212 let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
213 for (idx, &label) in labels.iter().enumerate() {
214 class_indices.entry(label).or_default().push(idx);
215 }
216 if shuffle {
217 let mut rng = thread_rng();
218 for indices in class_indices.values_mut() {
219 for i in (1..indices.len()).rev() {
220 let j = rng.gen_range(0..=i);
221 indices.swap(i, j);
222 }
223 }
224 }
225 let mut fold_indices: Vec<Vec<usize>> = vec![Vec::new(); n_splits];
226 for indices in class_indices.values() {
227 let n_class = indices.len();
228 let fold_size = n_class / n_splits;
229 let remainder = n_class % n_splits;
230 let mut current_idx = 0;
231 for fold in 0..n_splits {
232 let size = if fold < remainder {
233 fold_size + 1
234 } else {
235 fold_size
236 };
237 for &idx in &indices[current_idx..current_idx + size] {
238 fold_indices[fold].push(idx);
239 }
240 current_idx += size;
241 }
242 }
243 Ok(Self {
244 n_splits,
245 fold_indices,
246 })
247 }
248 pub fn n_splits(&self) -> usize {
250 self.n_splits
251 }
252 pub fn get_fold(&self, fold: usize) -> Result<(Vec<usize>, Vec<usize>)> {
254 if fold >= self.n_splits {
255 return Err(MLError::InvalidInput(format!(
256 "Fold {} out of range for {} splits",
257 fold, self.n_splits
258 )));
259 }
260 let test_indices = self.fold_indices[fold].clone();
261 let train_indices: Vec<usize> = self
262 .fold_indices
263 .iter()
264 .enumerate()
265 .filter(|(i, _)| *i != fold)
266 .flat_map(|(_, indices)| indices.iter().cloned())
267 .collect();
268 Ok((train_indices, test_indices))
269 }
270 pub fn split(
272 &self,
273 features: &Array2<f64>,
274 labels: &Array1<usize>,
275 fold: usize,
276 ) -> Result<(Array2<f64>, Array1<usize>, Array2<f64>, Array1<usize>)> {
277 let (train_idx, test_idx) = self.get_fold(fold)?;
278 let n_train = train_idx.len();
279 let n_test = test_idx.len();
280 let n_features = features.ncols();
281 let mut train_features = Array2::zeros((n_train, n_features));
282 let mut train_labels = Array1::zeros(n_train);
283 let mut test_features = Array2::zeros((n_test, n_features));
284 let mut test_labels = Array1::zeros(n_test);
285 for (i, &idx) in train_idx.iter().enumerate() {
286 train_features.row_mut(i).assign(&features.row(idx));
287 train_labels[i] = labels[idx];
288 }
289 for (i, &idx) in test_idx.iter().enumerate() {
290 test_features.row_mut(i).assign(&features.row(idx));
291 test_labels[i] = labels[idx];
292 }
293 Ok((train_features, train_labels, test_features, test_labels))
294 }
295}
296pub struct LeaveOneOut {
298 n_samples: usize,
299}
300impl LeaveOneOut {
301 pub fn new(n_samples: usize) -> Self {
303 Self { n_samples }
304 }
305 pub fn n_splits(&self) -> usize {
307 self.n_samples
308 }
309 pub fn get_fold(&self, fold: usize) -> Result<(Vec<usize>, Vec<usize>)> {
311 if fold >= self.n_samples {
312 return Err(MLError::InvalidInput(format!(
313 "Fold {} out of range for {} samples",
314 fold, self.n_samples
315 )));
316 }
317 let test_indices = vec![fold];
318 let train_indices: Vec<usize> = (0..self.n_samples).filter(|&i| i != fold).collect();
319 Ok((train_indices, test_indices))
320 }
321}
322#[derive(Debug, Clone)]
324pub struct RepeatedKFold {
325 n_splits: usize,
326 n_repeats: usize,
327 n_samples: usize,
328}
329impl RepeatedKFold {
330 pub fn new(n_samples: usize, n_splits: usize, n_repeats: usize) -> Result<Self> {
332 if n_splits < 2 {
333 return Err(MLError::InvalidInput(
334 "Number of splits must be at least 2".to_string(),
335 ));
336 }
337 if n_repeats < 1 {
338 return Err(MLError::InvalidInput(
339 "Number of repeats must be at least 1".to_string(),
340 ));
341 }
342 if n_samples < n_splits {
343 return Err(MLError::InvalidInput(format!(
344 "Cannot have {} splits for {} samples",
345 n_splits, n_samples
346 )));
347 }
348 Ok(Self {
349 n_splits,
350 n_repeats,
351 n_samples,
352 })
353 }
354 pub fn total_splits(&self) -> usize {
356 self.n_splits * self.n_repeats
357 }
358 pub fn get_iteration(&self, iteration: usize) -> Result<(Vec<usize>, Vec<usize>)> {
361 if iteration >= self.total_splits() {
362 return Err(MLError::InvalidInput(format!(
363 "Iteration {} out of range for {} total splits",
364 iteration,
365 self.total_splits()
366 )));
367 }
368 let fold = iteration % self.n_splits;
369 let kfold = KFold::new(self.n_samples, self.n_splits, true)?;
370 kfold.get_fold(fold)
371 }
372}
373#[derive(Debug, Clone)]
376pub struct TimeSeriesSplit {
377 n_splits: usize,
378 n_samples: usize,
379 max_train_size: Option<usize>,
380 test_size: Option<usize>,
381 gap: usize,
382}
383impl TimeSeriesSplit {
384 pub fn new(
393 n_samples: usize,
394 n_splits: usize,
395 max_train_size: Option<usize>,
396 test_size: Option<usize>,
397 gap: usize,
398 ) -> Result<Self> {
399 if n_splits < 2 {
400 return Err(MLError::InvalidInput(
401 "Number of splits must be at least 2".to_string(),
402 ));
403 }
404 if n_samples < n_splits + 1 {
405 return Err(MLError::InvalidInput(format!(
406 "Cannot have {} splits for {} samples",
407 n_splits, n_samples
408 )));
409 }
410 Ok(Self {
411 n_splits,
412 n_samples,
413 max_train_size,
414 test_size,
415 gap,
416 })
417 }
418 pub fn n_splits(&self) -> usize {
420 self.n_splits
421 }
422 pub fn get_fold(&self, fold: usize) -> Result<(Vec<usize>, Vec<usize>)> {
424 if fold >= self.n_splits {
425 return Err(MLError::InvalidInput(format!(
426 "Fold {} out of range for {} splits",
427 fold, self.n_splits
428 )));
429 }
430 let test_size = self
431 .test_size
432 .unwrap_or((self.n_samples - self.gap) / (self.n_splits + 1));
433 let test_start = (fold + 1) * test_size + self.gap;
434 let test_end = (test_start + test_size).min(self.n_samples);
435 let train_end = test_start - self.gap;
436 let train_start = if let Some(max_size) = self.max_train_size {
437 train_end.saturating_sub(max_size)
438 } else {
439 0
440 };
441 let train_indices: Vec<usize> = (train_start..train_end).collect();
442 let test_indices: Vec<usize> = (test_start..test_end).collect();
443 Ok((train_indices, test_indices))
444 }
445 pub fn split(
447 &self,
448 features: &Array2<f64>,
449 labels: &Array1<usize>,
450 fold: usize,
451 ) -> Result<(Array2<f64>, Array1<usize>, Array2<f64>, Array1<usize>)> {
452 let (train_idx, test_idx) = self.get_fold(fold)?;
453 let n_train = train_idx.len();
454 let n_test = test_idx.len();
455 let n_features = features.ncols();
456 let mut train_features = Array2::zeros((n_train, n_features));
457 let mut train_labels = Array1::zeros(n_train);
458 let mut test_features = Array2::zeros((n_test, n_features));
459 let mut test_labels = Array1::zeros(n_test);
460 for (i, &idx) in train_idx.iter().enumerate() {
461 train_features.row_mut(i).assign(&features.row(idx));
462 train_labels[i] = labels[idx];
463 }
464 for (i, &idx) in test_idx.iter().enumerate() {
465 test_features.row_mut(i).assign(&features.row(idx));
466 test_labels[i] = labels[idx];
467 }
468 Ok((train_features, train_labels, test_features, test_labels))
469 }
470 pub fn split_regression(
472 &self,
473 features: &Array2<f64>,
474 labels: &Array1<f64>,
475 fold: usize,
476 ) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>, Array1<f64>)> {
477 let (train_idx, test_idx) = self.get_fold(fold)?;
478 let n_train = train_idx.len();
479 let n_test = test_idx.len();
480 let n_features = features.ncols();
481 let mut train_features = Array2::zeros((n_train, n_features));
482 let mut train_labels = Array1::zeros(n_train);
483 let mut test_features = Array2::zeros((n_test, n_features));
484 let mut test_labels = Array1::zeros(n_test);
485 for (i, &idx) in train_idx.iter().enumerate() {
486 train_features.row_mut(i).assign(&features.row(idx));
487 train_labels[i] = labels[idx];
488 }
489 for (i, &idx) in test_idx.iter().enumerate() {
490 test_features.row_mut(i).assign(&features.row(idx));
491 test_labels[i] = labels[idx];
492 }
493 Ok((train_features, train_labels, test_features, test_labels))
494 }
495}
496#[derive(Debug, Clone)]
499pub struct BlockedTimeSeriesSplit {
500 n_splits: usize,
501 group_boundaries: Vec<usize>,
502}
503impl BlockedTimeSeriesSplit {
504 pub fn new(group_sizes: &[usize], n_splits: usize) -> Result<Self> {
510 if n_splits < 2 {
511 return Err(MLError::InvalidInput(
512 "Number of splits must be at least 2".to_string(),
513 ));
514 }
515 if group_sizes.len() < n_splits + 1 {
516 return Err(MLError::InvalidInput(format!(
517 "Need at least {} groups for {} splits",
518 n_splits + 1,
519 n_splits
520 )));
521 }
522 let mut boundaries = vec![0];
523 let mut cumsum = 0;
524 for &size in group_sizes {
525 cumsum += size;
526 boundaries.push(cumsum);
527 }
528 Ok(Self {
529 n_splits,
530 group_boundaries: boundaries,
531 })
532 }
533 pub fn n_splits(&self) -> usize {
535 self.n_splits
536 }
537 pub fn get_fold(&self, fold: usize) -> Result<(Vec<usize>, Vec<usize>)> {
539 if fold >= self.n_splits {
540 return Err(MLError::InvalidInput(format!(
541 "Fold {} out of range for {} splits",
542 fold, self.n_splits
543 )));
544 }
545 let n_groups = self.group_boundaries.len() - 1;
546 let groups_per_fold = n_groups / (self.n_splits + 1);
547 let train_end_group = (fold + 1) * groups_per_fold;
548 let test_end_group = (train_end_group + groups_per_fold).min(n_groups);
549 let train_start = self.group_boundaries[0];
550 let train_end = self.group_boundaries[train_end_group];
551 let test_start = train_end;
552 let test_end = self.group_boundaries[test_end_group];
553 let train_indices: Vec<usize> = (train_start..train_end).collect();
554 let test_indices: Vec<usize> = (test_start..test_end).collect();
555 Ok((train_indices, test_indices))
556 }
557}