quantrs2_sim/
stim_sampler.rs1use crate::error::{Result, SimulatorError};
17use crate::stim_dem::DetectorErrorModel;
18use crate::stim_executor::{ExecutionResult, StimExecutor};
19use crate::stim_parser::StimCircuit;
20use scirs2_core::random::prelude::*;
21
22#[derive(Debug, Clone)]
24pub struct CompiledStimCircuit {
25 circuit: StimCircuit,
27 pub num_qubits: usize,
29 pub num_measurements: usize,
31 pub num_detectors: usize,
33 pub num_observables: usize,
35 dem: Option<DetectorErrorModel>,
37}
38
39impl CompiledStimCircuit {
40 pub fn compile(circuit: &StimCircuit) -> Result<Self> {
42 let mut executor = StimExecutor::from_circuit(circuit);
44 let result = executor.execute(circuit)?;
45
46 Ok(Self {
47 circuit: circuit.clone(),
48 num_qubits: circuit.num_qubits,
49 num_measurements: result.num_measurements,
50 num_detectors: result.num_detectors,
51 num_observables: result.num_observables,
52 dem: None,
53 })
54 }
55
56 pub fn compile_with_dem(circuit: &StimCircuit) -> Result<Self> {
58 let mut compiled = Self::compile(circuit)?;
59 compiled.dem = Some(DetectorErrorModel::from_circuit(circuit)?);
60 Ok(compiled)
61 }
62
63 #[must_use]
65 pub fn circuit(&self) -> &StimCircuit {
66 &self.circuit
67 }
68
69 #[must_use]
71 pub fn has_dem(&self) -> bool {
72 self.dem.is_some()
73 }
74}
75
76#[derive(Debug)]
78pub struct DetectorSampler {
79 compiled: CompiledStimCircuit,
81}
82
83impl DetectorSampler {
84 #[must_use]
86 pub fn new(compiled: CompiledStimCircuit) -> Self {
87 Self { compiled }
88 }
89
90 pub fn compile(circuit: &StimCircuit) -> Result<Self> {
92 Ok(Self::new(CompiledStimCircuit::compile(circuit)?))
93 }
94
95 pub fn compile_with_dem(circuit: &StimCircuit) -> Result<Self> {
97 Ok(Self::new(CompiledStimCircuit::compile_with_dem(circuit)?))
98 }
99
100 pub fn sample(&self) -> Result<ExecutionResult> {
102 let mut executor = StimExecutor::from_circuit(&self.compiled.circuit);
103 executor.execute(&self.compiled.circuit)
104 }
105
106 pub fn sample_detectors(&self) -> Result<Vec<bool>> {
108 let result = self.sample()?;
109 Ok(result.detector_values)
110 }
111
112 pub fn sample_measurements(&self) -> Result<Vec<bool>> {
114 let result = self.sample()?;
115 Ok(result.measurement_record)
116 }
117
118 pub fn sample_batch(&self, num_shots: usize) -> Result<Vec<ExecutionResult>> {
120 (0..num_shots).map(|_| self.sample()).collect()
121 }
122
123 pub fn sample_batch_detectors(&self, num_shots: usize) -> Result<Vec<Vec<bool>>> {
125 (0..num_shots).map(|_| self.sample_detectors()).collect()
126 }
127
128 pub fn sample_batch_detectors_packed(&self, num_shots: usize) -> Result<Vec<Vec<u8>>> {
130 let samples = self.sample_batch_detectors(num_shots)?;
131 Ok(samples.into_iter().map(|s| pack_bits(&s)).collect())
132 }
133
134 pub fn sample_batch_measurements_packed(&self, num_shots: usize) -> Result<Vec<Vec<u8>>> {
136 let samples: Vec<Vec<bool>> = (0..num_shots)
137 .map(|_| self.sample_measurements())
138 .collect::<Result<Vec<_>>>()?;
139 Ok(samples.into_iter().map(|s| pack_bits(&s)).collect())
140 }
141
142 pub fn sample_statistics(&self, num_shots: usize) -> Result<SampleStatistics> {
144 let samples = self.sample_batch(num_shots)?;
145
146 let mut detector_fire_counts = vec![0usize; self.compiled.num_detectors];
147 let mut measurement_one_counts = vec![0usize; self.compiled.num_measurements];
148 let mut total_detector_fires = 0;
149
150 for result in &samples {
151 for (i, &val) in result.detector_values.iter().enumerate() {
152 if val {
153 detector_fire_counts[i] += 1;
154 total_detector_fires += 1;
155 }
156 }
157 for (i, &val) in result.measurement_record.iter().enumerate() {
158 if val {
159 measurement_one_counts[i] += 1;
160 }
161 }
162 }
163
164 Ok(SampleStatistics {
165 num_shots,
166 num_detectors: self.compiled.num_detectors,
167 num_measurements: self.compiled.num_measurements,
168 detector_fire_counts,
169 measurement_one_counts,
170 total_detector_fires,
171 logical_error_rate: 0.0, })
173 }
174
175 #[must_use]
177 pub fn num_detectors(&self) -> usize {
178 self.compiled.num_detectors
179 }
180
181 #[must_use]
183 pub fn num_measurements(&self) -> usize {
184 self.compiled.num_measurements
185 }
186
187 #[must_use]
189 pub fn num_qubits(&self) -> usize {
190 self.compiled.num_qubits
191 }
192}
193
194#[derive(Debug, Clone)]
196pub struct SampleStatistics {
197 pub num_shots: usize,
199 pub num_detectors: usize,
201 pub num_measurements: usize,
203 pub detector_fire_counts: Vec<usize>,
205 pub measurement_one_counts: Vec<usize>,
207 pub total_detector_fires: usize,
209 pub logical_error_rate: f64,
211}
212
213impl SampleStatistics {
214 #[must_use]
216 pub fn detector_fire_rate(&self, detector_idx: usize) -> f64 {
217 if detector_idx < self.detector_fire_counts.len() && self.num_shots > 0 {
218 self.detector_fire_counts[detector_idx] as f64 / self.num_shots as f64
219 } else {
220 0.0
221 }
222 }
223
224 #[must_use]
226 pub fn average_detector_fires(&self) -> f64 {
227 if self.num_shots > 0 {
228 self.total_detector_fires as f64 / self.num_shots as f64
229 } else {
230 0.0
231 }
232 }
233
234 #[must_use]
236 pub fn any_detector_fire_rate(&self) -> f64 {
237 let shots_with_fire = self.detector_fire_counts.iter().filter(|&&c| c > 0).count();
238 if self.num_shots > 0 {
239 shots_with_fire as f64 / self.num_shots as f64
240 } else {
241 0.0
242 }
243 }
244}
245
246fn pack_bits(bits: &[bool]) -> Vec<u8> {
248 bits.chunks(8)
249 .map(|chunk| {
250 let mut byte = 0u8;
251 for (i, &bit) in chunk.iter().enumerate() {
252 if bit {
253 byte |= 1 << i;
254 }
255 }
256 byte
257 })
258 .collect()
259}
260
261fn unpack_bits(bytes: &[u8], num_bits: usize) -> Vec<bool> {
263 let mut bits = Vec::with_capacity(num_bits);
264 for (byte_idx, &byte) in bytes.iter().enumerate() {
265 for bit_idx in 0..8 {
266 if byte_idx * 8 + bit_idx >= num_bits {
267 break;
268 }
269 bits.push((byte >> bit_idx) & 1 == 1);
270 }
271 }
272 bits
273}
274
275pub fn compile_sampler(circuit: &StimCircuit) -> Result<DetectorSampler> {
277 DetectorSampler::compile(circuit)
278}
279
280pub fn compile_sampler_with_dem(circuit: &StimCircuit) -> Result<DetectorSampler> {
282 DetectorSampler::compile_with_dem(circuit)
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_compile_sampler() {
291 let circuit_str = r#"
292 H 0
293 CNOT 0 1
294 M 0 1
295 DETECTOR rec[-1] rec[-2]
296 "#;
297
298 let circuit = StimCircuit::from_str(circuit_str).unwrap();
299 let sampler = compile_sampler(&circuit).unwrap();
300
301 assert_eq!(sampler.num_qubits(), 2);
302 assert_eq!(sampler.num_measurements(), 2);
303 assert_eq!(sampler.num_detectors(), 1);
304 }
305
306 #[test]
307 fn test_sample_basic() {
308 let circuit_str = r#"
309 H 0
310 CNOT 0 1
311 M 0 1
312 "#;
313
314 let circuit = StimCircuit::from_str(circuit_str).unwrap();
315 let sampler = compile_sampler(&circuit).unwrap();
316
317 let result = sampler.sample().unwrap();
318 assert_eq!(result.measurement_record.len(), 2);
319 assert_eq!(result.measurement_record[0], result.measurement_record[1]);
321 }
322
323 #[test]
324 fn test_sample_batch() {
325 let circuit_str = r#"
326 M 0
327 "#;
328
329 let circuit = StimCircuit::from_str(circuit_str).unwrap();
330 let sampler = compile_sampler(&circuit).unwrap();
331
332 let results = sampler.sample_batch(10).unwrap();
333 assert_eq!(results.len(), 10);
334 for result in &results {
336 assert!(!result.measurement_record[0]);
337 }
338 }
339
340 #[test]
341 fn test_sample_detectors() {
342 let circuit_str = r#"
343 M 0 1
344 DETECTOR rec[-1] rec[-2]
345 "#;
346
347 let circuit = StimCircuit::from_str(circuit_str).unwrap();
348 let sampler = compile_sampler(&circuit).unwrap();
349
350 let detectors = sampler.sample_detectors().unwrap();
351 assert_eq!(detectors.len(), 1);
352 assert!(!detectors[0]); }
354
355 #[test]
356 fn test_sample_batch_packed() {
357 let circuit_str = r#"
358 M 0 1 2 3 4 5 6 7 8
359 DETECTOR rec[-1] rec[-2]
360 "#;
361
362 let circuit = StimCircuit::from_str(circuit_str).unwrap();
363 let sampler = compile_sampler(&circuit).unwrap();
364
365 let packed = sampler.sample_batch_measurements_packed(5).unwrap();
366 assert_eq!(packed.len(), 5);
367 assert_eq!(packed[0].len(), 2);
369 }
370
371 #[test]
372 fn test_sample_statistics() {
373 let circuit_str = r#"
374 M 0
375 DETECTOR rec[-1]
376 "#;
377
378 let circuit = StimCircuit::from_str(circuit_str).unwrap();
379 let sampler = compile_sampler(&circuit).unwrap();
380
381 let stats = sampler.sample_statistics(100).unwrap();
382 assert_eq!(stats.num_shots, 100);
383 assert_eq!(stats.num_detectors, 1);
384 assert_eq!(stats.num_measurements, 1);
385 }
386
387 #[test]
388 fn test_pack_unpack_bits() {
389 let bits = vec![true, false, true, true, false, false, true, false, true];
390 let packed = pack_bits(&bits);
391 let unpacked = unpack_bits(&packed, bits.len());
392 assert_eq!(bits, unpacked);
393 }
394
395 #[test]
396 fn test_compile_with_dem() {
397 let circuit_str = r#"
398 H 0
399 CNOT 0 1
400 M 0 1
401 DETECTOR rec[-1] rec[-2]
402 "#;
403
404 let circuit = StimCircuit::from_str(circuit_str).unwrap();
405 let sampler = compile_sampler_with_dem(&circuit).unwrap();
406
407 assert!(sampler.compiled.has_dem());
408 }
409}