Skip to main content

fdars_core/alignment/
lambda_cv.rs

1//! Cross-validation for the elastic alignment regularisation parameter lambda.
2
3use super::karcher::karcher_mean;
4use super::pairwise::elastic_distance;
5use crate::cv::{create_folds, fold_indices, subset_rows};
6use crate::error::FdarError;
7use crate::matrix::FdMatrix;
8
9// ─── Config / Result ─────────────────────────────────────────────────────────
10
11/// Configuration for lambda cross-validation.
12#[derive(Debug, Clone, PartialEq)]
13pub struct LambdaCvConfig {
14    /// Candidate lambda values to evaluate.
15    pub lambdas: Vec<f64>,
16    /// Number of folds (0 = leave-one-out).
17    pub n_folds: usize,
18    /// Maximum Karcher iterations per fold.
19    pub max_iter: usize,
20    /// Karcher convergence tolerance.
21    pub tol: f64,
22    /// RNG seed for fold assignment.
23    pub seed: u64,
24}
25
26impl Default for LambdaCvConfig {
27    fn default() -> Self {
28        Self {
29            lambdas: vec![0.0, 0.01, 0.1, 1.0, 10.0],
30            n_folds: 5,
31            max_iter: 15,
32            tol: 1e-3,
33            seed: 42,
34        }
35    }
36}
37
38/// Result of lambda cross-validation.
39#[derive(Debug, Clone, PartialEq)]
40#[non_exhaustive]
41pub struct LambdaCvResult {
42    /// Lambda with the lowest mean CV score.
43    pub best_lambda: f64,
44    /// Mean CV score for each candidate lambda (same order as `lambdas`).
45    pub cv_scores: Vec<f64>,
46    /// Candidate lambda values (copied from config).
47    pub lambdas: Vec<f64>,
48}
49
50// ─── Cross-validation ────────────────────────────────────────────────────────
51
52/// Select the best elastic-alignment regularisation parameter via K-fold
53/// cross-validation.
54///
55/// For each candidate lambda the data are split into K folds. A Karcher mean
56/// is computed on the training set and every held-out curve is scored by its
57/// elastic distance to that mean. The lambda with the lowest average
58/// held-out distance wins.
59///
60/// # Arguments
61/// * `data`    — Functional data matrix (n x m).
62/// * `argvals` — Evaluation grid (length m).
63/// * `config`  — Cross-validation settings (lambdas, folds, iterations, …).
64///
65/// # Errors
66/// Returns `FdarError::InvalidDimension` if `data` has fewer than 4 rows
67/// or `argvals` length does not match `data.ncols()`.
68/// Returns `FdarError::InvalidParameter` if any lambda is negative or
69/// `n_folds` is 1.
70#[must_use = "expensive computation whose result should not be discarded"]
71pub fn lambda_cv(
72    data: &FdMatrix,
73    argvals: &[f64],
74    config: &LambdaCvConfig,
75) -> Result<LambdaCvResult, FdarError> {
76    let n = data.nrows();
77    let m = data.ncols();
78
79    // ── Validation ──────────────────────────────────────────────────────
80    if n < 4 {
81        return Err(FdarError::InvalidDimension {
82            parameter: "data",
83            expected: "at least 4 rows".to_string(),
84            actual: format!("{n} rows"),
85        });
86    }
87    if argvals.len() != m {
88        return Err(FdarError::InvalidDimension {
89            parameter: "argvals",
90            expected: format!("{m}"),
91            actual: format!("{}", argvals.len()),
92        });
93    }
94    if config.lambdas.iter().any(|&l| l < 0.0) {
95        return Err(FdarError::InvalidParameter {
96            parameter: "lambdas",
97            message: "all lambda values must be >= 0".to_string(),
98        });
99    }
100    if config.n_folds == 1 {
101        return Err(FdarError::InvalidParameter {
102            parameter: "n_folds",
103            message: "n_folds must be > 1 or 0 (leave-one-out)".to_string(),
104        });
105    }
106
107    let actual_folds = if config.n_folds == 0 {
108        n
109    } else {
110        config.n_folds
111    };
112    let folds = create_folds(n, actual_folds, config.seed);
113
114    // Number of distinct fold labels actually produced.
115    let k_max = *folds.iter().max().unwrap_or(&0) + 1;
116
117    // ── Evaluate each lambda ────────────────────────────────────────────
118    let mut cv_scores = Vec::with_capacity(config.lambdas.len());
119
120    for &lambda in &config.lambdas {
121        let mut fold_scores = Vec::with_capacity(k_max);
122
123        for k in 0..k_max {
124            let (train_idx, test_idx) = fold_indices(&folds, k);
125            if train_idx.is_empty() || test_idx.is_empty() {
126                continue;
127            }
128
129            let train_data = subset_rows(data, &train_idx);
130            let km = karcher_mean(&train_data, argvals, config.max_iter, config.tol, lambda);
131
132            let fold_dist: f64 = test_idx
133                .iter()
134                .map(|&idx| {
135                    let test_curve = data.row(idx);
136                    elastic_distance(&test_curve, &km.mean, argvals, lambda)
137                })
138                .sum::<f64>()
139                / test_idx.len() as f64;
140
141            fold_scores.push(fold_dist);
142        }
143
144        let mean_score = if fold_scores.is_empty() {
145            f64::INFINITY
146        } else {
147            fold_scores.iter().sum::<f64>() / fold_scores.len() as f64
148        };
149        cv_scores.push(mean_score);
150    }
151
152    // ── Pick best lambda ────────────────────────────────────────────────
153    let best_idx = cv_scores
154        .iter()
155        .enumerate()
156        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
157        .map(|(i, _)| i)
158        .unwrap_or(0);
159
160    Ok(LambdaCvResult {
161        best_lambda: config.lambdas[best_idx],
162        cv_scores,
163        lambdas: config.lambdas.clone(),
164    })
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::simulation::{sim_fundata, EFunType, EValType};
171    use crate::test_helpers::uniform_grid;
172
173    fn make_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
174        let t = uniform_grid(m);
175        let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
176        (data, t)
177    }
178
179    #[test]
180    fn lambda_cv_default_config() {
181        let (data, t) = make_test_data(8, 30);
182        let config = LambdaCvConfig {
183            max_iter: 5,
184            tol: 1e-2,
185            ..LambdaCvConfig::default()
186        };
187        let result = lambda_cv(&data, &t, &config).unwrap();
188        assert_eq!(result.cv_scores.len(), config.lambdas.len());
189        assert!(result.best_lambda >= 0.0);
190        assert!(result.cv_scores.iter().all(|&s| s.is_finite()));
191    }
192
193    #[test]
194    fn lambda_cv_loo() {
195        let (data, t) = make_test_data(6, 25);
196        let config = LambdaCvConfig {
197            lambdas: vec![0.0, 1.0],
198            n_folds: 0,
199            max_iter: 3,
200            tol: 1e-2,
201            seed: 7,
202        };
203        let result = lambda_cv(&data, &t, &config).unwrap();
204        assert_eq!(result.cv_scores.len(), 2);
205    }
206
207    #[test]
208    fn lambda_cv_rejects_too_few_rows() {
209        let t = uniform_grid(10);
210        let data = sim_fundata(3, &t, 2, EFunType::Fourier, EValType::Exponential, Some(0));
211        let config = LambdaCvConfig::default();
212        assert!(lambda_cv(&data, &t, &config).is_err());
213    }
214
215    #[test]
216    fn lambda_cv_rejects_negative_lambda() {
217        let (data, t) = make_test_data(8, 20);
218        let config = LambdaCvConfig {
219            lambdas: vec![-1.0, 0.0],
220            ..LambdaCvConfig::default()
221        };
222        assert!(lambda_cv(&data, &t, &config).is_err());
223    }
224
225    #[test]
226    fn lambda_cv_rejects_one_fold() {
227        let (data, t) = make_test_data(8, 20);
228        let config = LambdaCvConfig {
229            n_folds: 1,
230            ..LambdaCvConfig::default()
231        };
232        assert!(lambda_cv(&data, &t, &config).is_err());
233    }
234
235    #[test]
236    fn lambda_cv_rejects_argval_mismatch() {
237        let (data, _) = make_test_data(8, 20);
238        let bad_t = uniform_grid(15);
239        let config = LambdaCvConfig::default();
240        assert!(lambda_cv(&data, &bad_t, &config).is_err());
241    }
242}