1use super::karcher::karcher_mean;
4use super::pairwise::elastic_distance;
5use crate::cv::{create_folds, fold_indices, subset_rows};
6use crate::error::FdarError;
7use crate::matrix::FdMatrix;
8
9#[derive(Debug, Clone, PartialEq)]
13pub struct LambdaCvConfig {
14 pub lambdas: Vec<f64>,
16 pub n_folds: usize,
18 pub max_iter: usize,
20 pub tol: f64,
22 pub seed: u64,
24}
25
26impl Default for LambdaCvConfig {
27 fn default() -> Self {
28 Self {
29 lambdas: vec![0.0, 0.01, 0.1, 1.0, 10.0],
30 n_folds: 5,
31 max_iter: 15,
32 tol: 1e-3,
33 seed: 42,
34 }
35 }
36}
37
38#[derive(Debug, Clone, PartialEq)]
40#[non_exhaustive]
41pub struct LambdaCvResult {
42 pub best_lambda: f64,
44 pub cv_scores: Vec<f64>,
46 pub lambdas: Vec<f64>,
48}
49
50#[must_use = "expensive computation whose result should not be discarded"]
71pub fn lambda_cv(
72 data: &FdMatrix,
73 argvals: &[f64],
74 config: &LambdaCvConfig,
75) -> Result<LambdaCvResult, FdarError> {
76 let n = data.nrows();
77 let m = data.ncols();
78
79 if n < 4 {
81 return Err(FdarError::InvalidDimension {
82 parameter: "data",
83 expected: "at least 4 rows".to_string(),
84 actual: format!("{n} rows"),
85 });
86 }
87 if argvals.len() != m {
88 return Err(FdarError::InvalidDimension {
89 parameter: "argvals",
90 expected: format!("{m}"),
91 actual: format!("{}", argvals.len()),
92 });
93 }
94 if config.lambdas.iter().any(|&l| l < 0.0) {
95 return Err(FdarError::InvalidParameter {
96 parameter: "lambdas",
97 message: "all lambda values must be >= 0".to_string(),
98 });
99 }
100 if config.n_folds == 1 {
101 return Err(FdarError::InvalidParameter {
102 parameter: "n_folds",
103 message: "n_folds must be > 1 or 0 (leave-one-out)".to_string(),
104 });
105 }
106
107 let actual_folds = if config.n_folds == 0 {
108 n
109 } else {
110 config.n_folds
111 };
112 let folds = create_folds(n, actual_folds, config.seed);
113
114 let k_max = *folds.iter().max().unwrap_or(&0) + 1;
116
117 let mut cv_scores = Vec::with_capacity(config.lambdas.len());
119
120 for &lambda in &config.lambdas {
121 let mut fold_scores = Vec::with_capacity(k_max);
122
123 for k in 0..k_max {
124 let (train_idx, test_idx) = fold_indices(&folds, k);
125 if train_idx.is_empty() || test_idx.is_empty() {
126 continue;
127 }
128
129 let train_data = subset_rows(data, &train_idx);
130 let km = karcher_mean(&train_data, argvals, config.max_iter, config.tol, lambda);
131
132 let fold_dist: f64 = test_idx
133 .iter()
134 .map(|&idx| {
135 let test_curve = data.row(idx);
136 elastic_distance(&test_curve, &km.mean, argvals, lambda)
137 })
138 .sum::<f64>()
139 / test_idx.len() as f64;
140
141 fold_scores.push(fold_dist);
142 }
143
144 let mean_score = if fold_scores.is_empty() {
145 f64::INFINITY
146 } else {
147 fold_scores.iter().sum::<f64>() / fold_scores.len() as f64
148 };
149 cv_scores.push(mean_score);
150 }
151
152 let best_idx = cv_scores
154 .iter()
155 .enumerate()
156 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
157 .map(|(i, _)| i)
158 .unwrap_or(0);
159
160 Ok(LambdaCvResult {
161 best_lambda: config.lambdas[best_idx],
162 cv_scores,
163 lambdas: config.lambdas.clone(),
164 })
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::simulation::{sim_fundata, EFunType, EValType};
171 use crate::test_helpers::uniform_grid;
172
173 fn make_test_data(n: usize, m: usize) -> (FdMatrix, Vec<f64>) {
174 let t = uniform_grid(m);
175 let data = sim_fundata(n, &t, 3, EFunType::Fourier, EValType::Exponential, Some(42));
176 (data, t)
177 }
178
179 #[test]
180 fn lambda_cv_default_config() {
181 let (data, t) = make_test_data(8, 30);
182 let config = LambdaCvConfig {
183 max_iter: 5,
184 tol: 1e-2,
185 ..LambdaCvConfig::default()
186 };
187 let result = lambda_cv(&data, &t, &config).unwrap();
188 assert_eq!(result.cv_scores.len(), config.lambdas.len());
189 assert!(result.best_lambda >= 0.0);
190 assert!(result.cv_scores.iter().all(|&s| s.is_finite()));
191 }
192
193 #[test]
194 fn lambda_cv_loo() {
195 let (data, t) = make_test_data(6, 25);
196 let config = LambdaCvConfig {
197 lambdas: vec![0.0, 1.0],
198 n_folds: 0,
199 max_iter: 3,
200 tol: 1e-2,
201 seed: 7,
202 };
203 let result = lambda_cv(&data, &t, &config).unwrap();
204 assert_eq!(result.cv_scores.len(), 2);
205 }
206
207 #[test]
208 fn lambda_cv_rejects_too_few_rows() {
209 let t = uniform_grid(10);
210 let data = sim_fundata(3, &t, 2, EFunType::Fourier, EValType::Exponential, Some(0));
211 let config = LambdaCvConfig::default();
212 assert!(lambda_cv(&data, &t, &config).is_err());
213 }
214
215 #[test]
216 fn lambda_cv_rejects_negative_lambda() {
217 let (data, t) = make_test_data(8, 20);
218 let config = LambdaCvConfig {
219 lambdas: vec![-1.0, 0.0],
220 ..LambdaCvConfig::default()
221 };
222 assert!(lambda_cv(&data, &t, &config).is_err());
223 }
224
225 #[test]
226 fn lambda_cv_rejects_one_fold() {
227 let (data, t) = make_test_data(8, 20);
228 let config = LambdaCvConfig {
229 n_folds: 1,
230 ..LambdaCvConfig::default()
231 };
232 assert!(lambda_cv(&data, &t, &config).is_err());
233 }
234
235 #[test]
236 fn lambda_cv_rejects_argval_mismatch() {
237 let (data, _) = make_test_data(8, 20);
238 let bad_t = uniform_grid(15);
239 let config = LambdaCvConfig::default();
240 assert!(lambda_cv(&data, &bad_t, &config).is_err());
241 }
242}