quantrs2_sim/quantum_machine_learning_layers/
config.rs

1//! Quantum Machine Learning Configuration Types
2//!
3//! This module contains all configuration types for the QML framework.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::f64::consts::PI;
8
9/// Quantum machine learning configuration
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct QMLConfig {
12    /// Number of qubits in the quantum layer
13    pub num_qubits: usize,
14    /// QML architecture type
15    pub architecture_type: QMLArchitectureType,
16    /// Layer configuration for each QML layer
17    pub layer_configs: Vec<QMLLayerConfig>,
18    /// Training algorithm configuration
19    pub training_config: QMLTrainingConfig,
20    /// Hardware-aware optimization settings
21    pub hardware_optimization: HardwareOptimizationConfig,
22    /// Classical preprocessing configuration
23    pub classical_preprocessing: ClassicalPreprocessingConfig,
24    /// Hybrid training configuration
25    pub hybrid_training: HybridTrainingConfig,
26    /// Enable quantum advantage analysis
27    pub quantum_advantage_analysis: bool,
28    /// Noise-aware training settings
29    pub noise_aware_training: NoiseAwareTrainingConfig,
30    /// Performance optimization settings
31    pub performance_optimization: PerformanceOptimizationConfig,
32}
33
34impl Default for QMLConfig {
35    fn default() -> Self {
36        Self {
37            num_qubits: 8,
38            architecture_type: QMLArchitectureType::VariationalQuantumCircuit,
39            layer_configs: vec![QMLLayerConfig {
40                layer_type: QMLLayerType::ParameterizedQuantumCircuit,
41                num_parameters: 16,
42                ansatz_type: AnsatzType::Hardware,
43                entanglement_pattern: EntanglementPattern::Linear,
44                rotation_gates: vec![RotationGate::RY, RotationGate::RZ],
45                depth: 4,
46                enable_gradient_computation: true,
47            }],
48            training_config: QMLTrainingConfig::default(),
49            hardware_optimization: HardwareOptimizationConfig::default(),
50            classical_preprocessing: ClassicalPreprocessingConfig::default(),
51            hybrid_training: HybridTrainingConfig::default(),
52            quantum_advantage_analysis: true,
53            noise_aware_training: NoiseAwareTrainingConfig::default(),
54            performance_optimization: PerformanceOptimizationConfig::default(),
55        }
56    }
57}
58
59/// QML architecture types
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
61pub enum QMLArchitectureType {
62    VariationalQuantumCircuit,
63    QuantumConvolutionalNN,
64    QuantumRecurrentNN,
65    QuantumGraphNN,
66    QuantumAttentionNetwork,
67    QuantumTransformer,
68    HybridClassicalQuantum,
69    QuantumBoltzmannMachine,
70    QuantumGAN,
71    QuantumAutoencoder,
72}
73
74/// QML layer configuration
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct QMLLayerConfig {
77    pub layer_type: QMLLayerType,
78    pub num_parameters: usize,
79    pub ansatz_type: AnsatzType,
80    pub entanglement_pattern: EntanglementPattern,
81    pub rotation_gates: Vec<RotationGate>,
82    pub depth: usize,
83    pub enable_gradient_computation: bool,
84}
85
86/// Types of QML layers
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
88pub enum QMLLayerType {
89    ParameterizedQuantumCircuit,
90    QuantumConvolutional,
91    QuantumPooling,
92    QuantumDense,
93    QuantumLSTM,
94    QuantumGRU,
95    QuantumAttention,
96    QuantumDropout,
97    QuantumBatchNorm,
98    DataReUpload,
99}
100
101/// Ansatz types for parameterized quantum circuits
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
103pub enum AnsatzType {
104    Hardware,
105    ProblemSpecific,
106    AllToAll,
107    Layered,
108    Alternating,
109    BrickWall,
110    Tree,
111    Custom,
112}
113
114/// Entanglement patterns
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
116pub enum EntanglementPattern {
117    Linear,
118    Circular,
119    AllToAll,
120    Star,
121    Grid,
122    Random,
123    Block,
124    Custom,
125}
126
127/// Rotation gates for parameterized circuits
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
129pub enum RotationGate {
130    RX,
131    RY,
132    RZ,
133    U3,
134    Phase,
135}
136
137/// QML training configuration
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct QMLTrainingConfig {
140    pub algorithm: QMLTrainingAlgorithm,
141    pub learning_rate: f64,
142    pub epochs: usize,
143    pub batch_size: usize,
144    pub gradient_method: GradientMethod,
145    pub optimizer: OptimizerType,
146    pub regularization: RegularizationConfig,
147    pub early_stopping: EarlyStoppingConfig,
148    pub lr_schedule: LearningRateSchedule,
149}
150
151impl Default for QMLTrainingConfig {
152    fn default() -> Self {
153        Self {
154            algorithm: QMLTrainingAlgorithm::ParameterShift,
155            learning_rate: 0.01,
156            epochs: 100,
157            batch_size: 32,
158            gradient_method: GradientMethod::ParameterShift,
159            optimizer: OptimizerType::Adam,
160            regularization: RegularizationConfig::default(),
161            early_stopping: EarlyStoppingConfig::default(),
162            lr_schedule: LearningRateSchedule::Constant,
163        }
164    }
165}
166
167/// QML training algorithms
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
169pub enum QMLTrainingAlgorithm {
170    ParameterShift,
171    FiniteDifference,
172    QuantumNaturalGradient,
173    SPSA,
174    QAOA,
175    VQE,
176    Rotosolve,
177    HybridTraining,
178}
179
180/// Gradient computation methods
181#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
182pub enum GradientMethod {
183    ParameterShift,
184    FiniteDifference,
185    Adjoint,
186    Backpropagation,
187    QuantumFisherInformation,
188}
189
190/// Optimizer types
191#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
192pub enum OptimizerType {
193    SGD,
194    Adam,
195    AdaGrad,
196    RMSprop,
197    Momentum,
198    LBFGS,
199    QuantumNaturalGradient,
200    SPSA,
201}
202
203/// Regularization configuration
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct RegularizationConfig {
206    pub l1_strength: f64,
207    pub l2_strength: f64,
208    pub dropout_prob: f64,
209    pub parameter_bounds: Option<(f64, f64)>,
210    pub enable_clipping: bool,
211    pub gradient_clip_threshold: f64,
212}
213
214impl Default for RegularizationConfig {
215    fn default() -> Self {
216        Self {
217            l1_strength: 0.0,
218            l2_strength: 0.001,
219            dropout_prob: 0.1,
220            parameter_bounds: Some((-PI, PI)),
221            enable_clipping: true,
222            gradient_clip_threshold: 1.0,
223        }
224    }
225}
226
227/// Early stopping configuration
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct EarlyStoppingConfig {
230    pub enabled: bool,
231    pub patience: usize,
232    pub min_delta: f64,
233    pub monitor_metric: String,
234    pub mode_max: bool,
235}
236
237impl Default for EarlyStoppingConfig {
238    fn default() -> Self {
239        Self {
240            enabled: true,
241            patience: 10,
242            min_delta: 1e-6,
243            monitor_metric: "val_loss".to_string(),
244            mode_max: false,
245        }
246    }
247}
248
249/// Learning rate schedules
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251pub enum LearningRateSchedule {
252    Constant,
253    ExponentialDecay,
254    StepDecay,
255    CosineAnnealing,
256    WarmRestart,
257    ReduceOnPlateau,
258}
259
260/// Hardware optimization configuration
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct HardwareOptimizationConfig {
263    pub target_hardware: QuantumHardwareTarget,
264    pub minimize_gate_count: bool,
265    pub minimize_depth: bool,
266    pub noise_aware: bool,
267    pub connectivity_constraints: ConnectivityConstraints,
268    pub gate_fidelities: HashMap<String, f64>,
269    pub enable_parallelization: bool,
270    pub optimization_level: HardwareOptimizationLevel,
271}
272
273impl Default for HardwareOptimizationConfig {
274    fn default() -> Self {
275        let mut gate_fidelities = HashMap::new();
276        gate_fidelities.insert("single_qubit".to_string(), 0.999);
277        gate_fidelities.insert("two_qubit".to_string(), 0.99);
278
279        Self {
280            target_hardware: QuantumHardwareTarget::Simulator,
281            minimize_gate_count: true,
282            minimize_depth: true,
283            noise_aware: false,
284            connectivity_constraints: ConnectivityConstraints::AllToAll,
285            gate_fidelities,
286            enable_parallelization: true,
287            optimization_level: HardwareOptimizationLevel::Medium,
288        }
289    }
290}
291
292/// Quantum hardware targets
293#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
294pub enum QuantumHardwareTarget {
295    Simulator,
296    IBM,
297    Google,
298    IonQ,
299    Rigetti,
300    Quantinuum,
301    Xanadu,
302    Custom,
303}
304
305/// Connectivity constraints
306#[derive(Debug, Clone, Serialize, Deserialize)]
307pub enum ConnectivityConstraints {
308    AllToAll,
309    Linear,
310    Grid(usize, usize),
311    Custom(Vec<(usize, usize)>),
312    HeavyHex,
313    Square,
314}
315
316/// Hardware optimization levels
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
318pub enum HardwareOptimizationLevel {
319    Basic,
320    Medium,
321    Aggressive,
322    Maximum,
323}
324
325/// Classical preprocessing configuration
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct ClassicalPreprocessingConfig {
328    pub feature_scaling: bool,
329    pub scaling_method: ScalingMethod,
330    pub enable_pca: bool,
331    pub pca_components: Option<usize>,
332    pub encoding_method: DataEncodingMethod,
333    pub feature_selection: FeatureSelectionConfig,
334}
335
336impl Default for ClassicalPreprocessingConfig {
337    fn default() -> Self {
338        Self {
339            feature_scaling: true,
340            scaling_method: ScalingMethod::StandardScaler,
341            enable_pca: false,
342            pca_components: None,
343            encoding_method: DataEncodingMethod::Amplitude,
344            feature_selection: FeatureSelectionConfig::default(),
345        }
346    }
347}
348
349/// Scaling methods for classical preprocessing
350#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
351pub enum ScalingMethod {
352    StandardScaler,
353    MinMaxScaler,
354    RobustScaler,
355    QuantileUniform,
356    PowerTransformer,
357}
358
359/// Data encoding methods for quantum circuits
360#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
361pub enum DataEncodingMethod {
362    Amplitude,
363    Angle,
364    Basis,
365    QuantumFeatureMap,
366    IQP,
367    PauliFeatureMap,
368    DataReUpload,
369}
370
371/// Feature selection configuration
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct FeatureSelectionConfig {
374    pub enabled: bool,
375    pub method: FeatureSelectionMethod,
376    pub num_features: Option<usize>,
377    pub threshold: f64,
378}
379
380impl Default for FeatureSelectionConfig {
381    fn default() -> Self {
382        Self {
383            enabled: false,
384            method: FeatureSelectionMethod::VarianceThreshold,
385            num_features: None,
386            threshold: 0.0,
387        }
388    }
389}
390
391/// Feature selection methods
392#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
393pub enum FeatureSelectionMethod {
394    VarianceThreshold,
395    UnivariateSelection,
396    RecursiveFeatureElimination,
397    L1Based,
398    TreeBased,
399    QuantumFeatureImportance,
400}
401
402/// Hybrid training configuration
403#[derive(Debug, Clone, Serialize, Deserialize)]
404pub struct HybridTrainingConfig {
405    pub enabled: bool,
406    pub classical_architecture: ClassicalArchitecture,
407    pub interface_config: QuantumClassicalInterface,
408    pub alternating_schedule: AlternatingSchedule,
409    pub gradient_flow: GradientFlowConfig,
410}
411
412impl Default for HybridTrainingConfig {
413    fn default() -> Self {
414        Self {
415            enabled: false,
416            classical_architecture: ClassicalArchitecture::MLP,
417            interface_config: QuantumClassicalInterface::Expectation,
418            alternating_schedule: AlternatingSchedule::Simultaneous,
419            gradient_flow: GradientFlowConfig::default(),
420        }
421    }
422}
423
424/// Classical neural network architectures for hybrid training
425#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
426pub enum ClassicalArchitecture {
427    MLP,
428    CNN,
429    RNN,
430    LSTM,
431    Transformer,
432    ResNet,
433    Custom,
434}
435
436/// Quantum-classical interfaces
437#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
438pub enum QuantumClassicalInterface {
439    Expectation,
440    Sampling,
441    StateTomography,
442    ProcessTomography,
443    ShadowTomography,
444}
445
446/// Alternating training schedules for hybrid systems
447#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
448pub enum AlternatingSchedule {
449    Simultaneous,
450    Alternating,
451    ClassicalFirst,
452    QuantumFirst,
453    Custom,
454}
455
456/// Gradient flow configuration for hybrid training
457#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct GradientFlowConfig {
459    pub classical_to_quantum: bool,
460    pub quantum_to_classical: bool,
461    pub gradient_scaling: f64,
462    pub enable_clipping: bool,
463    pub accumulation_steps: usize,
464}
465
466impl Default for GradientFlowConfig {
467    fn default() -> Self {
468        Self {
469            classical_to_quantum: true,
470            quantum_to_classical: true,
471            gradient_scaling: 1.0,
472            enable_clipping: true,
473            accumulation_steps: 1,
474        }
475    }
476}
477
478/// Noise-aware training configuration
479#[derive(Debug, Clone, Serialize, Deserialize, Default)]
480pub struct NoiseAwareTrainingConfig {
481    pub enabled: bool,
482    pub noise_parameters: NoiseParameters,
483    pub error_mitigation: ErrorMitigationConfig,
484    pub noise_characterization: NoiseCharacterizationConfig,
485    pub robust_training: RobustTrainingConfig,
486}
487
488/// Noise parameters for quantum devices
489#[derive(Debug, Clone, Serialize, Deserialize)]
490pub struct NoiseParameters {
491    pub single_qubit_error: f64,
492    pub two_qubit_error: f64,
493    pub measurement_error: f64,
494    pub coherence_times: (f64, f64),
495    pub gate_times: HashMap<String, f64>,
496}
497
498impl Default for NoiseParameters {
499    fn default() -> Self {
500        let mut gate_times = HashMap::new();
501        gate_times.insert("single_qubit".to_string(), 50e-9);
502        gate_times.insert("two_qubit".to_string(), 200e-9);
503
504        Self {
505            single_qubit_error: 0.001,
506            two_qubit_error: 0.01,
507            measurement_error: 0.01,
508            coherence_times: (50e-6, 100e-6),
509            gate_times,
510        }
511    }
512}
513
514/// Error mitigation configuration
515#[derive(Debug, Clone, Serialize, Deserialize, Default)]
516pub struct ErrorMitigationConfig {
517    pub zero_noise_extrapolation: bool,
518    pub readout_error_mitigation: bool,
519    pub symmetry_verification: bool,
520    pub virtual_distillation: VirtualDistillationConfig,
521    pub quantum_error_correction: bool,
522}
523
524/// Virtual distillation configuration
525#[derive(Debug, Clone, Serialize, Deserialize)]
526pub struct VirtualDistillationConfig {
527    pub enabled: bool,
528    pub num_copies: usize,
529    pub protocol: DistillationProtocol,
530}
531
532impl Default for VirtualDistillationConfig {
533    fn default() -> Self {
534        Self {
535            enabled: false,
536            num_copies: 2,
537            protocol: DistillationProtocol::Standard,
538        }
539    }
540}
541
542/// Distillation protocols
543#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
544pub enum DistillationProtocol {
545    Standard,
546    Improved,
547    QuantumAdvantage,
548}
549
550/// Noise characterization configuration
551#[derive(Debug, Clone, Serialize, Deserialize)]
552pub struct NoiseCharacterizationConfig {
553    pub enabled: bool,
554    pub method: NoiseCharacterizationMethod,
555    pub benchmarking: BenchmarkingProtocols,
556    pub calibration_frequency: CalibrationFrequency,
557}
558
559impl Default for NoiseCharacterizationConfig {
560    fn default() -> Self {
561        Self {
562            enabled: false,
563            method: NoiseCharacterizationMethod::ProcessTomography,
564            benchmarking: BenchmarkingProtocols::default(),
565            calibration_frequency: CalibrationFrequency::Daily,
566        }
567    }
568}
569
570/// Noise characterization methods
571#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
572pub enum NoiseCharacterizationMethod {
573    ProcessTomography,
574    RandomizedBenchmarking,
575    GateSetTomography,
576    QuantumDetectorTomography,
577    CrossEntropyBenchmarking,
578}
579
580/// Benchmarking protocols
581#[derive(Debug, Clone, Serialize, Deserialize)]
582pub struct BenchmarkingProtocols {
583    pub randomized_benchmarking: bool,
584    pub quantum_volume: bool,
585    pub cross_entropy_benchmarking: bool,
586    pub mirror_benchmarking: bool,
587}
588
589impl Default for BenchmarkingProtocols {
590    fn default() -> Self {
591        Self {
592            randomized_benchmarking: true,
593            quantum_volume: false,
594            cross_entropy_benchmarking: false,
595            mirror_benchmarking: false,
596        }
597    }
598}
599
600/// Calibration frequency
601#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
602pub enum CalibrationFrequency {
603    RealTime,
604    Hourly,
605    Daily,
606    Weekly,
607    Manual,
608}
609
610/// Robust training configuration
611#[derive(Debug, Clone, Serialize, Deserialize, Default)]
612pub struct RobustTrainingConfig {
613    pub enabled: bool,
614    pub noise_injection: NoiseInjectionConfig,
615    pub adversarial_training: AdversarialTrainingConfig,
616    pub ensemble_methods: EnsembleMethodsConfig,
617}
618
619/// Noise injection configuration
620#[derive(Debug, Clone, Serialize, Deserialize)]
621pub struct NoiseInjectionConfig {
622    pub enabled: bool,
623    pub injection_probability: f64,
624    pub noise_strength: f64,
625    pub noise_type: NoiseType,
626}
627
628impl Default for NoiseInjectionConfig {
629    fn default() -> Self {
630        Self {
631            enabled: false,
632            injection_probability: 0.1,
633            noise_strength: 0.01,
634            noise_type: NoiseType::Depolarizing,
635        }
636    }
637}
638
639/// Noise types for training
640#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
641pub enum NoiseType {
642    Depolarizing,
643    AmplitudeDamping,
644    PhaseDamping,
645    BitFlip,
646    PhaseFlip,
647    Pauli,
648}
649
650/// Adversarial training configuration
651#[derive(Debug, Clone, Serialize, Deserialize)]
652pub struct AdversarialTrainingConfig {
653    pub enabled: bool,
654    pub attack_strength: f64,
655    pub attack_method: AdversarialAttackMethod,
656    pub defense_method: AdversarialDefenseMethod,
657}
658
659impl Default for AdversarialTrainingConfig {
660    fn default() -> Self {
661        Self {
662            enabled: false,
663            attack_strength: 0.01,
664            attack_method: AdversarialAttackMethod::FGSM,
665            defense_method: AdversarialDefenseMethod::AdversarialTraining,
666        }
667    }
668}
669
670/// Adversarial attack methods
671#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
672pub enum AdversarialAttackMethod {
673    FGSM,
674    PGD,
675    CarliniWagner,
676    QuantumAdversarial,
677}
678
679/// Adversarial defense methods
680#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
681pub enum AdversarialDefenseMethod {
682    AdversarialTraining,
683    DefensiveDistillation,
684    CertifiedDefenses,
685    QuantumErrorCorrection,
686}
687
688/// Ensemble methods configuration
689#[derive(Debug, Clone, Serialize, Deserialize)]
690pub struct EnsembleMethodsConfig {
691    pub enabled: bool,
692    pub num_ensemble: usize,
693    pub ensemble_method: EnsembleMethod,
694    pub voting_strategy: VotingStrategy,
695}
696
697impl Default for EnsembleMethodsConfig {
698    fn default() -> Self {
699        Self {
700            enabled: false,
701            num_ensemble: 5,
702            ensemble_method: EnsembleMethod::Bagging,
703            voting_strategy: VotingStrategy::MajorityVoting,
704        }
705    }
706}
707
708/// Ensemble methods
709#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
710pub enum EnsembleMethod {
711    Bagging,
712    Boosting,
713    RandomForest,
714    QuantumEnsemble,
715}
716
717/// Voting strategies for ensembles
718#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
719pub enum VotingStrategy {
720    MajorityVoting,
721    WeightedVoting,
722    SoftVoting,
723    QuantumVoting,
724}
725
726/// Performance optimization configuration
727#[derive(Debug, Clone, Serialize, Deserialize)]
728pub struct PerformanceOptimizationConfig {
729    pub enabled: bool,
730    pub memory_optimization: MemoryOptimizationConfig,
731    pub computation_optimization: ComputationOptimizationConfig,
732    pub parallelization: ParallelizationConfig,
733    pub caching: CachingConfig,
734}
735
736impl Default for PerformanceOptimizationConfig {
737    fn default() -> Self {
738        Self {
739            enabled: true,
740            memory_optimization: MemoryOptimizationConfig::default(),
741            computation_optimization: ComputationOptimizationConfig::default(),
742            parallelization: ParallelizationConfig::default(),
743            caching: CachingConfig::default(),
744        }
745    }
746}
747
748/// Memory optimization configuration
749#[derive(Debug, Clone, Serialize, Deserialize)]
750pub struct MemoryOptimizationConfig {
751    pub enabled: bool,
752    pub memory_mapping: bool,
753    pub gradient_checkpointing: bool,
754    pub memory_pool_size: Option<usize>,
755}
756
757impl Default for MemoryOptimizationConfig {
758    fn default() -> Self {
759        Self {
760            enabled: true,
761            memory_mapping: false,
762            gradient_checkpointing: false,
763            memory_pool_size: None,
764        }
765    }
766}
767
768/// Computation optimization configuration
769#[derive(Debug, Clone, Serialize, Deserialize)]
770pub struct ComputationOptimizationConfig {
771    pub enabled: bool,
772    pub mixed_precision: bool,
773    pub simd_optimization: bool,
774    pub jit_compilation: bool,
775}
776
777impl Default for ComputationOptimizationConfig {
778    fn default() -> Self {
779        Self {
780            enabled: true,
781            mixed_precision: false,
782            simd_optimization: true,
783            jit_compilation: false,
784        }
785    }
786}
787
788/// Parallelization configuration
789#[derive(Debug, Clone, Serialize, Deserialize)]
790pub struct ParallelizationConfig {
791    pub enabled: bool,
792    pub num_threads: Option<usize>,
793    pub data_parallelism: bool,
794    pub model_parallelism: bool,
795    pub pipeline_parallelism: bool,
796}
797
798impl Default for ParallelizationConfig {
799    fn default() -> Self {
800        Self {
801            enabled: true,
802            num_threads: None,
803            data_parallelism: true,
804            model_parallelism: false,
805            pipeline_parallelism: false,
806        }
807    }
808}
809
810/// Caching configuration
811#[derive(Debug, Clone, Serialize, Deserialize)]
812pub struct CachingConfig {
813    pub enabled: bool,
814    pub cache_size: usize,
815    pub cache_gradients: bool,
816    pub cache_intermediate: bool,
817}
818
819impl Default for CachingConfig {
820    fn default() -> Self {
821        Self {
822            enabled: true,
823            cache_size: 1000,
824            cache_gradients: true,
825            cache_intermediate: false,
826        }
827    }
828}