quantrs2_ml/utils/
split.rs

1//! Data splitting utilities for cross-validation and train/test splits
2
3use crate::error::{MLError, Result};
4use scirs2_core::ndarray::{Array1, Array2, Axis};
5use scirs2_core::random::prelude::*;
6
7use super::*;
8/// Split data into train and test sets
9pub fn train_test_split(
10    features: &Array2<f64>,
11    labels: &Array1<usize>,
12    test_ratio: f64,
13    shuffle: bool,
14) -> Result<(Array2<f64>, Array1<usize>, Array2<f64>, Array1<usize>)> {
15    if features.nrows() != labels.len() {
16        return Err(MLError::InvalidInput(
17            "Features and labels must have same number of samples".to_string(),
18        ));
19    }
20    if test_ratio <= 0.0 || test_ratio >= 1.0 {
21        return Err(MLError::InvalidInput(
22            "Test ratio must be between 0 and 1".to_string(),
23        ));
24    }
25    let n_samples = features.nrows();
26    let n_test = (n_samples as f64 * test_ratio) as usize;
27    let n_train = n_samples - n_test;
28    let mut indices: Vec<usize> = (0..n_samples).collect();
29    if shuffle {
30        let mut rng = thread_rng();
31        for i in (1..indices.len()).rev() {
32            let j = rng.gen_range(0..=i);
33            indices.swap(i, j);
34        }
35    }
36    let mut train_features = Array2::zeros((n_train, features.ncols()));
37    let mut train_labels = Array1::zeros(n_train);
38    let mut test_features = Array2::zeros((n_test, features.ncols()));
39    let mut test_labels = Array1::zeros(n_test);
40    for (i, &idx) in indices[..n_train].iter().enumerate() {
41        train_features.row_mut(i).assign(&features.row(idx));
42        train_labels[i] = labels[idx];
43    }
44    for (i, &idx) in indices[n_train..].iter().enumerate() {
45        test_features.row_mut(i).assign(&features.row(idx));
46        test_labels[i] = labels[idx];
47    }
48    Ok((train_features, train_labels, test_features, test_labels))
49}
50/// Split regression data into train and test sets (with continuous labels)
51pub fn train_test_split_regression(
52    features: &Array2<f64>,
53    labels: &Array1<f64>,
54    test_ratio: f64,
55    shuffle: bool,
56) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>, Array1<f64>)> {
57    if features.nrows() != labels.len() {
58        return Err(MLError::InvalidInput(
59            "Features and labels must have same number of samples".to_string(),
60        ));
61    }
62    if test_ratio <= 0.0 || test_ratio >= 1.0 {
63        return Err(MLError::InvalidInput(
64            "Test ratio must be between 0 and 1".to_string(),
65        ));
66    }
67    let n_samples = features.nrows();
68    let n_test = (n_samples as f64 * test_ratio) as usize;
69    let n_train = n_samples - n_test;
70    let mut indices: Vec<usize> = (0..n_samples).collect();
71    if shuffle {
72        let mut rng = thread_rng();
73        for i in (1..indices.len()).rev() {
74            let j = rng.gen_range(0..=i);
75            indices.swap(i, j);
76        }
77    }
78    let mut train_features = Array2::zeros((n_train, features.ncols()));
79    let mut train_labels = Array1::zeros(n_train);
80    let mut test_features = Array2::zeros((n_test, features.ncols()));
81    let mut test_labels = Array1::zeros(n_test);
82    for (i, &idx) in indices[..n_train].iter().enumerate() {
83        train_features.row_mut(i).assign(&features.row(idx));
84        train_labels[i] = labels[idx];
85    }
86    for (i, &idx) in indices[n_train..].iter().enumerate() {
87        test_features.row_mut(i).assign(&features.row(idx));
88        test_labels[i] = labels[idx];
89    }
90    Ok((train_features, train_labels, test_features, test_labels))
91}
92/// K-Fold cross-validation split indices generator
93#[derive(Debug, Clone)]
94pub struct KFold {
95    n_splits: usize,
96    shuffle: bool,
97    indices: Vec<usize>,
98}
99impl KFold {
100    /// Create a new K-Fold splitter
101    pub fn new(n_samples: usize, n_splits: usize, shuffle: bool) -> Result<Self> {
102        if n_splits < 2 {
103            return Err(MLError::InvalidInput(
104                "Number of splits must be at least 2".to_string(),
105            ));
106        }
107        if n_samples < n_splits {
108            return Err(MLError::InvalidInput(format!(
109                "Cannot have {} splits for {} samples",
110                n_splits, n_samples
111            )));
112        }
113        let mut indices: Vec<usize> = (0..n_samples).collect();
114        if shuffle {
115            let mut rng = thread_rng();
116            for i in (1..indices.len()).rev() {
117                let j = rng.gen_range(0..=i);
118                indices.swap(i, j);
119            }
120        }
121        Ok(Self {
122            n_splits,
123            shuffle,
124            indices,
125        })
126    }
127    /// Get the number of splits
128    pub fn n_splits(&self) -> usize {
129        self.n_splits
130    }
131    /// Get whether shuffling is enabled
132    pub fn shuffle(&self) -> bool {
133        self.shuffle
134    }
135    /// Get train and test indices for a specific fold
136    pub fn get_fold(&self, fold: usize) -> Result<(Vec<usize>, Vec<usize>)> {
137        if fold >= self.n_splits {
138            return Err(MLError::InvalidInput(format!(
139                "Fold {} out of range for {} splits",
140                fold, self.n_splits
141            )));
142        }
143        let n_samples = self.indices.len();
144        let fold_size = n_samples / self.n_splits;
145        let n_larger_folds = n_samples % self.n_splits;
146        let start = if fold < n_larger_folds {
147            fold * (fold_size + 1)
148        } else {
149            n_larger_folds * (fold_size + 1) + (fold - n_larger_folds) * fold_size
150        };
151        let end = if fold < n_larger_folds {
152            start + fold_size + 1
153        } else {
154            start + fold_size
155        };
156        let test_indices: Vec<usize> = self.indices[start..end].to_vec();
157        let train_indices: Vec<usize> = self.indices[..start]
158            .iter()
159            .chain(self.indices[end..].iter())
160            .cloned()
161            .collect();
162        Ok((train_indices, test_indices))
163    }
164    /// Split features and labels for a specific fold
165    pub fn split(
166        &self,
167        features: &Array2<f64>,
168        labels: &Array1<usize>,
169        fold: usize,
170    ) -> Result<(Array2<f64>, Array1<usize>, Array2<f64>, Array1<usize>)> {
171        let (train_idx, test_idx) = self.get_fold(fold)?;
172        let n_train = train_idx.len();
173        let n_test = test_idx.len();
174        let n_features = features.ncols();
175        let mut train_features = Array2::zeros((n_train, n_features));
176        let mut train_labels = Array1::zeros(n_train);
177        let mut test_features = Array2::zeros((n_test, n_features));
178        let mut test_labels = Array1::zeros(n_test);
179        for (i, &idx) in train_idx.iter().enumerate() {
180            train_features.row_mut(i).assign(&features.row(idx));
181            train_labels[i] = labels[idx];
182        }
183        for (i, &idx) in test_idx.iter().enumerate() {
184            test_features.row_mut(i).assign(&features.row(idx));
185            test_labels[i] = labels[idx];
186        }
187        Ok((train_features, train_labels, test_features, test_labels))
188    }
189}
190/// Stratified K-Fold cross-validation split indices generator
191/// Ensures each fold has approximately the same percentage of samples of each class
192#[derive(Debug, Clone)]
193pub struct StratifiedKFold {
194    n_splits: usize,
195    fold_indices: Vec<Vec<usize>>,
196}
197impl StratifiedKFold {
198    /// Create a new Stratified K-Fold splitter
199    pub fn new(labels: &Array1<usize>, n_splits: usize, shuffle: bool) -> Result<Self> {
200        if n_splits < 2 {
201            return Err(MLError::InvalidInput(
202                "Number of splits must be at least 2".to_string(),
203            ));
204        }
205        let n_samples = labels.len();
206        if n_samples < n_splits {
207            return Err(MLError::InvalidInput(format!(
208                "Cannot have {} splits for {} samples",
209                n_splits, n_samples
210            )));
211        }
212        let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
213        for (idx, &label) in labels.iter().enumerate() {
214            class_indices.entry(label).or_default().push(idx);
215        }
216        if shuffle {
217            let mut rng = thread_rng();
218            for indices in class_indices.values_mut() {
219                for i in (1..indices.len()).rev() {
220                    let j = rng.gen_range(0..=i);
221                    indices.swap(i, j);
222                }
223            }
224        }
225        let mut fold_indices: Vec<Vec<usize>> = vec![Vec::new(); n_splits];
226        for indices in class_indices.values() {
227            let n_class = indices.len();
228            let fold_size = n_class / n_splits;
229            let remainder = n_class % n_splits;
230            let mut current_idx = 0;
231            for fold in 0..n_splits {
232                let size = if fold < remainder {
233                    fold_size + 1
234                } else {
235                    fold_size
236                };
237                for &idx in &indices[current_idx..current_idx + size] {
238                    fold_indices[fold].push(idx);
239                }
240                current_idx += size;
241            }
242        }
243        Ok(Self {
244            n_splits,
245            fold_indices,
246        })
247    }
248    /// Get the number of splits
249    pub fn n_splits(&self) -> usize {
250        self.n_splits
251    }
252    /// Get train and test indices for a specific fold
253    pub fn get_fold(&self, fold: usize) -> Result<(Vec<usize>, Vec<usize>)> {
254        if fold >= self.n_splits {
255            return Err(MLError::InvalidInput(format!(
256                "Fold {} out of range for {} splits",
257                fold, self.n_splits
258            )));
259        }
260        let test_indices = self.fold_indices[fold].clone();
261        let train_indices: Vec<usize> = self
262            .fold_indices
263            .iter()
264            .enumerate()
265            .filter(|(i, _)| *i != fold)
266            .flat_map(|(_, indices)| indices.iter().cloned())
267            .collect();
268        Ok((train_indices, test_indices))
269    }
270    /// Split features and labels for a specific fold
271    pub fn split(
272        &self,
273        features: &Array2<f64>,
274        labels: &Array1<usize>,
275        fold: usize,
276    ) -> Result<(Array2<f64>, Array1<usize>, Array2<f64>, Array1<usize>)> {
277        let (train_idx, test_idx) = self.get_fold(fold)?;
278        let n_train = train_idx.len();
279        let n_test = test_idx.len();
280        let n_features = features.ncols();
281        let mut train_features = Array2::zeros((n_train, n_features));
282        let mut train_labels = Array1::zeros(n_train);
283        let mut test_features = Array2::zeros((n_test, n_features));
284        let mut test_labels = Array1::zeros(n_test);
285        for (i, &idx) in train_idx.iter().enumerate() {
286            train_features.row_mut(i).assign(&features.row(idx));
287            train_labels[i] = labels[idx];
288        }
289        for (i, &idx) in test_idx.iter().enumerate() {
290            test_features.row_mut(i).assign(&features.row(idx));
291            test_labels[i] = labels[idx];
292        }
293        Ok((train_features, train_labels, test_features, test_labels))
294    }
295}
296/// Leave-One-Out cross-validation
297pub struct LeaveOneOut {
298    n_samples: usize,
299}
300impl LeaveOneOut {
301    /// Create a new Leave-One-Out splitter
302    pub fn new(n_samples: usize) -> Self {
303        Self { n_samples }
304    }
305    /// Get the number of splits (equal to number of samples)
306    pub fn n_splits(&self) -> usize {
307        self.n_samples
308    }
309    /// Get train and test indices for a specific fold
310    pub fn get_fold(&self, fold: usize) -> Result<(Vec<usize>, Vec<usize>)> {
311        if fold >= self.n_samples {
312            return Err(MLError::InvalidInput(format!(
313                "Fold {} out of range for {} samples",
314                fold, self.n_samples
315            )));
316        }
317        let test_indices = vec![fold];
318        let train_indices: Vec<usize> = (0..self.n_samples).filter(|&i| i != fold).collect();
319        Ok((train_indices, test_indices))
320    }
321}
322/// Repeated K-Fold cross-validation
323#[derive(Debug, Clone)]
324pub struct RepeatedKFold {
325    n_splits: usize,
326    n_repeats: usize,
327    n_samples: usize,
328}
329impl RepeatedKFold {
330    /// Create a new Repeated K-Fold splitter
331    pub fn new(n_samples: usize, n_splits: usize, n_repeats: usize) -> Result<Self> {
332        if n_splits < 2 {
333            return Err(MLError::InvalidInput(
334                "Number of splits must be at least 2".to_string(),
335            ));
336        }
337        if n_repeats < 1 {
338            return Err(MLError::InvalidInput(
339                "Number of repeats must be at least 1".to_string(),
340            ));
341        }
342        if n_samples < n_splits {
343            return Err(MLError::InvalidInput(format!(
344                "Cannot have {} splits for {} samples",
345                n_splits, n_samples
346            )));
347        }
348        Ok(Self {
349            n_splits,
350            n_repeats,
351            n_samples,
352        })
353    }
354    /// Get total number of splits across all repeats
355    pub fn total_splits(&self) -> usize {
356        self.n_splits * self.n_repeats
357    }
358    /// Get train and test indices for a specific iteration
359    /// The iteration is: repeat * n_splits + fold
360    pub fn get_iteration(&self, iteration: usize) -> Result<(Vec<usize>, Vec<usize>)> {
361        if iteration >= self.total_splits() {
362            return Err(MLError::InvalidInput(format!(
363                "Iteration {} out of range for {} total splits",
364                iteration,
365                self.total_splits()
366            )));
367        }
368        let fold = iteration % self.n_splits;
369        let kfold = KFold::new(self.n_samples, self.n_splits, true)?;
370        kfold.get_fold(fold)
371    }
372}
373/// Time Series Split for temporal data
374/// Provides train/test indices to split time series data while preserving temporal order
375#[derive(Debug, Clone)]
376pub struct TimeSeriesSplit {
377    n_splits: usize,
378    n_samples: usize,
379    max_train_size: Option<usize>,
380    test_size: Option<usize>,
381    gap: usize,
382}
383impl TimeSeriesSplit {
384    /// Create a new Time Series Split
385    ///
386    /// # Arguments
387    /// * `n_samples` - Total number of samples
388    /// * `n_splits` - Number of splits (must be at least 2)
389    /// * `max_train_size` - Maximum size of training set (None for no limit)
390    /// * `test_size` - Fixed test set size (None for equal splits)
391    /// * `gap` - Number of samples to exclude between train and test
392    pub fn new(
393        n_samples: usize,
394        n_splits: usize,
395        max_train_size: Option<usize>,
396        test_size: Option<usize>,
397        gap: usize,
398    ) -> Result<Self> {
399        if n_splits < 2 {
400            return Err(MLError::InvalidInput(
401                "Number of splits must be at least 2".to_string(),
402            ));
403        }
404        if n_samples < n_splits + 1 {
405            return Err(MLError::InvalidInput(format!(
406                "Cannot have {} splits for {} samples",
407                n_splits, n_samples
408            )));
409        }
410        Ok(Self {
411            n_splits,
412            n_samples,
413            max_train_size,
414            test_size,
415            gap,
416        })
417    }
418    /// Get the number of splits
419    pub fn n_splits(&self) -> usize {
420        self.n_splits
421    }
422    /// Get train and test indices for a specific fold
423    pub fn get_fold(&self, fold: usize) -> Result<(Vec<usize>, Vec<usize>)> {
424        if fold >= self.n_splits {
425            return Err(MLError::InvalidInput(format!(
426                "Fold {} out of range for {} splits",
427                fold, self.n_splits
428            )));
429        }
430        let test_size = self
431            .test_size
432            .unwrap_or((self.n_samples - self.gap) / (self.n_splits + 1));
433        let test_start = (fold + 1) * test_size + self.gap;
434        let test_end = (test_start + test_size).min(self.n_samples);
435        let train_end = test_start - self.gap;
436        let train_start = if let Some(max_size) = self.max_train_size {
437            train_end.saturating_sub(max_size)
438        } else {
439            0
440        };
441        let train_indices: Vec<usize> = (train_start..train_end).collect();
442        let test_indices: Vec<usize> = (test_start..test_end).collect();
443        Ok((train_indices, test_indices))
444    }
445    /// Split features and labels for a specific fold
446    pub fn split(
447        &self,
448        features: &Array2<f64>,
449        labels: &Array1<usize>,
450        fold: usize,
451    ) -> Result<(Array2<f64>, Array1<usize>, Array2<f64>, Array1<usize>)> {
452        let (train_idx, test_idx) = self.get_fold(fold)?;
453        let n_train = train_idx.len();
454        let n_test = test_idx.len();
455        let n_features = features.ncols();
456        let mut train_features = Array2::zeros((n_train, n_features));
457        let mut train_labels = Array1::zeros(n_train);
458        let mut test_features = Array2::zeros((n_test, n_features));
459        let mut test_labels = Array1::zeros(n_test);
460        for (i, &idx) in train_idx.iter().enumerate() {
461            train_features.row_mut(i).assign(&features.row(idx));
462            train_labels[i] = labels[idx];
463        }
464        for (i, &idx) in test_idx.iter().enumerate() {
465            test_features.row_mut(i).assign(&features.row(idx));
466            test_labels[i] = labels[idx];
467        }
468        Ok((train_features, train_labels, test_features, test_labels))
469    }
470    /// Split regression features and labels for a specific fold
471    pub fn split_regression(
472        &self,
473        features: &Array2<f64>,
474        labels: &Array1<f64>,
475        fold: usize,
476    ) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>, Array1<f64>)> {
477        let (train_idx, test_idx) = self.get_fold(fold)?;
478        let n_train = train_idx.len();
479        let n_test = test_idx.len();
480        let n_features = features.ncols();
481        let mut train_features = Array2::zeros((n_train, n_features));
482        let mut train_labels = Array1::zeros(n_train);
483        let mut test_features = Array2::zeros((n_test, n_features));
484        let mut test_labels = Array1::zeros(n_test);
485        for (i, &idx) in train_idx.iter().enumerate() {
486            train_features.row_mut(i).assign(&features.row(idx));
487            train_labels[i] = labels[idx];
488        }
489        for (i, &idx) in test_idx.iter().enumerate() {
490            test_features.row_mut(i).assign(&features.row(idx));
491            test_labels[i] = labels[idx];
492        }
493        Ok((train_features, train_labels, test_features, test_labels))
494    }
495}
496/// Blocked Time Series Split (for grouped temporal data)
497/// Splits data into train/test while respecting group boundaries
498#[derive(Debug, Clone)]
499pub struct BlockedTimeSeriesSplit {
500    n_splits: usize,
501    group_boundaries: Vec<usize>,
502}
503impl BlockedTimeSeriesSplit {
504    /// Create a new Blocked Time Series Split
505    ///
506    /// # Arguments
507    /// * `group_sizes` - Sizes of each temporal group/block
508    /// * `n_splits` - Number of splits
509    pub fn new(group_sizes: &[usize], n_splits: usize) -> Result<Self> {
510        if n_splits < 2 {
511            return Err(MLError::InvalidInput(
512                "Number of splits must be at least 2".to_string(),
513            ));
514        }
515        if group_sizes.len() < n_splits + 1 {
516            return Err(MLError::InvalidInput(format!(
517                "Need at least {} groups for {} splits",
518                n_splits + 1,
519                n_splits
520            )));
521        }
522        let mut boundaries = vec![0];
523        let mut cumsum = 0;
524        for &size in group_sizes {
525            cumsum += size;
526            boundaries.push(cumsum);
527        }
528        Ok(Self {
529            n_splits,
530            group_boundaries: boundaries,
531        })
532    }
533    /// Get the number of splits
534    pub fn n_splits(&self) -> usize {
535        self.n_splits
536    }
537    /// Get train and test indices for a specific fold
538    pub fn get_fold(&self, fold: usize) -> Result<(Vec<usize>, Vec<usize>)> {
539        if fold >= self.n_splits {
540            return Err(MLError::InvalidInput(format!(
541                "Fold {} out of range for {} splits",
542                fold, self.n_splits
543            )));
544        }
545        let n_groups = self.group_boundaries.len() - 1;
546        let groups_per_fold = n_groups / (self.n_splits + 1);
547        let train_end_group = (fold + 1) * groups_per_fold;
548        let test_end_group = (train_end_group + groups_per_fold).min(n_groups);
549        let train_start = self.group_boundaries[0];
550        let train_end = self.group_boundaries[train_end_group];
551        let test_start = train_end;
552        let test_end = self.group_boundaries[test_end_group];
553        let train_indices: Vec<usize> = (train_start..train_end).collect();
554        let test_indices: Vec<usize> = (test_start..test_end).collect();
555        Ok((train_indices, test_indices))
556    }
557}