quantrs2_sim/
stim_sampler.rs

1//! Stim circuit sampler for efficient batch sampling
2//!
3//! This module provides optimized sampling capabilities:
4//! - `compile_sampler()` - Compile a circuit for repeated sampling
5//! - `sample()` / `sample_batch()` - Efficient sampling methods
6//! - Bit-packed output formats
7//!
8//! ## Example
9//!
10//! ```ignore
11//! let circuit = StimCircuit::from_str("H 0\nCNOT 0 1\nM 0 1").unwrap();
12//! let sampler = DetectorSampler::compile(&circuit);
13//! let samples = sampler.sample_batch(1000);
14//! ```
15
16use crate::error::{Result, SimulatorError};
17use crate::stim_dem::DetectorErrorModel;
18use crate::stim_executor::{ExecutionResult, StimExecutor};
19use crate::stim_parser::StimCircuit;
20use scirs2_core::random::prelude::*;
21
22/// Compiled circuit for efficient sampling
23#[derive(Debug, Clone)]
24pub struct CompiledStimCircuit {
25    /// Original circuit
26    circuit: StimCircuit,
27    /// Number of qubits
28    pub num_qubits: usize,
29    /// Number of measurements
30    pub num_measurements: usize,
31    /// Number of detectors
32    pub num_detectors: usize,
33    /// Number of observables
34    pub num_observables: usize,
35    /// Pre-computed DEM for fast sampling (optional)
36    dem: Option<DetectorErrorModel>,
37}
38
39impl CompiledStimCircuit {
40    /// Compile a Stim circuit for sampling
41    pub fn compile(circuit: &StimCircuit) -> Result<Self> {
42        // Run once to determine counts
43        let mut executor = StimExecutor::from_circuit(circuit);
44        let result = executor.execute(circuit)?;
45
46        Ok(Self {
47            circuit: circuit.clone(),
48            num_qubits: circuit.num_qubits,
49            num_measurements: result.num_measurements,
50            num_detectors: result.num_detectors,
51            num_observables: result.num_observables,
52            dem: None,
53        })
54    }
55
56    /// Compile with DEM for faster error-only sampling
57    pub fn compile_with_dem(circuit: &StimCircuit) -> Result<Self> {
58        let mut compiled = Self::compile(circuit)?;
59        compiled.dem = Some(DetectorErrorModel::from_circuit(circuit)?);
60        Ok(compiled)
61    }
62
63    /// Get the underlying circuit
64    #[must_use]
65    pub fn circuit(&self) -> &StimCircuit {
66        &self.circuit
67    }
68
69    /// Check if DEM is available
70    #[must_use]
71    pub fn has_dem(&self) -> bool {
72        self.dem.is_some()
73    }
74}
75
76/// Detector sampler for efficient batch sampling
77#[derive(Debug)]
78pub struct DetectorSampler {
79    /// Compiled circuit
80    compiled: CompiledStimCircuit,
81}
82
83impl DetectorSampler {
84    /// Create a new detector sampler from a compiled circuit
85    #[must_use]
86    pub fn new(compiled: CompiledStimCircuit) -> Self {
87        Self { compiled }
88    }
89
90    /// Compile and create a sampler from a circuit
91    pub fn compile(circuit: &StimCircuit) -> Result<Self> {
92        Ok(Self::new(CompiledStimCircuit::compile(circuit)?))
93    }
94
95    /// Compile with DEM for faster sampling
96    pub fn compile_with_dem(circuit: &StimCircuit) -> Result<Self> {
97        Ok(Self::new(CompiledStimCircuit::compile_with_dem(circuit)?))
98    }
99
100    /// Sample once, returning full execution result
101    pub fn sample(&self) -> Result<ExecutionResult> {
102        let mut executor = StimExecutor::from_circuit(&self.compiled.circuit);
103        executor.execute(&self.compiled.circuit)
104    }
105
106    /// Sample once, returning only detector values
107    pub fn sample_detectors(&self) -> Result<Vec<bool>> {
108        let result = self.sample()?;
109        Ok(result.detector_values)
110    }
111
112    /// Sample once, returning only measurement values
113    pub fn sample_measurements(&self) -> Result<Vec<bool>> {
114        let result = self.sample()?;
115        Ok(result.measurement_record)
116    }
117
118    /// Sample batch with full results
119    pub fn sample_batch(&self, num_shots: usize) -> Result<Vec<ExecutionResult>> {
120        (0..num_shots).map(|_| self.sample()).collect()
121    }
122
123    /// Sample batch, returning only detector values
124    pub fn sample_batch_detectors(&self, num_shots: usize) -> Result<Vec<Vec<bool>>> {
125        (0..num_shots).map(|_| self.sample_detectors()).collect()
126    }
127
128    /// Sample batch, returning bit-packed detector values
129    pub fn sample_batch_detectors_packed(&self, num_shots: usize) -> Result<Vec<Vec<u8>>> {
130        let samples = self.sample_batch_detectors(num_shots)?;
131        Ok(samples.into_iter().map(|s| pack_bits(&s)).collect())
132    }
133
134    /// Sample batch, returning bit-packed measurement values
135    pub fn sample_batch_measurements_packed(&self, num_shots: usize) -> Result<Vec<Vec<u8>>> {
136        let samples: Vec<Vec<bool>> = (0..num_shots)
137            .map(|_| self.sample_measurements())
138            .collect::<Result<Vec<_>>>()?;
139        Ok(samples.into_iter().map(|s| pack_bits(&s)).collect())
140    }
141
142    /// Get statistics from samples
143    pub fn sample_statistics(&self, num_shots: usize) -> Result<SampleStatistics> {
144        let samples = self.sample_batch(num_shots)?;
145
146        let mut detector_fire_counts = vec![0usize; self.compiled.num_detectors];
147        let mut measurement_one_counts = vec![0usize; self.compiled.num_measurements];
148        let mut total_detector_fires = 0;
149
150        for result in &samples {
151            for (i, &val) in result.detector_values.iter().enumerate() {
152                if val {
153                    detector_fire_counts[i] += 1;
154                    total_detector_fires += 1;
155                }
156            }
157            for (i, &val) in result.measurement_record.iter().enumerate() {
158                if val {
159                    measurement_one_counts[i] += 1;
160                }
161            }
162        }
163
164        Ok(SampleStatistics {
165            num_shots,
166            num_detectors: self.compiled.num_detectors,
167            num_measurements: self.compiled.num_measurements,
168            detector_fire_counts,
169            measurement_one_counts,
170            total_detector_fires,
171            logical_error_rate: 0.0, // Would need observable tracking
172        })
173    }
174
175    /// Get number of detectors
176    #[must_use]
177    pub fn num_detectors(&self) -> usize {
178        self.compiled.num_detectors
179    }
180
181    /// Get number of measurements
182    #[must_use]
183    pub fn num_measurements(&self) -> usize {
184        self.compiled.num_measurements
185    }
186
187    /// Get number of qubits
188    #[must_use]
189    pub fn num_qubits(&self) -> usize {
190        self.compiled.num_qubits
191    }
192}
193
194/// Statistics from batch sampling
195#[derive(Debug, Clone)]
196pub struct SampleStatistics {
197    /// Number of shots taken
198    pub num_shots: usize,
199    /// Number of detectors
200    pub num_detectors: usize,
201    /// Number of measurements per shot
202    pub num_measurements: usize,
203    /// Number of times each detector fired
204    pub detector_fire_counts: Vec<usize>,
205    /// Number of times each measurement was 1
206    pub measurement_one_counts: Vec<usize>,
207    /// Total number of detector fires across all shots
208    pub total_detector_fires: usize,
209    /// Estimated logical error rate (if observables tracked)
210    pub logical_error_rate: f64,
211}
212
213impl SampleStatistics {
214    /// Get the fire rate for a specific detector
215    #[must_use]
216    pub fn detector_fire_rate(&self, detector_idx: usize) -> f64 {
217        if detector_idx < self.detector_fire_counts.len() && self.num_shots > 0 {
218            self.detector_fire_counts[detector_idx] as f64 / self.num_shots as f64
219        } else {
220            0.0
221        }
222    }
223
224    /// Get the average number of detector fires per shot
225    #[must_use]
226    pub fn average_detector_fires(&self) -> f64 {
227        if self.num_shots > 0 {
228            self.total_detector_fires as f64 / self.num_shots as f64
229        } else {
230            0.0
231        }
232    }
233
234    /// Get the probability of any detector firing
235    #[must_use]
236    pub fn any_detector_fire_rate(&self) -> f64 {
237        let shots_with_fire = self.detector_fire_counts.iter().filter(|&&c| c > 0).count();
238        if self.num_shots > 0 {
239            shots_with_fire as f64 / self.num_shots as f64
240        } else {
241            0.0
242        }
243    }
244}
245
246/// Pack boolean values into bytes (LSB first)
247fn pack_bits(bits: &[bool]) -> Vec<u8> {
248    bits.chunks(8)
249        .map(|chunk| {
250            let mut byte = 0u8;
251            for (i, &bit) in chunk.iter().enumerate() {
252                if bit {
253                    byte |= 1 << i;
254                }
255            }
256            byte
257        })
258        .collect()
259}
260
261/// Unpack bytes into boolean values
262fn unpack_bits(bytes: &[u8], num_bits: usize) -> Vec<bool> {
263    let mut bits = Vec::with_capacity(num_bits);
264    for (byte_idx, &byte) in bytes.iter().enumerate() {
265        for bit_idx in 0..8 {
266            if byte_idx * 8 + bit_idx >= num_bits {
267                break;
268            }
269            bits.push((byte >> bit_idx) & 1 == 1);
270        }
271    }
272    bits
273}
274
275/// Compile a circuit for sampling (convenience function)
276pub fn compile_sampler(circuit: &StimCircuit) -> Result<DetectorSampler> {
277    DetectorSampler::compile(circuit)
278}
279
280/// Compile a circuit with DEM for faster sampling (convenience function)
281pub fn compile_sampler_with_dem(circuit: &StimCircuit) -> Result<DetectorSampler> {
282    DetectorSampler::compile_with_dem(circuit)
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_compile_sampler() {
291        let circuit_str = r#"
292            H 0
293            CNOT 0 1
294            M 0 1
295            DETECTOR rec[-1] rec[-2]
296        "#;
297
298        let circuit = StimCircuit::from_str(circuit_str).unwrap();
299        let sampler = compile_sampler(&circuit).unwrap();
300
301        assert_eq!(sampler.num_qubits(), 2);
302        assert_eq!(sampler.num_measurements(), 2);
303        assert_eq!(sampler.num_detectors(), 1);
304    }
305
306    #[test]
307    fn test_sample_basic() {
308        let circuit_str = r#"
309            H 0
310            CNOT 0 1
311            M 0 1
312        "#;
313
314        let circuit = StimCircuit::from_str(circuit_str).unwrap();
315        let sampler = compile_sampler(&circuit).unwrap();
316
317        let result = sampler.sample().unwrap();
318        assert_eq!(result.measurement_record.len(), 2);
319        // Bell state: measurements should be correlated
320        assert_eq!(result.measurement_record[0], result.measurement_record[1]);
321    }
322
323    #[test]
324    fn test_sample_batch() {
325        let circuit_str = r#"
326            M 0
327        "#;
328
329        let circuit = StimCircuit::from_str(circuit_str).unwrap();
330        let sampler = compile_sampler(&circuit).unwrap();
331
332        let results = sampler.sample_batch(10).unwrap();
333        assert_eq!(results.len(), 10);
334        // |0⟩ state should always give 0
335        for result in &results {
336            assert!(!result.measurement_record[0]);
337        }
338    }
339
340    #[test]
341    fn test_sample_detectors() {
342        let circuit_str = r#"
343            M 0 1
344            DETECTOR rec[-1] rec[-2]
345        "#;
346
347        let circuit = StimCircuit::from_str(circuit_str).unwrap();
348        let sampler = compile_sampler(&circuit).unwrap();
349
350        let detectors = sampler.sample_detectors().unwrap();
351        assert_eq!(detectors.len(), 1);
352        assert!(!detectors[0]); // |00⟩: XOR = 0, no fire
353    }
354
355    #[test]
356    fn test_sample_batch_packed() {
357        let circuit_str = r#"
358            M 0 1 2 3 4 5 6 7 8
359            DETECTOR rec[-1] rec[-2]
360        "#;
361
362        let circuit = StimCircuit::from_str(circuit_str).unwrap();
363        let sampler = compile_sampler(&circuit).unwrap();
364
365        let packed = sampler.sample_batch_measurements_packed(5).unwrap();
366        assert_eq!(packed.len(), 5);
367        // 9 measurements = 2 bytes per shot
368        assert_eq!(packed[0].len(), 2);
369    }
370
371    #[test]
372    fn test_sample_statistics() {
373        let circuit_str = r#"
374            M 0
375            DETECTOR rec[-1]
376        "#;
377
378        let circuit = StimCircuit::from_str(circuit_str).unwrap();
379        let sampler = compile_sampler(&circuit).unwrap();
380
381        let stats = sampler.sample_statistics(100).unwrap();
382        assert_eq!(stats.num_shots, 100);
383        assert_eq!(stats.num_detectors, 1);
384        assert_eq!(stats.num_measurements, 1);
385    }
386
387    #[test]
388    fn test_pack_unpack_bits() {
389        let bits = vec![true, false, true, true, false, false, true, false, true];
390        let packed = pack_bits(&bits);
391        let unpacked = unpack_bits(&packed, bits.len());
392        assert_eq!(bits, unpacked);
393    }
394
395    #[test]
396    fn test_compile_with_dem() {
397        let circuit_str = r#"
398            H 0
399            CNOT 0 1
400            M 0 1
401            DETECTOR rec[-1] rec[-2]
402        "#;
403
404        let circuit = StimCircuit::from_str(circuit_str).unwrap();
405        let sampler = compile_sampler_with_dem(&circuit).unwrap();
406
407        assert!(sampler.compiled.has_dem());
408    }
409}