1use scirs2_core::ndarray::Array2;
17use sklears_core::{
18 error::{Result, SklearsError},
19 traits::{Fit, Trained, Untrained},
20 types::Float,
21};
22use std::collections::HashMap;
23use std::marker::PhantomData;
24
25#[derive(Debug, Clone)]
27pub struct DataCharacteristics {
28 pub shape: (usize, usize),
30 pub distribution_types: Vec<DistributionType>,
32 pub skewness: Vec<Float>,
34 pub kurtosis: Vec<Float>,
36 pub outlier_percentages: Vec<Float>,
38 pub missing_percentages: Vec<Float>,
40 pub ranges: Vec<(Float, Float)>,
42 pub correlation_strength: Float,
44 pub quality_score: Float,
46 pub optimal_batch_size: usize,
48}
49
50#[derive(Debug, Clone, Copy)]
52pub enum DistributionType {
53 Normal,
55 Skewed,
57 Uniform,
59 Multimodal,
61 HeavyTailed,
63 Sparse,
65 Unknown,
67}
68
69#[derive(Debug, Clone, Copy)]
71pub enum AdaptationStrategy {
72 Conservative,
74 Balanced,
76 Aggressive,
78 Custom,
80}
81
82#[derive(Debug, Clone)]
84pub struct AdaptiveConfig {
85 pub strategy: AdaptationStrategy,
87 pub use_cross_validation: bool,
89 pub cv_folds: usize,
91 pub time_budget: Option<Float>,
93 pub parallel: bool,
95 pub tolerance: Float,
97 pub max_iterations: usize,
99 pub parameter_bounds: HashMap<String, (Float, Float)>,
101}
102
103impl Default for AdaptiveConfig {
104 fn default() -> Self {
105 Self {
106 strategy: AdaptationStrategy::Balanced,
107 use_cross_validation: true,
108 cv_folds: 5,
109 time_budget: Some(60.0), parallel: true,
111 tolerance: 1e-4,
112 max_iterations: 100,
113 parameter_bounds: HashMap::new(),
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct AdaptiveParameterSelector<State = Untrained> {
121 config: AdaptiveConfig,
122 state: PhantomData<State>,
123 data_characteristics_: Option<DataCharacteristics>,
125 optimal_parameters_: Option<HashMap<String, Float>>,
126 parameter_history_: Option<Vec<ParameterEvaluation>>,
127}
128
129#[derive(Debug, Clone)]
131pub struct ParameterEvaluation {
132 pub parameters: HashMap<String, Float>,
133 pub score: Float,
134 pub robustness_score: Float,
135 pub efficiency_score: Float,
136 pub quality_score: Float,
137 pub evaluation_time: Float,
138}
139
140#[derive(Debug, Clone)]
142pub struct ParameterRecommendations {
143 pub scaling: ScalingParameters,
145 pub imputation: ImputationParameters,
147 pub outlier_detection: OutlierDetectionParameters,
149 pub transformation: TransformationParameters,
151 pub confidence: Float,
153}
154
155#[derive(Debug, Clone)]
157pub struct ScalingParameters {
158 pub method: String, pub outlier_threshold: Float,
160 pub quantile_range: (Float, Float),
161 pub with_centering: bool,
162 pub with_scaling: bool,
163}
164
165#[derive(Debug, Clone)]
167pub struct ImputationParameters {
168 pub strategy: String, pub n_neighbors: Option<usize>,
170 pub outlier_aware: bool,
171 pub max_iterations: Option<usize>,
172}
173
174#[derive(Debug, Clone)]
176pub struct OutlierDetectionParameters {
177 pub method: String, pub contamination: Float,
179 pub threshold: Float,
180 pub ensemble_size: Option<usize>,
181}
182
183#[derive(Debug, Clone)]
185pub struct TransformationParameters {
186 pub method: String, pub handle_negatives: bool,
188 pub lambda: Option<Float>,
189 pub n_quantiles: Option<usize>,
190}
191
192impl AdaptiveParameterSelector<Untrained> {
193 pub fn new() -> Self {
195 Self {
196 config: AdaptiveConfig::default(),
197 state: PhantomData,
198 data_characteristics_: None,
199 optimal_parameters_: None,
200 parameter_history_: None,
201 }
202 }
203
204 pub fn conservative() -> Self {
206 Self::new().strategy(AdaptationStrategy::Conservative)
207 }
208
209 pub fn balanced() -> Self {
211 Self::new().strategy(AdaptationStrategy::Balanced)
212 }
213
214 pub fn aggressive() -> Self {
216 Self::new().strategy(AdaptationStrategy::Aggressive)
217 }
218
219 pub fn strategy(mut self, strategy: AdaptationStrategy) -> Self {
221 self.config.strategy = strategy;
222 self
223 }
224
225 pub fn cross_validation(mut self, enable: bool, folds: usize) -> Self {
227 self.config.use_cross_validation = enable;
228 self.config.cv_folds = folds;
229 self
230 }
231
232 pub fn time_budget(mut self, seconds: Float) -> Self {
234 self.config.time_budget = Some(seconds);
235 self
236 }
237
238 pub fn parallel(mut self, enable: bool) -> Self {
240 self.config.parallel = enable;
241 self
242 }
243
244 pub fn tolerance(mut self, tolerance: Float) -> Self {
246 self.config.tolerance = tolerance;
247 self
248 }
249
250 pub fn parameter_bounds(mut self, bounds: HashMap<String, (Float, Float)>) -> Self {
252 self.config.parameter_bounds = bounds;
253 self
254 }
255}
256
257impl Fit<Array2<Float>, ()> for AdaptiveParameterSelector<Untrained> {
258 type Fitted = AdaptiveParameterSelector<Trained>;
259
260 fn fit(mut self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
261 let (n_samples, n_features) = x.dim();
262
263 if n_samples == 0 || n_features == 0 {
264 return Err(SklearsError::InvalidInput(
265 "Input array is empty".to_string(),
266 ));
267 }
268
269 let characteristics = self.analyze_data_characteristics(x)?;
271
272 let optimal_parameters = self.optimize_parameters(x, &characteristics)?;
274
275 let parameter_history = self.evaluate_parameter_space(x, &characteristics)?;
277
278 self.data_characteristics_ = Some(characteristics);
279 self.optimal_parameters_ = Some(optimal_parameters);
280 self.parameter_history_ = Some(parameter_history);
281
282 Ok(AdaptiveParameterSelector {
283 config: self.config,
284 state: PhantomData,
285 data_characteristics_: self.data_characteristics_,
286 optimal_parameters_: self.optimal_parameters_,
287 parameter_history_: self.parameter_history_,
288 })
289 }
290}
291
292impl AdaptiveParameterSelector<Untrained> {
293 fn analyze_data_characteristics(&self, x: &Array2<Float>) -> Result<DataCharacteristics> {
295 let (n_samples, n_features) = x.dim();
296
297 let mut distribution_types = Vec::with_capacity(n_features);
298 let mut skewness = Vec::with_capacity(n_features);
299 let mut kurtosis = Vec::with_capacity(n_features);
300 let mut outlier_percentages = Vec::with_capacity(n_features);
301 let mut missing_percentages = Vec::with_capacity(n_features);
302 let mut ranges = Vec::with_capacity(n_features);
303
304 for j in 0..n_features {
306 let column = x.column(j);
307
308 let valid_values: Vec<Float> =
310 column.iter().filter(|x| x.is_finite()).copied().collect();
311
312 let missing_pct =
313 ((n_samples - valid_values.len()) as Float / n_samples as Float) * 100.0;
314 missing_percentages.push(missing_pct);
315
316 if valid_values.is_empty() {
317 distribution_types.push(DistributionType::Unknown);
318 skewness.push(0.0);
319 kurtosis.push(0.0);
320 outlier_percentages.push(0.0);
321 ranges.push((0.0, 0.0));
322 continue;
323 }
324
325 let mean = valid_values.iter().sum::<Float>() / valid_values.len() as Float;
327 let variance = valid_values
328 .iter()
329 .map(|x| (x - mean).powi(2))
330 .sum::<Float>()
331 / valid_values.len() as Float;
332 let std = variance.sqrt();
333
334 let feature_skewness = if std > 0.0 {
336 valid_values
337 .iter()
338 .map(|x| ((x - mean) / std).powi(3))
339 .sum::<Float>()
340 / valid_values.len() as Float
341 } else {
342 0.0
343 };
344
345 let feature_kurtosis = if std > 0.0 {
346 valid_values
347 .iter()
348 .map(|x| ((x - mean) / std).powi(4))
349 .sum::<Float>()
350 / valid_values.len() as Float
351 - 3.0 } else {
353 0.0
354 };
355
356 skewness.push(feature_skewness);
357 kurtosis.push(feature_kurtosis);
358
359 let mut sorted_values = valid_values.clone();
361 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
362
363 let q1_idx = sorted_values.len() / 4;
364 let q3_idx = 3 * sorted_values.len() / 4;
365 let q1 = sorted_values[q1_idx];
366 let q3 = sorted_values[q3_idx];
367 let iqr = q3 - q1;
368
369 let lower_bound = q1 - 1.5 * iqr;
370 let upper_bound = q3 + 1.5 * iqr;
371
372 let outlier_count = valid_values
373 .iter()
374 .filter(|&&x| x < lower_bound || x > upper_bound)
375 .count();
376 let outlier_pct = (outlier_count as Float / valid_values.len() as Float) * 100.0;
377 outlier_percentages.push(outlier_pct);
378
379 let min_val = sorted_values[0];
381 let max_val = sorted_values[sorted_values.len() - 1];
382 ranges.push((min_val, max_val));
383
384 let dist_type = self.classify_distribution(
386 feature_skewness,
387 feature_kurtosis,
388 outlier_pct,
389 &valid_values,
390 );
391 distribution_types.push(dist_type);
392 }
393
394 let correlation_strength = self.estimate_correlation_strength(x)?;
396
397 let avg_missing = missing_percentages.iter().sum::<Float>() / n_features as Float;
399 let avg_outliers = outlier_percentages.iter().sum::<Float>() / n_features as Float;
400 let quality_score = (100.0 - avg_missing - avg_outliers).max(0.0) / 100.0;
401
402 let optimal_batch_size = self.estimate_optimal_batch_size(n_samples, n_features);
404
405 Ok(DataCharacteristics {
406 shape: (n_samples, n_features),
407 distribution_types,
408 skewness,
409 kurtosis,
410 outlier_percentages,
411 missing_percentages,
412 ranges,
413 correlation_strength,
414 quality_score,
415 optimal_batch_size,
416 })
417 }
418
419 fn classify_distribution(
421 &self,
422 skewness: Float,
423 kurtosis: Float,
424 outlier_pct: Float,
425 values: &[Float],
426 ) -> DistributionType {
427 let zero_count = values.iter().filter(|&&x| x.abs() < 1e-10).count();
429 let sparsity = zero_count as Float / values.len() as Float;
430
431 if sparsity > 0.5 {
432 return DistributionType::Sparse;
433 }
434
435 if skewness.abs() < 0.5 && kurtosis.abs() < 1.0 && outlier_pct < 5.0 {
437 return DistributionType::Normal;
438 }
439
440 if skewness.abs() > 1.0 {
442 return DistributionType::Skewed;
443 }
444
445 if kurtosis > 2.0 || outlier_pct > 10.0 {
447 return DistributionType::HeavyTailed;
448 }
449
450 if kurtosis < -1.0 && skewness.abs() < 0.5 {
452 return DistributionType::Uniform;
453 }
454
455 if kurtosis < -1.5 && outlier_pct > 5.0 {
457 return DistributionType::Multimodal;
458 }
459
460 DistributionType::Unknown
461 }
462
463 fn estimate_correlation_strength(&self, x: &Array2<Float>) -> Result<Float> {
465 let (_n_samples, n_features) = x.dim();
466
467 if n_features < 2 {
468 return Ok(0.0);
469 }
470
471 let mut correlation_sum = 0.0;
472 let mut correlation_count = 0;
473
474 let max_pairs = 100.min(n_features * (n_features - 1) / 2);
476 let step = (n_features * (n_features - 1) / 2).max(1) / max_pairs.max(1);
477
478 let mut pair_count = 0;
479 for i in 0..n_features {
480 for j in (i + 1)..n_features {
481 if pair_count % step == 0 {
482 let col_i = x.column(i);
483 let col_j = x.column(j);
484
485 if let Ok(corr) = self.calculate_correlation(&col_i, &col_j) {
487 correlation_sum += corr.abs();
488 correlation_count += 1;
489 }
490 }
491 pair_count += 1;
492 }
493 }
494
495 Ok(if correlation_count > 0 {
496 correlation_sum / correlation_count as Float
497 } else {
498 0.0
499 })
500 }
501
502 fn calculate_correlation(
504 &self,
505 x: &scirs2_core::ndarray::ArrayView1<Float>,
506 y: &scirs2_core::ndarray::ArrayView1<Float>,
507 ) -> Result<Float> {
508 let pairs: Vec<(Float, Float)> = x
509 .iter()
510 .zip(y.iter())
511 .filter(|(&a, &b)| a.is_finite() && b.is_finite())
512 .map(|(&a, &b)| (a, b))
513 .collect();
514
515 if pairs.len() < 3 {
516 return Ok(0.0);
517 }
518
519 let mean_x = pairs.iter().map(|(x, _)| x).sum::<Float>() / pairs.len() as Float;
520 let mean_y = pairs.iter().map(|(_, y)| y).sum::<Float>() / pairs.len() as Float;
521
522 let mut sum_xy = 0.0;
523 let mut sum_x2 = 0.0;
524 let mut sum_y2 = 0.0;
525
526 for (x, y) in pairs {
527 let dx = x - mean_x;
528 let dy = y - mean_y;
529 sum_xy += dx * dy;
530 sum_x2 += dx * dx;
531 sum_y2 += dy * dy;
532 }
533
534 let denominator = (sum_x2 * sum_y2).sqrt();
535 if denominator > 1e-10 {
536 Ok(sum_xy / denominator)
537 } else {
538 Ok(0.0)
539 }
540 }
541
542 fn estimate_optimal_batch_size(&self, n_samples: usize, n_features: usize) -> usize {
544 let data_size = n_samples * n_features * std::mem::size_of::<Float>();
546 let target_memory = 100_000_000; let optimal_size = if data_size <= target_memory {
549 n_samples } else {
551 (target_memory / (n_features * std::mem::size_of::<Float>()))
552 .max(1000)
553 .min(n_samples)
554 };
555
556 optimal_size
557 }
558
559 fn optimize_parameters(
561 &self,
562 _x: &Array2<Float>,
563 characteristics: &DataCharacteristics,
564 ) -> Result<HashMap<String, Float>> {
565 let mut optimal_params = HashMap::new();
566
567 let scaling_method = self.select_optimal_scaling_method(characteristics);
569 optimal_params.insert("scaling_method".to_string(), scaling_method);
570
571 let outlier_threshold = self.select_optimal_outlier_threshold(characteristics);
573 optimal_params.insert("outlier_threshold".to_string(), outlier_threshold);
574
575 let imputation_strategy = self.select_optimal_imputation_strategy(characteristics);
577 optimal_params.insert("imputation_strategy".to_string(), imputation_strategy);
578
579 let (q_low, q_high) = self.select_optimal_quantile_range(characteristics);
581 optimal_params.insert("quantile_range_low".to_string(), q_low);
582 optimal_params.insert("quantile_range_high".to_string(), q_high);
583
584 let contamination_rate = self.select_optimal_contamination_rate(characteristics);
586 optimal_params.insert("contamination_rate".to_string(), contamination_rate);
587
588 Ok(optimal_params)
589 }
590
591 fn select_optimal_scaling_method(&self, characteristics: &DataCharacteristics) -> Float {
593 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
594 / characteristics.outlier_percentages.len() as Float;
595 let avg_skewness = characteristics
596 .skewness
597 .iter()
598 .map(|x| x.abs())
599 .sum::<Float>()
600 / characteristics.skewness.len() as Float;
601
602 match self.config.strategy {
603 AdaptationStrategy::Conservative => {
604 if avg_outlier_pct > 10.0 || avg_skewness > 1.0 {
605 2.0 } else {
607 0.0 }
609 }
610 AdaptationStrategy::Balanced => {
611 if avg_outlier_pct > 15.0 {
612 2.0 } else if avg_skewness > 2.0 {
614 1.0 } else {
616 0.0 }
618 }
619 AdaptationStrategy::Aggressive => {
620 if avg_outlier_pct > 20.0 {
621 2.0 } else {
623 0.0 }
625 }
626 AdaptationStrategy::Custom => 0.0, }
628 }
629
630 fn select_optimal_outlier_threshold(&self, characteristics: &DataCharacteristics) -> Float {
632 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
633 / characteristics.outlier_percentages.len() as Float;
634
635 match self.config.strategy {
636 AdaptationStrategy::Conservative => {
637 if avg_outlier_pct > 20.0 {
638 3.5
639 } else {
640 3.0
641 }
642 }
643 AdaptationStrategy::Balanced => {
644 if avg_outlier_pct > 15.0 {
645 2.5
646 } else {
647 2.0
648 }
649 }
650 AdaptationStrategy::Aggressive => {
651 if avg_outlier_pct > 10.0 {
652 2.0
653 } else {
654 1.5
655 }
656 }
657 AdaptationStrategy::Custom => 2.5, }
659 }
660
661 fn select_optimal_imputation_strategy(&self, characteristics: &DataCharacteristics) -> Float {
663 let avg_missing_pct = characteristics.missing_percentages.iter().sum::<Float>()
664 / characteristics.missing_percentages.len() as Float;
665 let has_skewed_features = characteristics.skewness.iter().any(|&s| s.abs() > 1.0);
666
667 if avg_missing_pct > 20.0 {
668 2.0 } else if has_skewed_features {
670 1.0 } else {
672 0.0 }
674 }
675
676 fn select_optimal_quantile_range(
678 &self,
679 characteristics: &DataCharacteristics,
680 ) -> (Float, Float) {
681 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
682 / characteristics.outlier_percentages.len() as Float;
683
684 match self.config.strategy {
685 AdaptationStrategy::Conservative => {
686 if avg_outlier_pct > 15.0 {
687 (10.0, 90.0)
688 } else {
689 (25.0, 75.0)
690 }
691 }
692 AdaptationStrategy::Balanced => {
693 if avg_outlier_pct > 10.0 {
694 (5.0, 95.0)
695 } else {
696 (25.0, 75.0)
697 }
698 }
699 AdaptationStrategy::Aggressive => {
700 (25.0, 75.0) }
702 AdaptationStrategy::Custom => (25.0, 75.0),
703 }
704 }
705
706 fn select_optimal_contamination_rate(&self, characteristics: &DataCharacteristics) -> Float {
708 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
709 / characteristics.outlier_percentages.len() as Float;
710
711 (avg_outlier_pct / 100.0 * 1.2).min(0.5).max(0.01)
713 }
714
715 fn evaluate_parameter_space(
717 &self,
718 x: &Array2<Float>,
719 characteristics: &DataCharacteristics,
720 ) -> Result<Vec<ParameterEvaluation>> {
721 let mut evaluations = Vec::new();
722
723 let scaling_methods = vec![0.0, 1.0, 2.0]; let outlier_thresholds = vec![1.5, 2.0, 2.5, 3.0, 3.5];
726 let contamination_rates = vec![0.05, 0.1, 0.15, 0.2];
727
728 let max_evaluations = 20; let mut evaluation_count = 0;
731
732 for &scaling_method in &scaling_methods {
733 for &threshold in &outlier_thresholds {
734 for &contamination in &contamination_rates {
735 if evaluation_count >= max_evaluations {
736 break;
737 }
738
739 let mut params = HashMap::new();
740 params.insert("scaling_method".to_string(), scaling_method);
741 params.insert("outlier_threshold".to_string(), threshold);
742 params.insert("contamination_rate".to_string(), contamination);
743
744 let evaluation = self.evaluate_parameters(¶ms, x, characteristics)?;
745 evaluations.push(evaluation);
746 evaluation_count += 1;
747 }
748 }
749 }
750
751 evaluations.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
753
754 Ok(evaluations)
755 }
756
757 fn evaluate_parameters(
759 &self,
760 params: &HashMap<String, Float>,
761 _x: &Array2<Float>,
762 characteristics: &DataCharacteristics,
763 ) -> Result<ParameterEvaluation> {
764 let start_time = std::time::Instant::now();
765
766 let robustness_score = self.compute_robustness_score(params, characteristics);
768 let efficiency_score = self.compute_efficiency_score(params, characteristics);
769 let quality_score = self.compute_quality_score(params, characteristics);
770
771 let overall_score = match self.config.strategy {
773 AdaptationStrategy::Conservative => {
774 robustness_score * 0.6 + quality_score * 0.3 + efficiency_score * 0.1
775 }
776 AdaptationStrategy::Balanced => {
777 robustness_score * 0.4 + quality_score * 0.4 + efficiency_score * 0.2
778 }
779 AdaptationStrategy::Aggressive => {
780 robustness_score * 0.2 + quality_score * 0.3 + efficiency_score * 0.5
781 }
782 AdaptationStrategy::Custom => {
783 robustness_score * 0.33 + quality_score * 0.33 + efficiency_score * 0.34
784 }
785 };
786
787 let evaluation_time = start_time.elapsed().as_secs_f64() as Float;
788
789 Ok(ParameterEvaluation {
790 parameters: params.clone(),
791 score: overall_score,
792 robustness_score,
793 efficiency_score,
794 quality_score,
795 evaluation_time,
796 })
797 }
798
799 fn compute_robustness_score(
801 &self,
802 params: &HashMap<String, Float>,
803 characteristics: &DataCharacteristics,
804 ) -> Float {
805 let scaling_method = params.get("scaling_method").unwrap_or(&0.0);
806 let outlier_threshold = params.get("outlier_threshold").unwrap_or(&2.5);
807
808 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
809 / characteristics.outlier_percentages.len() as Float;
810
811 let mut score: Float = 0.0;
812
813 if avg_outlier_pct > 10.0 && *scaling_method == 2.0 {
815 score += 0.4;
816 }
817
818 if avg_outlier_pct > 15.0 && *outlier_threshold <= 2.5 {
820 score += 0.3;
821 } else if avg_outlier_pct <= 5.0 && *outlier_threshold >= 3.0 {
822 score += 0.3;
823 }
824
825 let avg_skewness = characteristics
827 .skewness
828 .iter()
829 .map(|x| x.abs())
830 .sum::<Float>()
831 / characteristics.skewness.len() as Float;
832 if avg_skewness > 1.0 && *scaling_method != 0.0 {
833 score += 0.3;
834 }
835
836 score.min(1.0 as Float)
837 }
838
839 fn compute_efficiency_score(
841 &self,
842 params: &HashMap<String, Float>,
843 characteristics: &DataCharacteristics,
844 ) -> Float {
845 let scaling_method = params.get("scaling_method").unwrap_or(&0.0);
846 let (n_samples, n_features) = characteristics.shape;
847
848 let mut score: Float = if *scaling_method == 0.0 {
850 1.0
851 } else if *scaling_method == 1.0 {
852 0.8 } else {
854 0.6 };
856
857 let data_size_factor = (n_samples * n_features) as Float;
859 if data_size_factor > 1_000_000.0 {
860 score *= 1.2; }
862
863 score.min(1.0 as Float)
864 }
865
866 fn compute_quality_score(
868 &self,
869 params: &HashMap<String, Float>,
870 characteristics: &DataCharacteristics,
871 ) -> Float {
872 let mut score = characteristics.quality_score; let avg_missing_pct = characteristics.missing_percentages.iter().sum::<Float>()
876 / characteristics.missing_percentages.len() as Float;
877
878 if avg_missing_pct > 10.0 {
880 score *= 0.9; }
882
883 let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
885 / characteristics.outlier_percentages.len() as Float;
886
887 let outlier_threshold = params.get("outlier_threshold").unwrap_or(&2.5);
888 if avg_outlier_pct > 10.0 && *outlier_threshold <= 2.5 {
889 score *= 1.1; }
891
892 score.min(1.0 as Float)
893 }
894}
895
896impl AdaptiveParameterSelector<Trained> {
897 pub fn data_characteristics(&self) -> Option<&DataCharacteristics> {
899 self.data_characteristics_.as_ref()
900 }
901
902 pub fn optimal_parameters(&self) -> Option<&HashMap<String, Float>> {
904 self.optimal_parameters_.as_ref()
905 }
906
907 pub fn parameter_history(&self) -> Option<&Vec<ParameterEvaluation>> {
909 self.parameter_history_.as_ref()
910 }
911
912 pub fn recommend_parameters(&self) -> Result<ParameterRecommendations> {
914 let characteristics = self.data_characteristics_.as_ref().ok_or_else(|| {
915 SklearsError::InvalidInput("No data characteristics available".to_string())
916 })?;
917
918 let optimal_params = self.optimal_parameters_.as_ref().ok_or_else(|| {
919 SklearsError::InvalidInput("No optimal parameters available".to_string())
920 })?;
921
922 let scaling_method = optimal_params.get("scaling_method").unwrap_or(&0.0);
924 let scaling = ScalingParameters {
925 method: match *scaling_method as i32 {
926 0 => "standard".to_string(),
927 1 => "minmax".to_string(),
928 2 => "robust".to_string(),
929 _ => "standard".to_string(),
930 },
931 outlier_threshold: *optimal_params.get("outlier_threshold").unwrap_or(&2.5),
932 quantile_range: (
933 *optimal_params.get("quantile_range_low").unwrap_or(&25.0),
934 *optimal_params.get("quantile_range_high").unwrap_or(&75.0),
935 ),
936 with_centering: true,
937 with_scaling: true,
938 };
939
940 let imputation_strategy = optimal_params.get("imputation_strategy").unwrap_or(&0.0);
942 let avg_missing_pct = characteristics.missing_percentages.iter().sum::<Float>()
943 / characteristics.missing_percentages.len() as Float;
944
945 let imputation = ImputationParameters {
946 strategy: match *imputation_strategy as i32 {
947 0 => "mean".to_string(),
948 1 => "median".to_string(),
949 2 => "knn".to_string(),
950 _ => "mean".to_string(),
951 },
952 n_neighbors: if *imputation_strategy == 2.0 {
953 Some(5)
954 } else {
955 None
956 },
957 outlier_aware: avg_missing_pct > 10.0,
958 max_iterations: if *imputation_strategy == 2.0 {
959 Some(10)
960 } else {
961 None
962 },
963 };
964
965 let contamination_rate = *optimal_params.get("contamination_rate").unwrap_or(&0.1);
967 let outlier_detection = OutlierDetectionParameters {
968 method: "isolation_forest".to_string(),
969 contamination: contamination_rate,
970 threshold: *optimal_params.get("outlier_threshold").unwrap_or(&2.5),
971 ensemble_size: Some(100),
972 };
973
974 let avg_skewness = characteristics
976 .skewness
977 .iter()
978 .map(|x| x.abs())
979 .sum::<Float>()
980 / characteristics.skewness.len() as Float;
981
982 let transformation = TransformationParameters {
983 method: if avg_skewness > 1.5 {
984 "log1p".to_string()
985 } else if avg_skewness > 1.0 {
986 "box_cox".to_string()
987 } else {
988 "none".to_string()
989 },
990 handle_negatives: true,
991 lambda: None, n_quantiles: Some(1000),
993 };
994
995 let confidence = characteristics.quality_score * 0.5
997 + (1.0
998 - (characteristics.missing_percentages.iter().sum::<Float>()
999 / characteristics.missing_percentages.len() as Float
1000 / 100.0))
1001 * 0.3
1002 + (1.0
1003 - (characteristics.outlier_percentages.iter().sum::<Float>()
1004 / characteristics.outlier_percentages.len() as Float
1005 / 100.0))
1006 * 0.2;
1007
1008 Ok(ParameterRecommendations {
1009 scaling,
1010 imputation,
1011 outlier_detection,
1012 transformation,
1013 confidence: confidence.min(1.0).max(0.0),
1014 })
1015 }
1016
1017 pub fn adaptation_report(&self) -> Result<String> {
1019 let characteristics = self.data_characteristics_.as_ref().ok_or_else(|| {
1020 SklearsError::InvalidInput("No data characteristics available".to_string())
1021 })?;
1022
1023 let recommendations = self.recommend_parameters()?;
1024
1025 let mut report = String::new();
1026
1027 report.push_str("=== Adaptive Parameter Selection Report ===\n\n");
1028
1029 report.push_str("=== Data Characteristics ===\n");
1031 report.push_str(&format!("Data shape: {:?}\n", characteristics.shape));
1032 report.push_str(&format!(
1033 "Overall quality score: {:.3}\n",
1034 characteristics.quality_score
1035 ));
1036 report.push_str(&format!(
1037 "Correlation strength: {:.3}\n",
1038 characteristics.correlation_strength
1039 ));
1040 report.push_str(&format!(
1041 "Optimal batch size: {}\n",
1042 characteristics.optimal_batch_size
1043 ));
1044
1045 let avg_missing = characteristics.missing_percentages.iter().sum::<Float>()
1046 / characteristics.missing_percentages.len() as Float;
1047 let avg_outliers = characteristics.outlier_percentages.iter().sum::<Float>()
1048 / characteristics.outlier_percentages.len() as Float;
1049 let avg_skewness = characteristics
1050 .skewness
1051 .iter()
1052 .map(|x| x.abs())
1053 .sum::<Float>()
1054 / characteristics.skewness.len() as Float;
1055
1056 report.push_str(&format!("Average missing values: {:.1}%\n", avg_missing));
1057 report.push_str(&format!("Average outlier rate: {:.1}%\n", avg_outliers));
1058 report.push_str(&format!("Average absolute skewness: {:.3}\n", avg_skewness));
1059 report.push_str("\n");
1060
1061 report.push_str("=== Parameter Recommendations ===\n");
1063 report.push_str(&format!(
1064 "Confidence: {:.1}%\n\n",
1065 recommendations.confidence * 100.0
1066 ));
1067
1068 report.push_str("Scaling:\n");
1069 report.push_str(&format!(" Method: {}\n", recommendations.scaling.method));
1070 report.push_str(&format!(
1071 " Outlier threshold: {:.2}\n",
1072 recommendations.scaling.outlier_threshold
1073 ));
1074 report.push_str(&format!(
1075 " Quantile range: ({:.1}%, {:.1}%)\n",
1076 recommendations.scaling.quantile_range.0, recommendations.scaling.quantile_range.1
1077 ));
1078 report.push_str("\n");
1079
1080 report.push_str("Imputation:\n");
1081 report.push_str(&format!(
1082 " Strategy: {}\n",
1083 recommendations.imputation.strategy
1084 ));
1085 if let Some(k) = recommendations.imputation.n_neighbors {
1086 report.push_str(&format!(" K-neighbors: {}\n", k));
1087 }
1088 report.push_str(&format!(
1089 " Outlier-aware: {}\n",
1090 recommendations.imputation.outlier_aware
1091 ));
1092 report.push_str("\n");
1093
1094 report.push_str("Outlier Detection:\n");
1095 report.push_str(&format!(
1096 " Method: {}\n",
1097 recommendations.outlier_detection.method
1098 ));
1099 report.push_str(&format!(
1100 " Contamination: {:.3}\n",
1101 recommendations.outlier_detection.contamination
1102 ));
1103 report.push_str(&format!(
1104 " Threshold: {:.2}\n",
1105 recommendations.outlier_detection.threshold
1106 ));
1107 report.push_str("\n");
1108
1109 report.push_str("Transformation:\n");
1110 report.push_str(&format!(
1111 " Method: {}\n",
1112 recommendations.transformation.method
1113 ));
1114 report.push_str(&format!(
1115 " Handle negatives: {}\n",
1116 recommendations.transformation.handle_negatives
1117 ));
1118 report.push_str("\n");
1119
1120 report.push_str("=== Configuration ===\n");
1122 report.push_str(&format!("Strategy: {:?}\n", self.config.strategy));
1123 report.push_str(&format!(
1124 "Cross-validation: {} ({} folds)\n",
1125 self.config.use_cross_validation, self.config.cv_folds
1126 ));
1127 report.push_str(&format!("Parallel processing: {}\n", self.config.parallel));
1128 if let Some(budget) = self.config.time_budget {
1129 report.push_str(&format!("Time budget: {:.1}s\n", budget));
1130 }
1131
1132 Ok(report)
1133 }
1134
1135 pub fn get_insights(&self) -> Vec<String> {
1137 let mut insights = Vec::new();
1138
1139 if let Some(characteristics) = &self.data_characteristics_ {
1140 let avg_missing = characteristics.missing_percentages.iter().sum::<Float>()
1141 / characteristics.missing_percentages.len() as Float;
1142 let avg_outliers = characteristics.outlier_percentages.iter().sum::<Float>()
1143 / characteristics.outlier_percentages.len() as Float;
1144 let avg_skewness = characteristics
1145 .skewness
1146 .iter()
1147 .map(|x| x.abs())
1148 .sum::<Float>()
1149 / characteristics.skewness.len() as Float;
1150
1151 if avg_missing > 20.0 {
1152 insights.push("High missing value rate detected - consider advanced imputation methods like KNN or iterative imputation".to_string());
1153 }
1154
1155 if avg_outliers > 15.0 {
1156 insights.push(
1157 "High outlier rate detected - robust preprocessing methods are recommended"
1158 .to_string(),
1159 );
1160 }
1161
1162 if avg_skewness > 2.0 {
1163 insights.push(
1164 "Highly skewed data detected - consider log or Box-Cox transformations"
1165 .to_string(),
1166 );
1167 }
1168
1169 if characteristics.correlation_strength > 0.7 {
1170 insights.push(
1171 "High feature correlation detected - consider dimensionality reduction"
1172 .to_string(),
1173 );
1174 }
1175
1176 if characteristics.quality_score < 0.5 {
1177 insights.push(
1178 "Low data quality detected - comprehensive preprocessing pipeline recommended"
1179 .to_string(),
1180 );
1181 }
1182
1183 if characteristics.shape.0 > 1_000_000 {
1184 insights.push(
1185 "Large dataset detected - consider streaming or batch processing approaches"
1186 .to_string(),
1187 );
1188 }
1189
1190 if characteristics.optimal_batch_size < characteristics.shape.0 {
1191 insights.push(format!(
1192 "Consider batch processing with batch size: {}",
1193 characteristics.optimal_batch_size
1194 ));
1195 }
1196 }
1197
1198 if insights.is_empty() {
1199 insights.push("Data characteristics are within normal ranges - standard preprocessing should be sufficient".to_string());
1200 }
1201
1202 insights
1203 }
1204}
1205
1206impl Default for AdaptiveParameterSelector<Untrained> {
1207 fn default() -> Self {
1208 Self::new()
1209 }
1210}
1211
1212#[allow(non_snake_case)]
1213#[cfg(test)]
1214mod tests {
1215 use super::*;
1216 use approx::assert_relative_eq;
1217 use scirs2_core::ndarray::Array2;
1218
1219 #[test]
1220 fn test_adaptive_parameter_selector_creation() {
1221 let selector = AdaptiveParameterSelector::new();
1222 assert_eq!(
1223 selector.config.strategy as u8,
1224 AdaptationStrategy::Balanced as u8
1225 );
1226 assert!(selector.config.use_cross_validation);
1227 assert_eq!(selector.config.cv_folds, 5);
1228 }
1229
1230 #[test]
1231 fn test_adaptive_strategies() {
1232 let conservative = AdaptiveParameterSelector::conservative();
1233 assert_eq!(
1234 conservative.config.strategy as u8,
1235 AdaptationStrategy::Conservative as u8
1236 );
1237
1238 let aggressive = AdaptiveParameterSelector::aggressive();
1239 assert_eq!(
1240 aggressive.config.strategy as u8,
1241 AdaptationStrategy::Aggressive as u8
1242 );
1243 }
1244
1245 #[test]
1246 fn test_data_characteristics_analysis() {
1247 let data = Array2::from_shape_vec(
1248 (10, 3),
1249 vec![
1250 1.0, 10.0, 100.0, 2.0, 20.0, 200.0, 3.0, 30.0, 300.0, 4.0, 40.0, 400.0, 5.0, 50.0, 500.0, 6.0, 60.0,
1252 600.0, 7.0, 70.0, 700.0, 8.0, 80.0, 800.0, 100.0, 1000.0, 10000.0, 9.0, 90.0, 900.0,
1254 ],
1255 )
1256 .unwrap();
1257
1258 let selector = AdaptiveParameterSelector::balanced();
1259 let fitted = selector.fit(&data, &()).unwrap();
1260
1261 let characteristics = fitted.data_characteristics().unwrap();
1262 assert_eq!(characteristics.shape, (10, 3));
1263 assert_eq!(characteristics.distribution_types.len(), 3);
1264 assert_eq!(characteristics.skewness.len(), 3);
1265 assert!(characteristics.quality_score >= 0.0 && characteristics.quality_score <= 1.0);
1266 }
1267
1268 #[test]
1269 fn test_parameter_recommendations() {
1270 let data = Array2::from_shape_vec(
1271 (8, 2),
1272 vec![
1273 1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0,
1274 1000.0, 7.0, 70.0, 8.0, 80.0,
1276 ],
1277 )
1278 .unwrap();
1279
1280 let selector = AdaptiveParameterSelector::balanced();
1281 let fitted = selector.fit(&data, &()).unwrap();
1282
1283 let recommendations = fitted.recommend_parameters().unwrap();
1284 assert!(recommendations.confidence >= 0.0 && recommendations.confidence <= 1.0);
1285 assert!(!recommendations.scaling.method.is_empty());
1286 assert!(!recommendations.imputation.strategy.is_empty());
1287 }
1288
1289 #[test]
1290 fn test_parameter_optimization() {
1291 let data = Array2::from_shape_vec(
1292 (6, 2),
1293 vec![
1294 1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0,
1295 1000.0, ],
1297 )
1298 .unwrap();
1299
1300 let selector = AdaptiveParameterSelector::aggressive();
1301 let fitted = selector.fit(&data, &()).unwrap();
1302
1303 let optimal_params = fitted.optimal_parameters().unwrap();
1304 assert!(optimal_params.contains_key("scaling_method"));
1305 assert!(optimal_params.contains_key("outlier_threshold"));
1306 assert!(optimal_params.contains_key("contamination_rate"));
1307 }
1308
1309 #[test]
1310 fn test_distribution_classification() {
1311 let data = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1312
1313 let selector = AdaptiveParameterSelector::new();
1314 let fitted = selector.fit(&data, &()).unwrap();
1315
1316 let characteristics = fitted.data_characteristics().unwrap();
1317 assert!(matches!(
1319 characteristics.distribution_types[0],
1320 DistributionType::Normal | DistributionType::Uniform | DistributionType::Unknown
1321 ));
1322 }
1323
1324 #[test]
1325 fn test_missing_value_handling() {
1326 let data = Array2::from_shape_vec(
1327 (6, 2),
1328 vec![
1329 1.0,
1330 10.0,
1331 2.0,
1332 Float::NAN, 3.0,
1334 30.0,
1335 Float::NAN,
1336 40.0, 5.0,
1338 50.0,
1339 6.0,
1340 60.0,
1341 ],
1342 )
1343 .unwrap();
1344
1345 let selector = AdaptiveParameterSelector::balanced();
1346 let fitted = selector.fit(&data, &()).unwrap();
1347
1348 let characteristics = fitted.data_characteristics().unwrap();
1349 assert!(
1351 characteristics.missing_percentages[0] > 0.0
1352 || characteristics.missing_percentages[1] > 0.0
1353 );
1354 }
1355
1356 #[test]
1357 fn test_adaptation_report() {
1358 let data = Array2::from_shape_vec(
1359 (6, 2),
1360 vec![
1361 1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0, 1000.0,
1362 ],
1363 )
1364 .unwrap();
1365
1366 let selector = AdaptiveParameterSelector::balanced();
1367 let fitted = selector.fit(&data, &()).unwrap();
1368
1369 let report = fitted.adaptation_report().unwrap();
1370 assert!(report.contains("Adaptive Parameter Selection Report"));
1371 assert!(report.contains("Data Characteristics"));
1372 assert!(report.contains("Parameter Recommendations"));
1373 }
1374
1375 #[test]
1376 fn test_insights_generation() {
1377 let data = Array2::from_shape_vec((4, 2), vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0])
1378 .unwrap();
1379
1380 let selector = AdaptiveParameterSelector::conservative();
1381 let fitted = selector.fit(&data, &()).unwrap();
1382
1383 let insights = fitted.get_insights();
1384 assert!(!insights.is_empty());
1385 }
1386
1387 #[test]
1388 fn test_configuration_options() {
1389 let selector = AdaptiveParameterSelector::new()
1390 .cross_validation(false, 3)
1391 .time_budget(30.0)
1392 .parallel(false)
1393 .tolerance(1e-3);
1394
1395 assert!(!selector.config.use_cross_validation);
1396 assert_eq!(selector.config.cv_folds, 3);
1397 assert_eq!(selector.config.time_budget, Some(30.0));
1398 assert!(!selector.config.parallel);
1399 assert_relative_eq!(selector.config.tolerance, 1e-3, epsilon = 1e-10);
1400 }
1401
1402 #[test]
1403 fn test_error_handling() {
1404 let selector = AdaptiveParameterSelector::new();
1405
1406 let empty_data = Array2::from_shape_vec((0, 0), vec![]).unwrap();
1408 assert!(selector.fit(&empty_data, &()).is_err());
1409 }
1410
1411 #[test]
1412 fn test_parameter_bounds() {
1413 let mut bounds = HashMap::new();
1414 bounds.insert("outlier_threshold".to_string(), (1.0, 4.0));
1415
1416 let selector = AdaptiveParameterSelector::new().parameter_bounds(bounds.clone());
1417
1418 assert_eq!(selector.config.parameter_bounds, bounds);
1419 }
1420}