1use crate::adaptive_selection::{OptimizerType, ProblemCharacteristics};
8use crate::error::{OptimError, Result};
9use scirs2_core::ndarray::ScalarOperand;
10use scirs2_core::numeric::Float;
11use std::collections::HashMap;
12use std::fmt::Debug;
13
14#[derive(Debug, Clone)]
16pub enum DomainStrategy {
17 ComputerVision {
19 resolution_adaptive: bool,
21 batch_norm_tuning: bool,
23 augmentation_aware: bool,
25 },
26 NaturalLanguage {
28 sequence_adaptive: bool,
30 attention_optimized: bool,
32 vocab_aware: bool,
34 },
35 RecommendationSystems {
37 collaborative_filtering: bool,
39 matrix_factorization: bool,
41 cold_start_aware: bool,
43 },
44 TimeSeries {
46 temporal_aware: bool,
48 seasonality_adaptive: bool,
50 multi_step: bool,
52 },
53 ReinforcementLearning {
55 policy_gradient: bool,
57 value_function: bool,
59 exploration_aware: bool,
61 },
62 ScientificComputing {
64 stability_focused: bool,
66 precision_critical: bool,
68 sparse_optimized: bool,
70 },
71}
72
73#[derive(Debug, Clone)]
75pub struct DomainConfig<A: Float> {
76 pub base_learning_rate: A,
78 pub recommended_batch_sizes: Vec<usize>,
80 pub gradient_clip_values: Vec<A>,
82 pub regularization_range: (A, A),
84 pub optimizer_ranking: Vec<OptimizerType>,
86 pub domain_params: HashMap<String, A>,
88}
89
90#[derive(Debug)]
92pub struct DomainSpecificSelector<A: Float> {
93 strategy: DomainStrategy,
95 config: DomainConfig<A>,
97 domain_performance: HashMap<String, Vec<DomainPerformanceMetrics<A>>>,
99 transfer_knowledge: Vec<CrossDomainKnowledge<A>>,
101 currentcontext: Option<OptimizationContext<A>>,
103}
104
105#[derive(Debug, Clone)]
107pub struct DomainPerformanceMetrics<A: Float> {
108 pub validation_accuracy: A,
110 pub domain_specific_score: A,
112 pub stability_score: A,
114 pub convergence_epochs: usize,
116 pub resource_efficiency: A,
118 pub transfer_score: A,
120}
121
122#[derive(Debug, Clone)]
124pub struct CrossDomainKnowledge<A: Float> {
125 pub source_domain: String,
127 pub target_domain: String,
129 pub transferable_params: HashMap<String, A>,
131 pub transfer_score: A,
133 pub successful_strategy: OptimizerType,
135}
136
137#[derive(Debug, Clone)]
139pub struct OptimizationContext<A: Float> {
140 pub problem_chars: ProblemCharacteristics,
142 pub resource_constraints: ResourceConstraints<A>,
144 pub training_config: TrainingConfiguration<A>,
146 pub domain_metadata: HashMap<String, String>,
148}
149
150#[derive(Debug, Clone)]
152pub struct ResourceConstraints<A: Float> {
153 pub max_memory: usize,
155 pub max_time: A,
157 pub gpu_available: bool,
159 pub distributed_capable: bool,
161 pub energy_efficient: bool,
163}
164
165#[derive(Debug, Clone)]
167pub struct TrainingConfiguration<A: Float> {
168 pub max_epochs: usize,
170 pub early_stopping_patience: usize,
172 pub validation_frequency: usize,
174 pub lr_schedule_type: LearningRateScheduleType,
176 pub regularization_approach: RegularizationApproach<A>,
178}
179
180#[derive(Debug, Clone)]
182pub enum LearningRateScheduleType {
183 Constant,
185 ExponentialDecay {
187 decay_rate: f64,
189 },
190 CosineAnnealing {
192 t_max: usize,
194 },
195 ReduceOnPlateau {
197 patience: usize,
199 factor: f64,
201 },
202 OneCycle {
204 max_lr: f64,
206 },
207}
208
209#[derive(Debug, Clone)]
211pub enum RegularizationApproach<A: Float> {
212 L2Only {
214 weight: A,
216 },
217 L1Only {
219 weight: A,
221 },
222 ElasticNet {
224 l1_weight: A,
226 l2_weight: A,
228 },
229 Dropout {
231 dropout_rate: A,
233 },
234 Combined {
236 l2_weight: A,
238 dropout_rate: A,
240 additional_techniques: Vec<String>,
242 },
243}
244
245impl<A: Float + ScalarOperand + Debug + std::iter::Sum + Send + Sync> DomainSpecificSelector<A> {
246 pub fn new(strategy: DomainStrategy) -> Self {
248 let config = Self::default_config_for_strategy(&strategy);
249
250 Self {
251 strategy,
252 config,
253 domain_performance: HashMap::new(),
254 transfer_knowledge: Vec::new(),
255 currentcontext: None,
256 }
257 }
258
259 pub fn setcontext(&mut self, context: OptimizationContext<A>) {
261 self.currentcontext = Some(context);
262 }
263
264 pub fn select_optimal_config(&mut self) -> Result<DomainOptimizationConfig<A>> {
266 let context = self
267 .currentcontext
268 .as_ref()
269 .ok_or_else(|| OptimError::InvalidConfig("No optimization context set".to_string()))?;
270
271 match &self.strategy {
272 DomainStrategy::ComputerVision {
273 resolution_adaptive,
274 batch_norm_tuning,
275 augmentation_aware,
276 } => self.optimize_computer_vision(
277 context,
278 *resolution_adaptive,
279 *batch_norm_tuning,
280 *augmentation_aware,
281 ),
282 DomainStrategy::NaturalLanguage {
283 sequence_adaptive,
284 attention_optimized,
285 vocab_aware,
286 } => self.optimize_natural_language(
287 context,
288 *sequence_adaptive,
289 *attention_optimized,
290 *vocab_aware,
291 ),
292 DomainStrategy::RecommendationSystems {
293 collaborative_filtering,
294 matrix_factorization,
295 cold_start_aware,
296 } => self.optimize_recommendation_systems(
297 context,
298 *collaborative_filtering,
299 *matrix_factorization,
300 *cold_start_aware,
301 ),
302 DomainStrategy::TimeSeries {
303 temporal_aware,
304 seasonality_adaptive,
305 multi_step,
306 } => self.optimize_time_series(
307 context,
308 *temporal_aware,
309 *seasonality_adaptive,
310 *multi_step,
311 ),
312 DomainStrategy::ReinforcementLearning {
313 policy_gradient,
314 value_function,
315 exploration_aware,
316 } => self.optimize_reinforcement_learning(
317 context,
318 *policy_gradient,
319 *value_function,
320 *exploration_aware,
321 ),
322 DomainStrategy::ScientificComputing {
323 stability_focused,
324 precision_critical,
325 sparse_optimized,
326 } => self.optimize_scientific_computing(
327 context,
328 *stability_focused,
329 *precision_critical,
330 *sparse_optimized,
331 ),
332 }
333 }
334
335 fn optimize_computer_vision(
337 &self,
338 context: &OptimizationContext<A>,
339 resolution_adaptive: bool,
340 batch_norm_tuning: bool,
341 augmentation_aware: bool,
342 ) -> Result<DomainOptimizationConfig<A>> {
343 let mut config = DomainOptimizationConfig::default();
344
345 if resolution_adaptive {
347 let resolution_factor = self.estimate_resolution_factor(&context.problem_chars);
348 config.learning_rate =
349 self.config.base_learning_rate * A::from(resolution_factor).unwrap();
350
351 if context.problem_chars.input_dim > 512 * 512 {
353 config.learning_rate = config.learning_rate * A::from(0.5).unwrap();
354 }
355 }
356
357 if batch_norm_tuning {
359 config.optimizer_type = OptimizerType::AdamW; config
361 .specialized_params
362 .insert("batch_norm_momentum".to_string(), A::from(0.99).unwrap());
363 config
364 .specialized_params
365 .insert("batch_norm_eps".to_string(), A::from(1e-5).unwrap());
366 }
367
368 if augmentation_aware {
370 config.regularization_strength = config.regularization_strength * A::from(1.5).unwrap();
372 config
373 .specialized_params
374 .insert("mixup_alpha".to_string(), A::from(0.2).unwrap());
375 config
376 .specialized_params
377 .insert("cutmix_alpha".to_string(), A::from(1.0).unwrap());
378 }
379
380 config.batch_size = self.select_cv_batch_size(&context.resource_constraints);
382 config.gradient_clip_norm = Some(A::from(1.0).unwrap());
383
384 config.lr_schedule = LearningRateScheduleType::CosineAnnealing {
386 t_max: context.training_config.max_epochs,
387 };
388
389 Ok(config)
390 }
391
392 fn optimize_natural_language(
394 &self,
395 context: &OptimizationContext<A>,
396 sequence_adaptive: bool,
397 attention_optimized: bool,
398 vocab_aware: bool,
399 ) -> Result<DomainOptimizationConfig<A>> {
400 let mut config = DomainOptimizationConfig::default();
401
402 if sequence_adaptive {
404 let seq_length = context.problem_chars.input_dim; if seq_length > 512 {
408 config.learning_rate = self.config.base_learning_rate * A::from(0.7).unwrap();
409 config.gradient_clip_norm = Some(A::from(0.5).unwrap());
410 } else {
411 config.learning_rate = self.config.base_learning_rate;
412 config.gradient_clip_norm = Some(A::from(1.0).unwrap());
413 }
414 }
415
416 if attention_optimized {
418 config.optimizer_type = OptimizerType::AdamW; config
420 .specialized_params
421 .insert("attention_dropout".to_string(), A::from(0.1).unwrap());
422 config
423 .specialized_params
424 .insert("attention_head_dim".to_string(), A::from(64.0).unwrap());
425
426 config
428 .specialized_params
429 .insert("layer_decay_rate".to_string(), A::from(0.95).unwrap());
430 }
431
432 if vocab_aware {
434 let vocab_size = context.problem_chars.output_dim; if vocab_size > 30000 {
438 config
439 .specialized_params
440 .insert("tie_embeddings".to_string(), A::from(1.0).unwrap());
441 config
442 .specialized_params
443 .insert("embedding_dropout".to_string(), A::from(0.1).unwrap());
444 }
445 }
446
447 config.batch_size = self.select_nlp_batch_size(&context.resource_constraints);
449 config.lr_schedule = LearningRateScheduleType::OneCycle {
450 max_lr: config.learning_rate.to_f64().unwrap(),
451 };
452
453 config
455 .specialized_params
456 .insert("warmup_steps".to_string(), A::from(1000.0).unwrap());
457
458 Ok(config)
459 }
460
461 fn optimize_recommendation_systems(
463 &self,
464 context: &OptimizationContext<A>,
465 collaborative_filtering: bool,
466 matrix_factorization: bool,
467 cold_start_aware: bool,
468 ) -> Result<DomainOptimizationConfig<A>> {
469 let mut config = DomainOptimizationConfig::default();
470
471 if collaborative_filtering {
473 config.optimizer_type = OptimizerType::Adam; config.regularization_strength = A::from(0.01).unwrap(); config
476 .specialized_params
477 .insert("negative_sampling_rate".to_string(), A::from(5.0).unwrap());
478 }
479
480 if matrix_factorization {
482 config.learning_rate = A::from(0.01).unwrap(); config
484 .specialized_params
485 .insert("embedding_dim".to_string(), A::from(128.0).unwrap());
486 config
487 .specialized_params
488 .insert("factorization_rank".to_string(), A::from(50.0).unwrap());
489 }
490
491 if cold_start_aware {
493 config
494 .specialized_params
495 .insert("content_weight".to_string(), A::from(0.3).unwrap());
496 config
497 .specialized_params
498 .insert("popularity_bias".to_string(), A::from(0.1).unwrap());
499 }
500
501 config.batch_size = self.select_recsys_batch_size(&context.resource_constraints);
503 config.gradient_clip_norm = Some(A::from(5.0).unwrap()); Ok(config)
506 }
507
508 fn optimize_time_series(
510 &self,
511 context: &OptimizationContext<A>,
512 temporal_aware: bool,
513 seasonality_adaptive: bool,
514 multi_step: bool,
515 ) -> Result<DomainOptimizationConfig<A>> {
516 let mut config = DomainOptimizationConfig::default();
517
518 if temporal_aware {
520 config.optimizer_type = OptimizerType::RMSprop; config.learning_rate = A::from(0.001).unwrap(); config.specialized_params.insert(
523 "sequence_length".to_string(),
524 A::from(context.problem_chars.input_dim as f64).unwrap(),
525 );
526 }
527
528 if seasonality_adaptive {
530 config
531 .specialized_params
532 .insert("seasonal_periods".to_string(), A::from(24.0).unwrap()); config
534 .specialized_params
535 .insert("trend_strength".to_string(), A::from(0.1).unwrap());
536 }
537
538 if multi_step {
540 config
541 .specialized_params
542 .insert("prediction_horizon".to_string(), A::from(12.0).unwrap());
543 config
544 .specialized_params
545 .insert("multi_step_loss_weight".to_string(), A::from(0.8).unwrap());
546 }
547
548 config.batch_size = 32; config.gradient_clip_norm = Some(A::from(1.0).unwrap());
551 config.lr_schedule = LearningRateScheduleType::ReduceOnPlateau {
552 patience: 10,
553 factor: 0.5,
554 };
555
556 Ok(config)
557 }
558
559 fn optimize_reinforcement_learning(
561 &self,
562 context: &OptimizationContext<A>,
563 policy_gradient: bool,
564 value_function: bool,
565 exploration_aware: bool,
566 ) -> Result<DomainOptimizationConfig<A>> {
567 let mut config = DomainOptimizationConfig::default();
568
569 if policy_gradient {
571 config.optimizer_type = OptimizerType::Adam;
572 config.learning_rate = A::from(3e-4).unwrap(); config
574 .specialized_params
575 .insert("entropy_coeff".to_string(), A::from(0.01).unwrap());
576 }
577
578 if value_function {
580 config
581 .specialized_params
582 .insert("value_loss_coeff".to_string(), A::from(0.5).unwrap());
583 config
584 .specialized_params
585 .insert("huber_loss_delta".to_string(), A::from(1.0).unwrap());
586 }
587
588 if exploration_aware {
590 config
591 .specialized_params
592 .insert("epsilon_start".to_string(), A::from(1.0).unwrap());
593 config
594 .specialized_params
595 .insert("epsilon_end".to_string(), A::from(0.1).unwrap());
596 config
597 .specialized_params
598 .insert("epsilon_decay".to_string(), A::from(0.995).unwrap());
599 }
600
601 config.batch_size = 64; config.gradient_clip_norm = Some(A::from(0.5).unwrap()); config.lr_schedule = LearningRateScheduleType::Constant; Ok(config)
607 }
608
609 fn optimize_scientific_computing(
611 &self,
612 context: &OptimizationContext<A>,
613 stability_focused: bool,
614 precision_critical: bool,
615 sparse_optimized: bool,
616 ) -> Result<DomainOptimizationConfig<A>> {
617 let mut config = DomainOptimizationConfig::default();
618
619 if stability_focused {
621 config.optimizer_type = OptimizerType::LBFGS; config.learning_rate = A::from(0.1).unwrap(); config
624 .specialized_params
625 .insert("line_search_tolerance".to_string(), A::from(1e-6).unwrap());
626 }
627
628 if precision_critical {
630 config
631 .specialized_params
632 .insert("convergence_tolerance".to_string(), A::from(1e-8).unwrap());
633 config
634 .specialized_params
635 .insert("max_iterations".to_string(), A::from(1000.0).unwrap());
636 }
637
638 if sparse_optimized {
640 config.optimizer_type = OptimizerType::Adam;
641 config
642 .specialized_params
643 .insert("sparsity_threshold".to_string(), A::from(1e-6).unwrap());
644 }
645
646 config.batch_size = context.problem_chars.dataset_size.min(1024); config.gradient_clip_norm = None; config.lr_schedule = LearningRateScheduleType::Constant; Ok(config)
652 }
653
654 pub fn update_domain_performance(
656 &mut self,
657 domain: String,
658 metrics: DomainPerformanceMetrics<A>,
659 ) {
660 self.domain_performance
661 .entry(domain)
662 .or_default()
663 .push(metrics);
664 }
665
666 pub fn record_transfer_knowledge(&mut self, knowledge: CrossDomainKnowledge<A>) {
668 self.transfer_knowledge.push(knowledge);
669 }
670
671 pub fn get_domain_recommendations(&self, domain: &str) -> Vec<DomainRecommendation<A>> {
673 let mut recommendations = Vec::new();
674
675 if let Some(history) = self.domain_performance.get(domain) {
677 if !history.is_empty() {
678 let avg_performance = history.iter().map(|m| m.validation_accuracy).sum::<A>()
679 / A::from(history.len()).unwrap();
680
681 recommendations.push(DomainRecommendation {
682 recommendation_type: RecommendationType::PerformanceBaseline,
683 description: format!(
684 "Historical average performance: {:.4}",
685 avg_performance.to_f64().unwrap()
686 ),
687 confidence: A::from(0.8).unwrap(),
688 action: "Consider this as baseline for improvements".to_string(),
689 });
690 }
691 }
692
693 for knowledge in &self.transfer_knowledge {
695 if knowledge.target_domain == domain {
696 recommendations.push(DomainRecommendation {
697 recommendation_type: RecommendationType::TransferLearning,
698 description: format!(
699 "Transfer from {} domain with {:.2} effectiveness",
700 knowledge.source_domain,
701 knowledge.transfer_score.to_f64().unwrap()
702 ),
703 confidence: knowledge.transfer_score,
704 action: format!("Use {:?} optimizer", knowledge.successful_strategy),
705 });
706 }
707 }
708
709 recommendations
710 }
711
712 fn estimate_resolution_factor(&self, problem_chars: &ProblemCharacteristics) -> f64 {
714 let resolution = problem_chars.input_dim as f64;
715
716 if resolution > 1_000_000.0 {
717 0.5
719 } else if resolution > 250_000.0 {
720 0.7
722 } else if resolution > 50_000.0 {
723 0.9
725 } else {
726 1.0
728 }
729 }
730
731 fn select_cv_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
732 if constraints.max_memory > 16_000_000_000 {
733 128
735 } else if constraints.max_memory > 8_000_000_000 {
736 64
738 } else {
739 32
740 }
741 }
742
743 fn select_nlp_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
744 if constraints.max_memory > 32_000_000_000 {
745 64
747 } else if constraints.max_memory > 16_000_000_000 {
748 32
750 } else {
751 16
752 }
753 }
754
755 fn select_recsys_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
756 if constraints.max_memory > 8_000_000_000 {
758 512
759 } else {
760 256
761 }
762 }
763
764 fn default_config_for_strategy(strategy: &DomainStrategy) -> DomainConfig<A> {
766 match strategy {
767 DomainStrategy::ComputerVision { .. } => DomainConfig {
768 base_learning_rate: A::from(0.001).unwrap(),
769 recommended_batch_sizes: vec![32, 64, 128],
770 gradient_clip_values: vec![A::from(1.0).unwrap(), A::from(2.0).unwrap()],
771 regularization_range: (A::from(1e-5).unwrap(), A::from(1e-2).unwrap()),
772 optimizer_ranking: vec![
773 OptimizerType::AdamW,
774 OptimizerType::SGDMomentum,
775 OptimizerType::Adam,
776 ],
777 domain_params: HashMap::new(),
778 },
779 DomainStrategy::NaturalLanguage { .. } => DomainConfig {
780 base_learning_rate: A::from(2e-5).unwrap(),
781 recommended_batch_sizes: vec![16, 32, 64],
782 gradient_clip_values: vec![A::from(0.5).unwrap(), A::from(1.0).unwrap()],
783 regularization_range: (A::from(1e-4).unwrap(), A::from(1e-1).unwrap()),
784 optimizer_ranking: vec![OptimizerType::AdamW, OptimizerType::Adam],
785 domain_params: HashMap::new(),
786 },
787 DomainStrategy::RecommendationSystems { .. } => DomainConfig {
788 base_learning_rate: A::from(0.01).unwrap(),
789 recommended_batch_sizes: vec![128, 256, 512],
790 gradient_clip_values: vec![A::from(5.0).unwrap(), A::from(10.0).unwrap()],
791 regularization_range: (A::from(1e-3).unwrap(), A::from(1e-1).unwrap()),
792 optimizer_ranking: vec![OptimizerType::Adam, OptimizerType::AdaGrad],
793 domain_params: HashMap::new(),
794 },
795 DomainStrategy::TimeSeries { .. } => DomainConfig {
796 base_learning_rate: A::from(0.001).unwrap(),
797 recommended_batch_sizes: vec![16, 32, 64],
798 gradient_clip_values: vec![A::from(1.0).unwrap()],
799 regularization_range: (A::from(1e-4).unwrap(), A::from(1e-2).unwrap()),
800 optimizer_ranking: vec![OptimizerType::RMSprop, OptimizerType::Adam],
801 domain_params: HashMap::new(),
802 },
803 DomainStrategy::ReinforcementLearning { .. } => DomainConfig {
804 base_learning_rate: A::from(3e-4).unwrap(),
805 recommended_batch_sizes: vec![32, 64, 128],
806 gradient_clip_values: vec![A::from(0.5).unwrap()],
807 regularization_range: (A::from(1e-4).unwrap(), A::from(1e-2).unwrap()),
808 optimizer_ranking: vec![OptimizerType::Adam],
809 domain_params: HashMap::new(),
810 },
811 DomainStrategy::ScientificComputing { .. } => DomainConfig {
812 base_learning_rate: A::from(0.1).unwrap(),
813 recommended_batch_sizes: vec![64, 128, 256, 512],
814 gradient_clip_values: vec![],
815 regularization_range: (A::from(1e-6).unwrap(), A::from(1e-3).unwrap()),
816 optimizer_ranking: vec![OptimizerType::LBFGS, OptimizerType::Adam],
817 domain_params: HashMap::new(),
818 },
819 }
820 }
821}
822
823#[derive(Debug, Clone)]
825pub struct DomainOptimizationConfig<A: Float> {
826 pub optimizer_type: OptimizerType,
828 pub learning_rate: A,
830 pub batch_size: usize,
832 pub gradient_clip_norm: Option<A>,
834 pub regularization_strength: A,
836 pub lr_schedule: LearningRateScheduleType,
838 pub specialized_params: HashMap<String, A>,
840}
841
842impl<A: Float + Send + Sync> Default for DomainOptimizationConfig<A> {
843 fn default() -> Self {
844 Self {
845 optimizer_type: OptimizerType::Adam,
846 learning_rate: A::from(0.001).unwrap(),
847 batch_size: 32,
848 gradient_clip_norm: Some(A::from(1.0).unwrap()),
849 regularization_strength: A::from(1e-4).unwrap(),
850 lr_schedule: LearningRateScheduleType::Constant,
851 specialized_params: HashMap::new(),
852 }
853 }
854}
855
856#[derive(Debug, Clone)]
858pub struct DomainRecommendation<A: Float> {
859 pub recommendation_type: RecommendationType,
861 pub description: String,
863 pub confidence: A,
865 pub action: String,
867}
868
869#[derive(Debug, Clone)]
871pub enum RecommendationType {
872 PerformanceBaseline,
874 TransferLearning,
876 HyperparameterTuning,
878 ArchitectureChange,
880 ResourceOptimization,
882}
883
884#[cfg(test)]
885mod tests {
886 use super::*;
887 use crate::adaptive_selection::ProblemType;
888
889 #[test]
890 fn test_domain_specific_selector_creation() {
891 let strategy = DomainStrategy::ComputerVision {
892 resolution_adaptive: true,
893 batch_norm_tuning: true,
894 augmentation_aware: true,
895 };
896
897 let selector = DomainSpecificSelector::<f64>::new(strategy);
898 assert_eq!(selector.config.optimizer_ranking[0], OptimizerType::AdamW);
899 }
900
901 #[test]
902 fn test_computer_vision_optimization() {
903 let strategy = DomainStrategy::ComputerVision {
904 resolution_adaptive: true,
905 batch_norm_tuning: true,
906 augmentation_aware: true,
907 };
908
909 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
910
911 let context = OptimizationContext {
912 problem_chars: ProblemCharacteristics {
913 dataset_size: 50000,
914 input_dim: 224 * 224 * 3, output_dim: 1000,
916 problem_type: ProblemType::ComputerVision,
917 gradient_sparsity: 0.1,
918 gradient_noise: 0.05,
919 memory_budget: 8_000_000_000,
920 time_budget: 3600.0,
921 batch_size: 64,
922 lr_sensitivity: 0.5,
923 regularization_strength: 0.01,
924 architecture_type: Some("ResNet".to_string()),
925 },
926 resource_constraints: ResourceConstraints {
927 max_memory: 17_000_000_000, max_time: 7200.0,
929 gpu_available: true,
930 distributed_capable: false,
931 energy_efficient: false,
932 },
933 training_config: TrainingConfiguration {
934 max_epochs: 100,
935 early_stopping_patience: 10,
936 validation_frequency: 1,
937 lr_schedule_type: LearningRateScheduleType::CosineAnnealing { t_max: 100 },
938 regularization_approach: RegularizationApproach::L2Only { weight: 1e-4 },
939 },
940 domain_metadata: HashMap::new(),
941 };
942
943 selector.setcontext(context);
944 let config = selector.select_optimal_config().unwrap();
945
946 assert_eq!(config.optimizer_type, OptimizerType::AdamW);
947 assert_eq!(config.batch_size, 128); assert!(config.gradient_clip_norm.is_some());
949 }
950
951 #[test]
952 fn test_natural_language_optimization() {
953 let strategy = DomainStrategy::NaturalLanguage {
954 sequence_adaptive: true,
955 attention_optimized: true,
956 vocab_aware: true,
957 };
958
959 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
960
961 let context = OptimizationContext {
962 problem_chars: ProblemCharacteristics {
963 dataset_size: 100000,
964 input_dim: 512, output_dim: 50000, problem_type: ProblemType::NaturalLanguage,
967 gradient_sparsity: 0.2,
968 gradient_noise: 0.1,
969 memory_budget: 32_000_000_000,
970 time_budget: 7200.0,
971 batch_size: 32,
972 lr_sensitivity: 0.8,
973 regularization_strength: 0.1,
974 architecture_type: Some("Transformer".to_string()),
975 },
976 resource_constraints: ResourceConstraints {
977 max_memory: 32_000_000_000,
978 max_time: 10800.0,
979 gpu_available: true,
980 distributed_capable: true,
981 energy_efficient: false,
982 },
983 training_config: TrainingConfiguration {
984 max_epochs: 50,
985 early_stopping_patience: 5,
986 validation_frequency: 1,
987 lr_schedule_type: LearningRateScheduleType::OneCycle { max_lr: 2e-5 },
988 regularization_approach: RegularizationApproach::Dropout { dropout_rate: 0.1 },
989 },
990 domain_metadata: HashMap::new(),
991 };
992
993 selector.setcontext(context);
994 let config = selector.select_optimal_config().unwrap();
995
996 assert_eq!(config.optimizer_type, OptimizerType::AdamW);
997 assert!(config.specialized_params.contains_key("warmup_steps"));
998 assert!(config.specialized_params.contains_key("tie_embeddings")); }
1000
1001 #[test]
1002 fn test_time_series_optimization() {
1003 let strategy = DomainStrategy::TimeSeries {
1004 temporal_aware: true,
1005 seasonality_adaptive: true,
1006 multi_step: true,
1007 };
1008
1009 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1010
1011 let context = OptimizationContext {
1012 problem_chars: ProblemCharacteristics {
1013 dataset_size: 10000,
1014 input_dim: 168, output_dim: 24, problem_type: ProblemType::TimeSeries,
1017 gradient_sparsity: 0.05,
1018 gradient_noise: 0.2,
1019 memory_budget: 4_000_000_000,
1020 time_budget: 1800.0,
1021 batch_size: 32,
1022 lr_sensitivity: 0.7,
1023 regularization_strength: 0.01,
1024 architecture_type: Some("LSTM".to_string()),
1025 },
1026 resource_constraints: ResourceConstraints {
1027 max_memory: 8_000_000_000,
1028 max_time: 3600.0,
1029 gpu_available: true,
1030 distributed_capable: false,
1031 energy_efficient: true,
1032 },
1033 training_config: TrainingConfiguration {
1034 max_epochs: 200,
1035 early_stopping_patience: 20,
1036 validation_frequency: 5,
1037 lr_schedule_type: LearningRateScheduleType::ReduceOnPlateau {
1038 patience: 10,
1039 factor: 0.5,
1040 },
1041 regularization_approach: RegularizationApproach::L2Only { weight: 1e-4 },
1042 },
1043 domain_metadata: HashMap::new(),
1044 };
1045
1046 selector.setcontext(context);
1047 let config = selector.select_optimal_config().unwrap();
1048
1049 assert_eq!(config.optimizer_type, OptimizerType::RMSprop);
1050 assert_eq!(config.batch_size, 32);
1051 assert!(config.specialized_params.contains_key("seasonal_periods"));
1052 assert!(config.specialized_params.contains_key("prediction_horizon"));
1053 }
1054
1055 #[test]
1056 fn test_performance_tracking() {
1057 let strategy = DomainStrategy::ComputerVision {
1058 resolution_adaptive: true,
1059 batch_norm_tuning: false,
1060 augmentation_aware: false,
1061 };
1062
1063 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1064
1065 let metrics = DomainPerformanceMetrics {
1066 validation_accuracy: 0.95,
1067 domain_specific_score: 0.92,
1068 stability_score: 0.88,
1069 convergence_epochs: 50,
1070 resource_efficiency: 0.85,
1071 transfer_score: 0.7,
1072 };
1073
1074 selector.update_domain_performance("computer_vision".to_string(), metrics);
1075
1076 let recommendations = selector.get_domain_recommendations("computer_vision");
1077 assert!(!recommendations.is_empty());
1078 assert!(recommendations[0].description.contains("0.95"));
1079 }
1080
1081 #[test]
1082 fn test_cross_domain_transfer() {
1083 let strategy = DomainStrategy::ComputerVision {
1084 resolution_adaptive: true,
1085 batch_norm_tuning: true,
1086 augmentation_aware: true,
1087 };
1088
1089 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1090
1091 let transfer_knowledge = CrossDomainKnowledge {
1092 source_domain: "natural_language".to_string(),
1093 target_domain: "computer_vision".to_string(),
1094 transferable_params: HashMap::from([
1095 ("learning_rate".to_string(), 0.001),
1096 ("weight_decay".to_string(), 0.01),
1097 ]),
1098 transfer_score: 0.8,
1099 successful_strategy: OptimizerType::AdamW,
1100 };
1101
1102 selector.record_transfer_knowledge(transfer_knowledge);
1103
1104 let recommendations = selector.get_domain_recommendations("computer_vision");
1105 assert!(recommendations
1106 .iter()
1107 .any(|r| matches!(r.recommendation_type, RecommendationType::TransferLearning)));
1108 }
1109
1110 #[test]
1111 fn test_scientific_computing_optimization() {
1112 let strategy = DomainStrategy::ScientificComputing {
1113 stability_focused: true,
1114 precision_critical: true,
1115 sparse_optimized: false,
1116 };
1117
1118 let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1119
1120 let context = OptimizationContext {
1121 problem_chars: ProblemCharacteristics {
1122 dataset_size: 1000,
1123 input_dim: 100,
1124 output_dim: 1,
1125 problem_type: ProblemType::Regression,
1126 gradient_sparsity: 0.01,
1127 gradient_noise: 0.01,
1128 memory_budget: 16_000_000_000,
1129 time_budget: 7200.0,
1130 batch_size: 100,
1131 lr_sensitivity: 0.3,
1132 regularization_strength: 1e-6,
1133 architecture_type: Some("MLP".to_string()),
1134 },
1135 resource_constraints: ResourceConstraints {
1136 max_memory: 16_000_000_000,
1137 max_time: 7200.0,
1138 gpu_available: false,
1139 distributed_capable: false,
1140 energy_efficient: false,
1141 },
1142 training_config: TrainingConfiguration {
1143 max_epochs: 1000,
1144 early_stopping_patience: 100,
1145 validation_frequency: 10,
1146 lr_schedule_type: LearningRateScheduleType::Constant,
1147 regularization_approach: RegularizationApproach::L2Only { weight: 1e-6 },
1148 },
1149 domain_metadata: HashMap::new(),
1150 };
1151
1152 selector.setcontext(context);
1153 let config = selector.select_optimal_config().unwrap();
1154
1155 assert_eq!(config.optimizer_type, OptimizerType::LBFGS);
1156 assert!(config.gradient_clip_norm.is_none()); assert!(config
1158 .specialized_params
1159 .contains_key("convergence_tolerance"));
1160 }
1161}