optirs_core/domain_specific/
mod.rs

1// Domain-specific optimization strategies
2//
3// This module provides specialized optimization strategies tailored for different
4// machine learning domains, building on the adaptive selection framework to provide
5// domain-aware optimization approaches.
6
7use crate::adaptive_selection::{OptimizerType, ProblemCharacteristics};
8use crate::error::{OptimError, Result};
9use scirs2_core::ndarray::ScalarOperand;
10use scirs2_core::numeric::Float;
11use std::collections::HashMap;
12use std::fmt::Debug;
13
14/// Domain-specific optimization strategy
15#[derive(Debug, Clone)]
16pub enum DomainStrategy {
17    /// Computer Vision optimization
18    ComputerVision {
19        /// Image resolution considerations
20        resolution_adaptive: bool,
21        /// Batch normalization optimization
22        batch_norm_tuning: bool,
23        /// Data augmentation awareness
24        augmentation_aware: bool,
25    },
26    /// Natural Language Processing optimization
27    NaturalLanguage {
28        /// Sequence length adaptation
29        sequence_adaptive: bool,
30        /// Attention mechanism optimization
31        attention_optimized: bool,
32        /// Vocabulary size considerations
33        vocab_aware: bool,
34    },
35    /// Recommendation Systems optimization
36    RecommendationSystems {
37        /// Collaborative filtering optimization
38        collaborative_filtering: bool,
39        /// Matrix factorization tuning
40        matrix_factorization: bool,
41        /// Cold start handling
42        cold_start_aware: bool,
43    },
44    /// Time Series optimization
45    TimeSeries {
46        /// Temporal dependency handling
47        temporal_aware: bool,
48        /// Seasonality consideration
49        seasonality_adaptive: bool,
50        /// Multi-step ahead optimization
51        multi_step: bool,
52    },
53    /// Reinforcement Learning optimization
54    ReinforcementLearning {
55        /// Policy gradient optimization
56        policy_gradient: bool,
57        /// Value function optimization
58        value_function: bool,
59        /// Exploration-exploitation balance
60        exploration_aware: bool,
61    },
62    /// Scientific Computing optimization
63    ScientificComputing {
64        /// Numerical stability prioritization
65        stability_focused: bool,
66        /// High precision requirements
67        precision_critical: bool,
68        /// Sparse matrix optimization
69        sparse_optimized: bool,
70    },
71}
72
73/// Domain-specific configuration parameters
74#[derive(Debug, Clone)]
75pub struct DomainConfig<A: Float> {
76    /// Base learning rate for the domain
77    pub base_learning_rate: A,
78    /// Batch size recommendations
79    pub recommended_batch_sizes: Vec<usize>,
80    /// Gradient clipping thresholds
81    pub gradient_clip_values: Vec<A>,
82    /// Regularization strengths
83    pub regularization_range: (A, A),
84    /// Optimizer preferences (ranked by effectiveness)
85    pub optimizer_ranking: Vec<OptimizerType>,
86    /// Domain-specific hyperparameters
87    pub domain_params: HashMap<String, A>,
88}
89
90/// Domain-specific optimizer selector
91#[derive(Debug)]
92pub struct DomainSpecificSelector<A: Float> {
93    /// Current domain strategy
94    strategy: DomainStrategy,
95    /// Domain configuration
96    config: DomainConfig<A>,
97    /// Performance history per domain
98    domain_performance: HashMap<String, Vec<DomainPerformanceMetrics<A>>>,
99    /// Cross-domain transfer learning data
100    transfer_knowledge: Vec<CrossDomainKnowledge<A>>,
101    /// Current optimization context
102    currentcontext: Option<OptimizationContext<A>>,
103}
104
105/// Performance metrics specific to domains
106#[derive(Debug, Clone)]
107pub struct DomainPerformanceMetrics<A: Float> {
108    /// Standard performance metrics
109    pub validation_accuracy: A,
110    /// Domain-specific metrics
111    pub domain_specific_score: A,
112    /// Training stability
113    pub stability_score: A,
114    /// Convergence speed (epochs to target)
115    pub convergence_epochs: usize,
116    /// Resource efficiency
117    pub resource_efficiency: A,
118    /// Transfer learning potential
119    pub transfer_score: A,
120}
121
122/// Cross-domain knowledge transfer
123#[derive(Debug, Clone)]
124pub struct CrossDomainKnowledge<A: Float> {
125    /// Source domain
126    pub source_domain: String,
127    /// Target domain
128    pub target_domain: String,
129    /// Transferable hyperparameters
130    pub transferable_params: HashMap<String, A>,
131    /// Transfer effectiveness score
132    pub transfer_score: A,
133    /// Optimization strategy that worked
134    pub successful_strategy: OptimizerType,
135}
136
137/// Current optimization context
138#[derive(Debug, Clone)]
139pub struct OptimizationContext<A: Float> {
140    /// Problem characteristics
141    pub problem_chars: ProblemCharacteristics,
142    /// Resource constraints
143    pub resource_constraints: ResourceConstraints<A>,
144    /// Training characteristics
145    pub training_config: TrainingConfiguration<A>,
146    /// Domain-specific metadata
147    pub domain_metadata: HashMap<String, String>,
148}
149
150/// Resource constraints for optimization
151#[derive(Debug, Clone)]
152pub struct ResourceConstraints<A: Float> {
153    /// Maximum memory available (bytes)
154    pub max_memory: usize,
155    /// Maximum training time (seconds)
156    pub max_time: A,
157    /// GPU availability and type
158    pub gpu_available: bool,
159    /// Distributed training capability
160    pub distributed_capable: bool,
161    /// Energy efficiency requirements
162    pub energy_efficient: bool,
163}
164
165/// Training configuration parameters
166#[derive(Debug, Clone)]
167pub struct TrainingConfiguration<A: Float> {
168    /// Maximum number of epochs
169    pub max_epochs: usize,
170    /// Early stopping patience
171    pub early_stopping_patience: usize,
172    /// Validation frequency
173    pub validation_frequency: usize,
174    /// Learning rate scheduling
175    pub lr_schedule_type: LearningRateScheduleType,
176    /// Regularization approach
177    pub regularization_approach: RegularizationApproach<A>,
178}
179
180/// Learning rate schedule types
181#[derive(Debug, Clone)]
182pub enum LearningRateScheduleType {
183    /// Constant learning rate
184    Constant,
185    /// Exponential decay
186    ExponentialDecay {
187        /// Decay rate
188        decay_rate: f64,
189    },
190    /// Cosine annealing
191    CosineAnnealing {
192        /// Maximum number of iterations
193        t_max: usize,
194    },
195    /// Reduce on plateau
196    ReduceOnPlateau {
197        /// Number of epochs with no improvement
198        patience: usize,
199        /// Factor by which learning rate will be reduced
200        factor: f64,
201    },
202    /// One cycle policy
203    OneCycle {
204        /// Maximum learning rate
205        max_lr: f64,
206    },
207}
208
209/// Regularization approach
210#[derive(Debug, Clone)]
211pub enum RegularizationApproach<A: Float> {
212    /// L2 regularization only
213    L2Only {
214        /// Regularization weight
215        weight: A,
216    },
217    /// L1 regularization only
218    L1Only {
219        /// Regularization weight
220        weight: A,
221    },
222    /// Elastic net (L1 + L2)
223    ElasticNet {
224        /// L1 regularization weight
225        l1_weight: A,
226        /// L2 regularization weight
227        l2_weight: A,
228    },
229    /// Dropout regularization
230    Dropout {
231        /// Dropout rate
232        dropout_rate: A,
233    },
234    /// Combined approach
235    Combined {
236        /// L2 regularization weight
237        l2_weight: A,
238        /// Dropout rate
239        dropout_rate: A,
240        /// Additional regularization techniques
241        additional_techniques: Vec<String>,
242    },
243}
244
245impl<A: Float + ScalarOperand + Debug + std::iter::Sum + Send + Sync> DomainSpecificSelector<A> {
246    /// Create a new domain-specific selector
247    pub fn new(strategy: DomainStrategy) -> Self {
248        let config = Self::default_config_for_strategy(&strategy);
249
250        Self {
251            strategy,
252            config,
253            domain_performance: HashMap::new(),
254            transfer_knowledge: Vec::new(),
255            currentcontext: None,
256        }
257    }
258
259    /// Set optimization context
260    pub fn setcontext(&mut self, context: OptimizationContext<A>) {
261        self.currentcontext = Some(context);
262    }
263
264    /// Select optimal configuration for the current domain and context
265    pub fn select_optimal_config(&mut self) -> Result<DomainOptimizationConfig<A>> {
266        let context = self
267            .currentcontext
268            .as_ref()
269            .ok_or_else(|| OptimError::InvalidConfig("No optimization context set".to_string()))?;
270
271        match &self.strategy {
272            DomainStrategy::ComputerVision {
273                resolution_adaptive,
274                batch_norm_tuning,
275                augmentation_aware,
276            } => self.optimize_computer_vision(
277                context,
278                *resolution_adaptive,
279                *batch_norm_tuning,
280                *augmentation_aware,
281            ),
282            DomainStrategy::NaturalLanguage {
283                sequence_adaptive,
284                attention_optimized,
285                vocab_aware,
286            } => self.optimize_natural_language(
287                context,
288                *sequence_adaptive,
289                *attention_optimized,
290                *vocab_aware,
291            ),
292            DomainStrategy::RecommendationSystems {
293                collaborative_filtering,
294                matrix_factorization,
295                cold_start_aware,
296            } => self.optimize_recommendation_systems(
297                context,
298                *collaborative_filtering,
299                *matrix_factorization,
300                *cold_start_aware,
301            ),
302            DomainStrategy::TimeSeries {
303                temporal_aware,
304                seasonality_adaptive,
305                multi_step,
306            } => self.optimize_time_series(
307                context,
308                *temporal_aware,
309                *seasonality_adaptive,
310                *multi_step,
311            ),
312            DomainStrategy::ReinforcementLearning {
313                policy_gradient,
314                value_function,
315                exploration_aware,
316            } => self.optimize_reinforcement_learning(
317                context,
318                *policy_gradient,
319                *value_function,
320                *exploration_aware,
321            ),
322            DomainStrategy::ScientificComputing {
323                stability_focused,
324                precision_critical,
325                sparse_optimized,
326            } => self.optimize_scientific_computing(
327                context,
328                *stability_focused,
329                *precision_critical,
330                *sparse_optimized,
331            ),
332        }
333    }
334
335    /// Optimize for computer vision tasks
336    fn optimize_computer_vision(
337        &self,
338        context: &OptimizationContext<A>,
339        resolution_adaptive: bool,
340        batch_norm_tuning: bool,
341        augmentation_aware: bool,
342    ) -> Result<DomainOptimizationConfig<A>> {
343        let mut config = DomainOptimizationConfig::default();
344
345        // Resolution-_adaptive optimization
346        if resolution_adaptive {
347            let resolution_factor = self.estimate_resolution_factor(&context.problem_chars);
348            config.learning_rate =
349                self.config.base_learning_rate * A::from(resolution_factor).unwrap();
350
351            // Larger images need smaller learning rates
352            if context.problem_chars.input_dim > 512 * 512 {
353                config.learning_rate = config.learning_rate * A::from(0.5).unwrap();
354            }
355        }
356
357        // Batch normalization _tuning
358        if batch_norm_tuning {
359            config.optimizer_type = OptimizerType::AdamW; // Better for batch norm
360            config
361                .specialized_params
362                .insert("batch_norm_momentum".to_string(), A::from(0.99).unwrap());
363            config
364                .specialized_params
365                .insert("batch_norm_eps".to_string(), A::from(1e-5).unwrap());
366        }
367
368        // Data augmentation awareness
369        if augmentation_aware {
370            // More aggressive regularization with augmentation
371            config.regularization_strength = config.regularization_strength * A::from(1.5).unwrap();
372            config
373                .specialized_params
374                .insert("mixup_alpha".to_string(), A::from(0.2).unwrap());
375            config
376                .specialized_params
377                .insert("cutmix_alpha".to_string(), A::from(1.0).unwrap());
378        }
379
380        // CV-specific optimizations
381        config.batch_size = self.select_cv_batch_size(&context.resource_constraints);
382        config.gradient_clip_norm = Some(A::from(1.0).unwrap());
383
384        // Use cosine annealing for CV tasks
385        config.lr_schedule = LearningRateScheduleType::CosineAnnealing {
386            t_max: context.training_config.max_epochs,
387        };
388
389        Ok(config)
390    }
391
392    /// Optimize for natural language processing tasks
393    fn optimize_natural_language(
394        &self,
395        context: &OptimizationContext<A>,
396        sequence_adaptive: bool,
397        attention_optimized: bool,
398        vocab_aware: bool,
399    ) -> Result<DomainOptimizationConfig<A>> {
400        let mut config = DomainOptimizationConfig::default();
401
402        // Sequence-_adaptive optimization
403        if sequence_adaptive {
404            let seq_length = context.problem_chars.input_dim; // Assuming input_dim represents sequence length
405
406            // Longer sequences need more careful optimization
407            if seq_length > 512 {
408                config.learning_rate = self.config.base_learning_rate * A::from(0.7).unwrap();
409                config.gradient_clip_norm = Some(A::from(0.5).unwrap());
410            } else {
411                config.learning_rate = self.config.base_learning_rate;
412                config.gradient_clip_norm = Some(A::from(1.0).unwrap());
413            }
414        }
415
416        // Attention mechanism optimization
417        if attention_optimized {
418            config.optimizer_type = OptimizerType::AdamW; // Best for transformers
419            config
420                .specialized_params
421                .insert("attention_dropout".to_string(), A::from(0.1).unwrap());
422            config
423                .specialized_params
424                .insert("attention_head_dim".to_string(), A::from(64.0).unwrap());
425
426            // Layer-wise learning rate decay for transformers
427            config
428                .specialized_params
429                .insert("layer_decay_rate".to_string(), A::from(0.95).unwrap());
430        }
431
432        // Vocabulary-_aware optimization
433        if vocab_aware {
434            let vocab_size = context.problem_chars.output_dim; // Assuming output_dim represents vocab size
435
436            // Large vocabularies need special handling
437            if vocab_size > 30000 {
438                config
439                    .specialized_params
440                    .insert("tie_embeddings".to_string(), A::from(1.0).unwrap());
441                config
442                    .specialized_params
443                    .insert("embedding_dropout".to_string(), A::from(0.1).unwrap());
444            }
445        }
446
447        // NLP-specific optimizations
448        config.batch_size = self.select_nlp_batch_size(&context.resource_constraints);
449        config.lr_schedule = LearningRateScheduleType::OneCycle {
450            max_lr: config.learning_rate.to_f64().unwrap(),
451        };
452
453        // Warmup for transformers
454        config
455            .specialized_params
456            .insert("warmup_steps".to_string(), A::from(1000.0).unwrap());
457
458        Ok(config)
459    }
460
461    /// Optimize for recommendation systems
462    fn optimize_recommendation_systems(
463        &self,
464        context: &OptimizationContext<A>,
465        collaborative_filtering: bool,
466        matrix_factorization: bool,
467        cold_start_aware: bool,
468    ) -> Result<DomainOptimizationConfig<A>> {
469        let mut config = DomainOptimizationConfig::default();
470
471        // Collaborative _filtering optimization
472        if collaborative_filtering {
473            config.optimizer_type = OptimizerType::Adam; // Good for sparse data
474            config.regularization_strength = A::from(0.01).unwrap(); // Prevent overfitting
475            config
476                .specialized_params
477                .insert("negative_sampling_rate".to_string(), A::from(5.0).unwrap());
478        }
479
480        // Matrix _factorization tuning
481        if matrix_factorization {
482            config.learning_rate = A::from(0.01).unwrap(); // Lower LR for stability
483            config
484                .specialized_params
485                .insert("embedding_dim".to_string(), A::from(128.0).unwrap());
486            config
487                .specialized_params
488                .insert("factorization_rank".to_string(), A::from(50.0).unwrap());
489        }
490
491        // Cold start handling
492        if cold_start_aware {
493            config
494                .specialized_params
495                .insert("content_weight".to_string(), A::from(0.3).unwrap());
496            config
497                .specialized_params
498                .insert("popularity_bias".to_string(), A::from(0.1).unwrap());
499        }
500
501        // RecSys-specific optimizations
502        config.batch_size = self.select_recsys_batch_size(&context.resource_constraints);
503        config.gradient_clip_norm = Some(A::from(5.0).unwrap()); // Higher clip for sparse gradients
504
505        Ok(config)
506    }
507
508    /// Optimize for time series tasks
509    fn optimize_time_series(
510        &self,
511        context: &OptimizationContext<A>,
512        temporal_aware: bool,
513        seasonality_adaptive: bool,
514        multi_step: bool,
515    ) -> Result<DomainOptimizationConfig<A>> {
516        let mut config = DomainOptimizationConfig::default();
517
518        // Temporal dependency handling
519        if temporal_aware {
520            config.optimizer_type = OptimizerType::RMSprop; // Good for RNNs
521            config.learning_rate = A::from(0.001).unwrap(); // Conservative for temporal stability
522            config.specialized_params.insert(
523                "sequence_length".to_string(),
524                A::from(context.problem_chars.input_dim as f64).unwrap(),
525            );
526        }
527
528        // Seasonality consideration
529        if seasonality_adaptive {
530            config
531                .specialized_params
532                .insert("seasonal_periods".to_string(), A::from(24.0).unwrap()); // Daily pattern
533            config
534                .specialized_params
535                .insert("trend_strength".to_string(), A::from(0.1).unwrap());
536        }
537
538        // Multi-_step ahead optimization
539        if multi_step {
540            config
541                .specialized_params
542                .insert("prediction_horizon".to_string(), A::from(12.0).unwrap());
543            config
544                .specialized_params
545                .insert("multi_step_loss_weight".to_string(), A::from(0.8).unwrap());
546        }
547
548        // Time series-specific optimizations
549        config.batch_size = 32; // Smaller batches for temporal consistency
550        config.gradient_clip_norm = Some(A::from(1.0).unwrap());
551        config.lr_schedule = LearningRateScheduleType::ReduceOnPlateau {
552            patience: 10,
553            factor: 0.5,
554        };
555
556        Ok(config)
557    }
558
559    /// Optimize for reinforcement learning tasks
560    fn optimize_reinforcement_learning(
561        &self,
562        context: &OptimizationContext<A>,
563        policy_gradient: bool,
564        value_function: bool,
565        exploration_aware: bool,
566    ) -> Result<DomainOptimizationConfig<A>> {
567        let mut config = DomainOptimizationConfig::default();
568
569        // Policy _gradient optimization
570        if policy_gradient {
571            config.optimizer_type = OptimizerType::Adam;
572            config.learning_rate = A::from(3e-4).unwrap(); // Standard RL learning rate
573            config
574                .specialized_params
575                .insert("entropy_coeff".to_string(), A::from(0.01).unwrap());
576        }
577
578        // Value _function optimization
579        if value_function {
580            config
581                .specialized_params
582                .insert("value_loss_coeff".to_string(), A::from(0.5).unwrap());
583            config
584                .specialized_params
585                .insert("huber_loss_delta".to_string(), A::from(1.0).unwrap());
586        }
587
588        // Exploration-exploitation balance
589        if exploration_aware {
590            config
591                .specialized_params
592                .insert("epsilon_start".to_string(), A::from(1.0).unwrap());
593            config
594                .specialized_params
595                .insert("epsilon_end".to_string(), A::from(0.1).unwrap());
596            config
597                .specialized_params
598                .insert("epsilon_decay".to_string(), A::from(0.995).unwrap());
599        }
600
601        // RL-specific optimizations
602        config.batch_size = 64; // Standard RL batch size
603        config.gradient_clip_norm = Some(A::from(0.5).unwrap()); // Important for RL stability
604        config.lr_schedule = LearningRateScheduleType::Constant; // Often constant in RL
605
606        Ok(config)
607    }
608
609    /// Optimize for scientific computing tasks
610    fn optimize_scientific_computing(
611        &self,
612        context: &OptimizationContext<A>,
613        stability_focused: bool,
614        precision_critical: bool,
615        sparse_optimized: bool,
616    ) -> Result<DomainOptimizationConfig<A>> {
617        let mut config = DomainOptimizationConfig::default();
618
619        // Numerical stability prioritization
620        if stability_focused {
621            config.optimizer_type = OptimizerType::LBFGS; // More stable for scientific problems
622            config.learning_rate = A::from(0.1).unwrap(); // Higher LR for LBFGS
623            config
624                .specialized_params
625                .insert("line_search_tolerance".to_string(), A::from(1e-6).unwrap());
626        }
627
628        // High precision requirements
629        if precision_critical {
630            config
631                .specialized_params
632                .insert("convergence_tolerance".to_string(), A::from(1e-8).unwrap());
633            config
634                .specialized_params
635                .insert("max_iterations".to_string(), A::from(1000.0).unwrap());
636        }
637
638        // Sparse matrix optimization
639        if sparse_optimized {
640            config.optimizer_type = OptimizerType::Adam;
641            config
642                .specialized_params
643                .insert("sparsity_threshold".to_string(), A::from(1e-6).unwrap());
644        }
645
646        // Scientific computing-specific optimizations
647        config.batch_size = context.problem_chars.dataset_size.min(1024); // Can use larger batches
648        config.gradient_clip_norm = None; // Don't clip for scientific precision
649        config.lr_schedule = LearningRateScheduleType::Constant; // Consistent optimization
650
651        Ok(config)
652    }
653
654    /// Update performance based on training results
655    pub fn update_domain_performance(
656        &mut self,
657        domain: String,
658        metrics: DomainPerformanceMetrics<A>,
659    ) {
660        self.domain_performance
661            .entry(domain)
662            .or_default()
663            .push(metrics);
664    }
665
666    /// Record cross-domain transfer knowledge
667    pub fn record_transfer_knowledge(&mut self, knowledge: CrossDomainKnowledge<A>) {
668        self.transfer_knowledge.push(knowledge);
669    }
670
671    /// Get domain-specific recommendations
672    pub fn get_domain_recommendations(&self, domain: &str) -> Vec<DomainRecommendation<A>> {
673        let mut recommendations = Vec::new();
674
675        // Analyze historical performance for this domain
676        if let Some(history) = self.domain_performance.get(domain) {
677            if !history.is_empty() {
678                let avg_performance = history.iter().map(|m| m.validation_accuracy).sum::<A>()
679                    / A::from(history.len()).unwrap();
680
681                recommendations.push(DomainRecommendation {
682                    recommendation_type: RecommendationType::PerformanceBaseline,
683                    description: format!(
684                        "Historical average performance: {:.4}",
685                        avg_performance.to_f64().unwrap()
686                    ),
687                    confidence: A::from(0.8).unwrap(),
688                    action: "Consider this as baseline for improvements".to_string(),
689                });
690            }
691        }
692
693        // Cross-domain transfer recommendations
694        for knowledge in &self.transfer_knowledge {
695            if knowledge.target_domain == domain {
696                recommendations.push(DomainRecommendation {
697                    recommendation_type: RecommendationType::TransferLearning,
698                    description: format!(
699                        "Transfer from {} domain with {:.2} effectiveness",
700                        knowledge.source_domain,
701                        knowledge.transfer_score.to_f64().unwrap()
702                    ),
703                    confidence: knowledge.transfer_score,
704                    action: format!("Use {:?} optimizer", knowledge.successful_strategy),
705                });
706            }
707        }
708
709        recommendations
710    }
711
712    /// Helper methods for domain-specific optimizations
713    fn estimate_resolution_factor(&self, problem_chars: &ProblemCharacteristics) -> f64 {
714        let resolution = problem_chars.input_dim as f64;
715
716        if resolution > 1_000_000.0 {
717            // Very high resolution
718            0.5
719        } else if resolution > 250_000.0 {
720            // High resolution
721            0.7
722        } else if resolution > 50_000.0 {
723            // Medium resolution
724            0.9
725        } else {
726            // Low resolution
727            1.0
728        }
729    }
730
731    fn select_cv_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
732        if constraints.max_memory > 16_000_000_000 {
733            // 16GB+
734            128
735        } else if constraints.max_memory > 8_000_000_000 {
736            // 8GB+
737            64
738        } else {
739            32
740        }
741    }
742
743    fn select_nlp_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
744        if constraints.max_memory > 32_000_000_000 {
745            // 32GB+
746            64
747        } else if constraints.max_memory > 16_000_000_000 {
748            // 16GB+
749            32
750        } else {
751            16
752        }
753    }
754
755    fn select_recsys_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
756        // RecSys can typically use larger batches due to simpler models
757        if constraints.max_memory > 8_000_000_000 {
758            512
759        } else {
760            256
761        }
762    }
763
764    /// Create default configuration for a strategy
765    fn default_config_for_strategy(strategy: &DomainStrategy) -> DomainConfig<A> {
766        match strategy {
767            DomainStrategy::ComputerVision { .. } => DomainConfig {
768                base_learning_rate: A::from(0.001).unwrap(),
769                recommended_batch_sizes: vec![32, 64, 128],
770                gradient_clip_values: vec![A::from(1.0).unwrap(), A::from(2.0).unwrap()],
771                regularization_range: (A::from(1e-5).unwrap(), A::from(1e-2).unwrap()),
772                optimizer_ranking: vec![
773                    OptimizerType::AdamW,
774                    OptimizerType::SGDMomentum,
775                    OptimizerType::Adam,
776                ],
777                domain_params: HashMap::new(),
778            },
779            DomainStrategy::NaturalLanguage { .. } => DomainConfig {
780                base_learning_rate: A::from(2e-5).unwrap(),
781                recommended_batch_sizes: vec![16, 32, 64],
782                gradient_clip_values: vec![A::from(0.5).unwrap(), A::from(1.0).unwrap()],
783                regularization_range: (A::from(1e-4).unwrap(), A::from(1e-1).unwrap()),
784                optimizer_ranking: vec![OptimizerType::AdamW, OptimizerType::Adam],
785                domain_params: HashMap::new(),
786            },
787            DomainStrategy::RecommendationSystems { .. } => DomainConfig {
788                base_learning_rate: A::from(0.01).unwrap(),
789                recommended_batch_sizes: vec![128, 256, 512],
790                gradient_clip_values: vec![A::from(5.0).unwrap(), A::from(10.0).unwrap()],
791                regularization_range: (A::from(1e-3).unwrap(), A::from(1e-1).unwrap()),
792                optimizer_ranking: vec![OptimizerType::Adam, OptimizerType::AdaGrad],
793                domain_params: HashMap::new(),
794            },
795            DomainStrategy::TimeSeries { .. } => DomainConfig {
796                base_learning_rate: A::from(0.001).unwrap(),
797                recommended_batch_sizes: vec![16, 32, 64],
798                gradient_clip_values: vec![A::from(1.0).unwrap()],
799                regularization_range: (A::from(1e-4).unwrap(), A::from(1e-2).unwrap()),
800                optimizer_ranking: vec![OptimizerType::RMSprop, OptimizerType::Adam],
801                domain_params: HashMap::new(),
802            },
803            DomainStrategy::ReinforcementLearning { .. } => DomainConfig {
804                base_learning_rate: A::from(3e-4).unwrap(),
805                recommended_batch_sizes: vec![32, 64, 128],
806                gradient_clip_values: vec![A::from(0.5).unwrap()],
807                regularization_range: (A::from(1e-4).unwrap(), A::from(1e-2).unwrap()),
808                optimizer_ranking: vec![OptimizerType::Adam],
809                domain_params: HashMap::new(),
810            },
811            DomainStrategy::ScientificComputing { .. } => DomainConfig {
812                base_learning_rate: A::from(0.1).unwrap(),
813                recommended_batch_sizes: vec![64, 128, 256, 512],
814                gradient_clip_values: vec![],
815                regularization_range: (A::from(1e-6).unwrap(), A::from(1e-3).unwrap()),
816                optimizer_ranking: vec![OptimizerType::LBFGS, OptimizerType::Adam],
817                domain_params: HashMap::new(),
818            },
819        }
820    }
821}
822
823/// Final domain optimization configuration
824#[derive(Debug, Clone)]
825pub struct DomainOptimizationConfig<A: Float> {
826    /// Selected optimizer type
827    pub optimizer_type: OptimizerType,
828    /// Optimized learning rate
829    pub learning_rate: A,
830    /// Optimized batch size
831    pub batch_size: usize,
832    /// Gradient clipping norm (if applicable)
833    pub gradient_clip_norm: Option<A>,
834    /// Regularization strength
835    pub regularization_strength: A,
836    /// Learning rate schedule
837    pub lr_schedule: LearningRateScheduleType,
838    /// Domain-specific specialized parameters
839    pub specialized_params: HashMap<String, A>,
840}
841
842impl<A: Float + Send + Sync> Default for DomainOptimizationConfig<A> {
843    fn default() -> Self {
844        Self {
845            optimizer_type: OptimizerType::Adam,
846            learning_rate: A::from(0.001).unwrap(),
847            batch_size: 32,
848            gradient_clip_norm: Some(A::from(1.0).unwrap()),
849            regularization_strength: A::from(1e-4).unwrap(),
850            lr_schedule: LearningRateScheduleType::Constant,
851            specialized_params: HashMap::new(),
852        }
853    }
854}
855
856/// Domain-specific recommendation
857#[derive(Debug, Clone)]
858pub struct DomainRecommendation<A: Float> {
859    /// Type of recommendation
860    pub recommendation_type: RecommendationType,
861    /// Human-readable description
862    pub description: String,
863    /// Confidence in the recommendation (0.0-1.0)
864    pub confidence: A,
865    /// Suggested action
866    pub action: String,
867}
868
869/// Types of domain recommendations
870#[derive(Debug, Clone)]
871pub enum RecommendationType {
872    /// Performance baseline information
873    PerformanceBaseline,
874    /// Transfer learning suggestion
875    TransferLearning,
876    /// Hyperparameter adjustment
877    HyperparameterTuning,
878    /// Architecture modification
879    ArchitectureChange,
880    /// Resource optimization
881    ResourceOptimization,
882}
883
884#[cfg(test)]
885mod tests {
886    use super::*;
887    use crate::adaptive_selection::ProblemType;
888
889    #[test]
890    fn test_domain_specific_selector_creation() {
891        let strategy = DomainStrategy::ComputerVision {
892            resolution_adaptive: true,
893            batch_norm_tuning: true,
894            augmentation_aware: true,
895        };
896
897        let selector = DomainSpecificSelector::<f64>::new(strategy);
898        assert_eq!(selector.config.optimizer_ranking[0], OptimizerType::AdamW);
899    }
900
901    #[test]
902    fn test_computer_vision_optimization() {
903        let strategy = DomainStrategy::ComputerVision {
904            resolution_adaptive: true,
905            batch_norm_tuning: true,
906            augmentation_aware: true,
907        };
908
909        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
910
911        let context = OptimizationContext {
912            problem_chars: ProblemCharacteristics {
913                dataset_size: 50000,
914                input_dim: 224 * 224 * 3, // Standard ImageNet resolution
915                output_dim: 1000,
916                problem_type: ProblemType::ComputerVision,
917                gradient_sparsity: 0.1,
918                gradient_noise: 0.05,
919                memory_budget: 8_000_000_000,
920                time_budget: 3600.0,
921                batch_size: 64,
922                lr_sensitivity: 0.5,
923                regularization_strength: 0.01,
924                architecture_type: Some("ResNet".to_string()),
925            },
926            resource_constraints: ResourceConstraints {
927                max_memory: 17_000_000_000, // Slightly above 16GB to trigger 128 batch size
928                max_time: 7200.0,
929                gpu_available: true,
930                distributed_capable: false,
931                energy_efficient: false,
932            },
933            training_config: TrainingConfiguration {
934                max_epochs: 100,
935                early_stopping_patience: 10,
936                validation_frequency: 1,
937                lr_schedule_type: LearningRateScheduleType::CosineAnnealing { t_max: 100 },
938                regularization_approach: RegularizationApproach::L2Only { weight: 1e-4 },
939            },
940            domain_metadata: HashMap::new(),
941        };
942
943        selector.setcontext(context);
944        let config = selector.select_optimal_config().unwrap();
945
946        assert_eq!(config.optimizer_type, OptimizerType::AdamW);
947        assert_eq!(config.batch_size, 128); // Should select larger batch size for high memory
948        assert!(config.gradient_clip_norm.is_some());
949    }
950
951    #[test]
952    fn test_natural_language_optimization() {
953        let strategy = DomainStrategy::NaturalLanguage {
954            sequence_adaptive: true,
955            attention_optimized: true,
956            vocab_aware: true,
957        };
958
959        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
960
961        let context = OptimizationContext {
962            problem_chars: ProblemCharacteristics {
963                dataset_size: 100000,
964                input_dim: 512,    // Sequence length
965                output_dim: 50000, // Large vocabulary
966                problem_type: ProblemType::NaturalLanguage,
967                gradient_sparsity: 0.2,
968                gradient_noise: 0.1,
969                memory_budget: 32_000_000_000,
970                time_budget: 7200.0,
971                batch_size: 32,
972                lr_sensitivity: 0.8,
973                regularization_strength: 0.1,
974                architecture_type: Some("Transformer".to_string()),
975            },
976            resource_constraints: ResourceConstraints {
977                max_memory: 32_000_000_000,
978                max_time: 10800.0,
979                gpu_available: true,
980                distributed_capable: true,
981                energy_efficient: false,
982            },
983            training_config: TrainingConfiguration {
984                max_epochs: 50,
985                early_stopping_patience: 5,
986                validation_frequency: 1,
987                lr_schedule_type: LearningRateScheduleType::OneCycle { max_lr: 2e-5 },
988                regularization_approach: RegularizationApproach::Dropout { dropout_rate: 0.1 },
989            },
990            domain_metadata: HashMap::new(),
991        };
992
993        selector.setcontext(context);
994        let config = selector.select_optimal_config().unwrap();
995
996        assert_eq!(config.optimizer_type, OptimizerType::AdamW);
997        assert!(config.specialized_params.contains_key("warmup_steps"));
998        assert!(config.specialized_params.contains_key("tie_embeddings")); // Large vocab
999    }
1000
1001    #[test]
1002    fn test_time_series_optimization() {
1003        let strategy = DomainStrategy::TimeSeries {
1004            temporal_aware: true,
1005            seasonality_adaptive: true,
1006            multi_step: true,
1007        };
1008
1009        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1010
1011        let context = OptimizationContext {
1012            problem_chars: ProblemCharacteristics {
1013                dataset_size: 10000,
1014                input_dim: 168, // One week of hourly data
1015                output_dim: 24, // Next 24 hours
1016                problem_type: ProblemType::TimeSeries,
1017                gradient_sparsity: 0.05,
1018                gradient_noise: 0.2,
1019                memory_budget: 4_000_000_000,
1020                time_budget: 1800.0,
1021                batch_size: 32,
1022                lr_sensitivity: 0.7,
1023                regularization_strength: 0.01,
1024                architecture_type: Some("LSTM".to_string()),
1025            },
1026            resource_constraints: ResourceConstraints {
1027                max_memory: 8_000_000_000,
1028                max_time: 3600.0,
1029                gpu_available: true,
1030                distributed_capable: false,
1031                energy_efficient: true,
1032            },
1033            training_config: TrainingConfiguration {
1034                max_epochs: 200,
1035                early_stopping_patience: 20,
1036                validation_frequency: 5,
1037                lr_schedule_type: LearningRateScheduleType::ReduceOnPlateau {
1038                    patience: 10,
1039                    factor: 0.5,
1040                },
1041                regularization_approach: RegularizationApproach::L2Only { weight: 1e-4 },
1042            },
1043            domain_metadata: HashMap::new(),
1044        };
1045
1046        selector.setcontext(context);
1047        let config = selector.select_optimal_config().unwrap();
1048
1049        assert_eq!(config.optimizer_type, OptimizerType::RMSprop);
1050        assert_eq!(config.batch_size, 32);
1051        assert!(config.specialized_params.contains_key("seasonal_periods"));
1052        assert!(config.specialized_params.contains_key("prediction_horizon"));
1053    }
1054
1055    #[test]
1056    fn test_performance_tracking() {
1057        let strategy = DomainStrategy::ComputerVision {
1058            resolution_adaptive: true,
1059            batch_norm_tuning: false,
1060            augmentation_aware: false,
1061        };
1062
1063        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1064
1065        let metrics = DomainPerformanceMetrics {
1066            validation_accuracy: 0.95,
1067            domain_specific_score: 0.92,
1068            stability_score: 0.88,
1069            convergence_epochs: 50,
1070            resource_efficiency: 0.85,
1071            transfer_score: 0.7,
1072        };
1073
1074        selector.update_domain_performance("computer_vision".to_string(), metrics);
1075
1076        let recommendations = selector.get_domain_recommendations("computer_vision");
1077        assert!(!recommendations.is_empty());
1078        assert!(recommendations[0].description.contains("0.95"));
1079    }
1080
1081    #[test]
1082    fn test_cross_domain_transfer() {
1083        let strategy = DomainStrategy::ComputerVision {
1084            resolution_adaptive: true,
1085            batch_norm_tuning: true,
1086            augmentation_aware: true,
1087        };
1088
1089        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1090
1091        let transfer_knowledge = CrossDomainKnowledge {
1092            source_domain: "natural_language".to_string(),
1093            target_domain: "computer_vision".to_string(),
1094            transferable_params: HashMap::from([
1095                ("learning_rate".to_string(), 0.001),
1096                ("weight_decay".to_string(), 0.01),
1097            ]),
1098            transfer_score: 0.8,
1099            successful_strategy: OptimizerType::AdamW,
1100        };
1101
1102        selector.record_transfer_knowledge(transfer_knowledge);
1103
1104        let recommendations = selector.get_domain_recommendations("computer_vision");
1105        assert!(recommendations
1106            .iter()
1107            .any(|r| matches!(r.recommendation_type, RecommendationType::TransferLearning)));
1108    }
1109
1110    #[test]
1111    fn test_scientific_computing_optimization() {
1112        let strategy = DomainStrategy::ScientificComputing {
1113            stability_focused: true,
1114            precision_critical: true,
1115            sparse_optimized: false,
1116        };
1117
1118        let mut selector = DomainSpecificSelector::<f64>::new(strategy);
1119
1120        let context = OptimizationContext {
1121            problem_chars: ProblemCharacteristics {
1122                dataset_size: 1000,
1123                input_dim: 100,
1124                output_dim: 1,
1125                problem_type: ProblemType::Regression,
1126                gradient_sparsity: 0.01,
1127                gradient_noise: 0.01,
1128                memory_budget: 16_000_000_000,
1129                time_budget: 7200.0,
1130                batch_size: 100,
1131                lr_sensitivity: 0.3,
1132                regularization_strength: 1e-6,
1133                architecture_type: Some("MLP".to_string()),
1134            },
1135            resource_constraints: ResourceConstraints {
1136                max_memory: 16_000_000_000,
1137                max_time: 7200.0,
1138                gpu_available: false,
1139                distributed_capable: false,
1140                energy_efficient: false,
1141            },
1142            training_config: TrainingConfiguration {
1143                max_epochs: 1000,
1144                early_stopping_patience: 100,
1145                validation_frequency: 10,
1146                lr_schedule_type: LearningRateScheduleType::Constant,
1147                regularization_approach: RegularizationApproach::L2Only { weight: 1e-6 },
1148            },
1149            domain_metadata: HashMap::new(),
1150        };
1151
1152        selector.setcontext(context);
1153        let config = selector.select_optimal_config().unwrap();
1154
1155        assert_eq!(config.optimizer_type, OptimizerType::LBFGS);
1156        assert!(config.gradient_clip_norm.is_none()); // No clipping for precision
1157        assert!(config
1158            .specialized_params
1159            .contains_key("convergence_tolerance"));
1160    }
1161}