1mod accumulator;
2mod bts;
3pub(crate) mod parity;
4mod propagation;
5pub(crate) mod rng;
6#[cfg(test)]
7mod tests;
8
9use std::hash::{Hash, Hasher};
10
11use crate::circuit::{Circuit, Instruction, SmallVec};
12use crate::error::{PrismError, Result};
13#[cfg(feature = "gpu")]
14use crate::gpu::kernels::bts::GpuBtsCache;
15use crate::sim::ShotsResult;
16use rand::{RngCore, SeedableRng};
17use rand_chacha::ChaCha8Rng;
18
19use bts::{bts_batched, bts_single_pass, sample_bts_meas_major, BTS_BATCH_SHOTS};
20use rng::{binomial_sample, Xoshiro256PlusPlus};
21
22pub use accumulator::{
23 default_chunk_size, optimal_chunk_size, CorrelatorAccumulator, HistogramAccumulator,
24 MarginalsAccumulator, NullAccumulator, PauliExpectationAccumulator, ShotAccumulator,
25};
26pub use parity::ParityStats;
27use parity::{build_parity_blocks_if_useful, build_xor_dag_if_useful, minimize_flip_row_weight};
28pub(crate) use parity::{ParityBlock, ParityBlocks, SparseParity, XorDag};
29
30pub(crate) use propagation::batch_propagate_backward;
31pub(crate) use propagation::propagate_backward;
32use propagation::{
33 build_measurement_rows, colmajor_forward_sim, compute_reference_bits, rowmul_phase,
34 rowmul_phase_into,
35};
36
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub(crate) struct PauliVec {
39 pub(crate) x: Vec<u64>,
40 pub(crate) z: Vec<u64>,
41}
42
43impl PauliVec {
44 pub fn new(num_words: usize) -> Self {
45 Self {
46 x: vec![0u64; num_words],
47 z: vec![0u64; num_words],
48 }
49 }
50
51 pub fn z_on_qubit(num_words: usize, qubit: usize) -> Self {
52 let mut pv = Self::new(num_words);
53 pv.z[qubit / 64] |= 1u64 << (qubit % 64);
54 pv
55 }
56
57 #[inline(always)]
58 pub fn is_diagonal(&self) -> bool {
59 self.x.iter().all(|&w| w == 0)
60 }
61
62 #[inline(always)]
63 pub fn has_x_or_y(&self, qubit: usize) -> bool {
64 get_bit(&self.x, qubit)
65 }
66}
67
68impl Hash for PauliVec {
69 fn hash<H: Hasher>(&self, state: &mut H) {
70 self.x.hash(state);
71 self.z.hash(state);
72 }
73}
74
75#[inline(always)]
76pub(crate) fn get_bit(words: &[u64], qubit: usize) -> bool {
77 (words[qubit / 64] >> (qubit % 64)) & 1 != 0
78}
79
80#[inline(always)]
81pub(crate) fn set_bit(words: &mut [u64], qubit: usize, val: bool) {
82 let word = qubit / 64;
83 let bit = qubit % 64;
84 if val {
85 words[word] |= 1u64 << bit;
86 } else {
87 words[word] &= !(1u64 << bit);
88 }
89}
90
91#[inline(always)]
92pub(crate) fn flip_bit(words: &mut [u64], qubit: usize) {
93 words[qubit / 64] ^= 1u64 << (qubit % 64);
94}
95
96fn gaussian_eliminate(rows: &mut [Vec<u64>], num_cols: usize) -> (usize, Vec<usize>) {
97 let num_rows = rows.len();
98 let mut pivot_cols: Vec<usize> = Vec::new();
99 let mut current_row = 0;
100
101 for col in 0..num_cols {
102 let word = col / 64;
103 let bit = col % 64;
104 let mask = 1u64 << bit;
105
106 let pivot = rows[current_row..num_rows]
107 .iter()
108 .position(|row| row[word] & mask != 0)
109 .map(|i| i + current_row);
110
111 let pivot_row = match pivot {
112 Some(r) => r,
113 None => continue,
114 };
115
116 if pivot_row != current_row {
117 rows.swap(pivot_row, current_row);
118 }
119
120 let (top, rest) = rows.split_at_mut(current_row + 1);
121 let pivot_data = &top[current_row];
122 for row in rest.iter_mut() {
123 if row[word] & mask != 0 {
124 for w in 0..row.len() {
125 row[w] ^= pivot_data[w];
126 }
127 }
128 }
129
130 pivot_cols.push(col);
131 current_row += 1;
132 }
133
134 (current_row, pivot_cols)
135}
136
137const LUT_GROUP_SIZE: usize = 8;
138const LUT_MIN_RANK: usize = 8;
139
140struct FlipLut {
141 data: Vec<u64>,
142 m_words: usize,
143 num_full_groups: usize,
144 remainder_size: usize,
145}
146
147impl FlipLut {
148 fn build(flip_rows: &[Vec<u64>], m_words: usize) -> Self {
149 let rank = flip_rows.len();
150 let num_full_groups = rank / LUT_GROUP_SIZE;
151 let remainder_size = rank % LUT_GROUP_SIZE;
152 let total_groups = num_full_groups + usize::from(remainder_size > 0);
153 let entries_per_group = 1 << LUT_GROUP_SIZE;
154
155 let mut data = vec![0u64; total_groups * entries_per_group * m_words];
156
157 for g in 0..total_groups {
158 let group_start = g * LUT_GROUP_SIZE;
159 let k = if g < num_full_groups {
160 LUT_GROUP_SIZE
161 } else {
162 remainder_size
163 };
164 let lut_offset = g * entries_per_group * m_words;
165
166 for byte in 1..(1usize << k) {
167 let lowest = byte & byte.wrapping_neg();
168 let row_idx = group_start + lowest.trailing_zeros() as usize;
169 let prev = byte ^ lowest;
170
171 let dst_start = lut_offset + byte * m_words;
172 let src_start = lut_offset + prev * m_words;
173
174 for w in 0..m_words {
175 data[dst_start + w] = data[src_start + w] ^ flip_rows[row_idx][w];
176 }
177 }
178 }
179
180 Self {
181 data,
182 m_words,
183 num_full_groups,
184 remainder_size,
185 }
186 }
187
188 #[inline(always)]
189 fn lookup(&self, group: usize, byte: usize) -> &[u64] {
190 let offset = (group * (1 << LUT_GROUP_SIZE) + byte) * self.m_words;
191 &self.data[offset..offset + self.m_words]
192 }
193}
194
195pub struct CompiledSampler {
196 flip_rows: Vec<Vec<u64>>,
197 ref_bits_packed: Vec<u64>,
198 rank: usize,
199 num_measurements: usize,
200 rng: ChaCha8Rng,
201 lut: Option<FlipLut>,
202 sparse: Option<SparseParity>,
203 xor_dag: Option<XorDag>,
204 parity_blocks: Option<ParityBlocks>,
205 #[cfg(feature = "gpu")]
206 gpu_context: Option<std::sync::Arc<crate::gpu::GpuContext>>,
207 #[cfg(feature = "gpu")]
208 gpu_bts_cache: Option<GpuBtsCache>,
209}
210
211pub struct CompiledDetectorSampler {
213 measurement_sampler: CompiledSampler,
214 detector_rows: Vec<Vec<usize>>,
215 observable_rows: Vec<Vec<usize>>,
216 num_measurements: usize,
217}
218
219#[derive(Debug, Clone)]
221pub struct DetectorSampleBatch {
222 pub measurements: PackedShots,
223 pub detectors: PackedShots,
224 pub observables: PackedShots,
225}
226
227fn pack_bools(bools: &[bool]) -> Vec<u64> {
228 let n_words = bools.len().div_ceil(64);
229 let mut packed = vec![0u64; n_words];
230 for (i, &b) in bools.iter().enumerate() {
231 if b {
232 packed[i / 64] |= 1u64 << (i % 64);
233 }
234 }
235 packed
236}
237
238#[cfg(feature = "parallel")]
239#[derive(Clone, Copy)]
240struct SendPtrU64(*mut u64);
241#[cfg(feature = "parallel")]
242unsafe impl Send for SendPtrU64 {}
245#[cfg(feature = "parallel")]
246unsafe impl Sync for SendPtrU64 {}
249#[cfg(feature = "parallel")]
250impl SendPtrU64 {
251 #[inline(always)]
252 unsafe fn copy_from_slice(self, dst_offset: usize, src: &[u64]) {
253 std::ptr::copy_nonoverlapping(src.as_ptr(), self.0.add(dst_offset), src.len());
254 }
255}
256
257#[inline(always)]
258fn scatter_meas_major_rows(
259 dst: &mut [u64],
260 dst_s_words: usize,
261 src: &[u64],
262 src_s_words: usize,
263 meas_indices: &[usize],
264) {
265 debug_assert_eq!(src.len(), meas_indices.len() * src_s_words);
266 for (local_m, &global_m) in meas_indices.iter().enumerate() {
267 let src_row = &src[local_m * src_s_words..(local_m + 1) * src_s_words];
268 let dst_offset = global_m * dst_s_words;
269 dst[dst_offset..dst_offset + src_s_words].copy_from_slice(src_row);
270 }
271}
272
273#[inline(always)]
274fn project_local_flip_row(local_row: &[u64], mapping: &[usize], m_words: usize) -> Vec<u64> {
275 let mut global_row = vec![0u64; m_words];
276 for (local_word, &word) in local_row.iter().enumerate() {
277 let mut bits = word;
278 while bits != 0 {
279 let bit = bits.trailing_zeros() as usize;
280 let local_m = local_word * 64 + bit;
281 debug_assert!(local_m < mapping.len());
282 let global_m = mapping[local_m];
283 global_row[global_m / 64] |= 1u64 << (global_m % 64);
284 bits &= bits - 1;
285 }
286 }
287 global_row
288}
289
290fn build_sparse_from_filtered_blocks(
291 block_samplers: &[CompiledSampler],
292 meas_map: &[(usize, usize)],
293 rank_offsets: &[usize],
294 num_global_measurements: usize,
295) -> SparseParity {
296 let mut row_offsets = Vec::with_capacity(num_global_measurements + 1);
297 let mut col_indices = Vec::new();
298
299 for &(block_idx, local_meas) in meas_map {
300 row_offsets.push(col_indices.len() as u32);
301 let sparse = block_samplers[block_idx]
302 .sparse
303 .as_ref()
304 .expect("filtered block samplers must retain sparse parity");
305 let rank_offset = rank_offsets[block_idx] as u32;
306 for &col in sparse.row_cols(local_meas) {
307 col_indices.push(rank_offset + col);
308 }
309 }
310 row_offsets.push(col_indices.len() as u32);
311
312 let non_det_rows: Vec<u32> = (0..num_global_measurements as u32)
313 .filter(|&m| row_offsets[m as usize + 1] != row_offsets[m as usize])
314 .collect();
315
316 SparseParity {
317 col_indices,
318 row_offsets,
319 num_rows: num_global_measurements,
320 non_det_rows,
321 }
322}
323
324fn build_filtered_parity_blocks(
325 block_samplers: &[CompiledSampler],
326 local_to_global: &[Vec<usize>],
327) -> Option<ParityBlocks> {
328 let mut blocks = Vec::new();
329
330 for (block_idx, sampler) in block_samplers.iter().enumerate() {
331 let mapping = &local_to_global[block_idx];
332 if mapping.is_empty() {
333 continue;
334 }
335
336 if let Some(parity_blocks) = &sampler.parity_blocks {
337 for block in &parity_blocks.blocks {
338 let meas_indices = block
339 .meas_indices
340 .iter()
341 .map(|&local_m| mapping[local_m])
342 .collect();
343 blocks.push(ParityBlock {
344 meas_indices,
345 sparse: block.sparse.clone(),
346 block_rank: block.block_rank,
347 ref_bits_packed: block.ref_bits_packed.clone(),
348 });
349 }
350 continue;
351 }
352
353 let sparse = sampler
354 .sparse
355 .as_ref()
356 .expect("filtered block samplers must retain sparse parity")
357 .clone();
358 blocks.push(ParityBlock {
359 meas_indices: mapping.clone(),
360 sparse,
361 block_rank: sampler.rank,
362 ref_bits_packed: sampler.ref_bits_packed.clone(),
363 });
364 }
365
366 ParityBlocks::from_blocks_if_useful(blocks)
367}
368
369impl CompiledSampler {
370 pub fn rank(&self) -> usize {
371 self.rank
372 }
373
374 #[cfg(feature = "gpu")]
379 pub fn with_gpu(mut self, context: std::sync::Arc<crate::gpu::GpuContext>) -> Self {
380 self.gpu_context = Some(context);
381 self.gpu_bts_cache = None;
382 self
383 }
384
385 #[cfg(feature = "gpu")]
386 fn can_use_gpu_bts(&self) -> bool {
387 self.gpu_context.is_some()
388 && self.sparse.is_some()
389 && self.rank > 0
390 && self.parity_blocks.is_none()
391 && self.xor_dag.is_none()
392 }
393
394 #[cfg(feature = "gpu")]
395 pub(crate) fn should_use_gpu_bts(&self, num_shots: usize) -> bool {
396 let Some(sparse) = &self.sparse else {
397 return false;
398 };
399 if !self.can_use_gpu_bts()
400 || num_shots < crate::gpu::bts_min_shots()
401 || self.rank < crate::gpu::bts_min_rank()
402 {
403 return false;
404 }
405
406 let stats = sparse.stats();
407 let min_total_weight = sparse
408 .num_rows
409 .saturating_mul(crate::gpu::bts_min_weight_factor());
410 stats.total_weight >= min_total_weight
411 }
412
413 #[cfg(feature = "gpu")]
414 pub(crate) fn has_gpu_context(&self) -> bool {
415 self.gpu_context.is_some()
416 }
417
418 #[cfg(feature = "gpu")]
419 pub(crate) fn gpu_context(&self) -> Option<std::sync::Arc<crate::gpu::GpuContext>> {
420 self.gpu_context.clone()
421 }
422
423 #[cfg(feature = "gpu")]
424 fn should_try_gpu_counts(&self, total_shots: usize) -> bool {
425 if !self.should_use_gpu_bts(total_shots) || self.rank <= MAX_RANK_FOR_RANK_SPACE {
426 return false;
427 }
428
429 let m_words = self.num_measurements.div_ceil(64);
430 if m_words == 0 || m_words > crate::gpu::kernels::bts::GPU_COUNTS_MAX_WORDS {
431 return false;
432 }
433
434 let entropy_guard = total_shots.ilog2() as usize + 8;
435 self.rank <= entropy_guard
436 }
437
438 #[cfg(feature = "gpu")]
439 fn ensure_gpu_bts_cache(&mut self) -> Result<&mut GpuBtsCache> {
440 if self.gpu_bts_cache.is_none() {
441 let ctx = self
442 .gpu_context
443 .as_ref()
444 .expect("gpu BTS cache requested without gpu_context")
445 .clone();
446 let sparse = self
447 .sparse
448 .as_ref()
449 .expect("gpu BTS cache requested without sparse parity");
450 self.gpu_bts_cache = Some(GpuBtsCache::new(&ctx, sparse, &self.ref_bits_packed)?);
451 }
452 Ok(self
453 .gpu_bts_cache
454 .as_mut()
455 .expect("gpu BTS cache initialized above"))
456 }
457
458 pub fn num_measurements(&self) -> usize {
459 self.num_measurements
460 }
461
462 pub fn sparse(&self) -> Option<&SparseParity> {
463 self.sparse.as_ref()
464 }
465
466 pub fn parity_stats(&self) -> Option<ParityStats> {
467 self.sparse.as_ref().map(|s| s.stats())
468 }
469
470 pub fn sample(&mut self) -> Vec<bool> {
471 let num_meas_words = self.num_measurements.div_ceil(64);
472 let mut accum = vec![0u64; num_meas_words];
473 self.sample_into(&mut accum);
474 self.unpack_result(&accum)
475 }
476
477 pub(crate) fn sample_bulk_words_shot_major(&mut self, num_shots: usize) -> (Vec<u64>, usize) {
478 let m_words = self.num_measurements.div_ceil(64);
479 let mut accum = vec![0u64; num_shots * m_words];
480 let mut rand_buf = Vec::new();
481 self.sample_bulk_words_shot_major_reuse(&mut accum, &mut rand_buf, num_shots);
482 (accum, m_words)
483 }
484
485 pub(crate) fn sample_bulk_words_shot_major_reuse(
486 &mut self,
487 accum: &mut Vec<u64>,
488 rand_buf: &mut Vec<u8>,
489 num_shots: usize,
490 ) -> usize {
491 let m_words = self.num_measurements.div_ceil(64);
492 let needed = num_shots * m_words;
493 accum.resize(needed, 0);
494 accum[..needed].fill(0);
495 if num_shots == 0 || self.num_measurements == 0 || self.rank == 0 {
496 return m_words;
497 }
498
499 if let Some(lut) = &self.lut {
500 let total_groups = lut.num_full_groups + usize::from(lut.remainder_size > 0);
501 let bytes_per_shot = total_groups;
502 let total_bytes = num_shots * bytes_per_shot;
503 rand_buf.resize(total_bytes, 0);
504 {
505 let full_chunks = total_bytes / 8;
506 let tail = full_chunks * 8;
507 for i in 0..full_chunks {
508 let r = self.rng.next_u64();
509 rand_buf[i * 8..(i + 1) * 8].copy_from_slice(&r.to_le_bytes());
510 }
511 if tail < total_bytes {
512 let r = self.rng.next_u64();
513 let bytes = r.to_le_bytes();
514 rand_buf[tail..total_bytes].copy_from_slice(&bytes[..total_bytes - tail]);
515 }
516 if lut.remainder_size > 0 {
517 let remainder_mask = (1u8 << lut.remainder_size) - 1;
518 let last_group = lut.num_full_groups;
519 for s in 0..num_shots {
520 rand_buf[s * bytes_per_shot + last_group] &= remainder_mask;
521 }
522 }
523 }
524
525 let max_batch = if m_words > 0 {
526 (256 * 1024 / (m_words * 8)).max(64)
527 } else {
528 num_shots
529 };
530
531 #[cfg(feature = "parallel")]
532 const PAR_SHOT_THRESHOLD: usize = 256;
533
534 #[cfg(feature = "parallel")]
535 if num_shots >= PAR_SHOT_THRESHOLD {
536 use rayon::prelude::*;
537 let shots_per_chunk =
538 (num_shots.div_ceil(rayon::current_num_threads())).max(max_batch);
539 let chunk_m = shots_per_chunk * m_words;
540 accum
541 .par_chunks_mut(chunk_m)
542 .enumerate()
543 .for_each(|(ci, chunk)| {
544 let chunk_shots = chunk.len() / m_words;
545 let chunk_start = ci * shots_per_chunk;
546 for tile_start in (0..chunk_shots).step_by(max_batch) {
547 let tile_end = (tile_start + max_batch).min(chunk_shots);
548 for g in 0..total_groups {
549 for s in tile_start..tile_end {
550 let gs = chunk_start + s;
551 let byte = rand_buf[gs * bytes_per_shot + g] as usize;
552 let entry = lut.lookup(g, byte);
553 let base = s * m_words;
554 xor_words(&mut chunk[base..base + m_words], entry);
555 }
556 }
557 }
558 });
559 } else {
560 for tile_start in (0..num_shots).step_by(max_batch) {
561 let tile_end = (tile_start + max_batch).min(num_shots);
562 for g in 0..total_groups {
563 for s in tile_start..tile_end {
564 let byte = rand_buf[s * bytes_per_shot + g] as usize;
565 let entry = lut.lookup(g, byte);
566 let shot_base = s * m_words;
567 xor_words(&mut accum[shot_base..shot_base + m_words], entry);
568 }
569 }
570 }
571 }
572
573 #[cfg(not(feature = "parallel"))]
574 {
575 for tile_start in (0..num_shots).step_by(max_batch) {
576 let tile_end = (tile_start + max_batch).min(num_shots);
577 for g in 0..total_groups {
578 for s in tile_start..tile_end {
579 let byte = rand_buf[s * bytes_per_shot + g] as usize;
580 let entry = lut.lookup(g, byte);
581 let shot_base = s * m_words;
582 xor_words(&mut accum[shot_base..shot_base + m_words], entry);
583 }
584 }
585 }
586 }
587
588 m_words
589 } else {
590 for s in 0..num_shots {
591 let shot_base = s * m_words;
592 let shot_accum = &mut accum[shot_base..shot_base + m_words];
593 for j in 0..self.rank {
594 let bit = self.rng.next_u32() & 1;
595 if bit != 0 {
596 let row = &self.flip_rows[j];
597 xor_words(shot_accum, row);
598 }
599 }
600 }
601 m_words
602 }
603 }
604
605 fn should_use_bts(&self, num_shots: usize) -> bool {
606 if let Some(sparse) = &self.sparse {
607 if self.rank == 0 {
608 return false;
609 }
610 let m_words = self.num_measurements.div_ceil(64) as u64;
611 let lut_groups = (self.rank.div_ceil(LUT_GROUP_SIZE)) as u64;
612
613 let lut_alloc_bytes = num_shots as u64 * (lut_groups + m_words * 8);
614 if lut_alloc_bytes > MAX_LUT_ALLOC_BYTES {
615 return true;
616 }
617
618 let s_words = num_shots.div_ceil(64);
619 let stats = sparse.stats();
620 let bts_work = stats.total_weight as u64 * s_words as u64;
621 let lut_work = num_shots as u64 * lut_groups * m_words;
622 bts_work < lut_work
623 } else {
624 false
625 }
626 }
627
628 pub(crate) fn ref_bits_packed(&self) -> &[u64] {
629 &self.ref_bits_packed
630 }
631
632 pub fn sample_bulk(&mut self, num_shots: usize) -> Vec<Vec<bool>> {
633 self.sample_bulk_packed(num_shots).to_shots()
634 }
635
636 #[cfg(feature = "gpu")]
644 pub fn sample_bulk_packed_device(&mut self, num_shots: usize) -> Result<DevicePackedShots> {
645 if !self.can_use_gpu_bts() {
646 return Err(PrismError::BackendUnsupported {
647 backend: "CompiledSampler".to_string(),
648 operation: "sample_bulk_packed_device requires with_gpu() and flat sparse BTS"
649 .to_string(),
650 });
651 }
652
653 let ctx = self
654 .gpu_context
655 .as_ref()
656 .expect("gpu BTS selected without gpu_context")
657 .clone();
658 let rank = self.rank;
659 let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
660 let cache = self.ensure_gpu_bts_cache()?;
661 let data = crate::gpu::kernels::bts::launch_bts_sample_device(
662 &ctx,
663 &mut fast_rng,
664 rank,
665 num_shots,
666 cache,
667 )?;
668
669 Ok(DevicePackedShots {
670 context: ctx,
671 data,
672 num_shots,
673 num_measurements: self.num_measurements,
674 m_words: self.num_measurements.div_ceil(64),
675 s_words: num_shots.div_ceil(64),
676 layout: ShotLayout::MeasMajor,
677 rank,
678 })
679 }
680
681 #[cfg(feature = "gpu")]
690 pub(crate) fn try_sample_bulk_shot_major_gpu(
691 &mut self,
692 num_shots: usize,
693 ) -> Option<Result<Vec<u64>>> {
694 if !self.should_use_gpu_bts(num_shots) {
695 return None;
696 }
697 let ctx = self
698 .gpu_context
699 .as_ref()
700 .expect("gpu BTS selected without gpu_context")
701 .clone();
702 let rank = self.rank;
703 let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
704 let cache = match self.ensure_gpu_bts_cache() {
705 Ok(c) => c,
706 Err(e) => return Some(Err(e)),
707 };
708 Some(crate::gpu::kernels::bts::launch_bts_sample_shot_major_host(
709 &ctx,
710 &mut fast_rng,
711 rank,
712 num_shots,
713 cache,
714 ))
715 }
716
717 #[inline(always)]
718 fn sample_into(&mut self, accum: &mut [u64]) {
719 if self.rank == 0 {
720 return;
721 }
722
723 if let Some(lut) = &self.lut {
724 let mut rand_buf = 0u64;
725 let mut rand_pos = 8usize;
726
727 for g in 0..lut.num_full_groups {
728 if rand_pos >= 8 {
729 rand_buf = self.rng.next_u64();
730 rand_pos = 0;
731 }
732 let byte = ((rand_buf >> (rand_pos * 8)) & 0xFF) as usize;
733 rand_pos += 1;
734 let entry = lut.lookup(g, byte);
735 xor_words(accum, entry);
736 }
737 if lut.remainder_size > 0 {
738 if rand_pos >= 8 {
739 rand_buf = self.rng.next_u64();
740 }
741 let mask = (1u64 << lut.remainder_size) - 1;
742 let byte = (rand_buf & mask) as usize;
743 let entry = lut.lookup(lut.num_full_groups, byte);
744 xor_words(accum, entry);
745 }
746 } else {
747 for j in 0..self.rank {
748 let bit = self.rng.next_u32() & 1;
749 if bit != 0 {
750 let row = &self.flip_rows[j];
751 xor_words(accum, row);
752 }
753 }
754 }
755 }
756
757 #[inline(always)]
758 fn unpack_result(&self, accum: &[u64]) -> Vec<bool> {
759 let mut result = Vec::with_capacity(self.num_measurements);
760 for m in 0..self.num_measurements {
761 let w = m / 64;
762 let ref_word = if w < self.ref_bits_packed.len() {
763 self.ref_bits_packed[w]
764 } else {
765 0
766 };
767 let bit = ((accum[w] ^ ref_word) >> (m % 64)) & 1 != 0;
768 result.push(bit);
769 }
770 result
771 }
772
773 pub fn sample_bulk_packed(&mut self, num_shots: usize) -> PackedShots {
774 match self.sample_bulk_packed_inner(num_shots, false) {
775 Ok(packed) => packed,
776 Err(_) => unreachable!("compiled sampler CPU fallback should not fail"),
777 }
778 }
779
780 pub fn try_sample_bulk_packed(&mut self, num_shots: usize) -> Result<PackedShots> {
781 self.sample_bulk_packed_inner(num_shots, true)
782 }
783
784 fn sample_bulk_packed_inner(
785 &mut self,
786 num_shots: usize,
787 propagate_gpu_errors: bool,
788 ) -> Result<PackedShots> {
789 let m_words = self.num_measurements.div_ceil(64);
790 let s_words = num_shots.div_ceil(64);
791 if num_shots == 0 || self.num_measurements == 0 {
792 return Ok(PackedShots {
793 data: Vec::new(),
794 num_shots,
795 num_measurements: self.num_measurements,
796 m_words,
797 s_words,
798 layout: ShotLayout::ShotMajor,
799 });
800 }
801 if self.rank == 0 {
802 let mut data = vec![0u64; num_shots * m_words];
803 for s in 0..num_shots {
804 let base = s * m_words;
805 data[base..base + m_words].copy_from_slice(&self.ref_bits_packed);
806 }
807 return Ok(PackedShots {
808 data,
809 num_shots,
810 num_measurements: self.num_measurements,
811 m_words,
812 s_words,
813 layout: ShotLayout::ShotMajor,
814 });
815 }
816
817 if self.should_use_bts(num_shots) {
818 return self.sample_bulk_packed_bts(num_shots, m_words, s_words, propagate_gpu_errors);
819 }
820
821 let (mut data, _) = self.sample_bulk_words_shot_major(num_shots);
822 for s in 0..num_shots {
823 let base = s * m_words;
824 xor_words(&mut data[base..base + m_words], &self.ref_bits_packed);
825 }
826 Ok(PackedShots {
827 data,
828 num_shots,
829 num_measurements: self.num_measurements,
830 m_words,
831 s_words,
832 layout: ShotLayout::ShotMajor,
833 })
834 }
835
836 fn sample_bulk_packed_bts(
837 &mut self,
838 num_shots: usize,
839 m_words: usize,
840 s_words: usize,
841 propagate_gpu_errors: bool,
842 ) -> Result<PackedShots> {
843 #[cfg(not(feature = "gpu"))]
844 let _ = propagate_gpu_errors;
845
846 let num_meas = self.num_measurements;
847
848 if let Some(pb) = &self.parity_blocks {
849 let block_seeds: Vec<u64> = (0..pb.blocks.len()).map(|_| self.rng.next_u64()).collect();
850 let meas_major = if pb.direct_scatter {
851 let total_len = num_meas * s_words;
852 #[allow(clippy::uninit_vec)]
853 let mut meas_major = {
854 let mut v = Vec::with_capacity(total_len);
855 unsafe { v.set_len(total_len) };
859 v
860 };
861
862 #[cfg(feature = "parallel")]
863 {
864 use rayon::prelude::*;
865
866 let ptr = SendPtrU64(meas_major.as_mut_ptr());
867 pb.blocks
868 .par_iter()
869 .zip(block_seeds.par_iter())
870 .for_each(|(block, &seed)| {
871 let mut block_chacha = ChaCha8Rng::seed_from_u64(seed);
872 let mut block_rng = Xoshiro256PlusPlus::from_chacha(&mut block_chacha);
873 let block_data = sample_bts_meas_major(
874 &block.sparse,
875 num_shots,
876 &block.ref_bits_packed,
877 &mut block_rng,
878 block.block_rank,
879 );
880
881 unsafe {
885 for (local_m, &global_m) in block.meas_indices.iter().enumerate() {
886 let row =
887 &block_data[local_m * s_words..(local_m + 1) * s_words];
888 ptr.copy_from_slice(global_m * s_words, row);
889 }
890 }
891 });
892 }
893
894 #[cfg(not(feature = "parallel"))]
895 {
896 for (block, &seed) in pb.blocks.iter().zip(block_seeds.iter()) {
897 let mut block_chacha = ChaCha8Rng::seed_from_u64(seed);
898 let mut block_rng = Xoshiro256PlusPlus::from_chacha(&mut block_chacha);
899 let block_data = sample_bts_meas_major(
900 &block.sparse,
901 num_shots,
902 &block.ref_bits_packed,
903 &mut block_rng,
904 block.block_rank,
905 );
906 scatter_meas_major_rows(
907 &mut meas_major,
908 s_words,
909 &block_data,
910 s_words,
911 &block.meas_indices,
912 );
913 }
914 }
915
916 meas_major
917 } else {
918 #[cfg(feature = "parallel")]
919 let block_results: Vec<(Vec<u64>, &[usize])> = {
920 use rayon::prelude::*;
921 pb.blocks
922 .par_iter()
923 .zip(block_seeds.par_iter())
924 .map(|(block, &seed)| {
925 let mut block_chacha = ChaCha8Rng::seed_from_u64(seed);
926 let mut block_rng = Xoshiro256PlusPlus::from_chacha(&mut block_chacha);
927 let data = sample_bts_meas_major(
928 &block.sparse,
929 num_shots,
930 &block.ref_bits_packed,
931 &mut block_rng,
932 block.block_rank,
933 );
934 (data, block.meas_indices.as_slice())
935 })
936 .collect()
937 };
938
939 #[cfg(not(feature = "parallel"))]
940 let block_results: Vec<(Vec<u64>, &[usize])> = pb
941 .blocks
942 .iter()
943 .zip(block_seeds.iter())
944 .map(|(block, &seed)| {
945 let mut block_chacha = ChaCha8Rng::seed_from_u64(seed);
946 let mut block_rng = Xoshiro256PlusPlus::from_chacha(&mut block_chacha);
947 let data = sample_bts_meas_major(
948 &block.sparse,
949 num_shots,
950 &block.ref_bits_packed,
951 &mut block_rng,
952 block.block_rank,
953 );
954 (data, block.meas_indices.as_slice())
955 })
956 .collect();
957
958 let mut meas_major = vec![0u64; num_meas * s_words];
959 for (block_data, meas_indices) in &block_results {
960 scatter_meas_major_rows(
961 &mut meas_major,
962 s_words,
963 block_data,
964 s_words,
965 meas_indices,
966 );
967 }
968 meas_major
969 };
970
971 return Ok(PackedShots {
972 data: meas_major,
973 num_shots,
974 num_measurements: num_meas,
975 m_words,
976 s_words,
977 layout: ShotLayout::MeasMajor,
978 });
979 }
980
981 #[cfg(feature = "gpu")]
982 if self.should_use_gpu_bts(num_shots) {
983 match self.sample_bulk_packed_bts_gpu(num_shots, m_words, s_words) {
984 Ok(packed) => return Ok(packed),
985 Err(e) if propagate_gpu_errors => return Err(e),
986 Err(_) => {}
987 }
988 }
989
990 let sparse = self
991 .sparse
992 .as_ref()
993 .expect("sparse parity required for BTS (should_use_bts guards this)");
994
995 let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
996
997 if num_shots <= BTS_BATCH_SHOTS {
998 let data = bts_single_pass(
999 sparse,
1000 self.xor_dag.as_ref(),
1001 num_shots,
1002 &self.ref_bits_packed,
1003 &mut fast_rng,
1004 self.rank,
1005 );
1006 return Ok(PackedShots {
1007 data,
1008 num_shots,
1009 num_measurements: num_meas,
1010 m_words,
1011 s_words,
1012 layout: ShotLayout::MeasMajor,
1013 });
1014 }
1015
1016 let data = bts_batched(
1017 sparse,
1018 self.xor_dag.as_ref(),
1019 num_shots,
1020 s_words,
1021 &self.ref_bits_packed,
1022 &mut fast_rng,
1023 self.rank,
1024 );
1025 Ok(PackedShots {
1026 data,
1027 num_shots,
1028 num_measurements: num_meas,
1029 m_words,
1030 s_words,
1031 layout: ShotLayout::MeasMajor,
1032 })
1033 }
1034
1035 #[cfg(feature = "gpu")]
1036 fn sample_bulk_packed_bts_gpu(
1037 &mut self,
1038 num_shots: usize,
1039 m_words: usize,
1040 s_words: usize,
1041 ) -> Result<PackedShots> {
1042 let ctx = self
1043 .gpu_context
1044 .as_ref()
1045 .expect("gpu BTS selected without gpu_context")
1046 .clone();
1047 let rank = self.rank;
1048 let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
1049 let cache = self.ensure_gpu_bts_cache()?;
1050 let data = crate::gpu::kernels::bts::launch_bts_sample(
1051 &ctx,
1052 &mut fast_rng,
1053 rank,
1054 num_shots,
1055 cache,
1056 )?;
1057 Ok(PackedShots {
1058 data,
1059 num_shots,
1060 num_measurements: self.num_measurements,
1061 m_words,
1062 s_words,
1063 layout: ShotLayout::MeasMajor,
1064 })
1065 }
1066
1067 pub fn sample_chunked<A: ShotAccumulator>(&mut self, total_shots: usize, acc: &mut A) {
1068 let chunk_size = default_chunk_size(self.num_measurements);
1069 self.sample_chunked_with_size(total_shots, chunk_size, acc);
1070 }
1071
1072 pub fn sample_chunked_with_size<A: ShotAccumulator>(
1073 &mut self,
1074 total_shots: usize,
1075 chunk_size: usize,
1076 acc: &mut A,
1077 ) {
1078 let mut remaining = total_shots;
1079 while remaining > 0 {
1080 let this_batch = remaining.min(chunk_size);
1081 let packed = self.sample_bulk_packed(this_batch);
1082 acc.accumulate(&packed);
1083 remaining -= this_batch;
1084 }
1085 }
1086
1087 pub fn sample_counts(
1088 &mut self,
1089 total_shots: usize,
1090 ) -> std::collections::HashMap<Vec<u64>, u64> {
1091 match self.try_sample_counts(total_shots) {
1092 Ok(counts) => counts,
1093 Err(_) => self.sample_counts_cpu(total_shots),
1094 }
1095 }
1096
1097 pub fn try_sample_counts(
1098 &mut self,
1099 total_shots: usize,
1100 ) -> Result<std::collections::HashMap<Vec<u64>, u64>> {
1101 #[cfg(feature = "gpu")]
1102 if self.should_try_gpu_counts(total_shots) {
1103 return self.sample_bulk_packed_device(total_shots)?.counts();
1104 }
1105
1106 Ok(self.sample_counts_cpu(total_shots))
1107 }
1108
1109 fn sample_counts_cpu(
1110 &mut self,
1111 total_shots: usize,
1112 ) -> std::collections::HashMap<Vec<u64>, u64> {
1113 if self.rank > 0 && self.parity_blocks.is_none() {
1114 let num_outcomes = 1usize << self.rank;
1115
1116 if self.rank <= MAX_RANK_FOR_MULTINOMIAL
1117 && total_shots >= num_outcomes * MIN_SHOTS_PER_OUTCOME_MULTINOMIAL
1118 {
1119 return self.sample_counts_multinomial(total_shots);
1120 }
1121
1122 if self.rank <= MAX_RANK_FOR_RANK_SPACE
1123 && total_shots >= num_outcomes * MIN_SHOTS_PER_OUTCOME
1124 {
1125 return self.sample_counts_rank_space(total_shots);
1126 }
1127 }
1128 let mut acc = HistogramAccumulator::new();
1129 self.sample_chunked(total_shots, &mut acc);
1130 acc.into_counts()
1131 }
1132
1133 fn sample_counts_multinomial(
1134 &mut self,
1135 total_shots: usize,
1136 ) -> std::collections::HashMap<Vec<u64>, u64> {
1137 use std::collections::HashMap;
1138
1139 let m_words = self.num_measurements.div_ceil(64);
1140
1141 if total_shots == 0 || self.num_measurements == 0 {
1142 return HashMap::new();
1143 }
1144 if self.rank == 0 {
1145 let mut counts = HashMap::new();
1146 counts.insert(self.ref_bits_packed[..m_words].to_vec(), total_shots as u64);
1147 return counts;
1148 }
1149
1150 let rank = self.rank;
1151 let num_outcomes = 1usize << rank;
1152 let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
1153 let mut counts = HashMap::new();
1154 let mut remaining = total_shots;
1155
1156 for key in 0..num_outcomes {
1157 if remaining == 0 {
1158 break;
1159 }
1160 let outcomes_left = num_outcomes - key;
1161 let count = if outcomes_left == 1 {
1162 remaining
1163 } else {
1164 binomial_sample(&mut fast_rng, remaining, 1.0 / outcomes_left as f64)
1165 };
1166
1167 if count > 0 {
1168 let mut outcome = self.ref_bits_packed[..m_words].to_vec();
1169 if let Some(lut) = &self.lut {
1170 let total_groups = lut.num_full_groups + usize::from(lut.remainder_size > 0);
1171 for g in 0..total_groups {
1172 let byte = (key >> (g * 8)) & 0xFF;
1173 let entry = lut.lookup(g, byte);
1174 xor_words(&mut outcome, entry);
1175 }
1176 } else {
1177 for j in 0..rank {
1178 if (key >> j) & 1 != 0 {
1179 xor_words(&mut outcome, &self.flip_rows[j]);
1180 }
1181 }
1182 }
1183 counts.insert(outcome, count as u64);
1184 }
1185 remaining -= count;
1186 }
1187
1188 counts
1189 }
1190
1191 fn sample_counts_rank_space(
1192 &mut self,
1193 total_shots: usize,
1194 ) -> std::collections::HashMap<Vec<u64>, u64> {
1195 use std::collections::HashMap;
1196
1197 let m_words = self.num_measurements.div_ceil(64);
1198
1199 if total_shots == 0 || self.num_measurements == 0 {
1200 return HashMap::new();
1201 }
1202 if self.rank == 0 {
1203 let mut counts = HashMap::new();
1204 counts.insert(self.ref_bits_packed[..m_words].to_vec(), total_shots as u64);
1205 return counts;
1206 }
1207
1208 let rank = self.rank;
1209 let num_outcomes = 1usize << rank;
1210 let mut rank_counts = vec![0u64; num_outcomes];
1211 let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
1212
1213 if let Some(lut) = &self.lut {
1214 let total_groups = lut.num_full_groups + usize::from(lut.remainder_size > 0);
1215 let bytes_per_shot = total_groups;
1216 let remainder_mask: u8 = if lut.remainder_size > 0 {
1217 (1u8 << lut.remainder_size) - 1
1218 } else {
1219 0xFF
1220 };
1221
1222 let chunk_size = (32 * 1024 * 1024 / bytes_per_shot).max(64);
1223 let mut rand_buf = vec![0u8; chunk_size * bytes_per_shot];
1224 let mut remaining = total_shots;
1225
1226 while remaining > 0 {
1227 let this_chunk = remaining.min(chunk_size);
1228 let total_bytes = this_chunk * bytes_per_shot;
1229
1230 let full_chunks = total_bytes / 8;
1231 let tail = full_chunks * 8;
1232 for i in 0..full_chunks {
1233 let r = fast_rng.next_u64();
1234 rand_buf[i * 8..(i + 1) * 8].copy_from_slice(&r.to_le_bytes());
1235 }
1236 if tail < total_bytes {
1237 let r = fast_rng.next_u64();
1238 let bytes = r.to_le_bytes();
1239 rand_buf[tail..total_bytes].copy_from_slice(&bytes[..total_bytes - tail]);
1240 }
1241
1242 if lut.remainder_size > 0 {
1243 let last_group = lut.num_full_groups;
1244 for s in 0..this_chunk {
1245 rand_buf[s * bytes_per_shot + last_group] &= remainder_mask;
1246 }
1247 }
1248
1249 for s in 0..this_chunk {
1250 let base = s * bytes_per_shot;
1251 let mut key: usize = 0;
1252 for g in 0..bytes_per_shot {
1253 key |= (rand_buf[base + g] as usize) << (g * 8);
1254 }
1255 rank_counts[key] += 1;
1256 }
1257
1258 remaining -= this_chunk;
1259 }
1260 } else {
1261 let mut remaining = total_shots;
1262 while remaining > 0 {
1263 let this_chunk = remaining.min(4 * 1024 * 1024);
1264 for _ in 0..this_chunk {
1265 let bits = fast_rng.next_u64();
1266 let key = (bits as usize) & (num_outcomes - 1);
1267 rank_counts[key] += 1;
1268 }
1269 remaining -= this_chunk;
1270 }
1271 }
1272
1273 let mut counts = HashMap::new();
1274 for (key, &count) in rank_counts.iter().enumerate() {
1275 if count == 0 {
1276 continue;
1277 }
1278 let mut outcome = self.ref_bits_packed[..m_words].to_vec();
1279 if let Some(lut) = &self.lut {
1280 let total_groups = lut.num_full_groups + usize::from(lut.remainder_size > 0);
1281 for g in 0..total_groups {
1282 let byte = (key >> (g * 8)) & 0xFF;
1283 let entry = lut.lookup(g, byte);
1284 xor_words(&mut outcome, entry);
1285 }
1286 } else {
1287 for j in 0..rank {
1288 if (key >> j) & 1 != 0 {
1289 xor_words(&mut outcome, &self.flip_rows[j]);
1290 }
1291 }
1292 }
1293 counts.insert(outcome, count);
1294 }
1295
1296 counts
1297 }
1298
1299 pub fn sample_marginals(&mut self, total_shots: usize) -> Vec<f64> {
1300 match self.try_sample_marginals(total_shots) {
1301 Ok(marginals) => marginals,
1302 Err(_) => self.sample_marginals_cpu(total_shots),
1303 }
1304 }
1305
1306 pub fn try_sample_marginals(&mut self, total_shots: usize) -> Result<Vec<f64>> {
1307 #[cfg(feature = "gpu")]
1308 if self.should_use_gpu_bts(total_shots) {
1309 return self.sample_bulk_packed_device(total_shots)?.marginals();
1310 }
1311
1312 Ok(self.sample_marginals_cpu(total_shots))
1313 }
1314
1315 fn sample_marginals_cpu(&mut self, total_shots: usize) -> Vec<f64> {
1316 let mut acc = MarginalsAccumulator::new(self.num_measurements);
1317 self.sample_chunked(total_shots, &mut acc);
1318 acc.marginals()
1319 }
1320
1321 pub fn sample_detection_events(
1322 &mut self,
1323 pairs: &[(usize, usize)],
1324 num_shots: usize,
1325 ) -> PackedShots {
1326 let sparse = self.sparse.as_ref().expect("sparse parity required");
1327 let det_sparse = sparse.compile_detection_events(pairs);
1328 let num_events = det_sparse.num_rows;
1329 let m_words = num_events.div_ceil(64);
1330 let s_words = num_shots.div_ceil(64);
1331
1332 if num_events == 0 || num_shots == 0 || self.rank == 0 {
1333 return PackedShots {
1334 data: vec![0u64; num_events * s_words],
1335 num_shots,
1336 num_measurements: num_events,
1337 m_words,
1338 s_words,
1339 layout: ShotLayout::MeasMajor,
1340 };
1341 }
1342
1343 let det_weight = det_sparse.stats().total_weight;
1344 let meas_weight = sparse.stats().total_weight;
1345
1346 if det_weight > meas_weight + num_events {
1347 let meas_packed = self.sample_bulk_packed(num_shots);
1348 let mut data = vec![0u64; num_events * s_words];
1349 for (e, &(m_a, m_b)) in pairs.iter().enumerate() {
1350 let src_a = &meas_packed.data[m_a * s_words..(m_a + 1) * s_words];
1351 let src_b = &meas_packed.data[m_b * s_words..(m_b + 1) * s_words];
1352 let dst = &mut data[e * s_words..(e + 1) * s_words];
1353 for (d, (&a, &b)) in dst.iter_mut().zip(src_a.iter().zip(src_b.iter())) {
1354 *d = a ^ b;
1355 }
1356 }
1357 return PackedShots {
1358 data,
1359 num_shots,
1360 num_measurements: num_events,
1361 m_words,
1362 s_words,
1363 layout: ShotLayout::MeasMajor,
1364 };
1365 }
1366
1367 let det_ref = vec![0u64; m_words];
1368
1369 let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
1370
1371 let data = if num_shots > BTS_BATCH_SHOTS {
1372 bts_batched(
1373 &det_sparse,
1374 None,
1375 num_shots,
1376 s_words,
1377 &det_ref,
1378 &mut fast_rng,
1379 self.rank,
1380 )
1381 } else {
1382 sample_bts_meas_major(&det_sparse, num_shots, &det_ref, &mut fast_rng, self.rank)
1383 };
1384
1385 PackedShots {
1386 data,
1387 num_shots,
1388 num_measurements: num_events,
1389 m_words,
1390 s_words,
1391 layout: ShotLayout::MeasMajor,
1392 }
1393 }
1394
1395 pub fn exact_counts(&self) -> Option<std::collections::HashMap<Vec<u64>, u64>> {
1396 if self.rank > MAX_RANK_FOR_GRAY_CODE {
1397 return None;
1398 }
1399 let sparse = self.sparse.as_ref()?;
1400 Some(gray_code_exact_counts(
1401 sparse,
1402 self.rank,
1403 &self.ref_bits_packed,
1404 self.num_measurements,
1405 ))
1406 }
1407
1408 pub fn marginal_probabilities(&self) -> Vec<f64> {
1409 let mut probs = vec![0.5f64; self.num_measurements];
1410 if let Some(sparse) = &self.sparse {
1411 for (m, p) in probs.iter_mut().enumerate() {
1412 if sparse.row_weight(m) == 0 {
1413 let ref_bit = (self.ref_bits_packed[m / 64] >> (m % 64)) & 1;
1414 *p = ref_bit as f64;
1415 }
1416 }
1417 } else {
1418 for (m, p) in probs.iter_mut().enumerate() {
1419 let mut depends_on_random = false;
1420 for row in &self.flip_rows {
1421 let w = m / 64;
1422 if w < row.len() && (row[w] >> (m % 64)) & 1 != 0 {
1423 depends_on_random = true;
1424 break;
1425 }
1426 }
1427 if !depends_on_random {
1428 let ref_bit = (self.ref_bits_packed[m / 64] >> (m % 64)) & 1;
1429 *p = ref_bit as f64;
1430 }
1431 }
1432 }
1433 probs
1434 }
1435
1436 pub fn parity_report(&self) -> String {
1437 let mut report = format!(
1438 "CompiledSampler: {} measurements, rank {}, {} flip rows\n",
1439 self.num_measurements,
1440 self.rank,
1441 self.flip_rows.len()
1442 );
1443 if let Some(sparse) = &self.sparse {
1444 let stats = sparse.stats();
1445 report.push_str(&format!(
1446 "Parity matrix: {} rows, total weight {}\n\
1447 Weight range: {} to {}, mean {:.1}\n\
1448 Deterministic measurements: {}\n",
1449 sparse.num_rows,
1450 stats.total_weight,
1451 stats.min_weight,
1452 stats.max_weight,
1453 stats.mean_weight,
1454 stats.num_deterministic,
1455 ));
1456 let mut histogram = [0usize; 8];
1457 for m in 0..sparse.num_rows {
1458 let w = sparse.row_weight(m);
1459 let bucket = w.min(7);
1460 histogram[bucket] += 1;
1461 }
1462 report.push_str("Weight histogram: ");
1463 for (i, &count) in histogram.iter().enumerate() {
1464 if count > 0 {
1465 if i < 7 {
1466 report.push_str(&format!("w{}={} ", i, count));
1467 } else {
1468 report.push_str(&format!("w7+={} ", count));
1469 }
1470 }
1471 }
1472 report.push('\n');
1473 } else {
1474 report.push_str("No sparse parity matrix available\n");
1475 }
1476 report
1477 }
1478
1479 pub fn detection_event_report(&self, pairs: &[(usize, usize)]) -> String {
1480 let sparse = match &self.sparse {
1481 Some(s) => s,
1482 None => return "No sparse parity matrix available\n".to_string(),
1483 };
1484 let det_sparse = sparse.compile_detection_events(pairs);
1485 let meas_stats = sparse.stats();
1486 let det_stats = det_sparse.stats();
1487
1488 let mut report = format!(
1489 "Detection events: {} pairs\n\
1490 Measurement parity: total_weight={}, mean={:.2}\n\
1491 Detection parity: total_weight={}, mean={:.2}\n",
1492 pairs.len(),
1493 meas_stats.total_weight,
1494 meas_stats.mean_weight,
1495 det_stats.total_weight,
1496 det_stats.mean_weight,
1497 );
1498
1499 if meas_stats.total_weight > 0 {
1500 let reduction = 1.0 - det_stats.total_weight as f64 / meas_stats.total_weight as f64;
1501 report.push_str(&format!(
1502 "Weight reduction: {:.1}% ({:.1}x less work)\n",
1503 reduction * 100.0,
1504 if det_stats.total_weight > 0 {
1505 meas_stats.total_weight as f64 / det_stats.total_weight as f64
1506 } else {
1507 f64::INFINITY
1508 },
1509 ));
1510 }
1511
1512 let mut histogram = [0usize; 8];
1513 for m in 0..det_sparse.num_rows {
1514 let w = det_sparse.row_weight(m);
1515 histogram[w.min(7)] += 1;
1516 }
1517 report.push_str("Detection weight histogram: ");
1518 for (i, &count) in histogram.iter().enumerate() {
1519 if count > 0 {
1520 if i < 7 {
1521 report.push_str(&format!("w{}={} ", i, count));
1522 } else {
1523 report.push_str(&format!("w7+={} ", count));
1524 }
1525 }
1526 }
1527 report.push('\n');
1528 report
1529 }
1530}
1531
1532#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1533pub enum ShotLayout {
1534 ShotMajor,
1535 MeasMajor,
1536}
1537
1538#[derive(Debug, Clone)]
1539pub struct PackedShots {
1540 data: Vec<u64>,
1541 num_shots: usize,
1542 num_measurements: usize,
1543 m_words: usize,
1544 s_words: usize,
1545 layout: ShotLayout,
1546}
1547
1548impl PackedShots {
1549 pub fn num_shots(&self) -> usize {
1550 self.num_shots
1551 }
1552
1553 pub fn num_measurements(&self) -> usize {
1554 self.num_measurements
1555 }
1556
1557 pub fn layout(&self) -> ShotLayout {
1558 self.layout
1559 }
1560
1561 #[inline(always)]
1562 pub fn get_bit(&self, shot: usize, measurement: usize) -> bool {
1563 match self.layout {
1564 ShotLayout::ShotMajor => {
1565 let base = shot * self.m_words;
1566 let w = measurement / 64;
1567 (self.data[base + w] >> (measurement % 64)) & 1 != 0
1568 }
1569 ShotLayout::MeasMajor => {
1570 let base = measurement * self.s_words;
1571 let w = shot / 64;
1572 (self.data[base + w] >> (shot % 64)) & 1 != 0
1573 }
1574 }
1575 }
1576
1577 pub fn s_words(&self) -> usize {
1578 self.s_words
1579 }
1580
1581 pub fn m_words(&self) -> usize {
1582 self.m_words
1583 }
1584
1585 pub fn shot_words(&self, shot: usize) -> &[u64] {
1586 assert!(
1587 self.layout == ShotLayout::ShotMajor,
1588 "shot_words requires ShotMajor layout"
1589 );
1590 let base = shot * self.m_words;
1591 &self.data[base..base + self.m_words]
1592 }
1593
1594 pub fn meas_words(&self, m: usize) -> &[u64] {
1595 assert!(
1596 self.layout == ShotLayout::MeasMajor,
1597 "meas_words requires MeasMajor layout"
1598 );
1599 let base = m * self.s_words;
1600 &self.data[base..base + self.s_words]
1601 }
1602
1603 pub fn from_shot_major(data: Vec<u64>, num_shots: usize, num_measurements: usize) -> Self {
1604 let m_words = num_measurements.div_ceil(64);
1605 let s_words = num_shots.div_ceil(64);
1606 Self {
1607 data,
1608 num_shots,
1609 num_measurements,
1610 m_words,
1611 s_words,
1612 layout: ShotLayout::ShotMajor,
1613 }
1614 }
1615
1616 pub fn from_meas_major(data: Vec<u64>, num_shots: usize, num_measurements: usize) -> Self {
1617 let m_words = num_measurements.div_ceil(64);
1618 let s_words = num_shots.div_ceil(64);
1619 Self {
1620 data,
1621 num_shots,
1622 num_measurements,
1623 m_words,
1624 s_words,
1625 layout: ShotLayout::MeasMajor,
1626 }
1627 }
1628
1629 pub fn raw_data(&self) -> &[u64] {
1630 &self.data
1631 }
1632
1633 pub fn into_data(self) -> Vec<u64> {
1634 self.data
1635 }
1636
1637 #[allow(dead_code)]
1638 pub(crate) fn into_shot_major_data(self) -> Vec<u64> {
1639 if self.layout == ShotLayout::ShotMajor {
1640 return self.data;
1641 }
1642
1643 let mut shot_major = vec![0u64; self.num_shots * self.m_words];
1644 for batch_start in (0..self.num_shots).step_by(64) {
1645 let batch_end = (batch_start + 64).min(self.num_shots);
1646 let batch_len = batch_end - batch_start;
1647 let sw_base = batch_start / 64;
1648 let bit_off = batch_start % 64;
1649 let batch_mask = if batch_len == 64 {
1650 u64::MAX
1651 } else {
1652 (1u64 << batch_len) - 1
1653 };
1654
1655 for m in 0..self.num_measurements {
1656 let mword = m / 64;
1657 let mbit = m % 64;
1658 let meas_row = &self.data[m * self.s_words..(m + 1) * self.s_words];
1659 let mut bits = meas_row[sw_base] >> bit_off;
1660 if bit_off != 0 && sw_base + 1 < self.s_words {
1661 bits |= meas_row[sw_base + 1] << (64 - bit_off);
1662 }
1663 bits &= batch_mask;
1664
1665 while bits != 0 {
1666 let shot_in_batch = bits.trailing_zeros() as usize;
1667 let shot_base = (batch_start + shot_in_batch) * self.m_words;
1668 shot_major[shot_base + mword] |= 1u64 << mbit;
1669 bits &= bits - 1;
1670 }
1671 }
1672 }
1673
1674 shot_major
1675 }
1676
1677 pub fn to_shots(&self) -> Vec<Vec<bool>> {
1678 let mut shots = Vec::with_capacity(self.num_shots);
1679 for s in 0..self.num_shots {
1680 let mut shot = Vec::with_capacity(self.num_measurements);
1681 for m in 0..self.num_measurements {
1682 shot.push(self.get_bit(s, m));
1683 }
1684 shots.push(shot);
1685 }
1686 shots
1687 }
1688
1689 pub fn parity_rows(&self, rows: &[Vec<usize>]) -> Result<PackedShots> {
1693 validate_parity_rows(rows, self.num_measurements)?;
1694
1695 match self.layout {
1696 ShotLayout::MeasMajor => {
1697 let mut data = vec![0u64; rows.len() * self.s_words];
1698
1699 #[cfg(feature = "parallel")]
1700 if rows.len() * self.s_words >= 16_384 {
1701 use rayon::prelude::*;
1702 data.par_chunks_mut(self.s_words)
1703 .zip(rows.par_iter())
1704 .for_each(|(dst, row)| {
1705 for &measurement in row {
1706 let src_start = measurement * self.s_words;
1707 xor_words(dst, &self.data[src_start..src_start + self.s_words]);
1708 }
1709 });
1710 return Ok(PackedShots::from_meas_major(
1711 data,
1712 self.num_shots,
1713 rows.len(),
1714 ));
1715 }
1716
1717 for (out_idx, row) in rows.iter().enumerate() {
1718 let dst = &mut data[out_idx * self.s_words..(out_idx + 1) * self.s_words];
1719 for &measurement in row {
1720 let src_start = measurement * self.s_words;
1721 xor_words(dst, &self.data[src_start..src_start + self.s_words]);
1722 }
1723 }
1724 Ok(PackedShots::from_meas_major(
1725 data,
1726 self.num_shots,
1727 rows.len(),
1728 ))
1729 }
1730 ShotLayout::ShotMajor => {
1731 let out_words = rows.len().div_ceil(64);
1732 let mut data = vec![0u64; self.num_shots * out_words];
1733 for shot in 0..self.num_shots {
1734 let base = shot * out_words;
1735 for (out_idx, row) in rows.iter().enumerate() {
1736 let mut bit = false;
1737 for &measurement in row {
1738 bit ^= self.get_bit(shot, measurement);
1739 }
1740 if bit {
1741 data[base + out_idx / 64] |= 1u64 << (out_idx % 64);
1742 }
1743 }
1744 }
1745 Ok(PackedShots::from_shot_major(
1746 data,
1747 self.num_shots,
1748 rows.len(),
1749 ))
1750 }
1751 }
1752 }
1753
1754 pub fn counts(&self) -> std::collections::HashMap<Vec<u64>, u64> {
1755 if self.m_words > 8 {
1756 return self.counts_wide();
1757 }
1758 self.counts_packed()
1759 .into_iter()
1760 .map(|(k, v)| (k[..self.m_words].to_vec(), v))
1761 .collect()
1762 }
1763
1764 fn counts_wide(&self) -> std::collections::HashMap<Vec<u64>, u64> {
1765 use std::collections::HashMap;
1766
1767 let mut map = HashMap::new();
1768 let mw = self.m_words;
1769
1770 match self.layout {
1771 ShotLayout::ShotMajor => {
1772 for s in 0..self.num_shots {
1773 let base = s * mw;
1774 *map.entry(self.data[base..base + mw].to_vec()).or_insert(0) += 1;
1775 }
1776 }
1777 ShotLayout::MeasMajor => {
1778 let batch_size = 64;
1779 let mut shot_buf = vec![0u64; batch_size * mw];
1780 for batch_start in (0..self.num_shots).step_by(batch_size) {
1781 let batch_end = (batch_start + batch_size).min(self.num_shots);
1782 let batch_len = batch_end - batch_start;
1783 shot_buf[..batch_len * mw].fill(0);
1784
1785 let sw_base = batch_start / 64;
1786 let bit_off = batch_start % 64;
1787
1788 for m in 0..self.num_measurements {
1789 let mword = m / 64;
1790 let mbit = m % 64;
1791 let meas_row = &self.data[m * self.s_words..];
1792 let word = meas_row[sw_base];
1793 let shifted = word >> bit_off;
1794 for s in 0..batch_len {
1795 if (shifted >> s) & 1 != 0 {
1796 shot_buf[s * mw + mword] |= 1u64 << mbit;
1797 }
1798 }
1799 }
1800
1801 for s in 0..batch_len {
1802 let base = s * mw;
1803 *map.entry(shot_buf[base..base + mw].to_vec()).or_insert(0) += 1;
1804 }
1805 }
1806 }
1807 }
1808
1809 map
1810 }
1811
1812 fn counts_packed(&self) -> std::collections::HashMap<[u64; 8], u64> {
1813 use std::collections::HashMap;
1814 let mut map = HashMap::new();
1815 let mw = self.m_words;
1816 debug_assert!(mw <= 8);
1817
1818 match self.layout {
1819 ShotLayout::ShotMajor => {
1820 for s in 0..self.num_shots {
1821 let base = s * mw;
1822 let mut key = [0u64; 8];
1823 key[..mw].copy_from_slice(&self.data[base..base + mw]);
1824 *map.entry(key).or_insert(0) += 1;
1825 }
1826 }
1827 ShotLayout::MeasMajor => {
1828 let batch_size = 64;
1829 let mut shot_buf = vec![0u64; batch_size * mw];
1830 for batch_start in (0..self.num_shots).step_by(batch_size) {
1831 let batch_end = (batch_start + batch_size).min(self.num_shots);
1832 let batch_len = batch_end - batch_start;
1833 shot_buf[..batch_len * mw].fill(0);
1834
1835 let sw_base = batch_start / 64;
1836 let bit_off = batch_start % 64;
1837
1838 for m in 0..self.num_measurements {
1839 let mword = m / 64;
1840 let mbit = m % 64;
1841 let meas_row = &self.data[m * self.s_words..];
1842 let word = meas_row[sw_base];
1843 let shifted = word >> bit_off;
1844 for s in 0..batch_len {
1845 if (shifted >> s) & 1 != 0 {
1846 shot_buf[s * mw + mword] |= 1u64 << mbit;
1847 }
1848 }
1849 }
1850
1851 for s in 0..batch_len {
1852 let base = s * mw;
1853 let mut key = [0u64; 8];
1854 key[..mw].copy_from_slice(&shot_buf[base..base + mw]);
1855 *map.entry(key).or_insert(0) += 1;
1856 }
1857 }
1858 }
1859 }
1860 map
1861 }
1862}
1863
1864fn validate_parity_rows(rows: &[Vec<usize>], num_measurements: usize) -> Result<()> {
1865 for row in rows {
1866 for &measurement in row {
1867 if measurement >= num_measurements {
1868 return Err(PrismError::InvalidParameter {
1869 message: format!(
1870 "measurement index {measurement} out of bounds for {num_measurements} measurements"
1871 ),
1872 });
1873 }
1874 }
1875 }
1876 Ok(())
1877}
1878
1879impl CompiledDetectorSampler {
1880 pub fn num_measurements(&self) -> usize {
1881 self.num_measurements
1882 }
1883
1884 pub fn num_detectors(&self) -> usize {
1885 self.detector_rows.len()
1886 }
1887
1888 pub fn num_observables(&self) -> usize {
1889 self.observable_rows.len()
1890 }
1891
1892 pub fn rank(&self) -> usize {
1893 self.measurement_sampler.rank()
1894 }
1895
1896 pub fn detector_rows(&self) -> &[Vec<usize>] {
1897 &self.detector_rows
1898 }
1899
1900 pub fn observable_rows(&self) -> &[Vec<usize>] {
1901 &self.observable_rows
1902 }
1903
1904 #[cfg(feature = "gpu")]
1905 pub fn with_gpu(mut self, context: std::sync::Arc<crate::gpu::GpuContext>) -> Self {
1906 self.measurement_sampler = self.measurement_sampler.with_gpu(context);
1907 self
1908 }
1909
1910 pub fn sample_measurements_packed(&mut self, num_shots: usize) -> Result<PackedShots> {
1911 self.measurement_sampler.try_sample_bulk_packed(num_shots)
1912 }
1913
1914 pub fn sample_detectors_packed(&mut self, num_shots: usize) -> Result<PackedShots> {
1915 let measurements = self.sample_measurements_packed(num_shots)?;
1916 measurements.parity_rows(&self.detector_rows)
1917 }
1918
1919 pub fn sample_observables_packed(&mut self, num_shots: usize) -> Result<PackedShots> {
1920 let measurements = self.sample_measurements_packed(num_shots)?;
1921 measurements.parity_rows(&self.observable_rows)
1922 }
1923
1924 pub fn sample_packed(&mut self, num_shots: usize) -> Result<DetectorSampleBatch> {
1925 let measurements = self.sample_measurements_packed(num_shots)?;
1926 let detectors = measurements.parity_rows(&self.detector_rows)?;
1927 let observables = measurements.parity_rows(&self.observable_rows)?;
1928 Ok(DetectorSampleBatch {
1929 measurements,
1930 detectors,
1931 observables,
1932 })
1933 }
1934
1935 pub fn sample_detectors_chunked<A: ShotAccumulator>(
1936 &mut self,
1937 total_shots: usize,
1938 acc: &mut A,
1939 ) -> Result<()> {
1940 let chunk_size = default_chunk_size(self.detector_rows.len());
1941 let mut remaining = total_shots;
1942 while remaining > 0 {
1943 let batch = remaining.min(chunk_size);
1944 let packed = self.sample_detectors_packed(batch)?;
1945 acc.accumulate(&packed);
1946 remaining -= batch;
1947 }
1948 Ok(())
1949 }
1950
1951 pub fn sample_detector_counts(
1952 &mut self,
1953 total_shots: usize,
1954 ) -> Result<std::collections::HashMap<Vec<u64>, u64>> {
1955 let mut acc = HistogramAccumulator::new();
1956 self.sample_detectors_chunked(total_shots, &mut acc)?;
1957 Ok(acc.into_counts())
1958 }
1959}
1960
1961#[cfg(feature = "gpu")]
1968#[derive(Debug)]
1969pub struct DevicePackedShots {
1970 context: std::sync::Arc<crate::gpu::GpuContext>,
1971 data: crate::gpu::GpuBuffer<u64>,
1972 num_shots: usize,
1973 num_measurements: usize,
1974 m_words: usize,
1975 s_words: usize,
1976 layout: ShotLayout,
1977 rank: usize,
1978}
1979
1980#[cfg(feature = "gpu")]
1981impl DevicePackedShots {
1982 pub fn num_shots(&self) -> usize {
1983 self.num_shots
1984 }
1985
1986 pub fn num_measurements(&self) -> usize {
1987 self.num_measurements
1988 }
1989
1990 pub fn layout(&self) -> ShotLayout {
1991 self.layout
1992 }
1993
1994 pub(crate) fn context(&self) -> &std::sync::Arc<crate::gpu::GpuContext> {
1995 &self.context
1996 }
1997
1998 pub fn m_words(&self) -> usize {
1999 self.m_words
2000 }
2001
2002 pub fn s_words(&self) -> usize {
2003 self.s_words
2004 }
2005
2006 pub(crate) fn data_mut(&mut self) -> &mut crate::gpu::GpuBuffer<u64> {
2007 &mut self.data
2008 }
2009
2010 pub fn to_host(&self) -> Result<PackedShots> {
2012 let len = match self.layout {
2013 ShotLayout::ShotMajor => self.num_shots * self.m_words,
2014 ShotLayout::MeasMajor => self.num_measurements * self.s_words,
2015 };
2016 let mut host = vec![0u64; len];
2017 if len > 0 {
2018 self.data
2019 .copy_to_host(self.context.device(), &mut host)
2020 .map_err(|e| PrismError::BackendUnsupported {
2021 backend: "gpu".to_string(),
2022 operation: format!("copy device packed shots to host: {e}"),
2023 })?;
2024 }
2025
2026 Ok(match self.layout {
2027 ShotLayout::ShotMajor => {
2028 PackedShots::from_shot_major(host, self.num_shots, self.num_measurements)
2029 }
2030 ShotLayout::MeasMajor => {
2031 PackedShots::from_meas_major(host, self.num_shots, self.num_measurements)
2032 }
2033 })
2034 }
2035
2036 pub fn marginal_counts(&self) -> Result<Vec<u64>> {
2038 match self.layout {
2039 ShotLayout::MeasMajor => crate::gpu::kernels::bts::count_meas_major_marginals(
2040 &self.context,
2041 &self.data,
2042 self.num_measurements,
2043 self.num_shots,
2044 self.s_words,
2045 ),
2046 ShotLayout::ShotMajor => Ok(self.to_host()?.counts().into_iter().fold(
2047 vec![0u64; self.num_measurements],
2048 |mut acc, (key, count)| {
2049 for m in 0..self.num_measurements {
2050 if (key[m / 64] >> (m % 64)) & 1 != 0 {
2051 acc[m] += count;
2052 }
2053 }
2054 acc
2055 },
2056 )),
2057 }
2058 }
2059
2060 pub fn marginals(&self) -> Result<Vec<f64>> {
2062 if self.num_shots == 0 {
2063 return Ok(vec![0.0; self.num_measurements]);
2064 }
2065 Ok(self
2066 .marginal_counts()?
2067 .into_iter()
2068 .map(|count| count as f64 / self.num_shots as f64)
2069 .collect())
2070 }
2071
2072 pub fn counts(&self) -> Result<std::collections::HashMap<Vec<u64>, u64>> {
2074 self.counts_with_rank_hint(self.rank)
2075 }
2076
2077 pub(crate) fn counts_with_rank_hint(
2078 &self,
2079 rank_hint: usize,
2080 ) -> Result<std::collections::HashMap<Vec<u64>, u64>> {
2081 match self.layout {
2082 ShotLayout::MeasMajor => {
2083 if let Some(counts) = crate::gpu::kernels::bts::try_count_meas_major(
2084 &self.context,
2085 &self.data,
2086 self.num_measurements,
2087 self.num_shots,
2088 self.m_words,
2089 self.s_words,
2090 rank_hint,
2091 )? {
2092 return Ok(counts);
2093 }
2094 }
2095 ShotLayout::ShotMajor => {
2096 let raw_transfer_bytes = self
2097 .num_shots
2098 .saturating_mul(self.m_words)
2099 .saturating_mul(std::mem::size_of::<u64>());
2100 if let Some(counts) = crate::gpu::kernels::bts::try_count_shot_major(
2101 &self.context,
2102 &self.data,
2103 self.num_shots,
2104 self.m_words,
2105 rank_hint,
2106 raw_transfer_bytes,
2107 )? {
2108 return Ok(counts);
2109 }
2110 }
2111 }
2112 Ok(self.to_host()?.counts())
2113 }
2114}
2115
2116#[cfg(target_arch = "x86_64")]
2117#[target_feature(enable = "avx2")]
2118unsafe fn xor_words_avx2(dst: &mut [u64], src: &[u64]) {
2119 use std::arch::x86_64::*;
2120 let len = dst.len().min(src.len());
2121 let chunks = len / 4;
2122 let dp = dst.as_mut_ptr() as *mut __m256i;
2123 let sp = src.as_ptr() as *const __m256i;
2124 for i in 0..chunks {
2125 let d = _mm256_loadu_si256(dp.add(i));
2126 let s = _mm256_loadu_si256(sp.add(i));
2127 _mm256_storeu_si256(dp.add(i), _mm256_xor_si256(d, s));
2128 }
2129 let tail = chunks * 4;
2130 for i in tail..len {
2131 *dst.get_unchecked_mut(i) ^= *src.get_unchecked(i);
2132 }
2133}
2134
2135#[cfg(target_arch = "aarch64")]
2136#[allow(dead_code)]
2137unsafe fn xor_words_neon(dst: &mut [u64], src: &[u64]) {
2138 use std::arch::aarch64::*;
2139 let len = dst.len().min(src.len());
2140 let chunks = len / 2;
2141 let dp = dst.as_mut_ptr();
2142 let sp = src.as_ptr();
2143 for i in 0..chunks {
2144 let off = i * 2;
2145 let d = vld1q_u64(dp.add(off));
2146 let s = vld1q_u64(sp.add(off));
2147 vst1q_u64(dp.add(off), veorq_u64(d, s));
2148 }
2149 let tail = chunks * 2;
2150 for i in tail..len {
2151 *dst.get_unchecked_mut(i) ^= *src.get_unchecked(i);
2152 }
2153}
2154
2155#[inline(always)]
2156pub(crate) fn xor_words(dst: &mut [u64], src: &[u64]) {
2157 #[cfg(target_arch = "x86_64")]
2158 {
2159 if is_x86_feature_detected!("avx2") && dst.len() >= 4 {
2160 unsafe {
2162 xor_words_avx2(dst, src);
2163 }
2164 return;
2165 }
2166 }
2167 #[cfg(target_arch = "aarch64")]
2168 {
2169 if dst.len() >= 2 {
2170 unsafe {
2172 xor_words_neon(dst, src);
2173 }
2174 return;
2175 }
2176 }
2177 for (d, &s) in dst.iter_mut().zip(src) {
2178 *d ^= s;
2179 }
2180}
2181
2182fn gray_code_exact_counts(
2183 sparse: &SparseParity,
2184 rank: usize,
2185 ref_bits: &[u64],
2186 num_measurements: usize,
2187) -> std::collections::HashMap<Vec<u64>, u64> {
2188 use std::collections::HashMap;
2189
2190 let m_words = num_measurements.div_ceil(64);
2191 let mut meas_vec = ref_bits[..m_words].to_vec();
2192 let total: u64 = 1u64 << rank;
2193
2194 let mut col_words: Vec<Vec<u64>> = Vec::with_capacity(rank);
2195 for col in 0..rank {
2196 let mut cw = vec![0u64; m_words];
2197 for m in 0..num_measurements {
2198 let start = sparse.row_offsets[m] as usize;
2199 let end = sparse.row_offsets[m + 1] as usize;
2200 for &c in &sparse.col_indices[start..end] {
2201 if c as usize == col {
2202 cw[m / 64] |= 1u64 << (m % 64);
2203 }
2204 }
2205 }
2206 col_words.push(cw);
2207 }
2208
2209 let mut counts: HashMap<Vec<u64>, u64> = HashMap::new();
2210 *counts.entry(meas_vec.clone()).or_insert(0) += 1;
2211
2212 for step in 1..total {
2213 let bit_to_flip = step.trailing_zeros() as usize;
2214 let col = &col_words[bit_to_flip];
2215 for (mw, cw) in meas_vec.iter_mut().zip(col.iter()) {
2216 *mw ^= cw;
2217 }
2218 *counts.entry(meas_vec.clone()).or_insert(0) += 1;
2219 }
2220
2221 counts
2222}
2223
2224const MAX_RANK_FOR_GRAY_CODE: usize = 25;
2225const MAX_RANK_FOR_MULTINOMIAL: usize = 22;
2226const MIN_SHOTS_PER_OUTCOME_MULTINOMIAL: usize = 8;
2227const MAX_RANK_FOR_RANK_SPACE: usize = 20;
2228const MIN_SHOTS_PER_OUTCOME: usize = 4;
2229const MAX_LUT_ALLOC_BYTES: u64 = 256 * 1024 * 1024;
2230
2231pub fn compile_forward(circuit: &Circuit, seed: u64) -> Result<CompiledSampler> {
2232 if !circuit.is_clifford_only() {
2233 return Err(PrismError::IncompatibleBackend {
2234 backend: "CompiledSampler".to_string(),
2235 reason: "circuit contains non-Clifford gates".to_string(),
2236 });
2237 }
2238
2239 let measurements: Vec<(usize, usize)> = circuit
2240 .instructions
2241 .iter()
2242 .filter_map(|inst| match inst {
2243 Instruction::Measure {
2244 qubit,
2245 classical_bit,
2246 } => Some((*qubit, *classical_bit)),
2247 _ => None,
2248 })
2249 .collect();
2250
2251 let num_measurements = measurements.len();
2252 if num_measurements == 0 {
2253 return Ok(CompiledSampler {
2254 flip_rows: Vec::new(),
2255 ref_bits_packed: Vec::new(),
2256 rank: 0,
2257 num_measurements: 0,
2258 rng: ChaCha8Rng::seed_from_u64(seed),
2259 lut: None,
2260 sparse: None,
2261 xor_dag: None,
2262 parity_blocks: None,
2263 #[cfg(feature = "gpu")]
2264 gpu_context: None,
2265 #[cfg(feature = "gpu")]
2266 gpu_bts_cache: None,
2267 });
2268 }
2269
2270 let n = circuit.num_qubits;
2271
2272 let (mut xz, mut phase, nw) = colmajor_forward_sim(n, &circuit.instructions)?;
2273 let stride = 2 * nw;
2274 let m = num_measurements;
2275 let m_words = m.div_ceil(64);
2276
2277 let rank_words = m_words;
2278 let total_rows = 2 * n;
2279 let mut gen_dep: Vec<Vec<u64>> = vec![vec![0u64; rank_words]; total_rows + 1];
2280 let mut ref_bits: Vec<bool> = vec![false; m];
2281 let mut rank = 0usize;
2282
2283 let mut flip_rows: Vec<Vec<u64>> = Vec::with_capacity(m);
2284 let mut p_data: Vec<u64> = vec![0u64; stride];
2285 let mut p_dep: Vec<u64> = vec![0u64; rank_words];
2286 let mut scratch: Vec<u64> = vec![0u64; stride];
2287 let scratch_idx = total_rows;
2288
2289 for (meas_idx, &(qubit, _)) in measurements.iter().enumerate() {
2290 let word = qubit / 64;
2291 let bit_mask = 1u64 << (qubit % 64);
2292
2293 let mut p: Option<usize> = None;
2294 for i in n..2 * n {
2295 if xz[i * stride + word] & bit_mask != 0 {
2296 p = Some(i);
2297 break;
2298 }
2299 }
2300
2301 if let Some(p_row) = p {
2302 let k = rank;
2304 rank += 1;
2305 flip_rows.push(vec![0u64; m_words]);
2306
2307 flip_rows[k][meas_idx / 64] |= 1u64 << (meas_idx % 64);
2308
2309 let p_base = p_row * stride;
2310 p_data.copy_from_slice(&xz[p_base..p_base + stride]);
2311 let p_phase = phase[p_row];
2312 p_dep.copy_from_slice(&gen_dep[p_row][..rank_words]);
2313
2314 for r in 0..total_rows {
2315 if r == p_row {
2316 continue;
2317 }
2318 if xz[r * stride + word] & bit_mask == 0 {
2319 continue;
2320 }
2321
2322 let r_base = r * stride;
2323 phase[r] = rowmul_phase(&p_data, &mut xz, r_base, nw, p_phase, phase[r]);
2324 xor_words(&mut gen_dep[r][..rank_words], &p_dep[..rank_words]);
2325 }
2326
2327 let dest_idx = p_row - n;
2328 let dest_base = dest_idx * stride;
2329 xz.copy_within(p_row * stride..p_row * stride + stride, dest_base);
2330 phase[dest_idx] = p_phase;
2331 gen_dep[dest_idx][..rank_words].copy_from_slice(&p_dep);
2332
2333 let p_base = p_row * stride;
2334 xz[p_base..p_base + stride].fill(0);
2335 xz[p_base + nw + word] |= bit_mask;
2336 phase[p_row] = false;
2337
2338 gen_dep[p_row][..rank_words].fill(0);
2339 gen_dep[p_row][k / 64] |= 1u64 << (k % 64);
2340
2341 ref_bits[meas_idx] = false;
2342 } else {
2343 scratch[..stride].fill(0);
2344 let mut scratch_phase = false;
2345 gen_dep[scratch_idx][..rank_words].fill(0);
2346
2347 for g in 0..n {
2348 let d_base = g * stride;
2349 if xz[d_base + word] & bit_mask == 0 {
2350 continue;
2351 }
2352
2353 let s_base = (g + n) * stride;
2354 let s_phase = phase[g + n];
2355 scratch_phase =
2356 rowmul_phase_into(&xz, s_base, &mut scratch, nw, s_phase, scratch_phase);
2357
2358 let (lo, hi) = gen_dep.split_at_mut(scratch_idx);
2359 for (dst, &src) in hi[0][..rank_words].iter_mut().zip(&lo[g + n][..rank_words]) {
2360 *dst ^= src;
2361 }
2362 }
2363
2364 ref_bits[meas_idx] = scratch_phase;
2365
2366 for (w, &dep_word) in gen_dep[scratch_idx][..rank_words].iter().enumerate() {
2367 let mut bits = dep_word;
2368 while bits != 0 {
2369 let bit_pos = bits.trailing_zeros() as usize;
2370 let k = w * 64 + bit_pos;
2371 if k < rank {
2372 flip_rows[k][meas_idx / 64] |= 1u64 << (meas_idx % 64);
2373 }
2374 bits &= bits - 1;
2375 }
2376 }
2377 }
2378 }
2379
2380 let num_meas_words = m_words;
2381 minimize_flip_row_weight(&mut flip_rows);
2382
2383 let lut = if rank >= LUT_MIN_RANK {
2384 Some(FlipLut::build(&flip_rows, num_meas_words))
2385 } else {
2386 None
2387 };
2388
2389 let sparse = SparseParity::from_flip_rows(&flip_rows, num_measurements);
2390 let xor_dag = build_xor_dag_if_useful(&sparse);
2391 let ref_bits_packed = pack_bools(&ref_bits);
2392 let parity_blocks = build_parity_blocks_if_useful(&sparse, rank, &ref_bits_packed);
2393
2394 Ok(CompiledSampler {
2395 flip_rows,
2396 ref_bits_packed,
2397 rank,
2398 num_measurements,
2399 rng: ChaCha8Rng::seed_from_u64(seed),
2400 lut,
2401 sparse: Some(sparse),
2402 xor_dag,
2403 parity_blocks,
2404 #[cfg(feature = "gpu")]
2405 gpu_context: None,
2406 #[cfg(feature = "gpu")]
2407 gpu_bts_cache: None,
2408 })
2409}
2410
2411fn compile_measurements_filtered(
2412 circuit: &Circuit,
2413 blocks: &[Vec<usize>],
2414 seed: u64,
2415) -> Result<CompiledSampler> {
2416 let num_global_measurements: usize = circuit
2417 .instructions
2418 .iter()
2419 .filter(|i| matches!(i, Instruction::Measure { .. }))
2420 .count();
2421
2422 if num_global_measurements == 0 {
2423 return Ok(CompiledSampler {
2424 flip_rows: Vec::new(),
2425 ref_bits_packed: Vec::new(),
2426 rank: 0,
2427 num_measurements: 0,
2428 rng: ChaCha8Rng::seed_from_u64(seed),
2429 lut: None,
2430 sparse: None,
2431 xor_dag: None,
2432 parity_blocks: None,
2433 #[cfg(feature = "gpu")]
2434 gpu_context: None,
2435 #[cfg(feature = "gpu")]
2436 gpu_bts_cache: None,
2437 });
2438 }
2439
2440 let mut qubit_to_block: Vec<usize> = vec![0; circuit.num_qubits];
2441 for (bi, block) in blocks.iter().enumerate() {
2442 for &q in block {
2443 qubit_to_block[q] = bi;
2444 }
2445 }
2446
2447 let mut block_samplers: Vec<CompiledSampler> = Vec::with_capacity(blocks.len());
2448 for (bi, block) in blocks.iter().enumerate() {
2449 let (sub_circuit, _qubit_map, _classical_map) = circuit.extract_subcircuit(block);
2450 let block_seed = seed.wrapping_add(bi as u64 * 0x1234_5678);
2451 block_samplers.push(compile_measurements(&sub_circuit, block_seed)?);
2452 }
2453
2454 let mut meas_map: Vec<(usize, usize)> = Vec::with_capacity(num_global_measurements);
2455 let mut block_meas_count: Vec<usize> = vec![0; blocks.len()];
2456 for inst in &circuit.instructions {
2457 if let Instruction::Measure { qubit, .. } = inst {
2458 let bi = qubit_to_block[*qubit];
2459 let local_idx = block_meas_count[bi];
2460 block_meas_count[bi] += 1;
2461 meas_map.push((bi, local_idx));
2462 }
2463 }
2464
2465 let m_words = num_global_measurements.div_ceil(64);
2466 let total_rank: usize = block_samplers.iter().map(|s| s.rank).sum();
2467 let mut ref_bits_packed: Vec<u64> = vec![0u64; num_global_measurements.div_ceil(64)];
2468
2469 for (gi, &(bi, li)) in meas_map.iter().enumerate() {
2470 let src = &block_samplers[bi].ref_bits_packed;
2471 let bit = (src[li / 64] >> (li % 64)) & 1;
2472 if bit != 0 {
2473 ref_bits_packed[gi / 64] |= 1u64 << (gi % 64);
2474 }
2475 }
2476
2477 let mut local_to_global: Vec<Vec<usize>> = vec![Vec::new(); blocks.len()];
2478 for (gi, &(bi, _li)) in meas_map.iter().enumerate() {
2479 local_to_global[bi].push(gi);
2480 }
2481
2482 let mut rank_offsets = Vec::with_capacity(block_samplers.len());
2483 let mut rank_prefix = 0usize;
2484 for sampler in &block_samplers {
2485 rank_offsets.push(rank_prefix);
2486 rank_prefix += sampler.rank;
2487 }
2488 debug_assert_eq!(rank_prefix, total_rank);
2489
2490 let mut flip_rows: Vec<Vec<u64>> = Vec::with_capacity(total_rank);
2491 for (bi, sampler) in block_samplers.iter().enumerate() {
2492 let mapping = &local_to_global[bi];
2493 for local_row in &sampler.flip_rows {
2494 flip_rows.push(project_local_flip_row(local_row, mapping, m_words));
2495 }
2496 }
2497
2498 let lut = if total_rank >= LUT_MIN_RANK {
2499 Some(FlipLut::build(&flip_rows, m_words))
2500 } else {
2501 None
2502 };
2503
2504 let sparse = build_sparse_from_filtered_blocks(
2505 &block_samplers,
2506 &meas_map,
2507 &rank_offsets,
2508 num_global_measurements,
2509 );
2510 let xor_dag = build_xor_dag_if_useful(&sparse);
2511 let parity_blocks =
2512 build_filtered_parity_blocks(&block_samplers, &local_to_global).map(|mut blocks| {
2513 blocks.direct_scatter = true;
2514 blocks
2515 });
2516
2517 Ok(CompiledSampler {
2518 flip_rows,
2519 ref_bits_packed,
2520 rank: total_rank,
2521 num_measurements: num_global_measurements,
2522 rng: ChaCha8Rng::seed_from_u64(seed),
2523 lut,
2524 sparse: Some(sparse),
2525 xor_dag,
2526 parity_blocks,
2527 #[cfg(feature = "gpu")]
2528 gpu_context: None,
2529 #[cfg(feature = "gpu")]
2530 gpu_bts_cache: None,
2531 })
2532}
2533
2534pub(crate) fn defer_measure_reset_circuit(circuit: &Circuit) -> Result<Circuit> {
2536 if !circuit.is_clifford_only() {
2537 return Err(PrismError::IncompatibleBackend {
2538 backend: "CompiledSampler".to_string(),
2539 reason: "deferred measurement sampling requires Clifford gates".to_string(),
2540 });
2541 }
2542
2543 let has_measurements = circuit
2544 .instructions
2545 .iter()
2546 .any(|inst| matches!(inst, Instruction::Measure { .. }));
2547 if !has_measurements {
2548 return Err(PrismError::IncompatibleBackend {
2549 backend: "CompiledSampler".to_string(),
2550 reason: "deferred measurement sampling requires measurements".to_string(),
2551 });
2552 }
2553
2554 let mut aliases: Vec<usize> = (0..circuit.num_qubits).collect();
2555 let mut measured_aliases = vec![false; circuit.num_qubits];
2556 let mut next_qubit = circuit.num_qubits;
2557 let mut deferred_measurements: Vec<(usize, usize)> = Vec::new();
2558 let mut transformed = Circuit::new(circuit.num_qubits, circuit.num_classical_bits);
2559
2560 for inst in &circuit.instructions {
2561 match inst {
2562 Instruction::Gate { gate, targets } => {
2563 let mapped = map_deferred_targets(targets, &aliases, &measured_aliases)?;
2564 transformed.instructions.push(Instruction::Gate {
2565 gate: gate.clone(),
2566 targets: mapped,
2567 });
2568 }
2569 Instruction::Measure {
2570 qubit,
2571 classical_bit,
2572 } => {
2573 if *qubit >= aliases.len() {
2574 return Err(PrismError::InvalidQubit {
2575 index: *qubit,
2576 register_size: aliases.len(),
2577 });
2578 }
2579 if *classical_bit >= circuit.num_classical_bits {
2580 return Err(PrismError::InvalidClassicalBit {
2581 index: *classical_bit,
2582 register_size: circuit.num_classical_bits,
2583 });
2584 }
2585 let alias = aliases[*qubit];
2586 deferred_measurements.push((alias, *classical_bit));
2587 measured_aliases[alias] = true;
2588 }
2589 Instruction::Reset { qubit } => {
2590 if *qubit >= aliases.len() {
2591 return Err(PrismError::InvalidQubit {
2592 index: *qubit,
2593 register_size: aliases.len(),
2594 });
2595 }
2596 aliases[*qubit] = next_qubit;
2597 next_qubit += 1;
2598 measured_aliases.push(false);
2599 transformed.num_qubits = next_qubit;
2600 }
2601 Instruction::Barrier { qubits } => {
2602 let mut mapped = SmallVec::<[usize; 4]>::with_capacity(qubits.len());
2603 for &q in qubits {
2604 if q >= aliases.len() {
2605 return Err(PrismError::InvalidQubit {
2606 index: q,
2607 register_size: aliases.len(),
2608 });
2609 }
2610 mapped.push(aliases[q]);
2611 }
2612 transformed
2613 .instructions
2614 .push(Instruction::Barrier { qubits: mapped });
2615 }
2616 Instruction::Conditional { .. } => {
2617 return Err(PrismError::IncompatibleBackend {
2618 backend: "CompiledSampler".to_string(),
2619 reason: "deferred measurement sampling does not support classical conditionals"
2620 .to_string(),
2621 });
2622 }
2623 }
2624 }
2625
2626 transformed.num_qubits = next_qubit;
2627 transformed
2628 .instructions
2629 .reserve(deferred_measurements.len());
2630 for (qubit, classical_bit) in deferred_measurements {
2631 transformed.instructions.push(Instruction::Measure {
2632 qubit,
2633 classical_bit,
2634 });
2635 }
2636
2637 Ok(transformed)
2638}
2639
2640fn map_deferred_targets(
2641 targets: &SmallVec<[usize; 4]>,
2642 aliases: &[usize],
2643 measured_aliases: &[bool],
2644) -> Result<SmallVec<[usize; 4]>> {
2645 let mut mapped = SmallVec::<[usize; 4]>::with_capacity(targets.len());
2646 for &target in targets {
2647 if target >= aliases.len() {
2648 return Err(PrismError::InvalidQubit {
2649 index: target,
2650 register_size: aliases.len(),
2651 });
2652 }
2653 let alias = aliases[target];
2654 if measured_aliases[alias] {
2655 return Err(PrismError::IncompatibleBackend {
2656 backend: "CompiledSampler".to_string(),
2657 reason:
2658 "deferred measurement sampling requires reset before reusing a measured qubit"
2659 .to_string(),
2660 });
2661 }
2662 mapped.push(alias);
2663 }
2664 Ok(mapped)
2665}
2666
2667pub fn compile_detector_sampler(
2673 circuit: &Circuit,
2674 detector_rows: Vec<Vec<usize>>,
2675 observable_rows: Vec<Vec<usize>>,
2676 seed: u64,
2677) -> Result<CompiledDetectorSampler> {
2678 let measurement_circuit = if circuit.has_resets() || !circuit.has_terminal_measurements_only() {
2679 defer_measure_reset_circuit(circuit)?
2680 } else {
2681 circuit.clone()
2682 };
2683
2684 let num_measurements = measurement_circuit
2685 .instructions
2686 .iter()
2687 .filter(|inst| matches!(inst, Instruction::Measure { .. }))
2688 .count();
2689 validate_parity_rows(&detector_rows, num_measurements)?;
2690 validate_parity_rows(&observable_rows, num_measurements)?;
2691
2692 let measurement_sampler = compile_measurements(&measurement_circuit, seed)?;
2693 Ok(CompiledDetectorSampler {
2694 measurement_sampler,
2695 detector_rows,
2696 observable_rows,
2697 num_measurements,
2698 })
2699}
2700
2701pub fn compile_measurements(circuit: &Circuit, seed: u64) -> Result<CompiledSampler> {
2709 if !circuit.is_clifford_only() {
2710 return Err(PrismError::IncompatibleBackend {
2711 backend: "CompiledSampler".to_string(),
2712 reason: "circuit contains non-Clifford gates".to_string(),
2713 });
2714 }
2715
2716 let has_measurements = circuit
2717 .instructions
2718 .iter()
2719 .any(|inst| matches!(inst, Instruction::Measure { .. }));
2720 if has_measurements && circuit.has_resets() {
2721 return Err(PrismError::IncompatibleBackend {
2722 backend: "CompiledSampler".to_string(),
2723 reason: "compiled measurement sampling does not support reset instructions".to_string(),
2724 });
2725 }
2726 if has_measurements && !circuit.has_terminal_measurements_only() {
2727 return Err(PrismError::IncompatibleBackend {
2728 backend: "CompiledSampler".to_string(),
2729 reason: "compiled measurement sampling requires terminal measurements and does not support classical conditionals".to_string(),
2730 });
2731 }
2732
2733 if circuit.num_qubits >= 4 {
2734 let blocks = circuit.independent_subsystems();
2735 if blocks.len() > 1 {
2736 let max_block = blocks.iter().map(|b| b.len()).max().unwrap_or(0);
2737 if max_block < circuit.num_qubits {
2738 return compile_measurements_filtered(circuit, &blocks, seed);
2739 }
2740 }
2741 }
2742
2743 if circuit.num_qubits >= 64 {
2744 return compile_forward(circuit, seed);
2745 }
2746
2747 let measurement_rows = build_measurement_rows(circuit);
2748 let num_measurements = measurement_rows.len();
2749
2750 if num_measurements == 0 {
2751 return Ok(CompiledSampler {
2752 flip_rows: Vec::new(),
2753 ref_bits_packed: Vec::new(),
2754 rank: 0,
2755 num_measurements: 0,
2756 rng: ChaCha8Rng::seed_from_u64(seed),
2757 lut: None,
2758 sparse: None,
2759 xor_dag: None,
2760 parity_blocks: None,
2761 #[cfg(feature = "gpu")]
2762 gpu_context: None,
2763 #[cfg(feature = "gpu")]
2764 gpu_bts_cache: None,
2765 });
2766 }
2767
2768 let n = circuit.num_qubits;
2769
2770 let x_rows: Vec<Vec<u64>> = measurement_rows
2771 .iter()
2772 .map(|(p, _, _)| p.x.clone())
2773 .collect();
2774 let signs: Vec<bool> = measurement_rows.iter().map(|(_, _, s)| *s).collect();
2775
2776 let mut x_copy = x_rows.clone();
2777 let (rank, pivot_cols) = gaussian_eliminate(&mut x_copy, n);
2778
2779 let gate_count = circuit
2780 .instructions
2781 .iter()
2782 .filter(|i| {
2783 matches!(
2784 i,
2785 Instruction::Gate { .. } | Instruction::Conditional { .. }
2786 )
2787 })
2788 .count();
2789
2790 let ref_bits: Vec<bool> = if gate_count > 2 * num_measurements {
2791 let mini_outcomes = compute_reference_bits(&measurement_rows, n);
2792 mini_outcomes
2793 .iter()
2794 .zip(signs.iter())
2795 .map(|(&outcome, &sign)| outcome ^ sign)
2796 .collect()
2797 } else {
2798 use crate::backend::stabilizer::StabilizerBackend;
2799 use crate::backend::Backend;
2800 let mut stab = StabilizerBackend::new(seed);
2801 stab.init(circuit.num_qubits, circuit.num_classical_bits)?;
2802 stab.apply_instructions(&circuit.instructions)?;
2803 let ref_classical = stab.classical_results().to_vec();
2804 let classical_bit_order: Vec<usize> = measurement_rows.iter().map(|(_, c, _)| *c).collect();
2805 classical_bit_order
2806 .iter()
2807 .map(|&cbit| {
2808 if cbit < ref_classical.len() {
2809 ref_classical[cbit]
2810 } else {
2811 false
2812 }
2813 })
2814 .collect()
2815 };
2816
2817 let num_meas_words = num_measurements.div_ceil(64);
2818 let mut flip_rows: Vec<Vec<u64>> = vec![vec![0u64; num_meas_words]; rank];
2819
2820 for (j, &pcol) in pivot_cols.iter().enumerate() {
2821 for (i, x_row) in x_rows.iter().enumerate() {
2822 if get_bit(x_row, pcol) {
2823 flip_rows[j][i / 64] |= 1u64 << (i % 64);
2824 }
2825 }
2826 }
2827
2828 minimize_flip_row_weight(&mut flip_rows);
2829
2830 let lut = if rank >= LUT_MIN_RANK {
2831 Some(FlipLut::build(&flip_rows, num_meas_words))
2832 } else {
2833 None
2834 };
2835
2836 let sparse = SparseParity::from_flip_rows(&flip_rows, num_measurements);
2837 let xor_dag = build_xor_dag_if_useful(&sparse);
2838 let ref_bits_packed = pack_bools(&ref_bits);
2839 let parity_blocks = build_parity_blocks_if_useful(&sparse, rank, &ref_bits_packed);
2840
2841 Ok(CompiledSampler {
2842 flip_rows,
2843 ref_bits_packed,
2844 rank,
2845 num_measurements,
2846 rng: ChaCha8Rng::seed_from_u64(seed),
2847 lut,
2848 sparse: Some(sparse),
2849 xor_dag,
2850 parity_blocks,
2851 #[cfg(feature = "gpu")]
2852 gpu_context: None,
2853 #[cfg(feature = "gpu")]
2854 gpu_bts_cache: None,
2855 })
2856}
2857
2858pub fn run_shots_compiled(circuit: &Circuit, num_shots: usize, seed: u64) -> Result<ShotsResult> {
2867 let mut sampler = compile_measurements(circuit, seed)?;
2868 let packed = sampler.sample_bulk_packed(num_shots);
2869 Ok(ShotsResult {
2870 shots: packed.to_shots(),
2871 num_classical_bits: circuit.num_classical_bits,
2872 })
2873}
2874
2875#[cfg(feature = "gpu")]
2882pub fn run_shots_compiled_with_gpu(
2883 circuit: &Circuit,
2884 num_shots: usize,
2885 seed: u64,
2886 context: std::sync::Arc<crate::gpu::GpuContext>,
2887) -> Result<ShotsResult> {
2888 let mut sampler = compile_measurements(circuit, seed)?.with_gpu(context);
2889 let packed = sampler.sample_bulk_packed(num_shots);
2890 Ok(ShotsResult {
2891 shots: packed.to_shots(),
2892 num_classical_bits: circuit.num_classical_bits,
2893 })
2894}