Skip to main content

optirs_core/domain_specific/
mod.rs

1// Domain-specific optimization strategies
2//
3// This module provides specialized optimization strategies tailored for different
4// machine learning domains, building on the adaptive selection framework to provide
5// domain-aware optimization approaches.
6
7use crate::adaptive_selection::{OptimizerType, ProblemCharacteristics};
8use crate::error::{OptimError, Result};
9use scirs2_core::ndarray::ScalarOperand;
10use scirs2_core::numeric::Float;
11use std::collections::HashMap;
12use std::fmt::Debug;
13
14/// Domain-specific optimization strategy
15#[derive(Debug, Clone)]
16pub enum DomainStrategy {
17    /// Computer Vision optimization
18    ComputerVision {
19        /// Image resolution considerations
20        resolution_adaptive: bool,
21        /// Batch normalization optimization
22        batch_norm_tuning: bool,
23        /// Data augmentation awareness
24        augmentation_aware: bool,
25    },
26    /// Natural Language Processing optimization
27    NaturalLanguage {
28        /// Sequence length adaptation
29        sequence_adaptive: bool,
30        /// Attention mechanism optimization
31        attention_optimized: bool,
32        /// Vocabulary size considerations
33        vocab_aware: bool,
34    },
35    /// Recommendation Systems optimization
36    RecommendationSystems {
37        /// Collaborative filtering optimization
38        collaborative_filtering: bool,
39        /// Matrix factorization tuning
40        matrix_factorization: bool,
41        /// Cold start handling
42        cold_start_aware: bool,
43    },
44    /// Time Series optimization
45    TimeSeries {
46        /// Temporal dependency handling
47        temporal_aware: bool,
48        /// Seasonality consideration
49        seasonality_adaptive: bool,
50        /// Multi-step ahead optimization
51        multi_step: bool,
52    },
53    /// Reinforcement Learning optimization
54    ReinforcementLearning {
55        /// Policy gradient optimization
56        policy_gradient: bool,
57        /// Value function optimization
58        value_function: bool,
59        /// Exploration-exploitation balance
60        exploration_aware: bool,
61    },
62    /// Scientific Computing optimization
63    ScientificComputing {
64        /// Numerical stability prioritization
65        stability_focused: bool,
66        /// High precision requirements
67        precision_critical: bool,
68        /// Sparse matrix optimization
69        sparse_optimized: bool,
70    },
71}
72
73/// Domain-specific configuration parameters
74#[derive(Debug, Clone)]
75pub struct DomainConfig<A: Float> {
76    /// Base learning rate for the domain
77    pub base_learning_rate: A,
78    /// Batch size recommendations
79    pub recommended_batch_sizes: Vec<usize>,
80    /// Gradient clipping thresholds
81    pub gradient_clip_values: Vec<A>,
82    /// Regularization strengths
83    pub regularization_range: (A, A),
84    /// Optimizer preferences (ranked by effectiveness)
85    pub optimizer_ranking: Vec<OptimizerType>,
86    /// Domain-specific hyperparameters
87    pub domain_params: HashMap<String, A>,
88}
89
90/// Domain-specific optimizer selector
91#[derive(Debug)]
92pub struct DomainSpecificSelector<A: Float> {
93    /// Current domain strategy
94    strategy: DomainStrategy,
95    /// Domain configuration
96    config: DomainConfig<A>,
97    /// Performance history per domain
98    domain_performance: HashMap<String, Vec<DomainPerformanceMetrics<A>>>,
99    /// Cross-domain transfer learning data
100    transfer_knowledge: Vec<CrossDomainKnowledge<A>>,
101    /// Current optimization context
102    currentcontext: Option<OptimizationContext<A>>,
103}
104
105/// Performance metrics specific to domains
106#[derive(Debug, Clone)]
107pub struct DomainPerformanceMetrics<A: Float> {
108    /// Standard performance metrics
109    pub validation_accuracy: A,
110    /// Domain-specific metrics
111    pub domain_specific_score: A,
112    /// Training stability
113    pub stability_score: A,
114    /// Convergence speed (epochs to target)
115    pub convergence_epochs: usize,
116    /// Resource efficiency
117    pub resource_efficiency: A,
118    /// Transfer learning potential
119    pub transfer_score: A,
120}
121
122/// Cross-domain knowledge transfer
123#[derive(Debug, Clone)]
124pub struct CrossDomainKnowledge<A: Float> {
125    /// Source domain
126    pub source_domain: String,
127    /// Target domain
128    pub target_domain: String,
129    /// Transferable hyperparameters
130    pub transferable_params: HashMap<String, A>,
131    /// Transfer effectiveness score
132    pub transfer_score: A,
133    /// Optimization strategy that worked
134    pub successful_strategy: OptimizerType,
135}
136
137/// Current optimization context
138#[derive(Debug, Clone)]
139pub struct OptimizationContext<A: Float> {
140    /// Problem characteristics
141    pub problem_chars: ProblemCharacteristics,
142    /// Resource constraints
143    pub resource_constraints: ResourceConstraints<A>,
144    /// Training characteristics
145    pub training_config: TrainingConfiguration<A>,
146    /// Domain-specific metadata
147    pub domain_metadata: HashMap<String, String>,
148}
149
150/// Resource constraints for optimization
151#[derive(Debug, Clone)]
152pub struct ResourceConstraints<A: Float> {
153    /// Maximum memory available (bytes)
154    pub max_memory: usize,
155    /// Maximum training time (seconds)
156    pub max_time: A,
157    /// GPU availability and type
158    pub gpu_available: bool,
159    /// Distributed training capability
160    pub distributed_capable: bool,
161    /// Energy efficiency requirements
162    pub energy_efficient: bool,
163}
164
165/// Training configuration parameters
166#[derive(Debug, Clone)]
167pub struct TrainingConfiguration<A: Float> {
168    /// Maximum number of epochs
169    pub max_epochs: usize,
170    /// Early stopping patience
171    pub early_stopping_patience: usize,
172    /// Validation frequency
173    pub validation_frequency: usize,
174    /// Learning rate scheduling
175    pub lr_schedule_type: LearningRateScheduleType,
176    /// Regularization approach
177    pub regularization_approach: RegularizationApproach<A>,
178}
179
180/// Learning rate schedule types
181#[derive(Debug, Clone)]
182pub enum LearningRateScheduleType {
183    /// Constant learning rate
184    Constant,
185    /// Exponential decay
186    ExponentialDecay {
187        /// Decay rate
188        decay_rate: f64,
189    },
190    /// Cosine annealing
191    CosineAnnealing {
192        /// Maximum number of iterations
193        t_max: usize,
194    },
195    /// Reduce on plateau
196    ReduceOnPlateau {
197        /// Number of epochs with no improvement
198        patience: usize,
199        /// Factor by which learning rate will be reduced
200        factor: f64,
201    },
202    /// One cycle policy
203    OneCycle {
204        /// Maximum learning rate
205        max_lr: f64,
206    },
207}
208
209/// Regularization approach
210#[derive(Debug, Clone)]
211pub enum RegularizationApproach<A: Float> {
212    /// L2 regularization only
213    L2Only {
214        /// Regularization weight
215        weight: A,
216    },
217    /// L1 regularization only
218    L1Only {
219        /// Regularization weight
220        weight: A,
221    },
222    /// Elastic net (L1 + L2)
223    ElasticNet {
224        /// L1 regularization weight
225        l1_weight: A,
226        /// L2 regularization weight
227        l2_weight: A,
228    },
229    /// Dropout regularization
230    Dropout {
231        /// Dropout rate
232        dropout_rate: A,
233    },
234    /// Combined approach
235    Combined {
236        /// L2 regularization weight
237        l2_weight: A,
238        /// Dropout rate
239        dropout_rate: A,
240        /// Additional regularization techniques
241        additional_techniques: Vec<String>,
242    },
243}
244
245impl<A: Float + ScalarOperand + Debug + std::iter::Sum + Send + Sync> DomainSpecificSelector<A> {
246    /// Create a new domain-specific selector
247    pub fn new(strategy: DomainStrategy) -> Self {
248        let config = Self::default_config_for_strategy(&strategy);
249
250        Self {
251            strategy,
252            config,
253            domain_performance: HashMap::new(),
254            transfer_knowledge: Vec::new(),
255            currentcontext: None,
256        }
257    }
258
259    /// Set optimization context
260    pub fn setcontext(&mut self, context: OptimizationContext<A>) {
261        self.currentcontext = Some(context);
262    }
263
264    /// Select optimal configuration for the current domain and context
265    pub fn select_optimal_config(&mut self) -> Result<DomainOptimizationConfig<A>> {
266        let context = self
267            .currentcontext
268            .as_ref()
269            .ok_or_else(|| OptimError::InvalidConfig("No optimization context set".to_string()))?;
270
271        match &self.strategy {
272            DomainStrategy::ComputerVision {
273                resolution_adaptive,
274                batch_norm_tuning,
275                augmentation_aware,
276            } => self.optimize_computer_vision(
277                context,
278                *resolution_adaptive,
279                *batch_norm_tuning,
280                *augmentation_aware,
281            ),
282            DomainStrategy::NaturalLanguage {
283                sequence_adaptive,
284                attention_optimized,
285                vocab_aware,
286            } => self.optimize_natural_language(
287                context,
288                *sequence_adaptive,
289                *attention_optimized,
290                *vocab_aware,
291            ),
292            DomainStrategy::RecommendationSystems {
293                collaborative_filtering,
294                matrix_factorization,
295                cold_start_aware,
296            } => self.optimize_recommendation_systems(
297                context,
298                *collaborative_filtering,
299                *matrix_factorization,
300                *cold_start_aware,
301            ),
302            DomainStrategy::TimeSeries {
303                temporal_aware,
304                seasonality_adaptive,
305                multi_step,
306            } => self.optimize_time_series(
307                context,
308                *temporal_aware,
309                *seasonality_adaptive,
310                *multi_step,
311            ),
312            DomainStrategy::ReinforcementLearning {
313                policy_gradient,
314                value_function,
315                exploration_aware,
316            } => self.optimize_reinforcement_learning(
317                context,
318                *policy_gradient,
319                *value_function,
320                *exploration_aware,
321            ),
322            DomainStrategy::ScientificComputing {
323                stability_focused,
324                precision_critical,
325                sparse_optimized,
326            } => self.optimize_scientific_computing(
327                context,
328                *stability_focused,
329                *precision_critical,
330                *sparse_optimized,
331            ),
332        }
333    }
334
335    /// Optimize for computer vision tasks
336    fn optimize_computer_vision(
337        &self,
338        context: &OptimizationContext<A>,
339        resolution_adaptive: bool,
340        batch_norm_tuning: bool,
341        augmentation_aware: bool,
342    ) -> Result<DomainOptimizationConfig<A>> {
343        let mut config = DomainOptimizationConfig::default();
344
345        // Resolution-_adaptive optimization
346        if resolution_adaptive {
347            let resolution_factor = self.estimate_resolution_factor(&context.problem_chars);
348            config.learning_rate =
349                self.config.base_learning_rate * A::from(resolution_factor).expect("unwrap failed");
350
351            // Larger images need smaller learning rates
352            if context.problem_chars.input_dim > 512 * 512 {
353                config.learning_rate = config.learning_rate * A::from(0.5).expect("unwrap failed");
354            }
355        }
356
357        // Batch normalization _tuning
358        if batch_norm_tuning {
359            config.optimizer_type = OptimizerType::AdamW; // Better for batch norm
360            config.specialized_params.insert(
361                "batch_norm_momentum".to_string(),
362                A::from(0.99).expect("unwrap failed"),
363            );
364            config.specialized_params.insert(
365                "batch_norm_eps".to_string(),
366                A::from(1e-5).expect("unwrap failed"),
367            );
368        }
369
370        // Data augmentation awareness
371        if augmentation_aware {
372            // More aggressive regularization with augmentation
373            config.regularization_strength =
374                config.regularization_strength * A::from(1.5).expect("unwrap failed");
375            config.specialized_params.insert(
376                "mixup_alpha".to_string(),
377                A::from(0.2).expect("unwrap failed"),
378            );
379            config.specialized_params.insert(
380                "cutmix_alpha".to_string(),
381                A::from(1.0).expect("unwrap failed"),
382            );
383        }
384
385        // CV-specific optimizations
386        config.batch_size = self.select_cv_batch_size(&context.resource_constraints);
387        config.gradient_clip_norm = Some(A::from(1.0).expect("unwrap failed"));
388
389        // Use cosine annealing for CV tasks
390        config.lr_schedule = LearningRateScheduleType::CosineAnnealing {
391            t_max: context.training_config.max_epochs,
392        };
393
394        Ok(config)
395    }
396
397    /// Optimize for natural language processing tasks
398    fn optimize_natural_language(
399        &self,
400        context: &OptimizationContext<A>,
401        sequence_adaptive: bool,
402        attention_optimized: bool,
403        vocab_aware: bool,
404    ) -> Result<DomainOptimizationConfig<A>> {
405        let mut config = DomainOptimizationConfig::default();
406
407        // Sequence-_adaptive optimization
408        if sequence_adaptive {
409            let seq_length = context.problem_chars.input_dim; // Assuming input_dim represents sequence length
410
411            // Longer sequences need more careful optimization
412            if seq_length > 512 {
413                config.learning_rate =
414                    self.config.base_learning_rate * A::from(0.7).expect("unwrap failed");
415                config.gradient_clip_norm = Some(A::from(0.5).expect("unwrap failed"));
416            } else {
417                config.learning_rate = self.config.base_learning_rate;
418                config.gradient_clip_norm = Some(A::from(1.0).expect("unwrap failed"));
419            }
420        }
421
422        // Attention mechanism optimization
423        if attention_optimized {
424            config.optimizer_type = OptimizerType::AdamW; // Best for transformers
425            config.specialized_params.insert(
426                "attention_dropout".to_string(),
427                A::from(0.1).expect("unwrap failed"),
428            );
429            config.specialized_params.insert(
430                "attention_head_dim".to_string(),
431                A::from(64.0).expect("unwrap failed"),
432            );
433
434            // Layer-wise learning rate decay for transformers
435            config.specialized_params.insert(
436                "layer_decay_rate".to_string(),
437                A::from(0.95).expect("unwrap failed"),
438            );
439        }
440
441        // Vocabulary-_aware optimization
442        if vocab_aware {
443            let vocab_size = context.problem_chars.output_dim; // Assuming output_dim represents vocab size
444
445            // Large vocabularies need special handling
446            if vocab_size > 30000 {
447                config.specialized_params.insert(
448                    "tie_embeddings".to_string(),
449                    A::from(1.0).expect("unwrap failed"),
450                );
451                config.specialized_params.insert(
452                    "embedding_dropout".to_string(),
453                    A::from(0.1).expect("unwrap failed"),
454                );
455            }
456        }
457
458        // NLP-specific optimizations
459        config.batch_size = self.select_nlp_batch_size(&context.resource_constraints);
460        config.lr_schedule = LearningRateScheduleType::OneCycle {
461            max_lr: config.learning_rate.to_f64().expect("unwrap failed"),
462        };
463
464        // Warmup for transformers
465        config.specialized_params.insert(
466            "warmup_steps".to_string(),
467            A::from(1000.0).expect("unwrap failed"),
468        );
469
470        Ok(config)
471    }
472
473    /// Optimize for recommendation systems
474    fn optimize_recommendation_systems(
475        &self,
476        context: &OptimizationContext<A>,
477        collaborative_filtering: bool,
478        matrix_factorization: bool,
479        cold_start_aware: bool,
480    ) -> Result<DomainOptimizationConfig<A>> {
481        let mut config = DomainOptimizationConfig::default();
482
483        // Collaborative _filtering optimization
484        if collaborative_filtering {
485            config.optimizer_type = OptimizerType::Adam; // Good for sparse data
486            config.regularization_strength = A::from(0.01).expect("unwrap failed"); // Prevent overfitting
487            config.specialized_params.insert(
488                "negative_sampling_rate".to_string(),
489                A::from(5.0).expect("unwrap failed"),
490            );
491        }
492
493        // Matrix _factorization tuning
494        if matrix_factorization {
495            config.learning_rate = A::from(0.01).expect("unwrap failed"); // Lower LR for stability
496            config.specialized_params.insert(
497                "embedding_dim".to_string(),
498                A::from(128.0).expect("unwrap failed"),
499            );
500            config.specialized_params.insert(
501                "factorization_rank".to_string(),
502                A::from(50.0).expect("unwrap failed"),
503            );
504        }
505
506        // Cold start handling
507        if cold_start_aware {
508            config.specialized_params.insert(
509                "content_weight".to_string(),
510                A::from(0.3).expect("unwrap failed"),
511            );
512            config.specialized_params.insert(
513                "popularity_bias".to_string(),
514                A::from(0.1).expect("unwrap failed"),
515            );
516        }
517
518        // RecSys-specific optimizations
519        config.batch_size = self.select_recsys_batch_size(&context.resource_constraints);
520        config.gradient_clip_norm = Some(A::from(5.0).expect("unwrap failed")); // Higher clip for sparse gradients
521
522        Ok(config)
523    }
524
525    /// Optimize for time series tasks
526    fn optimize_time_series(
527        &self,
528        context: &OptimizationContext<A>,
529        temporal_aware: bool,
530        seasonality_adaptive: bool,
531        multi_step: bool,
532    ) -> Result<DomainOptimizationConfig<A>> {
533        let mut config = DomainOptimizationConfig::default();
534
535        // Temporal dependency handling
536        if temporal_aware {
537            config.optimizer_type = OptimizerType::RMSprop; // Good for RNNs
538            config.learning_rate = A::from(0.001).expect("unwrap failed"); // Conservative for temporal stability
539            config.specialized_params.insert(
540                "sequence_length".to_string(),
541                A::from(context.problem_chars.input_dim as f64).expect("unwrap failed"),
542            );
543        }
544
545        // Seasonality consideration
546        if seasonality_adaptive {
547            config.specialized_params.insert(
548                "seasonal_periods".to_string(),
549                A::from(24.0).expect("unwrap failed"),
550            ); // Daily pattern
551            config.specialized_params.insert(
552                "trend_strength".to_string(),
553                A::from(0.1).expect("unwrap failed"),
554            );
555        }
556
557        // Multi-_step ahead optimization
558        if multi_step {
559            config.specialized_params.insert(
560                "prediction_horizon".to_string(),
561                A::from(12.0).expect("unwrap failed"),
562            );
563            config.specialized_params.insert(
564                "multi_step_loss_weight".to_string(),
565                A::from(0.8).expect("unwrap failed"),
566            );
567        }
568
569        // Time series-specific optimizations
570        config.batch_size = 32; // Smaller batches for temporal consistency
571        config.gradient_clip_norm = Some(A::from(1.0).expect("unwrap failed"));
572        config.lr_schedule = LearningRateScheduleType::ReduceOnPlateau {
573            patience: 10,
574            factor: 0.5,
575        };
576
577        Ok(config)
578    }
579
580    /// Optimize for reinforcement learning tasks
581    fn optimize_reinforcement_learning(
582        &self,
583        context: &OptimizationContext<A>,
584        policy_gradient: bool,
585        value_function: bool,
586        exploration_aware: bool,
587    ) -> Result<DomainOptimizationConfig<A>> {
588        let mut config = DomainOptimizationConfig::default();
589
590        // Policy _gradient optimization
591        if policy_gradient {
592            config.optimizer_type = OptimizerType::Adam;
593            config.learning_rate = A::from(3e-4).expect("unwrap failed"); // Standard RL learning rate
594            config.specialized_params.insert(
595                "entropy_coeff".to_string(),
596                A::from(0.01).expect("unwrap failed"),
597            );
598        }
599
600        // Value _function optimization
601        if value_function {
602            config.specialized_params.insert(
603                "value_loss_coeff".to_string(),
604                A::from(0.5).expect("unwrap failed"),
605            );
606            config.specialized_params.insert(
607                "huber_loss_delta".to_string(),
608                A::from(1.0).expect("unwrap failed"),
609            );
610        }
611
612        // Exploration-exploitation balance
613        if exploration_aware {
614            config.specialized_params.insert(
615                "epsilon_start".to_string(),
616                A::from(1.0).expect("unwrap failed"),
617            );
618            config.specialized_params.insert(
619                "epsilon_end".to_string(),
620                A::from(0.1).expect("unwrap failed"),
621            );
622            config.specialized_params.insert(
623                "epsilon_decay".to_string(),
624                A::from(0.995).expect("unwrap failed"),
625            );
626        }
627
628        // RL-specific optimizations
629        config.batch_size = 64; // Standard RL batch size
630        config.gradient_clip_norm = Some(A::from(0.5).expect("unwrap failed")); // Important for RL stability
631        config.lr_schedule = LearningRateScheduleType::Constant; // Often constant in RL
632
633        Ok(config)
634    }
635
636    /// Optimize for scientific computing tasks
637    fn optimize_scientific_computing(
638        &self,
639        context: &OptimizationContext<A>,
640        stability_focused: bool,
641        precision_critical: bool,
642        sparse_optimized: bool,
643    ) -> Result<DomainOptimizationConfig<A>> {
644        let mut config = DomainOptimizationConfig::default();
645
646        // Numerical stability prioritization
647        if stability_focused {
648            config.optimizer_type = OptimizerType::LBFGS; // More stable for scientific problems
649            config.learning_rate = A::from(0.1).expect("unwrap failed"); // Higher LR for LBFGS
650            config.specialized_params.insert(
651                "line_search_tolerance".to_string(),
652                A::from(1e-6).expect("unwrap failed"),
653            );
654        }
655
656        // High precision requirements
657        if precision_critical {
658            config.specialized_params.insert(
659                "convergence_tolerance".to_string(),
660                A::from(1e-8).expect("unwrap failed"),
661            );
662            config.specialized_params.insert(
663                "max_iterations".to_string(),
664                A::from(1000.0).expect("unwrap failed"),
665            );
666        }
667
668        // Sparse matrix optimization
669        if sparse_optimized {
670            config.optimizer_type = OptimizerType::Adam;
671            config.specialized_params.insert(
672                "sparsity_threshold".to_string(),
673                A::from(1e-6).expect("unwrap failed"),
674            );
675        }
676
677        // Scientific computing-specific optimizations
678        config.batch_size = context.problem_chars.dataset_size.min(1024); // Can use larger batches
679        config.gradient_clip_norm = None; // Don't clip for scientific precision
680        config.lr_schedule = LearningRateScheduleType::Constant; // Consistent optimization
681
682        Ok(config)
683    }
684
685    /// Update performance based on training results
686    pub fn update_domain_performance(
687        &mut self,
688        domain: String,
689        metrics: DomainPerformanceMetrics<A>,
690    ) {
691        self.domain_performance
692            .entry(domain)
693            .or_default()
694            .push(metrics);
695    }
696
697    /// Record cross-domain transfer knowledge
698    pub fn record_transfer_knowledge(&mut self, knowledge: CrossDomainKnowledge<A>) {
699        self.transfer_knowledge.push(knowledge);
700    }
701
702    /// Get domain-specific recommendations
703    pub fn get_domain_recommendations(&self, domain: &str) -> Vec<DomainRecommendation<A>> {
704        let mut recommendations = Vec::new();
705
706        // Analyze historical performance for this domain
707        if let Some(history) = self.domain_performance.get(domain) {
708            if !history.is_empty() {
709                let avg_performance = history.iter().map(|m| m.validation_accuracy).sum::<A>()
710                    / A::from(history.len()).expect("unwrap failed");
711
712                recommendations.push(DomainRecommendation {
713                    recommendation_type: RecommendationType::PerformanceBaseline,
714                    description: format!(
715                        "Historical average performance: {:.4}",
716                        avg_performance.to_f64().expect("unwrap failed")
717                    ),
718                    confidence: A::from(0.8).expect("unwrap failed"),
719                    action: "Consider this as baseline for improvements".to_string(),
720                });
721            }
722        }
723
724        // Cross-domain transfer recommendations
725        for knowledge in &self.transfer_knowledge {
726            if knowledge.target_domain == domain {
727                recommendations.push(DomainRecommendation {
728                    recommendation_type: RecommendationType::TransferLearning,
729                    description: format!(
730                        "Transfer from {} domain with {:.2} effectiveness",
731                        knowledge.source_domain,
732                        knowledge.transfer_score.to_f64().expect("unwrap failed")
733                    ),
734                    confidence: knowledge.transfer_score,
735                    action: format!("Use {:?} optimizer", knowledge.successful_strategy),
736                });
737            }
738        }
739
740        recommendations
741    }
742
743    /// Helper methods for domain-specific optimizations
744    fn estimate_resolution_factor(&self, problem_chars: &ProblemCharacteristics) -> f64 {
745        let resolution = problem_chars.input_dim as f64;
746
747        if resolution > 1_000_000.0 {
748            // Very high resolution
749            0.5
750        } else if resolution > 250_000.0 {
751            // High resolution
752            0.7
753        } else if resolution > 50_000.0 {
754            // Medium resolution
755            0.9
756        } else {
757            // Low resolution
758            1.0
759        }
760    }
761
762    fn select_cv_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
763        if constraints.max_memory > 16_000_000_000 {
764            // 16GB+
765            128
766        } else if constraints.max_memory > 8_000_000_000 {
767            // 8GB+
768            64
769        } else {
770            32
771        }
772    }
773
774    fn select_nlp_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
775        if constraints.max_memory > 32_000_000_000 {
776            // 32GB+
777            64
778        } else if constraints.max_memory > 16_000_000_000 {
779            // 16GB+
780            32
781        } else {
782            16
783        }
784    }
785
786    fn select_recsys_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
787        // RecSys can typically use larger batches due to simpler models
788        if constraints.max_memory > 8_000_000_000 {
789            512
790        } else {
791            256
792        }
793    }
794
795    /// Create default configuration for a strategy
796    fn default_config_for_strategy(strategy: &DomainStrategy) -> DomainConfig<A> {
797        match strategy {
798            DomainStrategy::ComputerVision { .. } => DomainConfig {
799                base_learning_rate: A::from(0.001).expect("unwrap failed"),
800                recommended_batch_sizes: vec![32, 64, 128],
801                gradient_clip_values: vec![
802                    A::from(1.0).expect("unwrap failed"),
803                    A::from(2.0).expect("unwrap failed"),
804                ],
805                regularization_range: (
806                    A::from(1e-5).expect("unwrap failed"),
807                    A::from(1e-2).expect("unwrap failed"),
808                ),
809                optimizer_ranking: vec![
810                    OptimizerType::AdamW,
811                    OptimizerType::SGDMomentum,
812                    OptimizerType::Adam,
813                ],
814                domain_params: HashMap::new(),
815            },
816            DomainStrategy::NaturalLanguage { .. } => DomainConfig {
817                base_learning_rate: A::from(2e-5).expect("unwrap failed"),
818                recommended_batch_sizes: vec![16, 32, 64],
819                gradient_clip_values: vec![
820                    A::from(0.5).expect("unwrap failed"),
821                    A::from(1.0).expect("unwrap failed"),
822                ],
823                regularization_range: (
824                    A::from(1e-4).expect("unwrap failed"),
825                    A::from(1e-1).expect("unwrap failed"),
826                ),
827                optimizer_ranking: vec![OptimizerType::AdamW, OptimizerType::Adam],
828                domain_params: HashMap::new(),
829            },
830            DomainStrategy::RecommendationSystems { .. } => DomainConfig {
831                base_learning_rate: A::from(0.01).expect("unwrap failed"),
832                recommended_batch_sizes: vec![128, 256, 512],
833                gradient_clip_values: vec![
834                    A::from(5.0).expect("unwrap failed"),
835                    A::from(10.0).expect("unwrap failed"),
836                ],
837                regularization_range: (
838                    A::from(1e-3).expect("unwrap failed"),
839                    A::from(1e-1).expect("unwrap failed"),
840                ),
841                optimizer_ranking: vec![OptimizerType::Adam, OptimizerType::AdaGrad],
842                domain_params: HashMap::new(),
843            },
844            DomainStrategy::TimeSeries { .. } => DomainConfig {
845                base_learning_rate: A::from(0.001).expect("unwrap failed"),
846                recommended_batch_sizes: vec![16, 32, 64],
847                gradient_clip_values: vec![A::from(1.0).expect("unwrap failed")],
848                regularization_range: (
849                    A::from(1e-4).expect("unwrap failed"),
850                    A::from(1e-2).expect("unwrap failed"),
851                ),
852                optimizer_ranking: vec![OptimizerType::RMSprop, OptimizerType::Adam],
853                domain_params: HashMap::new(),
854            },
855            DomainStrategy::ReinforcementLearning { .. } => DomainConfig {
856                base_learning_rate: A::from(3e-4).expect("unwrap failed"),
857                recommended_batch_sizes: vec![32, 64, 128],
858                gradient_clip_values: vec![A::from(0.5).expect("unwrap failed")],
859                regularization_range: (
860                    A::from(1e-4).expect("unwrap failed"),
861                    A::from(1e-2).expect("unwrap failed"),
862                ),
863                optimizer_ranking: vec![OptimizerType::Adam],
864                domain_params: HashMap::new(),
865            },
866            DomainStrategy::ScientificComputing { .. } => DomainConfig {
867                base_learning_rate: A::from(0.1).expect("unwrap failed"),
868                recommended_batch_sizes: vec![64, 128, 256, 512],
869                gradient_clip_values: vec![],
870                regularization_range: (
871                    A::from(1e-6).expect("unwrap failed"),
872                    A::from(1e-3).expect("unwrap failed"),
873                ),
874                optimizer_ranking: vec![OptimizerType::LBFGS, OptimizerType::Adam],
875                domain_params: HashMap::new(),
876            },
877        }
878    }
879}
880
881/// Final domain optimization configuration
882#[derive(Debug, Clone)]
883pub struct DomainOptimizationConfig<A: Float> {
884    /// Selected optimizer type
885    pub optimizer_type: OptimizerType,
886    /// Optimized learning rate
887    pub learning_rate: A,
888    /// Optimized batch size
889    pub batch_size: usize,
890    /// Gradient clipping norm (if applicable)
891    pub gradient_clip_norm: Option<A>,
892    /// Regularization strength
893    pub regularization_strength: A,
894    /// Learning rate schedule
895    pub lr_schedule: LearningRateScheduleType,
896    /// Domain-specific specialized parameters
897    pub specialized_params: HashMap<String, A>,
898}
899
900impl<A: Float + Send + Sync> Default for DomainOptimizationConfig<A> {
901    fn default() -> Self {
902        Self {
903            optimizer_type: OptimizerType::Adam,
904            learning_rate: A::from(0.001).expect("unwrap failed"),
905            batch_size: 32,
906            gradient_clip_norm: Some(A::from(1.0).expect("unwrap failed")),
907            regularization_strength: A::from(1e-4).expect("unwrap failed"),
908            lr_schedule: LearningRateScheduleType::Constant,
909            specialized_params: HashMap::new(),
910        }
911    }
912}
913
914/// Domain-specific recommendation
915#[derive(Debug, Clone)]
916pub struct DomainRecommendation<A: Float> {
917    /// Type of recommendation
918    pub recommendation_type: RecommendationType,
919    /// Human-readable description
920    pub description: String,
921    /// Confidence in the recommendation (0.0-1.0)
922    pub confidence: A,
923    /// Suggested action
924    pub action: String,
925}
926
927/// Types of domain recommendations
928#[derive(Debug, Clone)]
929pub enum RecommendationType {
930    /// Performance baseline information
931    PerformanceBaseline,
932    /// Transfer learning suggestion
933    TransferLearning,
934    /// Hyperparameter adjustment
935    HyperparameterTuning,
936    /// Architecture modification
937    ArchitectureChange,
938    /// Resource optimization
939    ResourceOptimization,
940}
941
942#[cfg(test)]
943mod tests {
944    use super::*;
945    use crate::adaptive_selection::ProblemType;
946
947    #[test]
948    fn test_domain_specific_selector_creation() {
949        let strategy = DomainStrategy::ComputerVision {
950            resolution_adaptive: true,
951            batch_norm_tuning: true,
952            augmentation_aware: true,
953        };
954
955        let selector = DomainSpecificSelector::<f64>::new(strategy);
956        assert_eq!(selector.config.optimizer_ranking[0], OptimizerType::AdamW);
957    }
958
959    #[test]
960    fn test_computer_vision_optimization() {
961        let strategy = DomainStrategy::ComputerVision {
962            resolution_adaptive: true,
963            batch_norm_tuning: true,
964            augmentation_aware: true,
965        };
966
967        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
968
969        let context = OptimizationContext {
970            problem_chars: ProblemCharacteristics {
971                dataset_size: 50000,
972                input_dim: 224 * 224 * 3, // Standard ImageNet resolution
973                output_dim: 1000,
974                problem_type: ProblemType::ComputerVision,
975                gradient_sparsity: 0.1,
976                gradient_noise: 0.05,
977                memory_budget: 8_000_000_000,
978                time_budget: 3600.0,
979                batch_size: 64,
980                lr_sensitivity: 0.5,
981                regularization_strength: 0.01,
982                architecture_type: Some("ResNet".to_string()),
983            },
984            resource_constraints: ResourceConstraints {
985                max_memory: 17_000_000_000, // Slightly above 16GB to trigger 128 batch size
986                max_time: 7200.0,
987                gpu_available: true,
988                distributed_capable: false,
989                energy_efficient: false,
990            },
991            training_config: TrainingConfiguration {
992                max_epochs: 100,
993                early_stopping_patience: 10,
994                validation_frequency: 1,
995                lr_schedule_type: LearningRateScheduleType::CosineAnnealing { t_max: 100 },
996                regularization_approach: RegularizationApproach::L2Only { weight: 1e-4 },
997            },
998            domain_metadata: HashMap::new(),
999        };
1000
1001        selector.setcontext(context);
1002        let config = selector.select_optimal_config().expect("unwrap failed");
1003
1004        assert_eq!(config.optimizer_type, OptimizerType::AdamW);
1005        assert_eq!(config.batch_size, 128); // Should select larger batch size for high memory
1006        assert!(config.gradient_clip_norm.is_some());
1007    }
1008
1009    #[test]
1010    fn test_natural_language_optimization() {
1011        let strategy = DomainStrategy::NaturalLanguage {
1012            sequence_adaptive: true,
1013            attention_optimized: true,
1014            vocab_aware: true,
1015        };
1016
1017        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1018
1019        let context = OptimizationContext {
1020            problem_chars: ProblemCharacteristics {
1021                dataset_size: 100000,
1022                input_dim: 512,    // Sequence length
1023                output_dim: 50000, // Large vocabulary
1024                problem_type: ProblemType::NaturalLanguage,
1025                gradient_sparsity: 0.2,
1026                gradient_noise: 0.1,
1027                memory_budget: 32_000_000_000,
1028                time_budget: 7200.0,
1029                batch_size: 32,
1030                lr_sensitivity: 0.8,
1031                regularization_strength: 0.1,
1032                architecture_type: Some("Transformer".to_string()),
1033            },
1034            resource_constraints: ResourceConstraints {
1035                max_memory: 32_000_000_000,
1036                max_time: 10800.0,
1037                gpu_available: true,
1038                distributed_capable: true,
1039                energy_efficient: false,
1040            },
1041            training_config: TrainingConfiguration {
1042                max_epochs: 50,
1043                early_stopping_patience: 5,
1044                validation_frequency: 1,
1045                lr_schedule_type: LearningRateScheduleType::OneCycle { max_lr: 2e-5 },
1046                regularization_approach: RegularizationApproach::Dropout { dropout_rate: 0.1 },
1047            },
1048            domain_metadata: HashMap::new(),
1049        };
1050
1051        selector.setcontext(context);
1052        let config = selector.select_optimal_config().expect("unwrap failed");
1053
1054        assert_eq!(config.optimizer_type, OptimizerType::AdamW);
1055        assert!(config.specialized_params.contains_key("warmup_steps"));
1056        assert!(config.specialized_params.contains_key("tie_embeddings")); // Large vocab
1057    }
1058
1059    #[test]
1060    fn test_time_series_optimization() {
1061        let strategy = DomainStrategy::TimeSeries {
1062            temporal_aware: true,
1063            seasonality_adaptive: true,
1064            multi_step: true,
1065        };
1066
1067        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1068
1069        let context = OptimizationContext {
1070            problem_chars: ProblemCharacteristics {
1071                dataset_size: 10000,
1072                input_dim: 168, // One week of hourly data
1073                output_dim: 24, // Next 24 hours
1074                problem_type: ProblemType::TimeSeries,
1075                gradient_sparsity: 0.05,
1076                gradient_noise: 0.2,
1077                memory_budget: 4_000_000_000,
1078                time_budget: 1800.0,
1079                batch_size: 32,
1080                lr_sensitivity: 0.7,
1081                regularization_strength: 0.01,
1082                architecture_type: Some("LSTM".to_string()),
1083            },
1084            resource_constraints: ResourceConstraints {
1085                max_memory: 8_000_000_000,
1086                max_time: 3600.0,
1087                gpu_available: true,
1088                distributed_capable: false,
1089                energy_efficient: true,
1090            },
1091            training_config: TrainingConfiguration {
1092                max_epochs: 200,
1093                early_stopping_patience: 20,
1094                validation_frequency: 5,
1095                lr_schedule_type: LearningRateScheduleType::ReduceOnPlateau {
1096                    patience: 10,
1097                    factor: 0.5,
1098                },
1099                regularization_approach: RegularizationApproach::L2Only { weight: 1e-4 },
1100            },
1101            domain_metadata: HashMap::new(),
1102        };
1103
1104        selector.setcontext(context);
1105        let config = selector.select_optimal_config().expect("unwrap failed");
1106
1107        assert_eq!(config.optimizer_type, OptimizerType::RMSprop);
1108        assert_eq!(config.batch_size, 32);
1109        assert!(config.specialized_params.contains_key("seasonal_periods"));
1110        assert!(config.specialized_params.contains_key("prediction_horizon"));
1111    }
1112
1113    #[test]
1114    fn test_performance_tracking() {
1115        let strategy = DomainStrategy::ComputerVision {
1116            resolution_adaptive: true,
1117            batch_norm_tuning: false,
1118            augmentation_aware: false,
1119        };
1120
1121        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1122
1123        let metrics = DomainPerformanceMetrics {
1124            validation_accuracy: 0.95,
1125            domain_specific_score: 0.92,
1126            stability_score: 0.88,
1127            convergence_epochs: 50,
1128            resource_efficiency: 0.85,
1129            transfer_score: 0.7,
1130        };
1131
1132        selector.update_domain_performance("computer_vision".to_string(), metrics);
1133
1134        let recommendations = selector.get_domain_recommendations("computer_vision");
1135        assert!(!recommendations.is_empty());
1136        assert!(recommendations[0].description.contains("0.95"));
1137    }
1138
1139    #[test]
1140    fn test_cross_domain_transfer() {
1141        let strategy = DomainStrategy::ComputerVision {
1142            resolution_adaptive: true,
1143            batch_norm_tuning: true,
1144            augmentation_aware: true,
1145        };
1146
1147        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1148
1149        let transfer_knowledge = CrossDomainKnowledge {
1150            source_domain: "natural_language".to_string(),
1151            target_domain: "computer_vision".to_string(),
1152            transferable_params: HashMap::from([
1153                ("learning_rate".to_string(), 0.001),
1154                ("weight_decay".to_string(), 0.01),
1155            ]),
1156            transfer_score: 0.8,
1157            successful_strategy: OptimizerType::AdamW,
1158        };
1159
1160        selector.record_transfer_knowledge(transfer_knowledge);
1161
1162        let recommendations = selector.get_domain_recommendations("computer_vision");
1163        assert!(recommendations
1164            .iter()
1165            .any(|r| matches!(r.recommendation_type, RecommendationType::TransferLearning)));
1166    }
1167
1168    #[test]
1169    fn test_scientific_computing_optimization() {
1170        let strategy = DomainStrategy::ScientificComputing {
1171            stability_focused: true,
1172            precision_critical: true,
1173            sparse_optimized: false,
1174        };
1175
1176        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1177
1178        let context = OptimizationContext {
1179            problem_chars: ProblemCharacteristics {
1180                dataset_size: 1000,
1181                input_dim: 100,
1182                output_dim: 1,
1183                problem_type: ProblemType::Regression,
1184                gradient_sparsity: 0.01,
1185                gradient_noise: 0.01,
1186                memory_budget: 16_000_000_000,
1187                time_budget: 7200.0,
1188                batch_size: 100,
1189                lr_sensitivity: 0.3,
1190                regularization_strength: 1e-6,
1191                architecture_type: Some("MLP".to_string()),
1192            },
1193            resource_constraints: ResourceConstraints {
1194                max_memory: 16_000_000_000,
1195                max_time: 7200.0,
1196                gpu_available: false,
1197                distributed_capable: false,
1198                energy_efficient: false,
1199            },
1200            training_config: TrainingConfiguration {
1201                max_epochs: 1000,
1202                early_stopping_patience: 100,
1203                validation_frequency: 10,
1204                lr_schedule_type: LearningRateScheduleType::Constant,
1205                regularization_approach: RegularizationApproach::L2Only { weight: 1e-6 },
1206            },
1207            domain_metadata: HashMap::new(),
1208        };
1209
1210        selector.setcontext(context);
1211        let config = selector.select_optimal_config().expect("unwrap failed");
1212
1213        assert_eq!(config.optimizer_type, OptimizerType::LBFGS);
1214        assert!(config.gradient_clip_norm.is_none()); // No clipping for precision
1215        assert!(config
1216            .specialized_params
1217            .contains_key("convergence_tolerance"));
1218    }
1219}