1use crate::matrix::FdMatrix;
8use rand::prelude::*;
9use std::any::Any;
10
11pub fn create_folds(n: usize, n_folds: usize, seed: u64) -> Vec<usize> {
18 let n_folds = n_folds.max(1);
19 let mut rng = StdRng::seed_from_u64(seed);
20 let mut indices: Vec<usize> = (0..n).collect();
21 indices.shuffle(&mut rng);
22
23 let mut folds = vec![0usize; n];
24 for (rank, &idx) in indices.iter().enumerate() {
25 folds[idx] = rank % n_folds;
26 }
27 folds
28}
29
30pub fn create_stratified_folds(n: usize, y: &[usize], n_folds: usize, seed: u64) -> Vec<usize> {
34 let n_folds = n_folds.max(1);
35 let mut rng = StdRng::seed_from_u64(seed);
36 let n_classes = y.iter().copied().max().unwrap_or(0) + 1;
37
38 let mut folds = vec![0usize; n];
39
40 let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
42 for i in 0..n {
43 if y[i] < n_classes {
44 class_indices[y[i]].push(i);
45 }
46 }
47
48 for indices in &mut class_indices {
50 indices.shuffle(&mut rng);
51 for (rank, &idx) in indices.iter().enumerate() {
52 folds[idx] = rank % n_folds;
53 }
54 }
55
56 folds
57}
58
59pub fn fold_indices(folds: &[usize], fold: usize) -> (Vec<usize>, Vec<usize>) {
63 let train: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] != fold).collect();
64 let test: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] == fold).collect();
65 (train, test)
66}
67
68pub fn subset_rows(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
70 let m = data.ncols();
71 let n_sub = indices.len();
72 let mut sub = FdMatrix::zeros(n_sub, m);
73 for (new_i, &orig_i) in indices.iter().enumerate() {
74 for j in 0..m {
75 sub[(new_i, j)] = data[(orig_i, j)];
76 }
77 }
78 sub
79}
80
81pub fn subset_vec(v: &[f64], indices: &[usize]) -> Vec<f64> {
83 indices.iter().map(|&i| v[i]).collect()
84}
85
86#[derive(Debug, Clone, Copy, PartialEq)]
90pub enum CvType {
91 Regression,
92 Classification,
93}
94
95#[derive(Debug, Clone)]
97pub enum CvMetrics {
98 Regression { rmse: f64, mae: f64, r_squared: f64 },
100 Classification {
102 accuracy: f64,
103 confusion: Vec<Vec<usize>>,
104 },
105}
106
107#[derive(Debug, Clone)]
109pub struct CvFdataResult {
110 pub oof_predictions: Vec<f64>,
112 pub metrics: CvMetrics,
114 pub fold_metrics: Vec<CvMetrics>,
116 pub folds: Vec<usize>,
118 pub cv_type: CvType,
120 pub nrep: usize,
122 pub oof_sd: Option<Vec<f64>>,
124 pub rep_metrics: Option<Vec<CvMetrics>>,
126}
127
128pub fn cv_fdata<F, P>(
145 data: &FdMatrix,
146 y: &[f64],
147 fit_fn: F,
148 predict_fn: P,
149 n_folds: usize,
150 nrep: usize,
151 cv_type: CvType,
152 stratified: bool,
153 seed: u64,
154) -> CvFdataResult
155where
156 F: Fn(&FdMatrix, &[f64]) -> Box<dyn Any>,
157 P: Fn(&dyn Any, &FdMatrix) -> Vec<f64>,
158{
159 let n = data.nrows();
160 let nrep = nrep.max(1);
161 let n_folds = n_folds.max(2).min(n);
162
163 let mut all_oof: Vec<Vec<f64>> = Vec::with_capacity(nrep);
165 let mut all_rep_metrics: Vec<CvMetrics> = Vec::with_capacity(nrep);
166 let mut last_folds = vec![0usize; n];
167 let mut last_fold_metrics = Vec::new();
168
169 for r in 0..nrep {
170 let rep_seed = seed.wrapping_add(r as u64);
171
172 let folds = if stratified {
174 match cv_type {
175 CvType::Classification => {
176 let y_class: Vec<usize> = y.iter().map(|&v| v as usize).collect();
177 create_stratified_folds(n, &y_class, n_folds, rep_seed)
178 }
179 CvType::Regression => {
180 let mut sorted_y: Vec<(usize, f64)> = y.iter().copied().enumerate().collect();
182 sorted_y
183 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
184 let n_bins = n_folds.min(n);
185 let bin_labels: Vec<usize> = {
186 let mut labels = vec![0usize; n];
187 for (rank, &(orig_i, _)) in sorted_y.iter().enumerate() {
188 labels[orig_i] = (rank * n_bins / n).min(n_bins - 1);
189 }
190 labels
191 };
192 create_stratified_folds(n, &bin_labels, n_folds, rep_seed)
193 }
194 }
195 } else {
196 create_folds(n, n_folds, rep_seed)
197 };
198
199 let mut oof_preds = vec![0.0; n];
200 let mut fold_metrics = Vec::with_capacity(n_folds);
201
202 for fold in 0..n_folds {
203 let (train_idx, test_idx) = fold_indices(&folds, fold);
204 if train_idx.is_empty() || test_idx.is_empty() {
205 continue;
206 }
207
208 let train_data = subset_rows(data, &train_idx);
209 let train_y = subset_vec(y, &train_idx);
210 let test_data = subset_rows(data, &test_idx);
211 let test_y = subset_vec(y, &test_idx);
212
213 let model = fit_fn(&train_data, &train_y);
214 let preds = predict_fn(&*model, &test_data);
215
216 for (local_i, &orig_i) in test_idx.iter().enumerate() {
217 if local_i < preds.len() {
218 oof_preds[orig_i] = preds[local_i];
219 }
220 }
221
222 fold_metrics.push(compute_metrics(&test_y, &preds, cv_type));
223 }
224
225 let rep_metric = compute_metrics(y, &oof_preds, cv_type);
226 all_oof.push(oof_preds);
227 all_rep_metrics.push(rep_metric);
228 last_folds = folds;
229 last_fold_metrics = fold_metrics;
230 }
231
232 let (final_oof, oof_sd) = if nrep == 1 {
234 (all_oof.into_iter().next().unwrap(), None)
235 } else {
236 let mut mean_oof = vec![0.0; n];
237 for oof in &all_oof {
238 for i in 0..n {
239 mean_oof[i] += oof[i];
240 }
241 }
242 for v in &mut mean_oof {
243 *v /= nrep as f64;
244 }
245
246 let mut sd_oof = vec![0.0; n];
247 for oof in &all_oof {
248 for i in 0..n {
249 let diff = oof[i] - mean_oof[i];
250 sd_oof[i] += diff * diff;
251 }
252 }
253 for v in &mut sd_oof {
254 *v = (*v / (nrep as f64 - 1.0).max(1.0)).sqrt();
255 }
256
257 (mean_oof, Some(sd_oof))
258 };
259
260 let overall_metrics = compute_metrics(y, &final_oof, cv_type);
261
262 CvFdataResult {
263 oof_predictions: final_oof,
264 metrics: overall_metrics,
265 fold_metrics: last_fold_metrics,
266 folds: last_folds,
267 cv_type,
268 nrep,
269 oof_sd,
270 rep_metrics: if nrep > 1 {
271 Some(all_rep_metrics)
272 } else {
273 None
274 },
275 }
276}
277
278fn compute_metrics(y_true: &[f64], y_pred: &[f64], cv_type: CvType) -> CvMetrics {
280 let n = y_true.len().min(y_pred.len());
281 if n == 0 {
282 return match cv_type {
283 CvType::Regression => CvMetrics::Regression {
284 rmse: f64::NAN,
285 mae: f64::NAN,
286 r_squared: f64::NAN,
287 },
288 CvType::Classification => CvMetrics::Classification {
289 accuracy: 0.0,
290 confusion: Vec::new(),
291 },
292 };
293 }
294
295 match cv_type {
296 CvType::Regression => {
297 let mean_y = y_true.iter().sum::<f64>() / n as f64;
298 let mut ss_res = 0.0;
299 let mut ss_tot = 0.0;
300 let mut mae_sum = 0.0;
301 for i in 0..n {
302 let resid = y_true[i] - y_pred[i];
303 ss_res += resid * resid;
304 ss_tot += (y_true[i] - mean_y).powi(2);
305 mae_sum += resid.abs();
306 }
307 let rmse = (ss_res / n as f64).sqrt();
308 let mae = mae_sum / n as f64;
309 let r_squared = if ss_tot > 1e-15 {
310 1.0 - ss_res / ss_tot
311 } else {
312 0.0
313 };
314 CvMetrics::Regression {
315 rmse,
316 mae,
317 r_squared,
318 }
319 }
320 CvType::Classification => {
321 let n_classes = y_true
322 .iter()
323 .chain(y_pred.iter())
324 .map(|&v| v as usize)
325 .max()
326 .unwrap_or(0)
327 + 1;
328 let mut confusion = vec![vec![0usize; n_classes]; n_classes];
329 let mut correct = 0usize;
330 for i in 0..n {
331 let true_c = y_true[i] as usize;
332 let pred_c = y_pred[i].round() as usize;
333 if true_c < n_classes && pred_c < n_classes {
334 confusion[true_c][pred_c] += 1;
335 }
336 if true_c == pred_c {
337 correct += 1;
338 }
339 }
340 let accuracy = correct as f64 / n as f64;
341 CvMetrics::Classification {
342 accuracy,
343 confusion,
344 }
345 }
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_create_folds_basic() {
355 let folds = create_folds(10, 5, 42);
356 assert_eq!(folds.len(), 10);
357 for f in 0..5 {
359 let count = folds.iter().filter(|&&x| x == f).count();
360 assert_eq!(count, 2);
361 }
362 }
363
364 #[test]
365 fn test_create_folds_deterministic() {
366 let f1 = create_folds(20, 5, 123);
367 let f2 = create_folds(20, 5, 123);
368 assert_eq!(f1, f2);
369 }
370
371 #[test]
372 fn test_stratified_folds() {
373 let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
374 let folds = create_stratified_folds(10, &y, 5, 42);
375 assert_eq!(folds.len(), 10);
376 for f in 0..5 {
378 let class0_count = (0..10).filter(|&i| folds[i] == f && y[i] == 0).count();
379 let class1_count = (0..10).filter(|&i| folds[i] == f && y[i] == 1).count();
380 assert_eq!(class0_count, 1);
381 assert_eq!(class1_count, 1);
382 }
383 }
384
385 #[test]
386 fn test_fold_indices() {
387 let folds = vec![0, 1, 2, 0, 1, 2];
388 let (train, test) = fold_indices(&folds, 1);
389 assert_eq!(test, vec![1, 4]);
390 assert_eq!(train, vec![0, 2, 3, 5]);
391 }
392
393 #[test]
394 fn test_subset_rows() {
395 let mut data = FdMatrix::zeros(4, 3);
396 for i in 0..4 {
397 for j in 0..3 {
398 data[(i, j)] = (i * 10 + j) as f64;
399 }
400 }
401 let sub = subset_rows(&data, &[1, 3]);
402 assert_eq!(sub.nrows(), 2);
403 assert_eq!(sub.ncols(), 3);
404 assert!((sub[(0, 0)] - 10.0).abs() < 1e-10);
405 assert!((sub[(1, 0)] - 30.0).abs() < 1e-10);
406 }
407
408 #[test]
409 fn test_cv_fdata_regression() {
410 let n = 20;
412 let m = 5;
413 let mut data = FdMatrix::zeros(n, m);
414 let y: Vec<f64> = (0..n).map(|i| i as f64).collect();
415 for i in 0..n {
416 for j in 0..m {
417 data[(i, j)] = y[i] + j as f64 * 0.1;
418 }
419 }
420
421 let result = cv_fdata(
422 &data,
423 &y,
424 |_train_data, train_y| {
425 let mean = train_y.iter().sum::<f64>() / train_y.len() as f64;
426 Box::new(mean)
427 },
428 |model, test_data| {
429 let mean = model.downcast_ref::<f64>().unwrap();
430 vec![*mean; test_data.nrows()]
431 },
432 5,
433 1,
434 CvType::Regression,
435 false,
436 42,
437 );
438
439 assert_eq!(result.oof_predictions.len(), n);
440 assert_eq!(result.nrep, 1);
441 assert!(result.oof_sd.is_none());
442 match &result.metrics {
443 CvMetrics::Regression { rmse, .. } => assert!(*rmse > 0.0),
444 _ => panic!("Expected regression metrics"),
445 }
446 }
447
448 #[test]
449 fn test_cv_fdata_repeated() {
450 let n = 20;
451 let m = 3;
452 let data = FdMatrix::zeros(n, m);
453 let y: Vec<f64> = (0..n).map(|i| (i % 2) as f64).collect();
454
455 let result = cv_fdata(
456 &data,
457 &y,
458 |_d, _y| Box::new(0.5_f64),
459 |_model, test_data| vec![0.5; test_data.nrows()],
460 5,
461 3,
462 CvType::Regression,
463 false,
464 42,
465 );
466
467 assert_eq!(result.nrep, 3);
468 assert!(result.oof_sd.is_some());
469 assert!(result.rep_metrics.is_some());
470 assert_eq!(result.rep_metrics.as_ref().unwrap().len(), 3);
471 }
472
473 #[test]
474 fn test_compute_metrics_classification() {
475 let y_true = vec![0.0, 0.0, 1.0, 1.0];
476 let y_pred = vec![0.0, 1.0, 1.0, 1.0]; let m = compute_metrics(&y_true, &y_pred, CvType::Classification);
478 match m {
479 CvMetrics::Classification {
480 accuracy,
481 confusion,
482 } => {
483 assert!((accuracy - 0.75).abs() < 1e-10);
484 assert_eq!(confusion[0][0], 1); assert_eq!(confusion[0][1], 1); assert_eq!(confusion[1][1], 2); }
488 _ => panic!("Expected classification metrics"),
489 }
490 }
491}