1use ndarray::{Array1, Array2, Array3};
19use rand::{thread_rng, Rng};
20use serde::{Deserialize, Serialize};
21use std::collections::{HashMap, VecDeque};
22
23use crate::circuit_interfaces::{InterfaceCircuit, InterfaceGate, InterfaceGateType};
24use crate::error::{Result, SimulatorError};
25
26#[derive(Debug, Clone)]
28pub struct AdvancedMLMitigationConfig {
29 pub enable_deep_learning: bool,
31 pub enable_reinforcement_learning: bool,
33 pub enable_transfer_learning: bool,
35 pub enable_adversarial_training: bool,
37 pub enable_ensemble_methods: bool,
39 pub enable_online_learning: bool,
41 pub learning_rate: f64,
43 pub batch_size: usize,
45 pub memory_size: usize,
47 pub exploration_rate: f64,
49 pub transfer_alpha: f64,
51 pub ensemble_size: usize,
53}
54
55impl Default for AdvancedMLMitigationConfig {
56 fn default() -> Self {
57 Self {
58 enable_deep_learning: true,
59 enable_reinforcement_learning: true,
60 enable_transfer_learning: false,
61 enable_adversarial_training: false,
62 enable_ensemble_methods: true,
63 enable_online_learning: true,
64 learning_rate: 0.001,
65 batch_size: 64,
66 memory_size: 10000,
67 exploration_rate: 0.1,
68 transfer_alpha: 0.5,
69 ensemble_size: 5,
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct DeepMitigationNetwork {
77 pub layers: Vec<usize>,
79 pub weights: Vec<Array2<f64>>,
81 pub biases: Vec<Array1<f64>>,
83 pub activation: ActivationFunction,
85 pub loss_history: Vec<f64>,
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum ActivationFunction {
92 ReLU,
93 Sigmoid,
94 Tanh,
95 Swish,
96 GELU,
97}
98
99#[derive(Debug, Clone)]
101pub struct QLearningMitigationAgent {
102 pub q_table: HashMap<String, HashMap<MitigationAction, f64>>,
104 pub learning_rate: f64,
106 pub discount_factor: f64,
108 pub exploration_rate: f64,
110 pub experience_buffer: VecDeque<Experience>,
112 pub stats: RLTrainingStats,
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
118pub enum MitigationAction {
119 ZeroNoiseExtrapolation,
120 VirtualDistillation,
121 SymmetryVerification,
122 PauliTwirling,
123 RandomizedCompiling,
124 ClusterExpansion,
125 MachineLearningPrediction,
126 EnsembleMitigation,
127}
128
129#[derive(Debug, Clone)]
131pub struct Experience {
132 pub state: Array1<f64>,
134 pub action: MitigationAction,
136 pub reward: f64,
138 pub next_state: Array1<f64>,
140 pub done: bool,
142}
143
144#[derive(Debug, Clone, Default)]
146pub struct RLTrainingStats {
147 pub episodes: usize,
149 pub avg_reward: f64,
151 pub success_rate: f64,
153 pub exploration_decay: f64,
155 pub loss_convergence: Vec<f64>,
157}
158
159#[derive(Debug, Clone)]
161pub struct TransferLearningModel {
162 pub source_device: DeviceCharacteristics,
164 pub target_device: DeviceCharacteristics,
166 pub feature_extractor: DeepMitigationNetwork,
168 pub device_heads: HashMap<String, DeepMitigationNetwork>,
170 pub transfer_alpha: f64,
172 pub adaptation_stats: TransferStats,
174}
175
176#[derive(Debug, Clone)]
178pub struct DeviceCharacteristics {
179 pub device_id: String,
181 pub gate_errors: HashMap<String, f64>,
183 pub coherence_times: HashMap<String, f64>,
185 pub connectivity: Array2<bool>,
187 pub noise_correlations: Array2<f64>,
189}
190
191#[derive(Debug, Clone, Default)]
193pub struct TransferStats {
194 pub adaptation_loss: f64,
196 pub source_performance: f64,
198 pub target_performance: f64,
200 pub transfer_efficiency: f64,
202}
203
204pub struct EnsembleMitigation {
206 pub models: Vec<Box<dyn MitigationModel>>,
208 pub weights: Array1<f64>,
210 pub combination_strategy: EnsembleStrategy,
212 pub performance_history: Vec<f64>,
214}
215
216#[derive(Debug, Clone, Copy, PartialEq, Eq)]
218pub enum EnsembleStrategy {
219 WeightedAverage,
221 MajorityVoting,
223 Stacking,
225 DynamicSelection,
227 BayesianAveraging,
229}
230
231pub trait MitigationModel: Send + Sync {
233 fn mitigate(&self, measurements: &Array1<f64>, circuit: &InterfaceCircuit) -> Result<f64>;
235
236 fn update(&mut self, training_data: &[(Array1<f64>, f64)]) -> Result<()>;
238
239 fn confidence(&self) -> f64;
241
242 fn name(&self) -> String;
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct AdvancedMLMitigationResult {
249 pub mitigated_value: f64,
251 pub confidence: f64,
253 pub model_used: String,
255 pub raw_measurements: Vec<f64>,
257 pub overhead: f64,
259 pub error_reduction: f64,
261 pub performance_metrics: PerformanceMetrics,
263}
264
265#[derive(Debug, Clone, Default, Serialize, Deserialize)]
267pub struct PerformanceMetrics {
268 pub mae: f64,
270 pub rmse: f64,
272 pub r_squared: f64,
274 pub bias: f64,
276 pub variance: f64,
278 pub computation_time_ms: f64,
280}
281
282#[derive(Debug, Clone)]
284pub struct GraphMitigationNetwork {
285 pub node_features: Array2<f64>,
287 pub edge_features: Array3<f64>,
289 pub attention_weights: Array2<f64>,
291 pub conv_layers: Vec<GraphConvLayer>,
293 pub pooling: GraphPooling,
295}
296
297#[derive(Debug, Clone)]
299pub struct GraphConvLayer {
300 pub weights: Array2<f64>,
302 pub bias: Array1<f64>,
304 pub activation: ActivationFunction,
306}
307
308#[derive(Debug, Clone, Copy, PartialEq, Eq)]
310pub enum GraphPooling {
311 Mean,
312 Max,
313 Sum,
314 Attention,
315 Set2Set,
316}
317
318pub struct AdvancedMLErrorMitigator {
320 config: AdvancedMLMitigationConfig,
322 deep_model: Option<DeepMitigationNetwork>,
324 rl_agent: Option<QLearningMitigationAgent>,
326 transfer_model: Option<TransferLearningModel>,
328 ensemble: Option<EnsembleMitigation>,
330 graph_model: Option<GraphMitigationNetwork>,
332 training_history: VecDeque<(Array1<f64>, f64)>,
334 performance_tracker: PerformanceTracker,
336}
337
338#[derive(Debug, Clone, Default)]
340pub struct PerformanceTracker {
341 pub accuracy_history: HashMap<String, Vec<f64>>,
343 pub cost_history: HashMap<String, Vec<f64>>,
345 pub error_reduction_history: Vec<f64>,
347 pub best_models: HashMap<String, String>,
349}
350
351impl AdvancedMLErrorMitigator {
352 pub fn new(config: AdvancedMLMitigationConfig) -> Result<Self> {
354 let mut mitigator = Self {
355 config: config.clone(),
356 deep_model: None,
357 rl_agent: None,
358 transfer_model: None,
359 ensemble: None,
360 graph_model: None,
361 training_history: VecDeque::with_capacity(config.memory_size),
362 performance_tracker: PerformanceTracker::default(),
363 };
364
365 if config.enable_deep_learning {
367 mitigator.deep_model = Some(mitigator.create_deep_model()?);
368 }
369
370 if config.enable_reinforcement_learning {
371 mitigator.rl_agent = Some(mitigator.create_rl_agent()?);
372 }
373
374 if config.enable_ensemble_methods {
375 mitigator.ensemble = Some(mitigator.create_ensemble()?);
376 }
377
378 Ok(mitigator)
379 }
380
381 pub fn mitigate_errors(
383 &mut self,
384 measurements: &Array1<f64>,
385 circuit: &InterfaceCircuit,
386 ) -> Result<AdvancedMLMitigationResult> {
387 let start_time = std::time::Instant::now();
388
389 let features = self.extract_features(circuit, measurements)?;
391
392 let strategy = self.select_mitigation_strategy(&features)?;
394
395 let mitigated_value = match strategy {
397 MitigationAction::MachineLearningPrediction => {
398 self.apply_deep_learning_mitigation(&features, measurements)?
399 }
400 MitigationAction::EnsembleMitigation => {
401 self.apply_ensemble_mitigation(&features, measurements, circuit)?
402 }
403 _ => {
404 self.apply_traditional_mitigation(strategy, measurements, circuit)?
406 }
407 };
408
409 let confidence = self.calculate_confidence(&features, mitigated_value)?;
411 let error_reduction = self.estimate_error_reduction(measurements, mitigated_value)?;
412
413 let computation_time = start_time.elapsed().as_millis() as f64;
414
415 self.update_models(&features, mitigated_value)?;
417
418 Ok(AdvancedMLMitigationResult {
419 mitigated_value,
420 confidence,
421 model_used: format!("{:?}", strategy),
422 raw_measurements: measurements.to_vec(),
423 overhead: computation_time / 1000.0, error_reduction,
425 performance_metrics: PerformanceMetrics {
426 computation_time_ms: computation_time,
427 ..Default::default()
428 },
429 })
430 }
431
432 pub fn create_deep_model(&self) -> Result<DeepMitigationNetwork> {
434 let layers = vec![18, 128, 64, 32, 1]; let mut weights = Vec::new();
436 let mut biases = Vec::new();
437
438 for i in 0..layers.len() - 1 {
440 let fan_in = layers[i];
441 let fan_out = layers[i + 1];
442 let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
443
444 let w =
445 Array2::from_shape_fn((fan_out, fan_in), |_| thread_rng().gen_range(-limit..limit));
446 let b = Array1::zeros(fan_out);
447
448 weights.push(w);
449 biases.push(b);
450 }
451
452 Ok(DeepMitigationNetwork {
453 layers,
454 weights,
455 biases,
456 activation: ActivationFunction::ReLU,
457 loss_history: Vec::new(),
458 })
459 }
460
461 pub fn create_rl_agent(&self) -> Result<QLearningMitigationAgent> {
463 Ok(QLearningMitigationAgent {
464 q_table: HashMap::new(),
465 learning_rate: self.config.learning_rate,
466 discount_factor: 0.95,
467 exploration_rate: self.config.exploration_rate,
468 experience_buffer: VecDeque::with_capacity(self.config.memory_size),
469 stats: RLTrainingStats::default(),
470 })
471 }
472
473 fn create_ensemble(&self) -> Result<EnsembleMitigation> {
475 let models: Vec<Box<dyn MitigationModel>> = Vec::new();
476 let weights = Array1::ones(self.config.ensemble_size) / self.config.ensemble_size as f64;
477
478 Ok(EnsembleMitigation {
479 models,
480 weights,
481 combination_strategy: EnsembleStrategy::WeightedAverage,
482 performance_history: Vec::new(),
483 })
484 }
485
486 pub fn extract_features(
488 &self,
489 circuit: &InterfaceCircuit,
490 measurements: &Array1<f64>,
491 ) -> Result<Array1<f64>> {
492 let mut features = Vec::new();
493
494 features.push(circuit.gates.len() as f64); features.push(circuit.num_qubits as f64); let mut gate_counts = HashMap::new();
500 for gate in &circuit.gates {
501 *gate_counts
502 .entry(format!("{:?}", gate.gate_type))
503 .or_insert(0) += 1;
504 }
505
506 let total_gates = circuit.gates.len() as f64;
508 for gate_type in [
509 "PauliX", "PauliY", "PauliZ", "Hadamard", "CNOT", "CZ", "RX", "RY", "RZ", "Phase",
510 ] {
511 let count = gate_counts.get(gate_type).unwrap_or(&0);
512 features.push(*count as f64 / total_gates);
513 }
514
515 features.push(measurements.mean().unwrap_or(0.0));
517 features.push(measurements.std(0.0));
518 features.push(measurements.var(0.0));
519 features.push(measurements.len() as f64);
520
521 features.push(self.calculate_circuit_connectivity(circuit)?);
523 features.push(self.calculate_entanglement_estimate(circuit)?);
524
525 Ok(Array1::from_vec(features))
526 }
527
528 pub fn select_mitigation_strategy(
530 &mut self,
531 features: &Array1<f64>,
532 ) -> Result<MitigationAction> {
533 if let Some(ref mut agent) = self.rl_agent {
534 let state_key = Self::features_to_state_key(features);
535
536 if rand::random::<f64>() < agent.exploration_rate {
538 let actions = [
540 MitigationAction::ZeroNoiseExtrapolation,
541 MitigationAction::VirtualDistillation,
542 MitigationAction::MachineLearningPrediction,
543 MitigationAction::EnsembleMitigation,
544 ];
545 Ok(actions[thread_rng().gen_range(0..actions.len())])
546 } else {
547 let q_values = agent.q_table.get(&state_key).cloned().unwrap_or_default();
549
550 let best_action = q_values
551 .iter()
552 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
553 .map(|(action, _)| *action)
554 .unwrap_or(MitigationAction::MachineLearningPrediction);
555
556 Ok(best_action)
557 }
558 } else {
559 Ok(MitigationAction::MachineLearningPrediction)
561 }
562 }
563
564 fn apply_deep_learning_mitigation(
566 &self,
567 features: &Array1<f64>,
568 measurements: &Array1<f64>,
569 ) -> Result<f64> {
570 if let Some(ref model) = self.deep_model {
571 let prediction = Self::forward_pass_static(model, features)?;
572
573 let correction_factor = prediction[0];
575 let mitigated_value = measurements.mean().unwrap_or(0.0) * (1.0 + correction_factor);
576
577 Ok(mitigated_value)
578 } else {
579 Err(SimulatorError::InvalidConfiguration(
580 "Deep learning model not initialized".to_string(),
581 ))
582 }
583 }
584
585 fn apply_ensemble_mitigation(
587 &self,
588 features: &Array1<f64>,
589 measurements: &Array1<f64>,
590 circuit: &InterfaceCircuit,
591 ) -> Result<f64> {
592 if let Some(ref ensemble) = self.ensemble {
593 let mut predictions = Vec::new();
594
595 for model in &ensemble.models {
597 let prediction = model.mitigate(measurements, circuit)?;
598 predictions.push(prediction);
599 }
600
601 let mitigated_value = match ensemble.combination_strategy {
603 EnsembleStrategy::WeightedAverage => {
604 let weighted_sum: f64 = predictions
605 .iter()
606 .zip(ensemble.weights.iter())
607 .map(|(pred, weight)| pred * weight)
608 .sum();
609 weighted_sum
610 }
611 EnsembleStrategy::MajorityVoting => {
612 let mut sorted_predictions = predictions.clone();
614 sorted_predictions.sort_by(|a, b| a.partial_cmp(b).unwrap());
615 sorted_predictions[sorted_predictions.len() / 2]
616 }
617 _ => {
618 predictions.iter().sum::<f64>() / predictions.len() as f64
620 }
621 };
622
623 Ok(mitigated_value)
624 } else {
625 Ok(measurements.mean().unwrap_or(0.0))
627 }
628 }
629
630 pub fn apply_traditional_mitigation(
632 &self,
633 strategy: MitigationAction,
634 measurements: &Array1<f64>,
635 _circuit: &InterfaceCircuit,
636 ) -> Result<f64> {
637 match strategy {
638 MitigationAction::ZeroNoiseExtrapolation => {
639 let noise_factors = [1.0, 1.5, 2.0];
641 let values: Vec<f64> = noise_factors
642 .iter()
643 .zip(measurements.iter())
644 .map(|(factor, &val)| val / factor)
645 .collect();
646
647 let extrapolated = 2.0 * values[0] - values[1];
649 Ok(extrapolated)
650 }
651 MitigationAction::VirtualDistillation => {
652 let mean_val = measurements.mean().unwrap_or(0.0);
654 let variance = measurements.var(0.0);
655 let corrected = mean_val + variance * 0.1; Ok(corrected)
657 }
658 _ => {
659 Ok(measurements.mean().unwrap_or(0.0))
661 }
662 }
663 }
664
665 fn forward_pass_static(
667 model: &DeepMitigationNetwork,
668 input: &Array1<f64>,
669 ) -> Result<Array1<f64>> {
670 let mut current = input.clone();
671
672 for (weights, bias) in model.weights.iter().zip(model.biases.iter()) {
673 current = weights.dot(¤t) + bias;
675
676 current.mapv_inplace(|x| Self::apply_activation_static(x, model.activation));
678 }
679
680 Ok(current)
681 }
682
683 fn apply_activation_static(x: f64, activation: ActivationFunction) -> f64 {
685 match activation {
686 ActivationFunction::ReLU => x.max(0.0),
687 ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
688 ActivationFunction::Tanh => x.tanh(),
689 ActivationFunction::Swish => x * (1.0 / (1.0 + (-x).exp())),
690 ActivationFunction::GELU => {
691 0.5 * x
692 * (1.0
693 + ((2.0 / std::f64::consts::PI).sqrt() * (x + 0.044715 * x.powi(3))).tanh())
694 }
695 }
696 }
697
698 pub fn apply_activation(&self, x: f64, activation: ActivationFunction) -> f64 {
700 Self::apply_activation_static(x, activation)
701 }
702
703 pub fn forward_pass(
705 &self,
706 model: &DeepMitigationNetwork,
707 input: &Array1<f64>,
708 ) -> Result<Array1<f64>> {
709 Self::forward_pass_static(model, input)
710 }
711
712 fn calculate_circuit_connectivity(&self, circuit: &InterfaceCircuit) -> Result<f64> {
714 if circuit.num_qubits == 0 {
715 return Ok(0.0);
716 }
717
718 let mut connectivity_sum = 0.0;
719 let total_possible_connections = (circuit.num_qubits * (circuit.num_qubits - 1)) / 2;
720
721 for gate in &circuit.gates {
722 if gate.qubits.len() > 1 {
723 connectivity_sum += 1.0;
724 }
725 }
726
727 Ok(connectivity_sum / total_possible_connections as f64)
728 }
729
730 fn calculate_entanglement_estimate(&self, circuit: &InterfaceCircuit) -> Result<f64> {
732 let mut entangling_gates = 0;
733
734 for gate in &circuit.gates {
735 match gate.gate_type {
736 InterfaceGateType::CNOT
737 | InterfaceGateType::CZ
738 | InterfaceGateType::CY
739 | InterfaceGateType::SWAP
740 | InterfaceGateType::ISwap
741 | InterfaceGateType::Toffoli => {
742 entangling_gates += 1;
743 }
744 _ => {}
745 }
746 }
747
748 Ok(entangling_gates as f64 / circuit.gates.len() as f64)
749 }
750
751 fn features_to_state_key(features: &Array1<f64>) -> String {
753 let discretized: Vec<i32> = features
755 .iter()
756 .map(|&x| (x * 10.0).round() as i32)
757 .collect();
758 format!("{:?}", discretized)
759 }
760
761 fn calculate_confidence(&self, features: &Array1<f64>, _mitigated_value: f64) -> Result<f64> {
763 let feature_variance = features.var(0.0);
765 let confidence = 1.0 / (1.0 + feature_variance);
766 Ok(confidence.min(1.0).max(0.0))
767 }
768
769 fn estimate_error_reduction(&self, original: &Array1<f64>, mitigated: f64) -> Result<f64> {
771 let original_mean = original.mean().unwrap_or(0.0);
772 let original_variance = original.var(0.0);
773
774 let estimated_improvement = (original_variance.sqrt() - (mitigated - original_mean).abs())
776 / original_variance.sqrt();
777 Ok(estimated_improvement.max(0.0).min(1.0))
778 }
779
780 fn update_models(&mut self, features: &Array1<f64>, target: f64) -> Result<()> {
782 if self.training_history.len() >= self.config.memory_size {
784 self.training_history.pop_front();
785 }
786 self.training_history.push_back((features.clone(), target));
787
788 if self.training_history.len() >= self.config.batch_size {
790 self.update_deep_model()?;
791 }
792
793 self.update_rl_agent(features, target)?;
795
796 Ok(())
797 }
798
799 fn update_deep_model(&mut self) -> Result<()> {
801 if let Some(ref mut model) = self.deep_model {
802 let batch_size = self.config.batch_size.min(self.training_history.len());
806 let batch: Vec<_> = self
807 .training_history
808 .iter()
809 .rev()
810 .take(batch_size)
811 .collect();
812
813 let mut total_loss = 0.0;
814
815 for (features, target) in batch {
816 let prediction = Self::forward_pass_static(model, features)?;
817 let loss = (prediction[0] - target).powi(2);
818 total_loss += loss;
819 }
820
821 let avg_loss = total_loss / batch_size as f64;
822 model.loss_history.push(avg_loss);
823 }
824
825 Ok(())
826 }
827
828 fn update_rl_agent(&mut self, features: &Array1<f64>, reward: f64) -> Result<()> {
830 if let Some(ref mut agent) = self.rl_agent {
831 let state_key = Self::features_to_state_key(features);
832
833 agent.stats.episodes += 1;
837 agent.stats.avg_reward = (agent.stats.avg_reward * (agent.stats.episodes - 1) as f64
838 + reward)
839 / agent.stats.episodes as f64;
840
841 agent.exploration_rate *= 0.995;
843 agent.exploration_rate = agent.exploration_rate.max(0.01);
844 }
845
846 Ok(())
847 }
848}
849
850pub fn benchmark_advanced_ml_error_mitigation() -> Result<()> {
852 println!("Benchmarking Advanced ML Error Mitigation...");
853
854 let config = AdvancedMLMitigationConfig::default();
855 let mut mitigator = AdvancedMLErrorMitigator::new(config)?;
856
857 let mut circuit = InterfaceCircuit::new(4, 0);
859 circuit.add_gate(InterfaceGate::new(InterfaceGateType::Hadamard, vec![0]));
860 circuit.add_gate(InterfaceGate::new(InterfaceGateType::CNOT, vec![0, 1]));
861 circuit.add_gate(InterfaceGate::new(InterfaceGateType::RZ(0.5), vec![2]));
862
863 let noisy_measurements = Array1::from_vec(vec![0.48, 0.52, 0.47, 0.53, 0.49]);
865
866 let start_time = std::time::Instant::now();
867
868 let result = mitigator.mitigate_errors(&noisy_measurements, &circuit)?;
870
871 let duration = start_time.elapsed();
872
873 println!("✅ Advanced ML Error Mitigation Results:");
874 println!(" Mitigated Value: {:.6}", result.mitigated_value);
875 println!(" Confidence: {:.4}", result.confidence);
876 println!(" Model Used: {}", result.model_used);
877 println!(" Error Reduction: {:.4}", result.error_reduction);
878 println!(" Computation Time: {:.2}ms", duration.as_millis());
879
880 Ok(())
881}
882
883#[cfg(test)]
884mod tests {
885 use super::*;
886
887 #[test]
888 fn test_advanced_ml_mitigator_creation() {
889 let config = AdvancedMLMitigationConfig::default();
890 let mitigator = AdvancedMLErrorMitigator::new(config);
891 assert!(mitigator.is_ok());
892 }
893
894 #[test]
895 fn test_feature_extraction() {
896 let config = AdvancedMLMitigationConfig::default();
897 let mitigator = AdvancedMLErrorMitigator::new(config).unwrap();
898
899 let mut circuit = InterfaceCircuit::new(2, 0);
900 circuit.add_gate(InterfaceGate::new(InterfaceGateType::Hadamard, vec![0]));
901 circuit.add_gate(InterfaceGate::new(InterfaceGateType::CNOT, vec![0, 1]));
902
903 let measurements = Array1::from_vec(vec![0.5, 0.5, 0.5]);
904 let features = mitigator.extract_features(&circuit, &measurements);
905
906 assert!(features.is_ok());
907 let features = features.unwrap();
908 assert!(features.len() > 0);
909 }
910
911 #[test]
912 fn test_activation_functions() {
913 let config = AdvancedMLMitigationConfig::default();
914 let mitigator = AdvancedMLErrorMitigator::new(config).unwrap();
915
916 assert_eq!(
918 mitigator.apply_activation(-1.0, ActivationFunction::ReLU),
919 0.0
920 );
921 assert_eq!(
922 mitigator.apply_activation(1.0, ActivationFunction::ReLU),
923 1.0
924 );
925
926 let sigmoid_result = mitigator.apply_activation(0.0, ActivationFunction::Sigmoid);
928 assert!((sigmoid_result - 0.5).abs() < 1e-10);
929 }
930
931 #[test]
932 fn test_mitigation_strategy_selection() {
933 let config = AdvancedMLMitigationConfig::default();
934 let mut mitigator = AdvancedMLErrorMitigator::new(config).unwrap();
935
936 let features = Array1::from_vec(vec![1.0, 2.0, 3.0]);
937 let strategy = mitigator.select_mitigation_strategy(&features);
938
939 assert!(strategy.is_ok());
940 }
941
942 #[test]
943 fn test_traditional_mitigation() {
944 let config = AdvancedMLMitigationConfig::default();
945 let mitigator = AdvancedMLErrorMitigator::new(config).unwrap();
946
947 let measurements = Array1::from_vec(vec![0.48, 0.52, 0.49]);
948 let circuit = InterfaceCircuit::new(2, 0);
949
950 let result = mitigator.apply_traditional_mitigation(
951 MitigationAction::ZeroNoiseExtrapolation,
952 &measurements,
953 &circuit,
954 );
955
956 assert!(result.is_ok());
957 }
958}