optirs_nas/nas_engine/
config.rs

1// Neural Architecture Search Configuration
2//
3// This module contains all configuration types, enums, and parameter definitions
4// for the Neural Architecture Search system.
5
6use crate::EvaluationMetric;
7use scirs2_core::numeric::Float;
8use std::collections::HashMap;
9use std::fmt::Debug;
10use std::time::Duration;
11
12/// Neural Architecture Search configuration for optimizers
13#[derive(Debug, Clone)]
14pub struct NASConfig<T: Float + Debug + Send + Sync + 'static> {
15    /// Search strategy to use
16    pub search_strategy: SearchStrategyType,
17
18    /// Architecture search space
19    pub search_space: SearchSpaceConfig,
20
21    /// Performance evaluation configuration
22    pub evaluation_config: EvaluationConfig<T>,
23
24    /// Multi-objective optimization settings
25    pub multi_objective_config: MultiObjectiveConfig<T>,
26
27    /// Search budget (number of architectures to evaluate)
28    pub search_budget: usize,
29
30    /// Early stopping criteria
31    pub early_stopping: EarlyStoppingConfig<T>,
32
33    /// Enable progressive search
34    pub progressive_search: bool,
35
36    /// Population size for evolutionary/genetic algorithms
37    pub population_size: usize,
38
39    /// Enable architecture transfer learning
40    pub enable_transfer_learning: bool,
41
42    /// Architecture encoding strategy
43    pub encoding_strategy: ArchitectureEncodingStrategy,
44
45    /// Enable performance prediction
46    pub enable_performance_prediction: bool,
47
48    /// Search parallelization factor
49    pub parallelization_factor: usize,
50
51    /// Enable automated hyperparameter tuning
52    pub auto_hyperparameter_tuning: bool,
53
54    /// Resource constraints
55    pub resource_constraints: ResourceConstraints<T>,
56}
57
58/// Search strategy types
59#[derive(Debug, Clone)]
60pub enum SearchStrategyType {
61    /// Random search baseline
62    Random,
63
64    /// Evolutionary/genetic algorithm search
65    Evolutionary,
66
67    /// Reinforcement learning-based search
68    ReinforcementLearning,
69
70    /// Differentiable architecture search (DARTS)
71    Differentiable,
72
73    /// Bayesian optimization
74    BayesianOptimization,
75
76    /// Progressive search
77    Progressive,
78
79    /// Multi-objective evolutionary algorithm
80    MultiObjectiveEvolutionary,
81
82    /// Neural predictor-based search
83    NeuralPredictorBased,
84}
85
86/// Search space configuration
87#[derive(Debug, Clone)]
88pub struct SearchSpaceConfig {
89    /// Available optimizer components
90    pub components: Vec<OptimizerComponentConfig>,
91
92    /// Connection patterns between components
93    pub connection_patterns: Vec<ConnectionPatternType>,
94
95    /// Learning rate schedule search space
96    pub learning_rate_schedules: LearningRateScheduleSpace,
97
98    /// Regularization technique search space
99    pub regularization_techniques: RegularizationSpace,
100
101    /// Adaptive mechanism search space
102    pub adaptive_mechanisms: AdaptiveMechanismSpace,
103
104    /// Memory usage constraints
105    pub memory_constraints: MemoryConstraints,
106
107    /// Computation constraints
108    pub computation_constraints: ComputationConstraints,
109
110    /// Component types to include in search
111    pub component_types: Vec<ComponentType>,
112
113    /// Maximum number of components
114    pub max_components: usize,
115
116    /// Minimum number of components
117    pub min_components: usize,
118
119    /// Maximum number of connections
120    pub max_connections: usize,
121}
122
123/// Optimizer component configuration for search
124#[derive(Debug, Clone)]
125pub struct OptimizerComponentConfig {
126    /// Component type
127    pub component_type: ComponentType,
128
129    /// Hyperparameter search ranges
130    pub hyperparameter_ranges: HashMap<String, ParameterRange>,
131
132    /// Component complexity score
133    pub complexity_score: f64,
134
135    /// Memory requirement estimate
136    pub memory_requirement: usize,
137
138    /// Computational cost estimate
139    pub computational_cost: f64,
140
141    /// Compatibility constraints
142    pub compatibility_constraints: Vec<CompatibilityConstraint>,
143}
144
145/// Architecture constraints for components
146#[derive(Debug, Clone, Default)]
147pub struct ArchitectureConstraints {
148    /// Minimum parameter count
149    pub min_parameters: Option<usize>,
150    /// Maximum parameter count
151    pub max_parameters: Option<usize>,
152    /// Required input dimensions
153    pub input_dimensions: Vec<usize>,
154    /// Output dimension constraints
155    pub output_dimensions: Vec<usize>,
156    /// Compatibility rules
157    pub compatibility_rules: Vec<String>,
158}
159
160/// Component type configuration
161#[derive(Debug, Clone)]
162pub struct ComponentTypeConfig {
163    /// Component type name
164    pub name: String,
165    /// Available parameters
166    pub parameters: Vec<String>,
167    /// Default configuration
168    pub defaults: HashMap<String, String>,
169    /// Whether this component type is enabled
170    pub enabled: bool,
171    /// Probability of selection during search
172    pub probability: f64,
173    /// Hyperparameter ranges for this component
174    pub hyperparameter_ranges: HashMap<String, ParameterRange>,
175    /// Dependencies on other components
176    pub dependencies: Vec<String>,
177    /// Component type
178    pub component_type: ComponentType,
179    /// Architecture constraints
180    pub constraints: ArchitectureConstraints,
181}
182
183/// Component types for optimizers
184#[derive(Debug, Clone, PartialEq, Eq, Hash)]
185pub enum ComponentType {
186    /// Gradient computation methods
187    GradientComputation,
188
189    /// Momentum techniques
190    Momentum,
191
192    /// Adaptive learning rate methods
193    AdaptiveLearningRate,
194
195    /// Regularization techniques
196    Regularization,
197
198    /// Normalization methods
199    Normalization,
200
201    /// Second-order methods
202    SecondOrder,
203
204    /// Preconditioning techniques
205    Preconditioning,
206
207    /// Learning rate scheduling
208    LearningRateScheduling,
209
210    /// Gradient clipping methods
211    GradientClipping,
212
213    /// Memory management
214    MemoryManagement,
215
216    /// Convergence acceleration
217    ConvergenceAcceleration,
218
219    /// Specific optimizer types
220    SGD,
221    Adam,
222    AdamW,
223    RMSprop,
224    AdaGrad,
225    AdaDelta,
226    Nesterov,
227    LRScheduler,
228    BatchNorm,
229    Dropout,
230    LAMB,
231    LARS,
232    Lion,
233    RAdam,
234    Lookahead,
235    SAM,
236    LBFGS,
237    SparseAdam,
238    GroupedAdam,
239    MAML,
240    Reptile,
241    MetaSGD,
242    ConstantLR,
243    ExponentialLR,
244    StepLR,
245    CosineAnnealingLR,
246    OneCycleLR,
247    CyclicLR,
248    L1Regularizer,
249    L2Regularizer,
250    ElasticNetRegularizer,
251    DropoutRegularizer,
252    WeightDecay,
253    AdaptiveLR,
254    AdaptiveMomentum,
255    AdaptiveRegularization,
256    LSTMOptimizer,
257    TransformerOptimizer,
258    AttentionOptimizer,
259
260    /// Custom components
261    Custom(String),
262}
263
264/// Parameter range for hyperparameter search
265#[derive(Debug, Clone)]
266pub enum ParameterRange {
267    /// Continuous range [min, max]
268    Continuous(f64, f64),
269
270    /// Discrete set of values
271    Discrete(Vec<f64>),
272
273    /// Integer range [min, max]
274    Integer(i32, i32),
275
276    /// Boolean choice
277    Boolean,
278
279    /// Categorical choice
280    Categorical(Vec<String>),
281
282    /// Log-uniform distribution
283    LogUniform(f64, f64),
284}
285
286/// Connection pattern types
287#[derive(Debug, Clone)]
288pub enum ConnectionPatternType {
289    /// Sequential connection
290    Sequential,
291
292    /// Parallel branches
293    Parallel,
294
295    /// Skip connections
296    SkipConnection,
297
298    /// Dense connections
299    DenseConnection,
300
301    /// Residual connections
302    ResidualConnection,
303
304    /// Attention-based connections
305    AttentionConnection,
306
307    /// Highway connections
308    HighwayConnection,
309
310    /// Squeeze-and-excitation connections
311    SqueezeExcitationConnection,
312
313    /// Custom connection patterns
314    Custom(String),
315}
316
317/// Learning rate schedule search space
318#[derive(Debug, Clone)]
319pub struct LearningRateScheduleSpace {
320    /// Available schedule types
321    pub schedule_types: Vec<ScheduleType>,
322
323    /// Initial learning rate range
324    pub initial_lr_range: ParameterRange,
325
326    /// Schedule-specific parameters
327    pub schedule_parameters: HashMap<ScheduleType, HashMap<String, ParameterRange>>,
328}
329
330/// Schedule types for learning rate
331#[derive(Debug, Clone, PartialEq, Eq, Hash)]
332pub enum ScheduleType {
333    Constant,
334    StepDecay,
335    ExponentialDecay,
336    CosineAnnealing,
337    CyclicalLR,
338    OneCycleLR,
339    ReduceOnPlateau,
340    WarmupLinear,
341    WarmupCosine,
342    Custom(String),
343}
344
345/// Regularization search space
346#[derive(Debug, Clone)]
347pub struct RegularizationSpace {
348    /// Available regularization techniques
349    pub techniques: Vec<RegularizationTechnique>,
350
351    /// Regularization strength ranges
352    pub strength_ranges: HashMap<RegularizationTechnique, ParameterRange>,
353
354    /// Combination strategies
355    pub combination_strategies: Vec<String>,
356}
357
358/// Regularization techniques
359#[derive(Debug, Clone, PartialEq, Eq, Hash)]
360pub enum RegularizationTechnique {
361    L1,
362    L2,
363    Dropout,
364    DropConnect,
365    BatchNormalization,
366    LayerNormalization,
367    GroupNormalization,
368    SpectralNormalization,
369    WeightDecay,
370    EarlyStopping,
371    Custom(String),
372}
373
374/// Adaptive mechanism search space
375#[derive(Debug, Clone)]
376pub struct AdaptiveMechanismSpace {
377    /// Available adaptation strategies
378    pub adaptation_strategies: Vec<AdaptationStrategy>,
379
380    /// Adaptation parameters
381    pub adaptation_parameters: HashMap<AdaptationStrategy, HashMap<String, ParameterRange>>,
382
383    /// Adaptation frequency options
384    pub adaptation_frequencies: Vec<usize>,
385}
386
387/// Adaptation strategies
388#[derive(Debug, Clone, PartialEq, Eq, Hash)]
389pub enum AdaptationStrategy {
390    PerformanceBased,
391    GradientBased,
392    LossBased,
393    TimeBased,
394    HybridAdaptation,
395    Custom(String),
396}
397
398/// Memory constraints for architecture search
399#[derive(Debug, Clone)]
400pub struct MemoryConstraints {
401    /// Maximum memory usage (bytes)
402    pub max_memory_bytes: usize,
403
404    /// Memory usage per component limits
405    pub component_memory_limits: HashMap<ComponentType, usize>,
406
407    /// Enable memory optimization
408    pub enable_memory_optimization: bool,
409
410    /// Memory allocation strategy
411    pub allocation_strategy: MemoryAllocationStrategy,
412}
413
414/// Memory allocation strategies
415#[derive(Debug, Clone)]
416pub enum MemoryAllocationStrategy {
417    Static,
418    Dynamic,
419    Adaptive,
420    Lazy,
421}
422
423/// Computation constraints
424#[derive(Debug, Clone)]
425pub struct ComputationConstraints {
426    /// Maximum computational cost
427    pub max_computational_cost: f64,
428
429    /// Cost per component limits
430    pub component_cost_limits: HashMap<ComponentType, f64>,
431
432    /// Enable computation optimization
433    pub enable_computation_optimization: bool,
434
435    /// Parallelization constraints
436    pub parallelization_constraints: ParallelizationConstraints,
437}
438
439/// Parallelization constraints
440#[derive(Debug, Clone)]
441pub struct ParallelizationConstraints {
442    /// Maximum parallel workers
443    pub max_workers: usize,
444
445    /// Minimum batch size for parallelization
446    pub min_batch_size: usize,
447
448    /// Enable SIMD optimization
449    pub enable_simd: bool,
450
451    /// Enable GPU acceleration
452    pub enable_gpu: bool,
453}
454
455/// Compatibility constraint between components
456#[derive(Debug, Clone)]
457pub struct CompatibilityConstraint {
458    /// Constraint type
459    pub constraint_type: CompatibilityType,
460
461    /// Target components
462    pub target_components: Vec<ComponentType>,
463
464    /// Constraint condition
465    pub condition: ConstraintCondition,
466}
467
468/// Compatibility constraint types
469#[derive(Debug, Clone)]
470pub enum CompatibilityType {
471    /// Components must be used together
472    Requires,
473
474    /// Components cannot be used together
475    Excludes,
476
477    /// Components have conditional compatibility
478    Conditional,
479
480    /// Components have version requirements
481    VersionRequirement,
482
483    /// Components have parameter constraints
484    ParameterConstraint,
485
486    /// Custom compatibility rules
487    Custom(String),
488}
489
490/// Constraint conditions
491#[derive(Debug, Clone)]
492pub enum ConstraintCondition {
493    /// Always applies
494    Always,
495
496    /// Never applies
497    Never,
498
499    /// Applies under certain parameter conditions
500    ParameterCondition(ParameterCondition),
501
502    /// Applies based on architecture properties
503    ArchitectureCondition(String),
504
505    /// Custom condition logic
506    Custom(String),
507}
508
509/// Parameter-based constraint conditions
510#[derive(Debug, Clone)]
511pub enum ParameterCondition {
512    /// Parameter equals specific value
513    Equals(String, f64),
514
515    /// Parameter is greater than value
516    GreaterThan(String, f64),
517
518    /// Parameter is less than value
519    LessThan(String, f64),
520
521    /// Parameter is within range
522    InRange(String, f64, f64),
523
524    /// Parameter is in set of values
525    InSet(String, Vec<f64>),
526
527    /// Complex boolean conditions
528    BooleanExpression(String),
529}
530
531/// Performance evaluation configuration
532#[derive(Debug, Clone)]
533pub struct EvaluationConfig<T: Float + Debug + Send + Sync + 'static> {
534    /// Evaluation metrics to use
535    pub metrics: Vec<EvaluationMetric>,
536
537    /// Benchmark datasets
538    pub benchmark_datasets: Vec<BenchmarkDataset>,
539
540    /// Evaluation budget (time/iterations)
541    pub evaluation_budget: EvaluationBudget,
542
543    /// Statistical testing configuration
544    pub statistical_testing: StatisticalTestingConfig,
545
546    /// Cross-validation settings
547    pub cross_validation_folds: usize,
548
549    /// Enable early stopping during evaluation
550    pub enable_early_stopping: bool,
551
552    /// Evaluation parallelization
553    pub parallelization_factor: usize,
554
555    /// Problem domain specification
556    pub problem_domains: Vec<ProblemDomain>,
557
558    /// Resource limits for evaluation
559    pub resource_limits: HashMap<ResourceType, T>,
560}
561
562/// Problem domains for evaluation
563#[derive(Debug, Clone)]
564pub enum ProblemDomain {
565    Classification,
566    Regression,
567    Reinforcement,
568    GenerativeModeling,
569    SequenceModeling,
570    ComputerVision,
571    NaturalLanguageProcessing,
572    TimeSeriesForecasting,
573    Custom(String),
574}
575
576/// Resource types for evaluation limits
577#[derive(Debug, Clone, PartialEq, Eq, Hash)]
578pub enum ResourceType {
579    Memory,
580    CPUTime,
581    GPUTime,
582    NetworkBandwidth,
583    StorageSpace,
584}
585
586/// Benchmark dataset configuration
587#[derive(Debug, Clone)]
588pub struct BenchmarkDataset {
589    /// Dataset name
590    pub name: String,
591
592    /// Dataset path or URL
593    pub path: String,
594
595    /// Dataset characteristics
596    pub characteristics: DatasetCharacteristics,
597
598    /// Problem type
599    pub problem_type: ProblemType,
600
601    /// Evaluation weight
602    pub weight: f64,
603
604    /// Enable data augmentation
605    pub enable_augmentation: bool,
606}
607
608/// Dataset characteristics
609#[derive(Debug, Clone)]
610pub struct DatasetCharacteristics {
611    /// Number of samples
612    pub num_samples: usize,
613
614    /// Number of features
615    pub num_features: usize,
616
617    /// Number of classes (for classification)
618    pub num_classes: Option<usize>,
619
620    /// Dataset size category
621    pub size_category: DatasetSizeCategory,
622
623    /// Feature correlation structure
624    pub correlation_structure: CorrelationStructure,
625
626    /// Noise level estimate
627    pub noise_level: f64,
628
629    /// Imbalance ratio (for classification)
630    pub imbalance_ratio: Option<f64>,
631}
632
633/// Problem types for evaluation
634#[derive(Debug, Clone)]
635pub enum ProblemType {
636    BinaryClassification,
637    MultiClassClassification,
638    Regression,
639    MultiTaskLearning,
640    MetaLearning,
641    TransferLearning,
642    FewShotLearning,
643    Custom(String),
644}
645
646/// Dataset size categories
647#[derive(Debug, Clone)]
648pub enum DatasetSizeCategory {
649    Small,     // < 1K samples
650    Medium,    // 1K - 100K samples
651    Large,     // 100K - 1M samples
652    VeryLarge, // > 1M samples
653}
654
655/// Correlation structure in datasets
656#[derive(Debug, Clone)]
657pub enum CorrelationStructure {
658    Independent,
659    LowCorrelation,
660    ModerateCorrelation,
661    HighCorrelation,
662    BlockStructure,
663    Hierarchical,
664}
665
666/// Evaluation budget configuration
667#[derive(Debug, Clone)]
668pub struct EvaluationBudget {
669    /// Maximum evaluation time per architecture
670    pub max_time_per_architecture: Duration,
671
672    /// Maximum number of training epochs
673    pub max_epochs: usize,
674
675    /// Maximum number of function evaluations
676    pub max_function_evaluations: usize,
677
678    /// Early stopping patience
679    pub early_stopping_patience: usize,
680
681    /// Resource allocation per evaluation
682    pub resource_allocation: ResourceAllocation,
683}
684
685/// Resource allocation for evaluations
686#[derive(Debug, Clone)]
687pub struct ResourceAllocation {
688    /// CPU cores allocated
689    pub cpu_cores: usize,
690
691    /// Memory allocated (GB)
692    pub memory_gb: f64,
693
694    /// GPU devices allocated
695    pub gpu_devices: usize,
696
697    /// Storage space allocated (GB)
698    pub storage_gb: f64,
699}
700
701/// Statistical testing configuration
702#[derive(Debug, Clone)]
703pub struct StatisticalTestingConfig {
704    /// Significance level (alpha)
705    pub significance_level: f64,
706
707    /// Statistical test type
708    pub test_type: StatisticalTestType,
709
710    /// Multiple comparison correction
711    pub multiple_comparison_correction: MultipleComparisonCorrection,
712
713    /// Minimum effect size
714    pub min_effect_size: f64,
715
716    /// Bootstrap samples for confidence intervals
717    pub bootstrap_samples: usize,
718}
719
720/// Statistical test types
721#[derive(Debug, Clone)]
722pub enum StatisticalTestType {
723    TTest,
724    WilcoxonRankSum,
725    KruskalWallis,
726    FriedmanTest,
727    Bootstrap,
728}
729
730/// Multiple comparison correction methods
731#[derive(Debug, Clone)]
732pub enum MultipleComparisonCorrection {
733    None,
734    Bonferroni,
735    BenjaminiHochberg,
736    BenjaminiYekutieli,
737    Holm,
738}
739
740/// Multi-objective optimization configuration
741#[derive(Debug, Clone)]
742pub struct MultiObjectiveConfig<T: Float + Debug + Send + Sync + 'static> {
743    /// List of objectives to optimize
744    pub objectives: Vec<ObjectiveConfig<T>>,
745
746    /// Multi-objective algorithm to use
747    pub algorithm: MultiObjectiveAlgorithm,
748
749    /// User preferences for trade-offs
750    pub user_preferences: Option<UserPreferences<T>>,
751
752    /// Diversity promotion strategy
753    pub diversity_strategy: DiversityStrategy,
754
755    /// Constraint handling method
756    pub constraint_handling: ConstraintHandlingMethod,
757}
758
759/// Objective configuration
760#[derive(Debug, Clone)]
761pub struct ObjectiveConfig<T: Float + Debug + Send + Sync + 'static> {
762    /// Objective name
763    pub name: String,
764
765    /// Objective type
766    pub objective_type: ObjectiveType,
767
768    /// Optimization direction
769    pub direction: OptimizationDirection,
770
771    /// Objective weight
772    pub weight: T,
773
774    /// Objective priority
775    pub priority: ObjectivePriority,
776
777    /// Normalization bounds
778    pub normalization_bounds: Option<(T, T)>,
779}
780
781/// Objective types
782#[derive(Debug, Clone)]
783pub enum ObjectiveType {
784    Accuracy,
785    Loss,
786    TrainingTime,
787    InferenceTime,
788    MemoryUsage,
789    EnergyConsumption,
790    ModelSize,
791    Robustness,
792    Fairness,
793    Performance,
794    Efficiency,
795    Interpretability,
796    Privacy,
797    Sustainability,
798    Cost,
799    Custom(String),
800}
801
802/// Optimization directions
803#[derive(Debug, Clone)]
804pub enum OptimizationDirection {
805    Minimize,
806    Maximize,
807}
808
809/// Objective priorities
810#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
811pub enum ObjectivePriority {
812    Low,
813    Medium,
814    High,
815    Critical,
816}
817
818/// Multi-objective algorithms
819#[derive(Debug, Clone)]
820pub enum MultiObjectiveAlgorithm {
821    NSGA2,
822    NSGA3,
823    MOEAD,
824    SPEA2,
825    PAES,
826    WeightedSum,
827    EpsilonConstraint,
828    GoalProgramming,
829    Custom(String),
830}
831
832/// User preferences for multi-objective optimization
833#[derive(Debug, Clone)]
834pub struct UserPreferences<T: Float + Debug + Send + Sync + 'static> {
835    /// Preference type
836    pub preference_type: PreferenceType<T>,
837
838    /// Reference point (for reference point methods)
839    pub reference_point: Option<Vec<T>>,
840
841    /// Aspiration levels
842    pub aspiration_levels: Option<Vec<T>>,
843
844    /// Reservation levels
845    pub reservation_levels: Option<Vec<T>>,
846}
847
848/// Preference specification types
849#[derive(Debug, Clone)]
850pub enum PreferenceType<T: Float + Debug + Send + Sync + 'static> {
851    WeightVector(Vec<T>),
852    ReferencePoint(Vec<T>),
853    GoalVector(Vec<T>),
854    RankingOrder(Vec<usize>),
855    PairwiseComparison(Vec<(usize, usize)>),
856}
857
858/// Diversity promotion strategies
859#[derive(Debug, Clone)]
860pub enum DiversityStrategy {
861    Crowding,
862    Sharing,
863    Clearing,
864    Clustering,
865    Novelty,
866    Custom(String),
867}
868
869/// Constraint handling methods
870#[derive(Debug, Clone)]
871pub enum ConstraintHandlingMethod {
872    PenaltyFunction,
873    DeathPenalty,
874    Repair,
875    Decoder,
876    PreserveFeasibility,
877    Custom(String),
878}
879
880/// Early stopping configuration
881#[derive(Debug, Clone)]
882pub struct EarlyStoppingConfig<T: Float + Debug + Send + Sync + 'static> {
883    /// Enable early stopping
884    pub enabled: bool,
885
886    /// Patience (number of generations without improvement)
887    pub patience: usize,
888
889    /// Minimum improvement threshold
890    pub min_improvement: T,
891
892    /// Convergence detection strategy
893    pub convergence_strategy: ConvergenceDetectionStrategy,
894
895    /// Minimum number of generations before stopping
896    pub min_generations: usize,
897
898    /// Metric to monitor for early stopping
899    pub metric: EvaluationMetric,
900
901    /// Target performance value
902    pub target_performance: Option<T>,
903
904    /// Convergence detection strategy
905    pub convergence_detection: ConvergenceDetectionStrategy,
906}
907
908/// Convergence detection strategies
909#[derive(Debug, Clone)]
910pub enum ConvergenceDetectionStrategy {
911    BestScore,
912    AverageScore,
913    PopulationDiversity,
914    ParetoFrontStability,
915    NoImprovement,
916    Custom(String),
917}
918
919/// Architecture encoding strategies
920#[derive(Debug, Clone)]
921pub enum ArchitectureEncodingStrategy {
922    /// Direct parameter encoding
923    Direct,
924
925    /// Graph-based encoding
926    Graph,
927
928    /// String-based encoding
929    String,
930
931    /// Binary encoding
932    Binary,
933
934    /// Real-valued encoding
935    RealValued,
936
937    /// Hierarchical encoding
938    Hierarchical,
939
940    /// Neural encoding (using neural networks)
941    Neural,
942
943    /// Hybrid encoding strategies
944    Hybrid,
945
946    /// Custom encoding
947    Custom(String),
948}
949
950/// Resource constraints for NAS
951#[derive(Debug, Clone)]
952pub struct ResourceConstraints<T: Float + Debug + Send + Sync + 'static> {
953    /// Hardware resource limits
954    pub hardware_resources: HardwareResources,
955
956    /// Time constraints
957    pub time_constraints: TimeConstraints,
958
959    /// Energy constraints
960    pub energy_constraints: Option<T>,
961
962    /// Cost constraints
963    pub cost_constraints: Option<T>,
964
965    /// Resource violation handling
966    pub violation_handling: ResourceViolationHandling,
967
968    /// Maximum memory in GB
969    pub max_memory_gb: T,
970
971    /// Maximum computation hours
972    pub max_computation_hours: T,
973
974    /// Maximum energy in kWh
975    pub max_energy_kwh: T,
976
977    /// Maximum cost in USD
978    pub max_cost_usd: T,
979
980    /// Enable resource monitoring
981    pub enable_monitoring: bool,
982}
983
984/// Hardware resource specifications
985#[derive(Debug, Clone)]
986pub struct HardwareResources {
987    /// Maximum memory usage (GB)
988    pub max_memory_gb: f64,
989
990    /// Maximum CPU cores
991    pub max_cpu_cores: usize,
992
993    /// Maximum GPU devices
994    pub max_gpu_devices: usize,
995
996    /// Maximum storage space (GB)
997    pub max_storage_gb: f64,
998
999    /// Network bandwidth limits (MB/s)
1000    pub max_network_bandwidth: f64,
1001
1002    /// Enable cloud resource scaling
1003    pub enable_cloud_scaling: bool,
1004
1005    /// Cloud resource budget
1006    pub cloud_budget: Option<f64>,
1007
1008    /// CPU cores available
1009    pub cpu_cores: usize,
1010
1011    /// RAM in GB
1012    pub ram_gb: u32,
1013
1014    /// Number of GPUs
1015    pub num_gpus: usize,
1016
1017    /// GPU memory in GB
1018    pub gpu_memory_gb: u32,
1019
1020    /// Storage in GB
1021    pub storage_gb: u32,
1022
1023    /// Network bandwidth in Mbps
1024    pub network_bandwidth_mbps: f32,
1025}
1026
1027/// Time constraints for search
1028#[derive(Debug, Clone)]
1029pub struct TimeConstraints {
1030    /// Maximum total search time
1031    pub max_search_time: Duration,
1032
1033    /// Maximum time per architecture evaluation
1034    pub max_evaluation_time: Duration,
1035
1036    /// Search deadline
1037    pub search_deadline: Option<std::time::Instant>,
1038
1039    /// Time budget allocation strategy
1040    pub budget_allocation: TimeBudgetAllocation,
1041}
1042
1043/// Time budget allocation strategies
1044#[derive(Debug, Clone)]
1045pub enum TimeBudgetAllocation {
1046    Uniform,
1047    AdaptiveByComplexity,
1048    AdaptiveByPerformance,
1049    PriorityBased,
1050    Custom(String),
1051}
1052
1053/// Resource violation handling strategies
1054#[derive(Debug, Clone)]
1055pub enum ResourceViolationHandling {
1056    /// Stop search immediately
1057    Abort,
1058
1059    /// Skip violating architectures
1060    Skip,
1061
1062    /// Scale down resources
1063    ScaleDown,
1064
1065    /// Use approximation methods
1066    Approximate,
1067
1068    /// Dynamic resource allocation
1069    Dynamic,
1070
1071    /// Apply penalty to objective function
1072    Penalty,
1073
1074    /// Custom handling
1075    Custom(String),
1076}
1077
1078// Default implementations
1079impl<T: Float + Debug + Send + Sync + 'static> Default for NASConfig<T> {
1080    fn default() -> Self {
1081        Self {
1082            search_strategy: SearchStrategyType::Evolutionary,
1083            search_space: SearchSpaceConfig::default(),
1084            evaluation_config: EvaluationConfig::default(),
1085            multi_objective_config: MultiObjectiveConfig::default(),
1086            search_budget: 100,
1087            early_stopping: EarlyStoppingConfig::default(),
1088            progressive_search: false,
1089            population_size: 20,
1090            enable_transfer_learning: false,
1091            encoding_strategy: ArchitectureEncodingStrategy::Graph,
1092            enable_performance_prediction: false,
1093            parallelization_factor: 1,
1094            auto_hyperparameter_tuning: false,
1095            resource_constraints: ResourceConstraints::default(),
1096        }
1097    }
1098}
1099
1100impl Default for SearchSpaceConfig {
1101    fn default() -> Self {
1102        Self {
1103            components: Vec::new(),
1104            connection_patterns: vec![
1105                ConnectionPatternType::Sequential,
1106                ConnectionPatternType::Parallel,
1107                ConnectionPatternType::SkipConnection,
1108            ],
1109            learning_rate_schedules: LearningRateScheduleSpace::default(),
1110            regularization_techniques: RegularizationSpace::default(),
1111            adaptive_mechanisms: AdaptiveMechanismSpace::default(),
1112            memory_constraints: MemoryConstraints::default(),
1113            computation_constraints: ComputationConstraints::default(),
1114            component_types: vec![ComponentType::SGD, ComponentType::Adam],
1115            max_components: 10,
1116            min_components: 1,
1117            max_connections: 20,
1118        }
1119    }
1120}
1121
1122impl Default for LearningRateScheduleSpace {
1123    fn default() -> Self {
1124        Self {
1125            schedule_types: vec![
1126                ScheduleType::Constant,
1127                ScheduleType::StepDecay,
1128                ScheduleType::ExponentialDecay,
1129                ScheduleType::CosineAnnealing,
1130            ],
1131            initial_lr_range: ParameterRange::LogUniform(1e-5, 1e-1),
1132            schedule_parameters: HashMap::new(),
1133        }
1134    }
1135}
1136
1137impl Default for RegularizationSpace {
1138    fn default() -> Self {
1139        Self {
1140            techniques: vec![
1141                RegularizationTechnique::L1,
1142                RegularizationTechnique::L2,
1143                RegularizationTechnique::Dropout,
1144                RegularizationTechnique::WeightDecay,
1145            ],
1146            strength_ranges: HashMap::new(),
1147            combination_strategies: Vec::new(),
1148        }
1149    }
1150}
1151
1152impl Default for AdaptiveMechanismSpace {
1153    fn default() -> Self {
1154        Self {
1155            adaptation_strategies: vec![
1156                AdaptationStrategy::PerformanceBased,
1157                AdaptationStrategy::GradientBased,
1158                AdaptationStrategy::LossBased,
1159            ],
1160            adaptation_parameters: HashMap::new(),
1161            adaptation_frequencies: vec![1, 5, 10, 25, 50, 100],
1162        }
1163    }
1164}
1165
1166impl Default for MemoryConstraints {
1167    fn default() -> Self {
1168        Self {
1169            max_memory_bytes: 32 * 1024 * 1024 * 1024, // 32GB
1170            component_memory_limits: HashMap::new(),
1171            enable_memory_optimization: true,
1172            allocation_strategy: MemoryAllocationStrategy::Dynamic,
1173        }
1174    }
1175}
1176
1177impl Default for ComputationConstraints {
1178    fn default() -> Self {
1179        Self {
1180            max_computational_cost: 1000.0,
1181            component_cost_limits: HashMap::new(),
1182            enable_computation_optimization: true,
1183            parallelization_constraints: ParallelizationConstraints::default(),
1184        }
1185    }
1186}
1187
1188impl Default for ParallelizationConstraints {
1189    fn default() -> Self {
1190        Self {
1191            max_workers: num_cpus::get(),
1192            min_batch_size: 32,
1193            enable_simd: true,
1194            enable_gpu: true,
1195        }
1196    }
1197}
1198
1199impl<T: Float + Debug + Send + Sync + 'static> Default for EvaluationConfig<T> {
1200    fn default() -> Self {
1201        Self {
1202            metrics: vec![EvaluationMetric::Accuracy, EvaluationMetric::TrainingTime],
1203            benchmark_datasets: Vec::new(),
1204            evaluation_budget: EvaluationBudget::default(),
1205            statistical_testing: StatisticalTestingConfig::default(),
1206            cross_validation_folds: 5,
1207            enable_early_stopping: true,
1208            parallelization_factor: 1,
1209            problem_domains: vec![ProblemDomain::Classification],
1210            resource_limits: HashMap::new(),
1211        }
1212    }
1213}
1214
1215impl Default for EvaluationBudget {
1216    fn default() -> Self {
1217        Self {
1218            max_time_per_architecture: Duration::from_secs(300),
1219            max_epochs: 100,
1220            max_function_evaluations: 1000,
1221            early_stopping_patience: 10,
1222            resource_allocation: ResourceAllocation::default(),
1223        }
1224    }
1225}
1226
1227impl Default for ResourceAllocation {
1228    fn default() -> Self {
1229        Self {
1230            cpu_cores: 4,
1231            memory_gb: 8.0,
1232            gpu_devices: 1,
1233            storage_gb: 10.0,
1234        }
1235    }
1236}
1237
1238impl Default for StatisticalTestingConfig {
1239    fn default() -> Self {
1240        Self {
1241            significance_level: 0.05,
1242            test_type: StatisticalTestType::TTest,
1243            multiple_comparison_correction: MultipleComparisonCorrection::BenjaminiHochberg,
1244            min_effect_size: 0.1,
1245            bootstrap_samples: 1000,
1246        }
1247    }
1248}
1249
1250impl<T: Float + Debug + Send + Sync + 'static> Default for MultiObjectiveConfig<T> {
1251    fn default() -> Self {
1252        Self {
1253            objectives: Vec::new(),
1254            algorithm: MultiObjectiveAlgorithm::NSGA2,
1255            user_preferences: None,
1256            diversity_strategy: DiversityStrategy::Crowding,
1257            constraint_handling: ConstraintHandlingMethod::PenaltyFunction,
1258        }
1259    }
1260}
1261
1262impl<T: Float + Debug + Send + Sync + 'static> Default for EarlyStoppingConfig<T> {
1263    fn default() -> Self {
1264        Self {
1265            enabled: true,
1266            patience: 20,
1267            min_improvement: scirs2_core::numeric::NumCast::from(0.001)
1268                .unwrap_or_else(|| T::zero()),
1269            convergence_strategy: ConvergenceDetectionStrategy::BestScore,
1270            min_generations: 10,
1271            metric: EvaluationMetric::Accuracy,
1272            target_performance: None,
1273            convergence_detection: ConvergenceDetectionStrategy::NoImprovement,
1274        }
1275    }
1276}
1277
1278impl<T: Float + Debug + Send + Sync + 'static> Default for ResourceConstraints<T> {
1279    fn default() -> Self {
1280        Self {
1281            hardware_resources: HardwareResources::default(),
1282            time_constraints: TimeConstraints::default(),
1283            energy_constraints: None,
1284            cost_constraints: None,
1285            violation_handling: ResourceViolationHandling::Skip,
1286            max_memory_gb: scirs2_core::numeric::NumCast::from(32.0).unwrap_or_else(|| T::zero()),
1287            max_computation_hours: scirs2_core::numeric::NumCast::from(24.0)
1288                .unwrap_or_else(|| T::zero()),
1289            max_energy_kwh: scirs2_core::numeric::NumCast::from(100.0).unwrap_or_else(|| T::zero()),
1290            max_cost_usd: scirs2_core::numeric::NumCast::from(1000.0).unwrap_or_else(|| T::zero()),
1291            enable_monitoring: true,
1292        }
1293    }
1294}
1295
1296impl Default for HardwareResources {
1297    fn default() -> Self {
1298        Self {
1299            max_memory_gb: 32.0,
1300            max_cpu_cores: num_cpus::get(),
1301            max_gpu_devices: 4,
1302            max_storage_gb: 100.0,
1303            max_network_bandwidth: 1000.0, // 1 GB/s
1304            enable_cloud_scaling: false,
1305            cloud_budget: None,
1306            cpu_cores: num_cpus::get(),
1307            ram_gb: 32,
1308            num_gpus: 1,
1309            gpu_memory_gb: 8,
1310            storage_gb: 100,
1311            network_bandwidth_mbps: 1000.0,
1312        }
1313    }
1314}
1315
1316impl Default for TimeConstraints {
1317    fn default() -> Self {
1318        Self {
1319            max_search_time: Duration::from_secs(3600),    // 1 hour
1320            max_evaluation_time: Duration::from_secs(300), // 5 minutes
1321            search_deadline: None,
1322            budget_allocation: TimeBudgetAllocation::Uniform,
1323        }
1324    }
1325}
1326
1327#[cfg(test)]
1328mod tests {
1329    use super::*;
1330
1331    #[test]
1332    fn test_nas_config_default() {
1333        let config = NASConfig::<f32>::default();
1334        assert_eq!(config.search_budget, 100);
1335        assert_eq!(config.population_size, 20);
1336        assert!(!config.progressive_search);
1337    }
1338
1339    #[test]
1340    fn test_search_space_config_default() {
1341        let config = SearchSpaceConfig::default();
1342        assert!(!config.connection_patterns.is_empty());
1343        assert_eq!(config.connection_patterns.len(), 3);
1344    }
1345
1346    #[test]
1347    fn test_parameter_range_variants() {
1348        let continuous = ParameterRange::Continuous(0.0, 1.0);
1349        let discrete = ParameterRange::Discrete(vec![0.1, 0.5, 0.9]);
1350        let integer = ParameterRange::Integer(1, 10);
1351
1352        match continuous {
1353            ParameterRange::Continuous(min, max) => {
1354                assert_eq!(min, 0.0);
1355                assert_eq!(max, 1.0);
1356            }
1357            _ => panic!("Expected continuous range"),
1358        }
1359
1360        match discrete {
1361            ParameterRange::Discrete(values) => {
1362                assert_eq!(values.len(), 3);
1363            }
1364            _ => panic!("Expected discrete range"),
1365        }
1366
1367        match integer {
1368            ParameterRange::Integer(min, max) => {
1369                assert_eq!(min, 1);
1370                assert_eq!(max, 10);
1371            }
1372            _ => panic!("Expected integer range"),
1373        }
1374    }
1375}