Skip to main content

prism_q/sim/
shots.rs

1use std::collections::HashMap;
2
3use rand::Rng;
4use rand::SeedableRng;
5use rand_chacha::ChaCha8Rng;
6
7use super::compiled;
8use super::Probabilities;
9
10/// Result of a multi-shot simulation run.
11#[derive(Debug, Clone)]
12pub struct ShotsResult {
13    /// Classical measurement outcomes for each shot.
14    /// `shots[i][j]` is the j-th classical bit from the i-th shot.
15    pub shots: Vec<Vec<bool>>,
16    pub(crate) num_classical_bits: usize,
17}
18
19impl ShotsResult {
20    /// Build a frequency histogram of measurement outcomes.
21    ///
22    /// Keys are packed `Vec<u64>` where bit `i` of word `i/64` corresponds
23    /// to classical bit `i`. Use [`bitstring`] to format keys for display.
24    pub fn counts(&self) -> HashMap<Vec<u64>, u64> {
25        let m_words = self.num_classical_bits.div_ceil(64).max(1);
26        let mut counts: HashMap<Vec<u64>, u64> = HashMap::new();
27        for shot in &self.shots {
28            let mut key = vec![0u64; m_words];
29            for (i, &b) in shot.iter().enumerate() {
30                if b {
31                    key[i / 64] |= 1u64 << (i % 64);
32                }
33            }
34            *counts.entry(key).or_insert(0) += 1;
35        }
36        counts
37    }
38
39    pub fn num_shots(&self) -> usize {
40        self.shots.len()
41    }
42
43    pub fn num_classical_bits(&self) -> usize {
44        self.num_classical_bits
45    }
46}
47
48impl std::fmt::Display for ShotsResult {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        let counts = self.counts();
51        let mut entries: Vec<_> = counts.into_iter().collect();
52        entries.sort_by_key(|e| std::cmp::Reverse(e.1));
53        for (bits, count) in &entries {
54            let bs = bitstring(bits, self.num_classical_bits);
55            writeln!(f, "{bs}: {count}")?;
56        }
57        Ok(())
58    }
59}
60
61/// Format a packed `Vec<u64>` key (from [`ShotsResult::counts`]) as a binary string.
62///
63/// Bit 0 of the first word corresponds to classical bit 0 (leftmost character).
64pub fn bitstring(key: &[u64], num_bits: usize) -> String {
65    let mut s = String::with_capacity(num_bits);
66    for i in 0..num_bits {
67        let word = i / 64;
68        let bit = i % 64;
69        if word < key.len() && (key[word] >> bit) & 1 == 1 {
70            s.push('1');
71        } else {
72            s.push('0');
73        }
74    }
75    s
76}
77
78fn build_cdf(probs: &[f64]) -> Vec<f64> {
79    let mut cdf = Vec::with_capacity(probs.len());
80    let mut acc = 0.0;
81    for &p in probs {
82        acc += p;
83        cdf.push(acc);
84    }
85    if let Some(last) = cdf.last_mut() {
86        *last = 1.0;
87    }
88    cdf
89}
90
91fn sample_from_cdf(cdf: &[f64], r: f64) -> usize {
92    match cdf.binary_search_by(|p| p.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)) {
93        Ok(i) => i,
94        Err(i) => i.min(cdf.len() - 1),
95    }
96}
97
98pub(super) fn sample_shots(
99    probs: &Probabilities,
100    meas_map: &[(usize, usize)],
101    num_classical_bits: usize,
102    num_shots: usize,
103    seed: u64,
104) -> Vec<Vec<bool>> {
105    let mut rng = ChaCha8Rng::seed_from_u64(seed);
106
107    if meas_map.is_empty() {
108        return vec![vec![false; num_classical_bits]; num_shots];
109    }
110
111    let mut indices = Vec::with_capacity(num_shots);
112
113    match probs {
114        Probabilities::Dense(v) => {
115            let cdf = build_cdf(v);
116            for _ in 0..num_shots {
117                let r: f64 = rng.random();
118                indices.push(sample_from_cdf(&cdf, r));
119            }
120        }
121        Probabilities::Factored { blocks, .. } => {
122            let block_cdfs: Vec<Vec<f64>> = blocks.iter().map(|b| build_cdf(&b.probs)).collect();
123            for _ in 0..num_shots {
124                let mut global_idx = 0usize;
125                for (block, cdf) in blocks.iter().zip(block_cdfs.iter()) {
126                    let r: f64 = rng.random();
127                    let local_idx = sample_from_cdf(cdf, r);
128                    let mut m = block.mask;
129                    let mut bit = 0;
130                    while m != 0 {
131                        let pos = m.trailing_zeros() as usize;
132                        if local_idx & (1 << bit) != 0 {
133                            global_idx |= 1 << pos;
134                        }
135                        bit += 1;
136                        m &= m.wrapping_sub(1);
137                    }
138                }
139                indices.push(global_idx);
140            }
141        }
142    }
143
144    let mut flat = vec![false; num_shots * num_classical_bits];
145    for (s, &state_idx) in indices.iter().enumerate() {
146        let base = s * num_classical_bits;
147        for &(qubit, cbit) in meas_map {
148            flat[base + cbit] = (state_idx >> qubit) & 1 == 1;
149        }
150    }
151
152    let mut shots = Vec::with_capacity(num_shots);
153    for chunk in flat.chunks_exact(num_classical_bits) {
154        shots.push(chunk.to_vec());
155    }
156    shots
157}
158
159pub(super) fn packed_shots_to_classical_bits(
160    packed: &compiled::PackedShots,
161    meas_map: &[(usize, usize)],
162    num_classical_bits: usize,
163) -> Vec<Vec<bool>> {
164    let dense_identity_map = meas_map.len() == num_classical_bits
165        && meas_map
166            .iter()
167            .enumerate()
168            .all(|(idx, &(_, classical_bit))| idx == classical_bit);
169    if dense_identity_map {
170        return packed.to_shots();
171    }
172
173    let mut shots = vec![vec![false; num_classical_bits]; packed.num_shots()];
174    for (measurement, &(_, classical_bit)) in meas_map.iter().enumerate() {
175        if classical_bit >= num_classical_bits {
176            continue;
177        }
178        for (shot_idx, shot) in shots.iter_mut().enumerate() {
179            shot[classical_bit] = packed.get_bit(shot_idx, measurement);
180        }
181    }
182    shots
183}