1use std::collections::HashMap;
2
3use rand::Rng;
4use rand::SeedableRng;
5use rand_chacha::ChaCha8Rng;
6
7use super::compiled;
8use super::Probabilities;
9
10#[derive(Debug, Clone)]
12pub struct ShotsResult {
13 pub shots: Vec<Vec<bool>>,
16 pub(crate) num_classical_bits: usize,
17}
18
19impl ShotsResult {
20 pub fn counts(&self) -> HashMap<Vec<u64>, u64> {
25 let m_words = self.num_classical_bits.div_ceil(64).max(1);
26 let mut counts: HashMap<Vec<u64>, u64> = HashMap::new();
27 for shot in &self.shots {
28 let mut key = vec![0u64; m_words];
29 for (i, &b) in shot.iter().enumerate() {
30 if b {
31 key[i / 64] |= 1u64 << (i % 64);
32 }
33 }
34 *counts.entry(key).or_insert(0) += 1;
35 }
36 counts
37 }
38
39 pub fn num_shots(&self) -> usize {
40 self.shots.len()
41 }
42
43 pub fn num_classical_bits(&self) -> usize {
44 self.num_classical_bits
45 }
46}
47
48impl std::fmt::Display for ShotsResult {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 let counts = self.counts();
51 let mut entries: Vec<_> = counts.into_iter().collect();
52 entries.sort_by_key(|e| std::cmp::Reverse(e.1));
53 for (bits, count) in &entries {
54 let bs = bitstring(bits, self.num_classical_bits);
55 writeln!(f, "{bs}: {count}")?;
56 }
57 Ok(())
58 }
59}
60
61pub fn bitstring(key: &[u64], num_bits: usize) -> String {
65 let mut s = String::with_capacity(num_bits);
66 for i in 0..num_bits {
67 let word = i / 64;
68 let bit = i % 64;
69 if word < key.len() && (key[word] >> bit) & 1 == 1 {
70 s.push('1');
71 } else {
72 s.push('0');
73 }
74 }
75 s
76}
77
78fn build_cdf(probs: &[f64]) -> Vec<f64> {
79 let mut cdf = Vec::with_capacity(probs.len());
80 let mut acc = 0.0;
81 for &p in probs {
82 acc += p;
83 cdf.push(acc);
84 }
85 if let Some(last) = cdf.last_mut() {
86 *last = 1.0;
87 }
88 cdf
89}
90
91fn sample_from_cdf(cdf: &[f64], r: f64) -> usize {
92 match cdf.binary_search_by(|p| p.partial_cmp(&r).unwrap_or(std::cmp::Ordering::Equal)) {
93 Ok(i) => i,
94 Err(i) => i.min(cdf.len() - 1),
95 }
96}
97
98pub(super) fn sample_shots(
99 probs: &Probabilities,
100 meas_map: &[(usize, usize)],
101 num_classical_bits: usize,
102 num_shots: usize,
103 seed: u64,
104) -> Vec<Vec<bool>> {
105 let mut rng = ChaCha8Rng::seed_from_u64(seed);
106
107 if meas_map.is_empty() {
108 return vec![vec![false; num_classical_bits]; num_shots];
109 }
110
111 let mut indices = Vec::with_capacity(num_shots);
112
113 match probs {
114 Probabilities::Dense(v) => {
115 let cdf = build_cdf(v);
116 for _ in 0..num_shots {
117 let r: f64 = rng.random();
118 indices.push(sample_from_cdf(&cdf, r));
119 }
120 }
121 Probabilities::Factored { blocks, .. } => {
122 let block_cdfs: Vec<Vec<f64>> = blocks.iter().map(|b| build_cdf(&b.probs)).collect();
123 for _ in 0..num_shots {
124 let mut global_idx = 0usize;
125 for (block, cdf) in blocks.iter().zip(block_cdfs.iter()) {
126 let r: f64 = rng.random();
127 let local_idx = sample_from_cdf(cdf, r);
128 let mut m = block.mask;
129 let mut bit = 0;
130 while m != 0 {
131 let pos = m.trailing_zeros() as usize;
132 if local_idx & (1 << bit) != 0 {
133 global_idx |= 1 << pos;
134 }
135 bit += 1;
136 m &= m.wrapping_sub(1);
137 }
138 }
139 indices.push(global_idx);
140 }
141 }
142 }
143
144 let mut flat = vec![false; num_shots * num_classical_bits];
145 for (s, &state_idx) in indices.iter().enumerate() {
146 let base = s * num_classical_bits;
147 for &(qubit, cbit) in meas_map {
148 flat[base + cbit] = (state_idx >> qubit) & 1 == 1;
149 }
150 }
151
152 let mut shots = Vec::with_capacity(num_shots);
153 for chunk in flat.chunks_exact(num_classical_bits) {
154 shots.push(chunk.to_vec());
155 }
156 shots
157}
158
159pub(super) fn packed_shots_to_classical_bits(
160 packed: &compiled::PackedShots,
161 meas_map: &[(usize, usize)],
162 num_classical_bits: usize,
163) -> Vec<Vec<bool>> {
164 let dense_identity_map = meas_map.len() == num_classical_bits
165 && meas_map
166 .iter()
167 .enumerate()
168 .all(|(idx, &(_, classical_bit))| idx == classical_bit);
169 if dense_identity_map {
170 return packed.to_shots();
171 }
172
173 let mut shots = vec![vec![false; num_classical_bits]; packed.num_shots()];
174 for (measurement, &(_, classical_bit)) in meas_map.iter().enumerate() {
175 if classical_bit >= num_classical_bits {
176 continue;
177 }
178 for (shot_idx, shot) in shots.iter_mut().enumerate() {
179 shot[classical_bit] = packed.get_bit(shot_idx, measurement);
180 }
181 }
182 shots
183}