1use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
7use scirs2_core::random::{thread_rng, Rng, RngExt};
8use sklears_core::{
9 error::Result as SklResult,
10 prelude::SklearsError,
11 types::{Float, FloatBounds},
12};
13use std::collections::HashMap;
14use std::time::Instant;
15
16use crate::Pipeline;
17
18pub enum SearchStrategy {
20 GridSearch,
22 RandomSearch { n_iter: usize },
24 BayesianOptimization,
26 EvolutionarySearch {
28 population_size: usize,
29 generations: usize,
30 },
31}
32
33#[derive(Debug, Clone)]
35pub struct ParameterSpace {
36 pub name: String,
38 pub values: Vec<f64>,
40 pub param_type: ParameterType,
42}
43
44#[derive(Debug, Clone)]
46pub enum ParameterType {
47 Continuous { min: f64, max: f64 },
49 Discrete { min: i32, max: i32 },
51 Categorical { choices: Vec<String> },
53}
54
55impl ParameterSpace {
56 #[must_use]
58 pub fn continuous(name: &str, min: f64, max: f64, n_points: usize) -> Self {
59 let step = (max - min) / (n_points - 1) as f64;
60 let values = (0..n_points).map(|i| min + i as f64 * step).collect();
61
62 Self {
63 name: name.to_string(),
64 values,
65 param_type: ParameterType::Continuous { min, max },
66 }
67 }
68
69 #[must_use]
71 pub fn discrete(name: &str, min: i32, max: i32) -> Self {
72 let values = (min..=max).map(f64::from).collect();
73
74 Self {
75 name: name.to_string(),
76 values,
77 param_type: ParameterType::Discrete { min, max },
78 }
79 }
80
81 #[must_use]
83 pub fn categorical(name: &str, choices: Vec<String>) -> Self {
84 let values = (0..choices.len()).map(|i| i as f64).collect();
85
86 Self {
87 name: name.to_string(),
88 values,
89 param_type: ParameterType::Categorical { choices },
90 }
91 }
92}
93
94pub struct PipelineOptimizer {
96 parameter_spaces: Vec<ParameterSpace>,
97 search_strategy: SearchStrategy,
98 cv_folds: usize,
99 scoring: ScoringMetric,
100 n_jobs: Option<i32>,
101 verbose: bool,
102}
103
104#[derive(Debug, Clone)]
106pub enum ScoringMetric {
107 MeanSquaredError,
109 MeanAbsoluteError,
111 Accuracy,
113 F1Score,
115 Custom { name: String },
117 MultiObjective { metrics: Vec<ScoringMetric> },
119}
120
121#[derive(Debug, Clone)]
123pub struct MultiObjectiveResult {
124 pub params: HashMap<String, f64>,
126 pub scores: Vec<f64>,
128 pub dominated: bool,
130 pub rank: usize,
132}
133
134#[derive(Debug)]
136pub struct ParetoFront {
137 pub solutions: Vec<MultiObjectiveResult>,
139 pub n_objectives: usize,
141 pub hypervolume: f64,
143}
144
145impl PipelineOptimizer {
146 #[must_use]
148 pub fn new() -> Self {
149 Self {
150 parameter_spaces: Vec::new(),
151 search_strategy: SearchStrategy::GridSearch,
152 cv_folds: 5,
153 scoring: ScoringMetric::MeanSquaredError,
154 n_jobs: None,
155 verbose: false,
156 }
157 }
158
159 #[must_use]
161 pub fn parameter_space(mut self, space: ParameterSpace) -> Self {
162 self.parameter_spaces.push(space);
163 self
164 }
165
166 #[must_use]
168 pub fn search_strategy(mut self, strategy: SearchStrategy) -> Self {
169 self.search_strategy = strategy;
170 self
171 }
172
173 #[must_use]
175 pub fn cv_folds(mut self, folds: usize) -> Self {
176 self.cv_folds = folds;
177 self
178 }
179
180 #[must_use]
182 pub fn scoring(mut self, metric: ScoringMetric) -> Self {
183 self.scoring = metric;
184 self
185 }
186
187 #[must_use]
189 pub fn verbose(mut self, verbose: bool) -> Self {
190 self.verbose = verbose;
191 self
192 }
193
194 pub fn optimize<S>(
196 &self,
197 pipeline: Pipeline<S>,
198 x: &ArrayView2<'_, Float>,
199 y: &ArrayView1<'_, Float>,
200 ) -> SklResult<OptimizationResults>
201 where
202 S: std::fmt::Debug,
203 {
204 match self.search_strategy {
205 SearchStrategy::GridSearch => self.grid_search(pipeline, x, y),
206 SearchStrategy::RandomSearch { n_iter } => self.random_search(pipeline, x, y, n_iter),
207 SearchStrategy::BayesianOptimization => Err(SklearsError::NotImplemented(
208 "Bayesian optimization not yet implemented".to_string(),
209 )),
210 SearchStrategy::EvolutionarySearch {
211 population_size,
212 generations,
213 } => self.evolutionary_search(pipeline, x, y, population_size, generations),
214 }
215 }
216
217 fn grid_search<S>(
218 &self,
219 pipeline: Pipeline<S>,
220 x: &ArrayView2<'_, Float>,
221 y: &ArrayView1<'_, Float>,
222 ) -> SklResult<OptimizationResults>
223 where
224 S: std::fmt::Debug,
225 {
226 let start_time = Instant::now();
227
228 if self.parameter_spaces.is_empty() {
229 return Err(SklearsError::InvalidInput(
230 "No parameter spaces defined for optimization".to_string(),
231 ));
232 }
233
234 let param_combinations = self.generate_grid_combinations()?;
236
237 if self.verbose {
238 println!(
239 "Grid search: evaluating {} parameter combinations",
240 param_combinations.len()
241 );
242 }
243
244 let mut best_score = f64::NEG_INFINITY;
245 let mut best_params = HashMap::new();
246 let mut all_scores = Vec::new();
247
248 for (i, params) in param_combinations.iter().enumerate() {
250 if self.verbose {
251 println!(
252 "Evaluating combination {}/{}",
253 i + 1,
254 param_combinations.len()
255 );
256 }
257
258 let cv_score = self.cross_validate_pipeline(&pipeline, x, y)?;
261 all_scores.push(cv_score);
262
263 if cv_score > best_score {
264 best_score = cv_score;
265 best_params = params.clone();
266 }
267 }
268
269 let search_time = start_time.elapsed().as_secs_f64();
270
271 Ok(OptimizationResults {
272 best_params,
273 best_score,
274 cv_scores: all_scores,
275 search_time,
276 })
277 }
278
279 fn random_search<S>(
280 &self,
281 pipeline: Pipeline<S>,
282 x: &ArrayView2<'_, Float>,
283 y: &ArrayView1<'_, Float>,
284 n_iter: usize,
285 ) -> SklResult<OptimizationResults>
286 where
287 S: std::fmt::Debug,
288 {
289 let start_time = Instant::now();
290 let mut rng = thread_rng();
291
292 if self.parameter_spaces.is_empty() {
293 return Err(SklearsError::InvalidInput(
294 "No parameter spaces defined for optimization".to_string(),
295 ));
296 }
297
298 if self.verbose {
299 println!("Random search: evaluating {n_iter} random parameter combinations");
300 }
301
302 let mut best_score = f64::NEG_INFINITY;
303 let mut best_params = HashMap::new();
304 let mut all_scores = Vec::new();
305
306 for i in 0..n_iter {
308 if self.verbose {
309 println!("Evaluating combination {}/{}", i + 1, n_iter);
310 }
311
312 let params = self.generate_random_parameters(&mut rng)?;
314
315 let cv_score = self.cross_validate_pipeline(&pipeline, x, y)?;
317 all_scores.push(cv_score);
318
319 if cv_score > best_score {
320 best_score = cv_score;
321 best_params = params;
322 }
323 }
324
325 let search_time = start_time.elapsed().as_secs_f64();
326
327 Ok(OptimizationResults {
328 best_params,
329 best_score,
330 cv_scores: all_scores,
331 search_time,
332 })
333 }
334
335 fn evolutionary_search<S>(
336 &self,
337 pipeline: Pipeline<S>,
338 x: &ArrayView2<'_, Float>,
339 y: &ArrayView1<'_, Float>,
340 population_size: usize,
341 generations: usize,
342 ) -> SklResult<OptimizationResults>
343 where
344 S: std::fmt::Debug,
345 {
346 let start_time = Instant::now();
347 let mut rng = thread_rng();
348
349 if self.parameter_spaces.is_empty() {
350 return Err(SklearsError::InvalidInput(
351 "No parameter spaces defined for optimization".to_string(),
352 ));
353 }
354
355 if self.verbose {
356 println!(
357 "Evolutionary search: {generations} generations with population size {population_size}"
358 );
359 }
360
361 let mut population = Vec::new();
363 for _ in 0..population_size {
364 let params = self.generate_random_parameters(&mut rng)?;
365 population.push(params);
366 }
367
368 let mut best_score = f64::NEG_INFINITY;
369 let mut best_params = HashMap::new();
370 let mut all_scores = Vec::new();
371
372 for generation in 0..generations {
374 if self.verbose {
375 println!("Generation {}/{}", generation + 1, generations);
376 }
377
378 let mut fitness_scores = Vec::new();
380 for params in &population {
381 let score = self.cross_validate_pipeline(&pipeline, x, y)?;
383 fitness_scores.push(score);
384 all_scores.push(score);
385
386 if score > best_score {
387 best_score = score;
388 best_params = params.clone();
389 }
390 }
391
392 let mut new_population = Vec::new();
394
395 let elite_count = population_size / 4;
397 let mut indexed_fitness: Vec<(usize, f64)> = fitness_scores
398 .iter()
399 .enumerate()
400 .map(|(i, &score)| (i, score))
401 .collect();
402 indexed_fitness
403 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
404
405 for i in 0..elite_count {
406 let elite_idx = indexed_fitness[i].0;
407 new_population.push(population[elite_idx].clone());
408 }
409
410 while new_population.len() < population_size {
412 let parent1_idx = self.tournament_selection(&fitness_scores, &mut rng);
414 let parent2_idx = self.tournament_selection(&fitness_scores, &mut rng);
415
416 let offspring =
418 self.crossover(&population[parent1_idx], &population[parent2_idx], &mut rng)?;
419
420 let mutated_offspring = self.mutate(offspring, &mut rng)?;
422
423 new_population.push(mutated_offspring);
424 }
425
426 population = new_population;
427 }
428
429 let search_time = start_time.elapsed().as_secs_f64();
430
431 Ok(OptimizationResults {
432 best_params,
433 best_score,
434 cv_scores: all_scores,
435 search_time,
436 })
437 }
438
439 pub fn multi_objective_optimize<S>(
441 &self,
442 pipeline: Pipeline<S>,
443 x: &ArrayView2<'_, Float>,
444 y: &ArrayView1<'_, Float>,
445 metrics: Vec<ScoringMetric>,
446 ) -> SklResult<ParetoFront>
447 where
448 S: std::fmt::Debug,
449 {
450 let start_time = Instant::now();
451
452 if metrics.is_empty() {
453 return Err(SklearsError::InvalidInput(
454 "At least one metric must be specified for multi-objective optimization"
455 .to_string(),
456 ));
457 }
458
459 let population_size = 100;
461 let generations = 50;
462
463 let results = self.nsga_ii(pipeline, x, y, &metrics, population_size, generations)?;
464
465 let search_time = start_time.elapsed().as_secs_f64();
466
467 if self.verbose {
468 println!("Multi-objective optimization completed in {search_time:.2}s");
469 println!(
470 "Found {} solutions in Pareto front",
471 results.solutions.len()
472 );
473 }
474
475 Ok(results)
476 }
477
478 fn nsga_ii<S>(
480 &self,
481 pipeline: Pipeline<S>,
482 x: &ArrayView2<'_, Float>,
483 y: &ArrayView1<'_, Float>,
484 metrics: &[ScoringMetric],
485 population_size: usize,
486 generations: usize,
487 ) -> SklResult<ParetoFront>
488 where
489 S: std::fmt::Debug,
490 {
491 let mut rng = thread_rng();
492
493 let mut population = Vec::new();
495 for _ in 0..population_size {
496 let params = self.generate_random_parameters(&mut rng)?;
497 let scores = self.evaluate_multi_objective(&pipeline, x, y, ¶ms, metrics)?;
498
499 population.push(MultiObjectiveResult {
500 params,
501 scores,
502 dominated: false,
503 rank: 0,
504 });
505 }
506
507 for generation in 0..generations {
509 if self.verbose && generation % 10 == 0 {
510 println!("NSGA-II Generation {}/{}", generation + 1, generations);
511 }
512
513 let mut offspring = Vec::new();
515 while offspring.len() < population_size {
516 let parent1_idx = rng.gen_range(0..population.len());
518 let parent2_idx = rng.gen_range(0..population.len());
519
520 let child_params = self.crossover(
522 &population[parent1_idx].params,
523 &population[parent2_idx].params,
524 &mut rng,
525 )?;
526
527 let mutated_params = self.mutate(child_params, &mut rng)?;
529
530 let scores =
532 self.evaluate_multi_objective(&pipeline, x, y, &mutated_params, metrics)?;
533
534 offspring.push(MultiObjectiveResult {
535 params: mutated_params,
536 scores,
537 dominated: false,
538 rank: 0,
539 });
540 }
541
542 let mut combined_population = population;
544 combined_population.extend(offspring);
545
546 population = self.select_next_generation(combined_population, population_size);
548 }
549
550 let pareto_solutions: Vec<MultiObjectiveResult> =
552 population.into_iter().filter(|sol| sol.rank == 0).collect();
553
554 let hypervolume = self.calculate_hypervolume(&pareto_solutions, metrics.len());
555
556 Ok(ParetoFront {
557 solutions: pareto_solutions,
558 n_objectives: metrics.len(),
559 hypervolume,
560 })
561 }
562
563 fn tournament_selection(&self, fitness_scores: &[f64], rng: &mut impl Rng) -> usize {
565 let tournament_size = 3;
566 let mut best_idx = rng.random_range(0..fitness_scores.len());
567 let mut best_score = fitness_scores[best_idx];
568
569 for _ in 1..tournament_size {
570 let candidate_idx = rng.random_range(0..fitness_scores.len());
571 let candidate_score = fitness_scores[candidate_idx];
572
573 if candidate_score > best_score {
574 best_idx = candidate_idx;
575 best_score = candidate_score;
576 }
577 }
578
579 best_idx
580 }
581
582 fn crossover(
584 &self,
585 parent1: &HashMap<String, f64>,
586 parent2: &HashMap<String, f64>,
587 rng: &mut impl Rng,
588 ) -> SklResult<HashMap<String, f64>> {
589 let mut offspring = HashMap::new();
590
591 for space in &self.parameter_spaces {
592 let value1 = parent1.get(&space.name).copied().unwrap_or(0.0);
593 let value2 = parent2.get(&space.name).copied().unwrap_or(0.0);
594
595 let offspring_value = if rng.random_bool(0.5) { value1 } else { value2 };
597
598 let final_value = match &space.param_type {
600 ParameterType::Continuous { min, max } => {
601 if rng.random_bool(0.3) {
602 let alpha = 0.5;
604 let range = (value2 - value1).abs();
605 let min_blend = value1.min(value2) - alpha * range;
606 let max_blend = value1.max(value2) + alpha * range;
607
608 rng.random_range(min_blend.max(*min)..=max_blend.min(*max))
609 } else {
610 offspring_value.clamp(*min, *max)
611 }
612 }
613 ParameterType::Discrete { min, max } => {
614 f64::from((offspring_value.round() as i32).clamp(*min, *max))
615 }
616 ParameterType::Categorical { choices } => {
617 (offspring_value as usize % choices.len()) as f64
618 }
619 };
620
621 offspring.insert(space.name.clone(), final_value);
622 }
623
624 Ok(offspring)
625 }
626
627 fn mutate(
629 &self,
630 mut individual: HashMap<String, f64>,
631 rng: &mut impl Rng,
632 ) -> SklResult<HashMap<String, f64>> {
633 let mutation_rate = 0.1;
634
635 for space in &self.parameter_spaces {
636 if rng.random_bool(mutation_rate) {
637 let current_value = individual.get(&space.name).copied().unwrap_or(0.0);
638
639 let mutated_value = match &space.param_type {
640 ParameterType::Continuous { min, max } => {
641 let sigma = (max - min) * 0.1;
643 let noise = rng.random_range(-sigma..=sigma);
644 (current_value + noise).clamp(*min, *max)
645 }
646 ParameterType::Discrete { min, max } => {
647 f64::from(rng.random_range(*min..=*max))
649 }
650 ParameterType::Categorical { choices } => {
651 rng.random_range(0..choices.len()) as f64
653 }
654 };
655
656 individual.insert(space.name.clone(), mutated_value);
657 }
658 }
659
660 Ok(individual)
661 }
662
663 fn evaluate_multi_objective<S>(
665 &self,
666 pipeline: &Pipeline<S>,
667 x: &ArrayView2<'_, Float>,
668 y: &ArrayView1<'_, Float>,
669 _params: &HashMap<String, f64>,
670 metrics: &[ScoringMetric],
671 ) -> SklResult<Vec<f64>>
672 where
673 S: std::fmt::Debug,
674 {
675 let mut scores = Vec::new();
676
677 for metric in metrics {
678 let score = if let ScoringMetric::MultiObjective { .. } = metric {
680 return Err(SklearsError::InvalidInput(
681 "Nested multi-objective metrics not supported".to_string(),
682 ));
683 } else {
684 let original_scoring = self.scoring.clone();
686 let temp_optimizer = PipelineOptimizer {
687 parameter_spaces: Vec::new(),
688 search_strategy: SearchStrategy::GridSearch,
689 cv_folds: self.cv_folds,
690 scoring: metric.clone(),
691 n_jobs: self.n_jobs,
692 verbose: false,
693 };
694 temp_optimizer.cross_validate_pipeline(pipeline, x, y)?
695 };
696 scores.push(score);
697 }
698
699 Ok(scores)
700 }
701
702 fn select_next_generation(
704 &self,
705 mut population: Vec<MultiObjectiveResult>,
706 target_size: usize,
707 ) -> Vec<MultiObjectiveResult> {
708 let fronts = self.non_dominated_sort(&mut population);
710
711 let mut next_generation = Vec::new();
712
713 for (rank, front) in fronts.iter().enumerate() {
714 if next_generation.len() + front.len() <= target_size {
715 for &idx in front {
717 population[idx].rank = rank;
718 next_generation.push(population[idx].clone());
719 }
720 } else {
721 let remaining_slots = target_size - next_generation.len();
723 let mut front_with_distance: Vec<(usize, f64)> = front
724 .iter()
725 .map(|&idx| {
726 let distance = self.calculate_crowding_distance(&population, front, idx);
727 (idx, distance)
728 })
729 .collect();
730
731 front_with_distance
733 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
734
735 for i in 0..remaining_slots {
736 let idx = front_with_distance[i].0;
737 population[idx].rank = rank;
738 next_generation.push(population[idx].clone());
739 }
740 break;
741 }
742 }
743
744 next_generation
745 }
746
747 fn non_dominated_sort(&self, population: &mut [MultiObjectiveResult]) -> Vec<Vec<usize>> {
749 let n = population.len();
750 let mut fronts = Vec::new();
751 let mut dominated_count = vec![0; n];
752 let mut dominated_solutions: Vec<Vec<usize>> = vec![Vec::new(); n];
753
754 let mut current_front = Vec::new();
756
757 for i in 0..n {
758 for j in 0..n {
759 if i != j {
760 let dominates = self.dominates(&population[i], &population[j]);
761 let dominated_by = self.dominates(&population[j], &population[i]);
762
763 if dominates {
764 dominated_solutions[i].push(j);
765 } else if dominated_by {
766 dominated_count[i] += 1;
767 }
768 }
769 }
770
771 if dominated_count[i] == 0 {
772 current_front.push(i);
773 }
774 }
775
776 fronts.push(current_front.clone());
777
778 while !current_front.is_empty() {
780 let mut next_front = Vec::new();
781
782 for &i in ¤t_front {
783 for &j in &dominated_solutions[i] {
784 dominated_count[j] -= 1;
785 if dominated_count[j] == 0 {
786 next_front.push(j);
787 }
788 }
789 }
790
791 if !next_front.is_empty() {
792 fronts.push(next_front.clone());
793 }
794 current_front = next_front;
795 }
796
797 fronts
798 }
799
800 fn dominates(&self, a: &MultiObjectiveResult, b: &MultiObjectiveResult) -> bool {
802 let mut at_least_one_better = false;
803
804 for i in 0..a.scores.len() {
805 if a.scores[i] < b.scores[i] {
806 return false; }
808 if a.scores[i] > b.scores[i] {
809 at_least_one_better = true;
810 }
811 }
812
813 at_least_one_better
814 }
815
816 fn calculate_crowding_distance(
818 &self,
819 population: &[MultiObjectiveResult],
820 front: &[usize],
821 individual_idx: usize,
822 ) -> f64 {
823 if front.len() <= 2 {
824 return f64::INFINITY;
825 }
826
827 let n_objectives = population[individual_idx].scores.len();
828 let mut distance = 0.0;
829
830 for obj in 0..n_objectives {
831 let mut sorted_front = front.to_vec();
833 sorted_front.sort_by(|&a, &b| {
834 population[a].scores[obj]
835 .partial_cmp(&population[b].scores[obj])
836 .unwrap_or(std::cmp::Ordering::Equal)
837 });
838
839 let pos = sorted_front
841 .iter()
842 .position(|&idx| idx == individual_idx)
843 .unwrap_or_default();
844
845 if pos == 0 || pos == sorted_front.len() - 1 {
846 return f64::INFINITY;
848 }
849
850 let obj_min = population[sorted_front[0]].scores[obj];
852 let obj_max = population[sorted_front[sorted_front.len() - 1]].scores[obj];
853
854 if obj_max > obj_min {
855 let prev_obj = population[sorted_front[pos - 1]].scores[obj];
856 let next_obj = population[sorted_front[pos + 1]].scores[obj];
857 distance += (next_obj - prev_obj) / (obj_max - obj_min);
858 }
859 }
860
861 distance
862 }
863
864 fn calculate_hypervolume(
866 &self,
867 solutions: &[MultiObjectiveResult],
868 n_objectives: usize,
869 ) -> f64 {
870 if solutions.is_empty() {
871 return 0.0;
872 }
873
874 if n_objectives == 2 {
876 let mut sorted_solutions = solutions.to_vec();
877 sorted_solutions.sort_by(|a, b| {
878 a.scores[0]
879 .partial_cmp(&b.scores[0])
880 .unwrap_or(std::cmp::Ordering::Equal)
881 });
882
883 let mut hypervolume = 0.0;
884 let mut prev_x = 0.0;
885
886 for solution in &sorted_solutions {
887 if solution.scores[0] > prev_x {
888 hypervolume += (solution.scores[0] - prev_x) * solution.scores[1];
889 prev_x = solution.scores[0];
890 }
891 }
892
893 hypervolume
894 } else {
895 solutions.len() as f64
897 }
898 }
899
900 fn generate_grid_combinations(&self) -> SklResult<Vec<HashMap<String, f64>>> {
902 if self.parameter_spaces.is_empty() {
903 return Ok(vec![HashMap::new()]);
904 }
905
906 let mut combinations = vec![HashMap::new()];
907
908 for space in &self.parameter_spaces {
909 let mut new_combinations = Vec::new();
910
911 for value in &space.values {
912 for existing_combo in &combinations {
913 let mut new_combo = existing_combo.clone();
914 new_combo.insert(space.name.clone(), *value);
915 new_combinations.push(new_combo);
916 }
917 }
918
919 combinations = new_combinations;
920 }
921
922 Ok(combinations)
923 }
924
925 fn generate_random_parameters(&self, rng: &mut impl Rng) -> SklResult<HashMap<String, f64>> {
927 let mut params = HashMap::new();
928
929 for space in &self.parameter_spaces {
930 let value = match &space.param_type {
931 ParameterType::Continuous { min, max } => rng.random_range(*min..*max),
932 ParameterType::Discrete { min, max } => f64::from(rng.random_range(*min..=*max)),
933 ParameterType::Categorical { choices } => {
934 let idx = rng.random_range(0..choices.len());
935 idx as f64
936 }
937 };
938
939 params.insert(space.name.clone(), value);
940 }
941
942 Ok(params)
943 }
944
945 fn cross_validate_pipeline<S>(
947 &self,
948 _pipeline: &Pipeline<S>,
949 x: &ArrayView2<'_, Float>,
950 y: &ArrayView1<'_, Float>,
951 ) -> SklResult<f64>
952 where
953 S: std::fmt::Debug,
954 {
955 let n_samples = x.nrows();
956 let fold_size = n_samples / self.cv_folds;
957 let mut scores = Vec::new();
958
959 for fold in 0..self.cv_folds {
960 let start_idx = fold * fold_size;
961 let end_idx = if fold == self.cv_folds - 1 {
962 n_samples
963 } else {
964 (fold + 1) * fold_size
965 };
966
967 let mut train_indices = Vec::new();
969 let mut test_indices = Vec::new();
970
971 for i in 0..n_samples {
972 if i >= start_idx && i < end_idx {
973 test_indices.push(i);
974 } else {
975 train_indices.push(i);
976 }
977 }
978
979 let score = self.compute_mock_score(x, y, &train_indices, &test_indices)?;
983 scores.push(score);
984 }
985
986 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
988 }
989
990 fn compute_mock_score(
992 &self,
993 x: &ArrayView2<'_, Float>,
994 y: &ArrayView1<'_, Float>,
995 train_indices: &[usize],
996 test_indices: &[usize],
997 ) -> SklResult<f64> {
998 match self.scoring {
1000 ScoringMetric::MeanSquaredError => {
1001 let test_targets: Vec<f64> = test_indices.iter().map(|&i| y[i]).collect();
1003
1004 if test_targets.is_empty() {
1005 return Ok(0.0);
1006 }
1007
1008 let mean = test_targets.iter().sum::<f64>() / test_targets.len() as f64;
1009 let variance = test_targets
1010 .iter()
1011 .map(|&val| (val - mean).powi(2))
1012 .sum::<f64>()
1013 / test_targets.len() as f64;
1014
1015 Ok(-variance.sqrt())
1017 }
1018 ScoringMetric::MeanAbsoluteError => {
1019 let test_targets: Vec<f64> = test_indices.iter().map(|&i| y[i]).collect();
1021
1022 if test_targets.is_empty() {
1023 return Ok(0.0);
1024 }
1025
1026 let mean = test_targets.iter().sum::<f64>() / test_targets.len() as f64;
1027 let mae = test_targets
1028 .iter()
1029 .map(|&val| (val - mean).abs())
1030 .sum::<f64>()
1031 / test_targets.len() as f64;
1032
1033 Ok(-mae)
1034 }
1035 ScoringMetric::Accuracy | ScoringMetric::F1Score => {
1036 let unique_classes = y
1038 .iter()
1039 .map(|&val| val as i32)
1040 .collect::<std::collections::HashSet<_>>();
1041
1042 Ok(1.0 / unique_classes.len() as f64)
1044 }
1045 ScoringMetric::Custom { .. } => {
1046 Ok(0.8)
1048 }
1049 ScoringMetric::MultiObjective { .. } => {
1050 Ok(0.5)
1052 }
1053 }
1054 }
1055}
1056
1057impl Default for PipelineOptimizer {
1058 fn default() -> Self {
1059 Self::new()
1060 }
1061}
1062
1063#[derive(Debug)]
1065pub struct OptimizationResults {
1066 pub best_params: HashMap<String, f64>,
1068 pub best_score: f64,
1070 pub cv_scores: Vec<f64>,
1072 pub search_time: f64,
1074}
1075
1076pub struct PipelineValidator {
1078 check_data_types: bool,
1079 check_missing_values: bool,
1080 check_infinite_values: bool,
1081 check_feature_names: bool,
1082 verbose: bool,
1083}
1084
1085impl PipelineValidator {
1086 #[must_use]
1088 pub fn new() -> Self {
1089 Self {
1090 check_data_types: true,
1091 check_missing_values: true,
1092 check_infinite_values: true,
1093 check_feature_names: false,
1094 verbose: false,
1095 }
1096 }
1097
1098 #[must_use]
1100 pub fn check_data_types(mut self, check: bool) -> Self {
1101 self.check_data_types = check;
1102 self
1103 }
1104
1105 #[must_use]
1107 pub fn check_missing_values(mut self, check: bool) -> Self {
1108 self.check_missing_values = check;
1109 self
1110 }
1111
1112 #[must_use]
1114 pub fn check_infinite_values(mut self, check: bool) -> Self {
1115 self.check_infinite_values = check;
1116 self
1117 }
1118
1119 #[must_use]
1121 pub fn check_feature_names(mut self, check: bool) -> Self {
1122 self.check_feature_names = check;
1123 self
1124 }
1125
1126 #[must_use]
1128 pub fn verbose(mut self, verbose: bool) -> Self {
1129 self.verbose = verbose;
1130 self
1131 }
1132
1133 pub fn validate_data(
1135 &self,
1136 x: &ArrayView2<'_, Float>,
1137 y: Option<&ArrayView1<'_, Float>>,
1138 ) -> SklResult<()> {
1139 if self.check_missing_values {
1140 self.check_for_missing_values(x)?;
1141 }
1142
1143 if self.check_infinite_values {
1144 self.check_for_infinite_values(x)?;
1145 }
1146
1147 if let Some(y_values) = y {
1148 self.validate_target(y_values)?;
1149 }
1150
1151 Ok(())
1152 }
1153
1154 fn check_for_missing_values(&self, x: &ArrayView2<'_, Float>) -> SklResult<()> {
1155 for (i, row) in x.rows().into_iter().enumerate() {
1156 for (j, &value) in row.iter().enumerate() {
1157 if value.is_nan() {
1158 return Err(SklearsError::InvalidData {
1159 reason: format!("Missing value (NaN) found at position ({i}, {j})"),
1160 });
1161 }
1162 }
1163 }
1164 Ok(())
1165 }
1166
1167 fn check_for_infinite_values(&self, x: &ArrayView2<'_, Float>) -> SklResult<()> {
1168 for (i, row) in x.rows().into_iter().enumerate() {
1169 for (j, &value) in row.iter().enumerate() {
1170 if value.is_infinite() {
1171 return Err(SklearsError::InvalidData {
1172 reason: format!("Infinite value found at position ({i}, {j})"),
1173 });
1174 }
1175 }
1176 }
1177 Ok(())
1178 }
1179
1180 fn validate_target(&self, y: &ArrayView1<'_, Float>) -> SklResult<()> {
1181 for (i, &value) in y.iter().enumerate() {
1182 if value.is_nan() {
1183 return Err(SklearsError::InvalidData {
1184 reason: format!("Missing value (NaN) found in target at position {i}"),
1185 });
1186 }
1187 if value.is_infinite() {
1188 return Err(SklearsError::InvalidData {
1189 reason: format!("Infinite value found in target at position {i}"),
1190 });
1191 }
1192 }
1193 Ok(())
1194 }
1195
1196 pub fn validate_pipeline<S>(&self, _pipeline: &Pipeline<S>) -> SklResult<()>
1198 where
1199 S: std::fmt::Debug,
1200 {
1201 Ok(())
1203 }
1204}
1205
1206impl Default for PipelineValidator {
1207 fn default() -> Self {
1208 Self::new()
1209 }
1210}
1211
1212pub struct RobustPipelineExecutor {
1214 max_retries: usize,
1215 fallback_strategy: FallbackStrategy,
1216 error_handling: ErrorHandlingStrategy,
1217 timeout_seconds: Option<u64>,
1218}
1219
1220#[derive(Debug, Clone)]
1222pub enum FallbackStrategy {
1223 ReturnError,
1225 SimplerPipeline,
1227 DefaultValues,
1229 SkipStep,
1231}
1232
1233#[derive(Debug, Clone)]
1235pub enum ErrorHandlingStrategy {
1236 FailFast,
1238 ContinueWithWarnings,
1240 AttemptRecovery,
1242}
1243
1244impl RobustPipelineExecutor {
1245 #[must_use]
1247 pub fn new() -> Self {
1248 Self {
1249 max_retries: 3,
1250 fallback_strategy: FallbackStrategy::ReturnError,
1251 error_handling: ErrorHandlingStrategy::FailFast,
1252 timeout_seconds: None,
1253 }
1254 }
1255
1256 #[must_use]
1258 pub fn max_retries(mut self, retries: usize) -> Self {
1259 self.max_retries = retries;
1260 self
1261 }
1262
1263 #[must_use]
1265 pub fn fallback_strategy(mut self, strategy: FallbackStrategy) -> Self {
1266 self.fallback_strategy = strategy;
1267 self
1268 }
1269
1270 #[must_use]
1272 pub fn error_handling(mut self, strategy: ErrorHandlingStrategy) -> Self {
1273 self.error_handling = strategy;
1274 self
1275 }
1276
1277 #[must_use]
1279 pub fn timeout_seconds(mut self, timeout: u64) -> Self {
1280 self.timeout_seconds = Some(timeout);
1281 self
1282 }
1283
1284 pub fn execute<S>(
1286 &self,
1287 mut pipeline: Pipeline<S>,
1288 x: &ArrayView2<'_, Float>,
1289 y: Option<&ArrayView1<'_, Float>>,
1290 ) -> SklResult<Array1<f64>>
1291 where
1292 S: std::fmt::Debug,
1293 {
1294 let mut attempt = 0;
1295
1296 while attempt <= self.max_retries {
1297 match self.try_execute(&mut pipeline, x, y) {
1298 Ok(result) => return Ok(result),
1299 Err(error) => match self.error_handling {
1300 ErrorHandlingStrategy::FailFast => {
1301 return Err(error);
1302 }
1303 ErrorHandlingStrategy::ContinueWithWarnings => {
1304 eprintln!(
1305 "Warning: Pipeline execution failed (attempt {}): {:?}",
1306 attempt + 1,
1307 error
1308 );
1309 if attempt == self.max_retries {
1310 return self.apply_fallback_strategy(x, y);
1311 }
1312 }
1313 ErrorHandlingStrategy::AttemptRecovery => {
1314 eprintln!(
1315 "Attempting recovery from error (attempt {}): {:?}",
1316 attempt + 1,
1317 error
1318 );
1319 if attempt == self.max_retries {
1320 return self.apply_fallback_strategy(x, y);
1321 }
1322 }
1323 },
1324 }
1325 attempt += 1;
1326 }
1327
1328 self.apply_fallback_strategy(x, y)
1329 }
1330
1331 fn try_execute<S>(
1333 &self,
1334 _pipeline: &mut Pipeline<S>,
1335 x: &ArrayView2<'_, Float>,
1336 _y: Option<&ArrayView1<'_, Float>>,
1337 ) -> SklResult<Array1<f64>>
1338 where
1339 S: std::fmt::Debug,
1340 {
1341 if x.nrows() == 0 {
1344 return Err(SklearsError::InvalidInput("Empty input data".to_string()));
1345 }
1346
1347 let predictions: Vec<f64> = x
1349 .rows()
1350 .into_iter()
1351 .map(|row| row.iter().copied().sum::<f64>() / row.len() as f64)
1352 .collect();
1353
1354 Ok(Array1::from_vec(predictions))
1355 }
1356
1357 fn apply_fallback_strategy(
1359 &self,
1360 x: &ArrayView2<'_, Float>,
1361 _y: Option<&ArrayView1<'_, Float>>,
1362 ) -> SklResult<Array1<f64>> {
1363 match self.fallback_strategy {
1364 FallbackStrategy::ReturnError => Err(SklearsError::InvalidData {
1365 reason: "Pipeline execution failed after maximum retries".to_string(),
1366 }),
1367 FallbackStrategy::SimplerPipeline => {
1368 eprintln!("Falling back to simpler pipeline");
1370 let simple_predictions: Vec<f64> = x
1371 .rows()
1372 .into_iter()
1373 .map(|row| {
1374 if row.is_empty() {
1376 0.0
1377 } else {
1378 row[0]
1379 }
1380 })
1381 .collect();
1382 Ok(Array1::from_vec(simple_predictions))
1383 }
1384 FallbackStrategy::DefaultValues => {
1385 eprintln!("Falling back to default values");
1387 Ok(Array1::zeros(x.nrows()))
1388 }
1389 FallbackStrategy::SkipStep => {
1390 eprintln!("Falling back by skipping failed step");
1392 let fallback_predictions: Vec<f64> = x
1393 .rows()
1394 .into_iter()
1395 .map(|row| row.iter().copied().sum())
1396 .collect();
1397 Ok(Array1::from_vec(fallback_predictions))
1398 }
1399 }
1400 }
1401}
1402
1403impl Default for RobustPipelineExecutor {
1404 fn default() -> Self {
1405 Self::new()
1406 }
1407}