treeboost 0.1.0

High-performance Gradient Boosted Decision Tree engine for large-scale tabular data
Documentation
//! Dataset splitting utilities for training, validation, and cross-validation
//!
//! Provides deterministic, cache-friendly index splitting for:
//! - Holdout validation (train/val/calib three-way split)
//! - K-fold cross-validation
//!
//! All splits are:
//! - Deterministic: Same seed produces identical splits
//! - Cache-friendly: Indices are sorted after random selection (~47% speedup)
//! - Disjoint: No overlap between partitions

use rand::seq::SliceRandom;
use rand::SeedableRng;

/// Result of a holdout split
#[derive(Debug, Clone)]
pub struct HoldoutSplit {
    /// Training set indices (sorted for cache locality)
    pub train: Vec<usize>,
    /// Validation set indices (sorted for cache locality)
    pub validation: Vec<usize>,
    /// Calibration set indices (sorted for cache locality)
    pub calibration: Vec<usize>,
}

impl HoldoutSplit {
    /// Get the number of training samples
    pub fn train_len(&self) -> usize {
        self.train.len()
    }

    /// Get the number of validation samples
    pub fn val_len(&self) -> usize {
        self.validation.len()
    }

    /// Get the number of calibration samples
    pub fn calib_len(&self) -> usize {
        self.calibration.len()
    }
}

/// Result of a K-fold split
#[derive(Debug, Clone)]
pub struct KFoldSplit {
    /// All indices partitioned into k disjoint folds
    pub folds: Vec<Vec<usize>>,
}

impl KFoldSplit {
    /// Get the number of folds
    pub fn k(&self) -> usize {
        self.folds.len()
    }

    /// Get train and validation indices for a specific fold
    ///
    /// # Arguments
    /// * `fold_idx` - The fold to use as validation (0-indexed)
    ///
    /// # Returns
    /// * `(train_indices, val_indices)` - Both sorted for cache locality
    pub fn get_fold(&self, fold_idx: usize) -> (Vec<usize>, Vec<usize>) {
        assert!(fold_idx < self.folds.len(), "fold_idx out of range");

        // Validation: the selected fold
        let mut validation = self.folds[fold_idx].clone();

        // Training: all other folds combined
        let mut train: Vec<usize> = self
            .folds
            .iter()
            .enumerate()
            .filter(|(i, _)| *i != fold_idx)
            .flat_map(|(_, fold)| fold.iter().cloned())
            .collect();

        // Sort for cache locality
        train.sort_unstable();
        validation.sort_unstable();

        (train, validation)
    }
}

/// Split indices for holdout validation
///
/// Produces a three-way split into training, validation, and calibration sets.
/// Indices are sorted after random selection for cache-friendly sequential access.
///
/// # Arguments
/// * `num_rows` - Total number of samples
/// * `validation_ratio` - Fraction for validation (0.0 to skip)
/// * `calibration_ratio` - Fraction for calibration (0.0 to skip)
/// * `seed` - Random seed for reproducibility
///
/// # Example
/// ```ignore
/// let split = split_holdout(1000, 0.2, 0.0, 42);
/// assert_eq!(split.train.len() + split.validation.len(), 1000);
/// ```
pub fn split_holdout(
    num_rows: usize,
    validation_ratio: f32,
    calibration_ratio: f32,
    seed: u64,
) -> HoldoutSplit {
    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
    let mut indices: Vec<usize> = (0..num_rows).collect();
    indices.shuffle(&mut rng);

    // First split off calibration set
    let n_calibration = if calibration_ratio > 0.0 {
        ((num_rows as f32) * calibration_ratio).ceil() as usize
    } else {
        0
    };
    let mut calibration: Vec<usize> = indices.drain(..n_calibration).collect();

    // Then split off validation set from remaining
    let n_validation = if validation_ratio > 0.0 {
        ((indices.len() as f32) * validation_ratio / (1.0 - calibration_ratio)).ceil() as usize
    } else {
        0
    };
    let mut validation: Vec<usize> = indices.drain(..n_validation).collect();

    // Remaining is training set
    let mut train = indices;

    // Sort all index vectors for cache-friendly sequential access
    // This maintains random selection but enables sequential memory access patterns
    train.sort_unstable();
    validation.sort_unstable();
    calibration.sort_unstable();

    HoldoutSplit {
        train,
        validation,
        calibration,
    }
}

/// Split indices for K-fold cross-validation
///
/// Partitions indices into k disjoint folds of approximately equal size.
/// Each fold's indices are sorted for cache-friendly access.
///
/// # Arguments
/// * `num_rows` - Total number of samples
/// * `k` - Number of folds (must be >= 2)
/// * `seed` - Random seed for reproducibility
///
/// # Example
/// ```ignore
/// let split = split_kfold(1000, 5, 42);
/// for i in 0..5 {
///     let (train, val) = split.get_fold(i);
///     assert_eq!(train.len() + val.len(), 1000);
/// }
/// ```
pub fn split_kfold(num_rows: usize, k: usize, seed: u64) -> KFoldSplit {
    assert!(k >= 2, "K-fold requires at least 2 folds");

    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
    let mut indices: Vec<usize> = (0..num_rows).collect();
    indices.shuffle(&mut rng);

    // Partition into k folds
    let fold_size = num_rows / k;
    let remainder = num_rows % k;

    let mut folds = Vec::with_capacity(k);
    let mut start = 0;

    for i in 0..k {
        // First `remainder` folds get one extra sample
        let extra = if i < remainder { 1 } else { 0 };
        let end = start + fold_size + extra;

        let mut fold: Vec<usize> = indices[start..end].to_vec();
        fold.sort_unstable(); // Sort for cache locality
        folds.push(fold);

        start = end;
    }

    KFoldSplit { folds }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashSet;

    #[test]
    fn test_split_holdout_basic() {
        let split = split_holdout(1000, 0.2, 0.0, 42);

        // Check sizes (validation should be ~200)
        assert_eq!(split.train.len() + split.validation.len(), 1000);
        assert!(split.validation.len() >= 190 && split.validation.len() <= 210);
        assert_eq!(split.calibration.len(), 0);
    }

    #[test]
    fn test_split_holdout_three_way() {
        let split = split_holdout(1000, 0.2, 0.1, 42);

        // Check all indices are accounted for
        let total = split.train.len() + split.validation.len() + split.calibration.len();
        assert_eq!(total, 1000);

        // Check no overlap
        let train_set: HashSet<_> = split.train.iter().collect();
        let val_set: HashSet<_> = split.validation.iter().collect();
        let calib_set: HashSet<_> = split.calibration.iter().collect();

        assert!(train_set.is_disjoint(&val_set));
        assert!(train_set.is_disjoint(&calib_set));
        assert!(val_set.is_disjoint(&calib_set));
    }

    #[test]
    fn test_split_holdout_sorted() {
        let split = split_holdout(1000, 0.2, 0.1, 42);

        // Check all sets are sorted
        assert!(split.train.windows(2).all(|w| w[0] < w[1]));
        assert!(split.validation.windows(2).all(|w| w[0] < w[1]));
        if split.calibration.len() > 1 {
            assert!(split.calibration.windows(2).all(|w| w[0] < w[1]));
        }
    }

    #[test]
    fn test_split_holdout_deterministic() {
        let split1 = split_holdout(1000, 0.2, 0.0, 42);
        let split2 = split_holdout(1000, 0.2, 0.0, 42);

        assert_eq!(split1.train, split2.train);
        assert_eq!(split1.validation, split2.validation);

        // Different seed should produce different split
        let split3 = split_holdout(1000, 0.2, 0.0, 43);
        assert_ne!(split1.train, split3.train);
    }

    #[test]
    fn test_split_kfold_basic() {
        let split = split_kfold(100, 5, 42);

        assert_eq!(split.k(), 5);

        // Check all folds combined cover all indices
        let all_indices: HashSet<_> = split.folds.iter().flatten().cloned().collect();
        assert_eq!(all_indices.len(), 100);
    }

    #[test]
    fn test_split_kfold_disjoint() {
        let split = split_kfold(100, 5, 42);

        // Check folds are disjoint
        for i in 0..5 {
            for j in (i + 1)..5 {
                let set_i: HashSet<_> = split.folds[i].iter().collect();
                let set_j: HashSet<_> = split.folds[j].iter().collect();
                assert!(set_i.is_disjoint(&set_j), "Folds {} and {} overlap", i, j);
            }
        }
    }

    #[test]
    fn test_split_kfold_fold_sizes() {
        // 100 samples, 5 folds = 20 each
        let split = split_kfold(100, 5, 42);
        for fold in &split.folds {
            assert_eq!(fold.len(), 20);
        }

        // 103 samples, 5 folds = 20+1, 20+1, 20+1, 20, 20
        let split = split_kfold(103, 5, 42);
        let sizes: Vec<_> = split.folds.iter().map(|f| f.len()).collect();
        assert_eq!(sizes.iter().sum::<usize>(), 103);
        assert!(sizes.iter().all(|&s| s == 20 || s == 21));
    }

    #[test]
    fn test_split_kfold_get_fold() {
        let split = split_kfold(100, 5, 42);

        for i in 0..5 {
            let (train, val) = split.get_fold(i);

            // Total should be 100
            assert_eq!(train.len() + val.len(), 100);

            // Validation should be one fold's worth (~20)
            assert_eq!(val.len(), 20);

            // No overlap
            let train_set: HashSet<_> = train.iter().collect();
            let val_set: HashSet<_> = val.iter().collect();
            assert!(train_set.is_disjoint(&val_set));

            // Both sorted
            assert!(train.windows(2).all(|w| w[0] < w[1]));
            assert!(val.windows(2).all(|w| w[0] < w[1]));
        }
    }

    #[test]
    fn test_split_kfold_deterministic() {
        let split1 = split_kfold(100, 5, 42);
        let split2 = split_kfold(100, 5, 42);

        for i in 0..5 {
            assert_eq!(split1.folds[i], split2.folds[i]);
        }

        // Different seed should produce different split
        let split3 = split_kfold(100, 5, 43);
        assert_ne!(split1.folds[0], split3.folds[0]);
    }

    #[test]
    fn test_split_kfold_sorted() {
        let split = split_kfold(100, 5, 42);

        for fold in &split.folds {
            assert!(fold.windows(2).all(|w| w[0] < w[1]));
        }
    }

    #[test]
    #[should_panic(expected = "at least 2 folds")]
    fn test_split_kfold_invalid_k() {
        split_kfold(100, 1, 42);
    }

    #[test]
    fn test_split_holdout_no_validation() {
        let split = split_holdout(1000, 0.0, 0.0, 42);
        assert_eq!(split.train.len(), 1000);
        assert_eq!(split.validation.len(), 0);
        assert_eq!(split.calibration.len(), 0);
    }
}