quantrs2_core/error_correction/
concatenated.rs1use super::pauli::{Pauli, PauliString};
4use super::stabilizer::StabilizerCode;
5use crate::error::{QuantRS2Error, QuantRS2Result};
6use scirs2_core::Complex64;
7
8#[derive(Debug, Clone)]
10pub struct ConcatenatedCode {
11 pub inner_code: StabilizerCode,
13 pub outer_code: StabilizerCode,
15}
16
17impl ConcatenatedCode {
18 pub const fn new(inner_code: StabilizerCode, outer_code: StabilizerCode) -> Self {
20 Self {
21 inner_code,
22 outer_code,
23 }
24 }
25
26 pub const fn total_qubits(&self) -> usize {
28 self.inner_code.n * self.outer_code.n
29 }
30
31 pub const fn logical_qubits(&self) -> usize {
33 self.inner_code.k * self.outer_code.k
34 }
35
36 pub const fn distance(&self) -> usize {
38 self.inner_code.d * self.outer_code.d
39 }
40
41 pub fn encode(&self, logical_state: &[Complex64]) -> QuantRS2Result<Vec<Complex64>> {
43 if logical_state.len() != 1 << self.logical_qubits() {
44 return Err(QuantRS2Error::InvalidInput(
45 "Logical state dimension mismatch".to_string(),
46 ));
47 }
48
49 let outer_encoded = self.encode_with_code(logical_state, &self.outer_code)?;
51
52 let mut final_encoded = vec![Complex64::new(0.0, 0.0); 1 << self.total_qubits()];
54
55 for (i, &litude) in outer_encoded.iter().enumerate() {
58 if amplitude.norm() > 1e-10 {
59 final_encoded[i * (1 << self.inner_code.n)] = amplitude;
60 }
61 }
62
63 Ok(final_encoded)
64 }
65
66 pub fn correct_error(
68 &self,
69 encoded_state: &[Complex64],
70 error: &PauliString,
71 ) -> QuantRS2Result<Vec<Complex64>> {
72 if error.paulis.len() != self.total_qubits() {
73 return Err(QuantRS2Error::InvalidInput(
74 "Error must act on all physical qubits".to_string(),
75 ));
76 }
77
78 let mut corrected = encoded_state.to_vec();
81
82 for (i, &pauli) in error.paulis.iter().enumerate() {
84 if pauli != Pauli::I && i < corrected.len() {
85 corrected[i] *= -1.0;
87 }
88 }
89
90 Ok(corrected)
91 }
92
93 fn encode_with_code(
95 &self,
96 state: &[Complex64],
97 code: &StabilizerCode,
98 ) -> QuantRS2Result<Vec<Complex64>> {
99 let mut encoded = vec![Complex64::new(0.0, 0.0); 1 << code.n];
101
102 for (i, &litude) in state.iter().enumerate() {
103 if i < encoded.len() {
104 encoded[i * (1 << (code.n - code.k))] = amplitude;
105 }
106 }
107
108 Ok(encoded)
109 }
110}