quantrs2_sim/
stim_dem.rs

1//! Detector Error Model (DEM) for Stim circuit error analysis
2//!
3//! A DEM describes how errors in a circuit propagate to detectors and observables.
4//! This enables efficient decoding without re-simulating the full circuit.
5//!
6//! ## DEM Format
7//!
8//! The DEM file format consists of error instructions:
9//! ```text
10//! error(0.01) D0 D1
11//! error(0.02) D2 L0
12//! ```
13//!
14//! Each error line specifies:
15//! - Probability of the error occurring
16//! - Which detectors are flipped by this error (D0, D1, ...)
17//! - Which logical observables are flipped (L0, L1, ...)
18
19use crate::error::{Result, SimulatorError};
20use crate::stim_executor::{DetectorRecord, ObservableRecord, StimExecutor};
21use crate::stim_parser::{PauliTarget, PauliType, StimCircuit, StimInstruction};
22use std::collections::{HashMap, HashSet};
23
24/// A single error mechanism in the DEM
25#[derive(Debug, Clone)]
26pub struct DEMError {
27    /// Probability of this error occurring
28    pub probability: f64,
29    /// Detector indices that flip when this error occurs
30    pub detector_targets: Vec<usize>,
31    /// Observable indices that flip when this error occurs
32    pub observable_targets: Vec<usize>,
33    /// Original error location (for debugging)
34    pub source_location: Option<ErrorLocation>,
35}
36
37/// Location information for error source
38#[derive(Debug, Clone)]
39pub struct ErrorLocation {
40    /// Instruction index in the circuit
41    pub instruction_index: usize,
42    /// Description of the error type
43    pub error_type: String,
44    /// Qubits involved
45    pub qubits: Vec<usize>,
46}
47
48/// Detector Error Model representation
49#[derive(Debug, Clone)]
50pub struct DetectorErrorModel {
51    /// Number of detectors in the circuit
52    pub num_detectors: usize,
53    /// Number of logical observables
54    pub num_observables: usize,
55    /// List of error mechanisms
56    pub errors: Vec<DEMError>,
57    /// Coordinate system shifts (for visualization)
58    pub coordinate_shifts: Vec<Vec<f64>>,
59    /// Detector coordinates
60    pub detector_coords: HashMap<usize, Vec<f64>>,
61}
62
63impl DetectorErrorModel {
64    /// Create a new empty DEM
65    #[must_use]
66    pub fn new(num_detectors: usize, num_observables: usize) -> Self {
67        Self {
68            num_detectors,
69            num_observables,
70            errors: Vec::new(),
71            coordinate_shifts: Vec::new(),
72            detector_coords: HashMap::new(),
73        }
74    }
75
76    /// Generate a DEM from a Stim circuit
77    ///
78    /// This performs error analysis by:
79    /// 1. Identifying all error mechanisms in the circuit
80    /// 2. Propagating each error through to detectors/observables
81    /// 3. Recording which detectors/observables are affected
82    pub fn from_circuit(circuit: &StimCircuit) -> Result<Self> {
83        // First, run the circuit without errors to establish baseline
84        let mut clean_executor = StimExecutor::from_circuit(circuit);
85        let clean_result = clean_executor.execute(circuit)?;
86
87        let num_detectors = clean_result.num_detectors;
88        let num_observables = clean_result.num_observables;
89
90        let mut dem = Self::new(num_detectors, num_observables);
91
92        // Collect detector coordinates
93        for detector in clean_executor.detectors() {
94            if !detector.coordinates.is_empty() {
95                dem.detector_coords
96                    .insert(detector.index, detector.coordinates.clone());
97            }
98        }
99
100        // Analyze each error instruction in the circuit
101        let mut instruction_index = 0;
102        for instruction in &circuit.instructions {
103            match instruction {
104                StimInstruction::XError {
105                    probability,
106                    qubits,
107                }
108                | StimInstruction::YError {
109                    probability,
110                    qubits,
111                }
112                | StimInstruction::ZError {
113                    probability,
114                    qubits,
115                } => {
116                    let error_type = match instruction {
117                        StimInstruction::XError { .. } => "X",
118                        StimInstruction::YError { .. } => "Y",
119                        _ => "Z",
120                    };
121
122                    for &qubit in qubits {
123                        let dem_error = Self::analyze_single_qubit_error(
124                            circuit,
125                            instruction_index,
126                            error_type,
127                            qubit,
128                            *probability,
129                            &clean_result.detector_values,
130                            &clean_result.observable_values,
131                        )?;
132
133                        if !dem_error.detector_targets.is_empty()
134                            || !dem_error.observable_targets.is_empty()
135                        {
136                            dem.errors.push(dem_error);
137                        }
138                    }
139                }
140
141                StimInstruction::Depolarize1 {
142                    probability,
143                    qubits,
144                } => {
145                    // Depolarizing noise: treat as 3 separate X/Y/Z errors
146                    let per_pauli_prob = probability / 3.0;
147                    for &qubit in qubits {
148                        for error_type in &["X", "Y", "Z"] {
149                            let dem_error = Self::analyze_single_qubit_error(
150                                circuit,
151                                instruction_index,
152                                error_type,
153                                qubit,
154                                per_pauli_prob,
155                                &clean_result.detector_values,
156                                &clean_result.observable_values,
157                            )?;
158
159                            if !dem_error.detector_targets.is_empty()
160                                || !dem_error.observable_targets.is_empty()
161                            {
162                                dem.errors.push(dem_error);
163                            }
164                        }
165                    }
166                }
167
168                StimInstruction::CorrelatedError {
169                    probability,
170                    targets,
171                }
172                | StimInstruction::ElseCorrelatedError {
173                    probability,
174                    targets,
175                } => {
176                    let dem_error = Self::analyze_correlated_error(
177                        circuit,
178                        instruction_index,
179                        targets,
180                        *probability,
181                        &clean_result.detector_values,
182                        &clean_result.observable_values,
183                    )?;
184
185                    if !dem_error.detector_targets.is_empty()
186                        || !dem_error.observable_targets.is_empty()
187                    {
188                        dem.errors.push(dem_error);
189                    }
190                }
191
192                StimInstruction::Depolarize2 {
193                    probability,
194                    qubit_pairs,
195                } => {
196                    // Two-qubit depolarizing: 15 error mechanisms
197                    let per_pauli_prob = probability / 15.0;
198                    for &(q1, q2) in qubit_pairs {
199                        for p1 in &[PauliType::I, PauliType::X, PauliType::Y, PauliType::Z] {
200                            for p2 in &[PauliType::I, PauliType::X, PauliType::Y, PauliType::Z] {
201                                if *p1 == PauliType::I && *p2 == PauliType::I {
202                                    continue; // Skip identity
203                                }
204                                let targets = vec![
205                                    PauliTarget {
206                                        pauli: *p1,
207                                        qubit: q1,
208                                    },
209                                    PauliTarget {
210                                        pauli: *p2,
211                                        qubit: q2,
212                                    },
213                                ];
214                                let dem_error = Self::analyze_correlated_error(
215                                    circuit,
216                                    instruction_index,
217                                    &targets,
218                                    per_pauli_prob,
219                                    &clean_result.detector_values,
220                                    &clean_result.observable_values,
221                                )?;
222
223                                if !dem_error.detector_targets.is_empty()
224                                    || !dem_error.observable_targets.is_empty()
225                                {
226                                    dem.errors.push(dem_error);
227                                }
228                            }
229                        }
230                    }
231                }
232
233                _ => {}
234            }
235            instruction_index += 1;
236        }
237
238        // Merge duplicate errors (same detector/observable targets)
239        dem.merge_duplicate_errors();
240
241        Ok(dem)
242    }
243
244    /// Analyze how a single-qubit error affects detectors/observables
245    fn analyze_single_qubit_error(
246        circuit: &StimCircuit,
247        instruction_index: usize,
248        error_type: &str,
249        qubit: usize,
250        probability: f64,
251        clean_detectors: &[bool],
252        clean_observables: &[bool],
253    ) -> Result<DEMError> {
254        // Create a modified circuit with the error applied deterministically
255        let mut modified_circuit = circuit.clone();
256
257        // Find the error instruction and modify it to have probability 1.0
258        // Actually, we need to inject a deterministic error at this point
259        // For simplicity, we'll run the circuit with the error forced on
260
261        // This is a simplified analysis - in practice, we'd trace error propagation
262        // For now, we'll use Monte Carlo sampling with forced error
263
264        let mut detector_targets = Vec::new();
265        let mut observable_targets = Vec::new();
266
267        // Run circuit with forced error
268        let mut executor = StimExecutor::from_circuit(circuit);
269        // TODO: Add method to force specific error
270        // For now, return empty targets (simplified DEM)
271
272        Ok(DEMError {
273            probability,
274            detector_targets,
275            observable_targets,
276            source_location: Some(ErrorLocation {
277                instruction_index,
278                error_type: format!("{}_ERROR", error_type),
279                qubits: vec![qubit],
280            }),
281        })
282    }
283
284    /// Analyze how a correlated error affects detectors/observables
285    fn analyze_correlated_error(
286        circuit: &StimCircuit,
287        instruction_index: usize,
288        targets: &[PauliTarget],
289        probability: f64,
290        clean_detectors: &[bool],
291        clean_observables: &[bool],
292    ) -> Result<DEMError> {
293        let qubits: Vec<usize> = targets.iter().map(|t| t.qubit).collect();
294        let error_type = targets
295            .iter()
296            .map(|t| format!("{:?}{}", t.pauli, t.qubit))
297            .collect::<Vec<_>>()
298            .join(" ");
299
300        let mut detector_targets = Vec::new();
301        let mut observable_targets = Vec::new();
302
303        // Simplified analysis - return empty targets
304        // Full implementation would trace error propagation
305
306        Ok(DEMError {
307            probability,
308            detector_targets,
309            observable_targets,
310            source_location: Some(ErrorLocation {
311                instruction_index,
312                error_type: format!("CORRELATED_ERROR {}", error_type),
313                qubits,
314            }),
315        })
316    }
317
318    /// Merge errors with the same detector/observable targets
319    fn merge_duplicate_errors(&mut self) {
320        let mut merged: HashMap<(Vec<usize>, Vec<usize>), DEMError> = HashMap::new();
321
322        for error in self.errors.drain(..) {
323            let key = (
324                error.detector_targets.clone(),
325                error.observable_targets.clone(),
326            );
327
328            if let Some(existing) = merged.get_mut(&key) {
329                // Combine probabilities: P(A or B) = P(A) + P(B) - P(A)P(B)
330                // For small probabilities, approximate as P(A) + P(B)
331                existing.probability += error.probability;
332            } else {
333                merged.insert(key, error);
334            }
335        }
336
337        self.errors = merged.into_values().collect();
338    }
339
340    /// Convert DEM to Stim DEM format string
341    #[must_use]
342    pub fn to_dem_string(&self) -> String {
343        let mut output = String::new();
344
345        // Header comments
346        output.push_str("# Detector Error Model\n");
347        output.push_str(&format!(
348            "# {} detectors, {} observables\n",
349            self.num_detectors, self.num_observables
350        ));
351        output.push('\n');
352
353        // Detector coordinates
354        let mut sorted_detectors: Vec<_> = self.detector_coords.iter().collect();
355        sorted_detectors.sort_by_key(|(k, _)| *k);
356        for (det_idx, coords) in sorted_detectors {
357            output.push_str(&format!(
358                "detector D{} ({}) # coordinates: {:?}\n",
359                det_idx,
360                coords
361                    .iter()
362                    .map(|c| c.to_string())
363                    .collect::<Vec<_>>()
364                    .join(", "),
365                coords
366            ));
367        }
368        if !self.detector_coords.is_empty() {
369            output.push('\n');
370        }
371
372        // Error mechanisms
373        for error in &self.errors {
374            if error.probability > 0.0 {
375                output.push_str(&format!("error({:.6})", error.probability));
376
377                for &det in &error.detector_targets {
378                    output.push_str(&format!(" D{}", det));
379                }
380
381                for &obs in &error.observable_targets {
382                    output.push_str(&format!(" L{}", obs));
383                }
384
385                if let Some(ref loc) = error.source_location {
386                    output.push_str(&format!(" # {}", loc.error_type));
387                }
388
389                output.push('\n');
390            }
391        }
392
393        output
394    }
395
396    /// Parse a DEM from string
397    pub fn from_dem_string(s: &str) -> Result<Self> {
398        let mut num_detectors = 0;
399        let mut num_observables = 0;
400        let mut errors = Vec::new();
401        let mut detector_coords = HashMap::new();
402
403        for line in s.lines() {
404            let line = line.trim();
405
406            // Skip empty lines and comments
407            if line.is_empty() || line.starts_with('#') {
408                continue;
409            }
410
411            // Parse detector coordinate line
412            if line.starts_with("detector") {
413                // detector D0 (x, y, z)
414                // Simplified parsing
415                continue;
416            }
417
418            // Parse error line
419            if line.starts_with("error(") {
420                let (prob_str, rest) = line
421                    .strip_prefix("error(")
422                    .and_then(|s| s.split_once(')'))
423                    .ok_or_else(|| {
424                        SimulatorError::InvalidOperation("Invalid error line format".to_string())
425                    })?;
426
427                let probability = prob_str.parse::<f64>().map_err(|_| {
428                    SimulatorError::InvalidOperation(format!("Invalid probability: {}", prob_str))
429                })?;
430
431                let mut detector_targets = Vec::new();
432                let mut observable_targets = Vec::new();
433
434                // Parse targets before any comment
435                let targets_str = rest.split('#').next().unwrap_or(rest);
436                for token in targets_str.split_whitespace() {
437                    if let Some(stripped) = token.strip_prefix('D') {
438                        let idx = stripped.parse::<usize>().map_err(|_| {
439                            SimulatorError::InvalidOperation(format!("Invalid detector: {}", token))
440                        })?;
441                        detector_targets.push(idx);
442                        num_detectors = num_detectors.max(idx + 1);
443                    } else if let Some(stripped) = token.strip_prefix('L') {
444                        let idx = stripped.parse::<usize>().map_err(|_| {
445                            SimulatorError::InvalidOperation(format!(
446                                "Invalid observable: {}",
447                                token
448                            ))
449                        })?;
450                        observable_targets.push(idx);
451                        num_observables = num_observables.max(idx + 1);
452                    }
453                }
454
455                errors.push(DEMError {
456                    probability,
457                    detector_targets,
458                    observable_targets,
459                    source_location: None,
460                });
461            }
462        }
463
464        Ok(Self {
465            num_detectors,
466            num_observables,
467            errors,
468            coordinate_shifts: Vec::new(),
469            detector_coords,
470        })
471    }
472
473    /// Sample errors according to the DEM
474    ///
475    /// Returns (detector_outcomes, observable_flips) for a single sample
476    pub fn sample(&self) -> (Vec<bool>, Vec<bool>) {
477        use scirs2_core::random::prelude::*;
478        let mut rng = thread_rng();
479
480        let mut detector_flips = vec![false; self.num_detectors];
481        let mut observable_flips = vec![false; self.num_observables];
482
483        for error in &self.errors {
484            if rng.gen_bool(error.probability.min(1.0)) {
485                // This error occurred - flip affected detectors/observables
486                for &det in &error.detector_targets {
487                    if det < detector_flips.len() {
488                        detector_flips[det] ^= true;
489                    }
490                }
491                for &obs in &error.observable_targets {
492                    if obs < observable_flips.len() {
493                        observable_flips[obs] ^= true;
494                    }
495                }
496            }
497        }
498
499        (detector_flips, observable_flips)
500    }
501
502    /// Sample multiple shots
503    pub fn sample_batch(&self, num_shots: usize) -> Vec<(Vec<bool>, Vec<bool>)> {
504        (0..num_shots).map(|_| self.sample()).collect()
505    }
506
507    /// Get the total error probability
508    #[must_use]
509    pub fn total_error_probability(&self) -> f64 {
510        self.errors.iter().map(|e| e.probability).sum()
511    }
512
513    /// Get number of error mechanisms
514    #[must_use]
515    pub fn num_error_mechanisms(&self) -> usize {
516        self.errors.len()
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    #[test]
525    fn test_empty_dem() {
526        let dem = DetectorErrorModel::new(5, 2);
527        assert_eq!(dem.num_detectors, 5);
528        assert_eq!(dem.num_observables, 2);
529        assert!(dem.errors.is_empty());
530    }
531
532    #[test]
533    fn test_dem_to_string() {
534        let mut dem = DetectorErrorModel::new(2, 1);
535        dem.errors.push(DEMError {
536            probability: 0.01,
537            detector_targets: vec![0, 1],
538            observable_targets: vec![0],
539            source_location: None,
540        });
541
542        let dem_string = dem.to_dem_string();
543        assert!(dem_string.contains("error(0.010000)"));
544        assert!(dem_string.contains("D0"));
545        assert!(dem_string.contains("D1"));
546        assert!(dem_string.contains("L0"));
547    }
548
549    #[test]
550    fn test_dem_parse_roundtrip() {
551        let dem_str = r#"
552            # Test DEM
553            error(0.01) D0 D1
554            error(0.02) D2 L0
555        "#;
556
557        let dem = DetectorErrorModel::from_dem_string(dem_str).unwrap();
558        assert_eq!(dem.num_detectors, 3);
559        assert_eq!(dem.num_observables, 1);
560        assert_eq!(dem.errors.len(), 2);
561
562        assert!((dem.errors[0].probability - 0.01).abs() < 1e-10);
563        assert_eq!(dem.errors[0].detector_targets, vec![0, 1]);
564
565        assert!((dem.errors[1].probability - 0.02).abs() < 1e-10);
566        assert_eq!(dem.errors[1].detector_targets, vec![2]);
567        assert_eq!(dem.errors[1].observable_targets, vec![0]);
568    }
569
570    #[test]
571    fn test_dem_sample() {
572        let mut dem = DetectorErrorModel::new(3, 1);
573        // Add error that always triggers (probability 1.0)
574        dem.errors.push(DEMError {
575            probability: 1.0,
576            detector_targets: vec![0],
577            observable_targets: vec![],
578            source_location: None,
579        });
580
581        let (detector_flips, _) = dem.sample();
582        assert!(detector_flips[0]); // Should always flip
583        assert!(!detector_flips[1]); // Should never flip
584        assert!(!detector_flips[2]); // Should never flip
585    }
586
587    #[test]
588    fn test_from_circuit_basic() {
589        let circuit_str = r#"
590            H 0
591            CNOT 0 1
592            M 0 1
593            DETECTOR rec[-1] rec[-2]
594        "#;
595
596        let circuit = StimCircuit::from_str(circuit_str).unwrap();
597        let dem = DetectorErrorModel::from_circuit(&circuit).unwrap();
598
599        assert_eq!(dem.num_detectors, 1);
600        assert_eq!(dem.num_observables, 0);
601    }
602
603    #[test]
604    fn test_dem_total_probability() {
605        let mut dem = DetectorErrorModel::new(2, 0);
606        dem.errors.push(DEMError {
607            probability: 0.01,
608            detector_targets: vec![0],
609            observable_targets: vec![],
610            source_location: None,
611        });
612        dem.errors.push(DEMError {
613            probability: 0.02,
614            detector_targets: vec![1],
615            observable_targets: vec![],
616            source_location: None,
617        });
618
619        let total = dem.total_error_probability();
620        assert!((total - 0.03).abs() < 1e-10);
621    }
622}