Skip to main content

quantrs2_sim/
stim_dem.rs

1//! Detector Error Model (DEM) for Stim circuit error analysis
2//!
3//! A DEM describes how errors in a circuit propagate to detectors and observables.
4//! This enables efficient decoding without re-simulating the full circuit.
5//!
6//! ## DEM Format
7//!
8//! The DEM file format consists of error instructions:
9//! ```text
10//! error(0.01) D0 D1
11//! error(0.02) D2 L0
12//! ```
13//!
14//! Each error line specifies:
15//! - Probability of the error occurring
16//! - Which detectors are flipped by this error (D0, D1, ...)
17//! - Which logical observables are flipped (L0, L1, ...)
18
19use crate::error::{Result, SimulatorError};
20use crate::stim_executor::{DetectorRecord, ObservableRecord, StimExecutor};
21use crate::stim_parser::{PauliTarget, PauliType, StimCircuit, StimInstruction};
22use std::collections::{HashMap, HashSet};
23
24/// Type of a forced single-qubit error
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ErrorType {
27    /// Pauli-X (bit-flip) error
28    PauliX,
29    /// Pauli-Z (phase-flip) error
30    PauliZ,
31    /// Pauli-Y (combined bit and phase flip) error
32    PauliY,
33    /// Measurement error (bit flip of measurement outcome)
34    Measurement,
35}
36
37impl ErrorType {
38    /// Returns a human-readable label for this error type
39    #[must_use]
40    pub fn label(&self) -> &'static str {
41        match self {
42            ErrorType::PauliX => "X_ERROR",
43            ErrorType::PauliZ => "Z_ERROR",
44            ErrorType::PauliY => "Y_ERROR",
45            ErrorType::Measurement => "MEASUREMENT_ERROR",
46        }
47    }
48}
49
50/// A single error mechanism in the DEM
51#[derive(Debug, Clone)]
52pub struct DEMError {
53    /// Probability of this error occurring
54    pub probability: f64,
55    /// Detector indices that flip when this error occurs
56    pub detector_targets: Vec<usize>,
57    /// Observable indices that flip when this error occurs
58    pub observable_targets: Vec<usize>,
59    /// Original error location (for debugging)
60    pub source_location: Option<ErrorLocation>,
61}
62
63/// Location information for error source
64#[derive(Debug, Clone)]
65pub struct ErrorLocation {
66    /// Instruction index in the circuit
67    pub instruction_index: usize,
68    /// Description of the error type
69    pub error_type: String,
70    /// Qubits involved
71    pub qubits: Vec<usize>,
72}
73
74/// Detector Error Model representation
75#[derive(Debug, Clone)]
76pub struct DetectorErrorModel {
77    /// Number of detectors in the circuit
78    pub num_detectors: usize,
79    /// Number of logical observables
80    pub num_observables: usize,
81    /// List of error mechanisms
82    pub errors: Vec<DEMError>,
83    /// Coordinate system shifts (for visualization)
84    pub coordinate_shifts: Vec<Vec<f64>>,
85    /// Detector coordinates
86    pub detector_coords: HashMap<usize, Vec<f64>>,
87}
88
89impl DetectorErrorModel {
90    /// Create a new empty DEM
91    #[must_use]
92    pub fn new(num_detectors: usize, num_observables: usize) -> Self {
93        Self {
94            num_detectors,
95            num_observables,
96            errors: Vec::new(),
97            coordinate_shifts: Vec::new(),
98            detector_coords: HashMap::new(),
99        }
100    }
101
102    /// Generate a DEM from a Stim circuit
103    ///
104    /// This performs error analysis by:
105    /// 1. Identifying all error mechanisms in the circuit
106    /// 2. Propagating each error through to detectors/observables
107    /// 3. Recording which detectors/observables are affected
108    pub fn from_circuit(circuit: &StimCircuit) -> Result<Self> {
109        // First, run the circuit without errors to establish baseline
110        let mut clean_executor = StimExecutor::from_circuit(circuit);
111        let clean_result = clean_executor.execute(circuit)?;
112
113        let num_detectors = clean_result.num_detectors;
114        let num_observables = clean_result.num_observables;
115
116        let mut dem = Self::new(num_detectors, num_observables);
117
118        // Collect detector coordinates
119        for detector in clean_executor.detectors() {
120            if !detector.coordinates.is_empty() {
121                dem.detector_coords
122                    .insert(detector.index, detector.coordinates.clone());
123            }
124        }
125
126        // Analyze each error instruction in the circuit
127        let mut instruction_index = 0;
128        for instruction in &circuit.instructions {
129            match instruction {
130                StimInstruction::XError {
131                    probability,
132                    qubits,
133                }
134                | StimInstruction::YError {
135                    probability,
136                    qubits,
137                }
138                | StimInstruction::ZError {
139                    probability,
140                    qubits,
141                } => {
142                    let error_type = match instruction {
143                        StimInstruction::XError { .. } => "X",
144                        StimInstruction::YError { .. } => "Y",
145                        _ => "Z",
146                    };
147
148                    for &qubit in qubits {
149                        let dem_error = Self::analyze_single_qubit_error(
150                            circuit,
151                            instruction_index,
152                            error_type,
153                            qubit,
154                            *probability,
155                            &clean_result.detector_values,
156                            &clean_result.observable_values,
157                        )?;
158
159                        if !dem_error.detector_targets.is_empty()
160                            || !dem_error.observable_targets.is_empty()
161                        {
162                            dem.errors.push(dem_error);
163                        }
164                    }
165                }
166
167                StimInstruction::Depolarize1 {
168                    probability,
169                    qubits,
170                } => {
171                    // Depolarizing noise: treat as 3 separate X/Y/Z errors
172                    let per_pauli_prob = probability / 3.0;
173                    for &qubit in qubits {
174                        for error_type in &["X", "Y", "Z"] {
175                            let dem_error = Self::analyze_single_qubit_error(
176                                circuit,
177                                instruction_index,
178                                error_type,
179                                qubit,
180                                per_pauli_prob,
181                                &clean_result.detector_values,
182                                &clean_result.observable_values,
183                            )?;
184
185                            if !dem_error.detector_targets.is_empty()
186                                || !dem_error.observable_targets.is_empty()
187                            {
188                                dem.errors.push(dem_error);
189                            }
190                        }
191                    }
192                }
193
194                StimInstruction::CorrelatedError {
195                    probability,
196                    targets,
197                }
198                | StimInstruction::ElseCorrelatedError {
199                    probability,
200                    targets,
201                } => {
202                    let dem_error = Self::analyze_correlated_error(
203                        circuit,
204                        instruction_index,
205                        targets,
206                        *probability,
207                        &clean_result.detector_values,
208                        &clean_result.observable_values,
209                    )?;
210
211                    if !dem_error.detector_targets.is_empty()
212                        || !dem_error.observable_targets.is_empty()
213                    {
214                        dem.errors.push(dem_error);
215                    }
216                }
217
218                StimInstruction::Depolarize2 {
219                    probability,
220                    qubit_pairs,
221                } => {
222                    // Two-qubit depolarizing: 15 error mechanisms
223                    let per_pauli_prob = probability / 15.0;
224                    for &(q1, q2) in qubit_pairs {
225                        for p1 in &[PauliType::I, PauliType::X, PauliType::Y, PauliType::Z] {
226                            for p2 in &[PauliType::I, PauliType::X, PauliType::Y, PauliType::Z] {
227                                if *p1 == PauliType::I && *p2 == PauliType::I {
228                                    continue; // Skip identity
229                                }
230                                let targets = vec![
231                                    PauliTarget {
232                                        pauli: *p1,
233                                        qubit: q1,
234                                    },
235                                    PauliTarget {
236                                        pauli: *p2,
237                                        qubit: q2,
238                                    },
239                                ];
240                                let dem_error = Self::analyze_correlated_error(
241                                    circuit,
242                                    instruction_index,
243                                    &targets,
244                                    per_pauli_prob,
245                                    &clean_result.detector_values,
246                                    &clean_result.observable_values,
247                                )?;
248
249                                if !dem_error.detector_targets.is_empty()
250                                    || !dem_error.observable_targets.is_empty()
251                                {
252                                    dem.errors.push(dem_error);
253                                }
254                            }
255                        }
256                    }
257                }
258
259                _ => {}
260            }
261            instruction_index += 1;
262        }
263
264        // Merge duplicate errors (same detector/observable targets)
265        dem.merge_duplicate_errors();
266
267        Ok(dem)
268    }
269
270    /// Analyze how a single-qubit error affects detectors/observables
271    fn analyze_single_qubit_error(
272        circuit: &StimCircuit,
273        instruction_index: usize,
274        error_type: &str,
275        qubit: usize,
276        probability: f64,
277        clean_detectors: &[bool],
278        clean_observables: &[bool],
279    ) -> Result<DEMError> {
280        // Create a modified circuit with the error applied deterministically
281        let mut modified_circuit = circuit.clone();
282
283        // Find the error instruction and modify it to have probability 1.0
284        // Actually, we need to inject a deterministic error at this point
285        // For simplicity, we'll run the circuit with the error forced on
286
287        // This is a simplified analysis - in practice, we'd trace error propagation
288        // For now, we'll use Monte Carlo sampling with forced error
289
290        let mut detector_targets = Vec::new();
291        let mut observable_targets = Vec::new();
292
293        // Run circuit with forced error
294        // Note: force_error is available on DetectorErrorModel for callers who need it.
295        // This analysis returns a simplified DEM entry; full error propagation requires
296        // running the stabilizer simulation with the error injected.
297        let mut executor = StimExecutor::from_circuit(circuit);
298
299        Ok(DEMError {
300            probability,
301            detector_targets,
302            observable_targets,
303            source_location: Some(ErrorLocation {
304                instruction_index,
305                error_type: format!("{}_ERROR", error_type),
306                qubits: vec![qubit],
307            }),
308        })
309    }
310
311    /// Analyze how a correlated error affects detectors/observables
312    fn analyze_correlated_error(
313        circuit: &StimCircuit,
314        instruction_index: usize,
315        targets: &[PauliTarget],
316        probability: f64,
317        clean_detectors: &[bool],
318        clean_observables: &[bool],
319    ) -> Result<DEMError> {
320        let qubits: Vec<usize> = targets.iter().map(|t| t.qubit).collect();
321        let error_type = targets
322            .iter()
323            .map(|t| format!("{:?}{}", t.pauli, t.qubit))
324            .collect::<Vec<_>>()
325            .join(" ");
326
327        let mut detector_targets = Vec::new();
328        let mut observable_targets = Vec::new();
329
330        // Simplified analysis - return empty targets
331        // Full implementation would trace error propagation
332
333        Ok(DEMError {
334            probability,
335            detector_targets,
336            observable_targets,
337            source_location: Some(ErrorLocation {
338                instruction_index,
339                error_type: format!("CORRELATED_ERROR {}", error_type),
340                qubits,
341            }),
342        })
343    }
344
345    /// Merge errors with the same detector/observable targets
346    fn merge_duplicate_errors(&mut self) {
347        let mut merged: HashMap<(Vec<usize>, Vec<usize>), DEMError> = HashMap::new();
348
349        for error in self.errors.drain(..) {
350            let key = (
351                error.detector_targets.clone(),
352                error.observable_targets.clone(),
353            );
354
355            if let Some(existing) = merged.get_mut(&key) {
356                // Combine probabilities: P(A or B) = P(A) + P(B) - P(A)P(B)
357                // For small probabilities, approximate as P(A) + P(B)
358                existing.probability += error.probability;
359            } else {
360                merged.insert(key, error);
361            }
362        }
363
364        self.errors = merged.into_values().collect();
365    }
366
367    /// Convert DEM to Stim DEM format string
368    #[must_use]
369    pub fn to_dem_string(&self) -> String {
370        let mut output = String::new();
371
372        // Header comments
373        output.push_str("# Detector Error Model\n");
374        output.push_str(&format!(
375            "# {} detectors, {} observables\n",
376            self.num_detectors, self.num_observables
377        ));
378        output.push('\n');
379
380        // Detector coordinates
381        let mut sorted_detectors: Vec<_> = self.detector_coords.iter().collect();
382        sorted_detectors.sort_by_key(|(k, _)| *k);
383        for (det_idx, coords) in sorted_detectors {
384            output.push_str(&format!(
385                "detector D{} ({}) # coordinates: {:?}\n",
386                det_idx,
387                coords
388                    .iter()
389                    .map(|c| c.to_string())
390                    .collect::<Vec<_>>()
391                    .join(", "),
392                coords
393            ));
394        }
395        if !self.detector_coords.is_empty() {
396            output.push('\n');
397        }
398
399        // Error mechanisms
400        for error in &self.errors {
401            if error.probability > 0.0 {
402                output.push_str(&format!("error({:.6})", error.probability));
403
404                for &det in &error.detector_targets {
405                    output.push_str(&format!(" D{}", det));
406                }
407
408                for &obs in &error.observable_targets {
409                    output.push_str(&format!(" L{}", obs));
410                }
411
412                if let Some(ref loc) = error.source_location {
413                    output.push_str(&format!(" # {}", loc.error_type));
414                }
415
416                output.push('\n');
417            }
418        }
419
420        output
421    }
422
423    /// Parse a DEM from string
424    pub fn from_dem_string(s: &str) -> Result<Self> {
425        let mut num_detectors = 0;
426        let mut num_observables = 0;
427        let mut errors = Vec::new();
428        let mut detector_coords = HashMap::new();
429
430        for line in s.lines() {
431            let line = line.trim();
432
433            // Skip empty lines and comments
434            if line.is_empty() || line.starts_with('#') {
435                continue;
436            }
437
438            // Parse detector coordinate line
439            if line.starts_with("detector") {
440                // detector D0 (x, y, z)
441                // Simplified parsing
442                continue;
443            }
444
445            // Parse error line
446            if line.starts_with("error(") {
447                let (prob_str, rest) = line
448                    .strip_prefix("error(")
449                    .and_then(|s| s.split_once(')'))
450                    .ok_or_else(|| {
451                        SimulatorError::InvalidOperation("Invalid error line format".to_string())
452                    })?;
453
454                let probability = prob_str.parse::<f64>().map_err(|_| {
455                    SimulatorError::InvalidOperation(format!("Invalid probability: {}", prob_str))
456                })?;
457
458                let mut detector_targets = Vec::new();
459                let mut observable_targets = Vec::new();
460
461                // Parse targets before any comment
462                let targets_str = rest.split('#').next().unwrap_or(rest);
463                for token in targets_str.split_whitespace() {
464                    if let Some(stripped) = token.strip_prefix('D') {
465                        let idx = stripped.parse::<usize>().map_err(|_| {
466                            SimulatorError::InvalidOperation(format!("Invalid detector: {}", token))
467                        })?;
468                        detector_targets.push(idx);
469                        num_detectors = num_detectors.max(idx + 1);
470                    } else if let Some(stripped) = token.strip_prefix('L') {
471                        let idx = stripped.parse::<usize>().map_err(|_| {
472                            SimulatorError::InvalidOperation(format!(
473                                "Invalid observable: {}",
474                                token
475                            ))
476                        })?;
477                        observable_targets.push(idx);
478                        num_observables = num_observables.max(idx + 1);
479                    }
480                }
481
482                errors.push(DEMError {
483                    probability,
484                    detector_targets,
485                    observable_targets,
486                    source_location: None,
487                });
488            }
489        }
490
491        Ok(Self {
492            num_detectors,
493            num_observables,
494            errors,
495            coordinate_shifts: Vec::new(),
496            detector_coords,
497        })
498    }
499
500    /// Sample errors according to the DEM
501    ///
502    /// Returns (detector_outcomes, observable_flips) for a single sample
503    pub fn sample(&self) -> (Vec<bool>, Vec<bool>) {
504        use scirs2_core::random::prelude::*;
505        let mut rng = thread_rng();
506
507        let mut detector_flips = vec![false; self.num_detectors];
508        let mut observable_flips = vec![false; self.num_observables];
509
510        for error in &self.errors {
511            if rng.random_bool(error.probability.min(1.0)) {
512                // This error occurred - flip affected detectors/observables
513                for &det in &error.detector_targets {
514                    if det < detector_flips.len() {
515                        detector_flips[det] ^= true;
516                    }
517                }
518                for &obs in &error.observable_targets {
519                    if obs < observable_flips.len() {
520                        observable_flips[obs] ^= true;
521                    }
522                }
523            }
524        }
525
526        (detector_flips, observable_flips)
527    }
528
529    /// Sample multiple shots
530    pub fn sample_batch(&self, num_shots: usize) -> Vec<(Vec<bool>, Vec<bool>)> {
531        (0..num_shots).map(|_| self.sample()).collect()
532    }
533
534    /// Get the total error probability
535    #[must_use]
536    pub fn total_error_probability(&self) -> f64 {
537        self.errors.iter().map(|e| e.probability).sum()
538    }
539
540    /// Get number of error mechanisms
541    #[must_use]
542    pub fn num_error_mechanisms(&self) -> usize {
543        self.errors.len()
544    }
545
546    /// Force a specific error on a qubit with probability 1.0.
547    ///
548    /// This inserts a deterministic error mechanism into the DEM for debugging
549    /// and testing.  The error targets no detectors or observables by default
550    /// (they must be wired by the caller via the returned index, or by
551    /// re-analysing the circuit); however, it records the qubit, error type,
552    /// and probability so that downstream samplers see it as a certain event.
553    ///
554    /// # Arguments
555    /// * `qubit`      — Index of the qubit to apply the error to.
556    /// * `error_type` — The Pauli or measurement error to force.
557    ///
558    /// # Returns
559    /// The index of the newly added error mechanism in `self.errors`.
560    pub fn force_error(&mut self, qubit: usize, error_type: ErrorType) -> usize {
561        let forced = DEMError {
562            probability: 1.0,
563            detector_targets: Vec::new(),
564            observable_targets: Vec::new(),
565            source_location: Some(ErrorLocation {
566                instruction_index: 0,
567                error_type: error_type.label().to_string(),
568                qubits: vec![qubit],
569            }),
570        };
571
572        let idx = self.errors.len();
573        self.errors.push(forced);
574        idx
575    }
576
577    /// Force a specific error and associate it with given detector and observable targets.
578    ///
579    /// Unlike [`force_error`](Self::force_error), this variant lets the caller
580    /// specify exactly which detectors and observables flip when the error
581    /// occurs, enabling accurate decoding tests.
582    pub fn force_error_with_targets(
583        &mut self,
584        qubit: usize,
585        error_type: ErrorType,
586        detector_targets: Vec<usize>,
587        observable_targets: Vec<usize>,
588    ) -> usize {
589        let forced = DEMError {
590            probability: 1.0,
591            detector_targets,
592            observable_targets,
593            source_location: Some(ErrorLocation {
594                instruction_index: 0,
595                error_type: error_type.label().to_string(),
596                qubits: vec![qubit],
597            }),
598        };
599
600        let idx = self.errors.len();
601        self.errors.push(forced);
602        idx
603    }
604}
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609
610    #[test]
611    fn test_empty_dem() {
612        let dem = DetectorErrorModel::new(5, 2);
613        assert_eq!(dem.num_detectors, 5);
614        assert_eq!(dem.num_observables, 2);
615        assert!(dem.errors.is_empty());
616    }
617
618    #[test]
619    fn test_dem_to_string() {
620        let mut dem = DetectorErrorModel::new(2, 1);
621        dem.errors.push(DEMError {
622            probability: 0.01,
623            detector_targets: vec![0, 1],
624            observable_targets: vec![0],
625            source_location: None,
626        });
627
628        let dem_string = dem.to_dem_string();
629        assert!(dem_string.contains("error(0.010000)"));
630        assert!(dem_string.contains("D0"));
631        assert!(dem_string.contains("D1"));
632        assert!(dem_string.contains("L0"));
633    }
634
635    #[test]
636    fn test_dem_parse_roundtrip() {
637        let dem_str = r#"
638            # Test DEM
639            error(0.01) D0 D1
640            error(0.02) D2 L0
641        "#;
642
643        let dem = DetectorErrorModel::from_dem_string(dem_str).unwrap();
644        assert_eq!(dem.num_detectors, 3);
645        assert_eq!(dem.num_observables, 1);
646        assert_eq!(dem.errors.len(), 2);
647
648        assert!((dem.errors[0].probability - 0.01).abs() < 1e-10);
649        assert_eq!(dem.errors[0].detector_targets, vec![0, 1]);
650
651        assert!((dem.errors[1].probability - 0.02).abs() < 1e-10);
652        assert_eq!(dem.errors[1].detector_targets, vec![2]);
653        assert_eq!(dem.errors[1].observable_targets, vec![0]);
654    }
655
656    #[test]
657    fn test_dem_sample() {
658        let mut dem = DetectorErrorModel::new(3, 1);
659        // Add error that always triggers (probability 1.0)
660        dem.errors.push(DEMError {
661            probability: 1.0,
662            detector_targets: vec![0],
663            observable_targets: vec![],
664            source_location: None,
665        });
666
667        let (detector_flips, _) = dem.sample();
668        assert!(detector_flips[0]); // Should always flip
669        assert!(!detector_flips[1]); // Should never flip
670        assert!(!detector_flips[2]); // Should never flip
671    }
672
673    #[test]
674    fn test_from_circuit_basic() {
675        let circuit_str = r#"
676            H 0
677            CNOT 0 1
678            M 0 1
679            DETECTOR rec[-1] rec[-2]
680        "#;
681
682        let circuit = StimCircuit::from_str(circuit_str).unwrap();
683        let dem = DetectorErrorModel::from_circuit(&circuit).unwrap();
684
685        assert_eq!(dem.num_detectors, 1);
686        assert_eq!(dem.num_observables, 0);
687    }
688
689    #[test]
690    fn test_dem_total_probability() {
691        let mut dem = DetectorErrorModel::new(2, 0);
692        dem.errors.push(DEMError {
693            probability: 0.01,
694            detector_targets: vec![0],
695            observable_targets: vec![],
696            source_location: None,
697        });
698        dem.errors.push(DEMError {
699            probability: 0.02,
700            detector_targets: vec![1],
701            observable_targets: vec![],
702            source_location: None,
703        });
704
705        let total = dem.total_error_probability();
706        assert!((total - 0.03).abs() < 1e-10);
707    }
708}