1use scirs2_core::ndarray::{Array1, Array2, Array3};
19use scirs2_core::random::{thread_rng, Rng};
20use serde::{Deserialize, Serialize};
21use std::collections::{HashMap, VecDeque};
22
23use crate::circuit_interfaces::{InterfaceCircuit, InterfaceGate, InterfaceGateType};
24use crate::error::{Result, SimulatorError};
25use scirs2_core::random::prelude::*;
26
27#[derive(Debug, Clone)]
29pub struct AdvancedMLMitigationConfig {
30 pub enable_deep_learning: bool,
32 pub enable_reinforcement_learning: bool,
34 pub enable_transfer_learning: bool,
36 pub enable_adversarial_training: bool,
38 pub enable_ensemble_methods: bool,
40 pub enable_online_learning: bool,
42 pub learning_rate: f64,
44 pub batch_size: usize,
46 pub memory_size: usize,
48 pub exploration_rate: f64,
50 pub transfer_alpha: f64,
52 pub ensemble_size: usize,
54}
55
56impl Default for AdvancedMLMitigationConfig {
57 fn default() -> Self {
58 Self {
59 enable_deep_learning: true,
60 enable_reinforcement_learning: true,
61 enable_transfer_learning: false,
62 enable_adversarial_training: false,
63 enable_ensemble_methods: true,
64 enable_online_learning: true,
65 learning_rate: 0.001,
66 batch_size: 64,
67 memory_size: 10_000,
68 exploration_rate: 0.1,
69 transfer_alpha: 0.5,
70 ensemble_size: 5,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct DeepMitigationNetwork {
78 pub layers: Vec<usize>,
80 pub weights: Vec<Array2<f64>>,
82 pub biases: Vec<Array1<f64>>,
84 pub activation: ActivationFunction,
86 pub loss_history: Vec<f64>,
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum ActivationFunction {
93 ReLU,
94 Sigmoid,
95 Tanh,
96 Swish,
97 GELU,
98}
99
100#[derive(Debug, Clone)]
102pub struct QLearningMitigationAgent {
103 pub q_table: HashMap<String, HashMap<MitigationAction, f64>>,
105 pub learning_rate: f64,
107 pub discount_factor: f64,
109 pub exploration_rate: f64,
111 pub experience_buffer: VecDeque<Experience>,
113 pub stats: RLTrainingStats,
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
119pub enum MitigationAction {
120 ZeroNoiseExtrapolation,
121 VirtualDistillation,
122 SymmetryVerification,
123 PauliTwirling,
124 RandomizedCompiling,
125 ClusterExpansion,
126 MachineLearningPrediction,
127 EnsembleMitigation,
128}
129
130#[derive(Debug, Clone)]
132pub struct Experience {
133 pub state: Array1<f64>,
135 pub action: MitigationAction,
137 pub reward: f64,
139 pub next_state: Array1<f64>,
141 pub done: bool,
143}
144
145#[derive(Debug, Clone, Default)]
147pub struct RLTrainingStats {
148 pub episodes: usize,
150 pub avg_reward: f64,
152 pub success_rate: f64,
154 pub exploration_decay: f64,
156 pub loss_convergence: Vec<f64>,
158}
159
160#[derive(Debug, Clone)]
162pub struct TransferLearningModel {
163 pub source_device: DeviceCharacteristics,
165 pub target_device: DeviceCharacteristics,
167 pub feature_extractor: DeepMitigationNetwork,
169 pub device_heads: HashMap<String, DeepMitigationNetwork>,
171 pub transfer_alpha: f64,
173 pub adaptation_stats: TransferStats,
175}
176
177#[derive(Debug, Clone)]
179pub struct DeviceCharacteristics {
180 pub device_id: String,
182 pub gate_errors: HashMap<String, f64>,
184 pub coherence_times: HashMap<String, f64>,
186 pub connectivity: Array2<bool>,
188 pub noise_correlations: Array2<f64>,
190}
191
192#[derive(Debug, Clone, Default)]
194pub struct TransferStats {
195 pub adaptation_loss: f64,
197 pub source_performance: f64,
199 pub target_performance: f64,
201 pub transfer_efficiency: f64,
203}
204
205pub struct EnsembleMitigation {
207 pub models: Vec<Box<dyn MitigationModel>>,
209 pub weights: Array1<f64>,
211 pub combination_strategy: EnsembleStrategy,
213 pub performance_history: Vec<f64>,
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq)]
219pub enum EnsembleStrategy {
220 WeightedAverage,
222 MajorityVoting,
224 Stacking,
226 DynamicSelection,
228 BayesianAveraging,
230}
231
232pub trait MitigationModel: Send + Sync {
234 fn mitigate(&self, measurements: &Array1<f64>, circuit: &InterfaceCircuit) -> Result<f64>;
236
237 fn update(&mut self, training_data: &[(Array1<f64>, f64)]) -> Result<()>;
239
240 fn confidence(&self) -> f64;
242
243 fn name(&self) -> String;
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct AdvancedMLMitigationResult {
250 pub mitigated_value: f64,
252 pub confidence: f64,
254 pub model_used: String,
256 pub raw_measurements: Vec<f64>,
258 pub overhead: f64,
260 pub error_reduction: f64,
262 pub performance_metrics: PerformanceMetrics,
264}
265
266#[derive(Debug, Clone, Default, Serialize, Deserialize)]
268pub struct PerformanceMetrics {
269 pub mae: f64,
271 pub rmse: f64,
273 pub r_squared: f64,
275 pub bias: f64,
277 pub variance: f64,
279 pub computation_time_ms: f64,
281}
282
283#[derive(Debug, Clone)]
285pub struct GraphMitigationNetwork {
286 pub node_features: Array2<f64>,
288 pub edge_features: Array3<f64>,
290 pub attention_weights: Array2<f64>,
292 pub conv_layers: Vec<GraphConvLayer>,
294 pub pooling: GraphPooling,
296}
297
298#[derive(Debug, Clone)]
300pub struct GraphConvLayer {
301 pub weights: Array2<f64>,
303 pub bias: Array1<f64>,
305 pub activation: ActivationFunction,
307}
308
309#[derive(Debug, Clone, Copy, PartialEq, Eq)]
311pub enum GraphPooling {
312 Mean,
313 Max,
314 Sum,
315 Attention,
316 Set2Set,
317}
318
319pub struct AdvancedMLErrorMitigator {
321 config: AdvancedMLMitigationConfig,
323 deep_model: Option<DeepMitigationNetwork>,
325 rl_agent: Option<QLearningMitigationAgent>,
327 transfer_model: Option<TransferLearningModel>,
329 ensemble: Option<EnsembleMitigation>,
331 graph_model: Option<GraphMitigationNetwork>,
333 training_history: VecDeque<(Array1<f64>, f64)>,
335 performance_tracker: PerformanceTracker,
337}
338
339#[derive(Debug, Clone, Default)]
341pub struct PerformanceTracker {
342 pub accuracy_history: HashMap<String, Vec<f64>>,
344 pub cost_history: HashMap<String, Vec<f64>>,
346 pub error_reduction_history: Vec<f64>,
348 pub best_models: HashMap<String, String>,
350}
351
352impl AdvancedMLErrorMitigator {
353 pub fn new(config: AdvancedMLMitigationConfig) -> Result<Self> {
355 let mut mitigator = Self {
356 config: config.clone(),
357 deep_model: None,
358 rl_agent: None,
359 transfer_model: None,
360 ensemble: None,
361 graph_model: None,
362 training_history: VecDeque::with_capacity(config.memory_size),
363 performance_tracker: PerformanceTracker::default(),
364 };
365
366 if config.enable_deep_learning {
368 mitigator.deep_model = Some(mitigator.create_deep_model()?);
369 }
370
371 if config.enable_reinforcement_learning {
372 mitigator.rl_agent = Some(mitigator.create_rl_agent()?);
373 }
374
375 if config.enable_ensemble_methods {
376 mitigator.ensemble = Some(mitigator.create_ensemble()?);
377 }
378
379 Ok(mitigator)
380 }
381
382 pub fn mitigate_errors(
384 &mut self,
385 measurements: &Array1<f64>,
386 circuit: &InterfaceCircuit,
387 ) -> Result<AdvancedMLMitigationResult> {
388 let start_time = std::time::Instant::now();
389
390 let features = self.extract_features(circuit, measurements)?;
392
393 let strategy = self.select_mitigation_strategy(&features)?;
395
396 let mitigated_value = match strategy {
398 MitigationAction::MachineLearningPrediction => {
399 self.apply_deep_learning_mitigation(&features, measurements)?
400 }
401 MitigationAction::EnsembleMitigation => {
402 self.apply_ensemble_mitigation(&features, measurements, circuit)?
403 }
404 _ => {
405 self.apply_traditional_mitigation(strategy, measurements, circuit)?
407 }
408 };
409
410 let confidence = self.calculate_confidence(&features, mitigated_value)?;
412 let error_reduction = self.estimate_error_reduction(measurements, mitigated_value)?;
413
414 let computation_time = start_time.elapsed().as_millis() as f64;
415
416 self.update_models(&features, mitigated_value)?;
418
419 Ok(AdvancedMLMitigationResult {
420 mitigated_value,
421 confidence,
422 model_used: format!("{strategy:?}"),
423 raw_measurements: measurements.to_vec(),
424 overhead: computation_time / 1000.0, error_reduction,
426 performance_metrics: PerformanceMetrics {
427 computation_time_ms: computation_time,
428 ..Default::default()
429 },
430 })
431 }
432
433 pub fn create_deep_model(&self) -> Result<DeepMitigationNetwork> {
435 let layers = vec![18, 128, 64, 32, 1]; let mut weights = Vec::new();
437 let mut biases = Vec::new();
438
439 for i in 0..layers.len() - 1 {
441 let fan_in = layers[i];
442 let fan_out = layers[i + 1];
443 let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
444
445 let w = Array2::from_shape_fn((fan_out, fan_in), |_| {
446 thread_rng().random_range(-limit..limit)
447 });
448 let b = Array1::zeros(fan_out);
449
450 weights.push(w);
451 biases.push(b);
452 }
453
454 Ok(DeepMitigationNetwork {
455 layers,
456 weights,
457 biases,
458 activation: ActivationFunction::ReLU,
459 loss_history: Vec::new(),
460 })
461 }
462
463 pub fn create_rl_agent(&self) -> Result<QLearningMitigationAgent> {
465 Ok(QLearningMitigationAgent {
466 q_table: HashMap::new(),
467 learning_rate: self.config.learning_rate,
468 discount_factor: 0.95,
469 exploration_rate: self.config.exploration_rate,
470 experience_buffer: VecDeque::with_capacity(self.config.memory_size),
471 stats: RLTrainingStats::default(),
472 })
473 }
474
475 fn create_ensemble(&self) -> Result<EnsembleMitigation> {
477 let models: Vec<Box<dyn MitigationModel>> = Vec::new();
478 let weights = Array1::ones(self.config.ensemble_size) / self.config.ensemble_size as f64;
479
480 Ok(EnsembleMitigation {
481 models,
482 weights,
483 combination_strategy: EnsembleStrategy::WeightedAverage,
484 performance_history: Vec::new(),
485 })
486 }
487
488 pub fn extract_features(
490 &self,
491 circuit: &InterfaceCircuit,
492 measurements: &Array1<f64>,
493 ) -> Result<Array1<f64>> {
494 let mut features = Vec::new();
495
496 features.push(circuit.gates.len() as f64); features.push(circuit.num_qubits as f64); let mut gate_counts = HashMap::new();
502 for gate in &circuit.gates {
503 *gate_counts
504 .entry(format!("{:?}", gate.gate_type))
505 .or_insert(0) += 1;
506 }
507
508 let total_gates = circuit.gates.len() as f64;
510 for gate_type in [
511 "PauliX", "PauliY", "PauliZ", "Hadamard", "CNOT", "CZ", "RX", "RY", "RZ", "Phase",
512 ] {
513 let count = gate_counts.get(gate_type).unwrap_or(&0);
514 features.push(f64::from(*count) / total_gates);
515 }
516
517 features.push(measurements.mean().unwrap_or(0.0));
519 features.push(measurements.std(0.0));
520 features.push(measurements.var(0.0));
521 features.push(measurements.len() as f64);
522
523 features.push(self.calculate_circuit_connectivity(circuit)?);
525 features.push(self.calculate_entanglement_estimate(circuit)?);
526
527 Ok(Array1::from_vec(features))
528 }
529
530 pub fn select_mitigation_strategy(
532 &mut self,
533 features: &Array1<f64>,
534 ) -> Result<MitigationAction> {
535 if let Some(ref mut agent) = self.rl_agent {
536 let state_key = Self::features_to_state_key(features);
537
538 if thread_rng().random::<f64>() < agent.exploration_rate {
540 let actions = [
542 MitigationAction::ZeroNoiseExtrapolation,
543 MitigationAction::VirtualDistillation,
544 MitigationAction::MachineLearningPrediction,
545 MitigationAction::EnsembleMitigation,
546 ];
547 Ok(actions[thread_rng().random_range(0..actions.len())])
548 } else {
549 let q_values = agent.q_table.get(&state_key).cloned().unwrap_or_default();
551
552 let best_action = q_values
553 .iter()
554 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
555 .map_or(
556 MitigationAction::MachineLearningPrediction,
557 |(action, _)| *action,
558 );
559
560 Ok(best_action)
561 }
562 } else {
563 Ok(MitigationAction::MachineLearningPrediction)
565 }
566 }
567
568 fn apply_deep_learning_mitigation(
570 &self,
571 features: &Array1<f64>,
572 measurements: &Array1<f64>,
573 ) -> Result<f64> {
574 if let Some(ref model) = self.deep_model {
575 let prediction = Self::forward_pass_static(model, features)?;
576
577 let correction_factor = prediction[0];
579 let mitigated_value = measurements.mean().unwrap_or(0.0) * (1.0 + correction_factor);
580
581 Ok(mitigated_value)
582 } else {
583 Err(SimulatorError::InvalidConfiguration(
584 "Deep learning model not initialized".to_string(),
585 ))
586 }
587 }
588
589 fn apply_ensemble_mitigation(
591 &self,
592 features: &Array1<f64>,
593 measurements: &Array1<f64>,
594 circuit: &InterfaceCircuit,
595 ) -> Result<f64> {
596 if let Some(ref ensemble) = self.ensemble {
597 let mut predictions = Vec::new();
598
599 for model in &ensemble.models {
601 let prediction = model.mitigate(measurements, circuit)?;
602 predictions.push(prediction);
603 }
604
605 let mitigated_value = match ensemble.combination_strategy {
607 EnsembleStrategy::WeightedAverage => {
608 let weighted_sum: f64 = predictions
609 .iter()
610 .zip(ensemble.weights.iter())
611 .map(|(pred, weight)| pred * weight)
612 .sum();
613 weighted_sum
614 }
615 EnsembleStrategy::MajorityVoting => {
616 let mut sorted_predictions = predictions.clone();
618 sorted_predictions
619 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
620 sorted_predictions[sorted_predictions.len() / 2]
621 }
622 _ => {
623 predictions.iter().sum::<f64>() / predictions.len() as f64
625 }
626 };
627
628 Ok(mitigated_value)
629 } else {
630 Ok(measurements.mean().unwrap_or(0.0))
632 }
633 }
634
635 pub fn apply_traditional_mitigation(
637 &self,
638 strategy: MitigationAction,
639 measurements: &Array1<f64>,
640 _circuit: &InterfaceCircuit,
641 ) -> Result<f64> {
642 match strategy {
643 MitigationAction::ZeroNoiseExtrapolation => {
644 let noise_factors = [1.0, 1.5, 2.0];
646 let values: Vec<f64> = noise_factors
647 .iter()
648 .zip(measurements.iter())
649 .map(|(factor, &val)| val / factor)
650 .collect();
651
652 let extrapolated = 2.0f64.mul_add(values[0], -values[1]);
654 Ok(extrapolated)
655 }
656 MitigationAction::VirtualDistillation => {
657 let mean_val = measurements.mean().unwrap_or(0.0);
659 let variance = measurements.var(0.0);
660 let corrected = mean_val + variance * 0.1; Ok(corrected)
662 }
663 _ => {
664 Ok(measurements.mean().unwrap_or(0.0))
666 }
667 }
668 }
669
670 fn forward_pass_static(
672 model: &DeepMitigationNetwork,
673 input: &Array1<f64>,
674 ) -> Result<Array1<f64>> {
675 let mut current = input.clone();
676
677 for (weights, bias) in model.weights.iter().zip(model.biases.iter()) {
678 current = weights.dot(¤t) + bias;
680
681 current.mapv_inplace(|x| Self::apply_activation_static(x, model.activation));
683 }
684
685 Ok(current)
686 }
687
688 fn apply_activation_static(x: f64, activation: ActivationFunction) -> f64 {
690 match activation {
691 ActivationFunction::ReLU => x.max(0.0),
692 ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
693 ActivationFunction::Tanh => x.tanh(),
694 ActivationFunction::Swish => x * (1.0 / (1.0 + (-x).exp())),
695 ActivationFunction::GELU => {
696 0.5 * x
697 * (1.0
698 + ((2.0 / std::f64::consts::PI).sqrt()
699 * 0.044_715f64.mul_add(x.powi(3), x))
700 .tanh())
701 }
702 }
703 }
704
705 #[must_use]
707 pub fn apply_activation(&self, x: f64, activation: ActivationFunction) -> f64 {
708 Self::apply_activation_static(x, activation)
709 }
710
711 pub fn forward_pass(
713 &self,
714 model: &DeepMitigationNetwork,
715 input: &Array1<f64>,
716 ) -> Result<Array1<f64>> {
717 Self::forward_pass_static(model, input)
718 }
719
720 fn calculate_circuit_connectivity(&self, circuit: &InterfaceCircuit) -> Result<f64> {
722 if circuit.num_qubits == 0 {
723 return Ok(0.0);
724 }
725
726 let mut connectivity_sum = 0.0;
727 let total_possible_connections = (circuit.num_qubits * (circuit.num_qubits - 1)) / 2;
728
729 for gate in &circuit.gates {
730 if gate.qubits.len() > 1 {
731 connectivity_sum += 1.0;
732 }
733 }
734
735 Ok(connectivity_sum / total_possible_connections as f64)
736 }
737
738 fn calculate_entanglement_estimate(&self, circuit: &InterfaceCircuit) -> Result<f64> {
740 let mut entangling_gates = 0;
741
742 for gate in &circuit.gates {
743 match gate.gate_type {
744 InterfaceGateType::CNOT
745 | InterfaceGateType::CZ
746 | InterfaceGateType::CY
747 | InterfaceGateType::SWAP
748 | InterfaceGateType::ISwap
749 | InterfaceGateType::Toffoli => {
750 entangling_gates += 1;
751 }
752 _ => {}
753 }
754 }
755
756 Ok(f64::from(entangling_gates) / circuit.gates.len() as f64)
757 }
758
759 fn features_to_state_key(features: &Array1<f64>) -> String {
761 let discretized: Vec<i32> = features
763 .iter()
764 .map(|&x| (x * 10.0).round() as i32)
765 .collect();
766 format!("{discretized:?}")
767 }
768
769 fn calculate_confidence(&self, features: &Array1<f64>, _mitigated_value: f64) -> Result<f64> {
771 let feature_variance = features.var(0.0);
773 let confidence = 1.0 / (1.0 + feature_variance);
774 Ok(confidence.clamp(0.0, 1.0))
775 }
776
777 fn estimate_error_reduction(&self, original: &Array1<f64>, mitigated: f64) -> Result<f64> {
779 let original_mean = original.mean().unwrap_or(0.0);
780 let original_variance = original.var(0.0);
781
782 let estimated_improvement = (original_variance.sqrt() - (mitigated - original_mean).abs())
784 / original_variance.sqrt();
785 Ok(estimated_improvement.clamp(0.0, 1.0))
786 }
787
788 fn update_models(&mut self, features: &Array1<f64>, target: f64) -> Result<()> {
790 if self.training_history.len() >= self.config.memory_size {
792 self.training_history.pop_front();
793 }
794 self.training_history.push_back((features.clone(), target));
795
796 if self.training_history.len() >= self.config.batch_size {
798 self.update_deep_model()?;
799 }
800
801 self.update_rl_agent(features, target)?;
803
804 Ok(())
805 }
806
807 fn update_deep_model(&mut self) -> Result<()> {
809 if let Some(ref mut model) = self.deep_model {
810 let batch_size = self.config.batch_size.min(self.training_history.len());
814 let batch: Vec<_> = self
815 .training_history
816 .iter()
817 .rev()
818 .take(batch_size)
819 .collect();
820
821 let mut total_loss = 0.0;
822
823 for (features, target) in batch {
824 let prediction = Self::forward_pass_static(model, features)?;
825 let loss = (prediction[0] - target).powi(2);
826 total_loss += loss;
827 }
828
829 let avg_loss = total_loss / batch_size as f64;
830 model.loss_history.push(avg_loss);
831 }
832
833 Ok(())
834 }
835
836 fn update_rl_agent(&mut self, features: &Array1<f64>, reward: f64) -> Result<()> {
838 if let Some(ref mut agent) = self.rl_agent {
839 let state_key = Self::features_to_state_key(features);
840
841 agent.stats.episodes += 1;
845 agent.stats.avg_reward = agent
846 .stats
847 .avg_reward
848 .mul_add((agent.stats.episodes - 1) as f64, reward)
849 / agent.stats.episodes as f64;
850
851 agent.exploration_rate *= 0.995;
853 agent.exploration_rate = agent.exploration_rate.max(0.01);
854 }
855
856 Ok(())
857 }
858}
859
860pub fn benchmark_advanced_ml_error_mitigation() -> Result<()> {
862 println!("Benchmarking Advanced ML Error Mitigation...");
863
864 let config = AdvancedMLMitigationConfig::default();
865 let mut mitigator = AdvancedMLErrorMitigator::new(config)?;
866
867 let mut circuit = InterfaceCircuit::new(4, 0);
869 circuit.add_gate(InterfaceGate::new(InterfaceGateType::Hadamard, vec![0]));
870 circuit.add_gate(InterfaceGate::new(InterfaceGateType::CNOT, vec![0, 1]));
871 circuit.add_gate(InterfaceGate::new(InterfaceGateType::RZ(0.5), vec![2]));
872
873 let noisy_measurements = Array1::from_vec(vec![0.48, 0.52, 0.47, 0.53, 0.49]);
875
876 let start_time = std::time::Instant::now();
877
878 let result = mitigator.mitigate_errors(&noisy_measurements, &circuit)?;
880
881 let duration = start_time.elapsed();
882
883 println!("✅ Advanced ML Error Mitigation Results:");
884 println!(" Mitigated Value: {:.6}", result.mitigated_value);
885 println!(" Confidence: {:.4}", result.confidence);
886 println!(" Model Used: {}", result.model_used);
887 println!(" Error Reduction: {:.4}", result.error_reduction);
888 println!(" Computation Time: {:.2}ms", duration.as_millis());
889
890 Ok(())
891}
892
893#[cfg(test)]
894mod tests {
895 use super::*;
896
897 #[test]
898 fn test_advanced_ml_mitigator_creation() {
899 let config = AdvancedMLMitigationConfig::default();
900 let mitigator = AdvancedMLErrorMitigator::new(config);
901 assert!(mitigator.is_ok());
902 }
903
904 #[test]
905 fn test_feature_extraction() {
906 let config = AdvancedMLMitigationConfig::default();
907 let mitigator = AdvancedMLErrorMitigator::new(config).expect("Failed to create mitigator");
908
909 let mut circuit = InterfaceCircuit::new(2, 0);
910 circuit.add_gate(InterfaceGate::new(InterfaceGateType::Hadamard, vec![0]));
911 circuit.add_gate(InterfaceGate::new(InterfaceGateType::CNOT, vec![0, 1]));
912
913 let measurements = Array1::from_vec(vec![0.5, 0.5, 0.5]);
914 let features = mitigator.extract_features(&circuit, &measurements);
915
916 assert!(features.is_ok());
917 let features = features.expect("Failed to extract features");
918 assert!(!features.is_empty());
919 }
920
921 #[test]
922 fn test_activation_functions() {
923 let config = AdvancedMLMitigationConfig::default();
924 let mitigator = AdvancedMLErrorMitigator::new(config).expect("Failed to create mitigator");
925
926 assert_eq!(
928 mitigator.apply_activation(-1.0, ActivationFunction::ReLU),
929 0.0
930 );
931 assert_eq!(
932 mitigator.apply_activation(1.0, ActivationFunction::ReLU),
933 1.0
934 );
935
936 let sigmoid_result = mitigator.apply_activation(0.0, ActivationFunction::Sigmoid);
938 assert!((sigmoid_result - 0.5).abs() < 1e-10);
939 }
940
941 #[test]
942 fn test_mitigation_strategy_selection() {
943 let config = AdvancedMLMitigationConfig::default();
944 let mut mitigator =
945 AdvancedMLErrorMitigator::new(config).expect("Failed to create mitigator");
946
947 let features = Array1::from_vec(vec![1.0, 2.0, 3.0]);
948 let strategy = mitigator.select_mitigation_strategy(&features);
949
950 assert!(strategy.is_ok());
951 }
952
953 #[test]
954 fn test_traditional_mitigation() {
955 let config = AdvancedMLMitigationConfig::default();
956 let mitigator = AdvancedMLErrorMitigator::new(config).expect("Failed to create mitigator");
957
958 let measurements = Array1::from_vec(vec![0.48, 0.52, 0.49]);
959 let circuit = InterfaceCircuit::new(2, 0);
960
961 let result = mitigator.apply_traditional_mitigation(
962 MitigationAction::ZeroNoiseExtrapolation,
963 &measurements,
964 &circuit,
965 );
966
967 assert!(result.is_ok());
968 }
969}