1use crate::matrix::FdMatrix;
8use rand::prelude::*;
9use std::any::Any;
10use std::collections::HashMap;
11
12pub fn create_folds(n: usize, n_folds: usize, seed: u64) -> Vec<usize> {
19 let n_folds = n_folds.max(1);
20 let mut rng = StdRng::seed_from_u64(seed);
21 let mut indices: Vec<usize> = (0..n).collect();
22 indices.shuffle(&mut rng);
23
24 let mut folds = vec![0usize; n];
25 for (rank, &idx) in indices.iter().enumerate() {
26 folds[idx] = rank % n_folds;
27 }
28 folds
29}
30
31pub fn create_stratified_folds(n: usize, y: &[usize], n_folds: usize, seed: u64) -> Vec<usize> {
35 let n_folds = n_folds.max(1);
36 let mut rng = StdRng::seed_from_u64(seed);
37 let n_classes = y.iter().copied().max().unwrap_or(0) + 1;
38
39 let mut folds = vec![0usize; n];
40
41 let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
43 for i in 0..n {
44 if y[i] < n_classes {
45 class_indices[y[i]].push(i);
46 }
47 }
48
49 for indices in &mut class_indices {
51 indices.shuffle(&mut rng);
52 for (rank, &idx) in indices.iter().enumerate() {
53 folds[idx] = rank % n_folds;
54 }
55 }
56
57 folds
58}
59
60pub fn fold_indices(folds: &[usize], fold: usize) -> (Vec<usize>, Vec<usize>) {
64 let train: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] != fold).collect();
65 let test: Vec<usize> = (0..folds.len()).filter(|&i| folds[i] == fold).collect();
66 (train, test)
67}
68
69pub fn subset_rows(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
71 let m = data.ncols();
72 let n_sub = indices.len();
73 let mut sub = FdMatrix::zeros(n_sub, m);
74 for (new_i, &orig_i) in indices.iter().enumerate() {
75 for j in 0..m {
76 sub[(new_i, j)] = data[(orig_i, j)];
77 }
78 }
79 sub
80}
81
82pub fn subset_vec(v: &[f64], indices: &[usize]) -> Vec<f64> {
84 indices.iter().map(|&i| v[i]).collect()
85}
86
87#[derive(Debug, Clone, Copy, PartialEq)]
91#[non_exhaustive]
92pub enum CvType {
93 Regression,
94 Classification,
95}
96
97#[derive(Debug, Clone, PartialEq)]
99#[non_exhaustive]
100pub enum CvMetrics {
101 Regression { rmse: f64, mae: f64, r_squared: f64 },
103 Classification {
105 accuracy: f64,
106 confusion: Vec<Vec<usize>>,
107 },
108}
109
110pub type MetricFn = (&'static str, fn(&[f64], &[f64]) -> f64);
112
113pub fn metric_rmse(y_true: &[f64], y_pred: &[f64]) -> f64 {
117 let n = y_true.len().min(y_pred.len());
118 if n == 0 {
119 return f64::NAN;
120 }
121 let mse: f64 = (0..n).map(|i| (y_true[i] - y_pred[i]).powi(2)).sum::<f64>() / n as f64;
122 mse.sqrt()
123}
124
125pub fn metric_mae(y_true: &[f64], y_pred: &[f64]) -> f64 {
127 let n = y_true.len().min(y_pred.len());
128 if n == 0 {
129 return f64::NAN;
130 }
131 (0..n).map(|i| (y_true[i] - y_pred[i]).abs()).sum::<f64>() / n as f64
132}
133
134pub fn metric_r_squared(y_true: &[f64], y_pred: &[f64]) -> f64 {
136 let n = y_true.len().min(y_pred.len());
137 if n == 0 {
138 return f64::NAN;
139 }
140 let mean = y_true.iter().sum::<f64>() / n as f64;
141 let ss_res: f64 = (0..n).map(|i| (y_true[i] - y_pred[i]).powi(2)).sum();
142 let ss_tot: f64 = (0..n).map(|i| (y_true[i] - mean).powi(2)).sum();
143 if ss_tot > 1e-15 {
144 1.0 - ss_res / ss_tot
145 } else {
146 0.0
147 }
148}
149
150pub fn regression_metrics() -> Vec<MetricFn> {
152 vec![
153 ("rmse", metric_rmse as fn(&[f64], &[f64]) -> f64),
154 ("mae", metric_mae),
155 ("r_squared", metric_r_squared),
156 ]
157}
158
159pub fn metric_accuracy(y_true: &[f64], y_pred: &[f64]) -> f64 {
163 let n = y_true.len().min(y_pred.len());
164 if n == 0 {
165 return f64::NAN;
166 }
167 let correct = (0..n)
168 .filter(|&i| (y_true[i] as usize) == (y_pred[i].round() as usize))
169 .count();
170 correct as f64 / n as f64
171}
172
173pub fn metric_precision(y_true: &[f64], y_pred: &[f64]) -> f64 {
175 let n = y_true.len().min(y_pred.len());
176 let mut tp = 0usize;
177 let mut fp = 0usize;
178 for i in 0..n {
179 let pred = y_pred[i].round() as usize;
180 let true_c = y_true[i] as usize;
181 if pred == 1 {
182 if true_c == 1 {
183 tp += 1;
184 } else {
185 fp += 1;
186 }
187 }
188 }
189 if tp + fp > 0 {
190 tp as f64 / (tp + fp) as f64
191 } else {
192 0.0
193 }
194}
195
196pub fn metric_recall(y_true: &[f64], y_pred: &[f64]) -> f64 {
198 let n = y_true.len().min(y_pred.len());
199 let mut tp = 0usize;
200 let mut fn_ = 0usize;
201 for i in 0..n {
202 let pred = y_pred[i].round() as usize;
203 let true_c = y_true[i] as usize;
204 if true_c == 1 {
205 if pred == 1 {
206 tp += 1;
207 } else {
208 fn_ += 1;
209 }
210 }
211 }
212 if tp + fn_ > 0 {
213 tp as f64 / (tp + fn_) as f64
214 } else {
215 0.0
216 }
217}
218
219pub fn metric_f1(y_true: &[f64], y_pred: &[f64]) -> f64 {
221 let p = metric_precision(y_true, y_pred);
222 let r = metric_recall(y_true, y_pred);
223 if p + r > 0.0 {
224 2.0 * p * r / (p + r)
225 } else {
226 0.0
227 }
228}
229
230pub fn classification_metrics() -> Vec<MetricFn> {
232 vec![
233 ("accuracy", metric_accuracy as fn(&[f64], &[f64]) -> f64),
234 ("precision", metric_precision),
235 ("recall", metric_recall),
236 ("f1", metric_f1),
237 ]
238}
239
240fn evaluate_metrics(
242 y_true: &[f64],
243 y_pred: &[f64],
244 metric_fns: &[MetricFn],
245) -> HashMap<String, f64> {
246 metric_fns
247 .iter()
248 .map(|(name, f)| ((*name).to_string(), f(y_true, y_pred)))
249 .collect()
250}
251
252#[derive(Debug, Clone, PartialEq)]
254#[non_exhaustive]
255pub struct CvFdataResult {
256 pub oof_predictions: Vec<f64>,
258 pub metrics: CvMetrics,
260 pub fold_metrics: Vec<CvMetrics>,
262 pub folds: Vec<usize>,
264 pub cv_type: CvType,
266 pub nrep: usize,
268 pub oof_sd: Option<Vec<f64>>,
270 pub rep_metrics: Option<Vec<CvMetrics>>,
272 pub custom_metrics: HashMap<String, f64>,
274 pub fold_custom_metrics: Vec<HashMap<String, f64>>,
276}
277
278fn create_cv_folds(
282 n: usize,
283 y: &[f64],
284 n_folds: usize,
285 cv_type: CvType,
286 stratified: bool,
287 seed: u64,
288) -> Vec<usize> {
289 if stratified {
290 match cv_type {
291 CvType::Classification => {
292 let y_class: Vec<usize> = y
293 .iter()
294 .map(|&v| crate::utility::f64_to_usize_clamped(v))
295 .collect();
296 create_stratified_folds(n, &y_class, n_folds, seed)
297 }
298 CvType::Regression => {
299 let mut sorted_y: Vec<(usize, f64)> = y.iter().copied().enumerate().collect();
300 sorted_y.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
301 let n_bins = n_folds.min(n);
302 let bin_labels: Vec<usize> = {
303 let mut labels = vec![0usize; n];
304 for (rank, &(orig_i, _)) in sorted_y.iter().enumerate() {
305 labels[orig_i] = (rank * n_bins / n).min(n_bins - 1);
306 }
307 labels
308 };
309 create_stratified_folds(n, &bin_labels, n_folds, seed)
310 }
311 }
312 } else {
313 create_folds(n, n_folds, seed)
314 }
315}
316
317fn aggregate_oof_predictions(all_oof: Vec<Vec<f64>>, n: usize) -> (Vec<f64>, Option<Vec<f64>>) {
319 let nrep = all_oof.len();
320 if nrep == 1 {
321 return (
322 all_oof.into_iter().next().expect("non-empty iterator"),
323 None,
324 );
325 }
326 let mut mean_oof = vec![0.0; n];
327 for oof in &all_oof {
328 for i in 0..n {
329 mean_oof[i] += oof[i];
330 }
331 }
332 for v in &mut mean_oof {
333 *v /= nrep as f64;
334 }
335
336 let mut sd_oof = vec![0.0; n];
337 for oof in &all_oof {
338 for i in 0..n {
339 let diff = oof[i] - mean_oof[i];
340 sd_oof[i] += diff * diff;
341 }
342 }
343 for v in &mut sd_oof {
344 *v = (*v / (nrep as f64 - 1.0).max(1.0)).sqrt();
345 }
346
347 (mean_oof, Some(sd_oof))
348}
349
350pub fn cv_fdata<F, P>(
355 data: &FdMatrix,
356 y: &[f64],
357 fit_fn: F,
358 predict_fn: P,
359 n_folds: usize,
360 nrep: usize,
361 cv_type: CvType,
362 stratified: bool,
363 seed: u64,
364) -> CvFdataResult
365where
366 F: Fn(&FdMatrix, &[f64]) -> Box<dyn Any>,
367 P: Fn(&dyn Any, &FdMatrix) -> Vec<f64>,
368{
369 cv_fdata_with_metrics(
370 data,
371 y,
372 fit_fn,
373 predict_fn,
374 n_folds,
375 nrep,
376 cv_type,
377 stratified,
378 seed,
379 &[],
380 )
381}
382
383pub fn cv_fdata_with_metrics<F, P>(
420 data: &FdMatrix,
421 y: &[f64],
422 fit_fn: F,
423 predict_fn: P,
424 n_folds: usize,
425 nrep: usize,
426 cv_type: CvType,
427 stratified: bool,
428 seed: u64,
429 metric_fns: &[MetricFn],
430) -> CvFdataResult
431where
432 F: Fn(&FdMatrix, &[f64]) -> Box<dyn Any>,
433 P: Fn(&dyn Any, &FdMatrix) -> Vec<f64>,
434{
435 let n = data.nrows();
436 let nrep = nrep.max(1);
437 let n_folds = n_folds.max(2).min(n);
438
439 let mut all_oof: Vec<Vec<f64>> = Vec::with_capacity(nrep);
440 let mut all_rep_metrics: Vec<CvMetrics> = Vec::with_capacity(nrep);
441 let mut last_folds = vec![0usize; n];
442 let mut last_fold_metrics = Vec::new();
443 let mut last_fold_custom = Vec::new();
444
445 for r in 0..nrep {
446 let rep_seed = seed.wrapping_add(r as u64);
447 let folds = create_cv_folds(n, y, n_folds, cv_type, stratified, rep_seed);
448
449 let mut oof_preds = vec![0.0; n];
450 let mut fold_metrics = Vec::with_capacity(n_folds);
451 let mut fold_custom = Vec::with_capacity(n_folds);
452
453 for fold in 0..n_folds {
454 let (train_idx, test_idx) = fold_indices(&folds, fold);
455 if train_idx.is_empty() || test_idx.is_empty() {
456 continue;
457 }
458
459 let train_data = subset_rows(data, &train_idx);
460 let train_y = subset_vec(y, &train_idx);
461 let test_data = subset_rows(data, &test_idx);
462 let test_y = subset_vec(y, &test_idx);
463
464 let model = fit_fn(&train_data, &train_y);
465 let preds = predict_fn(&*model, &test_data);
466
467 for (local_i, &orig_i) in test_idx.iter().enumerate() {
468 if local_i < preds.len() {
469 oof_preds[orig_i] = preds[local_i];
470 }
471 }
472
473 fold_metrics.push(compute_metrics(&test_y, &preds, cv_type));
474 if !metric_fns.is_empty() {
475 fold_custom.push(evaluate_metrics(&test_y, &preds, metric_fns));
476 }
477 }
478
479 let rep_metric = compute_metrics(y, &oof_preds, cv_type);
480 all_oof.push(oof_preds);
481 all_rep_metrics.push(rep_metric);
482 last_folds = folds;
483 last_fold_metrics = fold_metrics;
484 last_fold_custom = fold_custom;
485 }
486
487 let (final_oof, oof_sd) = aggregate_oof_predictions(all_oof, n);
488 let overall_metrics = compute_metrics(y, &final_oof, cv_type);
489 let custom_metrics = if metric_fns.is_empty() {
490 HashMap::new()
491 } else {
492 evaluate_metrics(y, &final_oof, metric_fns)
493 };
494
495 CvFdataResult {
496 oof_predictions: final_oof,
497 metrics: overall_metrics,
498 fold_metrics: last_fold_metrics,
499 folds: last_folds,
500 cv_type,
501 nrep,
502 oof_sd,
503 rep_metrics: if nrep > 1 {
504 Some(all_rep_metrics)
505 } else {
506 None
507 },
508 custom_metrics,
509 fold_custom_metrics: last_fold_custom,
510 }
511}
512
513fn compute_metrics(y_true: &[f64], y_pred: &[f64], cv_type: CvType) -> CvMetrics {
515 let n = y_true.len().min(y_pred.len());
516 if n == 0 {
517 return match cv_type {
518 CvType::Regression => CvMetrics::Regression {
519 rmse: f64::NAN,
520 mae: f64::NAN,
521 r_squared: f64::NAN,
522 },
523 CvType::Classification => CvMetrics::Classification {
524 accuracy: 0.0,
525 confusion: Vec::new(),
526 },
527 };
528 }
529
530 match cv_type {
531 CvType::Regression => {
532 let mean_y = y_true.iter().sum::<f64>() / n as f64;
533 let mut ss_res = 0.0;
534 let mut ss_tot = 0.0;
535 let mut mae_sum = 0.0;
536 for i in 0..n {
537 let resid = y_true[i] - y_pred[i];
538 ss_res += resid * resid;
539 ss_tot += (y_true[i] - mean_y).powi(2);
540 mae_sum += resid.abs();
541 }
542 let rmse = (ss_res / n as f64).sqrt();
543 let mae = mae_sum / n as f64;
544 let r_squared = if ss_tot > 1e-15 {
545 1.0 - ss_res / ss_tot
546 } else {
547 0.0
548 };
549 CvMetrics::Regression {
550 rmse,
551 mae,
552 r_squared,
553 }
554 }
555 CvType::Classification => {
556 let n_classes = y_true
557 .iter()
558 .chain(y_pred.iter())
559 .map(|&v| v as usize)
560 .max()
561 .unwrap_or(0)
562 + 1;
563 let mut confusion = vec![vec![0usize; n_classes]; n_classes];
564 let mut correct = 0usize;
565 for i in 0..n {
566 let true_c = y_true[i] as usize;
567 let pred_c = y_pred[i].round() as usize;
568 if true_c < n_classes && pred_c < n_classes {
569 confusion[true_c][pred_c] += 1;
570 }
571 if true_c == pred_c {
572 correct += 1;
573 }
574 }
575 let accuracy = correct as f64 / n as f64;
576 CvMetrics::Classification {
577 accuracy,
578 confusion,
579 }
580 }
581 }
582}
583
584#[derive(Debug, Clone, PartialEq)]
605#[non_exhaustive]
606pub struct CvSelectionResult<T: Clone> {
607 pub candidates: Vec<T>,
609 pub cv_errors: Vec<f64>,
611 pub optimal: T,
613 pub min_error: f64,
615}
616
617impl<T: Clone + PartialOrd> CvSelectionResult<T> {
618 #[must_use]
622 pub fn from_search(candidates: Vec<T>, cv_errors: Vec<f64>) -> Option<Self> {
623 if candidates.is_empty() || candidates.len() != cv_errors.len() {
624 return None;
625 }
626 let (idx, &min_error) = cv_errors
627 .iter()
628 .enumerate()
629 .min_by(|(_, a), (_, b)| a.total_cmp(b))?;
630 Some(Self {
631 optimal: candidates[idx].clone(),
632 candidates,
633 cv_errors,
634 min_error,
635 })
636 }
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642 use crate::error::FdarError;
643
644 #[test]
645 fn test_create_folds_basic() {
646 let folds = create_folds(10, 5, 42);
647 assert_eq!(folds.len(), 10);
648 for f in 0..5 {
650 let count = folds.iter().filter(|&&x| x == f).count();
651 assert_eq!(count, 2);
652 }
653 }
654
655 #[test]
656 fn test_create_folds_deterministic() {
657 let f1 = create_folds(20, 5, 123);
658 let f2 = create_folds(20, 5, 123);
659 assert_eq!(f1, f2);
660 }
661
662 #[test]
663 fn test_stratified_folds() {
664 let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
665 let folds = create_stratified_folds(10, &y, 5, 42);
666 assert_eq!(folds.len(), 10);
667 for f in 0..5 {
669 let class0_count = (0..10).filter(|&i| folds[i] == f && y[i] == 0).count();
670 let class1_count = (0..10).filter(|&i| folds[i] == f && y[i] == 1).count();
671 assert_eq!(class0_count, 1);
672 assert_eq!(class1_count, 1);
673 }
674 }
675
676 #[test]
677 fn test_fold_indices() {
678 let folds = vec![0, 1, 2, 0, 1, 2];
679 let (train, test) = fold_indices(&folds, 1);
680 assert_eq!(test, vec![1, 4]);
681 assert_eq!(train, vec![0, 2, 3, 5]);
682 }
683
684 #[test]
685 fn test_subset_rows() {
686 let mut data = FdMatrix::zeros(4, 3);
687 for i in 0..4 {
688 for j in 0..3 {
689 data[(i, j)] = (i * 10 + j) as f64;
690 }
691 }
692 let sub = subset_rows(&data, &[1, 3]);
693 assert_eq!(sub.nrows(), 2);
694 assert_eq!(sub.ncols(), 3);
695 assert!((sub[(0, 0)] - 10.0).abs() < 1e-10);
696 assert!((sub[(1, 0)] - 30.0).abs() < 1e-10);
697 }
698
699 #[test]
700 fn test_cv_fdata_regression() -> Result<(), FdarError> {
701 let n = 20;
703 let m = 5;
704 let mut data = FdMatrix::zeros(n, m);
705 let y: Vec<f64> = (0..n).map(|i| i as f64).collect();
706 for i in 0..n {
707 for j in 0..m {
708 data[(i, j)] = y[i] + j as f64 * 0.1;
709 }
710 }
711
712 let result = cv_fdata(
713 &data,
714 &y,
715 |_train_data, train_y| {
716 let mean = train_y.iter().sum::<f64>() / train_y.len() as f64;
717 Box::new(mean)
718 },
719 |model, test_data| {
720 let mean = model.downcast_ref::<f64>().unwrap();
721 vec![*mean; test_data.nrows()]
722 },
723 5,
724 1,
725 CvType::Regression,
726 false,
727 42,
728 );
729
730 assert_eq!(result.oof_predictions.len(), n);
731 assert_eq!(result.nrep, 1);
732 assert!(result.oof_sd.is_none());
733 match &result.metrics {
734 CvMetrics::Regression { rmse, .. } => assert!(*rmse > 0.0),
735 _ => {
736 return Err(FdarError::ComputationFailed {
737 operation: "cv_fdata_regression",
738 detail: "expected regression metrics".into(),
739 });
740 }
741 }
742 Ok(())
743 }
744
745 #[test]
746 fn test_cv_fdata_repeated() {
747 let n = 20;
748 let m = 3;
749 let data = FdMatrix::zeros(n, m);
750 let y: Vec<f64> = (0..n).map(|i| (i % 2) as f64).collect();
751
752 let result = cv_fdata(
753 &data,
754 &y,
755 |_d, _y| Box::new(0.5_f64),
756 |_model, test_data| vec![0.5; test_data.nrows()],
757 5,
758 3,
759 CvType::Regression,
760 false,
761 42,
762 );
763
764 assert_eq!(result.nrep, 3);
765 assert!(result.oof_sd.is_some());
766 assert!(result.rep_metrics.is_some());
767 assert_eq!(result.rep_metrics.as_ref().unwrap().len(), 3);
768 }
769
770 #[test]
771 fn test_custom_metrics() {
772 let n = 20;
773 let m = 3;
774 let data = FdMatrix::zeros(n, m);
775 let y: Vec<f64> = (0..n).map(|i| i as f64).collect();
776
777 let metrics = regression_metrics();
778 let result = cv_fdata_with_metrics(
779 &data,
780 &y,
781 |_d, train_y| {
782 let mean = train_y.iter().sum::<f64>() / train_y.len() as f64;
783 Box::new(mean)
784 },
785 |model, test_data| {
786 let mean = model.downcast_ref::<f64>().unwrap();
787 vec![*mean; test_data.nrows()]
788 },
789 5,
790 1,
791 CvType::Regression,
792 false,
793 42,
794 &metrics,
795 );
796
797 assert!(result.custom_metrics.contains_key("rmse"));
798 assert!(result.custom_metrics.contains_key("mae"));
799 assert!(result.custom_metrics.contains_key("r_squared"));
800 assert!(*result.custom_metrics.get("rmse").unwrap() > 0.0);
801 assert_eq!(result.fold_custom_metrics.len(), 5);
802 }
803
804 #[test]
805 fn test_classification_metrics_standalone() {
806 let y_true = vec![0.0, 0.0, 1.0, 1.0, 1.0];
807 let y_pred = vec![0.0, 1.0, 1.0, 1.0, 0.0];
808 assert!((metric_accuracy(&y_true, &y_pred) - 0.6).abs() < 1e-10);
809 assert!((metric_precision(&y_true, &y_pred) - 2.0 / 3.0).abs() < 1e-10); assert!((metric_recall(&y_true, &y_pred) - 2.0 / 3.0).abs() < 1e-10); let f1 = metric_f1(&y_true, &y_pred);
812 assert!((f1 - 2.0 / 3.0).abs() < 1e-10); }
814
815 #[test]
816 fn test_compute_metrics_classification() -> Result<(), FdarError> {
817 let y_true = vec![0.0, 0.0, 1.0, 1.0];
818 let y_pred = vec![0.0, 1.0, 1.0, 1.0]; let m = compute_metrics(&y_true, &y_pred, CvType::Classification);
820 match m {
821 CvMetrics::Classification {
822 accuracy,
823 confusion,
824 } => {
825 assert!((accuracy - 0.75).abs() < 1e-10);
826 assert_eq!(confusion[0][0], 1); assert_eq!(confusion[0][1], 1); assert_eq!(confusion[1][1], 2); }
830 _ => {
831 return Err(FdarError::ComputationFailed {
832 operation: "compute_metrics_classification",
833 detail: "expected classification metrics".into(),
834 });
835 }
836 }
837 Ok(())
838 }
839
840 #[test]
843 fn cv_selection_basic() {
844 let candidates: Vec<f64> = vec![0.01, 0.1, 1.0, 10.0];
845 let cv_errors = vec![2.5, 1.2, 0.8, 1.5];
846 let result = CvSelectionResult::from_search(candidates, cv_errors).unwrap();
847 assert!((result.optimal - 1.0_f64).abs() < 1e-15);
848 assert!((result.min_error - 0.8).abs() < 1e-15);
849 assert_eq!(result.candidates.len(), 4);
850 assert_eq!(result.cv_errors.len(), 4);
851 }
852
853 #[test]
854 fn cv_selection_usize() {
855 let candidates = vec![1usize, 2, 3, 4, 5];
856 let cv_errors = vec![3.0, 2.0, 1.0, 1.5, 2.5];
857 let result = CvSelectionResult::from_search(candidates, cv_errors).unwrap();
858 assert_eq!(result.optimal, 3);
859 assert!((result.min_error - 1.0).abs() < 1e-15);
860 }
861
862 #[test]
863 fn cv_selection_empty() {
864 let result = CvSelectionResult::<f64>::from_search(vec![], vec![]);
865 assert!(result.is_none());
866 }
867
868 #[test]
869 fn cv_selection_length_mismatch() {
870 let result = CvSelectionResult::<f64>::from_search(vec![1.0, 2.0], vec![1.0]);
871 assert!(result.is_none());
872 }
873
874 #[test]
875 fn cv_selection_nan_handling() {
876 let candidates: Vec<f64> = vec![1.0, 2.0, 3.0];
878 let cv_errors = vec![f64::NAN, 0.5, f64::NAN];
879 let result = CvSelectionResult::from_search(candidates, cv_errors).unwrap();
880 assert!((result.optimal - 2.0_f64).abs() < 1e-15);
881 assert!((result.min_error - 0.5).abs() < 1e-15);
882 }
883}