Skip to main content

prism_q/sim/compiled/
mod.rs

1mod accumulator;
2mod bts;
3pub(crate) mod parity;
4mod propagation;
5pub(crate) mod rng;
6#[cfg(test)]
7mod tests;
8
9use std::hash::{Hash, Hasher};
10
11use crate::circuit::{Circuit, Instruction, SmallVec};
12use crate::error::{PrismError, Result};
13#[cfg(feature = "gpu")]
14use crate::gpu::kernels::bts::GpuBtsCache;
15use crate::sim::ShotsResult;
16use rand::{RngCore, SeedableRng};
17use rand_chacha::ChaCha8Rng;
18
19use bts::{bts_batched, bts_single_pass, sample_bts_meas_major, BTS_BATCH_SHOTS};
20use rng::{binomial_sample, Xoshiro256PlusPlus};
21
22pub use accumulator::{
23    default_chunk_size, optimal_chunk_size, CorrelatorAccumulator, HistogramAccumulator,
24    MarginalsAccumulator, NullAccumulator, PauliExpectationAccumulator, ShotAccumulator,
25};
26pub use parity::ParityStats;
27use parity::{build_parity_blocks_if_useful, build_xor_dag_if_useful, minimize_flip_row_weight};
28pub(crate) use parity::{ParityBlock, ParityBlocks, SparseParity, XorDag};
29
30pub(crate) use propagation::batch_propagate_backward;
31pub(crate) use propagation::propagate_backward;
32use propagation::{
33    build_measurement_rows, colmajor_forward_sim, compute_reference_bits, rowmul_phase,
34    rowmul_phase_into,
35};
36
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub(crate) struct PauliVec {
39    pub(crate) x: Vec<u64>,
40    pub(crate) z: Vec<u64>,
41}
42
43impl PauliVec {
44    pub fn new(num_words: usize) -> Self {
45        Self {
46            x: vec![0u64; num_words],
47            z: vec![0u64; num_words],
48        }
49    }
50
51    pub fn z_on_qubit(num_words: usize, qubit: usize) -> Self {
52        let mut pv = Self::new(num_words);
53        pv.z[qubit / 64] |= 1u64 << (qubit % 64);
54        pv
55    }
56
57    #[inline(always)]
58    pub fn is_diagonal(&self) -> bool {
59        self.x.iter().all(|&w| w == 0)
60    }
61
62    #[inline(always)]
63    pub fn has_x_or_y(&self, qubit: usize) -> bool {
64        get_bit(&self.x, qubit)
65    }
66}
67
68impl Hash for PauliVec {
69    fn hash<H: Hasher>(&self, state: &mut H) {
70        self.x.hash(state);
71        self.z.hash(state);
72    }
73}
74
75#[inline(always)]
76pub(crate) fn get_bit(words: &[u64], qubit: usize) -> bool {
77    (words[qubit / 64] >> (qubit % 64)) & 1 != 0
78}
79
80#[inline(always)]
81pub(crate) fn set_bit(words: &mut [u64], qubit: usize, val: bool) {
82    let word = qubit / 64;
83    let bit = qubit % 64;
84    if val {
85        words[word] |= 1u64 << bit;
86    } else {
87        words[word] &= !(1u64 << bit);
88    }
89}
90
91#[inline(always)]
92pub(crate) fn flip_bit(words: &mut [u64], qubit: usize) {
93    words[qubit / 64] ^= 1u64 << (qubit % 64);
94}
95
96fn gaussian_eliminate(rows: &mut [Vec<u64>], num_cols: usize) -> (usize, Vec<usize>) {
97    let num_rows = rows.len();
98    let mut pivot_cols: Vec<usize> = Vec::new();
99    let mut current_row = 0;
100
101    for col in 0..num_cols {
102        let word = col / 64;
103        let bit = col % 64;
104        let mask = 1u64 << bit;
105
106        let pivot = rows[current_row..num_rows]
107            .iter()
108            .position(|row| row[word] & mask != 0)
109            .map(|i| i + current_row);
110
111        let pivot_row = match pivot {
112            Some(r) => r,
113            None => continue,
114        };
115
116        if pivot_row != current_row {
117            rows.swap(pivot_row, current_row);
118        }
119
120        let (top, rest) = rows.split_at_mut(current_row + 1);
121        let pivot_data = &top[current_row];
122        for row in rest.iter_mut() {
123            if row[word] & mask != 0 {
124                for w in 0..row.len() {
125                    row[w] ^= pivot_data[w];
126                }
127            }
128        }
129
130        pivot_cols.push(col);
131        current_row += 1;
132    }
133
134    (current_row, pivot_cols)
135}
136
137const LUT_GROUP_SIZE: usize = 8;
138const LUT_MIN_RANK: usize = 8;
139
140struct FlipLut {
141    data: Vec<u64>,
142    m_words: usize,
143    num_full_groups: usize,
144    remainder_size: usize,
145}
146
147impl FlipLut {
148    fn build(flip_rows: &[Vec<u64>], m_words: usize) -> Self {
149        let rank = flip_rows.len();
150        let num_full_groups = rank / LUT_GROUP_SIZE;
151        let remainder_size = rank % LUT_GROUP_SIZE;
152        let total_groups = num_full_groups + usize::from(remainder_size > 0);
153        let entries_per_group = 1 << LUT_GROUP_SIZE;
154
155        let mut data = vec![0u64; total_groups * entries_per_group * m_words];
156
157        for g in 0..total_groups {
158            let group_start = g * LUT_GROUP_SIZE;
159            let k = if g < num_full_groups {
160                LUT_GROUP_SIZE
161            } else {
162                remainder_size
163            };
164            let lut_offset = g * entries_per_group * m_words;
165
166            for byte in 1..(1usize << k) {
167                let lowest = byte & byte.wrapping_neg();
168                let row_idx = group_start + lowest.trailing_zeros() as usize;
169                let prev = byte ^ lowest;
170
171                let dst_start = lut_offset + byte * m_words;
172                let src_start = lut_offset + prev * m_words;
173
174                for w in 0..m_words {
175                    data[dst_start + w] = data[src_start + w] ^ flip_rows[row_idx][w];
176                }
177            }
178        }
179
180        Self {
181            data,
182            m_words,
183            num_full_groups,
184            remainder_size,
185        }
186    }
187
188    #[inline(always)]
189    fn lookup(&self, group: usize, byte: usize) -> &[u64] {
190        let offset = (group * (1 << LUT_GROUP_SIZE) + byte) * self.m_words;
191        &self.data[offset..offset + self.m_words]
192    }
193}
194
195pub struct CompiledSampler {
196    flip_rows: Vec<Vec<u64>>,
197    ref_bits_packed: Vec<u64>,
198    rank: usize,
199    num_measurements: usize,
200    rng: ChaCha8Rng,
201    lut: Option<FlipLut>,
202    sparse: Option<SparseParity>,
203    xor_dag: Option<XorDag>,
204    parity_blocks: Option<ParityBlocks>,
205    #[cfg(feature = "gpu")]
206    gpu_context: Option<std::sync::Arc<crate::gpu::GpuContext>>,
207    #[cfg(feature = "gpu")]
208    gpu_bts_cache: Option<GpuBtsCache>,
209}
210
211/// Compiled sampler for measurement, detector, and observable records.
212pub struct CompiledDetectorSampler {
213    measurement_sampler: CompiledSampler,
214    detector_rows: Vec<Vec<usize>>,
215    observable_rows: Vec<Vec<usize>>,
216    num_measurements: usize,
217}
218
219/// Packed measurement, detector, and observable samples from one shot batch.
220#[derive(Debug, Clone)]
221pub struct DetectorSampleBatch {
222    pub measurements: PackedShots,
223    pub detectors: PackedShots,
224    pub observables: PackedShots,
225}
226
227fn pack_bools(bools: &[bool]) -> Vec<u64> {
228    let n_words = bools.len().div_ceil(64);
229    let mut packed = vec![0u64; n_words];
230    for (i, &b) in bools.iter().enumerate() {
231        if b {
232            packed[i / 64] |= 1u64 << (i % 64);
233        }
234    }
235    packed
236}
237
238#[cfg(feature = "parallel")]
239#[derive(Clone, Copy)]
240struct SendPtrU64(*mut u64);
241#[cfg(feature = "parallel")]
242// SAFETY: SendPtrU64 is used only for packed shot buffers partitioned by
243// disjoint word ranges before entering parallel workers.
244unsafe impl Send for SendPtrU64 {}
245#[cfg(feature = "parallel")]
246// SAFETY: The raw pointer wrapper is shared, but each worker writes only its
247// assigned non-overlapping range.
248unsafe impl Sync for SendPtrU64 {}
249#[cfg(feature = "parallel")]
250impl SendPtrU64 {
251    #[inline(always)]
252    unsafe fn copy_from_slice(self, dst_offset: usize, src: &[u64]) {
253        std::ptr::copy_nonoverlapping(src.as_ptr(), self.0.add(dst_offset), src.len());
254    }
255}
256
257#[inline(always)]
258fn scatter_meas_major_rows(
259    dst: &mut [u64],
260    dst_s_words: usize,
261    src: &[u64],
262    src_s_words: usize,
263    meas_indices: &[usize],
264) {
265    debug_assert_eq!(src.len(), meas_indices.len() * src_s_words);
266    for (local_m, &global_m) in meas_indices.iter().enumerate() {
267        let src_row = &src[local_m * src_s_words..(local_m + 1) * src_s_words];
268        let dst_offset = global_m * dst_s_words;
269        dst[dst_offset..dst_offset + src_s_words].copy_from_slice(src_row);
270    }
271}
272
273#[inline(always)]
274fn project_local_flip_row(local_row: &[u64], mapping: &[usize], m_words: usize) -> Vec<u64> {
275    let mut global_row = vec![0u64; m_words];
276    for (local_word, &word) in local_row.iter().enumerate() {
277        let mut bits = word;
278        while bits != 0 {
279            let bit = bits.trailing_zeros() as usize;
280            let local_m = local_word * 64 + bit;
281            debug_assert!(local_m < mapping.len());
282            let global_m = mapping[local_m];
283            global_row[global_m / 64] |= 1u64 << (global_m % 64);
284            bits &= bits - 1;
285        }
286    }
287    global_row
288}
289
290fn build_sparse_from_filtered_blocks(
291    block_samplers: &[CompiledSampler],
292    meas_map: &[(usize, usize)],
293    rank_offsets: &[usize],
294    num_global_measurements: usize,
295) -> SparseParity {
296    let mut row_offsets = Vec::with_capacity(num_global_measurements + 1);
297    let mut col_indices = Vec::new();
298
299    for &(block_idx, local_meas) in meas_map {
300        row_offsets.push(col_indices.len() as u32);
301        let sparse = block_samplers[block_idx]
302            .sparse
303            .as_ref()
304            .expect("filtered block samplers must retain sparse parity");
305        let rank_offset = rank_offsets[block_idx] as u32;
306        for &col in sparse.row_cols(local_meas) {
307            col_indices.push(rank_offset + col);
308        }
309    }
310    row_offsets.push(col_indices.len() as u32);
311
312    let non_det_rows: Vec<u32> = (0..num_global_measurements as u32)
313        .filter(|&m| row_offsets[m as usize + 1] != row_offsets[m as usize])
314        .collect();
315
316    SparseParity {
317        col_indices,
318        row_offsets,
319        num_rows: num_global_measurements,
320        non_det_rows,
321    }
322}
323
324fn build_filtered_parity_blocks(
325    block_samplers: &[CompiledSampler],
326    local_to_global: &[Vec<usize>],
327) -> Option<ParityBlocks> {
328    let mut blocks = Vec::new();
329
330    for (block_idx, sampler) in block_samplers.iter().enumerate() {
331        let mapping = &local_to_global[block_idx];
332        if mapping.is_empty() {
333            continue;
334        }
335
336        if let Some(parity_blocks) = &sampler.parity_blocks {
337            for block in &parity_blocks.blocks {
338                let meas_indices = block
339                    .meas_indices
340                    .iter()
341                    .map(|&local_m| mapping[local_m])
342                    .collect();
343                blocks.push(ParityBlock {
344                    meas_indices,
345                    sparse: block.sparse.clone(),
346                    block_rank: block.block_rank,
347                    ref_bits_packed: block.ref_bits_packed.clone(),
348                });
349            }
350            continue;
351        }
352
353        let sparse = sampler
354            .sparse
355            .as_ref()
356            .expect("filtered block samplers must retain sparse parity")
357            .clone();
358        blocks.push(ParityBlock {
359            meas_indices: mapping.clone(),
360            sparse,
361            block_rank: sampler.rank,
362            ref_bits_packed: sampler.ref_bits_packed.clone(),
363        });
364    }
365
366    ParityBlocks::from_blocks_if_useful(blocks)
367}
368
369impl CompiledSampler {
370    pub fn rank(&self) -> usize {
371        self.rank
372    }
373
374    /// Opt in to GPU-accelerated BTS sampling. The sampler routes
375    /// `sample_bulk_packed` through the GPU path when the circuit compiled
376    /// to a flat sparse parity (no parity blocks) and the shot count
377    /// crosses the GPU BTS threshold.
378    #[cfg(feature = "gpu")]
379    pub fn with_gpu(mut self, context: std::sync::Arc<crate::gpu::GpuContext>) -> Self {
380        self.gpu_context = Some(context);
381        self.gpu_bts_cache = None;
382        self
383    }
384
385    #[cfg(feature = "gpu")]
386    fn can_use_gpu_bts(&self) -> bool {
387        self.gpu_context.is_some()
388            && self.sparse.is_some()
389            && self.rank > 0
390            && self.parity_blocks.is_none()
391            && self.xor_dag.is_none()
392    }
393
394    #[cfg(feature = "gpu")]
395    pub(crate) fn should_use_gpu_bts(&self, num_shots: usize) -> bool {
396        let Some(sparse) = &self.sparse else {
397            return false;
398        };
399        if !self.can_use_gpu_bts()
400            || num_shots < crate::gpu::bts_min_shots()
401            || self.rank < crate::gpu::bts_min_rank()
402        {
403            return false;
404        }
405
406        let stats = sparse.stats();
407        let min_total_weight = sparse
408            .num_rows
409            .saturating_mul(crate::gpu::bts_min_weight_factor());
410        stats.total_weight >= min_total_weight
411    }
412
413    #[cfg(feature = "gpu")]
414    pub(crate) fn has_gpu_context(&self) -> bool {
415        self.gpu_context.is_some()
416    }
417
418    #[cfg(feature = "gpu")]
419    pub(crate) fn gpu_context(&self) -> Option<std::sync::Arc<crate::gpu::GpuContext>> {
420        self.gpu_context.clone()
421    }
422
423    #[cfg(feature = "gpu")]
424    fn should_try_gpu_counts(&self, total_shots: usize) -> bool {
425        if !self.should_use_gpu_bts(total_shots) || self.rank <= MAX_RANK_FOR_RANK_SPACE {
426            return false;
427        }
428
429        let m_words = self.num_measurements.div_ceil(64);
430        if m_words == 0 || m_words > crate::gpu::kernels::bts::GPU_COUNTS_MAX_WORDS {
431            return false;
432        }
433
434        let entropy_guard = total_shots.ilog2() as usize + 8;
435        self.rank <= entropy_guard
436    }
437
438    #[cfg(feature = "gpu")]
439    fn ensure_gpu_bts_cache(&mut self) -> Result<&mut GpuBtsCache> {
440        if self.gpu_bts_cache.is_none() {
441            let ctx = self
442                .gpu_context
443                .as_ref()
444                .expect("gpu BTS cache requested without gpu_context")
445                .clone();
446            let sparse = self
447                .sparse
448                .as_ref()
449                .expect("gpu BTS cache requested without sparse parity");
450            self.gpu_bts_cache = Some(GpuBtsCache::new(&ctx, sparse, &self.ref_bits_packed)?);
451        }
452        Ok(self
453            .gpu_bts_cache
454            .as_mut()
455            .expect("gpu BTS cache initialized above"))
456    }
457
458    pub fn num_measurements(&self) -> usize {
459        self.num_measurements
460    }
461
462    pub fn sparse(&self) -> Option<&SparseParity> {
463        self.sparse.as_ref()
464    }
465
466    pub fn parity_stats(&self) -> Option<ParityStats> {
467        self.sparse.as_ref().map(|s| s.stats())
468    }
469
470    pub fn sample(&mut self) -> Vec<bool> {
471        let num_meas_words = self.num_measurements.div_ceil(64);
472        let mut accum = vec![0u64; num_meas_words];
473        self.sample_into(&mut accum);
474        self.unpack_result(&accum)
475    }
476
477    pub(crate) fn sample_bulk_words_shot_major(&mut self, num_shots: usize) -> (Vec<u64>, usize) {
478        let m_words = self.num_measurements.div_ceil(64);
479        let mut accum = vec![0u64; num_shots * m_words];
480        let mut rand_buf = Vec::new();
481        self.sample_bulk_words_shot_major_reuse(&mut accum, &mut rand_buf, num_shots);
482        (accum, m_words)
483    }
484
485    pub(crate) fn sample_bulk_words_shot_major_reuse(
486        &mut self,
487        accum: &mut Vec<u64>,
488        rand_buf: &mut Vec<u8>,
489        num_shots: usize,
490    ) -> usize {
491        let m_words = self.num_measurements.div_ceil(64);
492        let needed = num_shots * m_words;
493        accum.resize(needed, 0);
494        accum[..needed].fill(0);
495        if num_shots == 0 || self.num_measurements == 0 || self.rank == 0 {
496            return m_words;
497        }
498
499        if let Some(lut) = &self.lut {
500            let total_groups = lut.num_full_groups + usize::from(lut.remainder_size > 0);
501            let bytes_per_shot = total_groups;
502            let total_bytes = num_shots * bytes_per_shot;
503            rand_buf.resize(total_bytes, 0);
504            {
505                let full_chunks = total_bytes / 8;
506                let tail = full_chunks * 8;
507                for i in 0..full_chunks {
508                    let r = self.rng.next_u64();
509                    rand_buf[i * 8..(i + 1) * 8].copy_from_slice(&r.to_le_bytes());
510                }
511                if tail < total_bytes {
512                    let r = self.rng.next_u64();
513                    let bytes = r.to_le_bytes();
514                    rand_buf[tail..total_bytes].copy_from_slice(&bytes[..total_bytes - tail]);
515                }
516                if lut.remainder_size > 0 {
517                    let remainder_mask = (1u8 << lut.remainder_size) - 1;
518                    let last_group = lut.num_full_groups;
519                    for s in 0..num_shots {
520                        rand_buf[s * bytes_per_shot + last_group] &= remainder_mask;
521                    }
522                }
523            }
524
525            let max_batch = if m_words > 0 {
526                (256 * 1024 / (m_words * 8)).max(64)
527            } else {
528                num_shots
529            };
530
531            #[cfg(feature = "parallel")]
532            const PAR_SHOT_THRESHOLD: usize = 256;
533
534            #[cfg(feature = "parallel")]
535            if num_shots >= PAR_SHOT_THRESHOLD {
536                use rayon::prelude::*;
537                let shots_per_chunk =
538                    (num_shots.div_ceil(rayon::current_num_threads())).max(max_batch);
539                let chunk_m = shots_per_chunk * m_words;
540                accum
541                    .par_chunks_mut(chunk_m)
542                    .enumerate()
543                    .for_each(|(ci, chunk)| {
544                        let chunk_shots = chunk.len() / m_words;
545                        let chunk_start = ci * shots_per_chunk;
546                        for tile_start in (0..chunk_shots).step_by(max_batch) {
547                            let tile_end = (tile_start + max_batch).min(chunk_shots);
548                            for g in 0..total_groups {
549                                for s in tile_start..tile_end {
550                                    let gs = chunk_start + s;
551                                    let byte = rand_buf[gs * bytes_per_shot + g] as usize;
552                                    let entry = lut.lookup(g, byte);
553                                    let base = s * m_words;
554                                    xor_words(&mut chunk[base..base + m_words], entry);
555                                }
556                            }
557                        }
558                    });
559            } else {
560                for tile_start in (0..num_shots).step_by(max_batch) {
561                    let tile_end = (tile_start + max_batch).min(num_shots);
562                    for g in 0..total_groups {
563                        for s in tile_start..tile_end {
564                            let byte = rand_buf[s * bytes_per_shot + g] as usize;
565                            let entry = lut.lookup(g, byte);
566                            let shot_base = s * m_words;
567                            xor_words(&mut accum[shot_base..shot_base + m_words], entry);
568                        }
569                    }
570                }
571            }
572
573            #[cfg(not(feature = "parallel"))]
574            {
575                for tile_start in (0..num_shots).step_by(max_batch) {
576                    let tile_end = (tile_start + max_batch).min(num_shots);
577                    for g in 0..total_groups {
578                        for s in tile_start..tile_end {
579                            let byte = rand_buf[s * bytes_per_shot + g] as usize;
580                            let entry = lut.lookup(g, byte);
581                            let shot_base = s * m_words;
582                            xor_words(&mut accum[shot_base..shot_base + m_words], entry);
583                        }
584                    }
585                }
586            }
587
588            m_words
589        } else {
590            for s in 0..num_shots {
591                let shot_base = s * m_words;
592                let shot_accum = &mut accum[shot_base..shot_base + m_words];
593                for j in 0..self.rank {
594                    let bit = self.rng.next_u32() & 1;
595                    if bit != 0 {
596                        let row = &self.flip_rows[j];
597                        xor_words(shot_accum, row);
598                    }
599                }
600            }
601            m_words
602        }
603    }
604
605    fn should_use_bts(&self, num_shots: usize) -> bool {
606        if let Some(sparse) = &self.sparse {
607            if self.rank == 0 {
608                return false;
609            }
610            let m_words = self.num_measurements.div_ceil(64) as u64;
611            let lut_groups = (self.rank.div_ceil(LUT_GROUP_SIZE)) as u64;
612
613            let lut_alloc_bytes = num_shots as u64 * (lut_groups + m_words * 8);
614            if lut_alloc_bytes > MAX_LUT_ALLOC_BYTES {
615                return true;
616            }
617
618            let s_words = num_shots.div_ceil(64);
619            let stats = sparse.stats();
620            let bts_work = stats.total_weight as u64 * s_words as u64;
621            let lut_work = num_shots as u64 * lut_groups * m_words;
622            bts_work < lut_work
623        } else {
624            false
625        }
626    }
627
628    pub(crate) fn ref_bits_packed(&self) -> &[u64] {
629        &self.ref_bits_packed
630    }
631
632    pub fn sample_bulk(&mut self, num_shots: usize) -> Vec<Vec<bool>> {
633        self.sample_bulk_packed(num_shots).to_shots()
634    }
635
636    /// Materialise packed shots directly on the GPU.
637    ///
638    /// This is available only when the sampler has a GPU context and the
639    /// compiled circuit routes through the flat sparse BTS path. Use
640    /// [`DevicePackedShots::to_host`] to copy the full packed payload back, or
641    /// [`DevicePackedShots::marginals`] / [`DevicePackedShots::counts`] to
642    /// reduce on device first.
643    #[cfg(feature = "gpu")]
644    pub fn sample_bulk_packed_device(&mut self, num_shots: usize) -> Result<DevicePackedShots> {
645        if !self.can_use_gpu_bts() {
646            return Err(PrismError::BackendUnsupported {
647                backend: "CompiledSampler".to_string(),
648                operation: "sample_bulk_packed_device requires with_gpu() and flat sparse BTS"
649                    .to_string(),
650            });
651        }
652
653        let ctx = self
654            .gpu_context
655            .as_ref()
656            .expect("gpu BTS selected without gpu_context")
657            .clone();
658        let rank = self.rank;
659        let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
660        let cache = self.ensure_gpu_bts_cache()?;
661        let data = crate::gpu::kernels::bts::launch_bts_sample_device(
662            &ctx,
663            &mut fast_rng,
664            rank,
665            num_shots,
666            cache,
667        )?;
668
669        Ok(DevicePackedShots {
670            context: ctx,
671            data,
672            num_shots,
673            num_measurements: self.num_measurements,
674            m_words: self.num_measurements.div_ceil(64),
675            s_words: num_shots.div_ceil(64),
676            layout: ShotLayout::MeasMajor,
677            rank,
678        })
679    }
680
681    /// GPU BTS sampling with on-device meas-major to shot-major bit-transpose.
682    ///
683    /// Returns `Some(Ok(data))` when the compiled circuit matches the GPU BTS
684    /// path (flat sparse, no xor_dag, shot threshold crossed). `data` is in
685    /// shot-major layout (`num_shots * m_words` u64s) so callers can skip the
686    /// host `into_shot_major_data()` transpose for downstream
687    /// shot-major consumers (noise apply, etc.). `None` signals "use the CPU
688    /// path for this sampler/shot-count combination".
689    #[cfg(feature = "gpu")]
690    pub(crate) fn try_sample_bulk_shot_major_gpu(
691        &mut self,
692        num_shots: usize,
693    ) -> Option<Result<Vec<u64>>> {
694        if !self.should_use_gpu_bts(num_shots) {
695            return None;
696        }
697        let ctx = self
698            .gpu_context
699            .as_ref()
700            .expect("gpu BTS selected without gpu_context")
701            .clone();
702        let rank = self.rank;
703        let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
704        let cache = match self.ensure_gpu_bts_cache() {
705            Ok(c) => c,
706            Err(e) => return Some(Err(e)),
707        };
708        Some(crate::gpu::kernels::bts::launch_bts_sample_shot_major_host(
709            &ctx,
710            &mut fast_rng,
711            rank,
712            num_shots,
713            cache,
714        ))
715    }
716
717    #[inline(always)]
718    fn sample_into(&mut self, accum: &mut [u64]) {
719        if self.rank == 0 {
720            return;
721        }
722
723        if let Some(lut) = &self.lut {
724            let mut rand_buf = 0u64;
725            let mut rand_pos = 8usize;
726
727            for g in 0..lut.num_full_groups {
728                if rand_pos >= 8 {
729                    rand_buf = self.rng.next_u64();
730                    rand_pos = 0;
731                }
732                let byte = ((rand_buf >> (rand_pos * 8)) & 0xFF) as usize;
733                rand_pos += 1;
734                let entry = lut.lookup(g, byte);
735                xor_words(accum, entry);
736            }
737            if lut.remainder_size > 0 {
738                if rand_pos >= 8 {
739                    rand_buf = self.rng.next_u64();
740                }
741                let mask = (1u64 << lut.remainder_size) - 1;
742                let byte = (rand_buf & mask) as usize;
743                let entry = lut.lookup(lut.num_full_groups, byte);
744                xor_words(accum, entry);
745            }
746        } else {
747            for j in 0..self.rank {
748                let bit = self.rng.next_u32() & 1;
749                if bit != 0 {
750                    let row = &self.flip_rows[j];
751                    xor_words(accum, row);
752                }
753            }
754        }
755    }
756
757    #[inline(always)]
758    fn unpack_result(&self, accum: &[u64]) -> Vec<bool> {
759        let mut result = Vec::with_capacity(self.num_measurements);
760        for m in 0..self.num_measurements {
761            let w = m / 64;
762            let ref_word = if w < self.ref_bits_packed.len() {
763                self.ref_bits_packed[w]
764            } else {
765                0
766            };
767            let bit = ((accum[w] ^ ref_word) >> (m % 64)) & 1 != 0;
768            result.push(bit);
769        }
770        result
771    }
772
773    pub fn sample_bulk_packed(&mut self, num_shots: usize) -> PackedShots {
774        match self.sample_bulk_packed_inner(num_shots, false) {
775            Ok(packed) => packed,
776            Err(_) => unreachable!("compiled sampler CPU fallback should not fail"),
777        }
778    }
779
780    pub fn try_sample_bulk_packed(&mut self, num_shots: usize) -> Result<PackedShots> {
781        self.sample_bulk_packed_inner(num_shots, true)
782    }
783
784    fn sample_bulk_packed_inner(
785        &mut self,
786        num_shots: usize,
787        propagate_gpu_errors: bool,
788    ) -> Result<PackedShots> {
789        let m_words = self.num_measurements.div_ceil(64);
790        let s_words = num_shots.div_ceil(64);
791        if num_shots == 0 || self.num_measurements == 0 {
792            return Ok(PackedShots {
793                data: Vec::new(),
794                num_shots,
795                num_measurements: self.num_measurements,
796                m_words,
797                s_words,
798                layout: ShotLayout::ShotMajor,
799            });
800        }
801        if self.rank == 0 {
802            let mut data = vec![0u64; num_shots * m_words];
803            for s in 0..num_shots {
804                let base = s * m_words;
805                data[base..base + m_words].copy_from_slice(&self.ref_bits_packed);
806            }
807            return Ok(PackedShots {
808                data,
809                num_shots,
810                num_measurements: self.num_measurements,
811                m_words,
812                s_words,
813                layout: ShotLayout::ShotMajor,
814            });
815        }
816
817        if self.should_use_bts(num_shots) {
818            return self.sample_bulk_packed_bts(num_shots, m_words, s_words, propagate_gpu_errors);
819        }
820
821        let (mut data, _) = self.sample_bulk_words_shot_major(num_shots);
822        for s in 0..num_shots {
823            let base = s * m_words;
824            xor_words(&mut data[base..base + m_words], &self.ref_bits_packed);
825        }
826        Ok(PackedShots {
827            data,
828            num_shots,
829            num_measurements: self.num_measurements,
830            m_words,
831            s_words,
832            layout: ShotLayout::ShotMajor,
833        })
834    }
835
836    fn sample_bulk_packed_bts(
837        &mut self,
838        num_shots: usize,
839        m_words: usize,
840        s_words: usize,
841        propagate_gpu_errors: bool,
842    ) -> Result<PackedShots> {
843        #[cfg(not(feature = "gpu"))]
844        let _ = propagate_gpu_errors;
845
846        let num_meas = self.num_measurements;
847
848        if let Some(pb) = &self.parity_blocks {
849            let block_seeds: Vec<u64> = (0..pb.blocks.len()).map(|_| self.rng.next_u64()).collect();
850            let meas_major = if pb.direct_scatter {
851                let total_len = num_meas * s_words;
852                #[allow(clippy::uninit_vec)]
853                let mut meas_major = {
854                    let mut v = Vec::with_capacity(total_len);
855                    // SAFETY: Every measurement row is written exactly once by one parity
856                    // block before the buffer is read. The filtered parity blocks form a
857                    // partition of the global measurement rows.
858                    unsafe { v.set_len(total_len) };
859                    v
860                };
861
862                #[cfg(feature = "parallel")]
863                {
864                    use rayon::prelude::*;
865
866                    let ptr = SendPtrU64(meas_major.as_mut_ptr());
867                    pb.blocks
868                        .par_iter()
869                        .zip(block_seeds.par_iter())
870                        .for_each(|(block, &seed)| {
871                            let mut block_chacha = ChaCha8Rng::seed_from_u64(seed);
872                            let mut block_rng = Xoshiro256PlusPlus::from_chacha(&mut block_chacha);
873                            let block_data = sample_bts_meas_major(
874                                &block.sparse,
875                                num_shots,
876                                &block.ref_bits_packed,
877                                &mut block_rng,
878                                block.block_rank,
879                            );
880
881                            // SAFETY: Each parity block owns a disjoint set of global
882                            // measurement rows, so these row copies target non-overlapping
883                            // regions of the final meas-major output.
884                            unsafe {
885                                for (local_m, &global_m) in block.meas_indices.iter().enumerate() {
886                                    let row =
887                                        &block_data[local_m * s_words..(local_m + 1) * s_words];
888                                    ptr.copy_from_slice(global_m * s_words, row);
889                                }
890                            }
891                        });
892                }
893
894                #[cfg(not(feature = "parallel"))]
895                {
896                    for (block, &seed) in pb.blocks.iter().zip(block_seeds.iter()) {
897                        let mut block_chacha = ChaCha8Rng::seed_from_u64(seed);
898                        let mut block_rng = Xoshiro256PlusPlus::from_chacha(&mut block_chacha);
899                        let block_data = sample_bts_meas_major(
900                            &block.sparse,
901                            num_shots,
902                            &block.ref_bits_packed,
903                            &mut block_rng,
904                            block.block_rank,
905                        );
906                        scatter_meas_major_rows(
907                            &mut meas_major,
908                            s_words,
909                            &block_data,
910                            s_words,
911                            &block.meas_indices,
912                        );
913                    }
914                }
915
916                meas_major
917            } else {
918                #[cfg(feature = "parallel")]
919                let block_results: Vec<(Vec<u64>, &[usize])> = {
920                    use rayon::prelude::*;
921                    pb.blocks
922                        .par_iter()
923                        .zip(block_seeds.par_iter())
924                        .map(|(block, &seed)| {
925                            let mut block_chacha = ChaCha8Rng::seed_from_u64(seed);
926                            let mut block_rng = Xoshiro256PlusPlus::from_chacha(&mut block_chacha);
927                            let data = sample_bts_meas_major(
928                                &block.sparse,
929                                num_shots,
930                                &block.ref_bits_packed,
931                                &mut block_rng,
932                                block.block_rank,
933                            );
934                            (data, block.meas_indices.as_slice())
935                        })
936                        .collect()
937                };
938
939                #[cfg(not(feature = "parallel"))]
940                let block_results: Vec<(Vec<u64>, &[usize])> = pb
941                    .blocks
942                    .iter()
943                    .zip(block_seeds.iter())
944                    .map(|(block, &seed)| {
945                        let mut block_chacha = ChaCha8Rng::seed_from_u64(seed);
946                        let mut block_rng = Xoshiro256PlusPlus::from_chacha(&mut block_chacha);
947                        let data = sample_bts_meas_major(
948                            &block.sparse,
949                            num_shots,
950                            &block.ref_bits_packed,
951                            &mut block_rng,
952                            block.block_rank,
953                        );
954                        (data, block.meas_indices.as_slice())
955                    })
956                    .collect();
957
958                let mut meas_major = vec![0u64; num_meas * s_words];
959                for (block_data, meas_indices) in &block_results {
960                    scatter_meas_major_rows(
961                        &mut meas_major,
962                        s_words,
963                        block_data,
964                        s_words,
965                        meas_indices,
966                    );
967                }
968                meas_major
969            };
970
971            return Ok(PackedShots {
972                data: meas_major,
973                num_shots,
974                num_measurements: num_meas,
975                m_words,
976                s_words,
977                layout: ShotLayout::MeasMajor,
978            });
979        }
980
981        #[cfg(feature = "gpu")]
982        if self.should_use_gpu_bts(num_shots) {
983            match self.sample_bulk_packed_bts_gpu(num_shots, m_words, s_words) {
984                Ok(packed) => return Ok(packed),
985                Err(e) if propagate_gpu_errors => return Err(e),
986                Err(_) => {}
987            }
988        }
989
990        let sparse = self
991            .sparse
992            .as_ref()
993            .expect("sparse parity required for BTS (should_use_bts guards this)");
994
995        let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
996
997        if num_shots <= BTS_BATCH_SHOTS {
998            let data = bts_single_pass(
999                sparse,
1000                self.xor_dag.as_ref(),
1001                num_shots,
1002                &self.ref_bits_packed,
1003                &mut fast_rng,
1004                self.rank,
1005            );
1006            return Ok(PackedShots {
1007                data,
1008                num_shots,
1009                num_measurements: num_meas,
1010                m_words,
1011                s_words,
1012                layout: ShotLayout::MeasMajor,
1013            });
1014        }
1015
1016        let data = bts_batched(
1017            sparse,
1018            self.xor_dag.as_ref(),
1019            num_shots,
1020            s_words,
1021            &self.ref_bits_packed,
1022            &mut fast_rng,
1023            self.rank,
1024        );
1025        Ok(PackedShots {
1026            data,
1027            num_shots,
1028            num_measurements: num_meas,
1029            m_words,
1030            s_words,
1031            layout: ShotLayout::MeasMajor,
1032        })
1033    }
1034
1035    #[cfg(feature = "gpu")]
1036    fn sample_bulk_packed_bts_gpu(
1037        &mut self,
1038        num_shots: usize,
1039        m_words: usize,
1040        s_words: usize,
1041    ) -> Result<PackedShots> {
1042        let ctx = self
1043            .gpu_context
1044            .as_ref()
1045            .expect("gpu BTS selected without gpu_context")
1046            .clone();
1047        let rank = self.rank;
1048        let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
1049        let cache = self.ensure_gpu_bts_cache()?;
1050        let data = crate::gpu::kernels::bts::launch_bts_sample(
1051            &ctx,
1052            &mut fast_rng,
1053            rank,
1054            num_shots,
1055            cache,
1056        )?;
1057        Ok(PackedShots {
1058            data,
1059            num_shots,
1060            num_measurements: self.num_measurements,
1061            m_words,
1062            s_words,
1063            layout: ShotLayout::MeasMajor,
1064        })
1065    }
1066
1067    pub fn sample_chunked<A: ShotAccumulator>(&mut self, total_shots: usize, acc: &mut A) {
1068        let chunk_size = default_chunk_size(self.num_measurements);
1069        self.sample_chunked_with_size(total_shots, chunk_size, acc);
1070    }
1071
1072    pub fn sample_chunked_with_size<A: ShotAccumulator>(
1073        &mut self,
1074        total_shots: usize,
1075        chunk_size: usize,
1076        acc: &mut A,
1077    ) {
1078        let mut remaining = total_shots;
1079        while remaining > 0 {
1080            let this_batch = remaining.min(chunk_size);
1081            let packed = self.sample_bulk_packed(this_batch);
1082            acc.accumulate(&packed);
1083            remaining -= this_batch;
1084        }
1085    }
1086
1087    pub fn sample_counts(
1088        &mut self,
1089        total_shots: usize,
1090    ) -> std::collections::HashMap<Vec<u64>, u64> {
1091        match self.try_sample_counts(total_shots) {
1092            Ok(counts) => counts,
1093            Err(_) => self.sample_counts_cpu(total_shots),
1094        }
1095    }
1096
1097    pub fn try_sample_counts(
1098        &mut self,
1099        total_shots: usize,
1100    ) -> Result<std::collections::HashMap<Vec<u64>, u64>> {
1101        #[cfg(feature = "gpu")]
1102        if self.should_try_gpu_counts(total_shots) {
1103            return self.sample_bulk_packed_device(total_shots)?.counts();
1104        }
1105
1106        Ok(self.sample_counts_cpu(total_shots))
1107    }
1108
1109    fn sample_counts_cpu(
1110        &mut self,
1111        total_shots: usize,
1112    ) -> std::collections::HashMap<Vec<u64>, u64> {
1113        if self.rank > 0 && self.parity_blocks.is_none() {
1114            let num_outcomes = 1usize << self.rank;
1115
1116            if self.rank <= MAX_RANK_FOR_MULTINOMIAL
1117                && total_shots >= num_outcomes * MIN_SHOTS_PER_OUTCOME_MULTINOMIAL
1118            {
1119                return self.sample_counts_multinomial(total_shots);
1120            }
1121
1122            if self.rank <= MAX_RANK_FOR_RANK_SPACE
1123                && total_shots >= num_outcomes * MIN_SHOTS_PER_OUTCOME
1124            {
1125                return self.sample_counts_rank_space(total_shots);
1126            }
1127        }
1128        let mut acc = HistogramAccumulator::new();
1129        self.sample_chunked(total_shots, &mut acc);
1130        acc.into_counts()
1131    }
1132
1133    fn sample_counts_multinomial(
1134        &mut self,
1135        total_shots: usize,
1136    ) -> std::collections::HashMap<Vec<u64>, u64> {
1137        use std::collections::HashMap;
1138
1139        let m_words = self.num_measurements.div_ceil(64);
1140
1141        if total_shots == 0 || self.num_measurements == 0 {
1142            return HashMap::new();
1143        }
1144        if self.rank == 0 {
1145            let mut counts = HashMap::new();
1146            counts.insert(self.ref_bits_packed[..m_words].to_vec(), total_shots as u64);
1147            return counts;
1148        }
1149
1150        let rank = self.rank;
1151        let num_outcomes = 1usize << rank;
1152        let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
1153        let mut counts = HashMap::new();
1154        let mut remaining = total_shots;
1155
1156        for key in 0..num_outcomes {
1157            if remaining == 0 {
1158                break;
1159            }
1160            let outcomes_left = num_outcomes - key;
1161            let count = if outcomes_left == 1 {
1162                remaining
1163            } else {
1164                binomial_sample(&mut fast_rng, remaining, 1.0 / outcomes_left as f64)
1165            };
1166
1167            if count > 0 {
1168                let mut outcome = self.ref_bits_packed[..m_words].to_vec();
1169                if let Some(lut) = &self.lut {
1170                    let total_groups = lut.num_full_groups + usize::from(lut.remainder_size > 0);
1171                    for g in 0..total_groups {
1172                        let byte = (key >> (g * 8)) & 0xFF;
1173                        let entry = lut.lookup(g, byte);
1174                        xor_words(&mut outcome, entry);
1175                    }
1176                } else {
1177                    for j in 0..rank {
1178                        if (key >> j) & 1 != 0 {
1179                            xor_words(&mut outcome, &self.flip_rows[j]);
1180                        }
1181                    }
1182                }
1183                counts.insert(outcome, count as u64);
1184            }
1185            remaining -= count;
1186        }
1187
1188        counts
1189    }
1190
1191    fn sample_counts_rank_space(
1192        &mut self,
1193        total_shots: usize,
1194    ) -> std::collections::HashMap<Vec<u64>, u64> {
1195        use std::collections::HashMap;
1196
1197        let m_words = self.num_measurements.div_ceil(64);
1198
1199        if total_shots == 0 || self.num_measurements == 0 {
1200            return HashMap::new();
1201        }
1202        if self.rank == 0 {
1203            let mut counts = HashMap::new();
1204            counts.insert(self.ref_bits_packed[..m_words].to_vec(), total_shots as u64);
1205            return counts;
1206        }
1207
1208        let rank = self.rank;
1209        let num_outcomes = 1usize << rank;
1210        let mut rank_counts = vec![0u64; num_outcomes];
1211        let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
1212
1213        if let Some(lut) = &self.lut {
1214            let total_groups = lut.num_full_groups + usize::from(lut.remainder_size > 0);
1215            let bytes_per_shot = total_groups;
1216            let remainder_mask: u8 = if lut.remainder_size > 0 {
1217                (1u8 << lut.remainder_size) - 1
1218            } else {
1219                0xFF
1220            };
1221
1222            let chunk_size = (32 * 1024 * 1024 / bytes_per_shot).max(64);
1223            let mut rand_buf = vec![0u8; chunk_size * bytes_per_shot];
1224            let mut remaining = total_shots;
1225
1226            while remaining > 0 {
1227                let this_chunk = remaining.min(chunk_size);
1228                let total_bytes = this_chunk * bytes_per_shot;
1229
1230                let full_chunks = total_bytes / 8;
1231                let tail = full_chunks * 8;
1232                for i in 0..full_chunks {
1233                    let r = fast_rng.next_u64();
1234                    rand_buf[i * 8..(i + 1) * 8].copy_from_slice(&r.to_le_bytes());
1235                }
1236                if tail < total_bytes {
1237                    let r = fast_rng.next_u64();
1238                    let bytes = r.to_le_bytes();
1239                    rand_buf[tail..total_bytes].copy_from_slice(&bytes[..total_bytes - tail]);
1240                }
1241
1242                if lut.remainder_size > 0 {
1243                    let last_group = lut.num_full_groups;
1244                    for s in 0..this_chunk {
1245                        rand_buf[s * bytes_per_shot + last_group] &= remainder_mask;
1246                    }
1247                }
1248
1249                for s in 0..this_chunk {
1250                    let base = s * bytes_per_shot;
1251                    let mut key: usize = 0;
1252                    for g in 0..bytes_per_shot {
1253                        key |= (rand_buf[base + g] as usize) << (g * 8);
1254                    }
1255                    rank_counts[key] += 1;
1256                }
1257
1258                remaining -= this_chunk;
1259            }
1260        } else {
1261            let mut remaining = total_shots;
1262            while remaining > 0 {
1263                let this_chunk = remaining.min(4 * 1024 * 1024);
1264                for _ in 0..this_chunk {
1265                    let bits = fast_rng.next_u64();
1266                    let key = (bits as usize) & (num_outcomes - 1);
1267                    rank_counts[key] += 1;
1268                }
1269                remaining -= this_chunk;
1270            }
1271        }
1272
1273        let mut counts = HashMap::new();
1274        for (key, &count) in rank_counts.iter().enumerate() {
1275            if count == 0 {
1276                continue;
1277            }
1278            let mut outcome = self.ref_bits_packed[..m_words].to_vec();
1279            if let Some(lut) = &self.lut {
1280                let total_groups = lut.num_full_groups + usize::from(lut.remainder_size > 0);
1281                for g in 0..total_groups {
1282                    let byte = (key >> (g * 8)) & 0xFF;
1283                    let entry = lut.lookup(g, byte);
1284                    xor_words(&mut outcome, entry);
1285                }
1286            } else {
1287                for j in 0..rank {
1288                    if (key >> j) & 1 != 0 {
1289                        xor_words(&mut outcome, &self.flip_rows[j]);
1290                    }
1291                }
1292            }
1293            counts.insert(outcome, count);
1294        }
1295
1296        counts
1297    }
1298
1299    pub fn sample_marginals(&mut self, total_shots: usize) -> Vec<f64> {
1300        match self.try_sample_marginals(total_shots) {
1301            Ok(marginals) => marginals,
1302            Err(_) => self.sample_marginals_cpu(total_shots),
1303        }
1304    }
1305
1306    pub fn try_sample_marginals(&mut self, total_shots: usize) -> Result<Vec<f64>> {
1307        #[cfg(feature = "gpu")]
1308        if self.should_use_gpu_bts(total_shots) {
1309            return self.sample_bulk_packed_device(total_shots)?.marginals();
1310        }
1311
1312        Ok(self.sample_marginals_cpu(total_shots))
1313    }
1314
1315    fn sample_marginals_cpu(&mut self, total_shots: usize) -> Vec<f64> {
1316        let mut acc = MarginalsAccumulator::new(self.num_measurements);
1317        self.sample_chunked(total_shots, &mut acc);
1318        acc.marginals()
1319    }
1320
1321    pub fn sample_detection_events(
1322        &mut self,
1323        pairs: &[(usize, usize)],
1324        num_shots: usize,
1325    ) -> PackedShots {
1326        let sparse = self.sparse.as_ref().expect("sparse parity required");
1327        let det_sparse = sparse.compile_detection_events(pairs);
1328        let num_events = det_sparse.num_rows;
1329        let m_words = num_events.div_ceil(64);
1330        let s_words = num_shots.div_ceil(64);
1331
1332        if num_events == 0 || num_shots == 0 || self.rank == 0 {
1333            return PackedShots {
1334                data: vec![0u64; num_events * s_words],
1335                num_shots,
1336                num_measurements: num_events,
1337                m_words,
1338                s_words,
1339                layout: ShotLayout::MeasMajor,
1340            };
1341        }
1342
1343        let det_weight = det_sparse.stats().total_weight;
1344        let meas_weight = sparse.stats().total_weight;
1345
1346        if det_weight > meas_weight + num_events {
1347            let meas_packed = self.sample_bulk_packed(num_shots);
1348            let mut data = vec![0u64; num_events * s_words];
1349            for (e, &(m_a, m_b)) in pairs.iter().enumerate() {
1350                let src_a = &meas_packed.data[m_a * s_words..(m_a + 1) * s_words];
1351                let src_b = &meas_packed.data[m_b * s_words..(m_b + 1) * s_words];
1352                let dst = &mut data[e * s_words..(e + 1) * s_words];
1353                for (d, (&a, &b)) in dst.iter_mut().zip(src_a.iter().zip(src_b.iter())) {
1354                    *d = a ^ b;
1355                }
1356            }
1357            return PackedShots {
1358                data,
1359                num_shots,
1360                num_measurements: num_events,
1361                m_words,
1362                s_words,
1363                layout: ShotLayout::MeasMajor,
1364            };
1365        }
1366
1367        let det_ref = vec![0u64; m_words];
1368
1369        let mut fast_rng = Xoshiro256PlusPlus::from_chacha(&mut self.rng);
1370
1371        let data = if num_shots > BTS_BATCH_SHOTS {
1372            bts_batched(
1373                &det_sparse,
1374                None,
1375                num_shots,
1376                s_words,
1377                &det_ref,
1378                &mut fast_rng,
1379                self.rank,
1380            )
1381        } else {
1382            sample_bts_meas_major(&det_sparse, num_shots, &det_ref, &mut fast_rng, self.rank)
1383        };
1384
1385        PackedShots {
1386            data,
1387            num_shots,
1388            num_measurements: num_events,
1389            m_words,
1390            s_words,
1391            layout: ShotLayout::MeasMajor,
1392        }
1393    }
1394
1395    pub fn exact_counts(&self) -> Option<std::collections::HashMap<Vec<u64>, u64>> {
1396        if self.rank > MAX_RANK_FOR_GRAY_CODE {
1397            return None;
1398        }
1399        let sparse = self.sparse.as_ref()?;
1400        Some(gray_code_exact_counts(
1401            sparse,
1402            self.rank,
1403            &self.ref_bits_packed,
1404            self.num_measurements,
1405        ))
1406    }
1407
1408    pub fn marginal_probabilities(&self) -> Vec<f64> {
1409        let mut probs = vec![0.5f64; self.num_measurements];
1410        if let Some(sparse) = &self.sparse {
1411            for (m, p) in probs.iter_mut().enumerate() {
1412                if sparse.row_weight(m) == 0 {
1413                    let ref_bit = (self.ref_bits_packed[m / 64] >> (m % 64)) & 1;
1414                    *p = ref_bit as f64;
1415                }
1416            }
1417        } else {
1418            for (m, p) in probs.iter_mut().enumerate() {
1419                let mut depends_on_random = false;
1420                for row in &self.flip_rows {
1421                    let w = m / 64;
1422                    if w < row.len() && (row[w] >> (m % 64)) & 1 != 0 {
1423                        depends_on_random = true;
1424                        break;
1425                    }
1426                }
1427                if !depends_on_random {
1428                    let ref_bit = (self.ref_bits_packed[m / 64] >> (m % 64)) & 1;
1429                    *p = ref_bit as f64;
1430                }
1431            }
1432        }
1433        probs
1434    }
1435
1436    pub fn parity_report(&self) -> String {
1437        let mut report = format!(
1438            "CompiledSampler: {} measurements, rank {}, {} flip rows\n",
1439            self.num_measurements,
1440            self.rank,
1441            self.flip_rows.len()
1442        );
1443        if let Some(sparse) = &self.sparse {
1444            let stats = sparse.stats();
1445            report.push_str(&format!(
1446                "Parity matrix: {} rows, total weight {}\n\
1447                 Weight range: {} to {}, mean {:.1}\n\
1448                 Deterministic measurements: {}\n",
1449                sparse.num_rows,
1450                stats.total_weight,
1451                stats.min_weight,
1452                stats.max_weight,
1453                stats.mean_weight,
1454                stats.num_deterministic,
1455            ));
1456            let mut histogram = [0usize; 8];
1457            for m in 0..sparse.num_rows {
1458                let w = sparse.row_weight(m);
1459                let bucket = w.min(7);
1460                histogram[bucket] += 1;
1461            }
1462            report.push_str("Weight histogram: ");
1463            for (i, &count) in histogram.iter().enumerate() {
1464                if count > 0 {
1465                    if i < 7 {
1466                        report.push_str(&format!("w{}={} ", i, count));
1467                    } else {
1468                        report.push_str(&format!("w7+={} ", count));
1469                    }
1470                }
1471            }
1472            report.push('\n');
1473        } else {
1474            report.push_str("No sparse parity matrix available\n");
1475        }
1476        report
1477    }
1478
1479    pub fn detection_event_report(&self, pairs: &[(usize, usize)]) -> String {
1480        let sparse = match &self.sparse {
1481            Some(s) => s,
1482            None => return "No sparse parity matrix available\n".to_string(),
1483        };
1484        let det_sparse = sparse.compile_detection_events(pairs);
1485        let meas_stats = sparse.stats();
1486        let det_stats = det_sparse.stats();
1487
1488        let mut report = format!(
1489            "Detection events: {} pairs\n\
1490             Measurement parity: total_weight={}, mean={:.2}\n\
1491             Detection parity:   total_weight={}, mean={:.2}\n",
1492            pairs.len(),
1493            meas_stats.total_weight,
1494            meas_stats.mean_weight,
1495            det_stats.total_weight,
1496            det_stats.mean_weight,
1497        );
1498
1499        if meas_stats.total_weight > 0 {
1500            let reduction = 1.0 - det_stats.total_weight as f64 / meas_stats.total_weight as f64;
1501            report.push_str(&format!(
1502                "Weight reduction: {:.1}% ({:.1}x less work)\n",
1503                reduction * 100.0,
1504                if det_stats.total_weight > 0 {
1505                    meas_stats.total_weight as f64 / det_stats.total_weight as f64
1506                } else {
1507                    f64::INFINITY
1508                },
1509            ));
1510        }
1511
1512        let mut histogram = [0usize; 8];
1513        for m in 0..det_sparse.num_rows {
1514            let w = det_sparse.row_weight(m);
1515            histogram[w.min(7)] += 1;
1516        }
1517        report.push_str("Detection weight histogram: ");
1518        for (i, &count) in histogram.iter().enumerate() {
1519            if count > 0 {
1520                if i < 7 {
1521                    report.push_str(&format!("w{}={} ", i, count));
1522                } else {
1523                    report.push_str(&format!("w7+={} ", count));
1524                }
1525            }
1526        }
1527        report.push('\n');
1528        report
1529    }
1530}
1531
1532#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1533pub enum ShotLayout {
1534    ShotMajor,
1535    MeasMajor,
1536}
1537
1538#[derive(Debug, Clone)]
1539pub struct PackedShots {
1540    data: Vec<u64>,
1541    num_shots: usize,
1542    num_measurements: usize,
1543    m_words: usize,
1544    s_words: usize,
1545    layout: ShotLayout,
1546}
1547
1548impl PackedShots {
1549    pub fn num_shots(&self) -> usize {
1550        self.num_shots
1551    }
1552
1553    pub fn num_measurements(&self) -> usize {
1554        self.num_measurements
1555    }
1556
1557    pub fn layout(&self) -> ShotLayout {
1558        self.layout
1559    }
1560
1561    #[inline(always)]
1562    pub fn get_bit(&self, shot: usize, measurement: usize) -> bool {
1563        match self.layout {
1564            ShotLayout::ShotMajor => {
1565                let base = shot * self.m_words;
1566                let w = measurement / 64;
1567                (self.data[base + w] >> (measurement % 64)) & 1 != 0
1568            }
1569            ShotLayout::MeasMajor => {
1570                let base = measurement * self.s_words;
1571                let w = shot / 64;
1572                (self.data[base + w] >> (shot % 64)) & 1 != 0
1573            }
1574        }
1575    }
1576
1577    pub fn s_words(&self) -> usize {
1578        self.s_words
1579    }
1580
1581    pub fn m_words(&self) -> usize {
1582        self.m_words
1583    }
1584
1585    pub fn shot_words(&self, shot: usize) -> &[u64] {
1586        assert!(
1587            self.layout == ShotLayout::ShotMajor,
1588            "shot_words requires ShotMajor layout"
1589        );
1590        let base = shot * self.m_words;
1591        &self.data[base..base + self.m_words]
1592    }
1593
1594    pub fn meas_words(&self, m: usize) -> &[u64] {
1595        assert!(
1596            self.layout == ShotLayout::MeasMajor,
1597            "meas_words requires MeasMajor layout"
1598        );
1599        let base = m * self.s_words;
1600        &self.data[base..base + self.s_words]
1601    }
1602
1603    pub fn from_shot_major(data: Vec<u64>, num_shots: usize, num_measurements: usize) -> Self {
1604        let m_words = num_measurements.div_ceil(64);
1605        let s_words = num_shots.div_ceil(64);
1606        Self {
1607            data,
1608            num_shots,
1609            num_measurements,
1610            m_words,
1611            s_words,
1612            layout: ShotLayout::ShotMajor,
1613        }
1614    }
1615
1616    pub fn from_meas_major(data: Vec<u64>, num_shots: usize, num_measurements: usize) -> Self {
1617        let m_words = num_measurements.div_ceil(64);
1618        let s_words = num_shots.div_ceil(64);
1619        Self {
1620            data,
1621            num_shots,
1622            num_measurements,
1623            m_words,
1624            s_words,
1625            layout: ShotLayout::MeasMajor,
1626        }
1627    }
1628
1629    pub fn raw_data(&self) -> &[u64] {
1630        &self.data
1631    }
1632
1633    pub fn into_data(self) -> Vec<u64> {
1634        self.data
1635    }
1636
1637    #[allow(dead_code)]
1638    pub(crate) fn into_shot_major_data(self) -> Vec<u64> {
1639        if self.layout == ShotLayout::ShotMajor {
1640            return self.data;
1641        }
1642
1643        let mut shot_major = vec![0u64; self.num_shots * self.m_words];
1644        for batch_start in (0..self.num_shots).step_by(64) {
1645            let batch_end = (batch_start + 64).min(self.num_shots);
1646            let batch_len = batch_end - batch_start;
1647            let sw_base = batch_start / 64;
1648            let bit_off = batch_start % 64;
1649            let batch_mask = if batch_len == 64 {
1650                u64::MAX
1651            } else {
1652                (1u64 << batch_len) - 1
1653            };
1654
1655            for m in 0..self.num_measurements {
1656                let mword = m / 64;
1657                let mbit = m % 64;
1658                let meas_row = &self.data[m * self.s_words..(m + 1) * self.s_words];
1659                let mut bits = meas_row[sw_base] >> bit_off;
1660                if bit_off != 0 && sw_base + 1 < self.s_words {
1661                    bits |= meas_row[sw_base + 1] << (64 - bit_off);
1662                }
1663                bits &= batch_mask;
1664
1665                while bits != 0 {
1666                    let shot_in_batch = bits.trailing_zeros() as usize;
1667                    let shot_base = (batch_start + shot_in_batch) * self.m_words;
1668                    shot_major[shot_base + mword] |= 1u64 << mbit;
1669                    bits &= bits - 1;
1670                }
1671            }
1672        }
1673
1674        shot_major
1675    }
1676
1677    pub fn to_shots(&self) -> Vec<Vec<bool>> {
1678        let mut shots = Vec::with_capacity(self.num_shots);
1679        for s in 0..self.num_shots {
1680            let mut shot = Vec::with_capacity(self.num_measurements);
1681            for m in 0..self.num_measurements {
1682                shot.push(self.get_bit(s, m));
1683            }
1684            shots.push(shot);
1685        }
1686        shots
1687    }
1688
1689    /// Compute packed parity rows over this packed shot matrix.
1690    ///
1691    /// Each row lists measurement indices to XOR into one output bit.
1692    pub fn parity_rows(&self, rows: &[Vec<usize>]) -> Result<PackedShots> {
1693        validate_parity_rows(rows, self.num_measurements)?;
1694
1695        match self.layout {
1696            ShotLayout::MeasMajor => {
1697                let mut data = vec![0u64; rows.len() * self.s_words];
1698
1699                #[cfg(feature = "parallel")]
1700                if rows.len() * self.s_words >= 16_384 {
1701                    use rayon::prelude::*;
1702                    data.par_chunks_mut(self.s_words)
1703                        .zip(rows.par_iter())
1704                        .for_each(|(dst, row)| {
1705                            for &measurement in row {
1706                                let src_start = measurement * self.s_words;
1707                                xor_words(dst, &self.data[src_start..src_start + self.s_words]);
1708                            }
1709                        });
1710                    return Ok(PackedShots::from_meas_major(
1711                        data,
1712                        self.num_shots,
1713                        rows.len(),
1714                    ));
1715                }
1716
1717                for (out_idx, row) in rows.iter().enumerate() {
1718                    let dst = &mut data[out_idx * self.s_words..(out_idx + 1) * self.s_words];
1719                    for &measurement in row {
1720                        let src_start = measurement * self.s_words;
1721                        xor_words(dst, &self.data[src_start..src_start + self.s_words]);
1722                    }
1723                }
1724                Ok(PackedShots::from_meas_major(
1725                    data,
1726                    self.num_shots,
1727                    rows.len(),
1728                ))
1729            }
1730            ShotLayout::ShotMajor => {
1731                let out_words = rows.len().div_ceil(64);
1732                let mut data = vec![0u64; self.num_shots * out_words];
1733                for shot in 0..self.num_shots {
1734                    let base = shot * out_words;
1735                    for (out_idx, row) in rows.iter().enumerate() {
1736                        let mut bit = false;
1737                        for &measurement in row {
1738                            bit ^= self.get_bit(shot, measurement);
1739                        }
1740                        if bit {
1741                            data[base + out_idx / 64] |= 1u64 << (out_idx % 64);
1742                        }
1743                    }
1744                }
1745                Ok(PackedShots::from_shot_major(
1746                    data,
1747                    self.num_shots,
1748                    rows.len(),
1749                ))
1750            }
1751        }
1752    }
1753
1754    pub fn counts(&self) -> std::collections::HashMap<Vec<u64>, u64> {
1755        if self.m_words > 8 {
1756            return self.counts_wide();
1757        }
1758        self.counts_packed()
1759            .into_iter()
1760            .map(|(k, v)| (k[..self.m_words].to_vec(), v))
1761            .collect()
1762    }
1763
1764    fn counts_wide(&self) -> std::collections::HashMap<Vec<u64>, u64> {
1765        use std::collections::HashMap;
1766
1767        let mut map = HashMap::new();
1768        let mw = self.m_words;
1769
1770        match self.layout {
1771            ShotLayout::ShotMajor => {
1772                for s in 0..self.num_shots {
1773                    let base = s * mw;
1774                    *map.entry(self.data[base..base + mw].to_vec()).or_insert(0) += 1;
1775                }
1776            }
1777            ShotLayout::MeasMajor => {
1778                let batch_size = 64;
1779                let mut shot_buf = vec![0u64; batch_size * mw];
1780                for batch_start in (0..self.num_shots).step_by(batch_size) {
1781                    let batch_end = (batch_start + batch_size).min(self.num_shots);
1782                    let batch_len = batch_end - batch_start;
1783                    shot_buf[..batch_len * mw].fill(0);
1784
1785                    let sw_base = batch_start / 64;
1786                    let bit_off = batch_start % 64;
1787
1788                    for m in 0..self.num_measurements {
1789                        let mword = m / 64;
1790                        let mbit = m % 64;
1791                        let meas_row = &self.data[m * self.s_words..];
1792                        let word = meas_row[sw_base];
1793                        let shifted = word >> bit_off;
1794                        for s in 0..batch_len {
1795                            if (shifted >> s) & 1 != 0 {
1796                                shot_buf[s * mw + mword] |= 1u64 << mbit;
1797                            }
1798                        }
1799                    }
1800
1801                    for s in 0..batch_len {
1802                        let base = s * mw;
1803                        *map.entry(shot_buf[base..base + mw].to_vec()).or_insert(0) += 1;
1804                    }
1805                }
1806            }
1807        }
1808
1809        map
1810    }
1811
1812    fn counts_packed(&self) -> std::collections::HashMap<[u64; 8], u64> {
1813        use std::collections::HashMap;
1814        let mut map = HashMap::new();
1815        let mw = self.m_words;
1816        debug_assert!(mw <= 8);
1817
1818        match self.layout {
1819            ShotLayout::ShotMajor => {
1820                for s in 0..self.num_shots {
1821                    let base = s * mw;
1822                    let mut key = [0u64; 8];
1823                    key[..mw].copy_from_slice(&self.data[base..base + mw]);
1824                    *map.entry(key).or_insert(0) += 1;
1825                }
1826            }
1827            ShotLayout::MeasMajor => {
1828                let batch_size = 64;
1829                let mut shot_buf = vec![0u64; batch_size * mw];
1830                for batch_start in (0..self.num_shots).step_by(batch_size) {
1831                    let batch_end = (batch_start + batch_size).min(self.num_shots);
1832                    let batch_len = batch_end - batch_start;
1833                    shot_buf[..batch_len * mw].fill(0);
1834
1835                    let sw_base = batch_start / 64;
1836                    let bit_off = batch_start % 64;
1837
1838                    for m in 0..self.num_measurements {
1839                        let mword = m / 64;
1840                        let mbit = m % 64;
1841                        let meas_row = &self.data[m * self.s_words..];
1842                        let word = meas_row[sw_base];
1843                        let shifted = word >> bit_off;
1844                        for s in 0..batch_len {
1845                            if (shifted >> s) & 1 != 0 {
1846                                shot_buf[s * mw + mword] |= 1u64 << mbit;
1847                            }
1848                        }
1849                    }
1850
1851                    for s in 0..batch_len {
1852                        let base = s * mw;
1853                        let mut key = [0u64; 8];
1854                        key[..mw].copy_from_slice(&shot_buf[base..base + mw]);
1855                        *map.entry(key).or_insert(0) += 1;
1856                    }
1857                }
1858            }
1859        }
1860        map
1861    }
1862}
1863
1864fn validate_parity_rows(rows: &[Vec<usize>], num_measurements: usize) -> Result<()> {
1865    for row in rows {
1866        for &measurement in row {
1867            if measurement >= num_measurements {
1868                return Err(PrismError::InvalidParameter {
1869                    message: format!(
1870                        "measurement index {measurement} out of bounds for {num_measurements} measurements"
1871                    ),
1872                });
1873            }
1874        }
1875    }
1876    Ok(())
1877}
1878
1879impl CompiledDetectorSampler {
1880    pub fn num_measurements(&self) -> usize {
1881        self.num_measurements
1882    }
1883
1884    pub fn num_detectors(&self) -> usize {
1885        self.detector_rows.len()
1886    }
1887
1888    pub fn num_observables(&self) -> usize {
1889        self.observable_rows.len()
1890    }
1891
1892    pub fn rank(&self) -> usize {
1893        self.measurement_sampler.rank()
1894    }
1895
1896    pub fn detector_rows(&self) -> &[Vec<usize>] {
1897        &self.detector_rows
1898    }
1899
1900    pub fn observable_rows(&self) -> &[Vec<usize>] {
1901        &self.observable_rows
1902    }
1903
1904    #[cfg(feature = "gpu")]
1905    pub fn with_gpu(mut self, context: std::sync::Arc<crate::gpu::GpuContext>) -> Self {
1906        self.measurement_sampler = self.measurement_sampler.with_gpu(context);
1907        self
1908    }
1909
1910    pub fn sample_measurements_packed(&mut self, num_shots: usize) -> Result<PackedShots> {
1911        self.measurement_sampler.try_sample_bulk_packed(num_shots)
1912    }
1913
1914    pub fn sample_detectors_packed(&mut self, num_shots: usize) -> Result<PackedShots> {
1915        let measurements = self.sample_measurements_packed(num_shots)?;
1916        measurements.parity_rows(&self.detector_rows)
1917    }
1918
1919    pub fn sample_observables_packed(&mut self, num_shots: usize) -> Result<PackedShots> {
1920        let measurements = self.sample_measurements_packed(num_shots)?;
1921        measurements.parity_rows(&self.observable_rows)
1922    }
1923
1924    pub fn sample_packed(&mut self, num_shots: usize) -> Result<DetectorSampleBatch> {
1925        let measurements = self.sample_measurements_packed(num_shots)?;
1926        let detectors = measurements.parity_rows(&self.detector_rows)?;
1927        let observables = measurements.parity_rows(&self.observable_rows)?;
1928        Ok(DetectorSampleBatch {
1929            measurements,
1930            detectors,
1931            observables,
1932        })
1933    }
1934
1935    pub fn sample_detectors_chunked<A: ShotAccumulator>(
1936        &mut self,
1937        total_shots: usize,
1938        acc: &mut A,
1939    ) -> Result<()> {
1940        let chunk_size = default_chunk_size(self.detector_rows.len());
1941        let mut remaining = total_shots;
1942        while remaining > 0 {
1943            let batch = remaining.min(chunk_size);
1944            let packed = self.sample_detectors_packed(batch)?;
1945            acc.accumulate(&packed);
1946            remaining -= batch;
1947        }
1948        Ok(())
1949    }
1950
1951    pub fn sample_detector_counts(
1952        &mut self,
1953        total_shots: usize,
1954    ) -> Result<std::collections::HashMap<Vec<u64>, u64>> {
1955        let mut acc = HistogramAccumulator::new();
1956        self.sample_detectors_chunked(total_shots, &mut acc)?;
1957        Ok(acc.into_counts())
1958    }
1959}
1960
1961/// Device-resident packed shots emitted by the GPU BTS path.
1962///
1963/// The payload stays in measurement-major layout on the device until an
1964/// explicit host copy via [`Self::to_host`]. Marginals and exact counts can be
1965/// reduced first so higher-level workflows do not have to transfer the full
1966/// shot matrix.
1967#[cfg(feature = "gpu")]
1968#[derive(Debug)]
1969pub struct DevicePackedShots {
1970    context: std::sync::Arc<crate::gpu::GpuContext>,
1971    data: crate::gpu::GpuBuffer<u64>,
1972    num_shots: usize,
1973    num_measurements: usize,
1974    m_words: usize,
1975    s_words: usize,
1976    layout: ShotLayout,
1977    rank: usize,
1978}
1979
1980#[cfg(feature = "gpu")]
1981impl DevicePackedShots {
1982    pub fn num_shots(&self) -> usize {
1983        self.num_shots
1984    }
1985
1986    pub fn num_measurements(&self) -> usize {
1987        self.num_measurements
1988    }
1989
1990    pub fn layout(&self) -> ShotLayout {
1991        self.layout
1992    }
1993
1994    pub(crate) fn context(&self) -> &std::sync::Arc<crate::gpu::GpuContext> {
1995        &self.context
1996    }
1997
1998    pub fn m_words(&self) -> usize {
1999        self.m_words
2000    }
2001
2002    pub fn s_words(&self) -> usize {
2003        self.s_words
2004    }
2005
2006    pub(crate) fn data_mut(&mut self) -> &mut crate::gpu::GpuBuffer<u64> {
2007        &mut self.data
2008    }
2009
2010    /// Copy the full packed payload back to host memory.
2011    pub fn to_host(&self) -> Result<PackedShots> {
2012        let len = match self.layout {
2013            ShotLayout::ShotMajor => self.num_shots * self.m_words,
2014            ShotLayout::MeasMajor => self.num_measurements * self.s_words,
2015        };
2016        let mut host = vec![0u64; len];
2017        if len > 0 {
2018            self.data
2019                .copy_to_host(self.context.device(), &mut host)
2020                .map_err(|e| PrismError::BackendUnsupported {
2021                    backend: "gpu".to_string(),
2022                    operation: format!("copy device packed shots to host: {e}"),
2023                })?;
2024        }
2025
2026        Ok(match self.layout {
2027            ShotLayout::ShotMajor => {
2028                PackedShots::from_shot_major(host, self.num_shots, self.num_measurements)
2029            }
2030            ShotLayout::MeasMajor => {
2031                PackedShots::from_meas_major(host, self.num_shots, self.num_measurements)
2032            }
2033        })
2034    }
2035
2036    /// Return per-measurement one-counts without copying the full shot matrix.
2037    pub fn marginal_counts(&self) -> Result<Vec<u64>> {
2038        match self.layout {
2039            ShotLayout::MeasMajor => crate::gpu::kernels::bts::count_meas_major_marginals(
2040                &self.context,
2041                &self.data,
2042                self.num_measurements,
2043                self.num_shots,
2044                self.s_words,
2045            ),
2046            ShotLayout::ShotMajor => Ok(self.to_host()?.counts().into_iter().fold(
2047                vec![0u64; self.num_measurements],
2048                |mut acc, (key, count)| {
2049                    for m in 0..self.num_measurements {
2050                        if (key[m / 64] >> (m % 64)) & 1 != 0 {
2051                            acc[m] += count;
2052                        }
2053                    }
2054                    acc
2055                },
2056            )),
2057        }
2058    }
2059
2060    /// Return per-measurement marginal probabilities.
2061    pub fn marginals(&self) -> Result<Vec<f64>> {
2062        if self.num_shots == 0 {
2063            return Ok(vec![0.0; self.num_measurements]);
2064        }
2065        Ok(self
2066            .marginal_counts()?
2067            .into_iter()
2068            .map(|count| count as f64 / self.num_shots as f64)
2069            .collect())
2070    }
2071
2072    /// Return exact counts, using a GPU reduction when it reduces transfer.
2073    pub fn counts(&self) -> Result<std::collections::HashMap<Vec<u64>, u64>> {
2074        self.counts_with_rank_hint(self.rank)
2075    }
2076
2077    pub(crate) fn counts_with_rank_hint(
2078        &self,
2079        rank_hint: usize,
2080    ) -> Result<std::collections::HashMap<Vec<u64>, u64>> {
2081        match self.layout {
2082            ShotLayout::MeasMajor => {
2083                if let Some(counts) = crate::gpu::kernels::bts::try_count_meas_major(
2084                    &self.context,
2085                    &self.data,
2086                    self.num_measurements,
2087                    self.num_shots,
2088                    self.m_words,
2089                    self.s_words,
2090                    rank_hint,
2091                )? {
2092                    return Ok(counts);
2093                }
2094            }
2095            ShotLayout::ShotMajor => {
2096                let raw_transfer_bytes = self
2097                    .num_shots
2098                    .saturating_mul(self.m_words)
2099                    .saturating_mul(std::mem::size_of::<u64>());
2100                if let Some(counts) = crate::gpu::kernels::bts::try_count_shot_major(
2101                    &self.context,
2102                    &self.data,
2103                    self.num_shots,
2104                    self.m_words,
2105                    rank_hint,
2106                    raw_transfer_bytes,
2107                )? {
2108                    return Ok(counts);
2109                }
2110            }
2111        }
2112        Ok(self.to_host()?.counts())
2113    }
2114}
2115
2116#[cfg(target_arch = "x86_64")]
2117#[target_feature(enable = "avx2")]
2118unsafe fn xor_words_avx2(dst: &mut [u64], src: &[u64]) {
2119    use std::arch::x86_64::*;
2120    let len = dst.len().min(src.len());
2121    let chunks = len / 4;
2122    let dp = dst.as_mut_ptr() as *mut __m256i;
2123    let sp = src.as_ptr() as *const __m256i;
2124    for i in 0..chunks {
2125        let d = _mm256_loadu_si256(dp.add(i));
2126        let s = _mm256_loadu_si256(sp.add(i));
2127        _mm256_storeu_si256(dp.add(i), _mm256_xor_si256(d, s));
2128    }
2129    let tail = chunks * 4;
2130    for i in tail..len {
2131        *dst.get_unchecked_mut(i) ^= *src.get_unchecked(i);
2132    }
2133}
2134
2135#[cfg(target_arch = "aarch64")]
2136#[allow(dead_code)]
2137unsafe fn xor_words_neon(dst: &mut [u64], src: &[u64]) {
2138    use std::arch::aarch64::*;
2139    let len = dst.len().min(src.len());
2140    let chunks = len / 2;
2141    let dp = dst.as_mut_ptr();
2142    let sp = src.as_ptr();
2143    for i in 0..chunks {
2144        let off = i * 2;
2145        let d = vld1q_u64(dp.add(off));
2146        let s = vld1q_u64(sp.add(off));
2147        vst1q_u64(dp.add(off), veorq_u64(d, s));
2148    }
2149    let tail = chunks * 2;
2150    for i in tail..len {
2151        *dst.get_unchecked_mut(i) ^= *src.get_unchecked(i);
2152    }
2153}
2154
2155#[inline(always)]
2156pub(crate) fn xor_words(dst: &mut [u64], src: &[u64]) {
2157    #[cfg(target_arch = "x86_64")]
2158    {
2159        if is_x86_feature_detected!("avx2") && dst.len() >= 4 {
2160            // SAFETY: AVX2 detected, pointers are valid u64 slices
2161            unsafe {
2162                xor_words_avx2(dst, src);
2163            }
2164            return;
2165        }
2166    }
2167    #[cfg(target_arch = "aarch64")]
2168    {
2169        if dst.len() >= 2 {
2170            // SAFETY: NEON is baseline on aarch64, pointers are valid u64 slices
2171            unsafe {
2172                xor_words_neon(dst, src);
2173            }
2174            return;
2175        }
2176    }
2177    for (d, &s) in dst.iter_mut().zip(src) {
2178        *d ^= s;
2179    }
2180}
2181
2182fn gray_code_exact_counts(
2183    sparse: &SparseParity,
2184    rank: usize,
2185    ref_bits: &[u64],
2186    num_measurements: usize,
2187) -> std::collections::HashMap<Vec<u64>, u64> {
2188    use std::collections::HashMap;
2189
2190    let m_words = num_measurements.div_ceil(64);
2191    let mut meas_vec = ref_bits[..m_words].to_vec();
2192    let total: u64 = 1u64 << rank;
2193
2194    let mut col_words: Vec<Vec<u64>> = Vec::with_capacity(rank);
2195    for col in 0..rank {
2196        let mut cw = vec![0u64; m_words];
2197        for m in 0..num_measurements {
2198            let start = sparse.row_offsets[m] as usize;
2199            let end = sparse.row_offsets[m + 1] as usize;
2200            for &c in &sparse.col_indices[start..end] {
2201                if c as usize == col {
2202                    cw[m / 64] |= 1u64 << (m % 64);
2203                }
2204            }
2205        }
2206        col_words.push(cw);
2207    }
2208
2209    let mut counts: HashMap<Vec<u64>, u64> = HashMap::new();
2210    *counts.entry(meas_vec.clone()).or_insert(0) += 1;
2211
2212    for step in 1..total {
2213        let bit_to_flip = step.trailing_zeros() as usize;
2214        let col = &col_words[bit_to_flip];
2215        for (mw, cw) in meas_vec.iter_mut().zip(col.iter()) {
2216            *mw ^= cw;
2217        }
2218        *counts.entry(meas_vec.clone()).or_insert(0) += 1;
2219    }
2220
2221    counts
2222}
2223
2224const MAX_RANK_FOR_GRAY_CODE: usize = 25;
2225const MAX_RANK_FOR_MULTINOMIAL: usize = 22;
2226const MIN_SHOTS_PER_OUTCOME_MULTINOMIAL: usize = 8;
2227const MAX_RANK_FOR_RANK_SPACE: usize = 20;
2228const MIN_SHOTS_PER_OUTCOME: usize = 4;
2229const MAX_LUT_ALLOC_BYTES: u64 = 256 * 1024 * 1024;
2230
2231pub fn compile_forward(circuit: &Circuit, seed: u64) -> Result<CompiledSampler> {
2232    if !circuit.is_clifford_only() {
2233        return Err(PrismError::IncompatibleBackend {
2234            backend: "CompiledSampler".to_string(),
2235            reason: "circuit contains non-Clifford gates".to_string(),
2236        });
2237    }
2238
2239    let measurements: Vec<(usize, usize)> = circuit
2240        .instructions
2241        .iter()
2242        .filter_map(|inst| match inst {
2243            Instruction::Measure {
2244                qubit,
2245                classical_bit,
2246            } => Some((*qubit, *classical_bit)),
2247            _ => None,
2248        })
2249        .collect();
2250
2251    let num_measurements = measurements.len();
2252    if num_measurements == 0 {
2253        return Ok(CompiledSampler {
2254            flip_rows: Vec::new(),
2255            ref_bits_packed: Vec::new(),
2256            rank: 0,
2257            num_measurements: 0,
2258            rng: ChaCha8Rng::seed_from_u64(seed),
2259            lut: None,
2260            sparse: None,
2261            xor_dag: None,
2262            parity_blocks: None,
2263            #[cfg(feature = "gpu")]
2264            gpu_context: None,
2265            #[cfg(feature = "gpu")]
2266            gpu_bts_cache: None,
2267        });
2268    }
2269
2270    let n = circuit.num_qubits;
2271
2272    let (mut xz, mut phase, nw) = colmajor_forward_sim(n, &circuit.instructions)?;
2273    let stride = 2 * nw;
2274    let m = num_measurements;
2275    let m_words = m.div_ceil(64);
2276
2277    let rank_words = m_words;
2278    let total_rows = 2 * n;
2279    let mut gen_dep: Vec<Vec<u64>> = vec![vec![0u64; rank_words]; total_rows + 1];
2280    let mut ref_bits: Vec<bool> = vec![false; m];
2281    let mut rank = 0usize;
2282
2283    let mut flip_rows: Vec<Vec<u64>> = Vec::with_capacity(m);
2284    let mut p_data: Vec<u64> = vec![0u64; stride];
2285    let mut p_dep: Vec<u64> = vec![0u64; rank_words];
2286    let mut scratch: Vec<u64> = vec![0u64; stride];
2287    let scratch_idx = total_rows;
2288
2289    for (meas_idx, &(qubit, _)) in measurements.iter().enumerate() {
2290        let word = qubit / 64;
2291        let bit_mask = 1u64 << (qubit % 64);
2292
2293        let mut p: Option<usize> = None;
2294        for i in n..2 * n {
2295            if xz[i * stride + word] & bit_mask != 0 {
2296                p = Some(i);
2297                break;
2298            }
2299        }
2300
2301        if let Some(p_row) = p {
2302            // Random measurement, this is the k-th random degree of freedom
2303            let k = rank;
2304            rank += 1;
2305            flip_rows.push(vec![0u64; m_words]);
2306
2307            flip_rows[k][meas_idx / 64] |= 1u64 << (meas_idx % 64);
2308
2309            let p_base = p_row * stride;
2310            p_data.copy_from_slice(&xz[p_base..p_base + stride]);
2311            let p_phase = phase[p_row];
2312            p_dep.copy_from_slice(&gen_dep[p_row][..rank_words]);
2313
2314            for r in 0..total_rows {
2315                if r == p_row {
2316                    continue;
2317                }
2318                if xz[r * stride + word] & bit_mask == 0 {
2319                    continue;
2320                }
2321
2322                let r_base = r * stride;
2323                phase[r] = rowmul_phase(&p_data, &mut xz, r_base, nw, p_phase, phase[r]);
2324                xor_words(&mut gen_dep[r][..rank_words], &p_dep[..rank_words]);
2325            }
2326
2327            let dest_idx = p_row - n;
2328            let dest_base = dest_idx * stride;
2329            xz.copy_within(p_row * stride..p_row * stride + stride, dest_base);
2330            phase[dest_idx] = p_phase;
2331            gen_dep[dest_idx][..rank_words].copy_from_slice(&p_dep);
2332
2333            let p_base = p_row * stride;
2334            xz[p_base..p_base + stride].fill(0);
2335            xz[p_base + nw + word] |= bit_mask;
2336            phase[p_row] = false;
2337
2338            gen_dep[p_row][..rank_words].fill(0);
2339            gen_dep[p_row][k / 64] |= 1u64 << (k % 64);
2340
2341            ref_bits[meas_idx] = false;
2342        } else {
2343            scratch[..stride].fill(0);
2344            let mut scratch_phase = false;
2345            gen_dep[scratch_idx][..rank_words].fill(0);
2346
2347            for g in 0..n {
2348                let d_base = g * stride;
2349                if xz[d_base + word] & bit_mask == 0 {
2350                    continue;
2351                }
2352
2353                let s_base = (g + n) * stride;
2354                let s_phase = phase[g + n];
2355                scratch_phase =
2356                    rowmul_phase_into(&xz, s_base, &mut scratch, nw, s_phase, scratch_phase);
2357
2358                let (lo, hi) = gen_dep.split_at_mut(scratch_idx);
2359                for (dst, &src) in hi[0][..rank_words].iter_mut().zip(&lo[g + n][..rank_words]) {
2360                    *dst ^= src;
2361                }
2362            }
2363
2364            ref_bits[meas_idx] = scratch_phase;
2365
2366            for (w, &dep_word) in gen_dep[scratch_idx][..rank_words].iter().enumerate() {
2367                let mut bits = dep_word;
2368                while bits != 0 {
2369                    let bit_pos = bits.trailing_zeros() as usize;
2370                    let k = w * 64 + bit_pos;
2371                    if k < rank {
2372                        flip_rows[k][meas_idx / 64] |= 1u64 << (meas_idx % 64);
2373                    }
2374                    bits &= bits - 1;
2375                }
2376            }
2377        }
2378    }
2379
2380    let num_meas_words = m_words;
2381    minimize_flip_row_weight(&mut flip_rows);
2382
2383    let lut = if rank >= LUT_MIN_RANK {
2384        Some(FlipLut::build(&flip_rows, num_meas_words))
2385    } else {
2386        None
2387    };
2388
2389    let sparse = SparseParity::from_flip_rows(&flip_rows, num_measurements);
2390    let xor_dag = build_xor_dag_if_useful(&sparse);
2391    let ref_bits_packed = pack_bools(&ref_bits);
2392    let parity_blocks = build_parity_blocks_if_useful(&sparse, rank, &ref_bits_packed);
2393
2394    Ok(CompiledSampler {
2395        flip_rows,
2396        ref_bits_packed,
2397        rank,
2398        num_measurements,
2399        rng: ChaCha8Rng::seed_from_u64(seed),
2400        lut,
2401        sparse: Some(sparse),
2402        xor_dag,
2403        parity_blocks,
2404        #[cfg(feature = "gpu")]
2405        gpu_context: None,
2406        #[cfg(feature = "gpu")]
2407        gpu_bts_cache: None,
2408    })
2409}
2410
2411fn compile_measurements_filtered(
2412    circuit: &Circuit,
2413    blocks: &[Vec<usize>],
2414    seed: u64,
2415) -> Result<CompiledSampler> {
2416    let num_global_measurements: usize = circuit
2417        .instructions
2418        .iter()
2419        .filter(|i| matches!(i, Instruction::Measure { .. }))
2420        .count();
2421
2422    if num_global_measurements == 0 {
2423        return Ok(CompiledSampler {
2424            flip_rows: Vec::new(),
2425            ref_bits_packed: Vec::new(),
2426            rank: 0,
2427            num_measurements: 0,
2428            rng: ChaCha8Rng::seed_from_u64(seed),
2429            lut: None,
2430            sparse: None,
2431            xor_dag: None,
2432            parity_blocks: None,
2433            #[cfg(feature = "gpu")]
2434            gpu_context: None,
2435            #[cfg(feature = "gpu")]
2436            gpu_bts_cache: None,
2437        });
2438    }
2439
2440    let mut qubit_to_block: Vec<usize> = vec![0; circuit.num_qubits];
2441    for (bi, block) in blocks.iter().enumerate() {
2442        for &q in block {
2443            qubit_to_block[q] = bi;
2444        }
2445    }
2446
2447    let mut block_samplers: Vec<CompiledSampler> = Vec::with_capacity(blocks.len());
2448    for (bi, block) in blocks.iter().enumerate() {
2449        let (sub_circuit, _qubit_map, _classical_map) = circuit.extract_subcircuit(block);
2450        let block_seed = seed.wrapping_add(bi as u64 * 0x1234_5678);
2451        block_samplers.push(compile_measurements(&sub_circuit, block_seed)?);
2452    }
2453
2454    let mut meas_map: Vec<(usize, usize)> = Vec::with_capacity(num_global_measurements);
2455    let mut block_meas_count: Vec<usize> = vec![0; blocks.len()];
2456    for inst in &circuit.instructions {
2457        if let Instruction::Measure { qubit, .. } = inst {
2458            let bi = qubit_to_block[*qubit];
2459            let local_idx = block_meas_count[bi];
2460            block_meas_count[bi] += 1;
2461            meas_map.push((bi, local_idx));
2462        }
2463    }
2464
2465    let m_words = num_global_measurements.div_ceil(64);
2466    let total_rank: usize = block_samplers.iter().map(|s| s.rank).sum();
2467    let mut ref_bits_packed: Vec<u64> = vec![0u64; num_global_measurements.div_ceil(64)];
2468
2469    for (gi, &(bi, li)) in meas_map.iter().enumerate() {
2470        let src = &block_samplers[bi].ref_bits_packed;
2471        let bit = (src[li / 64] >> (li % 64)) & 1;
2472        if bit != 0 {
2473            ref_bits_packed[gi / 64] |= 1u64 << (gi % 64);
2474        }
2475    }
2476
2477    let mut local_to_global: Vec<Vec<usize>> = vec![Vec::new(); blocks.len()];
2478    for (gi, &(bi, _li)) in meas_map.iter().enumerate() {
2479        local_to_global[bi].push(gi);
2480    }
2481
2482    let mut rank_offsets = Vec::with_capacity(block_samplers.len());
2483    let mut rank_prefix = 0usize;
2484    for sampler in &block_samplers {
2485        rank_offsets.push(rank_prefix);
2486        rank_prefix += sampler.rank;
2487    }
2488    debug_assert_eq!(rank_prefix, total_rank);
2489
2490    let mut flip_rows: Vec<Vec<u64>> = Vec::with_capacity(total_rank);
2491    for (bi, sampler) in block_samplers.iter().enumerate() {
2492        let mapping = &local_to_global[bi];
2493        for local_row in &sampler.flip_rows {
2494            flip_rows.push(project_local_flip_row(local_row, mapping, m_words));
2495        }
2496    }
2497
2498    let lut = if total_rank >= LUT_MIN_RANK {
2499        Some(FlipLut::build(&flip_rows, m_words))
2500    } else {
2501        None
2502    };
2503
2504    let sparse = build_sparse_from_filtered_blocks(
2505        &block_samplers,
2506        &meas_map,
2507        &rank_offsets,
2508        num_global_measurements,
2509    );
2510    let xor_dag = build_xor_dag_if_useful(&sparse);
2511    let parity_blocks =
2512        build_filtered_parity_blocks(&block_samplers, &local_to_global).map(|mut blocks| {
2513            blocks.direct_scatter = true;
2514            blocks
2515        });
2516
2517    Ok(CompiledSampler {
2518        flip_rows,
2519        ref_bits_packed,
2520        rank: total_rank,
2521        num_measurements: num_global_measurements,
2522        rng: ChaCha8Rng::seed_from_u64(seed),
2523        lut,
2524        sparse: Some(sparse),
2525        xor_dag,
2526        parity_blocks,
2527        #[cfg(feature = "gpu")]
2528        gpu_context: None,
2529        #[cfg(feature = "gpu")]
2530        gpu_bts_cache: None,
2531    })
2532}
2533
2534/// Convert reset reuse into fresh qubit aliases and move measurements to the end.
2535pub(crate) fn defer_measure_reset_circuit(circuit: &Circuit) -> Result<Circuit> {
2536    if !circuit.is_clifford_only() {
2537        return Err(PrismError::IncompatibleBackend {
2538            backend: "CompiledSampler".to_string(),
2539            reason: "deferred measurement sampling requires Clifford gates".to_string(),
2540        });
2541    }
2542
2543    let has_measurements = circuit
2544        .instructions
2545        .iter()
2546        .any(|inst| matches!(inst, Instruction::Measure { .. }));
2547    if !has_measurements {
2548        return Err(PrismError::IncompatibleBackend {
2549            backend: "CompiledSampler".to_string(),
2550            reason: "deferred measurement sampling requires measurements".to_string(),
2551        });
2552    }
2553
2554    let mut aliases: Vec<usize> = (0..circuit.num_qubits).collect();
2555    let mut measured_aliases = vec![false; circuit.num_qubits];
2556    let mut next_qubit = circuit.num_qubits;
2557    let mut deferred_measurements: Vec<(usize, usize)> = Vec::new();
2558    let mut transformed = Circuit::new(circuit.num_qubits, circuit.num_classical_bits);
2559
2560    for inst in &circuit.instructions {
2561        match inst {
2562            Instruction::Gate { gate, targets } => {
2563                let mapped = map_deferred_targets(targets, &aliases, &measured_aliases)?;
2564                transformed.instructions.push(Instruction::Gate {
2565                    gate: gate.clone(),
2566                    targets: mapped,
2567                });
2568            }
2569            Instruction::Measure {
2570                qubit,
2571                classical_bit,
2572            } => {
2573                if *qubit >= aliases.len() {
2574                    return Err(PrismError::InvalidQubit {
2575                        index: *qubit,
2576                        register_size: aliases.len(),
2577                    });
2578                }
2579                if *classical_bit >= circuit.num_classical_bits {
2580                    return Err(PrismError::InvalidClassicalBit {
2581                        index: *classical_bit,
2582                        register_size: circuit.num_classical_bits,
2583                    });
2584                }
2585                let alias = aliases[*qubit];
2586                deferred_measurements.push((alias, *classical_bit));
2587                measured_aliases[alias] = true;
2588            }
2589            Instruction::Reset { qubit } => {
2590                if *qubit >= aliases.len() {
2591                    return Err(PrismError::InvalidQubit {
2592                        index: *qubit,
2593                        register_size: aliases.len(),
2594                    });
2595                }
2596                aliases[*qubit] = next_qubit;
2597                next_qubit += 1;
2598                measured_aliases.push(false);
2599                transformed.num_qubits = next_qubit;
2600            }
2601            Instruction::Barrier { qubits } => {
2602                let mut mapped = SmallVec::<[usize; 4]>::with_capacity(qubits.len());
2603                for &q in qubits {
2604                    if q >= aliases.len() {
2605                        return Err(PrismError::InvalidQubit {
2606                            index: q,
2607                            register_size: aliases.len(),
2608                        });
2609                    }
2610                    mapped.push(aliases[q]);
2611                }
2612                transformed
2613                    .instructions
2614                    .push(Instruction::Barrier { qubits: mapped });
2615            }
2616            Instruction::Conditional { .. } => {
2617                return Err(PrismError::IncompatibleBackend {
2618                    backend: "CompiledSampler".to_string(),
2619                    reason: "deferred measurement sampling does not support classical conditionals"
2620                        .to_string(),
2621                });
2622            }
2623        }
2624    }
2625
2626    transformed.num_qubits = next_qubit;
2627    transformed
2628        .instructions
2629        .reserve(deferred_measurements.len());
2630    for (qubit, classical_bit) in deferred_measurements {
2631        transformed.instructions.push(Instruction::Measure {
2632            qubit,
2633            classical_bit,
2634        });
2635    }
2636
2637    Ok(transformed)
2638}
2639
2640fn map_deferred_targets(
2641    targets: &SmallVec<[usize; 4]>,
2642    aliases: &[usize],
2643    measured_aliases: &[bool],
2644) -> Result<SmallVec<[usize; 4]>> {
2645    let mut mapped = SmallVec::<[usize; 4]>::with_capacity(targets.len());
2646    for &target in targets {
2647        if target >= aliases.len() {
2648            return Err(PrismError::InvalidQubit {
2649                index: target,
2650                register_size: aliases.len(),
2651            });
2652        }
2653        let alias = aliases[target];
2654        if measured_aliases[alias] {
2655            return Err(PrismError::IncompatibleBackend {
2656                backend: "CompiledSampler".to_string(),
2657                reason:
2658                    "deferred measurement sampling requires reset before reusing a measured qubit"
2659                        .to_string(),
2660            });
2661        }
2662        mapped.push(alias);
2663    }
2664    Ok(mapped)
2665}
2666
2667/// Compile a Clifford measurement circuit plus detector parity metadata.
2668///
2669/// `detector_rows` and `observable_rows` contain measurement record indices
2670/// in circuit measurement order. Circuits with reset reuse are rewritten to
2671/// terminal measurements before compiling.
2672pub fn compile_detector_sampler(
2673    circuit: &Circuit,
2674    detector_rows: Vec<Vec<usize>>,
2675    observable_rows: Vec<Vec<usize>>,
2676    seed: u64,
2677) -> Result<CompiledDetectorSampler> {
2678    let measurement_circuit = if circuit.has_resets() || !circuit.has_terminal_measurements_only() {
2679        defer_measure_reset_circuit(circuit)?
2680    } else {
2681        circuit.clone()
2682    };
2683
2684    let num_measurements = measurement_circuit
2685        .instructions
2686        .iter()
2687        .filter(|inst| matches!(inst, Instruction::Measure { .. }))
2688        .count();
2689    validate_parity_rows(&detector_rows, num_measurements)?;
2690    validate_parity_rows(&observable_rows, num_measurements)?;
2691
2692    let measurement_sampler = compile_measurements(&measurement_circuit, seed)?;
2693    Ok(CompiledDetectorSampler {
2694        measurement_sampler,
2695        detector_rows,
2696        observable_rows,
2697        num_measurements,
2698    })
2699}
2700
2701/// Compile a Clifford circuit's measurements into a fast sampler.
2702///
2703/// Selects forward (SGI stabilizer + dependency tracking) or backward (Pauli
2704/// propagation + Gaussian elimination) based on circuit depth. Forward wins
2705/// for deep circuits (gate_count >= 5×measurements).
2706/// Requires terminal measurements and does not support reset or conditional
2707/// operations on measured data.
2708pub fn compile_measurements(circuit: &Circuit, seed: u64) -> Result<CompiledSampler> {
2709    if !circuit.is_clifford_only() {
2710        return Err(PrismError::IncompatibleBackend {
2711            backend: "CompiledSampler".to_string(),
2712            reason: "circuit contains non-Clifford gates".to_string(),
2713        });
2714    }
2715
2716    let has_measurements = circuit
2717        .instructions
2718        .iter()
2719        .any(|inst| matches!(inst, Instruction::Measure { .. }));
2720    if has_measurements && circuit.has_resets() {
2721        return Err(PrismError::IncompatibleBackend {
2722            backend: "CompiledSampler".to_string(),
2723            reason: "compiled measurement sampling does not support reset instructions".to_string(),
2724        });
2725    }
2726    if has_measurements && !circuit.has_terminal_measurements_only() {
2727        return Err(PrismError::IncompatibleBackend {
2728            backend: "CompiledSampler".to_string(),
2729            reason: "compiled measurement sampling requires terminal measurements and does not support classical conditionals".to_string(),
2730        });
2731    }
2732
2733    if circuit.num_qubits >= 4 {
2734        let blocks = circuit.independent_subsystems();
2735        if blocks.len() > 1 {
2736            let max_block = blocks.iter().map(|b| b.len()).max().unwrap_or(0);
2737            if max_block < circuit.num_qubits {
2738                return compile_measurements_filtered(circuit, &blocks, seed);
2739            }
2740        }
2741    }
2742
2743    if circuit.num_qubits >= 64 {
2744        return compile_forward(circuit, seed);
2745    }
2746
2747    let measurement_rows = build_measurement_rows(circuit);
2748    let num_measurements = measurement_rows.len();
2749
2750    if num_measurements == 0 {
2751        return Ok(CompiledSampler {
2752            flip_rows: Vec::new(),
2753            ref_bits_packed: Vec::new(),
2754            rank: 0,
2755            num_measurements: 0,
2756            rng: ChaCha8Rng::seed_from_u64(seed),
2757            lut: None,
2758            sparse: None,
2759            xor_dag: None,
2760            parity_blocks: None,
2761            #[cfg(feature = "gpu")]
2762            gpu_context: None,
2763            #[cfg(feature = "gpu")]
2764            gpu_bts_cache: None,
2765        });
2766    }
2767
2768    let n = circuit.num_qubits;
2769
2770    let x_rows: Vec<Vec<u64>> = measurement_rows
2771        .iter()
2772        .map(|(p, _, _)| p.x.clone())
2773        .collect();
2774    let signs: Vec<bool> = measurement_rows.iter().map(|(_, _, s)| *s).collect();
2775
2776    let mut x_copy = x_rows.clone();
2777    let (rank, pivot_cols) = gaussian_eliminate(&mut x_copy, n);
2778
2779    let gate_count = circuit
2780        .instructions
2781        .iter()
2782        .filter(|i| {
2783            matches!(
2784                i,
2785                Instruction::Gate { .. } | Instruction::Conditional { .. }
2786            )
2787        })
2788        .count();
2789
2790    let ref_bits: Vec<bool> = if gate_count > 2 * num_measurements {
2791        let mini_outcomes = compute_reference_bits(&measurement_rows, n);
2792        mini_outcomes
2793            .iter()
2794            .zip(signs.iter())
2795            .map(|(&outcome, &sign)| outcome ^ sign)
2796            .collect()
2797    } else {
2798        use crate::backend::stabilizer::StabilizerBackend;
2799        use crate::backend::Backend;
2800        let mut stab = StabilizerBackend::new(seed);
2801        stab.init(circuit.num_qubits, circuit.num_classical_bits)?;
2802        stab.apply_instructions(&circuit.instructions)?;
2803        let ref_classical = stab.classical_results().to_vec();
2804        let classical_bit_order: Vec<usize> = measurement_rows.iter().map(|(_, c, _)| *c).collect();
2805        classical_bit_order
2806            .iter()
2807            .map(|&cbit| {
2808                if cbit < ref_classical.len() {
2809                    ref_classical[cbit]
2810                } else {
2811                    false
2812                }
2813            })
2814            .collect()
2815    };
2816
2817    let num_meas_words = num_measurements.div_ceil(64);
2818    let mut flip_rows: Vec<Vec<u64>> = vec![vec![0u64; num_meas_words]; rank];
2819
2820    for (j, &pcol) in pivot_cols.iter().enumerate() {
2821        for (i, x_row) in x_rows.iter().enumerate() {
2822            if get_bit(x_row, pcol) {
2823                flip_rows[j][i / 64] |= 1u64 << (i % 64);
2824            }
2825        }
2826    }
2827
2828    minimize_flip_row_weight(&mut flip_rows);
2829
2830    let lut = if rank >= LUT_MIN_RANK {
2831        Some(FlipLut::build(&flip_rows, num_meas_words))
2832    } else {
2833        None
2834    };
2835
2836    let sparse = SparseParity::from_flip_rows(&flip_rows, num_measurements);
2837    let xor_dag = build_xor_dag_if_useful(&sparse);
2838    let ref_bits_packed = pack_bools(&ref_bits);
2839    let parity_blocks = build_parity_blocks_if_useful(&sparse, rank, &ref_bits_packed);
2840
2841    Ok(CompiledSampler {
2842        flip_rows,
2843        ref_bits_packed,
2844        rank,
2845        num_measurements,
2846        rng: ChaCha8Rng::seed_from_u64(seed),
2847        lut,
2848        sparse: Some(sparse),
2849        xor_dag,
2850        parity_blocks,
2851        #[cfg(feature = "gpu")]
2852        gpu_context: None,
2853        #[cfg(feature = "gpu")]
2854        gpu_bts_cache: None,
2855    })
2856}
2857
2858/// Sample shots via the compiled (Heisenberg-picture) path.
2859///
2860/// Returns `Vec<Vec<bool>>`, inherently O(num_shots) memory.
2861/// For bounded-memory streaming at large shot counts, use
2862/// `compile_measurements` + `sample_chunked` / `sample_counts` directly.
2863///
2864/// Requires the same circuit subset as [`compile_measurements`], namely
2865/// terminal measurements with no resets or classical conditionals.
2866pub fn run_shots_compiled(circuit: &Circuit, num_shots: usize, seed: u64) -> Result<ShotsResult> {
2867    let mut sampler = compile_measurements(circuit, seed)?;
2868    let packed = sampler.sample_bulk_packed(num_shots);
2869    Ok(ShotsResult {
2870        shots: packed.to_shots(),
2871        num_classical_bits: circuit.num_classical_bits,
2872    })
2873}
2874
2875/// Like [`run_shots_compiled`] but routes BTS sampling through the GPU when the
2876/// circuit compiles to a flat sparse parity matrix and `num_shots` reaches the
2877/// GPU BTS threshold. The threshold defaults to
2878/// [`crate::gpu::BTS_MIN_SHOTS_DEFAULT`] and can be overridden with
2879/// `PRISM_GPU_BTS_MIN_SHOTS`. Requires the `gpu` feature and a working CUDA
2880/// context.
2881#[cfg(feature = "gpu")]
2882pub fn run_shots_compiled_with_gpu(
2883    circuit: &Circuit,
2884    num_shots: usize,
2885    seed: u64,
2886    context: std::sync::Arc<crate::gpu::GpuContext>,
2887) -> Result<ShotsResult> {
2888    let mut sampler = compile_measurements(circuit, seed)?.with_gpu(context);
2889    let packed = sampler.sample_bulk_packed(num_shots);
2890    Ok(ShotsResult {
2891        shots: packed.to_shots(),
2892        num_classical_bits: circuit.num_classical_bits,
2893    })
2894}