Skip to main content

fdars_core/
cv.rs

1//! Cross-validation utilities and unified CV framework.
2//!
3//! This module provides:
4//! - Shared fold assignment utilities used across all CV functions
5//! - [`cv_fdata`]: Generic k-fold + repeated CV framework (R's `cv.fdata`)
6
7use crate::matrix::FdMatrix;
8use rand::prelude::*;
9use std::any::Any;
10use std::collections::HashMap;
11
12// ─── Fold Utilities ─────────────────────────────────────────────────────────
13
14/// Assign observations to folds (deterministic given seed).
15///
16/// Returns a vector of length `n` where element `i` is the fold index (0..n_folds)
17/// that observation `i` belongs to.
18pub fn create_folds(n: usize, n_folds: usize, seed: u64) -> Vec<usize> {
19    let n_folds = n_folds.max(1);
20    let mut rng = StdRng::seed_from_u64(seed);
21    let mut indices: Vec<usize> = (0..n).collect();
22    indices.shuffle(&mut rng);
23
24    let mut folds = vec![0usize; n];
25    for (rank, &idx) in indices.iter().enumerate() {
26        folds[idx] = rank % n_folds;
27    }
28    folds
29}
30
31/// Assign observations to stratified folds (classification).
32///
33/// Ensures each fold has approximately the same class distribution.
34pub fn create_stratified_folds(n: usize, y: &[usize], n_folds: usize, seed: u64) -> Vec<usize> {
35    let n_folds = n_folds.max(1);
36    let mut rng = StdRng::seed_from_u64(seed);
37    let n_classes = y.iter().copied().max().unwrap_or(0) + 1;
38
39    let mut folds = vec![0usize; n];
40
41    // Group indices by class
42    let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
43    for i in 0..n {
44        if y[i] < n_classes {
45            class_indices[y[i]].push(i);
46        }
47    }
48
49    // Shuffle within each class, then assign folds round-robin
50    for indices in &mut class_indices {
51        indices.shuffle(&mut rng);
52        for (rank, &idx) in indices.iter().enumerate() {
53            folds[idx] = rank % n_folds;
54        }
55    }
56
57    folds
58}
59
60/// Split indices into train and test sets for a given fold.
61///
62/// Returns `(train_indices, test_indices)`.
63pub fn fold_indices(folds: &[usize], fold: usize) -> (Vec<usize>, Vec<usize>) {
64    let train: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] != fold).collect();
65    let test: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] == fold).collect();
66    (train, test)
67}
68
69/// Extract a sub-matrix from an FdMatrix by selecting specific row indices.
70pub fn subset_rows(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
71    let m = data.ncols();
72    let n_sub = indices.len();
73    let mut sub = FdMatrix::zeros(n_sub, m);
74    for (new_i, &orig_i) in indices.iter().enumerate() {
75        for j in 0..m {
76            sub[(new_i, j)] = data[(orig_i, j)];
77        }
78    }
79    sub
80}
81
82/// Extract elements from a slice by indices.
83pub fn subset_vec(v: &[f64], indices: &[usize]) -> Vec<f64> {
84    indices.iter().map(|&i| v[i]).collect()
85}
86
87// ─── CV Metrics ─────────────────────────────────────────────────────────────
88
89/// Type of cross-validation task.
90#[derive(Debug, Clone, Copy, PartialEq)]
91#[non_exhaustive]
92pub enum CvType {
93    Regression,
94    Classification,
95}
96
97/// Cross-validation metrics.
98#[derive(Debug, Clone, PartialEq)]
99#[non_exhaustive]
100pub enum CvMetrics {
101    /// Regression metrics.
102    Regression { rmse: f64, mae: f64, r_squared: f64 },
103    /// Classification metrics.
104    Classification {
105        accuracy: f64,
106        confusion: Vec<Vec<usize>>,
107    },
108}
109
110/// A named metric function: `(name, fn(y_true, y_pred) -> f64)`.
111pub type MetricFn = (&'static str, fn(&[f64], &[f64]) -> f64);
112
113// ─── Built-in Regression Metrics ────────────────────────────────────────────
114
115/// Root Mean Squared Error.
116pub fn metric_rmse(y_true: &[f64], y_pred: &[f64]) -> f64 {
117    let n = y_true.len().min(y_pred.len());
118    if n == 0 {
119        return f64::NAN;
120    }
121    let mse: f64 = (0..n).map(|i| (y_true[i] - y_pred[i]).powi(2)).sum::<f64>() / n as f64;
122    mse.sqrt()
123}
124
125/// Mean Absolute Error.
126pub fn metric_mae(y_true: &[f64], y_pred: &[f64]) -> f64 {
127    let n = y_true.len().min(y_pred.len());
128    if n == 0 {
129        return f64::NAN;
130    }
131    (0..n).map(|i| (y_true[i] - y_pred[i]).abs()).sum::<f64>() / n as f64
132}
133
134/// Coefficient of determination (R-squared).
135pub fn metric_r_squared(y_true: &[f64], y_pred: &[f64]) -> f64 {
136    let n = y_true.len().min(y_pred.len());
137    if n == 0 {
138        return f64::NAN;
139    }
140    let mean = y_true.iter().sum::<f64>() / n as f64;
141    let ss_res: f64 = (0..n).map(|i| (y_true[i] - y_pred[i]).powi(2)).sum();
142    let ss_tot: f64 = (0..n).map(|i| (y_true[i] - mean).powi(2)).sum();
143    if ss_tot > 1e-15 {
144        1.0 - ss_res / ss_tot
145    } else {
146        0.0
147    }
148}
149
150/// Default regression metric set: RMSE, MAE, R-squared.
151pub fn regression_metrics() -> Vec<MetricFn> {
152    vec![
153        ("rmse", metric_rmse as fn(&[f64], &[f64]) -> f64),
154        ("mae", metric_mae),
155        ("r_squared", metric_r_squared),
156    ]
157}
158
159// ─── Built-in Classification Metrics ────────────────────────────────────────
160
161/// Classification accuracy.
162pub fn metric_accuracy(y_true: &[f64], y_pred: &[f64]) -> f64 {
163    let n = y_true.len().min(y_pred.len());
164    if n == 0 {
165        return f64::NAN;
166    }
167    let correct = (0..n)
168        .filter(|&i| (y_true[i] as usize) == (y_pred[i].round() as usize))
169        .count();
170    correct as f64 / n as f64
171}
172
173/// Macro (binary) precision: TP / (TP + FP).
174pub fn metric_precision(y_true: &[f64], y_pred: &[f64]) -> f64 {
175    let n = y_true.len().min(y_pred.len());
176    let mut tp = 0usize;
177    let mut fp = 0usize;
178    for i in 0..n {
179        let pred = y_pred[i].round() as usize;
180        let true_c = y_true[i] as usize;
181        if pred == 1 {
182            if true_c == 1 {
183                tp += 1;
184            } else {
185                fp += 1;
186            }
187        }
188    }
189    if tp + fp > 0 {
190        tp as f64 / (tp + fp) as f64
191    } else {
192        0.0
193    }
194}
195
196/// Macro (binary) recall: TP / (TP + FN).
197pub fn metric_recall(y_true: &[f64], y_pred: &[f64]) -> f64 {
198    let n = y_true.len().min(y_pred.len());
199    let mut tp = 0usize;
200    let mut fn_ = 0usize;
201    for i in 0..n {
202        let pred = y_pred[i].round() as usize;
203        let true_c = y_true[i] as usize;
204        if true_c == 1 {
205            if pred == 1 {
206                tp += 1;
207            } else {
208                fn_ += 1;
209            }
210        }
211    }
212    if tp + fn_ > 0 {
213        tp as f64 / (tp + fn_) as f64
214    } else {
215        0.0
216    }
217}
218
219/// F1 score (harmonic mean of precision and recall).
220pub fn metric_f1(y_true: &[f64], y_pred: &[f64]) -> f64 {
221    let p = metric_precision(y_true, y_pred);
222    let r = metric_recall(y_true, y_pred);
223    if p + r > 0.0 {
224        2.0 * p * r / (p + r)
225    } else {
226        0.0
227    }
228}
229
230/// Default classification metric set: accuracy, precision, recall, F1.
231pub fn classification_metrics() -> Vec<MetricFn> {
232    vec![
233        ("accuracy", metric_accuracy as fn(&[f64], &[f64]) -> f64),
234        ("precision", metric_precision),
235        ("recall", metric_recall),
236        ("f1", metric_f1),
237    ]
238}
239
240/// Evaluate a set of metric functions on (y_true, y_pred).
241fn evaluate_metrics(
242    y_true: &[f64],
243    y_pred: &[f64],
244    metric_fns: &[MetricFn],
245) -> HashMap<String, f64> {
246    metric_fns
247        .iter()
248        .map(|(name, f)| ((*name).to_string(), f(y_true, y_pred)))
249        .collect()
250}
251
252/// Result of unified cross-validation.
253#[derive(Debug, Clone, PartialEq)]
254#[non_exhaustive]
255pub struct CvFdataResult {
256    /// Out-of-fold predictions (length n); for repeated CV, averaged across reps.
257    pub oof_predictions: Vec<f64>,
258    /// Overall metrics (built-in).
259    pub metrics: CvMetrics,
260    /// Per-fold metrics (built-in).
261    pub fold_metrics: Vec<CvMetrics>,
262    /// Fold assignments from the last (or only) repetition.
263    pub folds: Vec<usize>,
264    /// Type of CV task.
265    pub cv_type: CvType,
266    /// Number of repetitions.
267    pub nrep: usize,
268    /// Standard deviation of OOF predictions across repetitions (only when nrep > 1).
269    pub oof_sd: Option<Vec<f64>>,
270    /// Per-repetition overall metrics (only when nrep > 1).
271    pub rep_metrics: Option<Vec<CvMetrics>>,
272    /// Custom metrics evaluated on OOF predictions (name -> value).
273    pub custom_metrics: HashMap<String, f64>,
274    /// Per-fold custom metrics.
275    pub fold_custom_metrics: Vec<HashMap<String, f64>>,
276}
277
278// ─── Unified CV Framework ────────────────────────────────────────────────────
279
280/// Create CV folds based on strategy (stratified or random).
281fn create_cv_folds(
282    n: usize,
283    y: &[f64],
284    n_folds: usize,
285    cv_type: CvType,
286    stratified: bool,
287    seed: u64,
288) -> Vec<usize> {
289    if stratified {
290        match cv_type {
291            CvType::Classification => {
292                let y_class: Vec<usize> = y
293                    .iter()
294                    .map(|&v| crate::utility::f64_to_usize_clamped(v))
295                    .collect();
296                create_stratified_folds(n, &y_class, n_folds, seed)
297            }
298            CvType::Regression => {
299                let mut sorted_y: Vec<(usize, f64)> = y.iter().copied().enumerate().collect();
300                sorted_y.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
301                let n_bins = n_folds.min(n);
302                let bin_labels: Vec<usize> = {
303                    let mut labels = vec![0usize; n];
304                    for (rank, &(orig_i, _)) in sorted_y.iter().enumerate() {
305                        labels[orig_i] = (rank * n_bins / n).min(n_bins - 1);
306                    }
307                    labels
308                };
309                create_stratified_folds(n, &bin_labels, n_folds, seed)
310            }
311        }
312    } else {
313        create_folds(n, n_folds, seed)
314    }
315}
316
317/// Aggregate out-of-fold predictions across repetitions (mean and SD).
318fn aggregate_oof_predictions(all_oof: Vec<Vec<f64>>, n: usize) -> (Vec<f64>, Option<Vec<f64>>) {
319    let nrep = all_oof.len();
320    if nrep == 1 {
321        return (
322            all_oof.into_iter().next().expect("non-empty iterator"),
323            None,
324        );
325    }
326    let mut mean_oof = vec![0.0; n];
327    for oof in &all_oof {
328        for i in 0..n {
329            mean_oof[i] += oof[i];
330        }
331    }
332    for v in &mut mean_oof {
333        *v /= nrep as f64;
334    }
335
336    let mut sd_oof = vec![0.0; n];
337    for oof in &all_oof {
338        for i in 0..n {
339            let diff = oof[i] - mean_oof[i];
340            sd_oof[i] += diff * diff;
341        }
342    }
343    for v in &mut sd_oof {
344        *v = (*v / (nrep as f64 - 1.0).max(1.0)).sqrt();
345    }
346
347    (mean_oof, Some(sd_oof))
348}
349
350/// Generic k-fold + repeated cross-validation framework (R's `cv.fdata`).
351///
352/// Uses built-in metrics (RMSE/MAE/R² for regression, accuracy/confusion for
353/// classification). For custom metrics, use [`cv_fdata_with_metrics`].
354pub fn cv_fdata<F, P>(
355    data: &FdMatrix,
356    y: &[f64],
357    fit_fn: F,
358    predict_fn: P,
359    n_folds: usize,
360    nrep: usize,
361    cv_type: CvType,
362    stratified: bool,
363    seed: u64,
364) -> CvFdataResult
365where
366    F: Fn(&FdMatrix, &[f64]) -> Box<dyn Any>,
367    P: Fn(&dyn Any, &FdMatrix) -> Vec<f64>,
368{
369    cv_fdata_with_metrics(
370        data,
371        y,
372        fit_fn,
373        predict_fn,
374        n_folds,
375        nrep,
376        cv_type,
377        stratified,
378        seed,
379        &[],
380    )
381}
382
383/// Generic k-fold + repeated CV with user-defined metrics.
384///
385/// Same as [`cv_fdata`] but accepts a slice of [`MetricFn`] that are evaluated
386/// on each fold's (y_true, y_pred) and on the overall OOF predictions.
387///
388/// # Examples
389///
390/// ```
391/// use fdars_core::cv::*;
392/// use fdars_core::matrix::FdMatrix;
393/// use std::any::Any;
394///
395/// let data = FdMatrix::zeros(20, 5);
396/// let y: Vec<f64> = (0..20).map(|i| i as f64).collect();
397///
398/// // Custom metric: median absolute error
399/// fn median_ae(y_true: &[f64], y_pred: &[f64]) -> f64 {
400///     let mut errs: Vec<f64> = y_true.iter().zip(y_pred)
401///         .map(|(&a, &b)| (a - b).abs()).collect();
402///     errs.sort_by(|a, b| a.partial_cmp(b).unwrap());
403///     errs[errs.len() / 2]
404/// }
405///
406/// let mut metrics = regression_metrics();
407/// metrics.push(("median_ae", median_ae));
408///
409/// let result = cv_fdata_with_metrics(
410///     &data, &y,
411///     |_d, y| Box::new(y.iter().sum::<f64>() / y.len() as f64),
412///     |m, td| { let v = *m.downcast_ref::<f64>().unwrap(); vec![v; td.nrows()] },
413///     5, 1, CvType::Regression, false, 42,
414///     &metrics,
415/// );
416/// assert!(result.custom_metrics.contains_key("rmse"));
417/// assert!(result.custom_metrics.contains_key("median_ae"));
418/// ```
419pub fn cv_fdata_with_metrics<F, P>(
420    data: &FdMatrix,
421    y: &[f64],
422    fit_fn: F,
423    predict_fn: P,
424    n_folds: usize,
425    nrep: usize,
426    cv_type: CvType,
427    stratified: bool,
428    seed: u64,
429    metric_fns: &[MetricFn],
430) -> CvFdataResult
431where
432    F: Fn(&FdMatrix, &[f64]) -> Box<dyn Any>,
433    P: Fn(&dyn Any, &FdMatrix) -> Vec<f64>,
434{
435    let n = data.nrows();
436    let nrep = nrep.max(1);
437    let n_folds = n_folds.max(2).min(n);
438
439    let mut all_oof: Vec<Vec<f64>> = Vec::with_capacity(nrep);
440    let mut all_rep_metrics: Vec<CvMetrics> = Vec::with_capacity(nrep);
441    let mut last_folds = vec![0usize; n];
442    let mut last_fold_metrics = Vec::new();
443    let mut last_fold_custom = Vec::new();
444
445    for r in 0..nrep {
446        let rep_seed = seed.wrapping_add(r as u64);
447        let folds = create_cv_folds(n, y, n_folds, cv_type, stratified, rep_seed);
448
449        let mut oof_preds = vec![0.0; n];
450        let mut fold_metrics = Vec::with_capacity(n_folds);
451        let mut fold_custom = Vec::with_capacity(n_folds);
452
453        for fold in 0..n_folds {
454            let (train_idx, test_idx) = fold_indices(&folds, fold);
455            if train_idx.is_empty() || test_idx.is_empty() {
456                continue;
457            }
458
459            let train_data = subset_rows(data, &train_idx);
460            let train_y = subset_vec(y, &train_idx);
461            let test_data = subset_rows(data, &test_idx);
462            let test_y = subset_vec(y, &test_idx);
463
464            let model = fit_fn(&train_data, &train_y);
465            let preds = predict_fn(&*model, &test_data);
466
467            for (local_i, &orig_i) in test_idx.iter().enumerate() {
468                if local_i < preds.len() {
469                    oof_preds[orig_i] = preds[local_i];
470                }
471            }
472
473            fold_metrics.push(compute_metrics(&test_y, &preds, cv_type));
474            if !metric_fns.is_empty() {
475                fold_custom.push(evaluate_metrics(&test_y, &preds, metric_fns));
476            }
477        }
478
479        let rep_metric = compute_metrics(y, &oof_preds, cv_type);
480        all_oof.push(oof_preds);
481        all_rep_metrics.push(rep_metric);
482        last_folds = folds;
483        last_fold_metrics = fold_metrics;
484        last_fold_custom = fold_custom;
485    }
486
487    let (final_oof, oof_sd) = aggregate_oof_predictions(all_oof, n);
488    let overall_metrics = compute_metrics(y, &final_oof, cv_type);
489    let custom_metrics = if metric_fns.is_empty() {
490        HashMap::new()
491    } else {
492        evaluate_metrics(y, &final_oof, metric_fns)
493    };
494
495    CvFdataResult {
496        oof_predictions: final_oof,
497        metrics: overall_metrics,
498        fold_metrics: last_fold_metrics,
499        folds: last_folds,
500        cv_type,
501        nrep,
502        oof_sd,
503        rep_metrics: if nrep > 1 {
504            Some(all_rep_metrics)
505        } else {
506            None
507        },
508        custom_metrics,
509        fold_custom_metrics: last_fold_custom,
510    }
511}
512
513/// Compute metrics from true and predicted values.
514fn compute_metrics(y_true: &[f64], y_pred: &[f64], cv_type: CvType) -> CvMetrics {
515    let n = y_true.len().min(y_pred.len());
516    if n == 0 {
517        return match cv_type {
518            CvType::Regression => CvMetrics::Regression {
519                rmse: f64::NAN,
520                mae: f64::NAN,
521                r_squared: f64::NAN,
522            },
523            CvType::Classification => CvMetrics::Classification {
524                accuracy: 0.0,
525                confusion: Vec::new(),
526            },
527        };
528    }
529
530    match cv_type {
531        CvType::Regression => {
532            let mean_y = y_true.iter().sum::<f64>() / n as f64;
533            let mut ss_res = 0.0;
534            let mut ss_tot = 0.0;
535            let mut mae_sum = 0.0;
536            for i in 0..n {
537                let resid = y_true[i] - y_pred[i];
538                ss_res += resid * resid;
539                ss_tot += (y_true[i] - mean_y).powi(2);
540                mae_sum += resid.abs();
541            }
542            let rmse = (ss_res / n as f64).sqrt();
543            let mae = mae_sum / n as f64;
544            let r_squared = if ss_tot > 1e-15 {
545                1.0 - ss_res / ss_tot
546            } else {
547                0.0
548            };
549            CvMetrics::Regression {
550                rmse,
551                mae,
552                r_squared,
553            }
554        }
555        CvType::Classification => {
556            let n_classes = y_true
557                .iter()
558                .chain(y_pred.iter())
559                .map(|&v| v as usize)
560                .max()
561                .unwrap_or(0)
562                + 1;
563            let mut confusion = vec![vec![0usize; n_classes]; n_classes];
564            let mut correct = 0usize;
565            for i in 0..n {
566                let true_c = y_true[i] as usize;
567                let pred_c = y_pred[i].round() as usize;
568                if true_c < n_classes && pred_c < n_classes {
569                    confusion[true_c][pred_c] += 1;
570                }
571                if true_c == pred_c {
572                    correct += 1;
573                }
574            }
575            let accuracy = correct as f64 / n as f64;
576            CvMetrics::Classification {
577                accuracy,
578                confusion,
579            }
580        }
581    }
582}
583
584// ─── Generic CV Selection Result ────────────────────────────────────────────
585
586/// Generic cross-validation result for hyperparameter selection.
587///
588/// Provides a type-safe way to represent the outcome of any CV-based parameter
589/// search. Existing specialised types ([`super::scalar_on_function::FregreCvResult`],
590/// etc.) remain unchanged; new code can use `CvSelectionResult<f64>` for lambda
591/// or bandwidth CV, `CvSelectionResult<usize>` for component-count CV, and so on.
592///
593/// # Examples
594///
595/// ```
596/// use fdars_core::cv::CvSelectionResult;
597///
598/// let candidates: Vec<f64> = vec![0.01, 0.1, 1.0, 10.0];
599/// let cv_errors = vec![2.5, 1.2, 0.8, 1.5];
600/// let result = CvSelectionResult::from_search(candidates, cv_errors).unwrap();
601/// assert!((result.optimal - 1.0_f64).abs() < 1e-15);
602/// assert!((result.min_error - 0.8).abs() < 1e-15);
603/// ```
604#[derive(Debug, Clone, PartialEq)]
605#[non_exhaustive]
606pub struct CvSelectionResult<T: Clone> {
607    /// Candidate parameter values tested.
608    pub candidates: Vec<T>,
609    /// CV error (e.g., MSE) for each candidate.
610    pub cv_errors: Vec<f64>,
611    /// Optimal parameter value (minimising CV error).
612    pub optimal: T,
613    /// Minimum CV error.
614    pub min_error: f64,
615}
616
617impl<T: Clone + PartialOrd> CvSelectionResult<T> {
618    /// Create from candidates and errors, selecting the minimum.
619    ///
620    /// Returns `None` if `candidates` is empty or lengths differ.
621    #[must_use]
622    pub fn from_search(candidates: Vec<T>, cv_errors: Vec<f64>) -> Option<Self> {
623        if candidates.is_empty() || candidates.len() != cv_errors.len() {
624            return None;
625        }
626        let (idx, &min_error) = cv_errors
627            .iter()
628            .enumerate()
629            .min_by(|(_, a), (_, b)| a.total_cmp(b))?;
630        Some(Self {
631            optimal: candidates[idx].clone(),
632            candidates,
633            cv_errors,
634            min_error,
635        })
636    }
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642    use crate::error::FdarError;
643
644    #[test]
645    fn test_create_folds_basic() {
646        let folds = create_folds(10, 5, 42);
647        assert_eq!(folds.len(), 10);
648        // Each fold should have 2 members
649        for f in 0..5 {
650            let count = folds.iter().filter(|&&x| x == f).count();
651            assert_eq!(count, 2);
652        }
653    }
654
655    #[test]
656    fn test_create_folds_deterministic() {
657        let f1 = create_folds(20, 5, 123);
658        let f2 = create_folds(20, 5, 123);
659        assert_eq!(f1, f2);
660    }
661
662    #[test]
663    fn test_stratified_folds() {
664        let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
665        let folds = create_stratified_folds(10, &y, 5, 42);
666        assert_eq!(folds.len(), 10);
667        // Each fold should have 1 from each class
668        for f in 0..5 {
669            let class0_count = (0..10).filter(|&i| folds[i] == f && y[i] == 0).count();
670            let class1_count = (0..10).filter(|&i| folds[i] == f && y[i] == 1).count();
671            assert_eq!(class0_count, 1);
672            assert_eq!(class1_count, 1);
673        }
674    }
675
676    #[test]
677    fn test_fold_indices() {
678        let folds = vec![0, 1, 2, 0, 1, 2];
679        let (train, test) = fold_indices(&folds, 1);
680        assert_eq!(test, vec![1, 4]);
681        assert_eq!(train, vec![0, 2, 3, 5]);
682    }
683
684    #[test]
685    fn test_subset_rows() {
686        let mut data = FdMatrix::zeros(4, 3);
687        for i in 0..4 {
688            for j in 0..3 {
689                data[(i, j)] = (i * 10 + j) as f64;
690            }
691        }
692        let sub = subset_rows(&data, &[1, 3]);
693        assert_eq!(sub.nrows(), 2);
694        assert_eq!(sub.ncols(), 3);
695        assert!((sub[(0, 0)] - 10.0).abs() < 1e-10);
696        assert!((sub[(1, 0)] - 30.0).abs() < 1e-10);
697    }
698
699    #[test]
700    fn test_cv_fdata_regression() -> Result<(), FdarError> {
701        // Simple test: predict mean
702        let n = 20;
703        let m = 5;
704        let mut data = FdMatrix::zeros(n, m);
705        let y: Vec<f64> = (0..n).map(|i| i as f64).collect();
706        for i in 0..n {
707            for j in 0..m {
708                data[(i, j)] = y[i] + j as f64 * 0.1;
709            }
710        }
711
712        let result = cv_fdata(
713            &data,
714            &y,
715            |_train_data, train_y| {
716                let mean = train_y.iter().sum::<f64>() / train_y.len() as f64;
717                Box::new(mean)
718            },
719            |model, test_data| {
720                let mean = model.downcast_ref::<f64>().unwrap();
721                vec![*mean; test_data.nrows()]
722            },
723            5,
724            1,
725            CvType::Regression,
726            false,
727            42,
728        );
729
730        assert_eq!(result.oof_predictions.len(), n);
731        assert_eq!(result.nrep, 1);
732        assert!(result.oof_sd.is_none());
733        match &result.metrics {
734            CvMetrics::Regression { rmse, .. } => assert!(*rmse > 0.0),
735            _ => {
736                return Err(FdarError::ComputationFailed {
737                    operation: "cv_fdata_regression",
738                    detail: "expected regression metrics".into(),
739                });
740            }
741        }
742        Ok(())
743    }
744
745    #[test]
746    fn test_cv_fdata_repeated() {
747        let n = 20;
748        let m = 3;
749        let data = FdMatrix::zeros(n, m);
750        let y: Vec<f64> = (0..n).map(|i| (i % 2) as f64).collect();
751
752        let result = cv_fdata(
753            &data,
754            &y,
755            |_d, _y| Box::new(0.5_f64),
756            |_model, test_data| vec![0.5; test_data.nrows()],
757            5,
758            3,
759            CvType::Regression,
760            false,
761            42,
762        );
763
764        assert_eq!(result.nrep, 3);
765        assert!(result.oof_sd.is_some());
766        assert!(result.rep_metrics.is_some());
767        assert_eq!(result.rep_metrics.as_ref().unwrap().len(), 3);
768    }
769
770    #[test]
771    fn test_custom_metrics() {
772        let n = 20;
773        let m = 3;
774        let data = FdMatrix::zeros(n, m);
775        let y: Vec<f64> = (0..n).map(|i| i as f64).collect();
776
777        let metrics = regression_metrics();
778        let result = cv_fdata_with_metrics(
779            &data,
780            &y,
781            |_d, train_y| {
782                let mean = train_y.iter().sum::<f64>() / train_y.len() as f64;
783                Box::new(mean)
784            },
785            |model, test_data| {
786                let mean = model.downcast_ref::<f64>().unwrap();
787                vec![*mean; test_data.nrows()]
788            },
789            5,
790            1,
791            CvType::Regression,
792            false,
793            42,
794            &metrics,
795        );
796
797        assert!(result.custom_metrics.contains_key("rmse"));
798        assert!(result.custom_metrics.contains_key("mae"));
799        assert!(result.custom_metrics.contains_key("r_squared"));
800        assert!(*result.custom_metrics.get("rmse").unwrap() > 0.0);
801        assert_eq!(result.fold_custom_metrics.len(), 5);
802    }
803
804    #[test]
805    fn test_classification_metrics_standalone() {
806        let y_true = vec![0.0, 0.0, 1.0, 1.0, 1.0];
807        let y_pred = vec![0.0, 1.0, 1.0, 1.0, 0.0];
808        assert!((metric_accuracy(&y_true, &y_pred) - 0.6).abs() < 1e-10);
809        assert!((metric_precision(&y_true, &y_pred) - 2.0 / 3.0).abs() < 1e-10); // TP=2, FP=1
810        assert!((metric_recall(&y_true, &y_pred) - 2.0 / 3.0).abs() < 1e-10); // TP=2, FN=1
811        let f1 = metric_f1(&y_true, &y_pred);
812        assert!((f1 - 2.0 / 3.0).abs() < 1e-10); // P=R => F1=P=R
813    }
814
815    #[test]
816    fn test_compute_metrics_classification() -> Result<(), FdarError> {
817        let y_true = vec![0.0, 0.0, 1.0, 1.0];
818        let y_pred = vec![0.0, 1.0, 1.0, 1.0]; // 1 misclassification
819        let m = compute_metrics(&y_true, &y_pred, CvType::Classification);
820        match m {
821            CvMetrics::Classification {
822                accuracy,
823                confusion,
824            } => {
825                assert!((accuracy - 0.75).abs() < 1e-10);
826                assert_eq!(confusion[0][0], 1); // true 0, pred 0
827                assert_eq!(confusion[0][1], 1); // true 0, pred 1
828                assert_eq!(confusion[1][1], 2); // true 1, pred 1
829            }
830            _ => {
831                return Err(FdarError::ComputationFailed {
832                    operation: "compute_metrics_classification",
833                    detail: "expected classification metrics".into(),
834                });
835            }
836        }
837        Ok(())
838    }
839
840    // ── CvSelectionResult ───────────────────────────────────────────────
841
842    #[test]
843    fn cv_selection_basic() {
844        let candidates: Vec<f64> = vec![0.01, 0.1, 1.0, 10.0];
845        let cv_errors = vec![2.5, 1.2, 0.8, 1.5];
846        let result = CvSelectionResult::from_search(candidates, cv_errors).unwrap();
847        assert!((result.optimal - 1.0_f64).abs() < 1e-15);
848        assert!((result.min_error - 0.8).abs() < 1e-15);
849        assert_eq!(result.candidates.len(), 4);
850        assert_eq!(result.cv_errors.len(), 4);
851    }
852
853    #[test]
854    fn cv_selection_usize() {
855        let candidates = vec![1usize, 2, 3, 4, 5];
856        let cv_errors = vec![3.0, 2.0, 1.0, 1.5, 2.5];
857        let result = CvSelectionResult::from_search(candidates, cv_errors).unwrap();
858        assert_eq!(result.optimal, 3);
859        assert!((result.min_error - 1.0).abs() < 1e-15);
860    }
861
862    #[test]
863    fn cv_selection_empty() {
864        let result = CvSelectionResult::<f64>::from_search(vec![], vec![]);
865        assert!(result.is_none());
866    }
867
868    #[test]
869    fn cv_selection_length_mismatch() {
870        let result = CvSelectionResult::<f64>::from_search(vec![1.0, 2.0], vec![1.0]);
871        assert!(result.is_none());
872    }
873
874    #[test]
875    fn cv_selection_nan_handling() {
876        // NaN errors should be ordered after finite values
877        let candidates: Vec<f64> = vec![1.0, 2.0, 3.0];
878        let cv_errors = vec![f64::NAN, 0.5, f64::NAN];
879        let result = CvSelectionResult::from_search(candidates, cv_errors).unwrap();
880        assert!((result.optimal - 2.0_f64).abs() < 1e-15);
881        assert!((result.min_error - 0.5).abs() < 1e-15);
882    }
883}