1use crate::error::{Result, SimulatorError};
20use crate::stim_executor::{DetectorRecord, ObservableRecord, StimExecutor};
21use crate::stim_parser::{PauliTarget, PauliType, StimCircuit, StimInstruction};
22use std::collections::{HashMap, HashSet};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ErrorType {
27 PauliX,
29 PauliZ,
31 PauliY,
33 Measurement,
35}
36
37impl ErrorType {
38 #[must_use]
40 pub fn label(&self) -> &'static str {
41 match self {
42 ErrorType::PauliX => "X_ERROR",
43 ErrorType::PauliZ => "Z_ERROR",
44 ErrorType::PauliY => "Y_ERROR",
45 ErrorType::Measurement => "MEASUREMENT_ERROR",
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct DEMError {
53 pub probability: f64,
55 pub detector_targets: Vec<usize>,
57 pub observable_targets: Vec<usize>,
59 pub source_location: Option<ErrorLocation>,
61}
62
63#[derive(Debug, Clone)]
65pub struct ErrorLocation {
66 pub instruction_index: usize,
68 pub error_type: String,
70 pub qubits: Vec<usize>,
72}
73
74#[derive(Debug, Clone)]
76pub struct DetectorErrorModel {
77 pub num_detectors: usize,
79 pub num_observables: usize,
81 pub errors: Vec<DEMError>,
83 pub coordinate_shifts: Vec<Vec<f64>>,
85 pub detector_coords: HashMap<usize, Vec<f64>>,
87}
88
89impl DetectorErrorModel {
90 #[must_use]
92 pub fn new(num_detectors: usize, num_observables: usize) -> Self {
93 Self {
94 num_detectors,
95 num_observables,
96 errors: Vec::new(),
97 coordinate_shifts: Vec::new(),
98 detector_coords: HashMap::new(),
99 }
100 }
101
102 pub fn from_circuit(circuit: &StimCircuit) -> Result<Self> {
109 let mut clean_executor = StimExecutor::from_circuit(circuit);
111 let clean_result = clean_executor.execute(circuit)?;
112
113 let num_detectors = clean_result.num_detectors;
114 let num_observables = clean_result.num_observables;
115
116 let mut dem = Self::new(num_detectors, num_observables);
117
118 for detector in clean_executor.detectors() {
120 if !detector.coordinates.is_empty() {
121 dem.detector_coords
122 .insert(detector.index, detector.coordinates.clone());
123 }
124 }
125
126 let mut instruction_index = 0;
128 for instruction in &circuit.instructions {
129 match instruction {
130 StimInstruction::XError {
131 probability,
132 qubits,
133 }
134 | StimInstruction::YError {
135 probability,
136 qubits,
137 }
138 | StimInstruction::ZError {
139 probability,
140 qubits,
141 } => {
142 let error_type = match instruction {
143 StimInstruction::XError { .. } => "X",
144 StimInstruction::YError { .. } => "Y",
145 _ => "Z",
146 };
147
148 for &qubit in qubits {
149 let dem_error = Self::analyze_single_qubit_error(
150 circuit,
151 instruction_index,
152 error_type,
153 qubit,
154 *probability,
155 &clean_result.detector_values,
156 &clean_result.observable_values,
157 )?;
158
159 if !dem_error.detector_targets.is_empty()
160 || !dem_error.observable_targets.is_empty()
161 {
162 dem.errors.push(dem_error);
163 }
164 }
165 }
166
167 StimInstruction::Depolarize1 {
168 probability,
169 qubits,
170 } => {
171 let per_pauli_prob = probability / 3.0;
173 for &qubit in qubits {
174 for error_type in &["X", "Y", "Z"] {
175 let dem_error = Self::analyze_single_qubit_error(
176 circuit,
177 instruction_index,
178 error_type,
179 qubit,
180 per_pauli_prob,
181 &clean_result.detector_values,
182 &clean_result.observable_values,
183 )?;
184
185 if !dem_error.detector_targets.is_empty()
186 || !dem_error.observable_targets.is_empty()
187 {
188 dem.errors.push(dem_error);
189 }
190 }
191 }
192 }
193
194 StimInstruction::CorrelatedError {
195 probability,
196 targets,
197 }
198 | StimInstruction::ElseCorrelatedError {
199 probability,
200 targets,
201 } => {
202 let dem_error = Self::analyze_correlated_error(
203 circuit,
204 instruction_index,
205 targets,
206 *probability,
207 &clean_result.detector_values,
208 &clean_result.observable_values,
209 )?;
210
211 if !dem_error.detector_targets.is_empty()
212 || !dem_error.observable_targets.is_empty()
213 {
214 dem.errors.push(dem_error);
215 }
216 }
217
218 StimInstruction::Depolarize2 {
219 probability,
220 qubit_pairs,
221 } => {
222 let per_pauli_prob = probability / 15.0;
224 for &(q1, q2) in qubit_pairs {
225 for p1 in &[PauliType::I, PauliType::X, PauliType::Y, PauliType::Z] {
226 for p2 in &[PauliType::I, PauliType::X, PauliType::Y, PauliType::Z] {
227 if *p1 == PauliType::I && *p2 == PauliType::I {
228 continue; }
230 let targets = vec![
231 PauliTarget {
232 pauli: *p1,
233 qubit: q1,
234 },
235 PauliTarget {
236 pauli: *p2,
237 qubit: q2,
238 },
239 ];
240 let dem_error = Self::analyze_correlated_error(
241 circuit,
242 instruction_index,
243 &targets,
244 per_pauli_prob,
245 &clean_result.detector_values,
246 &clean_result.observable_values,
247 )?;
248
249 if !dem_error.detector_targets.is_empty()
250 || !dem_error.observable_targets.is_empty()
251 {
252 dem.errors.push(dem_error);
253 }
254 }
255 }
256 }
257 }
258
259 _ => {}
260 }
261 instruction_index += 1;
262 }
263
264 dem.merge_duplicate_errors();
266
267 Ok(dem)
268 }
269
270 fn analyze_single_qubit_error(
272 circuit: &StimCircuit,
273 instruction_index: usize,
274 error_type: &str,
275 qubit: usize,
276 probability: f64,
277 clean_detectors: &[bool],
278 clean_observables: &[bool],
279 ) -> Result<DEMError> {
280 let mut modified_circuit = circuit.clone();
282
283 let mut detector_targets = Vec::new();
291 let mut observable_targets = Vec::new();
292
293 let mut executor = StimExecutor::from_circuit(circuit);
298
299 Ok(DEMError {
300 probability,
301 detector_targets,
302 observable_targets,
303 source_location: Some(ErrorLocation {
304 instruction_index,
305 error_type: format!("{}_ERROR", error_type),
306 qubits: vec![qubit],
307 }),
308 })
309 }
310
311 fn analyze_correlated_error(
313 circuit: &StimCircuit,
314 instruction_index: usize,
315 targets: &[PauliTarget],
316 probability: f64,
317 clean_detectors: &[bool],
318 clean_observables: &[bool],
319 ) -> Result<DEMError> {
320 let qubits: Vec<usize> = targets.iter().map(|t| t.qubit).collect();
321 let error_type = targets
322 .iter()
323 .map(|t| format!("{:?}{}", t.pauli, t.qubit))
324 .collect::<Vec<_>>()
325 .join(" ");
326
327 let mut detector_targets = Vec::new();
328 let mut observable_targets = Vec::new();
329
330 Ok(DEMError {
334 probability,
335 detector_targets,
336 observable_targets,
337 source_location: Some(ErrorLocation {
338 instruction_index,
339 error_type: format!("CORRELATED_ERROR {}", error_type),
340 qubits,
341 }),
342 })
343 }
344
345 fn merge_duplicate_errors(&mut self) {
347 let mut merged: HashMap<(Vec<usize>, Vec<usize>), DEMError> = HashMap::new();
348
349 for error in self.errors.drain(..) {
350 let key = (
351 error.detector_targets.clone(),
352 error.observable_targets.clone(),
353 );
354
355 if let Some(existing) = merged.get_mut(&key) {
356 existing.probability += error.probability;
359 } else {
360 merged.insert(key, error);
361 }
362 }
363
364 self.errors = merged.into_values().collect();
365 }
366
367 #[must_use]
369 pub fn to_dem_string(&self) -> String {
370 let mut output = String::new();
371
372 output.push_str("# Detector Error Model\n");
374 output.push_str(&format!(
375 "# {} detectors, {} observables\n",
376 self.num_detectors, self.num_observables
377 ));
378 output.push('\n');
379
380 let mut sorted_detectors: Vec<_> = self.detector_coords.iter().collect();
382 sorted_detectors.sort_by_key(|(k, _)| *k);
383 for (det_idx, coords) in sorted_detectors {
384 output.push_str(&format!(
385 "detector D{} ({}) # coordinates: {:?}\n",
386 det_idx,
387 coords
388 .iter()
389 .map(|c| c.to_string())
390 .collect::<Vec<_>>()
391 .join(", "),
392 coords
393 ));
394 }
395 if !self.detector_coords.is_empty() {
396 output.push('\n');
397 }
398
399 for error in &self.errors {
401 if error.probability > 0.0 {
402 output.push_str(&format!("error({:.6})", error.probability));
403
404 for &det in &error.detector_targets {
405 output.push_str(&format!(" D{}", det));
406 }
407
408 for &obs in &error.observable_targets {
409 output.push_str(&format!(" L{}", obs));
410 }
411
412 if let Some(ref loc) = error.source_location {
413 output.push_str(&format!(" # {}", loc.error_type));
414 }
415
416 output.push('\n');
417 }
418 }
419
420 output
421 }
422
423 pub fn from_dem_string(s: &str) -> Result<Self> {
425 let mut num_detectors = 0;
426 let mut num_observables = 0;
427 let mut errors = Vec::new();
428 let mut detector_coords = HashMap::new();
429
430 for line in s.lines() {
431 let line = line.trim();
432
433 if line.is_empty() || line.starts_with('#') {
435 continue;
436 }
437
438 if line.starts_with("detector") {
440 continue;
443 }
444
445 if line.starts_with("error(") {
447 let (prob_str, rest) = line
448 .strip_prefix("error(")
449 .and_then(|s| s.split_once(')'))
450 .ok_or_else(|| {
451 SimulatorError::InvalidOperation("Invalid error line format".to_string())
452 })?;
453
454 let probability = prob_str.parse::<f64>().map_err(|_| {
455 SimulatorError::InvalidOperation(format!("Invalid probability: {}", prob_str))
456 })?;
457
458 let mut detector_targets = Vec::new();
459 let mut observable_targets = Vec::new();
460
461 let targets_str = rest.split('#').next().unwrap_or(rest);
463 for token in targets_str.split_whitespace() {
464 if let Some(stripped) = token.strip_prefix('D') {
465 let idx = stripped.parse::<usize>().map_err(|_| {
466 SimulatorError::InvalidOperation(format!("Invalid detector: {}", token))
467 })?;
468 detector_targets.push(idx);
469 num_detectors = num_detectors.max(idx + 1);
470 } else if let Some(stripped) = token.strip_prefix('L') {
471 let idx = stripped.parse::<usize>().map_err(|_| {
472 SimulatorError::InvalidOperation(format!(
473 "Invalid observable: {}",
474 token
475 ))
476 })?;
477 observable_targets.push(idx);
478 num_observables = num_observables.max(idx + 1);
479 }
480 }
481
482 errors.push(DEMError {
483 probability,
484 detector_targets,
485 observable_targets,
486 source_location: None,
487 });
488 }
489 }
490
491 Ok(Self {
492 num_detectors,
493 num_observables,
494 errors,
495 coordinate_shifts: Vec::new(),
496 detector_coords,
497 })
498 }
499
500 pub fn sample(&self) -> (Vec<bool>, Vec<bool>) {
504 use scirs2_core::random::prelude::*;
505 let mut rng = thread_rng();
506
507 let mut detector_flips = vec![false; self.num_detectors];
508 let mut observable_flips = vec![false; self.num_observables];
509
510 for error in &self.errors {
511 if rng.random_bool(error.probability.min(1.0)) {
512 for &det in &error.detector_targets {
514 if det < detector_flips.len() {
515 detector_flips[det] ^= true;
516 }
517 }
518 for &obs in &error.observable_targets {
519 if obs < observable_flips.len() {
520 observable_flips[obs] ^= true;
521 }
522 }
523 }
524 }
525
526 (detector_flips, observable_flips)
527 }
528
529 pub fn sample_batch(&self, num_shots: usize) -> Vec<(Vec<bool>, Vec<bool>)> {
531 (0..num_shots).map(|_| self.sample()).collect()
532 }
533
534 #[must_use]
536 pub fn total_error_probability(&self) -> f64 {
537 self.errors.iter().map(|e| e.probability).sum()
538 }
539
540 #[must_use]
542 pub fn num_error_mechanisms(&self) -> usize {
543 self.errors.len()
544 }
545
546 pub fn force_error(&mut self, qubit: usize, error_type: ErrorType) -> usize {
561 let forced = DEMError {
562 probability: 1.0,
563 detector_targets: Vec::new(),
564 observable_targets: Vec::new(),
565 source_location: Some(ErrorLocation {
566 instruction_index: 0,
567 error_type: error_type.label().to_string(),
568 qubits: vec![qubit],
569 }),
570 };
571
572 let idx = self.errors.len();
573 self.errors.push(forced);
574 idx
575 }
576
577 pub fn force_error_with_targets(
583 &mut self,
584 qubit: usize,
585 error_type: ErrorType,
586 detector_targets: Vec<usize>,
587 observable_targets: Vec<usize>,
588 ) -> usize {
589 let forced = DEMError {
590 probability: 1.0,
591 detector_targets,
592 observable_targets,
593 source_location: Some(ErrorLocation {
594 instruction_index: 0,
595 error_type: error_type.label().to_string(),
596 qubits: vec![qubit],
597 }),
598 };
599
600 let idx = self.errors.len();
601 self.errors.push(forced);
602 idx
603 }
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609
610 #[test]
611 fn test_empty_dem() {
612 let dem = DetectorErrorModel::new(5, 2);
613 assert_eq!(dem.num_detectors, 5);
614 assert_eq!(dem.num_observables, 2);
615 assert!(dem.errors.is_empty());
616 }
617
618 #[test]
619 fn test_dem_to_string() {
620 let mut dem = DetectorErrorModel::new(2, 1);
621 dem.errors.push(DEMError {
622 probability: 0.01,
623 detector_targets: vec![0, 1],
624 observable_targets: vec![0],
625 source_location: None,
626 });
627
628 let dem_string = dem.to_dem_string();
629 assert!(dem_string.contains("error(0.010000)"));
630 assert!(dem_string.contains("D0"));
631 assert!(dem_string.contains("D1"));
632 assert!(dem_string.contains("L0"));
633 }
634
635 #[test]
636 fn test_dem_parse_roundtrip() {
637 let dem_str = r#"
638 # Test DEM
639 error(0.01) D0 D1
640 error(0.02) D2 L0
641 "#;
642
643 let dem = DetectorErrorModel::from_dem_string(dem_str).unwrap();
644 assert_eq!(dem.num_detectors, 3);
645 assert_eq!(dem.num_observables, 1);
646 assert_eq!(dem.errors.len(), 2);
647
648 assert!((dem.errors[0].probability - 0.01).abs() < 1e-10);
649 assert_eq!(dem.errors[0].detector_targets, vec![0, 1]);
650
651 assert!((dem.errors[1].probability - 0.02).abs() < 1e-10);
652 assert_eq!(dem.errors[1].detector_targets, vec![2]);
653 assert_eq!(dem.errors[1].observable_targets, vec![0]);
654 }
655
656 #[test]
657 fn test_dem_sample() {
658 let mut dem = DetectorErrorModel::new(3, 1);
659 dem.errors.push(DEMError {
661 probability: 1.0,
662 detector_targets: vec![0],
663 observable_targets: vec![],
664 source_location: None,
665 });
666
667 let (detector_flips, _) = dem.sample();
668 assert!(detector_flips[0]); assert!(!detector_flips[1]); assert!(!detector_flips[2]); }
672
673 #[test]
674 fn test_from_circuit_basic() {
675 let circuit_str = r#"
676 H 0
677 CNOT 0 1
678 M 0 1
679 DETECTOR rec[-1] rec[-2]
680 "#;
681
682 let circuit = StimCircuit::from_str(circuit_str).unwrap();
683 let dem = DetectorErrorModel::from_circuit(&circuit).unwrap();
684
685 assert_eq!(dem.num_detectors, 1);
686 assert_eq!(dem.num_observables, 0);
687 }
688
689 #[test]
690 fn test_dem_total_probability() {
691 let mut dem = DetectorErrorModel::new(2, 0);
692 dem.errors.push(DEMError {
693 probability: 0.01,
694 detector_targets: vec![0],
695 observable_targets: vec![],
696 source_location: None,
697 });
698 dem.errors.push(DEMError {
699 probability: 0.02,
700 detector_targets: vec![1],
701 observable_targets: vec![],
702 source_location: None,
703 });
704
705 let total = dem.total_error_probability();
706 assert!((total - 0.03).abs() < 1e-10);
707 }
708}