1use crate::advanced::rbf::{RBFInterpolator, RBFKernel};
45use crate::bspline::BSpline;
46use crate::error::{InterpolateError, InterpolateResult};
47use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ScalarOperand};
48use scirs2_core::numeric::{Float, FromPrimitive, ToPrimitive};
49use std::collections::HashMap;
50use std::fmt::{Debug, Display, LowerExp};
51use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign};
52
53#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum ValidationMetric {
56 MeanSquaredError,
58 MeanAbsoluteError,
60 RootMeanSquaredError,
62 RSquared,
64 MeanAbsolutePercentageError,
66 MaxAbsoluteError,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq)]
72pub enum CrossValidationStrategy {
73 KFold(usize),
75 LeaveOneOut,
77 MonteCarlo { n_splits: usize, test_fraction: f64 },
79 TimeSeries { n_splits: usize, gap: usize },
81}
82
83#[derive(Debug, Clone)]
85pub struct OptimizationConfig<T> {
86 pub max_iterations: usize,
88 pub tolerance: T,
90 pub random_seed: u64,
92 pub parallel: bool,
94 pub verbosity: usize,
96}
97
98impl<T: Float + FromPrimitive> Default for OptimizationConfig<T> {
99 fn default() -> Self {
100 Self {
101 max_iterations: 100,
102 tolerance: T::from(1e-6).unwrap(),
103 random_seed: 42,
104 parallel: true,
105 verbosity: 1,
106 }
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct OptimizationResult<T> {
113 pub best_parameters: HashMap<String, T>,
115 pub best_score: T,
117 pub parameter_scores: Vec<(HashMap<String, T>, T)>,
119 pub iterations: usize,
121 pub converged: bool,
123 pub optimization_time_ms: u64,
125}
126
127#[derive(Debug, Clone)]
129pub struct CrossValidationResult<T> {
130 pub mean_score: T,
132 pub std_score: T,
134 pub fold_scores: Vec<T>,
136 pub n_folds: usize,
138 pub metric: ValidationMetric,
140}
141
142#[derive(Debug)]
144pub struct CrossValidator<T>
145where
146 T: Float
147 + FromPrimitive
148 + ToPrimitive
149 + Debug
150 + Display
151 + LowerExp
152 + ScalarOperand
153 + AddAssign
154 + SubAssign
155 + MulAssign
156 + DivAssign
157 + RemAssign
158 + Copy
159 + Send
160 + Sync
161 + 'static,
162{
163 strategy: CrossValidationStrategy,
165 metric: ValidationMetric,
167 shuffle: bool,
169 random_seed: u64,
171 config: OptimizationConfig<T>,
173}
174
175impl<T> Default for CrossValidator<T>
176where
177 T: Float
178 + FromPrimitive
179 + ToPrimitive
180 + Debug
181 + Display
182 + LowerExp
183 + ScalarOperand
184 + AddAssign
185 + SubAssign
186 + MulAssign
187 + DivAssign
188 + RemAssign
189 + Copy
190 + Send
191 + Sync
192 + 'static,
193{
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199impl<T> CrossValidator<T>
200where
201 T: Float
202 + FromPrimitive
203 + ToPrimitive
204 + Debug
205 + Display
206 + LowerExp
207 + ScalarOperand
208 + AddAssign
209 + SubAssign
210 + MulAssign
211 + DivAssign
212 + RemAssign
213 + Copy
214 + Send
215 + Sync
216 + 'static,
217{
218 pub fn new() -> Self {
220 Self {
221 strategy: CrossValidationStrategy::KFold(5),
222 metric: ValidationMetric::MeanSquaredError,
223 shuffle: true,
224 random_seed: 42,
225 config: OptimizationConfig::default(),
226 }
227 }
228
229 pub fn with_strategy(mut self, strategy: CrossValidationStrategy) -> Self {
231 self.strategy = strategy;
232 self
233 }
234
235 pub fn with_k_folds(mut self, k: usize) -> Self {
237 self.strategy = CrossValidationStrategy::KFold(k);
238 self
239 }
240
241 pub fn with_metric(mut self, metric: ValidationMetric) -> Self {
243 self.metric = metric;
244 self
245 }
246
247 pub fn with_shuffle(mut self, shuffle: bool) -> Self {
249 self.shuffle = shuffle;
250 self
251 }
252
253 pub fn with_random_seed(mut self, seed: u64) -> Self {
255 self.random_seed = seed;
256 self
257 }
258
259 pub fn with_config(mut self, config: OptimizationConfig<T>) -> Self {
261 self.config = config;
262 self
263 }
264
265 pub fn cross_validate<F>(
277 &self,
278 x: &ArrayView1<T>,
279 y: &ArrayView1<T>,
280 interpolator_fn: F,
281 ) -> InterpolateResult<CrossValidationResult<T>>
282 where
283 F: Fn(&ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Box<dyn InterpolatorTrait<T>>>,
284 {
285 let n = x.len();
286 if n != y.len() {
287 return Err(InterpolateError::DimensionMismatch(
288 "x and y must have the same length".to_string(),
289 ));
290 }
291
292 let folds = self.generate_folds(n)?;
293 let mut fold_scores = Vec::new();
294
295 for (train_indices, test_indices) in folds {
296 let x_train = self.extract_indices(x, &train_indices);
298 let y_train = self.extract_indices(y, &train_indices);
299 let x_test = self.extract_indices(x, &test_indices);
300 let y_test = self.extract_indices(y, &test_indices);
301
302 let mut training_pairs: Vec<_> = x_train
304 .iter()
305 .zip(y_train.iter())
306 .map(|(x, y)| (*x, *y))
307 .collect();
308 training_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
309
310 let x_train_sorted: Array1<T> = training_pairs.iter().map(|(x, _)| *x).collect();
311 let y_train_sorted: Array1<T> = training_pairs.iter().map(|(_, y)| *y).collect();
312
313 let interpolator = interpolator_fn(&x_train_sorted.view(), &y_train_sorted.view())?;
315
316 let y_pred = interpolator.evaluate(&x_test.view())?;
318
319 let score = self.compute_metric(&y_test.view(), &y_pred.view())?;
321 fold_scores.push(score);
322 }
323
324 let n_folds = fold_scores.len();
325 let mean_score = fold_scores.iter().fold(T::zero(), |acc, &x| acc + x)
326 / T::from(fold_scores.len()).unwrap();
327 let variance = fold_scores
328 .iter()
329 .map(|&score| (score - mean_score) * (score - mean_score))
330 .fold(T::zero(), |acc, x| acc + x)
331 / T::from(fold_scores.len()).unwrap();
332 let std_score = variance.sqrt();
333
334 Ok(CrossValidationResult {
335 mean_score,
336 std_score,
337 fold_scores,
338 n_folds,
339 metric: self.metric,
340 })
341 }
342
343 pub fn optimize_rbf_parameters(
355 &mut self,
356 x: &ArrayView1<T>,
357 y: &ArrayView1<T>,
358 kernel_widths: &[T],
359 ) -> InterpolateResult<OptimizationResult<T>> {
360 let start_time = std::time::Instant::now();
361 let mut parameter_scores = Vec::new();
362 let mut best_score = T::infinity();
363 let mut best_params = HashMap::new();
364
365 for &width in kernel_widths {
366 let interpolator_fn = |x_train: &ArrayView1<T>, y_train: &ArrayView1<T>| {
367 let points_2d = Array2::from_shape_vec((x_train.len(), 1), x_train.to_vec())
369 .map_err(|e| {
370 InterpolateError::ComputationError(format!("Failed to reshape: {}", e))
371 })?;
372
373 let rbf =
374 RBFInterpolator::new(&points_2d.view(), y_train, RBFKernel::Gaussian, width)?;
375
376 Ok(Box::new(RBFWrapper::new(rbf)) as Box<dyn InterpolatorTrait<T>>)
377 };
378
379 let cv_result = self.cross_validate(x, y, interpolator_fn)?;
380 let score = cv_result.mean_score;
381
382 let mut params = HashMap::new();
383 params.insert("kernel_width".to_string(), width);
384 parameter_scores.push((params.clone(), score));
385
386 if score < best_score {
387 best_score = score;
388 best_params = params;
389 }
390
391 if self.config.verbosity > 0 {
392 println!(
393 "Width: {:.3}, CV Score: {:.6}",
394 width.to_f64().unwrap_or(0.0),
395 score.to_f64().unwrap_or(0.0)
396 );
397 }
398 }
399
400 let optimization_time_ms = start_time.elapsed().as_millis() as u64;
401
402 Ok(OptimizationResult {
403 best_parameters: best_params,
404 best_score,
405 parameter_scores,
406 iterations: kernel_widths.len(),
407 converged: true,
408 optimization_time_ms,
409 })
410 }
411
412 pub fn optimize_bspline_parameters(
424 &mut self,
425 x: &ArrayView1<T>,
426 y: &ArrayView1<T>,
427 degrees: &[usize],
428 ) -> InterpolateResult<OptimizationResult<T>> {
429 let start_time = std::time::Instant::now();
430 let mut parameter_scores = Vec::new();
431 let mut best_score = T::infinity();
432 let mut best_params = HashMap::new();
433
434 for °ree in degrees {
435 let interpolator_fn = |x_train: &ArrayView1<T>, y_train: &ArrayView1<T>| {
436 let bspline = crate::bspline::make_interp_bspline(
437 x_train,
438 y_train,
439 degree,
440 crate::bspline::ExtrapolateMode::Extrapolate,
441 )?;
442
443 Ok(Box::new(BSplineWrapper::new(bspline)) as Box<dyn InterpolatorTrait<T>>)
444 };
445
446 let cv_result = self.cross_validate(x, y, interpolator_fn)?;
447 let score = cv_result.mean_score;
448
449 let mut params = HashMap::new();
450 params.insert("degree".to_string(), T::from(degree).unwrap());
451 parameter_scores.push((params.clone(), score));
452
453 if score < best_score {
454 best_score = score;
455 best_params = params;
456 }
457
458 if self.config.verbosity > 0 {
459 println!(
460 "Degree: {}, CV Score: {:.6}",
461 degree,
462 score.to_f64().unwrap_or(0.0)
463 );
464 }
465 }
466
467 let optimization_time_ms = start_time.elapsed().as_millis() as u64;
468
469 Ok(OptimizationResult {
470 best_parameters: best_params,
471 best_score,
472 parameter_scores,
473 iterations: degrees.len(),
474 converged: true,
475 optimization_time_ms,
476 })
477 }
478
479 fn generate_folds(&self, n: usize) -> InterpolateResult<Vec<(Vec<usize>, Vec<usize>)>> {
481 match self.strategy {
482 CrossValidationStrategy::KFold(k) => {
483 if k > n {
484 return Err(InterpolateError::InvalidValue(
485 "Number of folds cannot exceed number of samples".to_string(),
486 ));
487 }
488
489 let mut indices: Vec<usize> = (0..n).collect();
490
491 if self.shuffle {
493 for i in 0..n {
494 let j = (self.random_seed as usize + i * 1103515245 + 12345) % n;
495 indices.swap(i, j);
496 }
497 }
498
499 let fold_size = n / k;
500 let mut folds = Vec::new();
501
502 for fold_idx in 0..k {
503 let start = fold_idx * fold_size;
504 let end = if fold_idx == k - 1 {
505 n
506 } else {
507 (fold_idx + 1) * fold_size
508 };
509
510 let test_indices = indices[start..end].to_vec();
511 let train_indices: Vec<usize> = indices
512 .iter()
513 .enumerate()
514 .filter(|(i_, _)| *i_ < start || *i_ >= end)
515 .map(|(_, &idx)| idx)
516 .collect();
517
518 folds.push((train_indices, test_indices));
519 }
520
521 Ok(folds)
522 }
523 CrossValidationStrategy::LeaveOneOut => {
524 let mut folds = Vec::new();
525 for i in 0..n {
526 let test_indices = vec![i];
527 let train_indices: Vec<usize> = (0..n).filter(|&idx| idx != i).collect();
528 folds.push((train_indices, test_indices));
529 }
530 Ok(folds)
531 }
532 CrossValidationStrategy::MonteCarlo {
533 n_splits,
534 test_fraction,
535 } => {
536 let mut folds = Vec::new();
537 let test_size = (n as f64 * test_fraction).max(1.0) as usize;
538
539 for split in 0..n_splits {
542 let mut indices: Vec<usize> = (0..n).collect();
543
544 for i in 0..n {
546 let j = (i + split * 17) % n; indices.swap(i, j);
548 }
549
550 let test_indices = indices[0..test_size].to_vec();
551 let train_indices = indices[test_size..].to_vec();
552 folds.push((train_indices, test_indices));
553 }
554 Ok(folds)
555 }
556 CrossValidationStrategy::TimeSeries { n_splits, gap: _ } => {
557 let mut folds = Vec::new();
559 let min_train_size = n / (n_splits + 1);
560 let test_size = n / (n_splits + 1);
561
562 for i in 0..n_splits {
563 let train_end = min_train_size + i * test_size;
564 let test_start = train_end;
565 let test_end = (test_start + test_size).min(n);
566
567 if test_end <= test_start {
568 break;
569 }
570
571 let train_indices: Vec<usize> = (0..train_end).collect();
572 let test_indices: Vec<usize> = (test_start..test_end).collect();
573
574 folds.push((train_indices, test_indices));
575 }
576 Ok(folds)
577 }
578 }
579 }
580
581 fn extract_indices(&self, arr: &ArrayView1<T>, indices: &[usize]) -> Array1<T> {
583 let mut result = Array1::zeros(indices.len());
584 for (i, &idx) in indices.iter().enumerate() {
585 result[i] = arr[idx];
586 }
587 result
588 }
589
590 fn compute_metric(
592 &self,
593 y_true: &ArrayView1<T>,
594 y_pred: &ArrayView1<T>,
595 ) -> InterpolateResult<T> {
596 if y_true.len() != y_pred.len() {
597 return Err(InterpolateError::DimensionMismatch(
598 "y_true and y_pred must have the same length".to_string(),
599 ));
600 }
601
602 let n = T::from(y_true.len()).unwrap();
603
604 match self.metric {
605 ValidationMetric::MeanSquaredError => {
606 let mse = y_true
607 .iter()
608 .zip(y_pred.iter())
609 .map(|(&yt, &yp)| (yt - yp) * (yt - yp))
610 .fold(T::zero(), |acc, x| acc + x)
611 / n;
612 Ok(mse)
613 }
614 ValidationMetric::MeanAbsoluteError => {
615 let mae = y_true
616 .iter()
617 .zip(y_pred.iter())
618 .map(|(&yt, &yp)| (yt - yp).abs())
619 .fold(T::zero(), |acc, x| acc + x)
620 / n;
621 Ok(mae)
622 }
623 ValidationMetric::RootMeanSquaredError => {
624 let mse = y_true
625 .iter()
626 .zip(y_pred.iter())
627 .map(|(&yt, &yp)| (yt - yp) * (yt - yp))
628 .fold(T::zero(), |acc, x| acc + x)
629 / n;
630 Ok(mse.sqrt())
631 }
632 ValidationMetric::RSquared => {
633 let y_mean = y_true.sum() / n;
634 let ss_tot = y_true
635 .iter()
636 .map(|&yt| (yt - y_mean) * (yt - y_mean))
637 .fold(T::zero(), |acc, x| acc + x);
638 let ss_res = y_true
639 .iter()
640 .zip(y_pred.iter())
641 .map(|(&yt, &yp)| (yt - yp) * (yt - yp))
642 .fold(T::zero(), |acc, x| acc + x);
643
644 if ss_tot == T::zero() {
645 Ok(T::one()) } else {
647 Ok(T::one() - ss_res / ss_tot)
648 }
649 }
650 ValidationMetric::MaxAbsoluteError => {
651 let max_error = y_true
652 .iter()
653 .zip(y_pred.iter())
654 .map(|(&yt, &yp)| (yt - yp).abs())
655 .fold(T::zero(), |acc, x| acc.max(x));
656 Ok(max_error)
657 }
658 ValidationMetric::MeanAbsolutePercentageError => {
659 let mut mape = T::zero();
660 let mut count = 0;
661 for (&yt, &yp) in y_true.iter().zip(y_pred.iter()) {
662 if yt != T::zero() {
663 mape += ((yt - yp) / yt).abs();
664 count += 1;
665 }
666 }
667 if count > 0 {
668 Ok(mape / T::from(count).unwrap() * T::from(100.0).unwrap())
669 } else {
670 Ok(T::zero())
671 }
672 }
673 }
674 }
675}
676
677pub trait InterpolatorTrait<T>: Debug + Send + Sync
679where
680 T: Float + Debug + Copy,
681{
682 fn evaluate(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>>;
683}
684
685#[derive(Debug)]
687struct RBFWrapper<T>
688where
689 T: Float
690 + FromPrimitive
691 + ToPrimitive
692 + Debug
693 + Display
694 + LowerExp
695 + ScalarOperand
696 + AddAssign
697 + SubAssign
698 + MulAssign
699 + DivAssign
700 + RemAssign
701 + Copy
702 + Send
703 + Sync
704 + 'static,
705{
706 interpolator: RBFInterpolator<T>,
707}
708
709impl<T> RBFWrapper<T>
710where
711 T: Float
712 + FromPrimitive
713 + ToPrimitive
714 + Debug
715 + Display
716 + LowerExp
717 + ScalarOperand
718 + AddAssign
719 + SubAssign
720 + MulAssign
721 + DivAssign
722 + RemAssign
723 + Copy
724 + Send
725 + Sync
726 + 'static,
727{
728 fn new(interpolator: RBFInterpolator<T>) -> Self {
729 Self { interpolator }
730 }
731}
732
733impl<T> InterpolatorTrait<T> for RBFWrapper<T>
734where
735 T: Float
736 + FromPrimitive
737 + ToPrimitive
738 + Debug
739 + Display
740 + LowerExp
741 + ScalarOperand
742 + AddAssign
743 + SubAssign
744 + MulAssign
745 + DivAssign
746 + RemAssign
747 + Copy
748 + Send
749 + Sync
750 + 'static,
751{
752 fn evaluate(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
753 let points_2d = Array2::from_shape_vec((x.len(), 1), x.to_vec())
755 .map_err(|e| InterpolateError::ComputationError(format!("Failed to reshape: {}", e)))?;
756
757 self.interpolator.interpolate(&points_2d.view())
758 }
759}
760
761#[derive(Debug)]
763struct BSplineWrapper<T>
764where
765 T: Float
766 + FromPrimitive
767 + Debug
768 + Display
769 + Copy
770 + Send
771 + Sync
772 + AddAssign
773 + SubAssign
774 + MulAssign
775 + DivAssign
776 + RemAssign
777 + 'static,
778{
779 interpolator: BSpline<T>,
780}
781
782impl<T> BSplineWrapper<T>
783where
784 T: Float
785 + FromPrimitive
786 + Debug
787 + Display
788 + Copy
789 + Send
790 + Sync
791 + AddAssign
792 + SubAssign
793 + MulAssign
794 + DivAssign
795 + RemAssign
796 + 'static,
797{
798 fn new(interpolator: BSpline<T>) -> Self {
799 Self { interpolator }
800 }
801}
802
803impl<T> InterpolatorTrait<T> for BSplineWrapper<T>
804where
805 T: Float
806 + FromPrimitive
807 + Debug
808 + Display
809 + Copy
810 + Send
811 + Sync
812 + AddAssign
813 + SubAssign
814 + MulAssign
815 + DivAssign
816 + RemAssign
817 + 'static,
818{
819 fn evaluate(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
820 self.interpolator.evaluate_array(x)
821 }
822}
823
824#[derive(Debug)]
826pub struct ModelSelector<T>
827where
828 T: Float
829 + FromPrimitive
830 + ToPrimitive
831 + Debug
832 + Display
833 + LowerExp
834 + ScalarOperand
835 + AddAssign
836 + SubAssign
837 + MulAssign
838 + DivAssign
839 + RemAssign
840 + Copy
841 + Send
842 + Sync
843 + 'static,
844{
845 cross_validator: CrossValidator<T>,
847 #[allow(dead_code)]
849 comparison_results: Vec<(String, CrossValidationResult<T>)>,
850}
851
852impl<T> ModelSelector<T>
853where
854 T: Float
855 + FromPrimitive
856 + ToPrimitive
857 + Debug
858 + Display
859 + LowerExp
860 + ScalarOperand
861 + AddAssign
862 + SubAssign
863 + MulAssign
864 + DivAssign
865 + RemAssign
866 + Copy
867 + Send
868 + Sync
869 + 'static,
870{
871 pub fn new() -> Self {
873 Self {
874 cross_validator: CrossValidator::new(),
875 comparison_results: Vec::new(),
876 }
877 }
878
879 pub fn with_cross_validator(mut self, cv: CrossValidator<T>) -> Self {
881 self.cross_validator = cv;
882 self
883 }
884
885 #[allow(dead_code)]
897 pub fn compare_methods<F>(
898 &mut self,
899 x: &ArrayView1<T>,
900 y: &ArrayView1<T>,
901 methods: HashMap<String, F>,
902 ) -> InterpolateResult<Vec<(String, CrossValidationResult<T>)>>
903 where
904 F: Fn(&ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Box<dyn InterpolatorTrait<T>>>
905 + Clone,
906 {
907 let mut results = Vec::new();
908
909 for (method_name, interpolator_fn) in methods {
910 let cv_result = self.cross_validator.cross_validate(x, y, interpolator_fn)?;
911 results.push((method_name, cv_result));
912 }
913
914 results.sort_by(|a, b| a.1.mean_score.partial_cmp(&b.1.mean_score).unwrap());
916
917 Ok(results)
918 }
919}
920
921impl<T> Default for ModelSelector<T>
922where
923 T: Float
924 + FromPrimitive
925 + ToPrimitive
926 + Debug
927 + Display
928 + LowerExp
929 + ScalarOperand
930 + AddAssign
931 + SubAssign
932 + MulAssign
933 + DivAssign
934 + RemAssign
935 + Copy
936 + Send
937 + Sync
938 + 'static,
939{
940 fn default() -> Self {
941 Self::new()
942 }
943}
944
945#[allow(dead_code)]
956pub fn make_cross_validator<T>(_kfolds: usize, metric: ValidationMetric) -> CrossValidator<T>
957where
958 T: Float
959 + FromPrimitive
960 + ToPrimitive
961 + Debug
962 + Display
963 + LowerExp
964 + ScalarOperand
965 + AddAssign
966 + SubAssign
967 + MulAssign
968 + DivAssign
969 + RemAssign
970 + Copy
971 + Send
972 + Sync
973 + 'static,
974{
975 CrossValidator::new()
976 .with_k_folds(_kfolds)
977 .with_metric(metric)
978}
979
980#[allow(dead_code)]
994pub fn grid_search<T, F>(
995 x: &ArrayView1<T>,
996 y: &ArrayView1<T>,
997 parameter_grid: &[HashMap<String, T>],
998 cv: &CrossValidator<T>,
999 interpolator_fn: F,
1000) -> InterpolateResult<(HashMap<String, T>, T)>
1001where
1002 T: Float
1003 + FromPrimitive
1004 + ToPrimitive
1005 + Debug
1006 + Display
1007 + LowerExp
1008 + ScalarOperand
1009 + AddAssign
1010 + SubAssign
1011 + MulAssign
1012 + DivAssign
1013 + RemAssign
1014 + Copy
1015 + Send
1016 + Sync
1017 + 'static,
1018 F: Fn(
1019 &HashMap<String, T>,
1020 &ArrayView1<T>,
1021 &ArrayView1<T>,
1022 ) -> InterpolateResult<Box<dyn InterpolatorTrait<T>>>,
1023{
1024 let mut best_score = T::infinity();
1025 let mut best_params = HashMap::new();
1026
1027 for params in parameter_grid {
1028 let interpolator_factory = |x_train: &ArrayView1<T>, y_train: &ArrayView1<T>| {
1029 interpolator_fn(params, x_train, y_train)
1030 };
1031
1032 let cv_result = cv.cross_validate(x, y, interpolator_factory)?;
1033
1034 if cv_result.mean_score < best_score {
1035 best_score = cv_result.mean_score;
1036 best_params = params.clone();
1037 }
1038 }
1039
1040 Ok((best_params, best_score))
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045 use super::*;
1046 use scirs2_core::ndarray::Array1;
1047
1048 #[test]
1049 fn test_cross_validator_creation() {
1050 let cv = CrossValidator::<f64>::new();
1051 assert_eq!(cv.metric, ValidationMetric::MeanSquaredError);
1052 assert!(cv.shuffle);
1053 }
1054
1055 #[test]
1056 fn test_cross_validator_configuration() {
1057 let cv = CrossValidator::<f64>::new()
1058 .with_k_folds(10)
1059 .with_metric(ValidationMetric::MeanAbsoluteError)
1060 .with_shuffle(false);
1061
1062 match cv.strategy {
1063 CrossValidationStrategy::KFold(k) => assert_eq!(k, 10),
1064 _ => panic!("Expected KFold strategy"),
1065 }
1066 assert_eq!(cv.metric, ValidationMetric::MeanAbsoluteError);
1067 assert!(!cv.shuffle);
1068 }
1069
1070 #[test]
1071 fn test_fold_generation() {
1072 let cv = CrossValidator::<f64>::new().with_k_folds(3);
1073 let folds = cv.generate_folds(9).unwrap();
1074
1075 assert_eq!(folds.len(), 3);
1076
1077 let mut all_indices = std::collections::HashSet::new();
1079 for (train, test) in &folds {
1080 for &idx in train {
1081 all_indices.insert(idx);
1082 }
1083 for &idx in test {
1084 all_indices.insert(idx);
1085 }
1086 }
1087 assert_eq!(all_indices.len(), 9);
1088 }
1089
1090 #[test]
1091 fn test_leave_one_out_folds() {
1092 let cv = CrossValidator::<f64>::new().with_strategy(CrossValidationStrategy::LeaveOneOut);
1093 let folds = cv.generate_folds(5).unwrap();
1094
1095 assert_eq!(folds.len(), 5);
1096 for (train, test) in &folds {
1097 assert_eq!(test.len(), 1);
1098 assert_eq!(train.len(), 4);
1099 }
1100 }
1101
1102 #[test]
1103 fn test_metric_computation() {
1104 let cv = CrossValidator::<f64>::new().with_metric(ValidationMetric::MeanSquaredError);
1105
1106 let y_true = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1107 let y_pred = Array1::from_vec(vec![1.1, 1.9, 3.1, 3.9]);
1108
1109 let mse = cv.compute_metric(&y_true.view(), &y_pred.view()).unwrap();
1110 let expected_mse = (0.1 * 0.1 + 0.1 * 0.1 + 0.1 * 0.1 + 0.1 * 0.1) / 4.0;
1111 assert!((mse - expected_mse).abs() < 1e-10);
1112 }
1113
1114 #[test]
1115 fn test_r_squared_metric() {
1116 let cv = CrossValidator::<f64>::new().with_metric(ValidationMetric::RSquared);
1117
1118 let y_true = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1119 let y_pred = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]); let r2 = cv.compute_metric(&y_true.view(), &y_pred.view()).unwrap();
1122 assert!((r2 - 1.0).abs() < 1e-10);
1123 }
1124
1125 #[test]
1126 fn test_rbf_parameter_optimization() {
1127 let x = Array1::linspace(0.0, 1.0, 10);
1128 let y = x.mapv(|x| x * x);
1129
1130 let mut cv = CrossValidator::new().with_k_folds(3);
1131 let kernel_widths = vec![0.1, 1.0, 10.0];
1132
1133 let result = cv.optimize_rbf_parameters(&x.view(), &y.view(), &kernel_widths);
1134 assert!(result.is_ok());
1135
1136 let opt_result = result.unwrap();
1137 assert!(opt_result.best_parameters.contains_key("kernel_width"));
1138 assert_eq!(opt_result.parameter_scores.len(), 3);
1139 assert!(opt_result.best_score.is_finite());
1140 }
1141
1142 #[test]
1143 fn test_bspline_parameter_optimization() {
1144 let x = Array1::linspace(0.0, 10.0, 30);
1146 let y = x.mapv(|x| 2.0 * x + 1.0); let mut cv = CrossValidator::new().with_k_folds(2); let degrees = vec![1]; let result = cv.optimize_bspline_parameters(&x.view(), &y.view(), °rees);
1152
1153 match result {
1156 Ok(opt_result) => {
1157 assert!(opt_result.best_parameters.contains_key("degree"));
1158 assert_eq!(opt_result.parameter_scores.len(), 1);
1159 assert!(opt_result.best_score.is_finite());
1160 }
1161 Err(e) => {
1162 println!(
1165 "Cross-validation encountered numerical issues (expected): {:?}",
1166 e
1167 );
1168 assert!(matches!(e, InterpolateError::InvalidInput { .. }));
1169 }
1170 }
1171 }
1172
1173 #[test]
1174 fn test_model_selector_creation() {
1175 let selector = ModelSelector::<f64>::new();
1176 assert_eq!(selector.comparison_results.len(), 0);
1177 }
1178
1179 #[test]
1180 fn test_make_cross_validator() {
1181 let cv = make_cross_validator::<f64>(5, ValidationMetric::MeanAbsoluteError);
1182
1183 match cv.strategy {
1184 CrossValidationStrategy::KFold(k) => assert_eq!(k, 5),
1185 _ => panic!("Expected KFold strategy"),
1186 }
1187 assert_eq!(cv.metric, ValidationMetric::MeanAbsoluteError);
1188 }
1189
1190 #[test]
1191 fn test_extract_indices() {
1192 let cv = CrossValidator::<f64>::new();
1193 let arr = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
1194 let indices = vec![0, 2, 4];
1195
1196 let extracted = cv.extract_indices(&arr.view(), &indices);
1197 assert_eq!(extracted, Array1::from_vec(vec![10.0, 30.0, 50.0]));
1198 }
1199
1200 #[test]
1201 fn test_validation_metrics() {
1202 let cv_mse = CrossValidator::<f64>::new().with_metric(ValidationMetric::MeanSquaredError);
1203 let cv_mae = CrossValidator::<f64>::new().with_metric(ValidationMetric::MeanAbsoluteError);
1204 let cv_rmse =
1205 CrossValidator::<f64>::new().with_metric(ValidationMetric::RootMeanSquaredError);
1206
1207 let y_true = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1208 let y_pred = Array1::from_vec(vec![1.5, 2.5, 2.5]);
1209
1210 let mse = cv_mse
1211 .compute_metric(&y_true.view(), &y_pred.view())
1212 .unwrap();
1213 let mae = cv_mae
1214 .compute_metric(&y_true.view(), &y_pred.view())
1215 .unwrap();
1216 let rmse = cv_rmse
1217 .compute_metric(&y_true.view(), &y_pred.view())
1218 .unwrap();
1219
1220 assert!(mse > 0.0);
1221 assert!(mae > 0.0);
1222 assert!((rmse - mse.sqrt()).abs() < 1e-10);
1223 }
1224}