Skip to main content

fdars_core/
cv.rs

1//! Cross-validation utilities and unified CV framework.
2//!
3//! This module provides:
4//! - Shared fold assignment utilities used across all CV functions
5//! - [`cv_fdata`]: Generic k-fold + repeated CV framework (R's `cv.fdata`)
6
7use crate::matrix::FdMatrix;
8use rand::prelude::*;
9use std::any::Any;
10
11// ─── Fold Utilities ─────────────────────────────────────────────────────────
12
13/// Assign observations to folds (deterministic given seed).
14///
15/// Returns a vector of length `n` where element `i` is the fold index (0..n_folds)
16/// that observation `i` belongs to.
17pub fn create_folds(n: usize, n_folds: usize, seed: u64) -> Vec<usize> {
18    let n_folds = n_folds.max(1);
19    let mut rng = StdRng::seed_from_u64(seed);
20    let mut indices: Vec<usize> = (0..n).collect();
21    indices.shuffle(&mut rng);
22
23    let mut folds = vec![0usize; n];
24    for (rank, &idx) in indices.iter().enumerate() {
25        folds[idx] = rank % n_folds;
26    }
27    folds
28}
29
30/// Assign observations to stratified folds (classification).
31///
32/// Ensures each fold has approximately the same class distribution.
33pub fn create_stratified_folds(n: usize, y: &[usize], n_folds: usize, seed: u64) -> Vec<usize> {
34    let n_folds = n_folds.max(1);
35    let mut rng = StdRng::seed_from_u64(seed);
36    let n_classes = y.iter().copied().max().unwrap_or(0) + 1;
37
38    let mut folds = vec![0usize; n];
39
40    // Group indices by class
41    let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
42    for i in 0..n {
43        if y[i] < n_classes {
44            class_indices[y[i]].push(i);
45        }
46    }
47
48    // Shuffle within each class, then assign folds round-robin
49    for indices in &mut class_indices {
50        indices.shuffle(&mut rng);
51        for (rank, &idx) in indices.iter().enumerate() {
52            folds[idx] = rank % n_folds;
53        }
54    }
55
56    folds
57}
58
59/// Split indices into train and test sets for a given fold.
60///
61/// Returns `(train_indices, test_indices)`.
62pub fn fold_indices(folds: &[usize], fold: usize) -> (Vec<usize>, Vec<usize>) {
63    let train: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] != fold).collect();
64    let test: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] == fold).collect();
65    (train, test)
66}
67
68/// Extract a sub-matrix from an FdMatrix by selecting specific row indices.
69pub fn subset_rows(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
70    let m = data.ncols();
71    let n_sub = indices.len();
72    let mut sub = FdMatrix::zeros(n_sub, m);
73    for (new_i, &orig_i) in indices.iter().enumerate() {
74        for j in 0..m {
75            sub[(new_i, j)] = data[(orig_i, j)];
76        }
77    }
78    sub
79}
80
81/// Extract elements from a slice by indices.
82pub fn subset_vec(v: &[f64], indices: &[usize]) -> Vec<f64> {
83    indices.iter().map(|&i| v[i]).collect()
84}
85
86// ─── CV Metrics ─────────────────────────────────────────────────────────────
87
88/// Type of cross-validation task.
89#[derive(Debug, Clone, Copy, PartialEq)]
90pub enum CvType {
91    Regression,
92    Classification,
93}
94
95/// Cross-validation metrics.
96#[derive(Debug, Clone, PartialEq)]
97pub enum CvMetrics {
98    /// Regression metrics.
99    Regression { rmse: f64, mae: f64, r_squared: f64 },
100    /// Classification metrics.
101    Classification {
102        accuracy: f64,
103        confusion: Vec<Vec<usize>>,
104    },
105}
106
107/// Result of unified cross-validation.
108#[derive(Debug, Clone, PartialEq)]
109pub struct CvFdataResult {
110    /// Out-of-fold predictions (length n); for repeated CV, averaged across reps.
111    pub oof_predictions: Vec<f64>,
112    /// Overall metrics.
113    pub metrics: CvMetrics,
114    /// Per-fold metrics.
115    pub fold_metrics: Vec<CvMetrics>,
116    /// Fold assignments from the last (or only) repetition.
117    pub folds: Vec<usize>,
118    /// Type of CV task.
119    pub cv_type: CvType,
120    /// Number of repetitions.
121    pub nrep: usize,
122    /// Standard deviation of OOF predictions across repetitions (only when nrep > 1).
123    pub oof_sd: Option<Vec<f64>>,
124    /// Per-repetition overall metrics (only when nrep > 1).
125    pub rep_metrics: Option<Vec<CvMetrics>>,
126}
127
128// ─── Unified CV Framework ────────────────────────────────────────────────────
129
130/// Create CV folds based on strategy (stratified or random).
131fn create_cv_folds(
132    n: usize,
133    y: &[f64],
134    n_folds: usize,
135    cv_type: CvType,
136    stratified: bool,
137    seed: u64,
138) -> Vec<usize> {
139    if stratified {
140        match cv_type {
141            CvType::Classification => {
142                let y_class: Vec<usize> = y
143                    .iter()
144                    .map(|&v| crate::utility::f64_to_usize_clamped(v))
145                    .collect();
146                create_stratified_folds(n, &y_class, n_folds, seed)
147            }
148            CvType::Regression => {
149                let mut sorted_y: Vec<(usize, f64)> = y.iter().copied().enumerate().collect();
150                sorted_y.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
151                let n_bins = n_folds.min(n);
152                let bin_labels: Vec<usize> = {
153                    let mut labels = vec![0usize; n];
154                    for (rank, &(orig_i, _)) in sorted_y.iter().enumerate() {
155                        labels[orig_i] = (rank * n_bins / n).min(n_bins - 1);
156                    }
157                    labels
158                };
159                create_stratified_folds(n, &bin_labels, n_folds, seed)
160            }
161        }
162    } else {
163        create_folds(n, n_folds, seed)
164    }
165}
166
167/// Aggregate out-of-fold predictions across repetitions (mean and SD).
168fn aggregate_oof_predictions(all_oof: Vec<Vec<f64>>, n: usize) -> (Vec<f64>, Option<Vec<f64>>) {
169    let nrep = all_oof.len();
170    if nrep == 1 {
171        return (
172            all_oof.into_iter().next().expect("non-empty iterator"),
173            None,
174        );
175    }
176    let mut mean_oof = vec![0.0; n];
177    for oof in &all_oof {
178        for i in 0..n {
179            mean_oof[i] += oof[i];
180        }
181    }
182    for v in &mut mean_oof {
183        *v /= nrep as f64;
184    }
185
186    let mut sd_oof = vec![0.0; n];
187    for oof in &all_oof {
188        for i in 0..n {
189            let diff = oof[i] - mean_oof[i];
190            sd_oof[i] += diff * diff;
191        }
192    }
193    for v in &mut sd_oof {
194        *v = (*v / (nrep as f64 - 1.0).max(1.0)).sqrt();
195    }
196
197    (mean_oof, Some(sd_oof))
198}
199
200/// Generic k-fold + repeated cross-validation framework (R's `cv.fdata`).
201pub fn cv_fdata<F, P>(
202    data: &FdMatrix,
203    y: &[f64],
204    fit_fn: F,
205    predict_fn: P,
206    n_folds: usize,
207    nrep: usize,
208    cv_type: CvType,
209    stratified: bool,
210    seed: u64,
211) -> CvFdataResult
212where
213    F: Fn(&FdMatrix, &[f64]) -> Box<dyn Any>,
214    P: Fn(&dyn Any, &FdMatrix) -> Vec<f64>,
215{
216    let n = data.nrows();
217    let nrep = nrep.max(1);
218    let n_folds = n_folds.max(2).min(n);
219
220    let mut all_oof: Vec<Vec<f64>> = Vec::with_capacity(nrep);
221    let mut all_rep_metrics: Vec<CvMetrics> = Vec::with_capacity(nrep);
222    let mut last_folds = vec![0usize; n];
223    let mut last_fold_metrics = Vec::new();
224
225    for r in 0..nrep {
226        let rep_seed = seed.wrapping_add(r as u64);
227        let folds = create_cv_folds(n, y, n_folds, cv_type, stratified, rep_seed);
228
229        let mut oof_preds = vec![0.0; n];
230        let mut fold_metrics = Vec::with_capacity(n_folds);
231
232        for fold in 0..n_folds {
233            let (train_idx, test_idx) = fold_indices(&folds, fold);
234            if train_idx.is_empty() || test_idx.is_empty() {
235                continue;
236            }
237
238            let train_data = subset_rows(data, &train_idx);
239            let train_y = subset_vec(y, &train_idx);
240            let test_data = subset_rows(data, &test_idx);
241            let test_y = subset_vec(y, &test_idx);
242
243            let model = fit_fn(&train_data, &train_y);
244            let preds = predict_fn(&*model, &test_data);
245
246            for (local_i, &orig_i) in test_idx.iter().enumerate() {
247                if local_i < preds.len() {
248                    oof_preds[orig_i] = preds[local_i];
249                }
250            }
251
252            fold_metrics.push(compute_metrics(&test_y, &preds, cv_type));
253        }
254
255        let rep_metric = compute_metrics(y, &oof_preds, cv_type);
256        all_oof.push(oof_preds);
257        all_rep_metrics.push(rep_metric);
258        last_folds = folds;
259        last_fold_metrics = fold_metrics;
260    }
261
262    let (final_oof, oof_sd) = aggregate_oof_predictions(all_oof, n);
263    let overall_metrics = compute_metrics(y, &final_oof, cv_type);
264
265    CvFdataResult {
266        oof_predictions: final_oof,
267        metrics: overall_metrics,
268        fold_metrics: last_fold_metrics,
269        folds: last_folds,
270        cv_type,
271        nrep,
272        oof_sd,
273        rep_metrics: if nrep > 1 {
274            Some(all_rep_metrics)
275        } else {
276            None
277        },
278    }
279}
280
281/// Compute metrics from true and predicted values.
282fn compute_metrics(y_true: &[f64], y_pred: &[f64], cv_type: CvType) -> CvMetrics {
283    let n = y_true.len().min(y_pred.len());
284    if n == 0 {
285        return match cv_type {
286            CvType::Regression => CvMetrics::Regression {
287                rmse: f64::NAN,
288                mae: f64::NAN,
289                r_squared: f64::NAN,
290            },
291            CvType::Classification => CvMetrics::Classification {
292                accuracy: 0.0,
293                confusion: Vec::new(),
294            },
295        };
296    }
297
298    match cv_type {
299        CvType::Regression => {
300            let mean_y = y_true.iter().sum::<f64>() / n as f64;
301            let mut ss_res = 0.0;
302            let mut ss_tot = 0.0;
303            let mut mae_sum = 0.0;
304            for i in 0..n {
305                let resid = y_true[i] - y_pred[i];
306                ss_res += resid * resid;
307                ss_tot += (y_true[i] - mean_y).powi(2);
308                mae_sum += resid.abs();
309            }
310            let rmse = (ss_res / n as f64).sqrt();
311            let mae = mae_sum / n as f64;
312            let r_squared = if ss_tot > 1e-15 {
313                1.0 - ss_res / ss_tot
314            } else {
315                0.0
316            };
317            CvMetrics::Regression {
318                rmse,
319                mae,
320                r_squared,
321            }
322        }
323        CvType::Classification => {
324            let n_classes = y_true
325                .iter()
326                .chain(y_pred.iter())
327                .map(|&v| v as usize)
328                .max()
329                .unwrap_or(0)
330                + 1;
331            let mut confusion = vec![vec![0usize; n_classes]; n_classes];
332            let mut correct = 0usize;
333            for i in 0..n {
334                let true_c = y_true[i] as usize;
335                let pred_c = y_pred[i].round() as usize;
336                if true_c < n_classes && pred_c < n_classes {
337                    confusion[true_c][pred_c] += 1;
338                }
339                if true_c == pred_c {
340                    correct += 1;
341                }
342            }
343            let accuracy = correct as f64 / n as f64;
344            CvMetrics::Classification {
345                accuracy,
346                confusion,
347            }
348        }
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::error::FdarError;
356
357    #[test]
358    fn test_create_folds_basic() {
359        let folds = create_folds(10, 5, 42);
360        assert_eq!(folds.len(), 10);
361        // Each fold should have 2 members
362        for f in 0..5 {
363            let count = folds.iter().filter(|&&x| x == f).count();
364            assert_eq!(count, 2);
365        }
366    }
367
368    #[test]
369    fn test_create_folds_deterministic() {
370        let f1 = create_folds(20, 5, 123);
371        let f2 = create_folds(20, 5, 123);
372        assert_eq!(f1, f2);
373    }
374
375    #[test]
376    fn test_stratified_folds() {
377        let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
378        let folds = create_stratified_folds(10, &y, 5, 42);
379        assert_eq!(folds.len(), 10);
380        // Each fold should have 1 from each class
381        for f in 0..5 {
382            let class0_count = (0..10).filter(|&i| folds[i] == f && y[i] == 0).count();
383            let class1_count = (0..10).filter(|&i| folds[i] == f && y[i] == 1).count();
384            assert_eq!(class0_count, 1);
385            assert_eq!(class1_count, 1);
386        }
387    }
388
389    #[test]
390    fn test_fold_indices() {
391        let folds = vec![0, 1, 2, 0, 1, 2];
392        let (train, test) = fold_indices(&folds, 1);
393        assert_eq!(test, vec![1, 4]);
394        assert_eq!(train, vec![0, 2, 3, 5]);
395    }
396
397    #[test]
398    fn test_subset_rows() {
399        let mut data = FdMatrix::zeros(4, 3);
400        for i in 0..4 {
401            for j in 0..3 {
402                data[(i, j)] = (i * 10 + j) as f64;
403            }
404        }
405        let sub = subset_rows(&data, &[1, 3]);
406        assert_eq!(sub.nrows(), 2);
407        assert_eq!(sub.ncols(), 3);
408        assert!((sub[(0, 0)] - 10.0).abs() < 1e-10);
409        assert!((sub[(1, 0)] - 30.0).abs() < 1e-10);
410    }
411
412    #[test]
413    fn test_cv_fdata_regression() -> Result<(), FdarError> {
414        // Simple test: predict mean
415        let n = 20;
416        let m = 5;
417        let mut data = FdMatrix::zeros(n, m);
418        let y: Vec<f64> = (0..n).map(|i| i as f64).collect();
419        for i in 0..n {
420            for j in 0..m {
421                data[(i, j)] = y[i] + j as f64 * 0.1;
422            }
423        }
424
425        let result = cv_fdata(
426            &data,
427            &y,
428            |_train_data, train_y| {
429                let mean = train_y.iter().sum::<f64>() / train_y.len() as f64;
430                Box::new(mean)
431            },
432            |model, test_data| {
433                let mean = model.downcast_ref::<f64>().unwrap();
434                vec![*mean; test_data.nrows()]
435            },
436            5,
437            1,
438            CvType::Regression,
439            false,
440            42,
441        );
442
443        assert_eq!(result.oof_predictions.len(), n);
444        assert_eq!(result.nrep, 1);
445        assert!(result.oof_sd.is_none());
446        match &result.metrics {
447            CvMetrics::Regression { rmse, .. } => assert!(*rmse > 0.0),
448            _ => {
449                return Err(FdarError::ComputationFailed {
450                    operation: "cv_fdata_regression",
451                    detail: "expected regression metrics".into(),
452                });
453            }
454        }
455        Ok(())
456    }
457
458    #[test]
459    fn test_cv_fdata_repeated() {
460        let n = 20;
461        let m = 3;
462        let data = FdMatrix::zeros(n, m);
463        let y: Vec<f64> = (0..n).map(|i| (i % 2) as f64).collect();
464
465        let result = cv_fdata(
466            &data,
467            &y,
468            |_d, _y| Box::new(0.5_f64),
469            |_model, test_data| vec![0.5; test_data.nrows()],
470            5,
471            3,
472            CvType::Regression,
473            false,
474            42,
475        );
476
477        assert_eq!(result.nrep, 3);
478        assert!(result.oof_sd.is_some());
479        assert!(result.rep_metrics.is_some());
480        assert_eq!(result.rep_metrics.as_ref().unwrap().len(), 3);
481    }
482
483    #[test]
484    fn test_compute_metrics_classification() -> Result<(), FdarError> {
485        let y_true = vec![0.0, 0.0, 1.0, 1.0];
486        let y_pred = vec![0.0, 1.0, 1.0, 1.0]; // 1 misclassification
487        let m = compute_metrics(&y_true, &y_pred, CvType::Classification);
488        match m {
489            CvMetrics::Classification {
490                accuracy,
491                confusion,
492            } => {
493                assert!((accuracy - 0.75).abs() < 1e-10);
494                assert_eq!(confusion[0][0], 1); // true 0, pred 0
495                assert_eq!(confusion[0][1], 1); // true 0, pred 1
496                assert_eq!(confusion[1][1], 2); // true 1, pred 1
497            }
498            _ => {
499                return Err(FdarError::ComputationFailed {
500                    operation: "compute_metrics_classification",
501                    detail: "expected classification metrics".into(),
502                });
503            }
504        }
505        Ok(())
506    }
507}