1use scirs2_core::ndarray::{Array1, Array2, Axis};
17use scirs2_core::Complex64;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, VecDeque};
20use std::sync::{Arc, Mutex};
21
22use crate::circuit_interfaces::CircuitInterface;
23use crate::concatenated_error_correction::ErrorType;
24use crate::error::Result;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum MLModelType {
29 NeuralNetwork,
31 DecisionTree,
33 SVM,
35 ReinforcementLearning,
37 Ensemble,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum LearningStrategy {
44 Supervised,
46 Unsupervised,
48 Reinforcement,
50 Online,
52 Transfer,
54}
55
56#[derive(Debug, Clone)]
58pub struct AdaptiveMLConfig {
59 pub model_type: MLModelType,
61 pub learning_strategy: LearningStrategy,
63 pub learning_rate: f64,
65 pub batch_size: usize,
67 pub max_history_size: usize,
69 pub confidence_threshold: f64,
71 pub real_time_learning: bool,
73 pub update_frequency: usize,
75 pub feature_extraction: FeatureExtractionMethod,
77 pub hardware_aware: bool,
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum FeatureExtractionMethod {
84 RawSyndrome,
86 FourierTransform,
88 PCA,
90 Autoencoder,
92 TemporalConvolution,
94}
95
96impl Default for AdaptiveMLConfig {
97 fn default() -> Self {
98 Self {
99 model_type: MLModelType::NeuralNetwork,
100 learning_strategy: LearningStrategy::Online,
101 learning_rate: 0.001,
102 batch_size: 32,
103 max_history_size: 10000,
104 confidence_threshold: 0.8,
105 real_time_learning: true,
106 update_frequency: 100,
107 feature_extraction: FeatureExtractionMethod::RawSyndrome,
108 hardware_aware: true,
109 }
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct SyndromeClassificationNetwork {
116 input_size: usize,
118 hidden_sizes: Vec<usize>,
120 output_size: usize,
122 weights: Vec<Array2<f64>>,
124 biases: Vec<Array1<f64>>,
126 learning_rate: f64,
128 training_history: Vec<(Array1<f64>, Array1<f64>)>,
130}
131
132impl SyndromeClassificationNetwork {
133 pub fn new(
135 input_size: usize,
136 hidden_sizes: Vec<usize>,
137 output_size: usize,
138 learning_rate: f64,
139 ) -> Self {
140 let mut layer_sizes = vec![input_size];
141 layer_sizes.extend(&hidden_sizes);
142 layer_sizes.push(output_size);
143
144 let mut weights = Vec::new();
145 let mut biases = Vec::new();
146
147 for i in 0..layer_sizes.len() - 1 {
148 let rows = layer_sizes[i + 1];
149 let cols = layer_sizes[i];
150
151 let scale = (2.0 / (rows + cols) as f64).sqrt();
153 let mut weight_matrix = Array2::zeros((rows, cols));
154 for elem in &mut weight_matrix {
155 *elem = (fastrand::f64() - 0.5) * 2.0 * scale;
156 }
157 weights.push(weight_matrix);
158
159 biases.push(Array1::zeros(rows));
160 }
161
162 Self {
163 input_size,
164 hidden_sizes,
165 output_size,
166 weights,
167 biases,
168 learning_rate,
169 training_history: Vec::new(),
170 }
171 }
172
173 pub fn forward(&self, input: &Array1<f64>) -> Array1<f64> {
175 let mut activation = input.clone();
176
177 for (weight, bias) in self.weights.iter().zip(self.biases.iter()) {
178 activation = weight.dot(&activation) + bias;
179
180 if weight == self.weights.last().unwrap() {
182 let max_val = activation.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
184 activation.mapv_inplace(|x| (x - max_val).exp());
185 let sum = activation.sum();
186 activation.mapv_inplace(|x| x / sum);
187 } else {
188 activation.mapv_inplace(|x| x.max(0.0));
189 }
190 }
191
192 activation
193 }
194
195 pub fn train_batch(&mut self, inputs: &[Array1<f64>], targets: &[Array1<f64>]) -> f64 {
197 let batch_size = inputs.len();
198 let mut total_loss = 0.0;
199
200 let mut weight_gradients: Vec<Array2<f64>> = self
202 .weights
203 .iter()
204 .map(|w| Array2::zeros(w.raw_dim()))
205 .collect();
206 let mut bias_gradients: Vec<Array1<f64>> = self
207 .biases
208 .iter()
209 .map(|b| Array1::zeros(b.raw_dim()))
210 .collect();
211
212 for (input, target) in inputs.iter().zip(targets.iter()) {
213 let (loss, w_grads, b_grads) = self.backward(input, target);
214 total_loss += loss;
215
216 for (wg_acc, wg) in weight_gradients.iter_mut().zip(w_grads.iter()) {
217 *wg_acc = &*wg_acc + wg;
218 }
219 for (bg_acc, bg) in bias_gradients.iter_mut().zip(b_grads.iter()) {
220 *bg_acc = &*bg_acc + bg;
221 }
222 }
223
224 let lr = self.learning_rate / batch_size as f64;
226 for (weight, gradient) in self.weights.iter_mut().zip(weight_gradients.iter()) {
227 *weight = &*weight - &(gradient * lr);
228 }
229 for (bias, gradient) in self.biases.iter_mut().zip(bias_gradients.iter()) {
230 *bias = &*bias - &(gradient * lr);
231 }
232
233 total_loss / batch_size as f64
234 }
235
236 fn backward(
238 &self,
239 input: &Array1<f64>,
240 target: &Array1<f64>,
241 ) -> (f64, Vec<Array2<f64>>, Vec<Array1<f64>>) {
242 let mut activations = vec![input.clone()];
244 let mut z_values = Vec::new();
245
246 for (weight, bias) in self.weights.iter().zip(self.biases.iter()) {
247 let z = weight.dot(activations.last().unwrap()) + bias;
248 z_values.push(z.clone());
249
250 let mut activation = z;
251 if weight == self.weights.last().unwrap() {
252 let max_val = activation.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
254 activation.mapv_inplace(|x| (x - max_val).exp());
255 let sum = activation.sum();
256 activation.mapv_inplace(|x| x / sum);
257 } else {
258 activation.mapv_inplace(|x| x.max(0.0)); }
260 activations.push(activation);
261 }
262
263 let output = activations.last().unwrap();
265 let loss = -target
266 .iter()
267 .zip(output.iter())
268 .map(|(&t, &o)| if t > 0.0 { t * o.ln() } else { 0.0 })
269 .sum::<f64>();
270
271 let mut weight_gradients = Vec::with_capacity(self.weights.len());
273 let mut bias_gradients = Vec::with_capacity(self.biases.len());
274
275 let mut delta = output - target;
277
278 for i in (0..self.weights.len()).rev() {
279 let weight_grad = delta
281 .view()
282 .insert_axis(Axis(1))
283 .dot(&activations[i].view().insert_axis(Axis(0)));
284 weight_gradients.insert(0, weight_grad);
285
286 bias_gradients.insert(0, delta.clone());
288
289 if i > 0 {
290 delta = self.weights[i].t().dot(&delta);
292
293 for (j, &z) in z_values[i - 1].iter().enumerate() {
295 if z <= 0.0 {
296 delta[j] = 0.0;
297 }
298 }
299 }
300 }
301
302 (loss, weight_gradients, bias_gradients)
303 }
304
305 pub fn predict(&self, syndrome: &Array1<f64>) -> (usize, f64) {
307 let output = self.forward(syndrome);
308 let max_idx = output
309 .iter()
310 .enumerate()
311 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
312 .unwrap()
313 .0;
314 let confidence = output[max_idx];
315 (max_idx, confidence)
316 }
317}
318
319#[derive(Debug, Clone)]
321pub struct ErrorCorrectionAgent {
322 q_table: HashMap<String, Array1<f64>>,
324 learning_rate: f64,
326 discount_factor: f64,
328 epsilon: f64,
330 action_space_size: usize,
332 training_steps: usize,
334 episode_rewards: VecDeque<f64>,
336}
337
338impl ErrorCorrectionAgent {
339 pub fn new(
341 action_space_size: usize,
342 learning_rate: f64,
343 discount_factor: f64,
344 epsilon: f64,
345 ) -> Self {
346 Self {
347 q_table: HashMap::new(),
348 learning_rate,
349 discount_factor,
350 epsilon,
351 action_space_size,
352 training_steps: 0,
353 episode_rewards: VecDeque::with_capacity(1000),
354 }
355 }
356
357 pub fn select_action(&mut self, state: &str) -> usize {
359 if fastrand::f64() < self.epsilon {
360 fastrand::usize(0..self.action_space_size)
362 } else {
363 let q_values = self
365 .q_table
366 .entry(state.to_string())
367 .or_insert_with(|| Array1::zeros(self.action_space_size));
368
369 q_values
370 .iter()
371 .enumerate()
372 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
373 .unwrap()
374 .0
375 }
376 }
377
378 pub fn update_q_value(
380 &mut self,
381 state: &str,
382 action: usize,
383 reward: f64,
384 next_state: &str,
385 done: bool,
386 ) {
387 let current_q = self
388 .q_table
389 .entry(state.to_string())
390 .or_insert_with(|| Array1::zeros(self.action_space_size))
391 .clone();
392
393 let next_q_max = if done {
394 0.0
395 } else {
396 let next_q_values = self
397 .q_table
398 .entry(next_state.to_string())
399 .or_insert_with(|| Array1::zeros(self.action_space_size));
400 next_q_values
401 .iter()
402 .fold(f64::NEG_INFINITY, |a, &b| a.max(b))
403 };
404
405 let td_target = self.discount_factor.mul_add(next_q_max, reward);
406 let td_error = td_target - current_q[action];
407
408 let q_values = self.q_table.get_mut(state).unwrap();
409 q_values[action] += self.learning_rate * td_error;
410
411 self.training_steps += 1;
412
413 if self.training_steps % 1000 == 0 {
415 self.epsilon = (self.epsilon * 0.995).max(0.01);
416 }
417 }
418
419 pub fn calculate_reward(
421 &self,
422 errors_before: usize,
423 errors_after: usize,
424 correction_cost: f64,
425 ) -> f64 {
426 let error_reduction = errors_before as f64 - errors_after as f64;
427 let reward = error_reduction.mul_add(10.0, -correction_cost);
428
429 if errors_after == 0 {
431 reward + 5.0
432 } else {
433 reward
434 }
435 }
436}
437
438pub struct AdaptiveMLErrorCorrection {
440 config: AdaptiveMLConfig,
442 classifier: SyndromeClassificationNetwork,
444 rl_agent: ErrorCorrectionAgent,
446 feature_extractor: FeatureExtractor,
448 training_history: Arc<Mutex<VecDeque<TrainingExample>>>,
450 metrics: CorrectionMetrics,
452 circuit_interface: CircuitInterface,
454 update_counter: usize,
456}
457
458#[derive(Debug, Clone)]
460pub struct TrainingExample {
461 pub syndrome: Array1<f64>,
463 pub error_type: ErrorType,
465 pub action: usize,
467 pub reward: f64,
469 pub timestamp: f64,
471}
472
473#[derive(Debug, Clone)]
475pub struct FeatureExtractor {
476 method: FeatureExtractionMethod,
478 pca_components: Option<Array2<f64>>,
480 autoencoder: Option<SyndromeClassificationNetwork>,
482}
483
484impl FeatureExtractor {
485 pub const fn new(method: FeatureExtractionMethod) -> Self {
487 Self {
488 method,
489 pca_components: None,
490 autoencoder: None,
491 }
492 }
493
494 pub fn extract_features(&self, syndrome: &[bool]) -> Array1<f64> {
496 match self.method {
497 FeatureExtractionMethod::RawSyndrome => {
498 let mut features: Vec<f64> = syndrome
499 .iter()
500 .map(|&b| if b { 1.0 } else { 0.0 })
501 .collect();
502 while features.len() < 4 {
504 features.push(0.0);
505 }
506 Array1::from_vec(features)
507 }
508 FeatureExtractionMethod::FourierTransform => self.fft_features(syndrome),
509 FeatureExtractionMethod::PCA => self.pca_features(syndrome),
510 FeatureExtractionMethod::Autoencoder => self.autoencoder_features(syndrome),
511 FeatureExtractionMethod::TemporalConvolution => self.temporal_conv_features(syndrome),
512 }
513 }
514
515 fn fft_features(&self, syndrome: &[bool]) -> Array1<f64> {
517 let mut signal: Vec<f64> = syndrome
518 .iter()
519 .map(|&b| if b { 1.0 } else { 0.0 })
520 .collect();
521
522 while signal.len() < 4 {
524 signal.push(0.0);
525 }
526
527 let mut features = Vec::new();
529 let n = signal.len();
530
531 for k in 0..n.min(8) {
532 let mut real_part = 0.0;
534 let mut imag_part = 0.0;
535
536 for (i, &x) in signal.iter().enumerate() {
537 let angle = -2.0 * std::f64::consts::PI * k as f64 * i as f64 / n as f64;
538 real_part += x * angle.cos();
539 imag_part += x * angle.sin();
540 }
541
542 features.push(real_part);
543 features.push(imag_part);
544 }
545
546 Array1::from_vec(features)
547 }
548
549 fn pca_features(&self, syndrome: &[bool]) -> Array1<f64> {
551 let mut features: Vec<f64> = syndrome
552 .iter()
553 .map(|&b| if b { 1.0 } else { 0.0 })
554 .collect();
555 while features.len() < 4 {
557 features.push(0.0);
558 }
559 let raw_features = Array1::from_vec(features);
560
561 if let Some(ref components) = self.pca_components {
562 components.dot(&raw_features)
563 } else {
564 raw_features
565 }
566 }
567
568 fn autoencoder_features(&self, syndrome: &[bool]) -> Array1<f64> {
570 let mut features: Vec<f64> = syndrome
571 .iter()
572 .map(|&b| if b { 1.0 } else { 0.0 })
573 .collect();
574 while features.len() < 4 {
576 features.push(0.0);
577 }
578 let raw_features = Array1::from_vec(features);
579
580 if let Some(ref encoder) = self.autoencoder {
581 encoder.forward(&raw_features)
582 } else {
583 raw_features
584 }
585 }
586
587 fn temporal_conv_features(&self, syndrome: &[bool]) -> Array1<f64> {
589 let mut signal: Vec<f64> = syndrome
590 .iter()
591 .map(|&b| if b { 1.0 } else { 0.0 })
592 .collect();
593
594 while signal.len() < 4 {
596 signal.push(0.0);
597 }
598
599 let kernel_size = 3;
601 let mut features = Vec::new();
602
603 for i in 0..signal.len().saturating_sub(kernel_size - 1) {
604 let mut conv_sum = 0.0;
605 for j in 0..kernel_size {
606 conv_sum += signal[i + j] * (j as f64 + 1.0) / kernel_size as f64;
607 }
609 features.push(conv_sum);
610 }
611
612 if features.is_empty() {
614 features = signal; }
616
617 Array1::from_vec(features)
618 }
619}
620
621#[derive(Debug, Clone, Default, Serialize, Deserialize)]
623pub struct CorrectionMetrics {
624 pub total_corrections: usize,
626 pub successful_corrections: usize,
628 pub false_positives: usize,
630 pub false_negatives: usize,
632 pub average_confidence: f64,
634 pub learning_curve: Vec<f64>,
636 pub reward_history: Vec<f64>,
638 pub avg_correction_time_ms: f64,
640}
641
642impl CorrectionMetrics {
643 pub fn accuracy(&self) -> f64 {
645 if self.total_corrections == 0 {
646 return 1.0;
647 }
648 self.successful_corrections as f64 / self.total_corrections as f64
649 }
650
651 pub fn precision(&self) -> f64 {
653 let true_positives = self.successful_corrections;
654 let predicted_positives = true_positives + self.false_positives;
655
656 if predicted_positives == 0 {
657 return 1.0;
658 }
659 true_positives as f64 / predicted_positives as f64
660 }
661
662 pub fn recall(&self) -> f64 {
664 let true_positives = self.successful_corrections;
665 let actual_positives = true_positives + self.false_negatives;
666
667 if actual_positives == 0 {
668 return 1.0;
669 }
670 true_positives as f64 / actual_positives as f64
671 }
672
673 pub fn f1_score(&self) -> f64 {
675 let precision = self.precision();
676 let recall = self.recall();
677
678 if precision + recall == 0.0 {
679 return 0.0;
680 }
681 2.0 * precision * recall / (precision + recall)
682 }
683}
684
685impl AdaptiveMLErrorCorrection {
686 pub fn new(config: AdaptiveMLConfig) -> Result<Self> {
688 let circuit_interface = CircuitInterface::new(Default::default())?;
689
690 let feature_extractor = FeatureExtractor::new(config.feature_extraction);
692
693 let test_syndrome = vec![false, false, false, false]; let test_features = feature_extractor.extract_features(&test_syndrome);
697 let input_size = test_features.len();
698
699 let hidden_sizes = vec![input_size * 2, input_size]; let output_size = 4; let classifier = SyndromeClassificationNetwork::new(
703 input_size,
704 hidden_sizes,
705 output_size,
706 config.learning_rate,
707 );
708
709 let action_space_size = 8; let rl_agent = ErrorCorrectionAgent::new(
712 action_space_size,
713 config.learning_rate,
714 0.99, 0.1, );
717
718 let training_history =
719 Arc::new(Mutex::new(VecDeque::with_capacity(config.max_history_size)));
720
721 Ok(Self {
722 config,
723 classifier,
724 rl_agent,
725 feature_extractor,
726 training_history,
727 metrics: CorrectionMetrics::default(),
728 circuit_interface,
729 update_counter: 0,
730 })
731 }
732
733 pub fn correct_errors_adaptive(
735 &mut self,
736 state: &mut Array1<Complex64>,
737 syndrome: &[bool],
738 ) -> Result<AdaptiveCorrectionResult> {
739 let start_time = std::time::Instant::now();
740
741 let features = self.feature_extractor.extract_features(syndrome);
743
744 let (predicted_error_class, confidence) = self.classifier.predict(&features);
746 let predicted_error_type = self.class_to_error_type(predicted_error_class);
747
748 let state_repr = self.syndrome_to_string(syndrome);
750 let action = self.rl_agent.select_action(&state_repr);
751
752 let errors_before = self.count_errors(state, syndrome);
754
755 let correction_applied = if confidence >= self.config.confidence_threshold {
757 self.apply_ml_correction(state, predicted_error_type, action)?;
758 true
759 } else {
760 self.apply_classical_correction(state, syndrome)?;
762 false
763 };
764
765 let errors_after = self.count_errors(state, syndrome);
767
768 let reward = self
770 .rl_agent
771 .calculate_reward(errors_before, errors_after, 1.0);
772
773 let next_state_repr = self.state_to_string(state);
775 self.rl_agent.update_q_value(
776 &state_repr,
777 action,
778 reward,
779 &next_state_repr,
780 errors_after == 0,
781 );
782
783 if self.config.real_time_learning {
785 let training_example = TrainingExample {
786 syndrome: features,
787 error_type: predicted_error_type,
788 action,
789 reward,
790 timestamp: start_time.elapsed().as_secs_f64(),
791 };
792
793 {
794 let mut history = self.training_history.lock().unwrap();
795 history.push_back(training_example);
796 if history.len() > self.config.max_history_size {
797 history.pop_front();
798 }
799 }
800 }
801
802 self.update_metrics(errors_before, errors_after, confidence, reward);
804
805 self.update_counter += 1;
807 if self.update_counter % self.config.update_frequency == 0 {
808 self.retrain_models()?;
809 }
810
811 let processing_time = start_time.elapsed().as_secs_f64() * 1000.0;
812
813 Ok(AdaptiveCorrectionResult {
814 predicted_error_type,
815 confidence,
816 correction_applied,
817 errors_corrected: errors_before.saturating_sub(errors_after),
818 reward,
819 processing_time_ms: processing_time,
820 rl_action: action,
821 })
822 }
823
824 fn apply_ml_correction(
826 &self,
827 state: &mut Array1<Complex64>,
828 error_type: ErrorType,
829 action: usize,
830 ) -> Result<()> {
831 match action {
832 0 => {
833 self.apply_single_qubit_correction(state, error_type, 0)?;
835 }
836 1 => {
837 self.apply_two_qubit_correction(state, error_type, 0, 1)?;
839 }
840 2 => {
841 self.apply_syndrome_based_correction(state, error_type)?;
843 }
844 3 => {
845 self.apply_probabilistic_correction(state, error_type)?;
847 }
848 _ => {
849 self.apply_single_qubit_correction(state, error_type, 0)?;
851 }
852 }
853 Ok(())
854 }
855
856 fn apply_single_qubit_correction(
858 &self,
859 state: &mut Array1<Complex64>,
860 error_type: ErrorType,
861 qubit: usize,
862 ) -> Result<()> {
863 let n_qubits = (state.len() as f64).log2().ceil() as usize;
864 if qubit >= n_qubits {
865 return Ok(());
866 }
867
868 match error_type {
869 ErrorType::BitFlip => {
870 for i in 0..state.len() {
872 if (i >> qubit) & 1 == 0 {
873 let partner = i | (1 << qubit);
874 if partner < state.len() {
875 state.swap(i, partner);
876 }
877 }
878 }
879 }
880 ErrorType::PhaseFlip => {
881 for i in 0..state.len() {
883 if (i >> qubit) & 1 == 1 {
884 state[i] *= -1.0;
885 }
886 }
887 }
888 ErrorType::BitPhaseFlip => {
889 self.apply_single_qubit_correction(state, ErrorType::PhaseFlip, qubit)?;
891 self.apply_single_qubit_correction(state, ErrorType::BitFlip, qubit)?;
892 }
893 ErrorType::Identity => {
894 }
896 }
897
898 Ok(())
899 }
900
901 fn apply_two_qubit_correction(
903 &self,
904 state: &mut Array1<Complex64>,
905 error_type: ErrorType,
906 qubit1: usize,
907 qubit2: usize,
908 ) -> Result<()> {
909 self.apply_single_qubit_correction(state, error_type, qubit1)?;
911 self.apply_single_qubit_correction(state, error_type, qubit2)?;
912 Ok(())
913 }
914
915 fn apply_syndrome_based_correction(
917 &self,
918 state: &mut Array1<Complex64>,
919 error_type: ErrorType,
920 ) -> Result<()> {
921 let n_qubits = (state.len() as f64).log2().ceil() as usize;
923 let target_qubit = fastrand::usize(0..n_qubits);
924 self.apply_single_qubit_correction(state, error_type, target_qubit)?;
925 Ok(())
926 }
927
928 fn apply_probabilistic_correction(
930 &self,
931 state: &mut Array1<Complex64>,
932 error_type: ErrorType,
933 ) -> Result<()> {
934 let n_qubits = (state.len() as f64).log2().ceil() as usize;
935
936 for qubit in 0..n_qubits {
938 let prob = match error_type {
939 ErrorType::BitFlip => 0.3,
940 ErrorType::PhaseFlip => 0.2,
941 ErrorType::BitPhaseFlip => 0.1,
942 ErrorType::Identity => 0.0,
943 };
944
945 if fastrand::f64() < prob {
946 self.apply_single_qubit_correction(state, error_type, qubit)?;
947 }
948 }
949
950 Ok(())
951 }
952
953 fn apply_classical_correction(
955 &self,
956 state: &mut Array1<Complex64>,
957 syndrome: &[bool],
958 ) -> Result<()> {
959 for (i, &has_error) in syndrome.iter().enumerate() {
961 if has_error {
962 self.apply_single_qubit_correction(state, ErrorType::BitFlip, i)?;
963 }
964 }
965 Ok(())
966 }
967
968 fn count_errors(&self, _state: &Array1<Complex64>, syndrome: &[bool]) -> usize {
970 syndrome.iter().map(|&b| usize::from(b)).sum()
971 }
972
973 const fn class_to_error_type(&self, class: usize) -> ErrorType {
975 match class {
976 0 => ErrorType::Identity,
977 1 => ErrorType::BitFlip,
978 2 => ErrorType::PhaseFlip,
979 3 => ErrorType::BitPhaseFlip,
980 _ => ErrorType::Identity,
981 }
982 }
983
984 fn syndrome_to_string(&self, syndrome: &[bool]) -> String {
986 syndrome
987 .iter()
988 .map(|&b| if b { '1' } else { '0' })
989 .collect()
990 }
991
992 fn state_to_string(&self, state: &Array1<Complex64>) -> String {
994 let amplitudes: Vec<f64> = state.iter().map(|c| c.norm()).collect();
995 format!("{amplitudes:.3?}")
996 }
997
998 fn update_metrics(
1000 &mut self,
1001 errors_before: usize,
1002 errors_after: usize,
1003 confidence: f64,
1004 reward: f64,
1005 ) {
1006 self.metrics.total_corrections += 1;
1007
1008 if errors_after < errors_before {
1009 self.metrics.successful_corrections += 1;
1010 } else if errors_after > errors_before {
1011 self.metrics.false_positives += 1;
1012 }
1013
1014 self.metrics.average_confidence = self
1015 .metrics
1016 .average_confidence
1017 .mul_add((self.metrics.total_corrections - 1) as f64, confidence)
1018 / self.metrics.total_corrections as f64;
1019
1020 self.metrics.reward_history.push(reward);
1021 if self.metrics.reward_history.len() > 1000 {
1022 self.metrics.reward_history.remove(0);
1023 }
1024 }
1025
1026 fn retrain_models(&mut self) -> Result<()> {
1028 let history = self.training_history.lock().unwrap();
1029 if history.len() < self.config.batch_size {
1030 return Ok(());
1031 }
1032
1033 let mut inputs = Vec::new();
1035 let mut targets = Vec::new();
1036
1037 for example in history.iter() {
1038 inputs.push(example.syndrome.clone());
1039
1040 let mut target = Array1::zeros(4);
1042 let error_class = match example.error_type {
1043 ErrorType::Identity => 0,
1044 ErrorType::BitFlip => 1,
1045 ErrorType::PhaseFlip => 2,
1046 ErrorType::BitPhaseFlip => 3,
1047 };
1048 target[error_class] = 1.0;
1049 targets.push(target);
1050 }
1051
1052 let batch_size = self.config.batch_size.min(inputs.len());
1054 for chunk in inputs.chunks(batch_size).zip(targets.chunks(batch_size)) {
1055 let loss = self.classifier.train_batch(chunk.0, chunk.1);
1056 self.metrics.learning_curve.push(loss);
1057 }
1058
1059 Ok(())
1060 }
1061
1062 pub const fn get_metrics(&self) -> &CorrectionMetrics {
1064 &self.metrics
1065 }
1066
1067 pub fn reset(&mut self) {
1069 self.metrics = CorrectionMetrics::default();
1070 self.training_history.lock().unwrap().clear();
1071 self.update_counter = 0;
1072 }
1073}
1074
1075#[derive(Debug, Clone, Serialize, Deserialize)]
1077pub struct AdaptiveCorrectionResult {
1078 pub predicted_error_type: ErrorType,
1080 pub confidence: f64,
1082 pub correction_applied: bool,
1084 pub errors_corrected: usize,
1086 pub reward: f64,
1088 pub processing_time_ms: f64,
1090 pub rl_action: usize,
1092}
1093
1094pub fn benchmark_adaptive_ml_error_correction() -> Result<HashMap<String, f64>> {
1096 let mut results = HashMap::new();
1097
1098 let configs = vec![
1100 AdaptiveMLConfig {
1101 model_type: MLModelType::NeuralNetwork,
1102 learning_strategy: LearningStrategy::Online,
1103 ..Default::default()
1104 },
1105 AdaptiveMLConfig {
1106 model_type: MLModelType::ReinforcementLearning,
1107 learning_strategy: LearningStrategy::Reinforcement,
1108 ..Default::default()
1109 },
1110 ];
1111
1112 for (i, config) in configs.into_iter().enumerate() {
1113 let start = std::time::Instant::now();
1114
1115 let mut adaptive_ec = AdaptiveMLErrorCorrection::new(config)?;
1116
1117 for _ in 0..100 {
1119 let mut test_state = Array1::from_vec(vec![
1120 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1121 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1122 Complex64::new(0.0, 0.0),
1123 Complex64::new(0.0, 0.0),
1124 ]);
1125
1126 let syndrome = vec![true, false, true, false]; let _result = adaptive_ec.correct_errors_adaptive(&mut test_state, &syndrome)?;
1128 }
1129
1130 let time = start.elapsed().as_secs_f64() * 1000.0;
1131 results.insert(format!("config_{i}"), time);
1132 }
1133
1134 Ok(results)
1135}
1136
1137#[cfg(test)]
1138mod tests {
1139 use super::*;
1140 use approx::assert_abs_diff_eq;
1141
1142 #[test]
1143 fn test_neural_network_creation() {
1144 let nn = SyndromeClassificationNetwork::new(4, vec![8, 4], 2, 0.01);
1145 assert_eq!(nn.input_size, 4);
1146 assert_eq!(nn.output_size, 2);
1147 assert_eq!(nn.weights.len(), 3); }
1149
1150 #[test]
1151 fn test_neural_network_forward() {
1152 let nn = SyndromeClassificationNetwork::new(3, vec![4], 2, 0.01);
1153 let input = Array1::from_vec(vec![1.0, 0.0, 1.0]);
1154 let output = nn.forward(&input);
1155
1156 assert_eq!(output.len(), 2);
1157 assert_abs_diff_eq!(output.sum(), 1.0, epsilon = 1e-6); }
1159
1160 #[test]
1161 fn test_rl_agent_creation() {
1162 let agent = ErrorCorrectionAgent::new(4, 0.1, 0.99, 0.1);
1163 assert_eq!(agent.action_space_size, 4);
1164 assert!(agent.q_table.is_empty());
1165 }
1166
1167 #[test]
1168 fn test_rl_agent_action_selection() {
1169 let mut agent = ErrorCorrectionAgent::new(3, 0.1, 0.99, 0.0); let state = "001";
1171
1172 let action = agent.select_action(state);
1174 assert!(action < 3);
1175 }
1176
1177 #[test]
1178 fn test_feature_extraction() {
1179 let extractor = FeatureExtractor::new(FeatureExtractionMethod::RawSyndrome);
1180 let syndrome = vec![true, false, true, false];
1181 let features = extractor.extract_features(&syndrome);
1182
1183 assert_eq!(features.len(), 4);
1184 assert_abs_diff_eq!(features[0], 1.0, epsilon = 1e-10);
1185 assert_abs_diff_eq!(features[1], 0.0, epsilon = 1e-10);
1186 assert_abs_diff_eq!(features[2], 1.0, epsilon = 1e-10);
1187 assert_abs_diff_eq!(features[3], 0.0, epsilon = 1e-10);
1188 }
1189
1190 #[test]
1191 fn test_adaptive_ml_error_correction_creation() {
1192 let config = AdaptiveMLConfig::default();
1193 let adaptive_ec = AdaptiveMLErrorCorrection::new(config);
1194 assert!(adaptive_ec.is_ok());
1195 }
1196
1197 #[test]
1198 fn test_error_correction_application() {
1199 let config = AdaptiveMLConfig::default();
1200 let mut adaptive_ec = AdaptiveMLErrorCorrection::new(config).unwrap();
1201
1202 let mut state = Array1::from_vec(vec![
1203 Complex64::new(1.0, 0.0),
1204 Complex64::new(0.0, 0.0),
1205 Complex64::new(0.0, 0.0),
1206 Complex64::new(0.0, 0.0),
1207 ]);
1208
1209 let syndrome = vec![false, false];
1210 let result = adaptive_ec.correct_errors_adaptive(&mut state, &syndrome);
1211 assert!(result.is_ok());
1212
1213 let correction_result = result.unwrap();
1214 assert!(correction_result.processing_time_ms >= 0.0);
1215 }
1216
1217 #[test]
1218 fn test_metrics_calculation() {
1219 let mut metrics = CorrectionMetrics::default();
1220 metrics.total_corrections = 100;
1221 metrics.successful_corrections = 90;
1222 metrics.false_positives = 5;
1223 metrics.false_negatives = 5;
1224
1225 assert_abs_diff_eq!(metrics.accuracy(), 0.9, epsilon = 1e-10);
1226 assert_abs_diff_eq!(metrics.precision(), 90.0 / 95.0, epsilon = 1e-10);
1227 assert_abs_diff_eq!(metrics.recall(), 90.0 / 95.0, epsilon = 1e-10);
1228 }
1229
1230 #[test]
1231 fn test_different_error_types() {
1232 let config = AdaptiveMLConfig::default();
1233 let adaptive_ec = AdaptiveMLErrorCorrection::new(config).unwrap();
1234
1235 assert_eq!(adaptive_ec.class_to_error_type(0), ErrorType::Identity);
1236 assert_eq!(adaptive_ec.class_to_error_type(1), ErrorType::BitFlip);
1237 assert_eq!(adaptive_ec.class_to_error_type(2), ErrorType::PhaseFlip);
1238 assert_eq!(adaptive_ec.class_to_error_type(3), ErrorType::BitPhaseFlip);
1239 }
1240}