1use crate::matrix::FdMatrix;
8use rand::prelude::*;
9use std::any::Any;
10use std::collections::HashMap;
11
12pub fn create_folds(n: usize, n_folds: usize, seed: u64) -> Vec<usize> {
19 let n_folds = n_folds.max(1);
20 let mut rng = StdRng::seed_from_u64(seed);
21 let mut indices: Vec<usize> = (0..n).collect();
22 indices.shuffle(&mut rng);
23
24 let mut folds = vec![0usize; n];
25 for (rank, &idx) in indices.iter().enumerate() {
26 folds[idx] = rank % n_folds;
27 }
28 folds
29}
30
31pub fn create_stratified_folds(n: usize, y: &[usize], n_folds: usize, seed: u64) -> Vec<usize> {
35 let n_folds = n_folds.max(1);
36 let mut rng = StdRng::seed_from_u64(seed);
37 let n_classes = y.iter().copied().max().unwrap_or(0) + 1;
38
39 let mut folds = vec![0usize; n];
40
41 let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
43 for i in 0..n {
44 if y[i] < n_classes {
45 class_indices[y[i]].push(i);
46 }
47 }
48
49 for indices in &mut class_indices {
51 indices.shuffle(&mut rng);
52 for (rank, &idx) in indices.iter().enumerate() {
53 folds[idx] = rank % n_folds;
54 }
55 }
56
57 folds
58}
59
60pub fn fold_indices(folds: &[usize], fold: usize) -> (Vec<usize>, Vec<usize>) {
64 let train: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] != fold).collect();
65 let test: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] == fold).collect();
66 (train, test)
67}
68
69pub fn subset_rows(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
71 let m = data.ncols();
72 let n_sub = indices.len();
73 let mut sub = FdMatrix::zeros(n_sub, m);
74 for (new_i, &orig_i) in indices.iter().enumerate() {
75 for j in 0..m {
76 sub[(new_i, j)] = data[(orig_i, j)];
77 }
78 }
79 sub
80}
81
82pub fn subset_vec(v: &[f64], indices: &[usize]) -> Vec<f64> {
84 indices.iter().map(|&i| v[i]).collect()
85}
86
87#[derive(Debug, Clone, Copy, PartialEq)]
91#[non_exhaustive]
92pub enum CvType {
93 Regression,
94 Classification,
95}
96
97#[derive(Debug, Clone, PartialEq)]
99#[non_exhaustive]
100pub enum CvMetrics {
101 Regression { rmse: f64, mae: f64, r_squared: f64 },
103 Classification {
105 accuracy: f64,
106 confusion: Vec<Vec<usize>>,
107 },
108}
109
110pub type MetricFn = (&'static str, fn(&[f64], &[f64]) -> f64);
112
113pub fn metric_rmse(y_true: &[f64], y_pred: &[f64]) -> f64 {
117 let n = y_true.len().min(y_pred.len());
118 if n == 0 {
119 return f64::NAN;
120 }
121 let mse: f64 = (0..n).map(|i| (y_true[i] - y_pred[i]).powi(2)).sum::<f64>() / n as f64;
122 mse.sqrt()
123}
124
125pub fn metric_mae(y_true: &[f64], y_pred: &[f64]) -> f64 {
127 let n = y_true.len().min(y_pred.len());
128 if n == 0 {
129 return f64::NAN;
130 }
131 (0..n).map(|i| (y_true[i] - y_pred[i]).abs()).sum::<f64>() / n as f64
132}
133
134pub fn metric_r_squared(y_true: &[f64], y_pred: &[f64]) -> f64 {
136 let n = y_true.len().min(y_pred.len());
137 if n == 0 {
138 return f64::NAN;
139 }
140 let mean = y_true.iter().sum::<f64>() / n as f64;
141 let ss_res: f64 = (0..n).map(|i| (y_true[i] - y_pred[i]).powi(2)).sum();
142 let ss_tot: f64 = (0..n).map(|i| (y_true[i] - mean).powi(2)).sum();
143 if ss_tot > 1e-15 {
144 1.0 - ss_res / ss_tot
145 } else {
146 0.0
147 }
148}
149
150pub fn regression_metrics() -> Vec<MetricFn> {
152 vec![
153 ("rmse", metric_rmse as fn(&[f64], &[f64]) -> f64),
154 ("mae", metric_mae),
155 ("r_squared", metric_r_squared),
156 ]
157}
158
159pub fn metric_accuracy(y_true: &[f64], y_pred: &[f64]) -> f64 {
163 let n = y_true.len().min(y_pred.len());
164 if n == 0 {
165 return f64::NAN;
166 }
167 let correct = (0..n)
168 .filter(|&i| (y_true[i] as usize) == (y_pred[i].round() as usize))
169 .count();
170 correct as f64 / n as f64
171}
172
173pub fn metric_precision(y_true: &[f64], y_pred: &[f64]) -> f64 {
175 let n = y_true.len().min(y_pred.len());
176 let mut tp = 0usize;
177 let mut fp = 0usize;
178 for i in 0..n {
179 let pred = y_pred[i].round() as usize;
180 let true_c = y_true[i] as usize;
181 if pred == 1 {
182 if true_c == 1 {
183 tp += 1;
184 } else {
185 fp += 1;
186 }
187 }
188 }
189 if tp + fp > 0 {
190 tp as f64 / (tp + fp) as f64
191 } else {
192 0.0
193 }
194}
195
196pub fn metric_recall(y_true: &[f64], y_pred: &[f64]) -> f64 {
198 let n = y_true.len().min(y_pred.len());
199 let mut tp = 0usize;
200 let mut fn_ = 0usize;
201 for i in 0..n {
202 let pred = y_pred[i].round() as usize;
203 let true_c = y_true[i] as usize;
204 if true_c == 1 {
205 if pred == 1 {
206 tp += 1;
207 } else {
208 fn_ += 1;
209 }
210 }
211 }
212 if tp + fn_ > 0 {
213 tp as f64 / (tp + fn_) as f64
214 } else {
215 0.0
216 }
217}
218
219pub fn metric_f1(y_true: &[f64], y_pred: &[f64]) -> f64 {
221 let p = metric_precision(y_true, y_pred);
222 let r = metric_recall(y_true, y_pred);
223 if p + r > 0.0 {
224 2.0 * p * r / (p + r)
225 } else {
226 0.0
227 }
228}
229
230pub fn classification_metrics() -> Vec<MetricFn> {
232 vec![
233 ("accuracy", metric_accuracy as fn(&[f64], &[f64]) -> f64),
234 ("precision", metric_precision),
235 ("recall", metric_recall),
236 ("f1", metric_f1),
237 ]
238}
239
240fn evaluate_metrics(
242 y_true: &[f64],
243 y_pred: &[f64],
244 metric_fns: &[MetricFn],
245) -> HashMap<String, f64> {
246 metric_fns
247 .iter()
248 .map(|(name, f)| ((*name).to_string(), f(y_true, y_pred)))
249 .collect()
250}
251
252#[derive(Debug, Clone, PartialEq)]
254#[non_exhaustive]
255pub struct CvFdataResult {
256 pub oof_predictions: Vec<f64>,
258 pub metrics: CvMetrics,
260 pub fold_metrics: Vec<CvMetrics>,
262 pub folds: Vec<usize>,
264 pub cv_type: CvType,
266 pub nrep: usize,
268 pub oof_sd: Option<Vec<f64>>,
270 pub rep_metrics: Option<Vec<CvMetrics>>,
272 pub custom_metrics: HashMap<String, f64>,
274 pub fold_custom_metrics: Vec<HashMap<String, f64>>,
276}
277
278fn create_cv_folds(
282 n: usize,
283 y: &[f64],
284 n_folds: usize,
285 cv_type: CvType,
286 stratified: bool,
287 seed: u64,
288) -> Vec<usize> {
289 if stratified {
290 match cv_type {
291 CvType::Classification => {
292 let y_class: Vec<usize> = y
293 .iter()
294 .map(|&v| crate::utility::f64_to_usize_clamped(v))
295 .collect();
296 create_stratified_folds(n, &y_class, n_folds, seed)
297 }
298 CvType::Regression => {
299 let mut sorted_y: Vec<(usize, f64)> = y.iter().copied().enumerate().collect();
300 sorted_y.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
301 let n_bins = n_folds.min(n);
302 let bin_labels: Vec<usize> = {
303 let mut labels = vec![0usize; n];
304 for (rank, &(orig_i, _)) in sorted_y.iter().enumerate() {
305 labels[orig_i] = (rank * n_bins / n).min(n_bins - 1);
306 }
307 labels
308 };
309 create_stratified_folds(n, &bin_labels, n_folds, seed)
310 }
311 }
312 } else {
313 create_folds(n, n_folds, seed)
314 }
315}
316
317fn aggregate_oof_predictions(all_oof: Vec<Vec<f64>>, n: usize) -> (Vec<f64>, Option<Vec<f64>>) {
319 let nrep = all_oof.len();
320 if nrep == 1 {
321 return (
322 all_oof.into_iter().next().expect("non-empty iterator"),
323 None,
324 );
325 }
326 let mut mean_oof = vec![0.0; n];
327 for oof in &all_oof {
328 for i in 0..n {
329 mean_oof[i] += oof[i];
330 }
331 }
332 for v in &mut mean_oof {
333 *v /= nrep as f64;
334 }
335
336 let mut sd_oof = vec![0.0; n];
337 for oof in &all_oof {
338 for i in 0..n {
339 let diff = oof[i] - mean_oof[i];
340 sd_oof[i] += diff * diff;
341 }
342 }
343 for v in &mut sd_oof {
344 *v = (*v / (nrep as f64 - 1.0).max(1.0)).sqrt();
345 }
346
347 (mean_oof, Some(sd_oof))
348}
349
350pub fn cv_fdata<F, P>(
355 data: &FdMatrix,
356 y: &[f64],
357 fit_fn: F,
358 predict_fn: P,
359 n_folds: usize,
360 nrep: usize,
361 cv_type: CvType,
362 stratified: bool,
363 seed: u64,
364) -> CvFdataResult
365where
366 F: Fn(&FdMatrix, &[f64]) -> Box<dyn Any>,
367 P: Fn(&dyn Any, &FdMatrix) -> Vec<f64>,
368{
369 cv_fdata_with_metrics(
370 data,
371 y,
372 fit_fn,
373 predict_fn,
374 n_folds,
375 nrep,
376 cv_type,
377 stratified,
378 seed,
379 &[],
380 )
381}
382
383pub fn cv_fdata_with_metrics<F, P>(
420 data: &FdMatrix,
421 y: &[f64],
422 fit_fn: F,
423 predict_fn: P,
424 n_folds: usize,
425 nrep: usize,
426 cv_type: CvType,
427 stratified: bool,
428 seed: u64,
429 metric_fns: &[MetricFn],
430) -> CvFdataResult
431where
432 F: Fn(&FdMatrix, &[f64]) -> Box<dyn Any>,
433 P: Fn(&dyn Any, &FdMatrix) -> Vec<f64>,
434{
435 let n = data.nrows();
436 let nrep = nrep.max(1);
437 let n_folds = n_folds.max(2).min(n);
438
439 let mut all_oof: Vec<Vec<f64>> = Vec::with_capacity(nrep);
440 let mut all_rep_metrics: Vec<CvMetrics> = Vec::with_capacity(nrep);
441 let mut last_folds = vec![0usize; n];
442 let mut last_fold_metrics = Vec::new();
443 let mut last_fold_custom = Vec::new();
444
445 for r in 0..nrep {
446 let rep_seed = seed.wrapping_add(r as u64);
447 let folds = create_cv_folds(n, y, n_folds, cv_type, stratified, rep_seed);
448
449 let mut oof_preds = vec![0.0; n];
450 let mut fold_metrics = Vec::with_capacity(n_folds);
451 let mut fold_custom = Vec::with_capacity(n_folds);
452
453 for fold in 0..n_folds {
454 let (train_idx, test_idx) = fold_indices(&folds, fold);
455 if train_idx.is_empty() || test_idx.is_empty() {
456 continue;
457 }
458
459 let train_data = subset_rows(data, &train_idx);
460 let train_y = subset_vec(y, &train_idx);
461 let test_data = subset_rows(data, &test_idx);
462 let test_y = subset_vec(y, &test_idx);
463
464 let model = fit_fn(&train_data, &train_y);
465 let preds = predict_fn(&*model, &test_data);
466
467 for (local_i, &orig_i) in test_idx.iter().enumerate() {
468 if local_i < preds.len() {
469 oof_preds[orig_i] = preds[local_i];
470 }
471 }
472
473 fold_metrics.push(compute_metrics(&test_y, &preds, cv_type));
474 if !metric_fns.is_empty() {
475 fold_custom.push(evaluate_metrics(&test_y, &preds, metric_fns));
476 }
477 }
478
479 let rep_metric = compute_metrics(y, &oof_preds, cv_type);
480 all_oof.push(oof_preds);
481 all_rep_metrics.push(rep_metric);
482 last_folds = folds;
483 last_fold_metrics = fold_metrics;
484 last_fold_custom = fold_custom;
485 }
486
487 let (final_oof, oof_sd) = aggregate_oof_predictions(all_oof, n);
488 let overall_metrics = compute_metrics(y, &final_oof, cv_type);
489 let custom_metrics = if metric_fns.is_empty() {
490 HashMap::new()
491 } else {
492 evaluate_metrics(y, &final_oof, metric_fns)
493 };
494
495 CvFdataResult {
496 oof_predictions: final_oof,
497 metrics: overall_metrics,
498 fold_metrics: last_fold_metrics,
499 folds: last_folds,
500 cv_type,
501 nrep,
502 oof_sd,
503 rep_metrics: if nrep > 1 {
504 Some(all_rep_metrics)
505 } else {
506 None
507 },
508 custom_metrics,
509 fold_custom_metrics: last_fold_custom,
510 }
511}
512
513fn compute_metrics(y_true: &[f64], y_pred: &[f64], cv_type: CvType) -> CvMetrics {
515 let n = y_true.len().min(y_pred.len());
516 if n == 0 {
517 return match cv_type {
518 CvType::Regression => CvMetrics::Regression {
519 rmse: f64::NAN,
520 mae: f64::NAN,
521 r_squared: f64::NAN,
522 },
523 CvType::Classification => CvMetrics::Classification {
524 accuracy: 0.0,
525 confusion: Vec::new(),
526 },
527 };
528 }
529
530 match cv_type {
531 CvType::Regression => {
532 let mean_y = y_true.iter().sum::<f64>() / n as f64;
533 let mut ss_res = 0.0;
534 let mut ss_tot = 0.0;
535 let mut mae_sum = 0.0;
536 for i in 0..n {
537 let resid = y_true[i] - y_pred[i];
538 ss_res += resid * resid;
539 ss_tot += (y_true[i] - mean_y).powi(2);
540 mae_sum += resid.abs();
541 }
542 let rmse = (ss_res / n as f64).sqrt();
543 let mae = mae_sum / n as f64;
544 let r_squared = if ss_tot > 1e-15 {
545 1.0 - ss_res / ss_tot
546 } else {
547 0.0
548 };
549 CvMetrics::Regression {
550 rmse,
551 mae,
552 r_squared,
553 }
554 }
555 CvType::Classification => {
556 let n_classes = y_true
557 .iter()
558 .chain(y_pred.iter())
559 .map(|&v| v as usize)
560 .max()
561 .unwrap_or(0)
562 + 1;
563 let mut confusion = vec![vec![0usize; n_classes]; n_classes];
564 let mut correct = 0usize;
565 for i in 0..n {
566 let true_c = y_true[i] as usize;
567 let pred_c = y_pred[i].round() as usize;
568 if true_c < n_classes && pred_c < n_classes {
569 confusion[true_c][pred_c] += 1;
570 }
571 if true_c == pred_c {
572 correct += 1;
573 }
574 }
575 let accuracy = correct as f64 / n as f64;
576 CvMetrics::Classification {
577 accuracy,
578 confusion,
579 }
580 }
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use crate::error::FdarError;
588
589 #[test]
590 fn test_create_folds_basic() {
591 let folds = create_folds(10, 5, 42);
592 assert_eq!(folds.len(), 10);
593 for f in 0..5 {
595 let count = folds.iter().filter(|&&x| x == f).count();
596 assert_eq!(count, 2);
597 }
598 }
599
600 #[test]
601 fn test_create_folds_deterministic() {
602 let f1 = create_folds(20, 5, 123);
603 let f2 = create_folds(20, 5, 123);
604 assert_eq!(f1, f2);
605 }
606
607 #[test]
608 fn test_stratified_folds() {
609 let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
610 let folds = create_stratified_folds(10, &y, 5, 42);
611 assert_eq!(folds.len(), 10);
612 for f in 0..5 {
614 let class0_count = (0..10).filter(|&i| folds[i] == f && y[i] == 0).count();
615 let class1_count = (0..10).filter(|&i| folds[i] == f && y[i] == 1).count();
616 assert_eq!(class0_count, 1);
617 assert_eq!(class1_count, 1);
618 }
619 }
620
621 #[test]
622 fn test_fold_indices() {
623 let folds = vec![0, 1, 2, 0, 1, 2];
624 let (train, test) = fold_indices(&folds, 1);
625 assert_eq!(test, vec![1, 4]);
626 assert_eq!(train, vec![0, 2, 3, 5]);
627 }
628
629 #[test]
630 fn test_subset_rows() {
631 let mut data = FdMatrix::zeros(4, 3);
632 for i in 0..4 {
633 for j in 0..3 {
634 data[(i, j)] = (i * 10 + j) as f64;
635 }
636 }
637 let sub = subset_rows(&data, &[1, 3]);
638 assert_eq!(sub.nrows(), 2);
639 assert_eq!(sub.ncols(), 3);
640 assert!((sub[(0, 0)] - 10.0).abs() < 1e-10);
641 assert!((sub[(1, 0)] - 30.0).abs() < 1e-10);
642 }
643
644 #[test]
645 fn test_cv_fdata_regression() -> Result<(), FdarError> {
646 let n = 20;
648 let m = 5;
649 let mut data = FdMatrix::zeros(n, m);
650 let y: Vec<f64> = (0..n).map(|i| i as f64).collect();
651 for i in 0..n {
652 for j in 0..m {
653 data[(i, j)] = y[i] + j as f64 * 0.1;
654 }
655 }
656
657 let result = cv_fdata(
658 &data,
659 &y,
660 |_train_data, train_y| {
661 let mean = train_y.iter().sum::<f64>() / train_y.len() as f64;
662 Box::new(mean)
663 },
664 |model, test_data| {
665 let mean = model.downcast_ref::<f64>().unwrap();
666 vec![*mean; test_data.nrows()]
667 },
668 5,
669 1,
670 CvType::Regression,
671 false,
672 42,
673 );
674
675 assert_eq!(result.oof_predictions.len(), n);
676 assert_eq!(result.nrep, 1);
677 assert!(result.oof_sd.is_none());
678 match &result.metrics {
679 CvMetrics::Regression { rmse, .. } => assert!(*rmse > 0.0),
680 _ => {
681 return Err(FdarError::ComputationFailed {
682 operation: "cv_fdata_regression",
683 detail: "expected regression metrics".into(),
684 });
685 }
686 }
687 Ok(())
688 }
689
690 #[test]
691 fn test_cv_fdata_repeated() {
692 let n = 20;
693 let m = 3;
694 let data = FdMatrix::zeros(n, m);
695 let y: Vec<f64> = (0..n).map(|i| (i % 2) as f64).collect();
696
697 let result = cv_fdata(
698 &data,
699 &y,
700 |_d, _y| Box::new(0.5_f64),
701 |_model, test_data| vec![0.5; test_data.nrows()],
702 5,
703 3,
704 CvType::Regression,
705 false,
706 42,
707 );
708
709 assert_eq!(result.nrep, 3);
710 assert!(result.oof_sd.is_some());
711 assert!(result.rep_metrics.is_some());
712 assert_eq!(result.rep_metrics.as_ref().unwrap().len(), 3);
713 }
714
715 #[test]
716 fn test_custom_metrics() {
717 let n = 20;
718 let m = 3;
719 let data = FdMatrix::zeros(n, m);
720 let y: Vec<f64> = (0..n).map(|i| i as f64).collect();
721
722 let metrics = regression_metrics();
723 let result = cv_fdata_with_metrics(
724 &data,
725 &y,
726 |_d, train_y| {
727 let mean = train_y.iter().sum::<f64>() / train_y.len() as f64;
728 Box::new(mean)
729 },
730 |model, test_data| {
731 let mean = model.downcast_ref::<f64>().unwrap();
732 vec![*mean; test_data.nrows()]
733 },
734 5,
735 1,
736 CvType::Regression,
737 false,
738 42,
739 &metrics,
740 );
741
742 assert!(result.custom_metrics.contains_key("rmse"));
743 assert!(result.custom_metrics.contains_key("mae"));
744 assert!(result.custom_metrics.contains_key("r_squared"));
745 assert!(*result.custom_metrics.get("rmse").unwrap() > 0.0);
746 assert_eq!(result.fold_custom_metrics.len(), 5);
747 }
748
749 #[test]
750 fn test_classification_metrics_standalone() {
751 let y_true = vec![0.0, 0.0, 1.0, 1.0, 1.0];
752 let y_pred = vec![0.0, 1.0, 1.0, 1.0, 0.0];
753 assert!((metric_accuracy(&y_true, &y_pred) - 0.6).abs() < 1e-10);
754 assert!((metric_precision(&y_true, &y_pred) - 2.0 / 3.0).abs() < 1e-10); assert!((metric_recall(&y_true, &y_pred) - 2.0 / 3.0).abs() < 1e-10); let f1 = metric_f1(&y_true, &y_pred);
757 assert!((f1 - 2.0 / 3.0).abs() < 1e-10); }
759
760 #[test]
761 fn test_compute_metrics_classification() -> Result<(), FdarError> {
762 let y_true = vec![0.0, 0.0, 1.0, 1.0];
763 let y_pred = vec![0.0, 1.0, 1.0, 1.0]; let m = compute_metrics(&y_true, &y_pred, CvType::Classification);
765 match m {
766 CvMetrics::Classification {
767 accuracy,
768 confusion,
769 } => {
770 assert!((accuracy - 0.75).abs() < 1e-10);
771 assert_eq!(confusion[0][0], 1); assert_eq!(confusion[0][1], 1); assert_eq!(confusion[1][1], 2); }
775 _ => {
776 return Err(FdarError::ComputationFailed {
777 operation: "compute_metrics_classification",
778 detail: "expected classification metrics".into(),
779 });
780 }
781 }
782 Ok(())
783 }
784}