quantrs2_sim/
concatenated_error_correction.rs

1//! Concatenated Quantum Error Correction with Hierarchical Decoding
2//!
3//! This module implements concatenated quantum error correction codes that provide
4//! enhanced error protection through multiple levels of encoding. Concatenated codes
5//! work by applying error correction codes recursively: the physical qubits of an
6//! outer code are themselves logical qubits encoded using an inner code.
7//!
8//! This implementation features:
9//! - Multiple concatenation levels for exponential error reduction
10//! - Hierarchical decoding with error propagation tracking
11//! - Adaptive thresholds based on error rates
12//! - Resource-efficient syndrome processing
13//! - Support for heterogeneous inner and outer codes
14
15use scirs2_core::ndarray::Array1;
16use scirs2_core::parallel_ops::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
17use scirs2_core::Complex64;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, VecDeque};
20
21use crate::circuit_interfaces::{
22    CircuitInterface, InterfaceCircuit, InterfaceGate, InterfaceGateType,
23};
24use crate::error::Result;
25// Remove the invalid imports - we'll define our own implementations
26
27/// Concatenation level configuration
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct ConcatenationLevel {
30    /// Level identifier (0 = innermost, higher = outer levels)
31    pub level: usize,
32    /// Code distance at this level
33    pub distance: usize,
34    /// Number of physical qubits per logical qubit at this level
35    pub code_rate: usize,
36}
37
38/// Concatenated code specification
39#[derive(Debug)]
40pub struct ConcatenatedCodeConfig {
41    /// Concatenation levels from inner to outer
42    pub levels: Vec<ConcatenationLevel>,
43    /// Error correction codes used at each level
44    pub codes_per_level: Vec<Box<dyn ErrorCorrectionCode>>,
45    /// Decoding method
46    pub decoding_method: HierarchicalDecodingMethod,
47    /// Error rate threshold for adaptive decoding
48    pub error_threshold: f64,
49    /// Enable parallel decoding at each level
50    pub parallel_decoding: bool,
51    /// Maximum decoding iterations
52    pub max_decoding_iterations: usize,
53}
54
55/// Error correction code trait for concatenation
56pub trait ErrorCorrectionCode: Send + Sync + std::fmt::Debug {
57    /// Get code parameters
58    fn get_parameters(&self) -> CodeParameters;
59
60    /// Encode logical qubits
61    fn encode(&self, logical_state: &Array1<Complex64>) -> Result<Array1<Complex64>>;
62
63    /// Decode with syndrome extraction
64    fn decode(&self, encoded_state: &Array1<Complex64>) -> Result<DecodingResult>;
65
66    /// Generate syndrome extraction circuit
67    fn syndrome_circuit(&self, num_qubits: usize) -> Result<InterfaceCircuit>;
68
69    /// Apply error correction based on syndrome
70    fn correct_errors(&self, state: &mut Array1<Complex64>, syndrome: &[bool]) -> Result<()>;
71}
72
73/// Code parameters
74#[derive(Debug, Clone, Copy)]
75pub struct CodeParameters {
76    /// Number of logical qubits
77    pub n_logical: usize,
78    /// Number of physical qubits
79    pub n_physical: usize,
80    /// Code distance
81    pub distance: usize,
82    /// Error correction capability
83    pub t: usize, // Can correct up to t errors
84}
85
86/// Hierarchical decoding method
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum HierarchicalDecodingMethod {
89    /// Sequential decoding from inner to outer levels
90    Sequential,
91    /// Parallel decoding across levels
92    Parallel,
93    /// Adaptive decoding based on error patterns
94    Adaptive,
95    /// Belief propagation between levels
96    BeliefPropagation,
97}
98
99/// Decoding result with error information
100#[derive(Debug, Clone)]
101pub struct DecodingResult {
102    /// Corrected state
103    pub corrected_state: Array1<Complex64>,
104    /// Syndrome measurements
105    pub syndrome: Vec<bool>,
106    /// Error pattern detected
107    pub error_pattern: Vec<ErrorType>,
108    /// Decoding confidence
109    pub confidence: f64,
110    /// Number of errors corrected
111    pub errors_corrected: usize,
112    /// Decoding successful
113    pub success: bool,
114}
115
116/// Types of quantum errors
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
118pub enum ErrorType {
119    /// No error
120    Identity,
121    /// Bit flip (X error)
122    BitFlip,
123    /// Phase flip (Z error)
124    PhaseFlip,
125    /// Bit-phase flip (Y error)
126    BitPhaseFlip,
127}
128
129/// Concatenated error correction result
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ConcatenatedCorrectionResult {
132    /// Final corrected state
133    pub final_state: Array1<Complex64>,
134    /// Results from each concatenation level
135    pub level_results: Vec<LevelDecodingResult>,
136    /// Overall decoding statistics
137    pub stats: ConcatenationStats,
138    /// Total execution time
139    pub execution_time_ms: f64,
140    /// Success probability estimate
141    pub success_probability: f64,
142}
143
144/// Decoding result for a single level
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct LevelDecodingResult {
147    /// Concatenation level
148    pub level: usize,
149    /// Syndrome measurements
150    pub syndromes: Vec<Vec<bool>>,
151    /// Errors corrected at this level
152    pub errors_corrected: usize,
153    /// Error patterns detected
154    pub error_patterns: Vec<String>,
155    /// Decoding confidence
156    pub confidence: f64,
157    /// Processing time for this level
158    pub processing_time_ms: f64,
159}
160
161/// Concatenation statistics
162#[derive(Debug, Clone, Default, Serialize, Deserialize)]
163pub struct ConcatenationStats {
164    /// Total physical qubits used
165    pub physical_qubits: usize,
166    /// Total logical qubits encoded
167    pub logical_qubits: usize,
168    /// Overall code distance
169    pub effective_distance: usize,
170    /// Total syndrome measurements
171    pub syndrome_measurements: usize,
172    /// Total errors corrected
173    pub total_errors_corrected: usize,
174    /// Memory overhead
175    pub memory_overhead_factor: f64,
176    /// Circuit depth overhead
177    pub circuit_depth_overhead: usize,
178    /// Decoding iterations performed
179    pub decoding_iterations: usize,
180}
181
182/// Concatenated quantum error correction implementation
183pub struct ConcatenatedErrorCorrection {
184    /// Configuration
185    config: ConcatenatedCodeConfig,
186    /// Circuit interface for compilation
187    circuit_interface: CircuitInterface,
188    /// Syndrome history for adaptive decoding
189    syndrome_history: VecDeque<Vec<Vec<bool>>>,
190    /// Error rate tracking
191    error_rates: HashMap<usize, f64>,
192    /// Statistics
193    stats: ConcatenationStats,
194}
195
196impl ConcatenatedErrorCorrection {
197    /// Create new concatenated error correction instance
198    pub fn new(config: ConcatenatedCodeConfig) -> Result<Self> {
199        let circuit_interface = CircuitInterface::new(Default::default())?;
200        let syndrome_history = VecDeque::with_capacity(100);
201        let error_rates = HashMap::new();
202
203        Ok(Self {
204            config,
205            circuit_interface,
206            syndrome_history,
207            error_rates,
208            stats: ConcatenationStats::default(),
209        })
210    }
211
212    /// Encode logical state using concatenated codes
213    pub fn encode_concatenated(
214        &mut self,
215        logical_state: &Array1<Complex64>,
216    ) -> Result<Array1<Complex64>> {
217        let mut current_state = logical_state.clone();
218
219        // Apply encoding at each concatenation level
220        for (level, code) in self.config.codes_per_level.iter().enumerate() {
221            current_state = code.encode(&current_state)?;
222
223            let params = code.get_parameters();
224            self.stats.physical_qubits = params.n_physical;
225            self.stats.logical_qubits = params.n_logical;
226
227            // Update effective distance (minimum over all levels)
228            if level == 0 {
229                self.stats.effective_distance = params.distance;
230            } else {
231                self.stats.effective_distance = self.stats.effective_distance.min(params.distance);
232            }
233        }
234
235        Ok(current_state)
236    }
237
238    /// Decode using hierarchical error correction
239    pub fn decode_hierarchical(
240        &mut self,
241        encoded_state: &Array1<Complex64>,
242    ) -> Result<ConcatenatedCorrectionResult> {
243        let start_time = std::time::Instant::now();
244
245        let result = match self.config.decoding_method {
246            HierarchicalDecodingMethod::Sequential => self.decode_sequential(encoded_state)?,
247            HierarchicalDecodingMethod::Parallel => self.decode_parallel(encoded_state)?,
248            HierarchicalDecodingMethod::Adaptive => self.decode_adaptive(encoded_state)?,
249            HierarchicalDecodingMethod::BeliefPropagation => {
250                self.decode_belief_propagation(encoded_state)?
251            }
252        };
253
254        let execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
255
256        // Update syndrome history
257        let all_syndromes: Vec<Vec<bool>> = result
258            .level_results
259            .iter()
260            .flat_map(|r| r.syndromes.iter().cloned())
261            .collect();
262        self.syndrome_history.push_back(all_syndromes);
263        if self.syndrome_history.len() > 100 {
264            self.syndrome_history.pop_front();
265        }
266
267        // Calculate success probability
268        let success_probability = self.estimate_success_probability(&result);
269
270        Ok(ConcatenatedCorrectionResult {
271            final_state: result.final_state,
272            level_results: result.level_results,
273            stats: self.stats.clone(),
274            execution_time_ms,
275            success_probability,
276        })
277    }
278
279    /// Sequential decoding from outer to inner levels
280    fn decode_sequential(
281        &mut self,
282        encoded_state: &Array1<Complex64>,
283    ) -> Result<ConcatenatedCorrectionResult> {
284        let mut current_state = encoded_state.clone();
285        let mut level_results = Vec::new();
286
287        // Decode from outermost to innermost level
288        for (level, code) in self.config.codes_per_level.iter().enumerate().rev() {
289            let level_start = std::time::Instant::now();
290
291            let decoding_result = code.decode(&current_state)?;
292            current_state = decoding_result.corrected_state;
293
294            // Convert error pattern to strings for serialization
295            let error_patterns: Vec<String> = decoding_result
296                .error_pattern
297                .iter()
298                .map(|e| format!("{e:?}"))
299                .collect();
300
301            let level_result = LevelDecodingResult {
302                level,
303                syndromes: vec![decoding_result.syndrome],
304                errors_corrected: decoding_result.errors_corrected,
305                error_patterns,
306                confidence: decoding_result.confidence,
307                processing_time_ms: level_start.elapsed().as_secs_f64() * 1000.0,
308            };
309
310            level_results.push(level_result);
311            self.stats.total_errors_corrected += decoding_result.errors_corrected;
312            self.stats.decoding_iterations += 1;
313        }
314
315        // Reverse to get correct order (inner to outer)
316        level_results.reverse();
317
318        Ok(ConcatenatedCorrectionResult {
319            final_state: current_state,
320            level_results,
321            stats: self.stats.clone(),
322            execution_time_ms: 0.0,   // Will be set by caller
323            success_probability: 0.0, // Will be calculated by caller
324        })
325    }
326
327    /// Parallel decoding across levels
328    fn decode_parallel(
329        &mut self,
330        encoded_state: &Array1<Complex64>,
331    ) -> Result<ConcatenatedCorrectionResult> {
332        if !self.config.parallel_decoding {
333            return self.decode_sequential(encoded_state);
334        }
335
336        // For true parallel decoding, we need to carefully manage state sharing
337        // This is a simplified implementation
338        let num_levels = self.config.codes_per_level.len();
339        let mut level_results = Vec::with_capacity(num_levels);
340
341        // Process levels in parallel where possible
342        let results: Vec<_> = (0..num_levels)
343            .into_par_iter()
344            .map(|level| {
345                let level_start = std::time::Instant::now();
346
347                // For parallel processing, we simulate on a copy of the state
348                let mut state_copy = encoded_state.clone();
349
350                let decoding_result = self.config.codes_per_level[level]
351                    .decode(&state_copy)
352                    .unwrap_or_else(|_| DecodingResult {
353                        corrected_state: state_copy,
354                        syndrome: vec![false],
355                        error_pattern: vec![ErrorType::Identity],
356                        confidence: 0.0,
357                        errors_corrected: 0,
358                        success: false,
359                    });
360
361                let error_patterns: Vec<String> = decoding_result
362                    .error_pattern
363                    .iter()
364                    .map(|e| format!("{e:?}"))
365                    .collect();
366
367                LevelDecodingResult {
368                    level,
369                    syndromes: vec![decoding_result.syndrome],
370                    errors_corrected: decoding_result.errors_corrected,
371                    error_patterns,
372                    confidence: decoding_result.confidence,
373                    processing_time_ms: level_start.elapsed().as_secs_f64() * 1000.0,
374                }
375            })
376            .collect();
377
378        level_results.extend(results);
379
380        // For simplicity, use the final state from sequential decoding
381        let sequential_result = self.decode_sequential(encoded_state)?;
382
383        Ok(ConcatenatedCorrectionResult {
384            final_state: sequential_result.final_state,
385            level_results,
386            stats: self.stats.clone(),
387            execution_time_ms: 0.0,
388            success_probability: 0.0,
389        })
390    }
391
392    /// Adaptive decoding based on error patterns
393    fn decode_adaptive(
394        &mut self,
395        encoded_state: &Array1<Complex64>,
396    ) -> Result<ConcatenatedCorrectionResult> {
397        // Start with sequential decoding
398        let mut result = self.decode_sequential(encoded_state)?;
399
400        // Analyze error patterns to decide if additional iterations are needed
401        let error_rate = self.calculate_current_error_rate(&result.level_results);
402
403        if error_rate > self.config.error_threshold {
404            // High error rate detected - try alternative decoding strategies
405            for iteration in 1..self.config.max_decoding_iterations {
406                let alternative_result = if iteration % 2 == 1 {
407                    self.decode_parallel(encoded_state)?
408                } else {
409                    self.decode_sequential(encoded_state)?
410                };
411
412                let alt_error_rate =
413                    self.calculate_current_error_rate(&alternative_result.level_results);
414                if alt_error_rate < error_rate {
415                    result = alternative_result;
416                    break;
417                }
418
419                self.stats.decoding_iterations += 1;
420            }
421        }
422
423        Ok(result)
424    }
425
426    /// Belief propagation decoding between levels
427    fn decode_belief_propagation(
428        &mut self,
429        encoded_state: &Array1<Complex64>,
430    ) -> Result<ConcatenatedCorrectionResult> {
431        // Simplified belief propagation - in practice this would be much more complex
432        let mut current_state = encoded_state.clone();
433        let mut level_results = Vec::new();
434
435        // Initialize belief messages
436        let num_levels = self.config.codes_per_level.len();
437        let mut beliefs: Vec<f64> = vec![1.0; num_levels];
438
439        for iteration in 0..self.config.max_decoding_iterations.min(5) {
440            for (level, code) in self.config.codes_per_level.iter().enumerate() {
441                let level_start = std::time::Instant::now();
442
443                // Decode with current beliefs
444                let decoding_result = code.decode(&current_state)?;
445
446                // Update beliefs based on decoding confidence
447                beliefs[level] = beliefs[level].mul_add(0.9, decoding_result.confidence * 0.1);
448
449                current_state = decoding_result.corrected_state;
450
451                let error_patterns: Vec<String> = decoding_result
452                    .error_pattern
453                    .iter()
454                    .map(|e| format!("{e:?}"))
455                    .collect();
456
457                let level_result = LevelDecodingResult {
458                    level,
459                    syndromes: vec![decoding_result.syndrome],
460                    errors_corrected: decoding_result.errors_corrected,
461                    error_patterns,
462                    confidence: beliefs[level],
463                    processing_time_ms: level_start.elapsed().as_secs_f64() * 1000.0,
464                };
465
466                if iteration == 0 || level_results.len() <= level {
467                    level_results.push(level_result);
468                } else {
469                    level_results[level] = level_result;
470                }
471
472                self.stats.total_errors_corrected += decoding_result.errors_corrected;
473            }
474
475            // Check convergence
476            let avg_confidence: f64 = beliefs.iter().sum::<f64>() / beliefs.len() as f64;
477            if avg_confidence > 0.95 {
478                break;
479            }
480
481            self.stats.decoding_iterations += 1;
482        }
483
484        Ok(ConcatenatedCorrectionResult {
485            final_state: current_state,
486            level_results,
487            stats: self.stats.clone(),
488            execution_time_ms: 0.0,
489            success_probability: 0.0,
490        })
491    }
492
493    /// Calculate current error rate from level results
494    fn calculate_current_error_rate(&self, level_results: &[LevelDecodingResult]) -> f64 {
495        if level_results.is_empty() {
496            return 0.0;
497        }
498
499        let total_errors: usize = level_results.iter().map(|r| r.errors_corrected).sum();
500
501        let total_qubits = self.stats.physical_qubits.max(1);
502        total_errors as f64 / total_qubits as f64
503    }
504
505    /// Estimate success probability based on decoding results
506    fn estimate_success_probability(&self, result: &ConcatenatedCorrectionResult) -> f64 {
507        if result.level_results.is_empty() {
508            return 1.0;
509        }
510
511        // Product of confidences across all levels
512        let confidence_product: f64 = result.level_results.iter().map(|r| r.confidence).product();
513
514        // Adjust based on error rate
515        let error_rate = self.calculate_current_error_rate(&result.level_results);
516        let error_penalty = (-error_rate * 10.0).exp();
517
518        (confidence_product * error_penalty).min(1.0).max(0.0)
519    }
520
521    /// Get current statistics
522    #[must_use]
523    pub const fn get_stats(&self) -> &ConcatenationStats {
524        &self.stats
525    }
526
527    /// Reset statistics
528    pub fn reset_stats(&mut self) {
529        self.stats = ConcatenationStats::default();
530        self.syndrome_history.clear();
531        self.error_rates.clear();
532    }
533}
534
535/// Implementation of error correction codes for concatenation
536/// Simple bit flip code implementation
537#[derive(Debug, Clone)]
538pub struct BitFlipCode;
539
540impl Default for BitFlipCode {
541    fn default() -> Self {
542        Self::new()
543    }
544}
545
546impl BitFlipCode {
547    #[must_use]
548    pub const fn new() -> Self {
549        Self
550    }
551}
552
553/// Wrapper for bit flip code
554#[derive(Debug)]
555pub struct ConcatenatedBitFlipCode {
556    inner_code: BitFlipCode,
557}
558
559impl Default for ConcatenatedBitFlipCode {
560    fn default() -> Self {
561        Self::new()
562    }
563}
564
565impl ConcatenatedBitFlipCode {
566    #[must_use]
567    pub const fn new() -> Self {
568        Self {
569            inner_code: BitFlipCode::new(),
570        }
571    }
572}
573
574impl ErrorCorrectionCode for ConcatenatedBitFlipCode {
575    fn get_parameters(&self) -> CodeParameters {
576        CodeParameters {
577            n_logical: 1,
578            n_physical: 3,
579            distance: 3,
580            t: 1,
581        }
582    }
583
584    fn encode(&self, logical_state: &Array1<Complex64>) -> Result<Array1<Complex64>> {
585        // Simulate bit flip encoding
586        let n_logical = logical_state.len();
587        let n_physical = n_logical * 3;
588
589        let mut encoded = Array1::zeros(n_physical);
590
591        // Triple each logical qubit
592        for i in 0..n_logical {
593            let amp = logical_state[i];
594            encoded[i * 3] = amp;
595            encoded[i * 3 + 1] = amp;
596            encoded[i * 3 + 2] = amp;
597        }
598
599        Ok(encoded)
600    }
601
602    fn decode(&self, encoded_state: &Array1<Complex64>) -> Result<DecodingResult> {
603        let n_physical = encoded_state.len();
604        let n_logical = n_physical / 3;
605
606        let mut corrected_state = Array1::zeros(n_logical);
607        let mut syndrome = Vec::new();
608        let mut error_pattern = Vec::new();
609        let mut errors_corrected = 0;
610
611        for i in 0..n_logical {
612            let block_start = i * 3;
613            let a0 = encoded_state[block_start];
614            let a1 = encoded_state[block_start + 1];
615            let a2 = encoded_state[block_start + 2];
616
617            // Majority vote (simplified)
618            let distances = [(a0 - a1).norm(), (a1 - a2).norm(), (a0 - a2).norm()];
619
620            let min_dist_idx = distances
621                .iter()
622                .enumerate()
623                .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
624                .map(|(idx, _)| idx)
625                .unwrap_or(0);
626
627            match min_dist_idx {
628                0 => {
629                    // a0 and a1 are closest
630                    corrected_state[i] = (a0 + a1) / 2.0;
631                    if (a2 - a0).norm() > 1e-10 {
632                        syndrome.push(true);
633                        error_pattern.push(ErrorType::BitFlip);
634                        errors_corrected += 1;
635                    } else {
636                        syndrome.push(false);
637                        error_pattern.push(ErrorType::Identity);
638                    }
639                }
640                1 => {
641                    // a1 and a2 are closest
642                    corrected_state[i] = (a1 + a2) / 2.0;
643                    if (a0 - a1).norm() > 1e-10 {
644                        syndrome.push(true);
645                        error_pattern.push(ErrorType::BitFlip);
646                        errors_corrected += 1;
647                    } else {
648                        syndrome.push(false);
649                        error_pattern.push(ErrorType::Identity);
650                    }
651                }
652                2 => {
653                    // a0 and a2 are closest
654                    corrected_state[i] = (a0 + a2) / 2.0;
655                    if (a1 - a0).norm() > 1e-10 {
656                        syndrome.push(true);
657                        error_pattern.push(ErrorType::BitFlip);
658                        errors_corrected += 1;
659                    } else {
660                        syndrome.push(false);
661                        error_pattern.push(ErrorType::Identity);
662                    }
663                }
664                _ => unreachable!(),
665            }
666        }
667
668        let confidence = 1.0 - (errors_corrected as f64 / n_logical as f64);
669
670        Ok(DecodingResult {
671            corrected_state,
672            syndrome,
673            error_pattern,
674            confidence,
675            errors_corrected,
676            success: errors_corrected <= n_logical,
677        })
678    }
679
680    fn syndrome_circuit(&self, num_qubits: usize) -> Result<InterfaceCircuit> {
681        let mut circuit = InterfaceCircuit::new(num_qubits + 2, 2);
682
683        // Simple syndrome extraction for bit flip code
684        for i in (0..num_qubits).step_by(3) {
685            if i + 2 < num_qubits {
686                circuit.add_gate(InterfaceGate::new(
687                    InterfaceGateType::CNOT,
688                    vec![i, num_qubits],
689                ));
690                circuit.add_gate(InterfaceGate::new(
691                    InterfaceGateType::CNOT,
692                    vec![i + 1, num_qubits],
693                ));
694                circuit.add_gate(InterfaceGate::new(
695                    InterfaceGateType::CNOT,
696                    vec![i + 1, num_qubits + 1],
697                ));
698                circuit.add_gate(InterfaceGate::new(
699                    InterfaceGateType::CNOT,
700                    vec![i + 2, num_qubits + 1],
701                ));
702            }
703        }
704
705        Ok(circuit)
706    }
707
708    fn correct_errors(&self, state: &mut Array1<Complex64>, syndrome: &[bool]) -> Result<()> {
709        // Apply corrections based on syndrome
710        for (i, &has_error) in syndrome.iter().enumerate() {
711            if has_error && i * 3 + 2 < state.len() {
712                // Apply majority vote correction
713                let block_start = i * 3;
714                let majority =
715                    (state[block_start] + state[block_start + 1] + state[block_start + 2]) / 3.0;
716                state[block_start] = majority;
717                state[block_start + 1] = majority;
718                state[block_start + 2] = majority;
719            }
720        }
721        Ok(())
722    }
723}
724
725/// Create concatenated error correction with predefined configuration
726pub fn create_standard_concatenated_code(levels: usize) -> Result<ConcatenatedErrorCorrection> {
727    let mut concatenation_levels = Vec::new();
728    let mut codes_per_level: Vec<Box<dyn ErrorCorrectionCode>> = Vec::new();
729
730    for level in 0..levels {
731        concatenation_levels.push(ConcatenationLevel {
732            level,
733            distance: 3,
734            code_rate: 3,
735        });
736
737        codes_per_level.push(Box::new(ConcatenatedBitFlipCode::new()));
738    }
739
740    let config = ConcatenatedCodeConfig {
741        levels: concatenation_levels,
742        codes_per_level,
743        decoding_method: HierarchicalDecodingMethod::Sequential,
744        error_threshold: 0.1,
745        parallel_decoding: true,
746        max_decoding_iterations: 10,
747    };
748
749    ConcatenatedErrorCorrection::new(config)
750}
751
752/// Benchmark concatenated error correction performance
753pub fn benchmark_concatenated_error_correction() -> Result<HashMap<String, f64>> {
754    let mut results = HashMap::new();
755
756    // Test different concatenation levels
757    let levels = vec![1, 2, 3];
758
759    for &level in &levels {
760        let start = std::time::Instant::now();
761
762        let mut concatenated = create_standard_concatenated_code(level)?;
763
764        // Create test logical state
765        let logical_state = Array1::from_vec(vec![
766            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
767            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
768        ]);
769
770        // Encode
771        let encoded = concatenated.encode_concatenated(&logical_state)?;
772
773        // Simulate some errors by adding noise
774        let mut noisy_encoded = encoded.clone();
775        for i in 0..noisy_encoded.len().min(5) {
776            noisy_encoded[i] += Complex64::new(0.01 * fastrand::f64(), 0.01 * fastrand::f64());
777        }
778
779        // Decode
780        let _result = concatenated.decode_hierarchical(&noisy_encoded)?;
781
782        let time = start.elapsed().as_secs_f64() * 1000.0;
783        results.insert(format!("level_{level}"), time);
784    }
785
786    Ok(results)
787}
788
789#[cfg(test)]
790mod tests {
791    use super::*;
792    use approx::assert_abs_diff_eq;
793
794    #[test]
795    fn test_concatenated_code_creation() {
796        let concatenated = create_standard_concatenated_code(2);
797        assert!(concatenated.is_ok());
798    }
799
800    #[test]
801    fn test_bit_flip_code_parameters() {
802        let code = ConcatenatedBitFlipCode::new();
803        let params = code.get_parameters();
804
805        assert_eq!(params.n_logical, 1);
806        assert_eq!(params.n_physical, 3);
807        assert_eq!(params.distance, 3);
808        assert_eq!(params.t, 1);
809    }
810
811    #[test]
812    fn test_bit_flip_encoding() {
813        let code = ConcatenatedBitFlipCode::new();
814        let logical_state =
815            Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
816
817        let encoded = code
818            .encode(&logical_state)
819            .expect("Encoding should succeed in test");
820        assert_eq!(encoded.len(), 6); // 2 logical -> 6 physical
821
822        // Check triplication
823        assert!((encoded[0] - logical_state[0]).norm() < 1e-10);
824        assert!((encoded[1] - logical_state[0]).norm() < 1e-10);
825        assert!((encoded[2] - logical_state[0]).norm() < 1e-10);
826    }
827
828    #[test]
829    fn test_concatenated_encoding_decoding() {
830        let mut concatenated = create_standard_concatenated_code(1)
831            .expect("Concatenated code creation should succeed in test");
832
833        let logical_state = Array1::from_vec(vec![
834            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
835            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
836        ]);
837
838        let encoded = concatenated
839            .encode_concatenated(&logical_state)
840            .expect("Encoding should succeed in test");
841        assert!(encoded.len() >= logical_state.len());
842
843        let result = concatenated
844            .decode_hierarchical(&encoded)
845            .expect("Decoding should succeed in test");
846        assert!(!result.level_results.is_empty());
847        assert!(result.success_probability >= 0.0);
848    }
849
850    #[test]
851    fn test_syndrome_circuit_creation() {
852        let code = ConcatenatedBitFlipCode::new();
853        let circuit = code
854            .syndrome_circuit(6)
855            .expect("Syndrome circuit creation should succeed in test");
856
857        assert_eq!(circuit.num_qubits, 8); // 6 data + 2 syndrome qubits
858        assert!(!circuit.gates.is_empty());
859    }
860
861    #[test]
862    fn test_decoding_methods() {
863        let mut concatenated = create_standard_concatenated_code(1)
864            .expect("Concatenated code creation should succeed in test");
865
866        let logical_state = Array1::from_vec(vec![Complex64::new(1.0, 0.0)]);
867        let encoded = concatenated
868            .encode_concatenated(&logical_state)
869            .expect("Encoding should succeed in test");
870
871        // Test sequential decoding
872        concatenated.config.decoding_method = HierarchicalDecodingMethod::Sequential;
873        let seq_result = concatenated
874            .decode_hierarchical(&encoded)
875            .expect("Sequential decoding should succeed in test");
876        assert!(!seq_result.level_results.is_empty());
877
878        // Test adaptive decoding
879        concatenated.config.decoding_method = HierarchicalDecodingMethod::Adaptive;
880        let adapt_result = concatenated
881            .decode_hierarchical(&encoded)
882            .expect("Adaptive decoding should succeed in test");
883        assert!(!adapt_result.level_results.is_empty());
884    }
885
886    #[test]
887    fn test_error_rate_calculation() {
888        let concatenated = create_standard_concatenated_code(1)
889            .expect("Concatenated code creation should succeed in test");
890
891        let level_results = vec![LevelDecodingResult {
892            level: 0,
893            syndromes: vec![vec![true, false]],
894            errors_corrected: 1,
895            error_patterns: vec!["BitFlip".to_string()],
896            confidence: 0.9,
897            processing_time_ms: 1.0,
898        }];
899
900        let error_rate = concatenated.calculate_current_error_rate(&level_results);
901        assert!(error_rate > 0.0);
902    }
903}