1use crate::adaptive_selection::{OptimizerType, ProblemCharacteristics};
8use crate::error::{OptimError, Result};
9use scirs2_core::ndarray::ScalarOperand;
10use scirs2_core::numeric::Float;
11use std::collections::HashMap;
12use std::fmt::Debug;
13
14#[derive(Debug, Clone)]
16pub enum DomainStrategy {
17 ComputerVision {
19 resolution_adaptive: bool,
21 batch_norm_tuning: bool,
23 augmentation_aware: bool,
25 },
26 NaturalLanguage {
28 sequence_adaptive: bool,
30 attention_optimized: bool,
32 vocab_aware: bool,
34 },
35 RecommendationSystems {
37 collaborative_filtering: bool,
39 matrix_factorization: bool,
41 cold_start_aware: bool,
43 },
44 TimeSeries {
46 temporal_aware: bool,
48 seasonality_adaptive: bool,
50 multi_step: bool,
52 },
53 ReinforcementLearning {
55 policy_gradient: bool,
57 value_function: bool,
59 exploration_aware: bool,
61 },
62 ScientificComputing {
64 stability_focused: bool,
66 precision_critical: bool,
68 sparse_optimized: bool,
70 },
71}
72
73#[derive(Debug, Clone)]
75pub struct DomainConfig<A: Float> {
76 pub base_learning_rate: A,
78 pub recommended_batch_sizes: Vec<usize>,
80 pub gradient_clip_values: Vec<A>,
82 pub regularization_range: (A, A),
84 pub optimizer_ranking: Vec<OptimizerType>,
86 pub domain_params: HashMap<String, A>,
88}
89
90#[derive(Debug)]
92pub struct DomainSpecificSelector<A: Float> {
93 strategy: DomainStrategy,
95 config: DomainConfig<A>,
97 domain_performance: HashMap<String, Vec<DomainPerformanceMetrics<A>>>,
99 transfer_knowledge: Vec<CrossDomainKnowledge<A>>,
101 currentcontext: Option<OptimizationContext<A>>,
103}
104
105#[derive(Debug, Clone)]
107pub struct DomainPerformanceMetrics<A: Float> {
108 pub validation_accuracy: A,
110 pub domain_specific_score: A,
112 pub stability_score: A,
114 pub convergence_epochs: usize,
116 pub resource_efficiency: A,
118 pub transfer_score: A,
120}
121
122#[derive(Debug, Clone)]
124pub struct CrossDomainKnowledge<A: Float> {
125 pub source_domain: String,
127 pub target_domain: String,
129 pub transferable_params: HashMap<String, A>,
131 pub transfer_score: A,
133 pub successful_strategy: OptimizerType,
135}
136
137#[derive(Debug, Clone)]
139pub struct OptimizationContext<A: Float> {
140 pub problem_chars: ProblemCharacteristics,
142 pub resource_constraints: ResourceConstraints<A>,
144 pub training_config: TrainingConfiguration<A>,
146 pub domain_metadata: HashMap<String, String>,
148}
149
150#[derive(Debug, Clone)]
152pub struct ResourceConstraints<A: Float> {
153 pub max_memory: usize,
155 pub max_time: A,
157 pub gpu_available: bool,
159 pub distributed_capable: bool,
161 pub energy_efficient: bool,
163}
164
165#[derive(Debug, Clone)]
167pub struct TrainingConfiguration<A: Float> {
168 pub max_epochs: usize,
170 pub early_stopping_patience: usize,
172 pub validation_frequency: usize,
174 pub lr_schedule_type: LearningRateScheduleType,
176 pub regularization_approach: RegularizationApproach<A>,
178}
179
180#[derive(Debug, Clone)]
182pub enum LearningRateScheduleType {
183 Constant,
185 ExponentialDecay {
187 decay_rate: f64,
189 },
190 CosineAnnealing {
192 t_max: usize,
194 },
195 ReduceOnPlateau {
197 patience: usize,
199 factor: f64,
201 },
202 OneCycle {
204 max_lr: f64,
206 },
207}
208
209#[derive(Debug, Clone)]
211pub enum RegularizationApproach<A: Float> {
212 L2Only {
214 weight: A,
216 },
217 L1Only {
219 weight: A,
221 },
222 ElasticNet {
224 l1_weight: A,
226 l2_weight: A,
228 },
229 Dropout {
231 dropout_rate: A,
233 },
234 Combined {
236 l2_weight: A,
238 dropout_rate: A,
240 additional_techniques: Vec<String>,
242 },
243}
244
245impl<A: Float + ScalarOperand + Debug + std::iter::Sum + Send + Sync> DomainSpecificSelector<A> {
246 pub fn new(strategy: DomainStrategy) -> Self {
248 let config = Self::default_config_for_strategy(&strategy);
249
250 Self {
251 strategy,
252 config,
253 domain_performance: HashMap::new(),
254 transfer_knowledge: Vec::new(),
255 currentcontext: None,
256 }
257 }
258
259 pub fn setcontext(&mut self, context: OptimizationContext<A>) {
261 self.currentcontext = Some(context);
262 }
263
264 pub fn select_optimal_config(&mut self) -> Result<DomainOptimizationConfig<A>> {
266 let context = self
267 .currentcontext
268 .as_ref()
269 .ok_or_else(|| OptimError::InvalidConfig("No optimization context set".to_string()))?;
270
271 match &self.strategy {
272 DomainStrategy::ComputerVision {
273 resolution_adaptive,
274 batch_norm_tuning,
275 augmentation_aware,
276 } => self.optimize_computer_vision(
277 context,
278 *resolution_adaptive,
279 *batch_norm_tuning,
280 *augmentation_aware,
281 ),
282 DomainStrategy::NaturalLanguage {
283 sequence_adaptive,
284 attention_optimized,
285 vocab_aware,
286 } => self.optimize_natural_language(
287 context,
288 *sequence_adaptive,
289 *attention_optimized,
290 *vocab_aware,
291 ),
292 DomainStrategy::RecommendationSystems {
293 collaborative_filtering,
294 matrix_factorization,
295 cold_start_aware,
296 } => self.optimize_recommendation_systems(
297 context,
298 *collaborative_filtering,
299 *matrix_factorization,
300 *cold_start_aware,
301 ),
302 DomainStrategy::TimeSeries {
303 temporal_aware,
304 seasonality_adaptive,
305 multi_step,
306 } => self.optimize_time_series(
307 context,
308 *temporal_aware,
309 *seasonality_adaptive,
310 *multi_step,
311 ),
312 DomainStrategy::ReinforcementLearning {
313 policy_gradient,
314 value_function,
315 exploration_aware,
316 } => self.optimize_reinforcement_learning(
317 context,
318 *policy_gradient,
319 *value_function,
320 *exploration_aware,
321 ),
322 DomainStrategy::ScientificComputing {
323 stability_focused,
324 precision_critical,
325 sparse_optimized,
326 } => self.optimize_scientific_computing(
327 context,
328 *stability_focused,
329 *precision_critical,
330 *sparse_optimized,
331 ),
332 }
333 }
334
335 fn optimize_computer_vision(
337 &self,
338 context: &OptimizationContext<A>,
339 resolution_adaptive: bool,
340 batch_norm_tuning: bool,
341 augmentation_aware: bool,
342 ) -> Result<DomainOptimizationConfig<A>> {
343 let mut config = DomainOptimizationConfig::default();
344
345 if resolution_adaptive {
347 let resolution_factor = self.estimate_resolution_factor(&context.problem_chars);
348 config.learning_rate =
349 self.config.base_learning_rate * A::from(resolution_factor).expect("unwrap failed");
350
351 if context.problem_chars.input_dim > 512 * 512 {
353 config.learning_rate = config.learning_rate * A::from(0.5).expect("unwrap failed");
354 }
355 }
356
357 if batch_norm_tuning {
359 config.optimizer_type = OptimizerType::AdamW; config.specialized_params.insert(
361 "batch_norm_momentum".to_string(),
362 A::from(0.99).expect("unwrap failed"),
363 );
364 config.specialized_params.insert(
365 "batch_norm_eps".to_string(),
366 A::from(1e-5).expect("unwrap failed"),
367 );
368 }
369
370 if augmentation_aware {
372 config.regularization_strength =
374 config.regularization_strength * A::from(1.5).expect("unwrap failed");
375 config.specialized_params.insert(
376 "mixup_alpha".to_string(),
377 A::from(0.2).expect("unwrap failed"),
378 );
379 config.specialized_params.insert(
380 "cutmix_alpha".to_string(),
381 A::from(1.0).expect("unwrap failed"),
382 );
383 }
384
385 config.batch_size = self.select_cv_batch_size(&context.resource_constraints);
387 config.gradient_clip_norm = Some(A::from(1.0).expect("unwrap failed"));
388
389 config.lr_schedule = LearningRateScheduleType::CosineAnnealing {
391 t_max: context.training_config.max_epochs,
392 };
393
394 Ok(config)
395 }
396
397 fn optimize_natural_language(
399 &self,
400 context: &OptimizationContext<A>,
401 sequence_adaptive: bool,
402 attention_optimized: bool,
403 vocab_aware: bool,
404 ) -> Result<DomainOptimizationConfig<A>> {
405 let mut config = DomainOptimizationConfig::default();
406
407 if sequence_adaptive {
409 let seq_length = context.problem_chars.input_dim; if seq_length > 512 {
413 config.learning_rate =
414 self.config.base_learning_rate * A::from(0.7).expect("unwrap failed");
415 config.gradient_clip_norm = Some(A::from(0.5).expect("unwrap failed"));
416 } else {
417 config.learning_rate = self.config.base_learning_rate;
418 config.gradient_clip_norm = Some(A::from(1.0).expect("unwrap failed"));
419 }
420 }
421
422 if attention_optimized {
424 config.optimizer_type = OptimizerType::AdamW; config.specialized_params.insert(
426 "attention_dropout".to_string(),
427 A::from(0.1).expect("unwrap failed"),
428 );
429 config.specialized_params.insert(
430 "attention_head_dim".to_string(),
431 A::from(64.0).expect("unwrap failed"),
432 );
433
434 config.specialized_params.insert(
436 "layer_decay_rate".to_string(),
437 A::from(0.95).expect("unwrap failed"),
438 );
439 }
440
441 if vocab_aware {
443 let vocab_size = context.problem_chars.output_dim; if vocab_size > 30000 {
447 config.specialized_params.insert(
448 "tie_embeddings".to_string(),
449 A::from(1.0).expect("unwrap failed"),
450 );
451 config.specialized_params.insert(
452 "embedding_dropout".to_string(),
453 A::from(0.1).expect("unwrap failed"),
454 );
455 }
456 }
457
458 config.batch_size = self.select_nlp_batch_size(&context.resource_constraints);
460 config.lr_schedule = LearningRateScheduleType::OneCycle {
461 max_lr: config.learning_rate.to_f64().expect("unwrap failed"),
462 };
463
464 config.specialized_params.insert(
466 "warmup_steps".to_string(),
467 A::from(1000.0).expect("unwrap failed"),
468 );
469
470 Ok(config)
471 }
472
473 fn optimize_recommendation_systems(
475 &self,
476 context: &OptimizationContext<A>,
477 collaborative_filtering: bool,
478 matrix_factorization: bool,
479 cold_start_aware: bool,
480 ) -> Result<DomainOptimizationConfig<A>> {
481 let mut config = DomainOptimizationConfig::default();
482
483 if collaborative_filtering {
485 config.optimizer_type = OptimizerType::Adam; config.regularization_strength = A::from(0.01).expect("unwrap failed"); config.specialized_params.insert(
488 "negative_sampling_rate".to_string(),
489 A::from(5.0).expect("unwrap failed"),
490 );
491 }
492
493 if matrix_factorization {
495 config.learning_rate = A::from(0.01).expect("unwrap failed"); config.specialized_params.insert(
497 "embedding_dim".to_string(),
498 A::from(128.0).expect("unwrap failed"),
499 );
500 config.specialized_params.insert(
501 "factorization_rank".to_string(),
502 A::from(50.0).expect("unwrap failed"),
503 );
504 }
505
506 if cold_start_aware {
508 config.specialized_params.insert(
509 "content_weight".to_string(),
510 A::from(0.3).expect("unwrap failed"),
511 );
512 config.specialized_params.insert(
513 "popularity_bias".to_string(),
514 A::from(0.1).expect("unwrap failed"),
515 );
516 }
517
518 config.batch_size = self.select_recsys_batch_size(&context.resource_constraints);
520 config.gradient_clip_norm = Some(A::from(5.0).expect("unwrap failed")); Ok(config)
523 }
524
525 fn optimize_time_series(
527 &self,
528 context: &OptimizationContext<A>,
529 temporal_aware: bool,
530 seasonality_adaptive: bool,
531 multi_step: bool,
532 ) -> Result<DomainOptimizationConfig<A>> {
533 let mut config = DomainOptimizationConfig::default();
534
535 if temporal_aware {
537 config.optimizer_type = OptimizerType::RMSprop; config.learning_rate = A::from(0.001).expect("unwrap failed"); config.specialized_params.insert(
540 "sequence_length".to_string(),
541 A::from(context.problem_chars.input_dim as f64).expect("unwrap failed"),
542 );
543 }
544
545 if seasonality_adaptive {
547 config.specialized_params.insert(
548 "seasonal_periods".to_string(),
549 A::from(24.0).expect("unwrap failed"),
550 ); config.specialized_params.insert(
552 "trend_strength".to_string(),
553 A::from(0.1).expect("unwrap failed"),
554 );
555 }
556
557 if multi_step {
559 config.specialized_params.insert(
560 "prediction_horizon".to_string(),
561 A::from(12.0).expect("unwrap failed"),
562 );
563 config.specialized_params.insert(
564 "multi_step_loss_weight".to_string(),
565 A::from(0.8).expect("unwrap failed"),
566 );
567 }
568
569 config.batch_size = 32; config.gradient_clip_norm = Some(A::from(1.0).expect("unwrap failed"));
572 config.lr_schedule = LearningRateScheduleType::ReduceOnPlateau {
573 patience: 10,
574 factor: 0.5,
575 };
576
577 Ok(config)
578 }
579
580 fn optimize_reinforcement_learning(
582 &self,
583 context: &OptimizationContext<A>,
584 policy_gradient: bool,
585 value_function: bool,
586 exploration_aware: bool,
587 ) -> Result<DomainOptimizationConfig<A>> {
588 let mut config = DomainOptimizationConfig::default();
589
590 if policy_gradient {
592 config.optimizer_type = OptimizerType::Adam;
593 config.learning_rate = A::from(3e-4).expect("unwrap failed"); config.specialized_params.insert(
595 "entropy_coeff".to_string(),
596 A::from(0.01).expect("unwrap failed"),
597 );
598 }
599
600 if value_function {
602 config.specialized_params.insert(
603 "value_loss_coeff".to_string(),
604 A::from(0.5).expect("unwrap failed"),
605 );
606 config.specialized_params.insert(
607 "huber_loss_delta".to_string(),
608 A::from(1.0).expect("unwrap failed"),
609 );
610 }
611
612 if exploration_aware {
614 config.specialized_params.insert(
615 "epsilon_start".to_string(),
616 A::from(1.0).expect("unwrap failed"),
617 );
618 config.specialized_params.insert(
619 "epsilon_end".to_string(),
620 A::from(0.1).expect("unwrap failed"),
621 );
622 config.specialized_params.insert(
623 "epsilon_decay".to_string(),
624 A::from(0.995).expect("unwrap failed"),
625 );
626 }
627
628 config.batch_size = 64; config.gradient_clip_norm = Some(A::from(0.5).expect("unwrap failed")); config.lr_schedule = LearningRateScheduleType::Constant; Ok(config)
634 }
635
636 fn optimize_scientific_computing(
638 &self,
639 context: &OptimizationContext<A>,
640 stability_focused: bool,
641 precision_critical: bool,
642 sparse_optimized: bool,
643 ) -> Result<DomainOptimizationConfig<A>> {
644 let mut config = DomainOptimizationConfig::default();
645
646 if stability_focused {
648 config.optimizer_type = OptimizerType::LBFGS; config.learning_rate = A::from(0.1).expect("unwrap failed"); config.specialized_params.insert(
651 "line_search_tolerance".to_string(),
652 A::from(1e-6).expect("unwrap failed"),
653 );
654 }
655
656 if precision_critical {
658 config.specialized_params.insert(
659 "convergence_tolerance".to_string(),
660 A::from(1e-8).expect("unwrap failed"),
661 );
662 config.specialized_params.insert(
663 "max_iterations".to_string(),
664 A::from(1000.0).expect("unwrap failed"),
665 );
666 }
667
668 if sparse_optimized {
670 config.optimizer_type = OptimizerType::Adam;
671 config.specialized_params.insert(
672 "sparsity_threshold".to_string(),
673 A::from(1e-6).expect("unwrap failed"),
674 );
675 }
676
677 config.batch_size = context.problem_chars.dataset_size.min(1024); config.gradient_clip_norm = None; config.lr_schedule = LearningRateScheduleType::Constant; Ok(config)
683 }
684
685 pub fn update_domain_performance(
687 &mut self,
688 domain: String,
689 metrics: DomainPerformanceMetrics<A>,
690 ) {
691 self.domain_performance
692 .entry(domain)
693 .or_default()
694 .push(metrics);
695 }
696
697 pub fn record_transfer_knowledge(&mut self, knowledge: CrossDomainKnowledge<A>) {
699 self.transfer_knowledge.push(knowledge);
700 }
701
702 pub fn get_domain_recommendations(&self, domain: &str) -> Vec<DomainRecommendation<A>> {
704 let mut recommendations = Vec::new();
705
706 if let Some(history) = self.domain_performance.get(domain) {
708 if !history.is_empty() {
709 let avg_performance = history.iter().map(|m| m.validation_accuracy).sum::<A>()
710 / A::from(history.len()).expect("unwrap failed");
711
712 recommendations.push(DomainRecommendation {
713 recommendation_type: RecommendationType::PerformanceBaseline,
714 description: format!(
715 "Historical average performance: {:.4}",
716 avg_performance.to_f64().expect("unwrap failed")
717 ),
718 confidence: A::from(0.8).expect("unwrap failed"),
719 action: "Consider this as baseline for improvements".to_string(),
720 });
721 }
722 }
723
724 for knowledge in &self.transfer_knowledge {
726 if knowledge.target_domain == domain {
727 recommendations.push(DomainRecommendation {
728 recommendation_type: RecommendationType::TransferLearning,
729 description: format!(
730 "Transfer from {} domain with {:.2} effectiveness",
731 knowledge.source_domain,
732 knowledge.transfer_score.to_f64().expect("unwrap failed")
733 ),
734 confidence: knowledge.transfer_score,
735 action: format!("Use {:?} optimizer", knowledge.successful_strategy),
736 });
737 }
738 }
739
740 recommendations
741 }
742
743 fn estimate_resolution_factor(&self, problem_chars: &ProblemCharacteristics) -> f64 {
745 let resolution = problem_chars.input_dim as f64;
746
747 if resolution > 1_000_000.0 {
748 0.5
750 } else if resolution > 250_000.0 {
751 0.7
753 } else if resolution > 50_000.0 {
754 0.9
756 } else {
757 1.0
759 }
760 }
761
762 fn select_cv_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
763 if constraints.max_memory > 16_000_000_000 {
764 128
766 } else if constraints.max_memory > 8_000_000_000 {
767 64
769 } else {
770 32
771 }
772 }
773
774 fn select_nlp_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
775 if constraints.max_memory > 32_000_000_000 {
776 64
778 } else if constraints.max_memory > 16_000_000_000 {
779 32
781 } else {
782 16
783 }
784 }
785
786 fn select_recsys_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
787 if constraints.max_memory > 8_000_000_000 {
789 512
790 } else {
791 256
792 }
793 }
794
795 fn default_config_for_strategy(strategy: &DomainStrategy) -> DomainConfig<A> {
797 match strategy {
798 DomainStrategy::ComputerVision { .. } => DomainConfig {
799 base_learning_rate: A::from(0.001).expect("unwrap failed"),
800 recommended_batch_sizes: vec![32, 64, 128],
801 gradient_clip_values: vec![
802 A::from(1.0).expect("unwrap failed"),
803 A::from(2.0).expect("unwrap failed"),
804 ],
805 regularization_range: (
806 A::from(1e-5).expect("unwrap failed"),
807 A::from(1e-2).expect("unwrap failed"),
808 ),
809 optimizer_ranking: vec![
810 OptimizerType::AdamW,
811 OptimizerType::SGDMomentum,
812 OptimizerType::Adam,
813 ],
814 domain_params: HashMap::new(),
815 },
816 DomainStrategy::NaturalLanguage { .. } => DomainConfig {
817 base_learning_rate: A::from(2e-5).expect("unwrap failed"),
818 recommended_batch_sizes: vec![16, 32, 64],
819 gradient_clip_values: vec![
820 A::from(0.5).expect("unwrap failed"),
821 A::from(1.0).expect("unwrap failed"),
822 ],
823 regularization_range: (
824 A::from(1e-4).expect("unwrap failed"),
825 A::from(1e-1).expect("unwrap failed"),
826 ),
827 optimizer_ranking: vec![OptimizerType::AdamW, OptimizerType::Adam],
828 domain_params: HashMap::new(),
829 },
830 DomainStrategy::RecommendationSystems { .. } => DomainConfig {
831 base_learning_rate: A::from(0.01).expect("unwrap failed"),
832 recommended_batch_sizes: vec![128, 256, 512],
833 gradient_clip_values: vec![
834 A::from(5.0).expect("unwrap failed"),
835 A::from(10.0).expect("unwrap failed"),
836 ],
837 regularization_range: (
838 A::from(1e-3).expect("unwrap failed"),
839 A::from(1e-1).expect("unwrap failed"),
840 ),
841 optimizer_ranking: vec![OptimizerType::Adam, OptimizerType::AdaGrad],
842 domain_params: HashMap::new(),
843 },
844 DomainStrategy::TimeSeries { .. } => DomainConfig {
845 base_learning_rate: A::from(0.001).expect("unwrap failed"),
846 recommended_batch_sizes: vec![16, 32, 64],
847 gradient_clip_values: vec![A::from(1.0).expect("unwrap failed")],
848 regularization_range: (
849 A::from(1e-4).expect("unwrap failed"),
850 A::from(1e-2).expect("unwrap failed"),
851 ),
852 optimizer_ranking: vec![OptimizerType::RMSprop, OptimizerType::Adam],
853 domain_params: HashMap::new(),
854 },
855 DomainStrategy::ReinforcementLearning { .. } => DomainConfig {
856 base_learning_rate: A::from(3e-4).expect("unwrap failed"),
857 recommended_batch_sizes: vec![32, 64, 128],
858 gradient_clip_values: vec![A::from(0.5).expect("unwrap failed")],
859 regularization_range: (
860 A::from(1e-4).expect("unwrap failed"),
861 A::from(1e-2).expect("unwrap failed"),
862 ),
863 optimizer_ranking: vec![OptimizerType::Adam],
864 domain_params: HashMap::new(),
865 },
866 DomainStrategy::ScientificComputing { .. } => DomainConfig {
867 base_learning_rate: A::from(0.1).expect("unwrap failed"),
868 recommended_batch_sizes: vec![64, 128, 256, 512],
869 gradient_clip_values: vec![],
870 regularization_range: (
871 A::from(1e-6).expect("unwrap failed"),
872 A::from(1e-3).expect("unwrap failed"),
873 ),
874 optimizer_ranking: vec![OptimizerType::LBFGS, OptimizerType::Adam],
875 domain_params: HashMap::new(),
876 },
877 }
878 }
879}
880
881#[derive(Debug, Clone)]
883pub struct DomainOptimizationConfig<A: Float> {
884 pub optimizer_type: OptimizerType,
886 pub learning_rate: A,
888 pub batch_size: usize,
890 pub gradient_clip_norm: Option<A>,
892 pub regularization_strength: A,
894 pub lr_schedule: LearningRateScheduleType,
896 pub specialized_params: HashMap<String, A>,
898}
899
900impl<A: Float + Send + Sync> Default for DomainOptimizationConfig<A> {
901 fn default() -> Self {
902 Self {
903 optimizer_type: OptimizerType::Adam,
904 learning_rate: A::from(0.001).expect("unwrap failed"),
905 batch_size: 32,
906 gradient_clip_norm: Some(A::from(1.0).expect("unwrap failed")),
907 regularization_strength: A::from(1e-4).expect("unwrap failed"),
908 lr_schedule: LearningRateScheduleType::Constant,
909 specialized_params: HashMap::new(),
910 }
911 }
912}
913
914#[derive(Debug, Clone)]
916pub struct DomainRecommendation<A: Float> {
917 pub recommendation_type: RecommendationType,
919 pub description: String,
921 pub confidence: A,
923 pub action: String,
925}
926
927#[derive(Debug, Clone)]
929pub enum RecommendationType {
930 PerformanceBaseline,
932 TransferLearning,
934 HyperparameterTuning,
936 ArchitectureChange,
938 ResourceOptimization,
940}
941
942#[cfg(test)]
943mod tests {
944 use super::*;
945 use crate::adaptive_selection::ProblemType;
946
947 #[test]
948 fn test_domain_specific_selector_creation() {
949 let strategy = DomainStrategy::ComputerVision {
950 resolution_adaptive: true,
951 batch_norm_tuning: true,
952 augmentation_aware: true,
953 };
954
955 let selector = DomainSpecificSelector::<f64>::new(strategy);
956 assert_eq!(selector.config.optimizer_ranking[0], OptimizerType::AdamW);
957 }
958
959 #[test]
960 fn test_computer_vision_optimization() {
961 let strategy = DomainStrategy::ComputerVision {
962 resolution_adaptive: true,
963 batch_norm_tuning: true,
964 augmentation_aware: true,
965 };
966
967 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
968
969 let context = OptimizationContext {
970 problem_chars: ProblemCharacteristics {
971 dataset_size: 50000,
972 input_dim: 224 * 224 * 3, output_dim: 1000,
974 problem_type: ProblemType::ComputerVision,
975 gradient_sparsity: 0.1,
976 gradient_noise: 0.05,
977 memory_budget: 8_000_000_000,
978 time_budget: 3600.0,
979 batch_size: 64,
980 lr_sensitivity: 0.5,
981 regularization_strength: 0.01,
982 architecture_type: Some("ResNet".to_string()),
983 },
984 resource_constraints: ResourceConstraints {
985 max_memory: 17_000_000_000, max_time: 7200.0,
987 gpu_available: true,
988 distributed_capable: false,
989 energy_efficient: false,
990 },
991 training_config: TrainingConfiguration {
992 max_epochs: 100,
993 early_stopping_patience: 10,
994 validation_frequency: 1,
995 lr_schedule_type: LearningRateScheduleType::CosineAnnealing { t_max: 100 },
996 regularization_approach: RegularizationApproach::L2Only { weight: 1e-4 },
997 },
998 domain_metadata: HashMap::new(),
999 };
1000
1001 selector.setcontext(context);
1002 let config = selector.select_optimal_config().expect("unwrap failed");
1003
1004 assert_eq!(config.optimizer_type, OptimizerType::AdamW);
1005 assert_eq!(config.batch_size, 128); assert!(config.gradient_clip_norm.is_some());
1007 }
1008
1009 #[test]
1010 fn test_natural_language_optimization() {
1011 let strategy = DomainStrategy::NaturalLanguage {
1012 sequence_adaptive: true,
1013 attention_optimized: true,
1014 vocab_aware: true,
1015 };
1016
1017 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1018
1019 let context = OptimizationContext {
1020 problem_chars: ProblemCharacteristics {
1021 dataset_size: 100000,
1022 input_dim: 512, output_dim: 50000, problem_type: ProblemType::NaturalLanguage,
1025 gradient_sparsity: 0.2,
1026 gradient_noise: 0.1,
1027 memory_budget: 32_000_000_000,
1028 time_budget: 7200.0,
1029 batch_size: 32,
1030 lr_sensitivity: 0.8,
1031 regularization_strength: 0.1,
1032 architecture_type: Some("Transformer".to_string()),
1033 },
1034 resource_constraints: ResourceConstraints {
1035 max_memory: 32_000_000_000,
1036 max_time: 10800.0,
1037 gpu_available: true,
1038 distributed_capable: true,
1039 energy_efficient: false,
1040 },
1041 training_config: TrainingConfiguration {
1042 max_epochs: 50,
1043 early_stopping_patience: 5,
1044 validation_frequency: 1,
1045 lr_schedule_type: LearningRateScheduleType::OneCycle { max_lr: 2e-5 },
1046 regularization_approach: RegularizationApproach::Dropout { dropout_rate: 0.1 },
1047 },
1048 domain_metadata: HashMap::new(),
1049 };
1050
1051 selector.setcontext(context);
1052 let config = selector.select_optimal_config().expect("unwrap failed");
1053
1054 assert_eq!(config.optimizer_type, OptimizerType::AdamW);
1055 assert!(config.specialized_params.contains_key("warmup_steps"));
1056 assert!(config.specialized_params.contains_key("tie_embeddings")); }
1058
1059 #[test]
1060 fn test_time_series_optimization() {
1061 let strategy = DomainStrategy::TimeSeries {
1062 temporal_aware: true,
1063 seasonality_adaptive: true,
1064 multi_step: true,
1065 };
1066
1067 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1068
1069 let context = OptimizationContext {
1070 problem_chars: ProblemCharacteristics {
1071 dataset_size: 10000,
1072 input_dim: 168, output_dim: 24, problem_type: ProblemType::TimeSeries,
1075 gradient_sparsity: 0.05,
1076 gradient_noise: 0.2,
1077 memory_budget: 4_000_000_000,
1078 time_budget: 1800.0,
1079 batch_size: 32,
1080 lr_sensitivity: 0.7,
1081 regularization_strength: 0.01,
1082 architecture_type: Some("LSTM".to_string()),
1083 },
1084 resource_constraints: ResourceConstraints {
1085 max_memory: 8_000_000_000,
1086 max_time: 3600.0,
1087 gpu_available: true,
1088 distributed_capable: false,
1089 energy_efficient: true,
1090 },
1091 training_config: TrainingConfiguration {
1092 max_epochs: 200,
1093 early_stopping_patience: 20,
1094 validation_frequency: 5,
1095 lr_schedule_type: LearningRateScheduleType::ReduceOnPlateau {
1096 patience: 10,
1097 factor: 0.5,
1098 },
1099 regularization_approach: RegularizationApproach::L2Only { weight: 1e-4 },
1100 },
1101 domain_metadata: HashMap::new(),
1102 };
1103
1104 selector.setcontext(context);
1105 let config = selector.select_optimal_config().expect("unwrap failed");
1106
1107 assert_eq!(config.optimizer_type, OptimizerType::RMSprop);
1108 assert_eq!(config.batch_size, 32);
1109 assert!(config.specialized_params.contains_key("seasonal_periods"));
1110 assert!(config.specialized_params.contains_key("prediction_horizon"));
1111 }
1112
1113 #[test]
1114 fn test_performance_tracking() {
1115 let strategy = DomainStrategy::ComputerVision {
1116 resolution_adaptive: true,
1117 batch_norm_tuning: false,
1118 augmentation_aware: false,
1119 };
1120
1121 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1122
1123 let metrics = DomainPerformanceMetrics {
1124 validation_accuracy: 0.95,
1125 domain_specific_score: 0.92,
1126 stability_score: 0.88,
1127 convergence_epochs: 50,
1128 resource_efficiency: 0.85,
1129 transfer_score: 0.7,
1130 };
1131
1132 selector.update_domain_performance("computer_vision".to_string(), metrics);
1133
1134 let recommendations = selector.get_domain_recommendations("computer_vision");
1135 assert!(!recommendations.is_empty());
1136 assert!(recommendations[0].description.contains("0.95"));
1137 }
1138
1139 #[test]
1140 fn test_cross_domain_transfer() {
1141 let strategy = DomainStrategy::ComputerVision {
1142 resolution_adaptive: true,
1143 batch_norm_tuning: true,
1144 augmentation_aware: true,
1145 };
1146
1147 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1148
1149 let transfer_knowledge = CrossDomainKnowledge {
1150 source_domain: "natural_language".to_string(),
1151 target_domain: "computer_vision".to_string(),
1152 transferable_params: HashMap::from([
1153 ("learning_rate".to_string(), 0.001),
1154 ("weight_decay".to_string(), 0.01),
1155 ]),
1156 transfer_score: 0.8,
1157 successful_strategy: OptimizerType::AdamW,
1158 };
1159
1160 selector.record_transfer_knowledge(transfer_knowledge);
1161
1162 let recommendations = selector.get_domain_recommendations("computer_vision");
1163 assert!(recommendations
1164 .iter()
1165 .any(|r| matches!(r.recommendation_type, RecommendationType::TransferLearning)));
1166 }
1167
1168 #[test]
1169 fn test_scientific_computing_optimization() {
1170 let strategy = DomainStrategy::ScientificComputing {
1171 stability_focused: true,
1172 precision_critical: true,
1173 sparse_optimized: false,
1174 };
1175
1176 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1177
1178 let context = OptimizationContext {
1179 problem_chars: ProblemCharacteristics {
1180 dataset_size: 1000,
1181 input_dim: 100,
1182 output_dim: 1,
1183 problem_type: ProblemType::Regression,
1184 gradient_sparsity: 0.01,
1185 gradient_noise: 0.01,
1186 memory_budget: 16_000_000_000,
1187 time_budget: 7200.0,
1188 batch_size: 100,
1189 lr_sensitivity: 0.3,
1190 regularization_strength: 1e-6,
1191 architecture_type: Some("MLP".to_string()),
1192 },
1193 resource_constraints: ResourceConstraints {
1194 max_memory: 16_000_000_000,
1195 max_time: 7200.0,
1196 gpu_available: false,
1197 distributed_capable: false,
1198 energy_efficient: false,
1199 },
1200 training_config: TrainingConfiguration {
1201 max_epochs: 1000,
1202 early_stopping_patience: 100,
1203 validation_frequency: 10,
1204 lr_schedule_type: LearningRateScheduleType::Constant,
1205 regularization_approach: RegularizationApproach::L2Only { weight: 1e-6 },
1206 },
1207 domain_metadata: HashMap::new(),
1208 };
1209
1210 selector.setcontext(context);
1211 let config = selector.select_optimal_config().expect("unwrap failed");
1212
1213 assert_eq!(config.optimizer_type, OptimizerType::LBFGS);
1214 assert!(config.gradient_clip_norm.is_none()); assert!(config
1216 .specialized_params
1217 .contains_key("convergence_tolerance"));
1218 }
1219}