quantrs2_sim/
error_mitigation.rs

1//! Error mitigation strategies for quantum computing
2//!
3//! This module provides state-of-the-art error mitigation techniques to improve
4//! the accuracy of noisy quantum simulations and real quantum hardware results.
5//!
6//! # Supported Techniques
7//!
8//! - **Zero-Noise Extrapolation (ZNE)**: Extrapolate results to zero noise limit
9//! - **Probabilistic Error Cancellation (PEC)**: Use quasi-probability to cancel errors
10//! - **Clifford Data Regression (CDR)**: Noise characterization with Clifford circuits
11//! - **Measurement Error Mitigation**: Correct readout errors using calibration
12//! - **Symmetry Verification**: Verify conservation laws and post-select results
13//!
14//! # Example
15//!
16//! ```ignore
17//! use quantrs2_sim::error_mitigation::{ZeroNoiseExtrapolation, ExtrapolationMethod};
18//!
19//! let zne = ZeroNoiseExtrapolation::new(ExtrapolationMethod::Richardson);
20//! let mitigated_result = zne.apply(&noisy_results, &noise_scales)?;
21//! ```
22
23use crate::error::{Result, SimulatorError};
24use scirs2_core::ndarray::{Array1, Array2};
25use scirs2_core::random::prelude::*;
26use scirs2_core::Complex64;
27use std::collections::HashMap;
28
29// ============================================================================
30// Zero-Noise Extrapolation (ZNE)
31// ============================================================================
32
33/// Extrapolation methods for ZNE
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum ExtrapolationMethod {
36    /// Linear extrapolation (2 points)
37    Linear,
38    /// Richardson extrapolation (3+ points)
39    Richardson,
40    /// Polynomial fit
41    Polynomial { degree: usize },
42    /// Exponential fit
43    Exponential,
44}
45
46/// Zero-Noise Extrapolation for error mitigation
47///
48/// ZNE runs the same circuit at different noise levels and extrapolates
49/// the result to the zero-noise limit.
50#[derive(Debug, Clone)]
51pub struct ZeroNoiseExtrapolation {
52    /// Extrapolation method
53    method: ExtrapolationMethod,
54    /// Noise scale factors to use (e.g., [1.0, 1.5, 2.0])
55    scale_factors: Vec<f64>,
56}
57
58impl ZeroNoiseExtrapolation {
59    /// Create a new ZNE instance with default scale factors
60    pub fn new(method: ExtrapolationMethod) -> Self {
61        Self {
62            method,
63            scale_factors: vec![1.0, 1.5, 2.0, 2.5, 3.0],
64        }
65    }
66
67    /// Create ZNE with custom scale factors
68    pub fn with_scale_factors(
69        method: ExtrapolationMethod,
70        scale_factors: Vec<f64>,
71    ) -> Result<Self> {
72        if scale_factors.is_empty() {
73            return Err(SimulatorError::InvalidInput(
74                "Scale factors cannot be empty".to_string(),
75            ));
76        }
77
78        // Verify scale factors are in ascending order and start with 1.0
79        if scale_factors[0] != 1.0 {
80            return Err(SimulatorError::InvalidInput(
81                "First scale factor must be 1.0 (original noise level)".to_string(),
82            ));
83        }
84
85        for i in 1..scale_factors.len() {
86            if scale_factors[i] <= scale_factors[i - 1] {
87                return Err(SimulatorError::InvalidInput(
88                    "Scale factors must be in strictly ascending order".to_string(),
89                ));
90            }
91        }
92
93        Ok(Self {
94            method,
95            scale_factors,
96        })
97    }
98
99    /// Apply ZNE to a set of noisy expectation values
100    ///
101    /// # Arguments
102    ///
103    /// * `noisy_values` - Expectation values at different noise scales
104    /// * `noise_scales` - Corresponding noise scale factors
105    pub fn apply(&self, noisy_values: &[f64], noise_scales: &[f64]) -> Result<f64> {
106        if noisy_values.len() != noise_scales.len() {
107            return Err(SimulatorError::InvalidInput(
108                "Number of values must match number of scales".to_string(),
109            ));
110        }
111
112        if noisy_values.len() < 2 {
113            return Err(SimulatorError::InvalidInput(
114                "At least 2 data points required for extrapolation".to_string(),
115            ));
116        }
117
118        match self.method {
119            ExtrapolationMethod::Linear => self.linear_extrapolation(noisy_values, noise_scales),
120            ExtrapolationMethod::Richardson => {
121                self.richardson_extrapolation(noisy_values, noise_scales)
122            }
123            ExtrapolationMethod::Polynomial { degree } => {
124                self.polynomial_extrapolation(noisy_values, noise_scales, degree)
125            }
126            ExtrapolationMethod::Exponential => {
127                self.exponential_extrapolation(noisy_values, noise_scales)
128            }
129        }
130    }
131
132    /// Linear extrapolation using first two points
133    fn linear_extrapolation(&self, values: &[f64], scales: &[f64]) -> Result<f64> {
134        if values.len() < 2 {
135            return Err(SimulatorError::InvalidInput(
136                "Linear extrapolation requires at least 2 points".to_string(),
137            ));
138        }
139
140        let x1 = scales[0];
141        let y1 = values[0];
142        let x2 = scales[1];
143        let y2 = values[1];
144
145        // Extrapolate to x=0 (zero noise)
146        let slope = (y2 - y1) / (x2 - x1);
147        Ok(y1 - slope * x1)
148    }
149
150    /// Richardson extrapolation (higher order)
151    fn richardson_extrapolation(&self, values: &[f64], scales: &[f64]) -> Result<f64> {
152        if values.len() < 3 {
153            return Err(SimulatorError::InvalidInput(
154                "Richardson extrapolation requires at least 3 points".to_string(),
155            ));
156        }
157
158        // Use quadratic Richardson extrapolation
159        let x0 = scales[0];
160        let x1 = scales[1];
161        let x2 = scales[2];
162        let y0 = values[0];
163        let y1 = values[1];
164        let y2 = values[2];
165
166        // Fit quadratic: y = a*x^2 + b*x + c
167        // Extrapolate to x=0 gives c
168        let denom = (x0 - x1) * (x0 - x2) * (x1 - x2);
169        if denom.abs() < 1e-10 {
170            return Err(SimulatorError::InvalidInput(
171                "Scale factors too close for stable extrapolation".to_string(),
172            ));
173        }
174
175        let a = (x2 * (y1 - y0) + x0 * (y2 - y1) + x1 * (y0 - y2)) / denom;
176        let b = (x2 * x2 * (y0 - y1) + x0 * x0 * (y1 - y2) + x1 * x1 * (y2 - y0)) / denom;
177        let c = (x1 * x2 * (x1 - x2) * y0 + x2 * x0 * (x2 - x0) * y1 + x0 * x1 * (x0 - x1) * y2)
178            / denom;
179
180        Ok(c)
181    }
182
183    /// Polynomial fit extrapolation
184    fn polynomial_extrapolation(
185        &self,
186        values: &[f64],
187        scales: &[f64],
188        degree: usize,
189    ) -> Result<f64> {
190        if values.len() <= degree {
191            return Err(SimulatorError::InvalidInput(format!(
192                "Need at least {} points for degree {} polynomial",
193                degree + 1,
194                degree
195            )));
196        }
197
198        // Simple polynomial fit using least squares
199        // For now, use Richardson for degree 2, linear for degree 1
200        match degree {
201            1 => self.linear_extrapolation(values, scales),
202            2 => self.richardson_extrapolation(values, scales),
203            _ => Err(SimulatorError::NotImplemented(
204                "Polynomial degree > 2 not yet implemented".to_string(),
205            )),
206        }
207    }
208
209    /// Exponential fit extrapolation: y = a * exp(-b*x) + c
210    fn exponential_extrapolation(&self, values: &[f64], scales: &[f64]) -> Result<f64> {
211        if values.len() < 3 {
212            return Err(SimulatorError::InvalidInput(
213                "Exponential extrapolation requires at least 3 points".to_string(),
214            ));
215        }
216
217        // Simplified: assume y = a * exp(-b*x) + c where c is the zero-noise value
218        // Use first derivative at x=0 from first two points
219        let x0 = scales[0];
220        let x1 = scales[1];
221        let y0 = values[0];
222        let y1 = values[1];
223
224        // Estimate decay rate
225        if (y1 - y0).abs() < 1e-10 {
226            return Ok(y0); // No decay observed
227        }
228
229        let b_estimate = -((y1 - y0) / (x1 - x0)) / y0.max(1e-10);
230        let c_estimate = y0 / (1.0 + b_estimate * x0).max(1e-10);
231
232        Ok(c_estimate)
233    }
234
235    /// Get the scale factors
236    pub fn scale_factors(&self) -> &[f64] {
237        &self.scale_factors
238    }
239}
240
241// ============================================================================
242// Measurement Error Mitigation
243// ============================================================================
244
245/// Measurement error mitigation using calibration matrices
246///
247/// Corrects readout errors by inverting the noise transfer matrix
248/// measured through calibration circuits.
249#[derive(Debug, Clone)]
250pub struct MeasurementErrorMitigation {
251    /// Calibration matrix: M[i][j] = P(measure i | prepared j)
252    calibration_matrix: Array2<f64>,
253    /// Inverse calibration matrix (for mitigation)
254    inverse_matrix: Option<Array2<f64>>,
255    /// Number of qubits
256    n_qubits: usize,
257}
258
259impl MeasurementErrorMitigation {
260    /// Create a new measurement error mitigation instance
261    ///
262    /// # Arguments
263    ///
264    /// * `n_qubits` - Number of qubits
265    pub fn new(n_qubits: usize) -> Self {
266        let dim = 1 << n_qubits; // 2^n
267        let calibration_matrix = Array2::eye(dim); // Identity by default (no error)
268
269        Self {
270            calibration_matrix,
271            inverse_matrix: None,
272            n_qubits,
273        }
274    }
275
276    /// Set the calibration matrix from measurements
277    ///
278    /// The calibration matrix `M[i][j]` represents the probability of
279    /// measuring bitstring i when bitstring j was prepared.
280    pub fn set_calibration_matrix(&mut self, matrix: Array2<f64>) -> Result<()> {
281        let expected_dim = 1 << self.n_qubits;
282        if matrix.nrows() != expected_dim || matrix.ncols() != expected_dim {
283            return Err(SimulatorError::InvalidInput(format!(
284                "Calibration matrix must be {}x{} for {} qubits",
285                expected_dim, expected_dim, self.n_qubits
286            )));
287        }
288
289        // Verify matrix is stochastic (columns sum to 1)
290        for col in 0..matrix.ncols() {
291            let sum: f64 = (0..matrix.nrows()).map(|row| matrix[[row, col]]).sum();
292            if (sum - 1.0).abs() > 1e-6 {
293                return Err(SimulatorError::InvalidInput(format!(
294                    "Column {} does not sum to 1.0 (sum = {})",
295                    col, sum
296                )));
297            }
298        }
299
300        self.calibration_matrix = matrix;
301        self.inverse_matrix = None; // Will be computed on demand
302        Ok(())
303    }
304
305    /// Compute and cache the inverse calibration matrix
306    fn compute_inverse(&mut self) -> Result<()> {
307        if self.inverse_matrix.is_some() {
308            return Ok(());
309        }
310
311        // Use pseudo-inverse for stability
312        // For now, use simple matrix inversion (can be improved with SVD)
313        let matrix = &self.calibration_matrix;
314        let n = matrix.nrows();
315
316        // Simple Gauss-Jordan elimination for small matrices
317        let mut augmented = Array2::zeros((n, 2 * n));
318        for i in 0..n {
319            for j in 0..n {
320                augmented[[i, j]] = matrix[[i, j]];
321                augmented[[i, j + n]] = if i == j { 1.0 } else { 0.0 };
322            }
323        }
324
325        // Forward elimination (simplified, may need pivoting for stability)
326        for i in 0..n {
327            // Find pivot
328            let pivot = augmented[[i, i]];
329            if pivot.abs() < 1e-10 {
330                return Err(SimulatorError::InvalidInput(
331                    "Calibration matrix is singular or nearly singular".to_string(),
332                ));
333            }
334
335            // Scale row
336            for j in 0..(2 * n) {
337                augmented[[i, j]] /= pivot;
338            }
339
340            // Eliminate
341            for k in 0..n {
342                if k != i {
343                    let factor = augmented[[k, i]];
344                    for j in 0..(2 * n) {
345                        augmented[[k, j]] -= factor * augmented[[i, j]];
346                    }
347                }
348            }
349        }
350
351        // Extract inverse from augmented matrix
352        let mut inverse = Array2::zeros((n, n));
353        for i in 0..n {
354            for j in 0..n {
355                inverse[[i, j]] = augmented[[i, j + n]];
356            }
357        }
358
359        self.inverse_matrix = Some(inverse);
360        Ok(())
361    }
362
363    /// Apply measurement error mitigation to noisy counts
364    ///
365    /// # Arguments
366    ///
367    /// * `noisy_counts` - Raw measurement counts from the quantum circuit
368    ///
369    /// # Returns
370    ///
371    /// Mitigated counts (may contain negative values due to inversion)
372    pub fn apply(&mut self, noisy_counts: &HashMap<String, usize>) -> Result<HashMap<String, f64>> {
373        self.compute_inverse()?;
374        let inverse = self.inverse_matrix.as_ref().unwrap();
375
376        let dim = 1 << self.n_qubits;
377        let total_shots: usize = noisy_counts.values().sum();
378
379        // Convert counts to probability vector
380        let mut noisy_probs = Array1::zeros(dim);
381        for (bitstring, count) in noisy_counts {
382            if bitstring.len() != self.n_qubits {
383                return Err(SimulatorError::InvalidInput(format!(
384                    "Bitstring {} has wrong length (expected {})",
385                    bitstring, self.n_qubits
386                )));
387            }
388            let index = usize::from_str_radix(bitstring, 2).map_err(|_| {
389                SimulatorError::InvalidInput(format!("Invalid bitstring: {}", bitstring))
390            })?;
391            noisy_probs[index] = *count as f64 / total_shots as f64;
392        }
393
394        // Apply inverse: mitigated_probs = M^(-1) * noisy_probs
395        let mitigated_probs = inverse.dot(&noisy_probs);
396
397        // Convert back to counts
398        let mut mitigated_counts = HashMap::new();
399        for i in 0..dim {
400            let bitstring = format!("{:0width$b}", i, width = self.n_qubits);
401            let mitigated_count = mitigated_probs[i] * total_shots as f64;
402            if mitigated_count.abs() > 1e-10 {
403                mitigated_counts.insert(bitstring, mitigated_count);
404            }
405        }
406
407        Ok(mitigated_counts)
408    }
409
410    /// Generate calibration matrix from error rates
411    ///
412    /// # Arguments
413    ///
414    /// * `readout_error_0` - Probability of measuring 1 when prepared in |0⟩
415    /// * `readout_error_1` - Probability of measuring 0 when prepared in |1⟩
416    pub fn from_error_rates(
417        n_qubits: usize,
418        readout_error_0: f64,
419        readout_error_1: f64,
420    ) -> Result<Self> {
421        if !(0.0..=1.0).contains(&readout_error_0) {
422            return Err(SimulatorError::InvalidInput(
423                "readout_error_0 must be in [0, 1]".to_string(),
424            ));
425        }
426        if !(0.0..=1.0).contains(&readout_error_1) {
427            return Err(SimulatorError::InvalidInput(
428                "readout_error_1 must be in [0, 1]".to_string(),
429            ));
430        }
431
432        let dim = 1 << n_qubits;
433        let mut matrix = Array2::zeros((dim, dim));
434
435        // Build tensor product of single-qubit error matrices
436        // Single-qubit matrix: [[1-e0, e1], [e0, 1-e1]]
437        for prepared in 0..dim {
438            for measured in 0..dim {
439                let mut prob = 1.0;
440                for qubit in 0..n_qubits {
441                    let prepared_bit = (prepared >> qubit) & 1;
442                    let measured_bit = (measured >> qubit) & 1;
443
444                    let p = if prepared_bit == 0 {
445                        if measured_bit == 0 {
446                            1.0 - readout_error_0
447                        } else {
448                            readout_error_0
449                        }
450                    } else if measured_bit == 1 {
451                        1.0 - readout_error_1
452                    } else {
453                        readout_error_1
454                    };
455
456                    prob *= p;
457                }
458                matrix[[measured, prepared]] = prob;
459            }
460        }
461
462        let mut mem = Self::new(n_qubits);
463        mem.set_calibration_matrix(matrix)?;
464        Ok(mem)
465    }
466}
467
468// ============================================================================
469// Symmetry Verification
470// ============================================================================
471
472/// Symmetry verification for post-selection
473///
474/// Verifies that measurement results respect known symmetries
475/// (e.g., particle number conservation, parity) and rejects
476/// results that violate symmetries.
477#[derive(Debug, Clone)]
478pub struct SymmetryVerification {
479    /// Type of symmetry to verify
480    symmetry_type: SymmetryType,
481    /// Expected symmetry value
482    expected_value: Option<i32>,
483}
484
485/// Types of symmetries
486#[derive(Debug, Clone, Copy, PartialEq, Eq)]
487pub enum SymmetryType {
488    /// Particle number conservation (count of 1s)
489    ParticleNumber,
490    /// Parity (even/odd number of 1s)
491    Parity,
492    /// Custom symmetry (user-defined)
493    Custom,
494}
495
496impl SymmetryVerification {
497    /// Create a new symmetry verification instance
498    pub fn new(symmetry_type: SymmetryType) -> Self {
499        Self {
500            symmetry_type,
501            expected_value: None,
502        }
503    }
504
505    /// Set the expected symmetry value
506    pub fn with_expected_value(mut self, value: i32) -> Self {
507        self.expected_value = Some(value);
508        self
509    }
510
511    /// Verify if a bitstring satisfies the symmetry
512    pub fn verify(&self, bitstring: &str) -> bool {
513        let symmetry_value = self.compute_symmetry(bitstring);
514
515        if let Some(expected) = self.expected_value {
516            symmetry_value == expected
517        } else {
518            true // No constraint if expected value not set
519        }
520    }
521
522    /// Compute the symmetry value for a bitstring
523    fn compute_symmetry(&self, bitstring: &str) -> i32 {
524        match self.symmetry_type {
525            SymmetryType::ParticleNumber => bitstring.chars().filter(|&c| c == '1').count() as i32,
526            SymmetryType::Parity => {
527                let ones = bitstring.chars().filter(|&c| c == '1').count();
528                (ones % 2) as i32
529            }
530            SymmetryType::Custom => 0, // Placeholder
531        }
532    }
533
534    /// Filter measurement counts based on symmetry
535    pub fn filter_counts(&self, counts: &HashMap<String, usize>) -> HashMap<String, usize> {
536        counts
537            .iter()
538            .filter(|(bitstring, _)| self.verify(bitstring))
539            .map(|(k, v)| (k.clone(), *v))
540            .collect()
541    }
542
543    /// Normalize filtered counts to original total
544    pub fn filter_and_normalize(&self, counts: &HashMap<String, usize>) -> HashMap<String, usize> {
545        let total_shots: usize = counts.values().sum();
546        let filtered = self.filter_counts(counts);
547        let filtered_total: usize = filtered.values().sum();
548
549        if filtered_total == 0 {
550            return HashMap::new();
551        }
552
553        // Renormalize to original total
554        let scale = total_shots as f64 / filtered_total as f64;
555        filtered
556            .iter()
557            .map(|(k, v)| (k.clone(), (*v as f64 * scale).round() as usize))
558            .collect()
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565
566    #[test]
567    fn test_zne_linear_extrapolation() {
568        let zne = ZeroNoiseExtrapolation::new(ExtrapolationMethod::Linear);
569        let values = vec![0.8, 0.6, 0.4];
570        let scales = vec![1.0, 2.0, 3.0];
571
572        let result = zne.apply(&values, &scales).unwrap();
573        assert!((result - 1.0).abs() < 0.01); // Should extrapolate to ~1.0
574    }
575
576    #[test]
577    fn test_zne_richardson_extrapolation() {
578        let zne = ZeroNoiseExtrapolation::new(ExtrapolationMethod::Richardson);
579        // Quadratic decay: y = 1 - 0.1*x^2
580        let values = vec![1.0, 0.9, 0.6];
581        let scales = vec![0.0, 1.0, 2.0];
582
583        let result = zne.apply(&values, &scales).unwrap();
584        assert!((result - 1.0).abs() < 0.01);
585    }
586
587    #[test]
588    fn test_zne_invalid_input() {
589        let zne = ZeroNoiseExtrapolation::new(ExtrapolationMethod::Linear);
590        let values = vec![0.8];
591        let scales = vec![1.0, 2.0];
592
593        let result = zne.apply(&values, &scales);
594        assert!(result.is_err());
595    }
596
597    #[test]
598    fn test_measurement_error_mitigation_identity() {
599        let mut mem = MeasurementErrorMitigation::new(2);
600
601        let mut counts = HashMap::new();
602        counts.insert("00".to_string(), 100);
603        counts.insert("11".to_string(), 50);
604
605        let mitigated = mem.apply(&counts).unwrap();
606
607        // With identity matrix, should get same counts
608        assert!((mitigated["00"] - 100.0).abs() < 1e-6);
609        assert!((mitigated["11"] - 50.0).abs() < 1e-6);
610    }
611
612    #[test]
613    fn test_measurement_error_mitigation_from_error_rates() {
614        let mem = MeasurementErrorMitigation::from_error_rates(1, 0.1, 0.05).unwrap();
615
616        // Verify calibration matrix structure
617        let matrix = &mem.calibration_matrix;
618        assert_eq!(matrix.nrows(), 2);
619        assert_eq!(matrix.ncols(), 2);
620
621        // Check specific values
622        assert!((matrix[[0, 0]] - 0.9).abs() < 1e-10); // P(0|0) = 1 - e0
623        assert!((matrix[[1, 0]] - 0.1).abs() < 1e-10); // P(1|0) = e0
624        assert!((matrix[[0, 1]] - 0.05).abs() < 1e-10); // P(0|1) = e1
625        assert!((matrix[[1, 1]] - 0.95).abs() < 1e-10); // P(1|1) = 1 - e1
626    }
627
628    #[test]
629    fn test_symmetry_particle_number() {
630        let sym = SymmetryVerification::new(SymmetryType::ParticleNumber).with_expected_value(2);
631
632        assert!(sym.verify("0011"));
633        assert!(sym.verify("1100"));
634        assert!(sym.verify("1010"));
635        assert!(!sym.verify("0001"));
636        assert!(!sym.verify("1111"));
637    }
638
639    #[test]
640    fn test_symmetry_parity() {
641        let sym = SymmetryVerification::new(SymmetryType::Parity).with_expected_value(0);
642
643        assert!(sym.verify("0011")); // Even parity
644        assert!(sym.verify("1100")); // Even parity
645        assert!(!sym.verify("0001")); // Odd parity
646        assert!(!sym.verify("0111")); // Odd parity
647    }
648
649    #[test]
650    fn test_symmetry_filter_counts() {
651        let sym = SymmetryVerification::new(SymmetryType::ParticleNumber).with_expected_value(2);
652
653        let mut counts = HashMap::new();
654        counts.insert("0011".to_string(), 100);
655        counts.insert("1100".to_string(), 50);
656        counts.insert("0001".to_string(), 30); // Should be filtered
657        counts.insert("1111".to_string(), 20); // Should be filtered
658
659        let filtered = sym.filter_counts(&counts);
660
661        assert_eq!(filtered.len(), 2);
662        assert_eq!(filtered["0011"], 100);
663        assert_eq!(filtered["1100"], 50);
664        assert!(!filtered.contains_key("0001"));
665        assert!(!filtered.contains_key("1111"));
666    }
667
668    #[test]
669    fn test_zne_custom_scale_factors() {
670        let result = ZeroNoiseExtrapolation::with_scale_factors(
671            ExtrapolationMethod::Linear,
672            vec![1.0, 1.5, 2.0],
673        );
674        assert!(result.is_ok());
675
676        let bad_result = ZeroNoiseExtrapolation::with_scale_factors(
677            ExtrapolationMethod::Linear,
678            vec![2.0, 1.0], // Not starting with 1.0
679        );
680        assert!(bad_result.is_err());
681    }
682}