1use scirs2_core::ndarray::Array2;
17use sklears_core::{
18 error::{Result, SklearsError},
19 traits::{Fit, Trained, Untrained},
20 types::Float,
21};
22use std::collections::HashMap;
23use std::marker::PhantomData;
24
25#[derive(Debug, Clone)]
27pub struct DataCharacteristics {
28 pub shape: (usize, usize),
30 pub distribution_types: Vec<DistributionType>,
32 pub skewness: Vec<Float>,
34 pub kurtosis: Vec<Float>,
36 pub outlier_percentages: Vec<Float>,
38 pub missing_percentages: Vec<Float>,
40 pub ranges: Vec<(Float, Float)>,
42 pub correlation_strength: Float,
44 pub quality_score: Float,
46 pub optimal_batch_size: usize,
48}
49
50#[derive(Debug, Clone, Copy)]
52pub enum DistributionType {
53 Normal,
55 Skewed,
57 Uniform,
59 Multimodal,
61 HeavyTailed,
63 Sparse,
65 Unknown,
67}
68
69#[derive(Debug, Clone, Copy)]
71pub enum AdaptationStrategy {
72 Conservative,
74 Balanced,
76 Aggressive,
78 Custom,
80}
81
82#[derive(Debug, Clone)]
84pub struct AdaptiveConfig {
85 pub strategy: AdaptationStrategy,
87 pub use_cross_validation: bool,
89 pub cv_folds: usize,
91 pub time_budget: Option<Float>,
93 pub parallel: bool,
95 pub tolerance: Float,
97 pub max_iterations: usize,
99 pub parameter_bounds: HashMap<String, (Float, Float)>,
101}
102
103impl Default for AdaptiveConfig {
104 fn default() -> Self {
105 Self {
106 strategy: AdaptationStrategy::Balanced,
107 use_cross_validation: true,
108 cv_folds: 5,
109 time_budget: Some(60.0), parallel: true,
111 tolerance: 1e-4,
112 max_iterations: 100,
113 parameter_bounds: HashMap::new(),
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct AdaptiveParameterSelector<State = Untrained> {
121 config: AdaptiveConfig,
122 state: PhantomData<State>,
123 data_characteristics_: Option<DataCharacteristics>,
125 optimal_parameters_: Option<HashMap<String, Float>>,
126 parameter_history_: Option<Vec<ParameterEvaluation>>,
127}
128
129#[derive(Debug, Clone)]
131pub struct ParameterEvaluation {
132 pub parameters: HashMap<String, Float>,
133 pub score: Float,
134 pub robustness_score: Float,
135 pub efficiency_score: Float,
136 pub quality_score: Float,
137 pub evaluation_time: Float,
138}
139
140#[derive(Debug, Clone)]
142pub struct ParameterRecommendations {
143 pub scaling: ScalingParameters,
145 pub imputation: ImputationParameters,
147 pub outlier_detection: OutlierDetectionParameters,
149 pub transformation: TransformationParameters,
151 pub confidence: Float,
153}
154
155#[derive(Debug, Clone)]
157pub struct ScalingParameters {
158 pub method: String, pub outlier_threshold: Float,
160 pub quantile_range: (Float, Float),
161 pub with_centering: bool,
162 pub with_scaling: bool,
163}
164
165#[derive(Debug, Clone)]
167pub struct ImputationParameters {
168 pub strategy: String, pub n_neighbors: Option<usize>,
170 pub outlier_aware: bool,
171 pub max_iterations: Option<usize>,
172}
173
174#[derive(Debug, Clone)]
176pub struct OutlierDetectionParameters {
177 pub method: String, pub contamination: Float,
179 pub threshold: Float,
180 pub ensemble_size: Option<usize>,
181}
182
183#[derive(Debug, Clone)]
185pub struct TransformationParameters {
186 pub method: String, pub handle_negatives: bool,
188 pub lambda: Option<Float>,
189 pub n_quantiles: Option<usize>,
190}
191
192impl AdaptiveParameterSelector<Untrained> {
193 pub fn new() -> Self {
195 Self {
196 config: AdaptiveConfig::default(),
197 state: PhantomData,
198 data_characteristics_: None,
199 optimal_parameters_: None,
200 parameter_history_: None,
201 }
202 }
203
204 pub fn conservative() -> Self {
206 Self::new().strategy(AdaptationStrategy::Conservative)
207 }
208
209 pub fn balanced() -> Self {
211 Self::new().strategy(AdaptationStrategy::Balanced)
212 }
213
214 pub fn aggressive() -> Self {
216 Self::new().strategy(AdaptationStrategy::Aggressive)
217 }
218
219 pub fn strategy(mut self, strategy: AdaptationStrategy) -> Self {
221 self.config.strategy = strategy;
222 self
223 }
224
225 pub fn cross_validation(mut self, enable: bool, folds: usize) -> Self {
227 self.config.use_cross_validation = enable;
228 self.config.cv_folds = folds;
229 self
230 }
231
232 pub fn time_budget(mut self, seconds: Float) -> Self {
234 self.config.time_budget = Some(seconds);
235 self
236 }
237
238 pub fn parallel(mut self, enable: bool) -> Self {
240 self.config.parallel = enable;
241 self
242 }
243
244 pub fn tolerance(mut self, tolerance: Float) -> Self {
246 self.config.tolerance = tolerance;
247 self
248 }
249
250 pub fn parameter_bounds(mut self, bounds: HashMap<String, (Float, Float)>) -> Self {
252 self.config.parameter_bounds = bounds;
253 self
254 }
255}
256
257impl Fit<Array2<Float>, ()> for AdaptiveParameterSelector<Untrained> {
258 type Fitted = AdaptiveParameterSelector<Trained>;
259
260 fn fit(mut self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
261 let (n_samples, n_features) = x.dim();
262
263 if n_samples == 0 || n_features == 0 {
264 return Err(SklearsError::InvalidInput(
265 "Input array is empty".to_string(),
266 ));
267 }
268
269 let characteristics = self.analyze_data_characteristics(x)?;
271
272 let optimal_parameters = self.optimize_parameters(x, &characteristics)?;
274
275 let parameter_history = self.evaluate_parameter_space(x, &characteristics)?;
277
278 self.data_characteristics_ = Some(characteristics);
279 self.optimal_parameters_ = Some(optimal_parameters);
280 self.parameter_history_ = Some(parameter_history);
281
282 Ok(AdaptiveParameterSelector {
283 config: self.config,
284 state: PhantomData,
285 data_characteristics_: self.data_characteristics_,
286 optimal_parameters_: self.optimal_parameters_,
287 parameter_history_: self.parameter_history_,
288 })
289 }
290}
291
292impl AdaptiveParameterSelector<Untrained> {
293 fn analyze_data_characteristics(&self, x: &Array2<Float>) -> Result<DataCharacteristics> {
295 let (n_samples, n_features) = x.dim();
296
297 let mut distribution_types = Vec::with_capacity(n_features);
298 let mut skewness = Vec::with_capacity(n_features);
299 let mut kurtosis = Vec::with_capacity(n_features);
300 let mut outlier_percentages = Vec::with_capacity(n_features);
301 let mut missing_percentages = Vec::with_capacity(n_features);
302 let mut ranges = Vec::with_capacity(n_features);
303
304 for j in 0..n_features {
306 let column = x.column(j);
307
308 let valid_values: Vec<Float> =
310 column.iter().filter(|x| x.is_finite()).copied().collect();
311
312 let missing_pct =
313 ((n_samples - valid_values.len()) as Float / n_samples as Float) * 100.0;
314 missing_percentages.push(missing_pct);
315
316 if valid_values.is_empty() {
317 distribution_types.push(DistributionType::Unknown);
318 skewness.push(0.0);
319 kurtosis.push(0.0);
320 outlier_percentages.push(0.0);
321 ranges.push((0.0, 0.0));
322 continue;
323 }
324
325 let mean = valid_values.iter().sum::<Float>() / valid_values.len() as Float;
327 let variance = valid_values
328 .iter()
329 .map(|x| (x - mean).powi(2))
330 .sum::<Float>()
331 / valid_values.len() as Float;
332 let std = variance.sqrt();
333
334 let feature_skewness = if std > 0.0 {
336 valid_values
337 .iter()
338 .map(|x| ((x - mean) / std).powi(3))
339 .sum::<Float>()
340 / valid_values.len() as Float
341 } else {
342 0.0
343 };
344
345 let feature_kurtosis = if std > 0.0 {
346 valid_values
347 .iter()
348 .map(|x| ((x - mean) / std).powi(4))
349 .sum::<Float>()
350 / valid_values.len() as Float
351 - 3.0 } else {
353 0.0
354 };
355
356 skewness.push(feature_skewness);
357 kurtosis.push(feature_kurtosis);
358
359 let mut sorted_values = valid_values.clone();
361 sorted_values.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
362
363 let q1_idx = sorted_values.len() / 4;
364 let q3_idx = 3 * sorted_values.len() / 4;
365 let q1 = sorted_values[q1_idx];
366 let q3 = sorted_values[q3_idx];
367 let iqr = q3 - q1;
368
369 let lower_bound = q1 - 1.5 * iqr;
370 let upper_bound = q3 + 1.5 * iqr;
371
372 let outlier_count = valid_values
373 .iter()
374 .filter(|&&x| x < lower_bound || x > upper_bound)
375 .count();
376 let outlier_pct = (outlier_count as Float / valid_values.len() as Float) * 100.0;
377 outlier_percentages.push(outlier_pct);
378
379 let min_val = sorted_values[0];
381 let max_val = sorted_values[sorted_values.len() - 1];
382 ranges.push((min_val, max_val));
383
384 let dist_type = self.classify_distribution(
386 feature_skewness,
387 feature_kurtosis,
388 outlier_pct,
389 &valid_values,
390 );
391 distribution_types.push(dist_type);
392 }
393
394 let correlation_strength = self.estimate_correlation_strength(x)?;
396
397 let avg_missing = missing_percentages.iter().sum::<Float>() / n_features as Float;
399 let avg_outliers = outlier_percentages.iter().sum::<Float>() / n_features as Float;
400 let quality_score = (100.0 - avg_missing - avg_outliers).max(0.0) / 100.0;
401
402 let optimal_batch_size = self.estimate_optimal_batch_size(n_samples, n_features);
404
405 Ok(DataCharacteristics {
406 shape: (n_samples, n_features),
407 distribution_types,
408 skewness,
409 kurtosis,
410 outlier_percentages,
411 missing_percentages,
412 ranges,
413 correlation_strength,
414 quality_score,
415 optimal_batch_size,
416 })
417 }
418
419 fn classify_distribution(
421 &self,
422 skewness: Float,
423 kurtosis: Float,
424 outlier_pct: Float,
425 values: &[Float],
426 ) -> DistributionType {
427 let zero_count = values.iter().filter(|&&x| x.abs() < 1e-10).count();
429 let sparsity = zero_count as Float / values.len() as Float;
430
431 if sparsity > 0.5 {
432 return DistributionType::Sparse;
433 }
434
435 if skewness.abs() < 0.5 && kurtosis.abs() < 1.0 && outlier_pct < 5.0 {
437 return DistributionType::Normal;
438 }
439
440 if skewness.abs() > 1.0 {
442 return DistributionType::Skewed;
443 }
444
445 if kurtosis > 2.0 || outlier_pct > 10.0 {
447 return DistributionType::HeavyTailed;
448 }
449
450 if kurtosis < -1.0 && skewness.abs() < 0.5 {
452 return DistributionType::Uniform;
453 }
454
455 if kurtosis < -1.5 && outlier_pct > 5.0 {
457 return DistributionType::Multimodal;
458 }
459
460 DistributionType::Unknown
461 }
462
463 fn estimate_correlation_strength(&self, x: &Array2<Float>) -> Result<Float> {
465 let (_n_samples, n_features) = x.dim();
466
467 if n_features < 2 {
468 return Ok(0.0);
469 }
470
471 let mut correlation_sum = 0.0;
472 let mut correlation_count = 0;
473
474 let max_pairs = 100.min(n_features * (n_features - 1) / 2);
476 let step = (n_features * (n_features - 1) / 2).max(1) / max_pairs.max(1);
477
478 let mut pair_count = 0;
479 for i in 0..n_features {
480 for j in (i + 1)..n_features {
481 if pair_count % step == 0 {
482 let col_i = x.column(i);
483 let col_j = x.column(j);
484
485 if let Ok(corr) = self.calculate_correlation(&col_i, &col_j) {
487 correlation_sum += corr.abs();
488 correlation_count += 1;
489 }
490 }
491 pair_count += 1;
492 }
493 }
494
495 Ok(if correlation_count > 0 {
496 correlation_sum / correlation_count as Float
497 } else {
498 0.0
499 })
500 }
501
502 fn calculate_correlation(
504 &self,
505 x: &scirs2_core::ndarray::ArrayView1<Float>,
506 y: &scirs2_core::ndarray::ArrayView1<Float>,
507 ) -> Result<Float> {
508 let pairs: Vec<(Float, Float)> = x
509 .iter()
510 .zip(y.iter())
511 .filter(|(&a, &b)| a.is_finite() && b.is_finite())
512 .map(|(&a, &b)| (a, b))
513 .collect();
514
515 if pairs.len() < 3 {
516 return Ok(0.0);
517 }
518
519 let mean_x = pairs.iter().map(|(x, _)| x).sum::<Float>() / pairs.len() as Float;
520 let mean_y = pairs.iter().map(|(_, y)| y).sum::<Float>() / pairs.len() as Float;
521
522 let mut sum_xy = 0.0;
523 let mut sum_x2 = 0.0;
524 let mut sum_y2 = 0.0;
525
526 for (x, y) in pairs {
527 let dx = x - mean_x;
528 let dy = y - mean_y;
529 sum_xy += dx * dy;
530 sum_x2 += dx * dx;
531 sum_y2 += dy * dy;
532 }
533
534 let denominator = (sum_x2 * sum_y2).sqrt();
535 if denominator > 1e-10 {
536 Ok(sum_xy / denominator)
537 } else {
538 Ok(0.0)
539 }
540 }
541
542 fn estimate_optimal_batch_size(&self, n_samples: usize, n_features: usize) -> usize {
544 let data_size = n_samples * n_features * std::mem::size_of::<Float>();
546 let target_memory = 100_000_000; let optimal_size = if data_size <= target_memory {
549 n_samples } else {
551 (target_memory / (n_features * std::mem::size_of::<Float>()))
552 .max(1000)
553 .min(n_samples)
554 };
555
556 optimal_size
557 }
558
559 fn optimize_parameters(
561 &self,
562 _x: &Array2<Float>,
563 characteristics: &DataCharacteristics,
564 ) -> Result<HashMap<String, Float>> {
565 let mut optimal_params = HashMap::new();
566
567 let scaling_method = self.select_optimal_scaling_method(characteristics);
569 optimal_params.insert("scaling_method".to_string(), scaling_method);
570
571 let outlier_threshold = self.select_optimal_outlier_threshold(characteristics);
573 optimal_params.insert("outlier_threshold".to_string(), outlier_threshold);
574
575 let imputation_strategy = self.select_optimal_imputation_strategy(characteristics);
577 optimal_params.insert("imputation_strategy".to_string(), imputation_strategy);
578
579 let (q_low, q_high) = self.select_optimal_quantile_range(characteristics);
581 optimal_params.insert("quantile_range_low".to_string(), q_low);
582 optimal_params.insert("quantile_range_high".to_string(), q_high);
583
584 let contamination_rate = self.select_optimal_contamination_rate(characteristics);
586 optimal_params.insert("contamination_rate".to_string(), contamination_rate);
587
588 Ok(optimal_params)
589 }
590
591 fn select_optimal_scaling_method(&self, characteristics: &DataCharacteristics) -> Float {
593 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
594 / characteristics.outlier_percentages.len() as Float;
595 let avg_skewness = characteristics
596 .skewness
597 .iter()
598 .map(|x| x.abs())
599 .sum::<Float>()
600 / characteristics.skewness.len() as Float;
601
602 match self.config.strategy {
603 AdaptationStrategy::Conservative => {
604 if avg_outlier_pct > 10.0 || avg_skewness > 1.0 {
605 2.0 } else {
607 0.0 }
609 }
610 AdaptationStrategy::Balanced => {
611 if avg_outlier_pct > 15.0 {
612 2.0 } else if avg_skewness > 2.0 {
614 1.0 } else {
616 0.0 }
618 }
619 AdaptationStrategy::Aggressive => {
620 if avg_outlier_pct > 20.0 {
621 2.0 } else {
623 0.0 }
625 }
626 AdaptationStrategy::Custom => 0.0, }
628 }
629
630 fn select_optimal_outlier_threshold(&self, characteristics: &DataCharacteristics) -> Float {
632 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
633 / characteristics.outlier_percentages.len() as Float;
634
635 match self.config.strategy {
636 AdaptationStrategy::Conservative => {
637 if avg_outlier_pct > 20.0 {
638 3.5
639 } else {
640 3.0
641 }
642 }
643 AdaptationStrategy::Balanced => {
644 if avg_outlier_pct > 15.0 {
645 2.5
646 } else {
647 2.0
648 }
649 }
650 AdaptationStrategy::Aggressive => {
651 if avg_outlier_pct > 10.0 {
652 2.0
653 } else {
654 1.5
655 }
656 }
657 AdaptationStrategy::Custom => 2.5, }
659 }
660
661 fn select_optimal_imputation_strategy(&self, characteristics: &DataCharacteristics) -> Float {
663 let avg_missing_pct = characteristics.missing_percentages.iter().sum::<Float>()
664 / characteristics.missing_percentages.len() as Float;
665 let has_skewed_features = characteristics.skewness.iter().any(|&s| s.abs() > 1.0);
666
667 if avg_missing_pct > 20.0 {
668 2.0 } else if has_skewed_features {
670 1.0 } else {
672 0.0 }
674 }
675
676 fn select_optimal_quantile_range(
678 &self,
679 characteristics: &DataCharacteristics,
680 ) -> (Float, Float) {
681 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
682 / characteristics.outlier_percentages.len() as Float;
683
684 match self.config.strategy {
685 AdaptationStrategy::Conservative => {
686 if avg_outlier_pct > 15.0 {
687 (10.0, 90.0)
688 } else {
689 (25.0, 75.0)
690 }
691 }
692 AdaptationStrategy::Balanced => {
693 if avg_outlier_pct > 10.0 {
694 (5.0, 95.0)
695 } else {
696 (25.0, 75.0)
697 }
698 }
699 AdaptationStrategy::Aggressive => {
700 (25.0, 75.0) }
702 AdaptationStrategy::Custom => (25.0, 75.0),
703 }
704 }
705
706 fn select_optimal_contamination_rate(&self, characteristics: &DataCharacteristics) -> Float {
708 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
709 / characteristics.outlier_percentages.len() as Float;
710
711 (avg_outlier_pct / 100.0 * 1.2).min(0.5).max(0.01)
713 }
714
715 fn evaluate_parameter_space(
717 &self,
718 x: &Array2<Float>,
719 characteristics: &DataCharacteristics,
720 ) -> Result<Vec<ParameterEvaluation>> {
721 let mut evaluations = Vec::new();
722
723 let scaling_methods = vec![0.0, 1.0, 2.0]; let outlier_thresholds = vec![1.5, 2.0, 2.5, 3.0, 3.5];
726 let contamination_rates = vec![0.05, 0.1, 0.15, 0.2];
727
728 let max_evaluations = 20; let mut evaluation_count = 0;
731
732 for &scaling_method in &scaling_methods {
733 for &threshold in &outlier_thresholds {
734 for &contamination in &contamination_rates {
735 if evaluation_count >= max_evaluations {
736 break;
737 }
738
739 let mut params = HashMap::new();
740 params.insert("scaling_method".to_string(), scaling_method);
741 params.insert("outlier_threshold".to_string(), threshold);
742 params.insert("contamination_rate".to_string(), contamination);
743
744 let evaluation = self.evaluate_parameters(¶ms, x, characteristics)?;
745 evaluations.push(evaluation);
746 evaluation_count += 1;
747 }
748 }
749 }
750
751 evaluations.sort_by(|a, b| {
753 b.score
754 .partial_cmp(&a.score)
755 .expect("operation should succeed")
756 });
757
758 Ok(evaluations)
759 }
760
761 fn evaluate_parameters(
763 &self,
764 params: &HashMap<String, Float>,
765 _x: &Array2<Float>,
766 characteristics: &DataCharacteristics,
767 ) -> Result<ParameterEvaluation> {
768 let start_time = std::time::Instant::now();
769
770 let robustness_score = self.compute_robustness_score(params, characteristics);
772 let efficiency_score = self.compute_efficiency_score(params, characteristics);
773 let quality_score = self.compute_quality_score(params, characteristics);
774
775 let overall_score = match self.config.strategy {
777 AdaptationStrategy::Conservative => {
778 robustness_score * 0.6 + quality_score * 0.3 + efficiency_score * 0.1
779 }
780 AdaptationStrategy::Balanced => {
781 robustness_score * 0.4 + quality_score * 0.4 + efficiency_score * 0.2
782 }
783 AdaptationStrategy::Aggressive => {
784 robustness_score * 0.2 + quality_score * 0.3 + efficiency_score * 0.5
785 }
786 AdaptationStrategy::Custom => {
787 robustness_score * 0.33 + quality_score * 0.33 + efficiency_score * 0.34
788 }
789 };
790
791 let evaluation_time = start_time.elapsed().as_secs_f64() as Float;
792
793 Ok(ParameterEvaluation {
794 parameters: params.clone(),
795 score: overall_score,
796 robustness_score,
797 efficiency_score,
798 quality_score,
799 evaluation_time,
800 })
801 }
802
803 fn compute_robustness_score(
805 &self,
806 params: &HashMap<String, Float>,
807 characteristics: &DataCharacteristics,
808 ) -> Float {
809 let scaling_method = params.get("scaling_method").unwrap_or(&0.0);
810 let outlier_threshold = params.get("outlier_threshold").unwrap_or(&2.5);
811
812 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
813 / characteristics.outlier_percentages.len() as Float;
814
815 let mut score: Float = 0.0;
816
817 if avg_outlier_pct > 10.0 && *scaling_method == 2.0 {
819 score += 0.4;
820 }
821
822 if avg_outlier_pct > 15.0 && *outlier_threshold <= 2.5 {
824 score += 0.3;
825 } else if avg_outlier_pct <= 5.0 && *outlier_threshold >= 3.0 {
826 score += 0.3;
827 }
828
829 let avg_skewness = characteristics
831 .skewness
832 .iter()
833 .map(|x| x.abs())
834 .sum::<Float>()
835 / characteristics.skewness.len() as Float;
836 if avg_skewness > 1.0 && *scaling_method != 0.0 {
837 score += 0.3;
838 }
839
840 score.min(1.0 as Float)
841 }
842
843 fn compute_efficiency_score(
845 &self,
846 params: &HashMap<String, Float>,
847 characteristics: &DataCharacteristics,
848 ) -> Float {
849 let scaling_method = params.get("scaling_method").unwrap_or(&0.0);
850 let (n_samples, n_features) = characteristics.shape;
851
852 let mut score: Float = if *scaling_method == 0.0 {
854 1.0
855 } else if *scaling_method == 1.0 {
856 0.8 } else {
858 0.6 };
860
861 let data_size_factor = (n_samples * n_features) as Float;
863 if data_size_factor > 1_000_000.0 {
864 score *= 1.2; }
866
867 score.min(1.0 as Float)
868 }
869
870 fn compute_quality_score(
872 &self,
873 params: &HashMap<String, Float>,
874 characteristics: &DataCharacteristics,
875 ) -> Float {
876 let mut score = characteristics.quality_score; let avg_missing_pct = characteristics.missing_percentages.iter().sum::<Float>()
880 / characteristics.missing_percentages.len() as Float;
881
882 if avg_missing_pct > 10.0 {
884 score *= 0.9; }
886
887 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
889 / characteristics.outlier_percentages.len() as Float;
890
891 let outlier_threshold = params.get("outlier_threshold").unwrap_or(&2.5);
892 if avg_outlier_pct > 10.0 && *outlier_threshold <= 2.5 {
893 score *= 1.1; }
895
896 score.min(1.0 as Float)
897 }
898}
899
900impl AdaptiveParameterSelector<Trained> {
901 pub fn data_characteristics(&self) -> Option<&DataCharacteristics> {
903 self.data_characteristics_.as_ref()
904 }
905
906 pub fn optimal_parameters(&self) -> Option<&HashMap<String, Float>> {
908 self.optimal_parameters_.as_ref()
909 }
910
911 pub fn parameter_history(&self) -> Option<&Vec<ParameterEvaluation>> {
913 self.parameter_history_.as_ref()
914 }
915
916 pub fn recommend_parameters(&self) -> Result<ParameterRecommendations> {
918 let characteristics = self.data_characteristics_.as_ref().ok_or_else(|| {
919 SklearsError::InvalidInput("No data characteristics available".to_string())
920 })?;
921
922 let optimal_params = self.optimal_parameters_.as_ref().ok_or_else(|| {
923 SklearsError::InvalidInput("No optimal parameters available".to_string())
924 })?;
925
926 let scaling_method = optimal_params.get("scaling_method").unwrap_or(&0.0);
928 let scaling = ScalingParameters {
929 method: match *scaling_method as i32 {
930 0 => "standard".to_string(),
931 1 => "minmax".to_string(),
932 2 => "robust".to_string(),
933 _ => "standard".to_string(),
934 },
935 outlier_threshold: *optimal_params.get("outlier_threshold").unwrap_or(&2.5),
936 quantile_range: (
937 *optimal_params.get("quantile_range_low").unwrap_or(&25.0),
938 *optimal_params.get("quantile_range_high").unwrap_or(&75.0),
939 ),
940 with_centering: true,
941 with_scaling: true,
942 };
943
944 let imputation_strategy = optimal_params.get("imputation_strategy").unwrap_or(&0.0);
946 let avg_missing_pct = characteristics.missing_percentages.iter().sum::<Float>()
947 / characteristics.missing_percentages.len() as Float;
948
949 let imputation = ImputationParameters {
950 strategy: match *imputation_strategy as i32 {
951 0 => "mean".to_string(),
952 1 => "median".to_string(),
953 2 => "knn".to_string(),
954 _ => "mean".to_string(),
955 },
956 n_neighbors: if *imputation_strategy == 2.0 {
957 Some(5)
958 } else {
959 None
960 },
961 outlier_aware: avg_missing_pct > 10.0,
962 max_iterations: if *imputation_strategy == 2.0 {
963 Some(10)
964 } else {
965 None
966 },
967 };
968
969 let contamination_rate = *optimal_params.get("contamination_rate").unwrap_or(&0.1);
971 let outlier_detection = OutlierDetectionParameters {
972 method: "isolation_forest".to_string(),
973 contamination: contamination_rate,
974 threshold: *optimal_params.get("outlier_threshold").unwrap_or(&2.5),
975 ensemble_size: Some(100),
976 };
977
978 let avg_skewness = characteristics
980 .skewness
981 .iter()
982 .map(|x| x.abs())
983 .sum::<Float>()
984 / characteristics.skewness.len() as Float;
985
986 let transformation = TransformationParameters {
987 method: if avg_skewness > 1.5 {
988 "log1p".to_string()
989 } else if avg_skewness > 1.0 {
990 "box_cox".to_string()
991 } else {
992 "none".to_string()
993 },
994 handle_negatives: true,
995 lambda: None, n_quantiles: Some(1000),
997 };
998
999 let confidence = characteristics.quality_score * 0.5
1001 + (1.0
1002 - (characteristics.missing_percentages.iter().sum::<Float>()
1003 / characteristics.missing_percentages.len() as Float
1004 / 100.0))
1005 * 0.3
1006 + (1.0
1007 - (characteristics.outlier_percentages.iter().sum::<Float>()
1008 / characteristics.outlier_percentages.len() as Float
1009 / 100.0))
1010 * 0.2;
1011
1012 Ok(ParameterRecommendations {
1013 scaling,
1014 imputation,
1015 outlier_detection,
1016 transformation,
1017 confidence: confidence.min(1.0).max(0.0),
1018 })
1019 }
1020
1021 pub fn adaptation_report(&self) -> Result<String> {
1023 let characteristics = self.data_characteristics_.as_ref().ok_or_else(|| {
1024 SklearsError::InvalidInput("No data characteristics available".to_string())
1025 })?;
1026
1027 let recommendations = self.recommend_parameters()?;
1028
1029 let mut report = String::new();
1030
1031 report.push_str("=== Adaptive Parameter Selection Report ===\n\n");
1032
1033 report.push_str("=== Data Characteristics ===\n");
1035 report.push_str(&format!("Data shape: {:?}\n", characteristics.shape));
1036 report.push_str(&format!(
1037 "Overall quality score: {:.3}\n",
1038 characteristics.quality_score
1039 ));
1040 report.push_str(&format!(
1041 "Correlation strength: {:.3}\n",
1042 characteristics.correlation_strength
1043 ));
1044 report.push_str(&format!(
1045 "Optimal batch size: {}\n",
1046 characteristics.optimal_batch_size
1047 ));
1048
1049 let avg_missing = characteristics.missing_percentages.iter().sum::<Float>()
1050 / characteristics.missing_percentages.len() as Float;
1051 let avg_outliers = characteristics.outlier_percentages.iter().sum::<Float>()
1052 / characteristics.outlier_percentages.len() as Float;
1053 let avg_skewness = characteristics
1054 .skewness
1055 .iter()
1056 .map(|x| x.abs())
1057 .sum::<Float>()
1058 / characteristics.skewness.len() as Float;
1059
1060 report.push_str(&format!("Average missing values: {:.1}%\n", avg_missing));
1061 report.push_str(&format!("Average outlier rate: {:.1}%\n", avg_outliers));
1062 report.push_str(&format!("Average absolute skewness: {:.3}\n", avg_skewness));
1063 report.push_str("\n");
1064
1065 report.push_str("=== Parameter Recommendations ===\n");
1067 report.push_str(&format!(
1068 "Confidence: {:.1}%\n\n",
1069 recommendations.confidence * 100.0
1070 ));
1071
1072 report.push_str("Scaling:\n");
1073 report.push_str(&format!(" Method: {}\n", recommendations.scaling.method));
1074 report.push_str(&format!(
1075 " Outlier threshold: {:.2}\n",
1076 recommendations.scaling.outlier_threshold
1077 ));
1078 report.push_str(&format!(
1079 " Quantile range: ({:.1}%, {:.1}%)\n",
1080 recommendations.scaling.quantile_range.0, recommendations.scaling.quantile_range.1
1081 ));
1082 report.push_str("\n");
1083
1084 report.push_str("Imputation:\n");
1085 report.push_str(&format!(
1086 " Strategy: {}\n",
1087 recommendations.imputation.strategy
1088 ));
1089 if let Some(k) = recommendations.imputation.n_neighbors {
1090 report.push_str(&format!(" K-neighbors: {}\n", k));
1091 }
1092 report.push_str(&format!(
1093 " Outlier-aware: {}\n",
1094 recommendations.imputation.outlier_aware
1095 ));
1096 report.push_str("\n");
1097
1098 report.push_str("Outlier Detection:\n");
1099 report.push_str(&format!(
1100 " Method: {}\n",
1101 recommendations.outlier_detection.method
1102 ));
1103 report.push_str(&format!(
1104 " Contamination: {:.3}\n",
1105 recommendations.outlier_detection.contamination
1106 ));
1107 report.push_str(&format!(
1108 " Threshold: {:.2}\n",
1109 recommendations.outlier_detection.threshold
1110 ));
1111 report.push_str("\n");
1112
1113 report.push_str("Transformation:\n");
1114 report.push_str(&format!(
1115 " Method: {}\n",
1116 recommendations.transformation.method
1117 ));
1118 report.push_str(&format!(
1119 " Handle negatives: {}\n",
1120 recommendations.transformation.handle_negatives
1121 ));
1122 report.push_str("\n");
1123
1124 report.push_str("=== Configuration ===\n");
1126 report.push_str(&format!("Strategy: {:?}\n", self.config.strategy));
1127 report.push_str(&format!(
1128 "Cross-validation: {} ({} folds)\n",
1129 self.config.use_cross_validation, self.config.cv_folds
1130 ));
1131 report.push_str(&format!("Parallel processing: {}\n", self.config.parallel));
1132 if let Some(budget) = self.config.time_budget {
1133 report.push_str(&format!("Time budget: {:.1}s\n", budget));
1134 }
1135
1136 Ok(report)
1137 }
1138
1139 pub fn get_insights(&self) -> Vec<String> {
1141 let mut insights = Vec::new();
1142
1143 if let Some(characteristics) = &self.data_characteristics_ {
1144 let avg_missing = characteristics.missing_percentages.iter().sum::<Float>()
1145 / characteristics.missing_percentages.len() as Float;
1146 let avg_outliers = characteristics.outlier_percentages.iter().sum::<Float>()
1147 / characteristics.outlier_percentages.len() as Float;
1148 let avg_skewness = characteristics
1149 .skewness
1150 .iter()
1151 .map(|x| x.abs())
1152 .sum::<Float>()
1153 / characteristics.skewness.len() as Float;
1154
1155 if avg_missing > 20.0 {
1156 insights.push("High missing value rate detected - consider advanced imputation methods like KNN or iterative imputation".to_string());
1157 }
1158
1159 if avg_outliers > 15.0 {
1160 insights.push(
1161 "High outlier rate detected - robust preprocessing methods are recommended"
1162 .to_string(),
1163 );
1164 }
1165
1166 if avg_skewness > 2.0 {
1167 insights.push(
1168 "Highly skewed data detected - consider log or Box-Cox transformations"
1169 .to_string(),
1170 );
1171 }
1172
1173 if characteristics.correlation_strength > 0.7 {
1174 insights.push(
1175 "High feature correlation detected - consider dimensionality reduction"
1176 .to_string(),
1177 );
1178 }
1179
1180 if characteristics.quality_score < 0.5 {
1181 insights.push(
1182 "Low data quality detected - comprehensive preprocessing pipeline recommended"
1183 .to_string(),
1184 );
1185 }
1186
1187 if characteristics.shape.0 > 1_000_000 {
1188 insights.push(
1189 "Large dataset detected - consider streaming or batch processing approaches"
1190 .to_string(),
1191 );
1192 }
1193
1194 if characteristics.optimal_batch_size < characteristics.shape.0 {
1195 insights.push(format!(
1196 "Consider batch processing with batch size: {}",
1197 characteristics.optimal_batch_size
1198 ));
1199 }
1200 }
1201
1202 if insights.is_empty() {
1203 insights.push("Data characteristics are within normal ranges - standard preprocessing should be sufficient".to_string());
1204 }
1205
1206 insights
1207 }
1208}
1209
1210impl Default for AdaptiveParameterSelector<Untrained> {
1211 fn default() -> Self {
1212 Self::new()
1213 }
1214}
1215
1216#[allow(non_snake_case)]
1217#[cfg(test)]
1218mod tests {
1219 use super::*;
1220 use approx::assert_relative_eq;
1221 use scirs2_core::ndarray::Array2;
1222
1223 #[test]
1224 fn test_adaptive_parameter_selector_creation() {
1225 let selector = AdaptiveParameterSelector::new();
1226 assert_eq!(
1227 selector.config.strategy as u8,
1228 AdaptationStrategy::Balanced as u8
1229 );
1230 assert!(selector.config.use_cross_validation);
1231 assert_eq!(selector.config.cv_folds, 5);
1232 }
1233
1234 #[test]
1235 fn test_adaptive_strategies() {
1236 let conservative = AdaptiveParameterSelector::conservative();
1237 assert_eq!(
1238 conservative.config.strategy as u8,
1239 AdaptationStrategy::Conservative as u8
1240 );
1241
1242 let aggressive = AdaptiveParameterSelector::aggressive();
1243 assert_eq!(
1244 aggressive.config.strategy as u8,
1245 AdaptationStrategy::Aggressive as u8
1246 );
1247 }
1248
1249 #[test]
1250 fn test_data_characteristics_analysis() {
1251 let data = Array2::from_shape_vec(
1252 (10, 3),
1253 vec![
1254 1.0, 10.0, 100.0, 2.0, 20.0, 200.0, 3.0, 30.0, 300.0, 4.0, 40.0, 400.0, 5.0, 50.0, 500.0, 6.0, 60.0,
1256 600.0, 7.0, 70.0, 700.0, 8.0, 80.0, 800.0, 100.0, 1000.0, 10000.0, 9.0, 90.0, 900.0,
1258 ],
1259 )
1260 .expect("operation should succeed");
1261
1262 let selector = AdaptiveParameterSelector::balanced();
1263 let fitted = selector
1264 .fit(&data, &())
1265 .expect("model fitting should succeed");
1266
1267 let characteristics = fitted
1268 .data_characteristics()
1269 .expect("operation should succeed");
1270 assert_eq!(characteristics.shape, (10, 3));
1271 assert_eq!(characteristics.distribution_types.len(), 3);
1272 assert_eq!(characteristics.skewness.len(), 3);
1273 assert!(characteristics.quality_score >= 0.0 && characteristics.quality_score <= 1.0);
1274 }
1275
1276 #[test]
1277 fn test_parameter_recommendations() {
1278 let data = Array2::from_shape_vec(
1279 (8, 2),
1280 vec![
1281 1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0,
1282 1000.0, 7.0, 70.0, 8.0, 80.0,
1284 ],
1285 )
1286 .expect("operation should succeed");
1287
1288 let selector = AdaptiveParameterSelector::balanced();
1289 let fitted = selector
1290 .fit(&data, &())
1291 .expect("model fitting should succeed");
1292
1293 let recommendations = fitted
1294 .recommend_parameters()
1295 .expect("operation should succeed");
1296 assert!(recommendations.confidence >= 0.0 && recommendations.confidence <= 1.0);
1297 assert!(!recommendations.scaling.method.is_empty());
1298 assert!(!recommendations.imputation.strategy.is_empty());
1299 }
1300
1301 #[test]
1302 fn test_parameter_optimization() {
1303 let data = Array2::from_shape_vec(
1304 (6, 2),
1305 vec![
1306 1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0,
1307 1000.0, ],
1309 )
1310 .expect("operation should succeed");
1311
1312 let selector = AdaptiveParameterSelector::aggressive();
1313 let fitted = selector
1314 .fit(&data, &())
1315 .expect("model fitting should succeed");
1316
1317 let optimal_params = fitted
1318 .optimal_parameters()
1319 .expect("operation should succeed");
1320 assert!(optimal_params.contains_key("scaling_method"));
1321 assert!(optimal_params.contains_key("outlier_threshold"));
1322 assert!(optimal_params.contains_key("contamination_rate"));
1323 }
1324
1325 #[test]
1326 fn test_distribution_classification() {
1327 let data = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0])
1328 .expect("shape and data length should match");
1329
1330 let selector = AdaptiveParameterSelector::new();
1331 let fitted = selector
1332 .fit(&data, &())
1333 .expect("model fitting should succeed");
1334
1335 let characteristics = fitted
1336 .data_characteristics()
1337 .expect("operation should succeed");
1338 assert!(matches!(
1340 characteristics.distribution_types[0],
1341 DistributionType::Normal | DistributionType::Uniform | DistributionType::Unknown
1342 ));
1343 }
1344
1345 #[test]
1346 fn test_missing_value_handling() {
1347 let data = Array2::from_shape_vec(
1348 (6, 2),
1349 vec![
1350 1.0,
1351 10.0,
1352 2.0,
1353 Float::NAN, 3.0,
1355 30.0,
1356 Float::NAN,
1357 40.0, 5.0,
1359 50.0,
1360 6.0,
1361 60.0,
1362 ],
1363 )
1364 .expect("operation should succeed");
1365
1366 let selector = AdaptiveParameterSelector::balanced();
1367 let fitted = selector
1368 .fit(&data, &())
1369 .expect("model fitting should succeed");
1370
1371 let characteristics = fitted
1372 .data_characteristics()
1373 .expect("operation should succeed");
1374 assert!(
1376 characteristics.missing_percentages[0] > 0.0
1377 || characteristics.missing_percentages[1] > 0.0
1378 );
1379 }
1380
1381 #[test]
1382 fn test_adaptation_report() {
1383 let data = Array2::from_shape_vec(
1384 (6, 2),
1385 vec![
1386 1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0, 1000.0,
1387 ],
1388 )
1389 .expect("operation should succeed");
1390
1391 let selector = AdaptiveParameterSelector::balanced();
1392 let fitted = selector
1393 .fit(&data, &())
1394 .expect("model fitting should succeed");
1395
1396 let report = fitted
1397 .adaptation_report()
1398 .expect("operation should succeed");
1399 assert!(report.contains("Adaptive Parameter Selection Report"));
1400 assert!(report.contains("Data Characteristics"));
1401 assert!(report.contains("Parameter Recommendations"));
1402 }
1403
1404 #[test]
1405 fn test_insights_generation() {
1406 let data = Array2::from_shape_vec((4, 2), vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0])
1407 .expect("operation should succeed");
1408
1409 let selector = AdaptiveParameterSelector::conservative();
1410 let fitted = selector
1411 .fit(&data, &())
1412 .expect("model fitting should succeed");
1413
1414 let insights = fitted.get_insights();
1415 assert!(!insights.is_empty());
1416 }
1417
1418 #[test]
1419 fn test_configuration_options() {
1420 let selector = AdaptiveParameterSelector::new()
1421 .cross_validation(false, 3)
1422 .time_budget(30.0)
1423 .parallel(false)
1424 .tolerance(1e-3);
1425
1426 assert!(!selector.config.use_cross_validation);
1427 assert_eq!(selector.config.cv_folds, 3);
1428 assert_eq!(selector.config.time_budget, Some(30.0));
1429 assert!(!selector.config.parallel);
1430 assert_relative_eq!(selector.config.tolerance, 1e-3, epsilon = 1e-10);
1431 }
1432
1433 #[test]
1434 fn test_error_handling() {
1435 let selector = AdaptiveParameterSelector::new();
1436
1437 let empty_data =
1439 Array2::from_shape_vec((0, 0), vec![]).expect("shape and data length should match");
1440 assert!(selector.fit(&empty_data, &()).is_err());
1441 }
1442
1443 #[test]
1444 fn test_parameter_bounds() {
1445 let mut bounds = HashMap::new();
1446 bounds.insert("outlier_threshold".to_string(), (1.0, 4.0));
1447
1448 let selector = AdaptiveParameterSelector::new().parameter_bounds(bounds.clone());
1449
1450 assert_eq!(selector.config.parameter_bounds, bounds);
1451 }
1452}