1use scirs2_core::ndarray::{Array1, Array2, Axis};
17use scirs2_core::Complex64;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, VecDeque};
20use std::sync::{Arc, Mutex};
21
22use crate::circuit_interfaces::CircuitInterface;
23use crate::concatenated_error_correction::ErrorType;
24use crate::error::Result;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum MLModelType {
29 NeuralNetwork,
31 DecisionTree,
33 SVM,
35 ReinforcementLearning,
37 Ensemble,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum LearningStrategy {
44 Supervised,
46 Unsupervised,
48 Reinforcement,
50 Online,
52 Transfer,
54}
55
56#[derive(Debug, Clone)]
58pub struct AdaptiveMLConfig {
59 pub model_type: MLModelType,
61 pub learning_strategy: LearningStrategy,
63 pub learning_rate: f64,
65 pub batch_size: usize,
67 pub max_history_size: usize,
69 pub confidence_threshold: f64,
71 pub real_time_learning: bool,
73 pub update_frequency: usize,
75 pub feature_extraction: FeatureExtractionMethod,
77 pub hardware_aware: bool,
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum FeatureExtractionMethod {
84 RawSyndrome,
86 FourierTransform,
88 PCA,
90 Autoencoder,
92 TemporalConvolution,
94}
95
96impl Default for AdaptiveMLConfig {
97 fn default() -> Self {
98 Self {
99 model_type: MLModelType::NeuralNetwork,
100 learning_strategy: LearningStrategy::Online,
101 learning_rate: 0.001,
102 batch_size: 32,
103 max_history_size: 10_000,
104 confidence_threshold: 0.8,
105 real_time_learning: true,
106 update_frequency: 100,
107 feature_extraction: FeatureExtractionMethod::RawSyndrome,
108 hardware_aware: true,
109 }
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct SyndromeClassificationNetwork {
116 input_size: usize,
118 hidden_sizes: Vec<usize>,
120 output_size: usize,
122 weights: Vec<Array2<f64>>,
124 biases: Vec<Array1<f64>>,
126 learning_rate: f64,
128 training_history: Vec<(Array1<f64>, Array1<f64>)>,
130}
131
132impl SyndromeClassificationNetwork {
133 #[must_use]
135 pub fn new(
136 input_size: usize,
137 hidden_sizes: Vec<usize>,
138 output_size: usize,
139 learning_rate: f64,
140 ) -> Self {
141 let mut layer_sizes = vec![input_size];
142 layer_sizes.extend(&hidden_sizes);
143 layer_sizes.push(output_size);
144
145 let mut weights = Vec::new();
146 let mut biases = Vec::new();
147
148 for i in 0..layer_sizes.len() - 1 {
149 let rows = layer_sizes[i + 1];
150 let cols = layer_sizes[i];
151
152 let scale = (2.0 / (rows + cols) as f64).sqrt();
154 let mut weight_matrix = Array2::zeros((rows, cols));
155 for elem in &mut weight_matrix {
156 *elem = (fastrand::f64() - 0.5) * 2.0 * scale;
157 }
158 weights.push(weight_matrix);
159
160 biases.push(Array1::zeros(rows));
161 }
162
163 Self {
164 input_size,
165 hidden_sizes,
166 output_size,
167 weights,
168 biases,
169 learning_rate,
170 training_history: Vec::new(),
171 }
172 }
173
174 #[must_use]
176 pub fn forward(&self, input: &Array1<f64>) -> Array1<f64> {
177 let mut activation = input.clone();
178
179 let last_weight = self.weights.last();
181
182 for (weight, bias) in self.weights.iter().zip(self.biases.iter()) {
183 activation = weight.dot(&activation) + bias;
184
185 let is_output_layer = last_weight.map_or(false, |last| weight == last);
187 if is_output_layer {
188 let max_val = activation.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
190 activation.mapv_inplace(|x| (x - max_val).exp());
191 let sum = activation.sum();
192 activation.mapv_inplace(|x| x / sum);
193 } else {
194 activation.mapv_inplace(|x| x.max(0.0));
195 }
196 }
197
198 activation
199 }
200
201 pub fn train_batch(&mut self, inputs: &[Array1<f64>], targets: &[Array1<f64>]) -> f64 {
203 let batch_size = inputs.len();
204 let mut total_loss = 0.0;
205
206 let mut weight_gradients: Vec<Array2<f64>> = self
208 .weights
209 .iter()
210 .map(|w| Array2::zeros(w.raw_dim()))
211 .collect();
212 let mut bias_gradients: Vec<Array1<f64>> = self
213 .biases
214 .iter()
215 .map(|b| Array1::zeros(b.raw_dim()))
216 .collect();
217
218 for (input, target) in inputs.iter().zip(targets.iter()) {
219 let (loss, w_grads, b_grads) = self.backward(input, target);
220 total_loss += loss;
221
222 for (wg_acc, wg) in weight_gradients.iter_mut().zip(w_grads.iter()) {
223 *wg_acc = &*wg_acc + wg;
224 }
225 for (bg_acc, bg) in bias_gradients.iter_mut().zip(b_grads.iter()) {
226 *bg_acc = &*bg_acc + bg;
227 }
228 }
229
230 let lr = self.learning_rate / batch_size as f64;
232 for (weight, gradient) in self.weights.iter_mut().zip(weight_gradients.iter()) {
233 *weight = &*weight - &(gradient * lr);
234 }
235 for (bias, gradient) in self.biases.iter_mut().zip(bias_gradients.iter()) {
236 *bias = &*bias - &(gradient * lr);
237 }
238
239 total_loss / batch_size as f64
240 }
241
242 fn backward(
244 &self,
245 input: &Array1<f64>,
246 target: &Array1<f64>,
247 ) -> (f64, Vec<Array2<f64>>, Vec<Array1<f64>>) {
248 let mut activations = vec![input.clone()];
250 let mut z_values = Vec::new();
251
252 let last_weight = self.weights.last();
254
255 for (weight, bias) in self.weights.iter().zip(self.biases.iter()) {
256 let last_activation = activations
258 .last()
259 .expect("activations should never be empty");
260 let z = weight.dot(last_activation) + bias;
261 z_values.push(z.clone());
262
263 let mut activation = z;
264 let is_output_layer = last_weight.map_or(false, |last| weight == last);
265 if is_output_layer {
266 let max_val = activation.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
268 activation.mapv_inplace(|x| (x - max_val).exp());
269 let sum = activation.sum();
270 activation.mapv_inplace(|x| x / sum);
271 } else {
272 activation.mapv_inplace(|x| x.max(0.0)); }
274 activations.push(activation);
275 }
276
277 let output = activations
280 .last()
281 .expect("activations should have output from forward pass");
282 let loss = -target
283 .iter()
284 .zip(output.iter())
285 .map(|(&t, &o)| if t > 0.0 { t * o.ln() } else { 0.0 })
286 .sum::<f64>();
287
288 let mut weight_gradients = Vec::with_capacity(self.weights.len());
290 let mut bias_gradients = Vec::with_capacity(self.biases.len());
291
292 let mut delta = output - target;
294
295 for i in (0..self.weights.len()).rev() {
296 let weight_grad = delta
298 .view()
299 .insert_axis(Axis(1))
300 .dot(&activations[i].view().insert_axis(Axis(0)));
301 weight_gradients.insert(0, weight_grad);
302
303 bias_gradients.insert(0, delta.clone());
305
306 if i > 0 {
307 delta = self.weights[i].t().dot(&delta);
309
310 for (j, &z) in z_values[i - 1].iter().enumerate() {
312 if z <= 0.0 {
313 delta[j] = 0.0;
314 }
315 }
316 }
317 }
318
319 (loss, weight_gradients, bias_gradients)
320 }
321
322 #[must_use]
324 pub fn predict(&self, syndrome: &Array1<f64>) -> (usize, f64) {
325 let output = self.forward(syndrome);
326 let max_idx = output
327 .iter()
328 .enumerate()
329 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
330 .map(|(idx, _)| idx)
331 .unwrap_or(0);
332 let confidence = output.get(max_idx).copied().unwrap_or(0.0);
333 (max_idx, confidence)
334 }
335}
336
337#[derive(Debug, Clone)]
339pub struct ErrorCorrectionAgent {
340 q_table: HashMap<String, Array1<f64>>,
342 learning_rate: f64,
344 discount_factor: f64,
346 epsilon: f64,
348 action_space_size: usize,
350 training_steps: usize,
352 episode_rewards: VecDeque<f64>,
354}
355
356impl ErrorCorrectionAgent {
357 #[must_use]
359 pub fn new(
360 action_space_size: usize,
361 learning_rate: f64,
362 discount_factor: f64,
363 epsilon: f64,
364 ) -> Self {
365 Self {
366 q_table: HashMap::new(),
367 learning_rate,
368 discount_factor,
369 epsilon,
370 action_space_size,
371 training_steps: 0,
372 episode_rewards: VecDeque::with_capacity(1000),
373 }
374 }
375
376 pub fn select_action(&mut self, state: &str) -> usize {
378 if fastrand::f64() < self.epsilon {
379 fastrand::usize(0..self.action_space_size)
381 } else {
382 let q_values = self
384 .q_table
385 .entry(state.to_string())
386 .or_insert_with(|| Array1::zeros(self.action_space_size));
387
388 q_values
389 .iter()
390 .enumerate()
391 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
392 .map(|(idx, _)| idx)
393 .unwrap_or(0)
394 }
395 }
396
397 pub fn update_q_value(
399 &mut self,
400 state: &str,
401 action: usize,
402 reward: f64,
403 next_state: &str,
404 done: bool,
405 ) {
406 let current_q = self
407 .q_table
408 .entry(state.to_string())
409 .or_insert_with(|| Array1::zeros(self.action_space_size))
410 .clone();
411
412 let next_q_max = if done {
413 0.0
414 } else {
415 let next_q_values = self
416 .q_table
417 .entry(next_state.to_string())
418 .or_insert_with(|| Array1::zeros(self.action_space_size));
419 next_q_values
420 .iter()
421 .fold(f64::NEG_INFINITY, |a, &b| a.max(b))
422 };
423
424 let td_target = self.discount_factor.mul_add(next_q_max, reward);
425 let current_q_action = current_q.get(action).copied().unwrap_or(0.0);
426 let td_error = td_target - current_q_action;
427
428 if let Some(q_values) = self.q_table.get_mut(state) {
430 if action < q_values.len() {
431 q_values[action] += self.learning_rate * td_error;
432 }
433 }
434
435 self.training_steps += 1;
436
437 if self.training_steps % 1000 == 0 {
439 self.epsilon = (self.epsilon * 0.995).max(0.01);
440 }
441 }
442
443 #[must_use]
445 pub fn calculate_reward(
446 &self,
447 errors_before: usize,
448 errors_after: usize,
449 correction_cost: f64,
450 ) -> f64 {
451 let error_reduction = errors_before as f64 - errors_after as f64;
452 let reward = error_reduction.mul_add(10.0, -correction_cost);
453
454 if errors_after == 0 {
456 reward + 5.0
457 } else {
458 reward
459 }
460 }
461}
462
463pub struct AdaptiveMLErrorCorrection {
465 config: AdaptiveMLConfig,
467 classifier: SyndromeClassificationNetwork,
469 rl_agent: ErrorCorrectionAgent,
471 feature_extractor: FeatureExtractor,
473 training_history: Arc<Mutex<VecDeque<TrainingExample>>>,
475 metrics: CorrectionMetrics,
477 circuit_interface: CircuitInterface,
479 update_counter: usize,
481}
482
483#[derive(Debug, Clone)]
485pub struct TrainingExample {
486 pub syndrome: Array1<f64>,
488 pub error_type: ErrorType,
490 pub action: usize,
492 pub reward: f64,
494 pub timestamp: f64,
496}
497
498#[derive(Debug, Clone)]
500pub struct FeatureExtractor {
501 method: FeatureExtractionMethod,
503 pca_components: Option<Array2<f64>>,
505 autoencoder: Option<SyndromeClassificationNetwork>,
507}
508
509impl FeatureExtractor {
510 #[must_use]
512 pub const fn new(method: FeatureExtractionMethod) -> Self {
513 Self {
514 method,
515 pca_components: None,
516 autoencoder: None,
517 }
518 }
519
520 #[must_use]
522 pub fn extract_features(&self, syndrome: &[bool]) -> Array1<f64> {
523 match self.method {
524 FeatureExtractionMethod::RawSyndrome => {
525 let mut features: Vec<f64> = syndrome
526 .iter()
527 .map(|&b| if b { 1.0 } else { 0.0 })
528 .collect();
529 while features.len() < 4 {
531 features.push(0.0);
532 }
533 Array1::from_vec(features)
534 }
535 FeatureExtractionMethod::FourierTransform => self.fft_features(syndrome),
536 FeatureExtractionMethod::PCA => self.pca_features(syndrome),
537 FeatureExtractionMethod::Autoencoder => self.autoencoder_features(syndrome),
538 FeatureExtractionMethod::TemporalConvolution => self.temporal_conv_features(syndrome),
539 }
540 }
541
542 fn fft_features(&self, syndrome: &[bool]) -> Array1<f64> {
544 let mut signal: Vec<f64> = syndrome
545 .iter()
546 .map(|&b| if b { 1.0 } else { 0.0 })
547 .collect();
548
549 while signal.len() < 4 {
551 signal.push(0.0);
552 }
553
554 let mut features = Vec::new();
556 let n = signal.len();
557
558 for k in 0..n.min(8) {
559 let mut real_part = 0.0;
561 let mut imag_part = 0.0;
562
563 for (i, &x) in signal.iter().enumerate() {
564 let angle = -2.0 * std::f64::consts::PI * k as f64 * i as f64 / n as f64;
565 real_part += x * angle.cos();
566 imag_part += x * angle.sin();
567 }
568
569 features.push(real_part);
570 features.push(imag_part);
571 }
572
573 Array1::from_vec(features)
574 }
575
576 fn pca_features(&self, syndrome: &[bool]) -> Array1<f64> {
578 let mut features: Vec<f64> = syndrome
579 .iter()
580 .map(|&b| if b { 1.0 } else { 0.0 })
581 .collect();
582 while features.len() < 4 {
584 features.push(0.0);
585 }
586 let raw_features = Array1::from_vec(features);
587
588 if let Some(ref components) = self.pca_components {
589 components.dot(&raw_features)
590 } else {
591 raw_features
592 }
593 }
594
595 fn autoencoder_features(&self, syndrome: &[bool]) -> Array1<f64> {
597 let mut features: Vec<f64> = syndrome
598 .iter()
599 .map(|&b| if b { 1.0 } else { 0.0 })
600 .collect();
601 while features.len() < 4 {
603 features.push(0.0);
604 }
605 let raw_features = Array1::from_vec(features);
606
607 if let Some(ref encoder) = self.autoencoder {
608 encoder.forward(&raw_features)
609 } else {
610 raw_features
611 }
612 }
613
614 fn temporal_conv_features(&self, syndrome: &[bool]) -> Array1<f64> {
616 let mut signal: Vec<f64> = syndrome
617 .iter()
618 .map(|&b| if b { 1.0 } else { 0.0 })
619 .collect();
620
621 while signal.len() < 4 {
623 signal.push(0.0);
624 }
625
626 let kernel_size = 3;
628 let mut features = Vec::new();
629
630 for i in 0..signal.len().saturating_sub(kernel_size - 1) {
631 let mut conv_sum = 0.0;
632 for j in 0..kernel_size {
633 conv_sum += signal[i + j] * (j as f64 + 1.0) / kernel_size as f64;
634 }
636 features.push(conv_sum);
637 }
638
639 if features.is_empty() {
641 features = signal; }
643
644 Array1::from_vec(features)
645 }
646}
647
648#[derive(Debug, Clone, Default, Serialize, Deserialize)]
650pub struct CorrectionMetrics {
651 pub total_corrections: usize,
653 pub successful_corrections: usize,
655 pub false_positives: usize,
657 pub false_negatives: usize,
659 pub average_confidence: f64,
661 pub learning_curve: Vec<f64>,
663 pub reward_history: Vec<f64>,
665 pub avg_correction_time_ms: f64,
667}
668
669impl CorrectionMetrics {
670 #[must_use]
672 pub fn accuracy(&self) -> f64 {
673 if self.total_corrections == 0 {
674 return 1.0;
675 }
676 self.successful_corrections as f64 / self.total_corrections as f64
677 }
678
679 #[must_use]
681 pub fn precision(&self) -> f64 {
682 let true_positives = self.successful_corrections;
683 let predicted_positives = true_positives + self.false_positives;
684
685 if predicted_positives == 0 {
686 return 1.0;
687 }
688 true_positives as f64 / predicted_positives as f64
689 }
690
691 #[must_use]
693 pub fn recall(&self) -> f64 {
694 let true_positives = self.successful_corrections;
695 let actual_positives = true_positives + self.false_negatives;
696
697 if actual_positives == 0 {
698 return 1.0;
699 }
700 true_positives as f64 / actual_positives as f64
701 }
702
703 #[must_use]
705 pub fn f1_score(&self) -> f64 {
706 let precision = self.precision();
707 let recall = self.recall();
708
709 if precision + recall == 0.0 {
710 return 0.0;
711 }
712 2.0 * precision * recall / (precision + recall)
713 }
714}
715
716impl AdaptiveMLErrorCorrection {
717 pub fn new(config: AdaptiveMLConfig) -> Result<Self> {
719 let circuit_interface = CircuitInterface::new(Default::default())?;
720
721 let feature_extractor = FeatureExtractor::new(config.feature_extraction);
723
724 let test_syndrome = vec![false, false, false, false]; let test_features = feature_extractor.extract_features(&test_syndrome);
728 let input_size = test_features.len();
729
730 let hidden_sizes = vec![input_size * 2, input_size]; let output_size = 4; let classifier = SyndromeClassificationNetwork::new(
734 input_size,
735 hidden_sizes,
736 output_size,
737 config.learning_rate,
738 );
739
740 let action_space_size = 8; let rl_agent = ErrorCorrectionAgent::new(
743 action_space_size,
744 config.learning_rate,
745 0.99, 0.1, );
748
749 let training_history =
750 Arc::new(Mutex::new(VecDeque::with_capacity(config.max_history_size)));
751
752 Ok(Self {
753 config,
754 classifier,
755 rl_agent,
756 feature_extractor,
757 training_history,
758 metrics: CorrectionMetrics::default(),
759 circuit_interface,
760 update_counter: 0,
761 })
762 }
763
764 pub fn correct_errors_adaptive(
766 &mut self,
767 state: &mut Array1<Complex64>,
768 syndrome: &[bool],
769 ) -> Result<AdaptiveCorrectionResult> {
770 let start_time = std::time::Instant::now();
771
772 let features = self.feature_extractor.extract_features(syndrome);
774
775 let (predicted_error_class, confidence) = self.classifier.predict(&features);
777 let predicted_error_type = self.class_to_error_type(predicted_error_class);
778
779 let state_repr = self.syndrome_to_string(syndrome);
781 let action = self.rl_agent.select_action(&state_repr);
782
783 let errors_before = self.count_errors(state, syndrome);
785
786 let correction_applied = if confidence >= self.config.confidence_threshold {
788 self.apply_ml_correction(state, predicted_error_type, action)?;
789 true
790 } else {
791 self.apply_classical_correction(state, syndrome)?;
793 false
794 };
795
796 let errors_after = self.count_errors(state, syndrome);
798
799 let reward = self
801 .rl_agent
802 .calculate_reward(errors_before, errors_after, 1.0);
803
804 let next_state_repr = self.state_to_string(state);
806 self.rl_agent.update_q_value(
807 &state_repr,
808 action,
809 reward,
810 &next_state_repr,
811 errors_after == 0,
812 );
813
814 if self.config.real_time_learning {
816 let training_example = TrainingExample {
817 syndrome: features,
818 error_type: predicted_error_type,
819 action,
820 reward,
821 timestamp: start_time.elapsed().as_secs_f64(),
822 };
823
824 if let Ok(mut history) = self.training_history.lock() {
825 history.push_back(training_example);
826 if history.len() > self.config.max_history_size {
827 history.pop_front();
828 }
829 }
830 }
831
832 self.update_metrics(errors_before, errors_after, confidence, reward);
834
835 self.update_counter += 1;
837 if self.update_counter % self.config.update_frequency == 0 {
838 self.retrain_models()?;
839 }
840
841 let processing_time = start_time.elapsed().as_secs_f64() * 1000.0;
842
843 Ok(AdaptiveCorrectionResult {
844 predicted_error_type,
845 confidence,
846 correction_applied,
847 errors_corrected: errors_before.saturating_sub(errors_after),
848 reward,
849 processing_time_ms: processing_time,
850 rl_action: action,
851 })
852 }
853
854 fn apply_ml_correction(
856 &self,
857 state: &mut Array1<Complex64>,
858 error_type: ErrorType,
859 action: usize,
860 ) -> Result<()> {
861 match action {
862 0 => {
863 self.apply_single_qubit_correction(state, error_type, 0)?;
865 }
866 1 => {
867 self.apply_two_qubit_correction(state, error_type, 0, 1)?;
869 }
870 2 => {
871 self.apply_syndrome_based_correction(state, error_type)?;
873 }
874 3 => {
875 self.apply_probabilistic_correction(state, error_type)?;
877 }
878 _ => {
879 self.apply_single_qubit_correction(state, error_type, 0)?;
881 }
882 }
883 Ok(())
884 }
885
886 fn apply_single_qubit_correction(
888 &self,
889 state: &mut Array1<Complex64>,
890 error_type: ErrorType,
891 qubit: usize,
892 ) -> Result<()> {
893 let n_qubits = (state.len() as f64).log2().ceil() as usize;
894 if qubit >= n_qubits {
895 return Ok(());
896 }
897
898 match error_type {
899 ErrorType::BitFlip => {
900 for i in 0..state.len() {
902 if (i >> qubit) & 1 == 0 {
903 let partner = i | (1 << qubit);
904 if partner < state.len() {
905 state.swap(i, partner);
906 }
907 }
908 }
909 }
910 ErrorType::PhaseFlip => {
911 for i in 0..state.len() {
913 if (i >> qubit) & 1 == 1 {
914 state[i] *= -1.0;
915 }
916 }
917 }
918 ErrorType::BitPhaseFlip => {
919 self.apply_single_qubit_correction(state, ErrorType::PhaseFlip, qubit)?;
921 self.apply_single_qubit_correction(state, ErrorType::BitFlip, qubit)?;
922 }
923 ErrorType::Identity => {
924 }
926 }
927
928 Ok(())
929 }
930
931 fn apply_two_qubit_correction(
933 &self,
934 state: &mut Array1<Complex64>,
935 error_type: ErrorType,
936 qubit1: usize,
937 qubit2: usize,
938 ) -> Result<()> {
939 self.apply_single_qubit_correction(state, error_type, qubit1)?;
941 self.apply_single_qubit_correction(state, error_type, qubit2)?;
942 Ok(())
943 }
944
945 fn apply_syndrome_based_correction(
947 &self,
948 state: &mut Array1<Complex64>,
949 error_type: ErrorType,
950 ) -> Result<()> {
951 let n_qubits = (state.len() as f64).log2().ceil() as usize;
953 let target_qubit = fastrand::usize(0..n_qubits);
954 self.apply_single_qubit_correction(state, error_type, target_qubit)?;
955 Ok(())
956 }
957
958 fn apply_probabilistic_correction(
960 &self,
961 state: &mut Array1<Complex64>,
962 error_type: ErrorType,
963 ) -> Result<()> {
964 let n_qubits = (state.len() as f64).log2().ceil() as usize;
965
966 for qubit in 0..n_qubits {
968 let prob = match error_type {
969 ErrorType::BitFlip => 0.3,
970 ErrorType::PhaseFlip => 0.2,
971 ErrorType::BitPhaseFlip => 0.1,
972 ErrorType::Identity => 0.0,
973 };
974
975 if fastrand::f64() < prob {
976 self.apply_single_qubit_correction(state, error_type, qubit)?;
977 }
978 }
979
980 Ok(())
981 }
982
983 fn apply_classical_correction(
985 &self,
986 state: &mut Array1<Complex64>,
987 syndrome: &[bool],
988 ) -> Result<()> {
989 for (i, &has_error) in syndrome.iter().enumerate() {
991 if has_error {
992 self.apply_single_qubit_correction(state, ErrorType::BitFlip, i)?;
993 }
994 }
995 Ok(())
996 }
997
998 fn count_errors(&self, _state: &Array1<Complex64>, syndrome: &[bool]) -> usize {
1000 syndrome.iter().map(|&b| usize::from(b)).sum()
1001 }
1002
1003 const fn class_to_error_type(&self, class: usize) -> ErrorType {
1005 match class {
1006 0 => ErrorType::Identity,
1007 1 => ErrorType::BitFlip,
1008 2 => ErrorType::PhaseFlip,
1009 3 => ErrorType::BitPhaseFlip,
1010 _ => ErrorType::Identity,
1011 }
1012 }
1013
1014 fn syndrome_to_string(&self, syndrome: &[bool]) -> String {
1016 syndrome
1017 .iter()
1018 .map(|&b| if b { '1' } else { '0' })
1019 .collect()
1020 }
1021
1022 fn state_to_string(&self, state: &Array1<Complex64>) -> String {
1024 let amplitudes: Vec<f64> = state.iter().map(|c| c.norm()).collect();
1025 format!("{amplitudes:.3?}")
1026 }
1027
1028 fn update_metrics(
1030 &mut self,
1031 errors_before: usize,
1032 errors_after: usize,
1033 confidence: f64,
1034 reward: f64,
1035 ) {
1036 self.metrics.total_corrections += 1;
1037
1038 if errors_after < errors_before {
1039 self.metrics.successful_corrections += 1;
1040 } else if errors_after > errors_before {
1041 self.metrics.false_positives += 1;
1042 }
1043
1044 self.metrics.average_confidence = self
1045 .metrics
1046 .average_confidence
1047 .mul_add((self.metrics.total_corrections - 1) as f64, confidence)
1048 / self.metrics.total_corrections as f64;
1049
1050 self.metrics.reward_history.push(reward);
1051 if self.metrics.reward_history.len() > 1000 {
1052 self.metrics.reward_history.remove(0);
1053 }
1054 }
1055
1056 fn retrain_models(&mut self) -> Result<()> {
1058 let history = self.training_history.lock().map_err(|e| {
1059 crate::error::SimulatorError::InvalidOperation(format!("Lock poisoned: {e}"))
1060 })?;
1061 if history.len() < self.config.batch_size {
1062 return Ok(());
1063 }
1064
1065 let mut inputs = Vec::new();
1067 let mut targets = Vec::new();
1068
1069 for example in history.iter() {
1070 inputs.push(example.syndrome.clone());
1071
1072 let mut target = Array1::zeros(4);
1074 let error_class = match example.error_type {
1075 ErrorType::Identity => 0,
1076 ErrorType::BitFlip => 1,
1077 ErrorType::PhaseFlip => 2,
1078 ErrorType::BitPhaseFlip => 3,
1079 };
1080 target[error_class] = 1.0;
1081 targets.push(target);
1082 }
1083
1084 let batch_size = self.config.batch_size.min(inputs.len());
1086 for chunk in inputs.chunks(batch_size).zip(targets.chunks(batch_size)) {
1087 let loss = self.classifier.train_batch(chunk.0, chunk.1);
1088 self.metrics.learning_curve.push(loss);
1089 }
1090
1091 Ok(())
1092 }
1093
1094 #[must_use]
1096 pub const fn get_metrics(&self) -> &CorrectionMetrics {
1097 &self.metrics
1098 }
1099
1100 pub fn reset(&mut self) {
1102 self.metrics = CorrectionMetrics::default();
1103 if let Ok(mut history) = self.training_history.lock() {
1104 history.clear();
1105 }
1106 self.update_counter = 0;
1107 }
1108}
1109
1110#[derive(Debug, Clone, Serialize, Deserialize)]
1112pub struct AdaptiveCorrectionResult {
1113 pub predicted_error_type: ErrorType,
1115 pub confidence: f64,
1117 pub correction_applied: bool,
1119 pub errors_corrected: usize,
1121 pub reward: f64,
1123 pub processing_time_ms: f64,
1125 pub rl_action: usize,
1127}
1128
1129pub fn benchmark_adaptive_ml_error_correction() -> Result<HashMap<String, f64>> {
1131 let mut results = HashMap::new();
1132
1133 let configs = vec![
1135 AdaptiveMLConfig {
1136 model_type: MLModelType::NeuralNetwork,
1137 learning_strategy: LearningStrategy::Online,
1138 ..Default::default()
1139 },
1140 AdaptiveMLConfig {
1141 model_type: MLModelType::ReinforcementLearning,
1142 learning_strategy: LearningStrategy::Reinforcement,
1143 ..Default::default()
1144 },
1145 ];
1146
1147 for (i, config) in configs.into_iter().enumerate() {
1148 let start = std::time::Instant::now();
1149
1150 let mut adaptive_ec = AdaptiveMLErrorCorrection::new(config)?;
1151
1152 for _ in 0..100 {
1154 let mut test_state = Array1::from_vec(vec![
1155 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1156 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1157 Complex64::new(0.0, 0.0),
1158 Complex64::new(0.0, 0.0),
1159 ]);
1160
1161 let syndrome = vec![true, false, true, false]; let _result = adaptive_ec.correct_errors_adaptive(&mut test_state, &syndrome)?;
1163 }
1164
1165 let time = start.elapsed().as_secs_f64() * 1000.0;
1166 results.insert(format!("config_{i}"), time);
1167 }
1168
1169 Ok(results)
1170}
1171
1172#[cfg(test)]
1173mod tests {
1174 use super::*;
1175 use approx::assert_abs_diff_eq;
1176
1177 #[test]
1178 fn test_neural_network_creation() {
1179 let nn = SyndromeClassificationNetwork::new(4, vec![8, 4], 2, 0.01);
1180 assert_eq!(nn.input_size, 4);
1181 assert_eq!(nn.output_size, 2);
1182 assert_eq!(nn.weights.len(), 3); }
1184
1185 #[test]
1186 fn test_neural_network_forward() {
1187 let nn = SyndromeClassificationNetwork::new(3, vec![4], 2, 0.01);
1188 let input = Array1::from_vec(vec![1.0, 0.0, 1.0]);
1189 let output = nn.forward(&input);
1190
1191 assert_eq!(output.len(), 2);
1192 assert_abs_diff_eq!(output.sum(), 1.0, epsilon = 1e-6); }
1194
1195 #[test]
1196 fn test_rl_agent_creation() {
1197 let agent = ErrorCorrectionAgent::new(4, 0.1, 0.99, 0.1);
1198 assert_eq!(agent.action_space_size, 4);
1199 assert!(agent.q_table.is_empty());
1200 }
1201
1202 #[test]
1203 fn test_rl_agent_action_selection() {
1204 let mut agent = ErrorCorrectionAgent::new(3, 0.1, 0.99, 0.0); let state = "001";
1206
1207 let action = agent.select_action(state);
1209 assert!(action < 3);
1210 }
1211
1212 #[test]
1213 fn test_feature_extraction() {
1214 let extractor = FeatureExtractor::new(FeatureExtractionMethod::RawSyndrome);
1215 let syndrome = vec![true, false, true, false];
1216 let features = extractor.extract_features(&syndrome);
1217
1218 assert_eq!(features.len(), 4);
1219 assert_abs_diff_eq!(features[0], 1.0, epsilon = 1e-10);
1220 assert_abs_diff_eq!(features[1], 0.0, epsilon = 1e-10);
1221 assert_abs_diff_eq!(features[2], 1.0, epsilon = 1e-10);
1222 assert_abs_diff_eq!(features[3], 0.0, epsilon = 1e-10);
1223 }
1224
1225 #[test]
1226 fn test_adaptive_ml_error_correction_creation() {
1227 let config = AdaptiveMLConfig::default();
1228 let adaptive_ec = AdaptiveMLErrorCorrection::new(config);
1229 assert!(adaptive_ec.is_ok());
1230 }
1231
1232 #[test]
1233 fn test_error_correction_application() {
1234 let config = AdaptiveMLConfig::default();
1235 let mut adaptive_ec = AdaptiveMLErrorCorrection::new(config)
1236 .expect("Failed to create AdaptiveMLErrorCorrection");
1237
1238 let mut state = Array1::from_vec(vec![
1239 Complex64::new(1.0, 0.0),
1240 Complex64::new(0.0, 0.0),
1241 Complex64::new(0.0, 0.0),
1242 Complex64::new(0.0, 0.0),
1243 ]);
1244
1245 let syndrome = vec![false, false];
1246 let result = adaptive_ec.correct_errors_adaptive(&mut state, &syndrome);
1247 assert!(result.is_ok());
1248
1249 let correction_result = result.expect("Failed to correct errors");
1250 assert!(correction_result.processing_time_ms >= 0.0);
1251 }
1252
1253 #[test]
1254 fn test_metrics_calculation() {
1255 let mut metrics = CorrectionMetrics::default();
1256 metrics.total_corrections = 100;
1257 metrics.successful_corrections = 90;
1258 metrics.false_positives = 5;
1259 metrics.false_negatives = 5;
1260
1261 assert_abs_diff_eq!(metrics.accuracy(), 0.9, epsilon = 1e-10);
1262 assert_abs_diff_eq!(metrics.precision(), 90.0 / 95.0, epsilon = 1e-10);
1263 assert_abs_diff_eq!(metrics.recall(), 90.0 / 95.0, epsilon = 1e-10);
1264 }
1265
1266 #[test]
1267 fn test_different_error_types() {
1268 let config = AdaptiveMLConfig::default();
1269 let adaptive_ec = AdaptiveMLErrorCorrection::new(config)
1270 .expect("Failed to create AdaptiveMLErrorCorrection");
1271
1272 assert_eq!(adaptive_ec.class_to_error_type(0), ErrorType::Identity);
1273 assert_eq!(adaptive_ec.class_to_error_type(1), ErrorType::BitFlip);
1274 assert_eq!(adaptive_ec.class_to_error_type(2), ErrorType::PhaseFlip);
1275 assert_eq!(adaptive_ec.class_to_error_type(3), ErrorType::BitPhaseFlip);
1276 }
1277}