1use crate::error::{Result, SimulatorError};
20use crate::stim_executor::{DetectorRecord, ObservableRecord, StimExecutor};
21use crate::stim_parser::{PauliTarget, PauliType, StimCircuit, StimInstruction};
22use std::collections::{HashMap, HashSet};
23
24#[derive(Debug, Clone)]
26pub struct DEMError {
27 pub probability: f64,
29 pub detector_targets: Vec<usize>,
31 pub observable_targets: Vec<usize>,
33 pub source_location: Option<ErrorLocation>,
35}
36
37#[derive(Debug, Clone)]
39pub struct ErrorLocation {
40 pub instruction_index: usize,
42 pub error_type: String,
44 pub qubits: Vec<usize>,
46}
47
48#[derive(Debug, Clone)]
50pub struct DetectorErrorModel {
51 pub num_detectors: usize,
53 pub num_observables: usize,
55 pub errors: Vec<DEMError>,
57 pub coordinate_shifts: Vec<Vec<f64>>,
59 pub detector_coords: HashMap<usize, Vec<f64>>,
61}
62
63impl DetectorErrorModel {
64 #[must_use]
66 pub fn new(num_detectors: usize, num_observables: usize) -> Self {
67 Self {
68 num_detectors,
69 num_observables,
70 errors: Vec::new(),
71 coordinate_shifts: Vec::new(),
72 detector_coords: HashMap::new(),
73 }
74 }
75
76 pub fn from_circuit(circuit: &StimCircuit) -> Result<Self> {
83 let mut clean_executor = StimExecutor::from_circuit(circuit);
85 let clean_result = clean_executor.execute(circuit)?;
86
87 let num_detectors = clean_result.num_detectors;
88 let num_observables = clean_result.num_observables;
89
90 let mut dem = Self::new(num_detectors, num_observables);
91
92 for detector in clean_executor.detectors() {
94 if !detector.coordinates.is_empty() {
95 dem.detector_coords
96 .insert(detector.index, detector.coordinates.clone());
97 }
98 }
99
100 let mut instruction_index = 0;
102 for instruction in &circuit.instructions {
103 match instruction {
104 StimInstruction::XError {
105 probability,
106 qubits,
107 }
108 | StimInstruction::YError {
109 probability,
110 qubits,
111 }
112 | StimInstruction::ZError {
113 probability,
114 qubits,
115 } => {
116 let error_type = match instruction {
117 StimInstruction::XError { .. } => "X",
118 StimInstruction::YError { .. } => "Y",
119 _ => "Z",
120 };
121
122 for &qubit in qubits {
123 let dem_error = Self::analyze_single_qubit_error(
124 circuit,
125 instruction_index,
126 error_type,
127 qubit,
128 *probability,
129 &clean_result.detector_values,
130 &clean_result.observable_values,
131 )?;
132
133 if !dem_error.detector_targets.is_empty()
134 || !dem_error.observable_targets.is_empty()
135 {
136 dem.errors.push(dem_error);
137 }
138 }
139 }
140
141 StimInstruction::Depolarize1 {
142 probability,
143 qubits,
144 } => {
145 let per_pauli_prob = probability / 3.0;
147 for &qubit in qubits {
148 for error_type in &["X", "Y", "Z"] {
149 let dem_error = Self::analyze_single_qubit_error(
150 circuit,
151 instruction_index,
152 error_type,
153 qubit,
154 per_pauli_prob,
155 &clean_result.detector_values,
156 &clean_result.observable_values,
157 )?;
158
159 if !dem_error.detector_targets.is_empty()
160 || !dem_error.observable_targets.is_empty()
161 {
162 dem.errors.push(dem_error);
163 }
164 }
165 }
166 }
167
168 StimInstruction::CorrelatedError {
169 probability,
170 targets,
171 }
172 | StimInstruction::ElseCorrelatedError {
173 probability,
174 targets,
175 } => {
176 let dem_error = Self::analyze_correlated_error(
177 circuit,
178 instruction_index,
179 targets,
180 *probability,
181 &clean_result.detector_values,
182 &clean_result.observable_values,
183 )?;
184
185 if !dem_error.detector_targets.is_empty()
186 || !dem_error.observable_targets.is_empty()
187 {
188 dem.errors.push(dem_error);
189 }
190 }
191
192 StimInstruction::Depolarize2 {
193 probability,
194 qubit_pairs,
195 } => {
196 let per_pauli_prob = probability / 15.0;
198 for &(q1, q2) in qubit_pairs {
199 for p1 in &[PauliType::I, PauliType::X, PauliType::Y, PauliType::Z] {
200 for p2 in &[PauliType::I, PauliType::X, PauliType::Y, PauliType::Z] {
201 if *p1 == PauliType::I && *p2 == PauliType::I {
202 continue; }
204 let targets = vec![
205 PauliTarget {
206 pauli: *p1,
207 qubit: q1,
208 },
209 PauliTarget {
210 pauli: *p2,
211 qubit: q2,
212 },
213 ];
214 let dem_error = Self::analyze_correlated_error(
215 circuit,
216 instruction_index,
217 &targets,
218 per_pauli_prob,
219 &clean_result.detector_values,
220 &clean_result.observable_values,
221 )?;
222
223 if !dem_error.detector_targets.is_empty()
224 || !dem_error.observable_targets.is_empty()
225 {
226 dem.errors.push(dem_error);
227 }
228 }
229 }
230 }
231 }
232
233 _ => {}
234 }
235 instruction_index += 1;
236 }
237
238 dem.merge_duplicate_errors();
240
241 Ok(dem)
242 }
243
244 fn analyze_single_qubit_error(
246 circuit: &StimCircuit,
247 instruction_index: usize,
248 error_type: &str,
249 qubit: usize,
250 probability: f64,
251 clean_detectors: &[bool],
252 clean_observables: &[bool],
253 ) -> Result<DEMError> {
254 let mut modified_circuit = circuit.clone();
256
257 let mut detector_targets = Vec::new();
265 let mut observable_targets = Vec::new();
266
267 let mut executor = StimExecutor::from_circuit(circuit);
269 Ok(DEMError {
273 probability,
274 detector_targets,
275 observable_targets,
276 source_location: Some(ErrorLocation {
277 instruction_index,
278 error_type: format!("{}_ERROR", error_type),
279 qubits: vec![qubit],
280 }),
281 })
282 }
283
284 fn analyze_correlated_error(
286 circuit: &StimCircuit,
287 instruction_index: usize,
288 targets: &[PauliTarget],
289 probability: f64,
290 clean_detectors: &[bool],
291 clean_observables: &[bool],
292 ) -> Result<DEMError> {
293 let qubits: Vec<usize> = targets.iter().map(|t| t.qubit).collect();
294 let error_type = targets
295 .iter()
296 .map(|t| format!("{:?}{}", t.pauli, t.qubit))
297 .collect::<Vec<_>>()
298 .join(" ");
299
300 let mut detector_targets = Vec::new();
301 let mut observable_targets = Vec::new();
302
303 Ok(DEMError {
307 probability,
308 detector_targets,
309 observable_targets,
310 source_location: Some(ErrorLocation {
311 instruction_index,
312 error_type: format!("CORRELATED_ERROR {}", error_type),
313 qubits,
314 }),
315 })
316 }
317
318 fn merge_duplicate_errors(&mut self) {
320 let mut merged: HashMap<(Vec<usize>, Vec<usize>), DEMError> = HashMap::new();
321
322 for error in self.errors.drain(..) {
323 let key = (
324 error.detector_targets.clone(),
325 error.observable_targets.clone(),
326 );
327
328 if let Some(existing) = merged.get_mut(&key) {
329 existing.probability += error.probability;
332 } else {
333 merged.insert(key, error);
334 }
335 }
336
337 self.errors = merged.into_values().collect();
338 }
339
340 #[must_use]
342 pub fn to_dem_string(&self) -> String {
343 let mut output = String::new();
344
345 output.push_str("# Detector Error Model\n");
347 output.push_str(&format!(
348 "# {} detectors, {} observables\n",
349 self.num_detectors, self.num_observables
350 ));
351 output.push('\n');
352
353 let mut sorted_detectors: Vec<_> = self.detector_coords.iter().collect();
355 sorted_detectors.sort_by_key(|(k, _)| *k);
356 for (det_idx, coords) in sorted_detectors {
357 output.push_str(&format!(
358 "detector D{} ({}) # coordinates: {:?}\n",
359 det_idx,
360 coords
361 .iter()
362 .map(|c| c.to_string())
363 .collect::<Vec<_>>()
364 .join(", "),
365 coords
366 ));
367 }
368 if !self.detector_coords.is_empty() {
369 output.push('\n');
370 }
371
372 for error in &self.errors {
374 if error.probability > 0.0 {
375 output.push_str(&format!("error({:.6})", error.probability));
376
377 for &det in &error.detector_targets {
378 output.push_str(&format!(" D{}", det));
379 }
380
381 for &obs in &error.observable_targets {
382 output.push_str(&format!(" L{}", obs));
383 }
384
385 if let Some(ref loc) = error.source_location {
386 output.push_str(&format!(" # {}", loc.error_type));
387 }
388
389 output.push('\n');
390 }
391 }
392
393 output
394 }
395
396 pub fn from_dem_string(s: &str) -> Result<Self> {
398 let mut num_detectors = 0;
399 let mut num_observables = 0;
400 let mut errors = Vec::new();
401 let mut detector_coords = HashMap::new();
402
403 for line in s.lines() {
404 let line = line.trim();
405
406 if line.is_empty() || line.starts_with('#') {
408 continue;
409 }
410
411 if line.starts_with("detector") {
413 continue;
416 }
417
418 if line.starts_with("error(") {
420 let (prob_str, rest) = line
421 .strip_prefix("error(")
422 .and_then(|s| s.split_once(')'))
423 .ok_or_else(|| {
424 SimulatorError::InvalidOperation("Invalid error line format".to_string())
425 })?;
426
427 let probability = prob_str.parse::<f64>().map_err(|_| {
428 SimulatorError::InvalidOperation(format!("Invalid probability: {}", prob_str))
429 })?;
430
431 let mut detector_targets = Vec::new();
432 let mut observable_targets = Vec::new();
433
434 let targets_str = rest.split('#').next().unwrap_or(rest);
436 for token in targets_str.split_whitespace() {
437 if let Some(stripped) = token.strip_prefix('D') {
438 let idx = stripped.parse::<usize>().map_err(|_| {
439 SimulatorError::InvalidOperation(format!("Invalid detector: {}", token))
440 })?;
441 detector_targets.push(idx);
442 num_detectors = num_detectors.max(idx + 1);
443 } else if let Some(stripped) = token.strip_prefix('L') {
444 let idx = stripped.parse::<usize>().map_err(|_| {
445 SimulatorError::InvalidOperation(format!(
446 "Invalid observable: {}",
447 token
448 ))
449 })?;
450 observable_targets.push(idx);
451 num_observables = num_observables.max(idx + 1);
452 }
453 }
454
455 errors.push(DEMError {
456 probability,
457 detector_targets,
458 observable_targets,
459 source_location: None,
460 });
461 }
462 }
463
464 Ok(Self {
465 num_detectors,
466 num_observables,
467 errors,
468 coordinate_shifts: Vec::new(),
469 detector_coords,
470 })
471 }
472
473 pub fn sample(&self) -> (Vec<bool>, Vec<bool>) {
477 use scirs2_core::random::prelude::*;
478 let mut rng = thread_rng();
479
480 let mut detector_flips = vec![false; self.num_detectors];
481 let mut observable_flips = vec![false; self.num_observables];
482
483 for error in &self.errors {
484 if rng.gen_bool(error.probability.min(1.0)) {
485 for &det in &error.detector_targets {
487 if det < detector_flips.len() {
488 detector_flips[det] ^= true;
489 }
490 }
491 for &obs in &error.observable_targets {
492 if obs < observable_flips.len() {
493 observable_flips[obs] ^= true;
494 }
495 }
496 }
497 }
498
499 (detector_flips, observable_flips)
500 }
501
502 pub fn sample_batch(&self, num_shots: usize) -> Vec<(Vec<bool>, Vec<bool>)> {
504 (0..num_shots).map(|_| self.sample()).collect()
505 }
506
507 #[must_use]
509 pub fn total_error_probability(&self) -> f64 {
510 self.errors.iter().map(|e| e.probability).sum()
511 }
512
513 #[must_use]
515 pub fn num_error_mechanisms(&self) -> usize {
516 self.errors.len()
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_empty_dem() {
526 let dem = DetectorErrorModel::new(5, 2);
527 assert_eq!(dem.num_detectors, 5);
528 assert_eq!(dem.num_observables, 2);
529 assert!(dem.errors.is_empty());
530 }
531
532 #[test]
533 fn test_dem_to_string() {
534 let mut dem = DetectorErrorModel::new(2, 1);
535 dem.errors.push(DEMError {
536 probability: 0.01,
537 detector_targets: vec![0, 1],
538 observable_targets: vec![0],
539 source_location: None,
540 });
541
542 let dem_string = dem.to_dem_string();
543 assert!(dem_string.contains("error(0.010000)"));
544 assert!(dem_string.contains("D0"));
545 assert!(dem_string.contains("D1"));
546 assert!(dem_string.contains("L0"));
547 }
548
549 #[test]
550 fn test_dem_parse_roundtrip() {
551 let dem_str = r#"
552 # Test DEM
553 error(0.01) D0 D1
554 error(0.02) D2 L0
555 "#;
556
557 let dem = DetectorErrorModel::from_dem_string(dem_str).unwrap();
558 assert_eq!(dem.num_detectors, 3);
559 assert_eq!(dem.num_observables, 1);
560 assert_eq!(dem.errors.len(), 2);
561
562 assert!((dem.errors[0].probability - 0.01).abs() < 1e-10);
563 assert_eq!(dem.errors[0].detector_targets, vec![0, 1]);
564
565 assert!((dem.errors[1].probability - 0.02).abs() < 1e-10);
566 assert_eq!(dem.errors[1].detector_targets, vec![2]);
567 assert_eq!(dem.errors[1].observable_targets, vec![0]);
568 }
569
570 #[test]
571 fn test_dem_sample() {
572 let mut dem = DetectorErrorModel::new(3, 1);
573 dem.errors.push(DEMError {
575 probability: 1.0,
576 detector_targets: vec![0],
577 observable_targets: vec![],
578 source_location: None,
579 });
580
581 let (detector_flips, _) = dem.sample();
582 assert!(detector_flips[0]); assert!(!detector_flips[1]); assert!(!detector_flips[2]); }
586
587 #[test]
588 fn test_from_circuit_basic() {
589 let circuit_str = r#"
590 H 0
591 CNOT 0 1
592 M 0 1
593 DETECTOR rec[-1] rec[-2]
594 "#;
595
596 let circuit = StimCircuit::from_str(circuit_str).unwrap();
597 let dem = DetectorErrorModel::from_circuit(&circuit).unwrap();
598
599 assert_eq!(dem.num_detectors, 1);
600 assert_eq!(dem.num_observables, 0);
601 }
602
603 #[test]
604 fn test_dem_total_probability() {
605 let mut dem = DetectorErrorModel::new(2, 0);
606 dem.errors.push(DEMError {
607 probability: 0.01,
608 detector_targets: vec![0],
609 observable_targets: vec![],
610 source_location: None,
611 });
612 dem.errors.push(DEMError {
613 probability: 0.02,
614 detector_targets: vec![1],
615 observable_targets: vec![],
616 source_location: None,
617 });
618
619 let total = dem.total_error_probability();
620 assert!((total - 0.03).abs() < 1e-10);
621 }
622}