Skip to main content

fdars_core/
cv.rs

1//! Cross-validation utilities and unified CV framework.
2//!
3//! This module provides:
4//! - Shared fold assignment utilities used across all CV functions
5//! - [`cv_fdata`]: Generic k-fold + repeated CV framework (R's `cv.fdata`)
6
7use crate::matrix::FdMatrix;
8use rand::prelude::*;
9use std::any::Any;
10
11// ─── Fold Utilities ─────────────────────────────────────────────────────────
12
13/// Assign observations to folds (deterministic given seed).
14///
15/// Returns a vector of length `n` where element `i` is the fold index (0..n_folds)
16/// that observation `i` belongs to.
17pub fn create_folds(n: usize, n_folds: usize, seed: u64) -> Vec<usize> {
18    let n_folds = n_folds.max(1);
19    let mut rng = StdRng::seed_from_u64(seed);
20    let mut indices: Vec<usize> = (0..n).collect();
21    indices.shuffle(&mut rng);
22
23    let mut folds = vec![0usize; n];
24    for (rank, &idx) in indices.iter().enumerate() {
25        folds[idx] = rank % n_folds;
26    }
27    folds
28}
29
30/// Assign observations to stratified folds (classification).
31///
32/// Ensures each fold has approximately the same class distribution.
33pub fn create_stratified_folds(n: usize, y: &[usize], n_folds: usize, seed: u64) -> Vec<usize> {
34    let n_folds = n_folds.max(1);
35    let mut rng = StdRng::seed_from_u64(seed);
36    let n_classes = y.iter().copied().max().unwrap_or(0) + 1;
37
38    let mut folds = vec![0usize; n];
39
40    // Group indices by class
41    let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
42    for i in 0..n {
43        if y[i] < n_classes {
44            class_indices[y[i]].push(i);
45        }
46    }
47
48    // Shuffle within each class, then assign folds round-robin
49    for indices in &mut class_indices {
50        indices.shuffle(&mut rng);
51        for (rank, &idx) in indices.iter().enumerate() {
52            folds[idx] = rank % n_folds;
53        }
54    }
55
56    folds
57}
58
59/// Split indices into train and test sets for a given fold.
60///
61/// Returns `(train_indices, test_indices)`.
62pub fn fold_indices(folds: &[usize], fold: usize) -> (Vec<usize>, Vec<usize>) {
63    let train: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] != fold).collect();
64    let test: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] == fold).collect();
65    (train, test)
66}
67
68/// Extract a sub-matrix from an FdMatrix by selecting specific row indices.
69pub fn subset_rows(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
70    let m = data.ncols();
71    let n_sub = indices.len();
72    let mut sub = FdMatrix::zeros(n_sub, m);
73    for (new_i, &orig_i) in indices.iter().enumerate() {
74        for j in 0..m {
75            sub[(new_i, j)] = data[(orig_i, j)];
76        }
77    }
78    sub
79}
80
81/// Extract elements from a slice by indices.
82pub fn subset_vec(v: &[f64], indices: &[usize]) -> Vec<f64> {
83    indices.iter().map(|&i| v[i]).collect()
84}
85
86// ─── CV Metrics ─────────────────────────────────────────────────────────────
87
88/// Type of cross-validation task.
89#[derive(Debug, Clone, Copy, PartialEq)]
90pub enum CvType {
91    Regression,
92    Classification,
93}
94
95/// Cross-validation metrics.
96#[derive(Debug, Clone)]
97pub enum CvMetrics {
98    /// Regression metrics.
99    Regression { rmse: f64, mae: f64, r_squared: f64 },
100    /// Classification metrics.
101    Classification {
102        accuracy: f64,
103        confusion: Vec<Vec<usize>>,
104    },
105}
106
107/// Result of unified cross-validation.
108#[derive(Debug, Clone)]
109pub struct CvFdataResult {
110    /// Out-of-fold predictions (length n); for repeated CV, averaged across reps.
111    pub oof_predictions: Vec<f64>,
112    /// Overall metrics.
113    pub metrics: CvMetrics,
114    /// Per-fold metrics.
115    pub fold_metrics: Vec<CvMetrics>,
116    /// Fold assignments from the last (or only) repetition.
117    pub folds: Vec<usize>,
118    /// Type of CV task.
119    pub cv_type: CvType,
120    /// Number of repetitions.
121    pub nrep: usize,
122    /// Standard deviation of OOF predictions across repetitions (only when nrep > 1).
123    pub oof_sd: Option<Vec<f64>>,
124    /// Per-repetition overall metrics (only when nrep > 1).
125    pub rep_metrics: Option<Vec<CvMetrics>>,
126}
127
128// ─── Unified CV Framework ────────────────────────────────────────────────────
129
130/// Generic k-fold + repeated cross-validation framework (R's `cv.fdata`).
131///
132/// The user provides fit/predict closures so this works with any model.
133///
134/// # Arguments
135/// * `data` — Functional data matrix (n × m)
136/// * `y` — Response vector (length n); for classification, should be 0, 1, 2, …
137/// * `fit_fn` — Closure that fits a model on training data and returns a boxed model
138/// * `predict_fn` — Closure that predicts from a model on test data
139/// * `n_folds` — Number of CV folds
140/// * `nrep` — Number of repetitions (1 = single CV, >1 = repeated)
141/// * `cv_type` — Whether this is regression or classification
142/// * `stratified` — Whether to stratify folds
143/// * `seed` — Random seed for fold assignment
144pub fn cv_fdata<F, P>(
145    data: &FdMatrix,
146    y: &[f64],
147    fit_fn: F,
148    predict_fn: P,
149    n_folds: usize,
150    nrep: usize,
151    cv_type: CvType,
152    stratified: bool,
153    seed: u64,
154) -> CvFdataResult
155where
156    F: Fn(&FdMatrix, &[f64]) -> Box<dyn Any>,
157    P: Fn(&dyn Any, &FdMatrix) -> Vec<f64>,
158{
159    let n = data.nrows();
160    let nrep = nrep.max(1);
161    let n_folds = n_folds.max(2).min(n);
162
163    // Storage for repeated CV
164    let mut all_oof: Vec<Vec<f64>> = Vec::with_capacity(nrep);
165    let mut all_rep_metrics: Vec<CvMetrics> = Vec::with_capacity(nrep);
166    let mut last_folds = vec![0usize; n];
167    let mut last_fold_metrics = Vec::new();
168
169    for r in 0..nrep {
170        let rep_seed = seed.wrapping_add(r as u64);
171
172        // Create folds
173        let folds = if stratified {
174            match cv_type {
175                CvType::Classification => {
176                    let y_class: Vec<usize> = y.iter().map(|&v| v as usize).collect();
177                    create_stratified_folds(n, &y_class, n_folds, rep_seed)
178                }
179                CvType::Regression => {
180                    // Stratify by quantile bin
181                    let mut sorted_y: Vec<(usize, f64)> = y.iter().copied().enumerate().collect();
182                    sorted_y
183                        .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
184                    let n_bins = n_folds.min(n);
185                    let bin_labels: Vec<usize> = {
186                        let mut labels = vec![0usize; n];
187                        for (rank, &(orig_i, _)) in sorted_y.iter().enumerate() {
188                            labels[orig_i] = (rank * n_bins / n).min(n_bins - 1);
189                        }
190                        labels
191                    };
192                    create_stratified_folds(n, &bin_labels, n_folds, rep_seed)
193                }
194            }
195        } else {
196            create_folds(n, n_folds, rep_seed)
197        };
198
199        let mut oof_preds = vec![0.0; n];
200        let mut fold_metrics = Vec::with_capacity(n_folds);
201
202        for fold in 0..n_folds {
203            let (train_idx, test_idx) = fold_indices(&folds, fold);
204            if train_idx.is_empty() || test_idx.is_empty() {
205                continue;
206            }
207
208            let train_data = subset_rows(data, &train_idx);
209            let train_y = subset_vec(y, &train_idx);
210            let test_data = subset_rows(data, &test_idx);
211            let test_y = subset_vec(y, &test_idx);
212
213            let model = fit_fn(&train_data, &train_y);
214            let preds = predict_fn(&*model, &test_data);
215
216            for (local_i, &orig_i) in test_idx.iter().enumerate() {
217                if local_i < preds.len() {
218                    oof_preds[orig_i] = preds[local_i];
219                }
220            }
221
222            fold_metrics.push(compute_metrics(&test_y, &preds, cv_type));
223        }
224
225        let rep_metric = compute_metrics(y, &oof_preds, cv_type);
226        all_oof.push(oof_preds);
227        all_rep_metrics.push(rep_metric);
228        last_folds = folds;
229        last_fold_metrics = fold_metrics;
230    }
231
232    // Aggregate across repetitions
233    let (final_oof, oof_sd) = if nrep == 1 {
234        (all_oof.into_iter().next().unwrap(), None)
235    } else {
236        let mut mean_oof = vec![0.0; n];
237        for oof in &all_oof {
238            for i in 0..n {
239                mean_oof[i] += oof[i];
240            }
241        }
242        for v in &mut mean_oof {
243            *v /= nrep as f64;
244        }
245
246        let mut sd_oof = vec![0.0; n];
247        for oof in &all_oof {
248            for i in 0..n {
249                let diff = oof[i] - mean_oof[i];
250                sd_oof[i] += diff * diff;
251            }
252        }
253        for v in &mut sd_oof {
254            *v = (*v / (nrep as f64 - 1.0).max(1.0)).sqrt();
255        }
256
257        (mean_oof, Some(sd_oof))
258    };
259
260    let overall_metrics = compute_metrics(y, &final_oof, cv_type);
261
262    CvFdataResult {
263        oof_predictions: final_oof,
264        metrics: overall_metrics,
265        fold_metrics: last_fold_metrics,
266        folds: last_folds,
267        cv_type,
268        nrep,
269        oof_sd,
270        rep_metrics: if nrep > 1 {
271            Some(all_rep_metrics)
272        } else {
273            None
274        },
275    }
276}
277
278/// Compute metrics from true and predicted values.
279fn compute_metrics(y_true: &[f64], y_pred: &[f64], cv_type: CvType) -> CvMetrics {
280    let n = y_true.len().min(y_pred.len());
281    if n == 0 {
282        return match cv_type {
283            CvType::Regression => CvMetrics::Regression {
284                rmse: f64::NAN,
285                mae: f64::NAN,
286                r_squared: f64::NAN,
287            },
288            CvType::Classification => CvMetrics::Classification {
289                accuracy: 0.0,
290                confusion: Vec::new(),
291            },
292        };
293    }
294
295    match cv_type {
296        CvType::Regression => {
297            let mean_y = y_true.iter().sum::<f64>() / n as f64;
298            let mut ss_res = 0.0;
299            let mut ss_tot = 0.0;
300            let mut mae_sum = 0.0;
301            for i in 0..n {
302                let resid = y_true[i] - y_pred[i];
303                ss_res += resid * resid;
304                ss_tot += (y_true[i] - mean_y).powi(2);
305                mae_sum += resid.abs();
306            }
307            let rmse = (ss_res / n as f64).sqrt();
308            let mae = mae_sum / n as f64;
309            let r_squared = if ss_tot > 1e-15 {
310                1.0 - ss_res / ss_tot
311            } else {
312                0.0
313            };
314            CvMetrics::Regression {
315                rmse,
316                mae,
317                r_squared,
318            }
319        }
320        CvType::Classification => {
321            let n_classes = y_true
322                .iter()
323                .chain(y_pred.iter())
324                .map(|&v| v as usize)
325                .max()
326                .unwrap_or(0)
327                + 1;
328            let mut confusion = vec![vec![0usize; n_classes]; n_classes];
329            let mut correct = 0usize;
330            for i in 0..n {
331                let true_c = y_true[i] as usize;
332                let pred_c = y_pred[i].round() as usize;
333                if true_c < n_classes && pred_c < n_classes {
334                    confusion[true_c][pred_c] += 1;
335                }
336                if true_c == pred_c {
337                    correct += 1;
338                }
339            }
340            let accuracy = correct as f64 / n as f64;
341            CvMetrics::Classification {
342                accuracy,
343                confusion,
344            }
345        }
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_create_folds_basic() {
355        let folds = create_folds(10, 5, 42);
356        assert_eq!(folds.len(), 10);
357        // Each fold should have 2 members
358        for f in 0..5 {
359            let count = folds.iter().filter(|&&x| x == f).count();
360            assert_eq!(count, 2);
361        }
362    }
363
364    #[test]
365    fn test_create_folds_deterministic() {
366        let f1 = create_folds(20, 5, 123);
367        let f2 = create_folds(20, 5, 123);
368        assert_eq!(f1, f2);
369    }
370
371    #[test]
372    fn test_stratified_folds() {
373        let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
374        let folds = create_stratified_folds(10, &y, 5, 42);
375        assert_eq!(folds.len(), 10);
376        // Each fold should have 1 from each class
377        for f in 0..5 {
378            let class0_count = (0..10).filter(|&i| folds[i] == f && y[i] == 0).count();
379            let class1_count = (0..10).filter(|&i| folds[i] == f && y[i] == 1).count();
380            assert_eq!(class0_count, 1);
381            assert_eq!(class1_count, 1);
382        }
383    }
384
385    #[test]
386    fn test_fold_indices() {
387        let folds = vec![0, 1, 2, 0, 1, 2];
388        let (train, test) = fold_indices(&folds, 1);
389        assert_eq!(test, vec![1, 4]);
390        assert_eq!(train, vec![0, 2, 3, 5]);
391    }
392
393    #[test]
394    fn test_subset_rows() {
395        let mut data = FdMatrix::zeros(4, 3);
396        for i in 0..4 {
397            for j in 0..3 {
398                data[(i, j)] = (i * 10 + j) as f64;
399            }
400        }
401        let sub = subset_rows(&data, &[1, 3]);
402        assert_eq!(sub.nrows(), 2);
403        assert_eq!(sub.ncols(), 3);
404        assert!((sub[(0, 0)] - 10.0).abs() < 1e-10);
405        assert!((sub[(1, 0)] - 30.0).abs() < 1e-10);
406    }
407
408    #[test]
409    fn test_cv_fdata_regression() {
410        // Simple test: predict mean
411        let n = 20;
412        let m = 5;
413        let mut data = FdMatrix::zeros(n, m);
414        let y: Vec<f64> = (0..n).map(|i| i as f64).collect();
415        for i in 0..n {
416            for j in 0..m {
417                data[(i, j)] = y[i] + j as f64 * 0.1;
418            }
419        }
420
421        let result = cv_fdata(
422            &data,
423            &y,
424            |_train_data, train_y| {
425                let mean = train_y.iter().sum::<f64>() / train_y.len() as f64;
426                Box::new(mean)
427            },
428            |model, test_data| {
429                let mean = model.downcast_ref::<f64>().unwrap();
430                vec![*mean; test_data.nrows()]
431            },
432            5,
433            1,
434            CvType::Regression,
435            false,
436            42,
437        );
438
439        assert_eq!(result.oof_predictions.len(), n);
440        assert_eq!(result.nrep, 1);
441        assert!(result.oof_sd.is_none());
442        match &result.metrics {
443            CvMetrics::Regression { rmse, .. } => assert!(*rmse > 0.0),
444            _ => panic!("Expected regression metrics"),
445        }
446    }
447
448    #[test]
449    fn test_cv_fdata_repeated() {
450        let n = 20;
451        let m = 3;
452        let data = FdMatrix::zeros(n, m);
453        let y: Vec<f64> = (0..n).map(|i| (i % 2) as f64).collect();
454
455        let result = cv_fdata(
456            &data,
457            &y,
458            |_d, _y| Box::new(0.5_f64),
459            |_model, test_data| vec![0.5; test_data.nrows()],
460            5,
461            3,
462            CvType::Regression,
463            false,
464            42,
465        );
466
467        assert_eq!(result.nrep, 3);
468        assert!(result.oof_sd.is_some());
469        assert!(result.rep_metrics.is_some());
470        assert_eq!(result.rep_metrics.as_ref().unwrap().len(), 3);
471    }
472
473    #[test]
474    fn test_compute_metrics_classification() {
475        let y_true = vec![0.0, 0.0, 1.0, 1.0];
476        let y_pred = vec![0.0, 1.0, 1.0, 1.0]; // 1 misclassification
477        let m = compute_metrics(&y_true, &y_pred, CvType::Classification);
478        match m {
479            CvMetrics::Classification {
480                accuracy,
481                confusion,
482            } => {
483                assert!((accuracy - 0.75).abs() < 1e-10);
484                assert_eq!(confusion[0][0], 1); // true 0, pred 0
485                assert_eq!(confusion[0][1], 1); // true 0, pred 1
486                assert_eq!(confusion[1][1], 2); // true 1, pred 1
487            }
488            _ => panic!("Expected classification metrics"),
489        }
490    }
491}