Skip to main content

fdars_core/
cv.rs

1//! Cross-validation utilities and unified CV framework.
2//!
3//! This module provides:
4//! - Shared fold assignment utilities used across all CV functions
5//! - [`cv_fdata`]: Generic k-fold + repeated CV framework (R's `cv.fdata`)
6
7use crate::matrix::FdMatrix;
8use rand::prelude::*;
9use std::any::Any;
10
11// ─── Fold Utilities ─────────────────────────────────────────────────────────
12
13/// Assign observations to folds (deterministic given seed).
14///
15/// Returns a vector of length `n` where element `i` is the fold index (0..n_folds)
16/// that observation `i` belongs to.
17pub fn create_folds(n: usize, n_folds: usize, seed: u64) -> Vec<usize> {
18    let n_folds = n_folds.max(1);
19    let mut rng = StdRng::seed_from_u64(seed);
20    let mut indices: Vec<usize> = (0..n).collect();
21    indices.shuffle(&mut rng);
22
23    let mut folds = vec![0usize; n];
24    for (rank, &idx) in indices.iter().enumerate() {
25        folds[idx] = rank % n_folds;
26    }
27    folds
28}
29
30/// Assign observations to stratified folds (classification).
31///
32/// Ensures each fold has approximately the same class distribution.
33pub fn create_stratified_folds(n: usize, y: &[usize], n_folds: usize, seed: u64) -> Vec<usize> {
34    let n_folds = n_folds.max(1);
35    let mut rng = StdRng::seed_from_u64(seed);
36    let n_classes = y.iter().copied().max().unwrap_or(0) + 1;
37
38    let mut folds = vec![0usize; n];
39
40    // Group indices by class
41    let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
42    for i in 0..n {
43        if y[i] < n_classes {
44            class_indices[y[i]].push(i);
45        }
46    }
47
48    // Shuffle within each class, then assign folds round-robin
49    for indices in &mut class_indices {
50        indices.shuffle(&mut rng);
51        for (rank, &idx) in indices.iter().enumerate() {
52            folds[idx] = rank % n_folds;
53        }
54    }
55
56    folds
57}
58
59/// Split indices into train and test sets for a given fold.
60///
61/// Returns `(train_indices, test_indices)`.
62pub fn fold_indices(folds: &[usize], fold: usize) -> (Vec<usize>, Vec<usize>) {
63    let train: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] != fold).collect();
64    let test: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] == fold).collect();
65    (train, test)
66}
67
68/// Extract a sub-matrix from an FdMatrix by selecting specific row indices.
69pub fn subset_rows(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
70    let m = data.ncols();
71    let n_sub = indices.len();
72    let mut sub = FdMatrix::zeros(n_sub, m);
73    for (new_i, &orig_i) in indices.iter().enumerate() {
74        for j in 0..m {
75            sub[(new_i, j)] = data[(orig_i, j)];
76        }
77    }
78    sub
79}
80
81/// Extract elements from a slice by indices.
82pub fn subset_vec(v: &[f64], indices: &[usize]) -> Vec<f64> {
83    indices.iter().map(|&i| v[i]).collect()
84}
85
86// ─── CV Metrics ─────────────────────────────────────────────────────────────
87
88/// Type of cross-validation task.
89#[derive(Debug, Clone, Copy, PartialEq)]
90#[non_exhaustive]
91pub enum CvType {
92    Regression,
93    Classification,
94}
95
96/// Cross-validation metrics.
97#[derive(Debug, Clone, PartialEq)]
98#[non_exhaustive]
99pub enum CvMetrics {
100    /// Regression metrics.
101    Regression { rmse: f64, mae: f64, r_squared: f64 },
102    /// Classification metrics.
103    Classification {
104        accuracy: f64,
105        confusion: Vec<Vec<usize>>,
106    },
107}
108
109/// Result of unified cross-validation.
110#[derive(Debug, Clone, PartialEq)]
111#[non_exhaustive]
112pub struct CvFdataResult {
113    /// Out-of-fold predictions (length n); for repeated CV, averaged across reps.
114    pub oof_predictions: Vec<f64>,
115    /// Overall metrics.
116    pub metrics: CvMetrics,
117    /// Per-fold metrics.
118    pub fold_metrics: Vec<CvMetrics>,
119    /// Fold assignments from the last (or only) repetition.
120    pub folds: Vec<usize>,
121    /// Type of CV task.
122    pub cv_type: CvType,
123    /// Number of repetitions.
124    pub nrep: usize,
125    /// Standard deviation of OOF predictions across repetitions (only when nrep > 1).
126    pub oof_sd: Option<Vec<f64>>,
127    /// Per-repetition overall metrics (only when nrep > 1).
128    pub rep_metrics: Option<Vec<CvMetrics>>,
129}
130
131// ─── Unified CV Framework ────────────────────────────────────────────────────
132
133/// Create CV folds based on strategy (stratified or random).
134fn create_cv_folds(
135    n: usize,
136    y: &[f64],
137    n_folds: usize,
138    cv_type: CvType,
139    stratified: bool,
140    seed: u64,
141) -> Vec<usize> {
142    if stratified {
143        match cv_type {
144            CvType::Classification => {
145                let y_class: Vec<usize> = y
146                    .iter()
147                    .map(|&v| crate::utility::f64_to_usize_clamped(v))
148                    .collect();
149                create_stratified_folds(n, &y_class, n_folds, seed)
150            }
151            CvType::Regression => {
152                let mut sorted_y: Vec<(usize, f64)> = y.iter().copied().enumerate().collect();
153                sorted_y.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
154                let n_bins = n_folds.min(n);
155                let bin_labels: Vec<usize> = {
156                    let mut labels = vec![0usize; n];
157                    for (rank, &(orig_i, _)) in sorted_y.iter().enumerate() {
158                        labels[orig_i] = (rank * n_bins / n).min(n_bins - 1);
159                    }
160                    labels
161                };
162                create_stratified_folds(n, &bin_labels, n_folds, seed)
163            }
164        }
165    } else {
166        create_folds(n, n_folds, seed)
167    }
168}
169
170/// Aggregate out-of-fold predictions across repetitions (mean and SD).
171fn aggregate_oof_predictions(all_oof: Vec<Vec<f64>>, n: usize) -> (Vec<f64>, Option<Vec<f64>>) {
172    let nrep = all_oof.len();
173    if nrep == 1 {
174        return (
175            all_oof.into_iter().next().expect("non-empty iterator"),
176            None,
177        );
178    }
179    let mut mean_oof = vec![0.0; n];
180    for oof in &all_oof {
181        for i in 0..n {
182            mean_oof[i] += oof[i];
183        }
184    }
185    for v in &mut mean_oof {
186        *v /= nrep as f64;
187    }
188
189    let mut sd_oof = vec![0.0; n];
190    for oof in &all_oof {
191        for i in 0..n {
192            let diff = oof[i] - mean_oof[i];
193            sd_oof[i] += diff * diff;
194        }
195    }
196    for v in &mut sd_oof {
197        *v = (*v / (nrep as f64 - 1.0).max(1.0)).sqrt();
198    }
199
200    (mean_oof, Some(sd_oof))
201}
202
203/// Generic k-fold + repeated cross-validation framework (R's `cv.fdata`).
204pub fn cv_fdata<F, P>(
205    data: &FdMatrix,
206    y: &[f64],
207    fit_fn: F,
208    predict_fn: P,
209    n_folds: usize,
210    nrep: usize,
211    cv_type: CvType,
212    stratified: bool,
213    seed: u64,
214) -> CvFdataResult
215where
216    F: Fn(&FdMatrix, &[f64]) -> Box<dyn Any>,
217    P: Fn(&dyn Any, &FdMatrix) -> Vec<f64>,
218{
219    let n = data.nrows();
220    let nrep = nrep.max(1);
221    let n_folds = n_folds.max(2).min(n);
222
223    let mut all_oof: Vec<Vec<f64>> = Vec::with_capacity(nrep);
224    let mut all_rep_metrics: Vec<CvMetrics> = Vec::with_capacity(nrep);
225    let mut last_folds = vec![0usize; n];
226    let mut last_fold_metrics = Vec::new();
227
228    for r in 0..nrep {
229        let rep_seed = seed.wrapping_add(r as u64);
230        let folds = create_cv_folds(n, y, n_folds, cv_type, stratified, rep_seed);
231
232        let mut oof_preds = vec![0.0; n];
233        let mut fold_metrics = Vec::with_capacity(n_folds);
234
235        for fold in 0..n_folds {
236            let (train_idx, test_idx) = fold_indices(&folds, fold);
237            if train_idx.is_empty() || test_idx.is_empty() {
238                continue;
239            }
240
241            let train_data = subset_rows(data, &train_idx);
242            let train_y = subset_vec(y, &train_idx);
243            let test_data = subset_rows(data, &test_idx);
244            let test_y = subset_vec(y, &test_idx);
245
246            let model = fit_fn(&train_data, &train_y);
247            let preds = predict_fn(&*model, &test_data);
248
249            for (local_i, &orig_i) in test_idx.iter().enumerate() {
250                if local_i < preds.len() {
251                    oof_preds[orig_i] = preds[local_i];
252                }
253            }
254
255            fold_metrics.push(compute_metrics(&test_y, &preds, cv_type));
256        }
257
258        let rep_metric = compute_metrics(y, &oof_preds, cv_type);
259        all_oof.push(oof_preds);
260        all_rep_metrics.push(rep_metric);
261        last_folds = folds;
262        last_fold_metrics = fold_metrics;
263    }
264
265    let (final_oof, oof_sd) = aggregate_oof_predictions(all_oof, n);
266    let overall_metrics = compute_metrics(y, &final_oof, cv_type);
267
268    CvFdataResult {
269        oof_predictions: final_oof,
270        metrics: overall_metrics,
271        fold_metrics: last_fold_metrics,
272        folds: last_folds,
273        cv_type,
274        nrep,
275        oof_sd,
276        rep_metrics: if nrep > 1 {
277            Some(all_rep_metrics)
278        } else {
279            None
280        },
281    }
282}
283
284/// Compute metrics from true and predicted values.
285fn compute_metrics(y_true: &[f64], y_pred: &[f64], cv_type: CvType) -> CvMetrics {
286    let n = y_true.len().min(y_pred.len());
287    if n == 0 {
288        return match cv_type {
289            CvType::Regression => CvMetrics::Regression {
290                rmse: f64::NAN,
291                mae: f64::NAN,
292                r_squared: f64::NAN,
293            },
294            CvType::Classification => CvMetrics::Classification {
295                accuracy: 0.0,
296                confusion: Vec::new(),
297            },
298        };
299    }
300
301    match cv_type {
302        CvType::Regression => {
303            let mean_y = y_true.iter().sum::<f64>() / n as f64;
304            let mut ss_res = 0.0;
305            let mut ss_tot = 0.0;
306            let mut mae_sum = 0.0;
307            for i in 0..n {
308                let resid = y_true[i] - y_pred[i];
309                ss_res += resid * resid;
310                ss_tot += (y_true[i] - mean_y).powi(2);
311                mae_sum += resid.abs();
312            }
313            let rmse = (ss_res / n as f64).sqrt();
314            let mae = mae_sum / n as f64;
315            let r_squared = if ss_tot > 1e-15 {
316                1.0 - ss_res / ss_tot
317            } else {
318                0.0
319            };
320            CvMetrics::Regression {
321                rmse,
322                mae,
323                r_squared,
324            }
325        }
326        CvType::Classification => {
327            let n_classes = y_true
328                .iter()
329                .chain(y_pred.iter())
330                .map(|&v| v as usize)
331                .max()
332                .unwrap_or(0)
333                + 1;
334            let mut confusion = vec![vec![0usize; n_classes]; n_classes];
335            let mut correct = 0usize;
336            for i in 0..n {
337                let true_c = y_true[i] as usize;
338                let pred_c = y_pred[i].round() as usize;
339                if true_c < n_classes && pred_c < n_classes {
340                    confusion[true_c][pred_c] += 1;
341                }
342                if true_c == pred_c {
343                    correct += 1;
344                }
345            }
346            let accuracy = correct as f64 / n as f64;
347            CvMetrics::Classification {
348                accuracy,
349                confusion,
350            }
351        }
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::error::FdarError;
359
360    #[test]
361    fn test_create_folds_basic() {
362        let folds = create_folds(10, 5, 42);
363        assert_eq!(folds.len(), 10);
364        // Each fold should have 2 members
365        for f in 0..5 {
366            let count = folds.iter().filter(|&&x| x == f).count();
367            assert_eq!(count, 2);
368        }
369    }
370
371    #[test]
372    fn test_create_folds_deterministic() {
373        let f1 = create_folds(20, 5, 123);
374        let f2 = create_folds(20, 5, 123);
375        assert_eq!(f1, f2);
376    }
377
378    #[test]
379    fn test_stratified_folds() {
380        let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
381        let folds = create_stratified_folds(10, &y, 5, 42);
382        assert_eq!(folds.len(), 10);
383        // Each fold should have 1 from each class
384        for f in 0..5 {
385            let class0_count = (0..10).filter(|&i| folds[i] == f && y[i] == 0).count();
386            let class1_count = (0..10).filter(|&i| folds[i] == f && y[i] == 1).count();
387            assert_eq!(class0_count, 1);
388            assert_eq!(class1_count, 1);
389        }
390    }
391
392    #[test]
393    fn test_fold_indices() {
394        let folds = vec![0, 1, 2, 0, 1, 2];
395        let (train, test) = fold_indices(&folds, 1);
396        assert_eq!(test, vec![1, 4]);
397        assert_eq!(train, vec![0, 2, 3, 5]);
398    }
399
400    #[test]
401    fn test_subset_rows() {
402        let mut data = FdMatrix::zeros(4, 3);
403        for i in 0..4 {
404            for j in 0..3 {
405                data[(i, j)] = (i * 10 + j) as f64;
406            }
407        }
408        let sub = subset_rows(&data, &[1, 3]);
409        assert_eq!(sub.nrows(), 2);
410        assert_eq!(sub.ncols(), 3);
411        assert!((sub[(0, 0)] - 10.0).abs() < 1e-10);
412        assert!((sub[(1, 0)] - 30.0).abs() < 1e-10);
413    }
414
415    #[test]
416    fn test_cv_fdata_regression() -> Result<(), FdarError> {
417        // Simple test: predict mean
418        let n = 20;
419        let m = 5;
420        let mut data = FdMatrix::zeros(n, m);
421        let y: Vec<f64> = (0..n).map(|i| i as f64).collect();
422        for i in 0..n {
423            for j in 0..m {
424                data[(i, j)] = y[i] + j as f64 * 0.1;
425            }
426        }
427
428        let result = cv_fdata(
429            &data,
430            &y,
431            |_train_data, train_y| {
432                let mean = train_y.iter().sum::<f64>() / train_y.len() as f64;
433                Box::new(mean)
434            },
435            |model, test_data| {
436                let mean = model.downcast_ref::<f64>().unwrap();
437                vec![*mean; test_data.nrows()]
438            },
439            5,
440            1,
441            CvType::Regression,
442            false,
443            42,
444        );
445
446        assert_eq!(result.oof_predictions.len(), n);
447        assert_eq!(result.nrep, 1);
448        assert!(result.oof_sd.is_none());
449        match &result.metrics {
450            CvMetrics::Regression { rmse, .. } => assert!(*rmse > 0.0),
451            _ => {
452                return Err(FdarError::ComputationFailed {
453                    operation: "cv_fdata_regression",
454                    detail: "expected regression metrics".into(),
455                });
456            }
457        }
458        Ok(())
459    }
460
461    #[test]
462    fn test_cv_fdata_repeated() {
463        let n = 20;
464        let m = 3;
465        let data = FdMatrix::zeros(n, m);
466        let y: Vec<f64> = (0..n).map(|i| (i % 2) as f64).collect();
467
468        let result = cv_fdata(
469            &data,
470            &y,
471            |_d, _y| Box::new(0.5_f64),
472            |_model, test_data| vec![0.5; test_data.nrows()],
473            5,
474            3,
475            CvType::Regression,
476            false,
477            42,
478        );
479
480        assert_eq!(result.nrep, 3);
481        assert!(result.oof_sd.is_some());
482        assert!(result.rep_metrics.is_some());
483        assert_eq!(result.rep_metrics.as_ref().unwrap().len(), 3);
484    }
485
486    #[test]
487    fn test_compute_metrics_classification() -> Result<(), FdarError> {
488        let y_true = vec![0.0, 0.0, 1.0, 1.0];
489        let y_pred = vec![0.0, 1.0, 1.0, 1.0]; // 1 misclassification
490        let m = compute_metrics(&y_true, &y_pred, CvType::Classification);
491        match m {
492            CvMetrics::Classification {
493                accuracy,
494                confusion,
495            } => {
496                assert!((accuracy - 0.75).abs() < 1e-10);
497                assert_eq!(confusion[0][0], 1); // true 0, pred 0
498                assert_eq!(confusion[0][1], 1); // true 0, pred 1
499                assert_eq!(confusion[1][1], 2); // true 1, pred 1
500            }
501            _ => {
502                return Err(FdarError::ComputationFailed {
503                    operation: "compute_metrics_classification",
504                    detail: "expected classification metrics".into(),
505                });
506            }
507        }
508        Ok(())
509    }
510}