Skip to main content

prism_q/sim/
homological.rs

1use crate::circuit::{Circuit, Instruction};
2use crate::error::Result;
3use crate::sim::compiled::batch_propagate_backward;
4use crate::sim::compiled::{default_chunk_size, xor_words, PackedShots, ShotAccumulator};
5use crate::sim::noise::NoiseModel;
6use crate::sim::ShotsResult;
7use rand::SeedableRng;
8use rand_chacha::ChaCha8Rng;
9
10/// Dense binary matrix over GF(2) stored as packed u64 words per row.
11/// Row-major: row i is stored in words[i * row_words .. (i+1) * row_words].
12struct F2DenseMatrix {
13    #[cfg(test)]
14    num_rows: usize,
15    #[cfg(test)]
16    num_cols: usize,
17    row_words: usize,
18    data: Vec<u64>,
19}
20
21impl F2DenseMatrix {
22    fn new(num_rows: usize, num_cols: usize) -> Self {
23        let row_words = num_cols.div_ceil(64);
24        Self {
25            #[cfg(test)]
26            num_rows,
27            #[cfg(test)]
28            num_cols,
29            row_words,
30            data: vec![0u64; num_rows * row_words],
31        }
32    }
33
34    #[inline(always)]
35    fn set(&mut self, row: usize, col: usize) {
36        self.data[row * self.row_words + col / 64] |= 1u64 << (col % 64);
37    }
38
39    #[inline(always)]
40    fn get(&self, row: usize, col: usize) -> bool {
41        (self.data[row * self.row_words + col / 64] >> (col % 64)) & 1 != 0
42    }
43
44    #[cfg(test)]
45    fn row(&self, row: usize) -> &[u64] {
46        let start = row * self.row_words;
47        &self.data[start..start + self.row_words]
48    }
49
50    #[cfg(test)]
51    fn xor_row(&mut self, dst_row: usize, src_row: usize) {
52        let rw = self.row_words;
53        let (dst_start, src_start) = (dst_row * rw, src_row * rw);
54        if dst_start < src_start {
55            let (left, right) = self.data.split_at_mut(src_start);
56            for w in 0..rw {
57                left[dst_start + w] ^= right[w];
58            }
59        } else {
60            let (left, right) = self.data.split_at_mut(dst_start);
61            for w in 0..rw {
62                right[w] ^= left[src_start + w];
63            }
64        }
65    }
66
67    #[cfg(test)]
68    fn swap_rows(&mut self, a: usize, b: usize) {
69        if a == b {
70            return;
71        }
72        let rw = self.row_words;
73        let (a_start, b_start) = (a * rw, b * rw);
74        for w in 0..rw {
75            self.data.swap(a_start + w, b_start + w);
76        }
77    }
78}
79
80/// Compute the kernel (null space) of a binary matrix over GF(2).
81///
82/// Given M ∈ F₂^{m×n}, returns a basis for ker(M) = {x ∈ F₂^n : Mx = 0}.
83/// Uses row reduction on the augmented matrix [M | I_n]^T approach:
84/// transpose M, row-reduce M^T, read off kernel vectors.
85///
86/// Returns: Vec of kernel basis vectors, each as a Vec<u64> packed bitvector of length n.
87#[cfg(test)]
88fn gf2_kernel(matrix: &F2DenseMatrix) -> Vec<Vec<u64>> {
89    let m = matrix.num_rows;
90    let n = matrix.num_cols;
91    let n_words = n.div_ceil(64);
92
93    let aug_cols = m + n;
94    let mut aug = F2DenseMatrix::new(n, aug_cols);
95
96    for r in 0..m {
97        for c in 0..n {
98            if matrix.get(r, c) {
99                aug.set(c, r);
100            }
101        }
102    }
103    for i in 0..n {
104        aug.set(i, m + i);
105    }
106
107    let mut pivot_row = 0;
108    for col in 0..m {
109        let mut found = None;
110        for r in pivot_row..n {
111            if aug.get(r, col) {
112                found = Some(r);
113                break;
114            }
115        }
116        let Some(pr) = found else { continue };
117
118        aug.swap_rows(pivot_row, pr);
119
120        for r in 0..n {
121            if r != pivot_row && aug.get(r, col) {
122                aug.xor_row(r, pivot_row);
123            }
124        }
125        pivot_row += 1;
126    }
127
128    let mut kernel = Vec::new();
129    let m_words = m.div_ceil(64);
130    for r in 0..n {
131        let row = aug.row(r);
132        let mt_zero = row[..m_words].iter().enumerate().all(|(w, &val)| {
133            if w == m_words - 1 && m % 64 != 0 {
134                val & ((1u64 << (m % 64)) - 1) == 0
135            } else {
136                val == 0
137            }
138        });
139        if mt_zero {
140            let mut kv = vec![0u64; n_words];
141            for c in 0..n {
142                if aug.get(r, m + c) {
143                    kv[c / 64] |= 1u64 << (c % 64);
144                }
145            }
146            kernel.push(kv);
147        }
148    }
149
150    kernel
151}
152
153/// Error chain complex for a Clifford circuit with noise.
154///
155/// Represents the chain complex C₂ →∂₂→ C₁ →∂₁→ C₀ where:
156/// - C₀ = F₂^m (measurement/detector space)
157/// - C₁ = F₂^p (error location space)
158/// - C₂ = F₂^s (stabilizer space)
159/// - ∂₁ = E (the m×p error propagation matrix)
160pub struct ErrorChainComplex {
161    /// E-matrix: m × p binary matrix. E[d][e] = 1 if error e flips measurement d.
162    e_matrix: F2DenseMatrix,
163    /// Error probabilities: p_total[e] = px + py + pz for error location e.
164    error_probs: Vec<f64>,
165    /// Number of measurements (detectors)
166    num_measurements: usize,
167    /// Number of error locations
168    num_errors: usize,
169    /// dim(im(∂₂) ∩ ker(∂₁)): stabilizer generators undetectable by measurements
170    boundary_dim: usize,
171    /// dim(H₁) = dim(ker(∂₁)/im(∂₂)): independent logical error classes
172    homology_dim: usize,
173}
174
175/// Precomputed sampler for O(r + 1) per-shot noisy measurement sampling.
176///
177/// Combines a compiled sampler (quantum randomness, O(r) per shot) with
178/// precomputed syndrome class probabilities (noise randomness, O(1) per shot).
179/// The syndrome classes are elements of im(E) ⊆ F₂^m where E is the
180/// error-to-measurement propagation matrix.
181pub struct HomologicalSampler {
182    /// Compiled sampler for quantum randomness (noiseless measurement distribution)
183    compiled: crate::sim::compiled::CompiledSampler,
184    /// Syndrome rank = dim(im(E))
185    syndrome_rank: usize,
186    /// 2^r class probabilities (for diagnostics)
187    #[allow(dead_code)]
188    class_probs: Vec<f64>,
189    /// 2^r cumulative probabilities for sampling
190    class_cdf: Vec<f64>,
191    /// 2^r detection signatures: for class c, which measurements are flipped.
192    /// Stored as packed u64 vectors, each of length ceil(m/64).
193    class_detections: Vec<Vec<u64>>,
194    /// dim(im(∂₂) ∩ ker(∂₁)): undetectable stabilizer generators
195    boundary_dim: usize,
196    /// dim(H₁ = ker(∂₁)/im(∂₂)): independent logical error classes
197    homology_dim: usize,
198    /// RNG for noise sampling
199    rng: ChaCha8Rng,
200}
201
202impl ErrorChainComplex {
203    /// Build the error chain complex from a Clifford circuit and noise model.
204    ///
205    /// Uses backward Pauli propagation (same as the compiled noisy sampler)
206    /// to determine which measurements are sensitive to each error location.
207    pub fn build(circuit: &Circuit, noise: &NoiseModel, _seed: u64) -> Result<Self> {
208        let m = circuit
209            .instructions
210            .iter()
211            .filter(|i| matches!(i, Instruction::Measure { .. }))
212            .count();
213        if m == 0 {
214            return Ok(Self {
215                e_matrix: F2DenseMatrix::new(0, 0),
216                error_probs: Vec::new(),
217                num_measurements: 0,
218                num_errors: 0,
219                boundary_dim: circuit.num_qubits,
220                homology_dim: 0,
221            });
222        }
223
224        let m_words = m.div_ceil(64);
225        let n = circuit.num_qubits;
226
227        let mut x_packed: Vec<Vec<u64>> = vec![vec![0u64; m_words]; n];
228        let mut z_packed: Vec<Vec<u64>> = vec![vec![0u64; m_words]; n];
229        let mut sign_packed = vec![0u64; m_words];
230
231        let mut meas_idx = m;
232        for instr in circuit.instructions.iter().rev() {
233            if let Instruction::Measure { qubit, .. } = instr {
234                meas_idx -= 1;
235                let word = meas_idx / 64;
236                let bit = meas_idx % 64;
237                z_packed[*qubit][word] |= 1u64 << bit;
238            }
239        }
240
241        let mut error_probs = Vec::new();
242        let mut e_cols: Vec<Vec<u64>> = Vec::new();
243
244        for (instr_idx, instr) in circuit.instructions.iter().enumerate().rev() {
245            match instr {
246                Instruction::Gate { gate, targets } => {
247                    let noise_events = &noise.after_gate[instr_idx];
248                    for event in noise_events {
249                        let (px, py, pz) = event.pauli_probs();
250                        let q = event.qubit();
251                        let p_total = px + py + pz;
252                        if p_total < 1e-15 {
253                            continue;
254                        }
255
256                        let x_sens = &z_packed[q];
257                        let z_sens = &x_packed[q];
258
259                        if px > 1e-15 && x_sens.iter().any(|&w| w != 0) {
260                            error_probs.push(px);
261                            e_cols.push(x_sens.clone());
262                        }
263
264                        if pz > 1e-15 && z_sens.iter().any(|&w| w != 0) {
265                            error_probs.push(pz);
266                            e_cols.push(z_sens.clone());
267                        }
268
269                        if py > 1e-15 {
270                            let mut y_sens = vec![0u64; m_words];
271                            for w in 0..m_words {
272                                y_sens[w] = x_sens[w] ^ z_sens[w];
273                            }
274                            if y_sens.iter().any(|&w| w != 0) {
275                                error_probs.push(py);
276                                e_cols.push(y_sens);
277                            }
278                        }
279                    }
280
281                    batch_propagate_backward(
282                        &mut x_packed,
283                        &mut z_packed,
284                        &mut sign_packed,
285                        gate,
286                        targets.as_slice(),
287                        m_words,
288                    );
289                }
290                Instruction::Measure { .. }
291                | Instruction::Reset { .. }
292                | Instruction::Barrier { .. } => {}
293                Instruction::Conditional { gate, targets, .. } => {
294                    batch_propagate_backward(
295                        &mut x_packed,
296                        &mut z_packed,
297                        &mut sign_packed,
298                        gate,
299                        targets.as_slice(),
300                        m_words,
301                    );
302                }
303            }
304        }
305
306        let p = error_probs.len();
307        let mut e_matrix = F2DenseMatrix::new(m, p);
308
309        for (col, col_data) in e_cols.iter().enumerate() {
310            for (w, &word) in col_data.iter().enumerate() {
311                if word == 0 {
312                    continue;
313                }
314                let base = w * 64;
315                let mut bits = word;
316                while bits != 0 {
317                    let bit = bits.trailing_zeros() as usize;
318                    let row = base + bit;
319                    if row < m {
320                        e_matrix.set(row, col);
321                    }
322                    bits &= bits - 1;
323                }
324            }
325        }
326
327        let (boundary_dim, homology_dim) = Self::compute_boundary_space(circuit, n);
328
329        Ok(Self {
330            e_matrix,
331            error_probs,
332            num_measurements: m,
333            num_errors: p,
334            boundary_dim,
335            homology_dim,
336        })
337    }
338
339    /// Forward-propagate stabilizer generators and compute ∂₂ boundary space.
340    ///
341    /// Returns (boundary_dim, homology_dim) where:
342    /// - boundary_dim = dim(im(∂₂) ∩ ker(∂₁)) = stabilizers undetectable by measurements
343    /// - homology_dim = dim(H₁) = independent logical error classes
344    ///
345    /// Algorithm: forward-propagate Z_0,...,Z_{n-1} through the circuit to get
346    /// output stabilizer generators. Build X-projection onto measured qubits.
347    /// rank(X_proj) counts stabilizers with detectable X components;
348    /// H₁ = ker(σ) / (S ∩ ker(σ)) has dim = n - num_measured + rank(X_proj).
349    fn compute_boundary_space(circuit: &Circuit, n: usize) -> (usize, usize) {
350        if n == 0 {
351            return (0, 0);
352        }
353
354        let n_words = n.div_ceil(64);
355        let mut stab_x: Vec<Vec<u64>> = vec![vec![0u64; n_words]; n];
356        let mut stab_z: Vec<Vec<u64>> = vec![vec![0u64; n_words]; n];
357        let mut stab_sign = vec![0u64; n_words];
358
359        for i in 0..n {
360            stab_z[i][i / 64] |= 1u64 << (i % 64);
361        }
362
363        for instr in circuit.instructions.iter() {
364            match instr {
365                Instruction::Gate { gate, targets } => {
366                    batch_propagate_backward(
367                        &mut stab_x,
368                        &mut stab_z,
369                        &mut stab_sign,
370                        gate,
371                        targets.as_slice(),
372                        n_words,
373                    );
374                }
375                Instruction::Conditional { gate, targets, .. } => {
376                    batch_propagate_backward(
377                        &mut stab_x,
378                        &mut stab_z,
379                        &mut stab_sign,
380                        gate,
381                        targets.as_slice(),
382                        n_words,
383                    );
384                }
385                _ => {}
386            }
387        }
388
389        let mut measured = vec![false; n];
390        for instr in &circuit.instructions {
391            if let Instruction::Measure { qubit, .. } = instr {
392                measured[*qubit] = true;
393            }
394        }
395        let num_measured = measured.iter().filter(|&&b| b).count();
396        let measured_indices: Vec<usize> = (0..n).filter(|&q| measured[q]).collect();
397
398        if num_measured == 0 {
399            return (n, 0);
400        }
401
402        let proj_words = num_measured.div_ceil(64);
403        let mut proj = vec![0u64; n * proj_words];
404
405        for stab_idx in 0..n {
406            for (proj_col, &q) in measured_indices.iter().enumerate() {
407                let x_bit = (stab_x[q][stab_idx / 64] >> (stab_idx % 64)) & 1;
408                if x_bit != 0 {
409                    proj[stab_idx * proj_words + proj_col / 64] |= 1u64 << (proj_col % 64);
410                }
411            }
412        }
413
414        let mut rank = 0;
415        let mut pivot_row = 0;
416        for col in 0..num_measured {
417            let mut found = None;
418            for r in pivot_row..n {
419                if (proj[r * proj_words + col / 64] >> (col % 64)) & 1 != 0 {
420                    found = Some(r);
421                    break;
422                }
423            }
424            let Some(pr) = found else { continue };
425
426            if pr != pivot_row {
427                for w in 0..proj_words {
428                    proj.swap(pivot_row * proj_words + w, pr * proj_words + w);
429                }
430            }
431
432            for r in 0..n {
433                if r != pivot_row && (proj[r * proj_words + col / 64] >> (col % 64)) & 1 != 0 {
434                    for w in 0..proj_words {
435                        proj[r * proj_words + w] ^= proj[pivot_row * proj_words + w];
436                    }
437                }
438            }
439
440            pivot_row += 1;
441            rank += 1;
442        }
443
444        let boundary_dim = n - rank;
445        let homology_dim = n - num_measured + rank;
446        (boundary_dim, homology_dim)
447    }
448
449    pub fn boundary_dim(&self) -> usize {
450        self.boundary_dim
451    }
452
453    pub fn homology_dim(&self) -> usize {
454        self.homology_dim
455    }
456
457    /// Compute exact noisy marginals analytically. No sampling, no rank limit.
458    ///
459    /// For each measurement j, the noisy probability is:
460    ///   p_j^noisy = p_j + (1 - 2·p_j) · (1 - f_j) / 2
461    /// where f_j = Π_{e: E(j,e)=1} (1 - 2·p_e) is the flip attenuation factor
462    /// and p_j is the noiseless marginal (0, 0.5, or 1).
463    ///
464    /// Cost: O(nnz(E)). Works for any qubit count.
465    pub fn noisy_marginals(&self, noiseless_marginals: &[f64]) -> Vec<f64> {
466        let m = self.num_measurements;
467        let p = self.num_errors;
468        if m == 0 || p == 0 {
469            return noiseless_marginals.to_vec();
470        }
471
472        let mut flip_factor = vec![1.0f64; m];
473        let rw = self.e_matrix.row_words;
474
475        for e in 0..p {
476            let factor = 1.0 - 2.0 * self.error_probs[e];
477            if (factor - 1.0).abs() < 1e-15 {
478                continue;
479            }
480
481            let col_word = e / 64;
482            let col_bit = 1u64 << (e % 64);
483
484            for (j, ff) in flip_factor.iter_mut().enumerate() {
485                if self.e_matrix.data[j * rw + col_word] & col_bit != 0 {
486                    *ff *= factor;
487                }
488            }
489        }
490
491        let mut result = Vec::with_capacity(m);
492        for j in 0..m {
493            let p_j = noiseless_marginals[j];
494            let p_flip = (1.0 - flip_factor[j]) / 2.0;
495            result.push(p_j + (1.0 - 2.0 * p_j) * p_flip);
496        }
497        result
498    }
499}
500
501impl HomologicalSampler {
502    /// Build a sampler from a circuit and noise model.
503    ///
504    /// Computes the E-matrix (error-to-measurement propagation), finds a basis
505    /// for im(E), and precomputes 2^r syndrome class probabilities where
506    /// r = rank(E). Also builds a compiled sampler for quantum randomness.
507    ///
508    /// Total per-shot cost: O(r_quantum + 1) where r_quantum is the stabilizer
509    /// rank (number of random measurements), versus O(p) for brute-force
510    /// where p is the number of error locations.
511    pub fn compile(circuit: &Circuit, noise: &NoiseModel, seed: u64) -> Result<Self> {
512        let ecc = ErrorChainComplex::build(circuit, noise, seed)?;
513        let m = ecc.num_measurements;
514        let p = ecc.num_errors;
515        let compiled = crate::sim::compiled::compile_measurements(circuit, seed)?;
516
517        if m == 0 || p == 0 {
518            return Ok(Self {
519                compiled,
520                syndrome_rank: 0,
521                class_probs: vec![1.0],
522                class_cdf: vec![1.0],
523                class_detections: vec![vec![0u64; m.div_ceil(64)]],
524                boundary_dim: ecc.boundary_dim,
525                homology_dim: ecc.homology_dim,
526                rng: ChaCha8Rng::seed_from_u64(seed),
527            });
528        }
529
530        let m_words = m.div_ceil(64);
531
532        let mut work = ecc.e_matrix.data.clone();
533        let rw = ecc.e_matrix.row_words;
534        let mut pivot_cols = Vec::new();
535        let mut pivot_row = 0;
536
537        for col in 0..p {
538            let mut found = None;
539            for r in pivot_row..m {
540                if (work[r * rw + col / 64] >> (col % 64)) & 1 != 0 {
541                    found = Some(r);
542                    break;
543                }
544            }
545            let Some(pr) = found else { continue };
546
547            if pr != pivot_row {
548                for w in 0..rw {
549                    work.swap(pivot_row * rw + w, pr * rw + w);
550                }
551            }
552
553            for r in 0..m {
554                if r != pivot_row && (work[r * rw + col / 64] >> (col % 64)) & 1 != 0 {
555                    for w in 0..rw {
556                        work[r * rw + w] ^= work[pivot_row * rw + w];
557                    }
558                }
559            }
560
561            pivot_cols.push(col);
562            pivot_row += 1;
563        }
564
565        let r = pivot_cols.len();
566        if r > 20 {
567            return Err(crate::error::PrismError::IncompatibleBackend {
568                backend: "HomologicalSampler".to_string(),
569                reason: format!("syndrome rank {} too large (max 20)", r),
570            });
571        }
572
573        // Extract r-bit coordinates from RREF: col j's coordinate at basis i
574        // is work[i][j] in the reduced matrix.
575        let mut col_coords = vec![0usize; p];
576        for (basis_idx, &_pivot_col) in pivot_cols.iter().enumerate() {
577            for j in 0..p {
578                if (work[basis_idx * rw + j / 64] >> (j % 64)) & 1 != 0 {
579                    col_coords[j] |= 1 << basis_idx;
580                }
581            }
582        }
583
584        let num_classes = 1usize << r;
585        let mut class_detections = Vec::with_capacity(num_classes);
586        for c in 0..num_classes {
587            let mut det = vec![0u64; m_words];
588            for (basis_idx, &pivot_col) in pivot_cols.iter().enumerate() {
589                if (c >> basis_idx) & 1 != 0 {
590                    for row in 0..m {
591                        if ecc.e_matrix.get(row, pivot_col) {
592                            det[row / 64] ^= 1u64 << (row % 64);
593                        }
594                    }
595                }
596            }
597            class_detections.push(det);
598        }
599
600        // F₂^r probability convolution: P[c] = (1-p_j) P[c] + p_j P[c ⊕ coord_j]
601        let mut class_probs = vec![0.0_f64; num_classes];
602        class_probs[0] = 1.0;
603
604        for (j, &coord) in col_coords.iter().enumerate() {
605            let pj = ecc.error_probs[j];
606            if pj < 1e-15 {
607                continue;
608            }
609            if coord == 0 {
610                continue;
611            }
612            let mut new_probs = vec![0.0_f64; num_classes];
613            for c in 0..num_classes {
614                new_probs[c] = (1.0 - pj) * class_probs[c] + pj * class_probs[c ^ coord];
615            }
616            class_probs = new_probs;
617        }
618
619        let mut class_cdf = vec![0.0_f64; num_classes];
620        class_cdf[0] = class_probs[0];
621        for c in 1..num_classes {
622            class_cdf[c] = class_cdf[c - 1] + class_probs[c];
623        }
624        let total = class_cdf[num_classes - 1];
625        if total > 0.0 {
626            for v in &mut class_cdf {
627                *v /= total;
628            }
629        }
630
631        Ok(Self {
632            compiled,
633            syndrome_rank: r,
634            class_probs,
635            class_cdf,
636            class_detections,
637            boundary_dim: ecc.boundary_dim,
638            homology_dim: ecc.homology_dim,
639            rng: ChaCha8Rng::seed_from_u64(seed),
640        })
641    }
642
643    pub fn syndrome_rank(&self) -> usize {
644        self.syndrome_rank
645    }
646
647    pub fn boundary_dim(&self) -> usize {
648        self.boundary_dim
649    }
650
651    pub fn homology_dim(&self) -> usize {
652        self.homology_dim
653    }
654
655    /// Sample a single shot: returns measurement outcomes.
656    ///
657    /// Cost: O(r_quantum) for compiled sampler + O(1) for noise class lookup.
658    pub fn sample(&mut self) -> Vec<bool> {
659        let mut outcome = self.compiled.sample();
660
661        let u: f64 = rand::Rng::random(&mut self.rng);
662        let class = match self
663            .class_cdf
664            .binary_search_by(|p| p.partial_cmp(&u).unwrap_or(std::cmp::Ordering::Equal))
665        {
666            Ok(i) => i,
667            Err(i) => i.min(self.class_cdf.len() - 1),
668        };
669
670        let det = &self.class_detections[class];
671        for (mi, bit) in outcome.iter_mut().enumerate() {
672            let det_bit = (det[mi / 64] >> (mi % 64)) & 1 != 0;
673            *bit ^= det_bit;
674        }
675        outcome
676    }
677
678    /// Sample multiple shots.
679    pub fn sample_bulk(&mut self, num_shots: usize) -> Vec<Vec<bool>> {
680        (0..num_shots).map(|_| self.sample()).collect()
681    }
682
683    pub fn sample_packed(&mut self, num_shots: usize) -> PackedShots {
684        let m = self.compiled.num_measurements();
685        let m_words = m.div_ceil(64);
686        if num_shots == 0 || m == 0 {
687            return PackedShots::from_shot_major(Vec::new(), num_shots, m);
688        }
689
690        let mut accum = Vec::new();
691        let mut rand_buf = Vec::new();
692        self.compiled
693            .sample_bulk_words_shot_major_reuse(&mut accum, &mut rand_buf, num_shots);
694
695        let ref_bits = self.compiled.ref_bits_packed();
696        for s in 0..num_shots {
697            let base = s * m_words;
698            xor_words(&mut accum[base..base + m_words], ref_bits);
699        }
700
701        for s in 0..num_shots {
702            let u: f64 = rand::Rng::random(&mut self.rng);
703            let class = match self
704                .class_cdf
705                .binary_search_by(|p| p.partial_cmp(&u).unwrap_or(std::cmp::Ordering::Equal))
706            {
707                Ok(i) => i,
708                Err(i) => i.min(self.class_cdf.len() - 1),
709            };
710
711            let det = &self.class_detections[class];
712            let base = s * m_words;
713            xor_words(&mut accum[base..base + m_words], det);
714        }
715
716        PackedShots::from_shot_major(accum, num_shots, m)
717    }
718
719    pub fn sample_chunked<A: ShotAccumulator>(&mut self, total_shots: usize, acc: &mut A) {
720        let chunk_size = default_chunk_size(self.compiled.num_measurements());
721        let mut remaining = total_shots;
722        while remaining > 0 {
723            let batch = remaining.min(chunk_size);
724            let packed = self.sample_packed(batch);
725            acc.accumulate(&packed);
726            remaining -= batch;
727        }
728    }
729
730    pub fn sample_marginals(&mut self, total_shots: usize) -> Vec<f64> {
731        let mut acc =
732            crate::sim::compiled::MarginalsAccumulator::new(self.compiled.num_measurements());
733        self.sample_chunked(total_shots, &mut acc);
734        acc.marginals()
735    }
736}
737
738/// Run noisy shot sampling using the homological sampler.
739///
740/// For Clifford circuits where the homology dimension h is small (≤ 20),
741/// precomputes class probabilities and samples in O(1) per shot.
742pub fn run_shots_homological(
743    circuit: &Circuit,
744    noise: &NoiseModel,
745    num_shots: usize,
746    seed: u64,
747) -> Result<ShotsResult> {
748    let sampler = HomologicalSampler::compile(circuit, noise, seed)?;
749    run_shots_homological_inner(sampler, circuit, num_shots)
750}
751
752pub(crate) fn run_shots_homological_inner(
753    mut sampler: HomologicalSampler,
754    circuit: &Circuit,
755    num_shots: usize,
756) -> Result<ShotsResult> {
757    let classical_bit_order: Vec<usize> = circuit
758        .instructions
759        .iter()
760        .filter_map(|inst| match inst {
761            Instruction::Measure { classical_bit, .. } => Some(*classical_bit),
762            _ => None,
763        })
764        .collect();
765    let num_classical = circuit.num_classical_bits;
766
767    let raw_shots = sampler.sample_bulk(num_shots);
768
769    let mut shots = Vec::with_capacity(num_shots);
770    for raw in &raw_shots {
771        let mut out = vec![false; num_classical];
772        for (mi, &cbit) in classical_bit_order.iter().enumerate() {
773            if cbit < num_classical {
774                out[cbit] = raw[mi];
775            }
776        }
777        shots.push(out);
778    }
779
780    Ok(ShotsResult {
781        shots,
782        num_classical_bits: circuit.num_classical_bits,
783    })
784}
785
786/// Compute exact noisy marginals analytically. No sampling, no rank limit.
787///
788/// Builds the error chain complex and compiled sampler, then computes
789/// exact per-measurement noisy probabilities in O(nnz(E)) time.
790/// Works for any qubit count, not limited by syndrome rank.
791pub fn noisy_marginals_analytical(
792    circuit: &Circuit,
793    noise: &NoiseModel,
794    seed: u64,
795) -> Result<Vec<f64>> {
796    let ecc = ErrorChainComplex::build(circuit, noise, seed)?;
797    let compiled = crate::sim::compiled::compile_measurements(circuit, seed)?;
798    let noiseless = compiled.marginal_probabilities();
799    let noisy = ecc.noisy_marginals(&noiseless);
800
801    let classical_bit_order: Vec<usize> = circuit
802        .instructions
803        .iter()
804        .filter_map(|inst| match inst {
805            Instruction::Measure { classical_bit, .. } => Some(*classical_bit),
806            _ => None,
807        })
808        .collect();
809    let num_classical = circuit.num_classical_bits;
810
811    let mut result = vec![0.5f64; num_classical];
812    for (mi, &cbit) in classical_bit_order.iter().enumerate() {
813        if cbit < num_classical && mi < noisy.len() {
814            result[cbit] = noisy[mi];
815        }
816    }
817    Ok(result)
818}
819
820#[cfg(test)]
821mod tests {
822    use super::*;
823    use crate::circuits;
824
825    #[test]
826    fn gf2_kernel_identity() {
827        // Identity matrix: kernel is trivial (empty)
828        let mut m = F2DenseMatrix::new(3, 3);
829        m.set(0, 0);
830        m.set(1, 1);
831        m.set(2, 2);
832        let k = gf2_kernel(&m);
833        assert!(k.is_empty(), "Identity matrix should have trivial kernel");
834    }
835
836    #[test]
837    fn gf2_kernel_zero_matrix() {
838        // Zero matrix: kernel is the full space
839        let m = F2DenseMatrix::new(3, 4);
840        let k = gf2_kernel(&m);
841        assert_eq!(k.len(), 4, "Zero 3×4 matrix should have 4-dim kernel");
842    }
843
844    #[test]
845    fn gf2_kernel_rank_deficient() {
846        // [1 1 0]
847        // [0 1 1]
848        // Row 0 + Row 1 = [1 0 1], so rank = 2, kernel dim = 3 - 2 = 1
849        let mut m = F2DenseMatrix::new(2, 3);
850        m.set(0, 0);
851        m.set(0, 1);
852        m.set(1, 1);
853        m.set(1, 2);
854        let k = gf2_kernel(&m);
855        assert_eq!(k.len(), 1, "rank-2 2×3 matrix should have 1-dim kernel");
856        // Kernel vector should be [1, 1, 1] (x₀ = x₁ = x₂)
857        // Row 0: x₀ + x₁ = 0 → x₀ = x₁
858        // Row 1: x₁ + x₂ = 0 → x₁ = x₂
859        let kv = &k[0];
860        assert_eq!(kv[0] & 0b111, 0b111, "kernel vector should be [1,1,1]");
861    }
862
863    #[test]
864    fn gf2_kernel_verifies() {
865        // Verify Mx = 0 for all kernel vectors
866        let mut m = F2DenseMatrix::new(3, 5);
867        // Some arbitrary matrix
868        m.set(0, 0);
869        m.set(0, 2);
870        m.set(0, 4);
871        m.set(1, 1);
872        m.set(1, 3);
873        m.set(2, 0);
874        m.set(2, 1);
875        m.set(2, 2);
876
877        let k = gf2_kernel(&m);
878        for kv in &k {
879            // Check M · kv = 0
880            for r in 0..3 {
881                let mut dot = 0u32;
882                for c in 0..5 {
883                    if m.get(r, c) && (kv[c / 64] >> (c % 64)) & 1 != 0 {
884                        dot ^= 1;
885                    }
886                }
887                assert_eq!(dot, 0, "kernel vector should satisfy Mx = 0");
888            }
889        }
890    }
891
892    #[test]
893    fn homological_ghz_compiles() {
894        let n = 6;
895        let mut circuit = circuits::ghz_circuit(n);
896        circuit.num_classical_bits = n;
897        for i in 0..n {
898            circuit.add_measure(i, i);
899        }
900        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
901        let sampler = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
902        assert!(sampler.syndrome_rank() <= n, "syndrome rank should be ≤ n");
903    }
904
905    #[test]
906    fn homological_ghz_samples() {
907        let n = 6;
908        let mut circuit = circuits::ghz_circuit(n);
909        circuit.num_classical_bits = n;
910        for i in 0..n {
911            circuit.add_measure(i, i);
912        }
913        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.01);
914        let mut sampler = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
915        let shots = sampler.sample_bulk(1000);
916        assert_eq!(shots.len(), 1000);
917        assert_eq!(shots[0].len(), n);
918    }
919
920    #[test]
921    fn homological_bell_pairs() {
922        let n = 4;
923        let mut circuit = circuits::independent_bell_pairs(n / 2);
924        circuit.num_classical_bits = n;
925        for i in 0..n {
926            circuit.add_measure(i, i);
927        }
928        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
929        let sampler = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
930        // Bell pairs with noise should have non-trivial syndrome rank
931        assert!(sampler.syndrome_rank() > 0);
932    }
933
934    #[test]
935    fn homological_class_probs_sum_to_one() {
936        let n = 6;
937        let mut circuit = circuits::ghz_circuit(n);
938        circuit.num_classical_bits = n;
939        for i in 0..n {
940            circuit.add_measure(i, i);
941        }
942        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.01);
943        let sampler = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
944        let sum: f64 = sampler.class_probs.iter().sum();
945        assert!(
946            (sum - 1.0).abs() < 1e-10,
947            "class probabilities should sum to 1, got {sum}"
948        );
949    }
950
951    #[test]
952    fn homological_matches_brute_force_statistics() {
953        let n = 4;
954        let mut circuit = circuits::ghz_circuit(n);
955        circuit.num_classical_bits = n;
956        for i in 0..n {
957            circuit.add_measure(i, i);
958        }
959        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.05);
960        let num_shots = 10000;
961
962        // Homological sampler
963        let homo_result = run_shots_homological(&circuit, &noise, num_shots, 42).unwrap();
964
965        // Brute-force sampler
966        let brute_result =
967            crate::sim::noise::run_shots_noisy(&circuit, &noise, num_shots, 42).unwrap();
968
969        // Compare per-bit marginal probabilities
970        let m = n;
971        for bit in 0..m {
972            let homo_ones: usize = homo_result.shots.iter().filter(|s| s[bit]).count();
973            let brute_ones: usize = brute_result.shots.iter().filter(|s| s[bit]).count();
974            let homo_p = homo_ones as f64 / num_shots as f64;
975            let brute_p = brute_ones as f64 / num_shots as f64;
976            let diff = (homo_p - brute_p).abs();
977            assert!(
978                diff < 0.05,
979                "bit {bit}: homological p={homo_p:.4}, brute p={brute_p:.4}, diff={diff:.4}"
980            );
981        }
982    }
983
984    #[test]
985    fn boundary_trivial_circuit_has_zero_homology() {
986        let n = 4;
987        let mut circuit = crate::circuit::Circuit::new(n, n);
988        for i in 0..n {
989            circuit.add_measure(i, i);
990        }
991        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
992        let ecc = ErrorChainComplex::build(&circuit, &noise, 42).unwrap();
993        assert_eq!(ecc.boundary_dim(), n);
994        assert_eq!(ecc.homology_dim(), 0);
995    }
996
997    #[test]
998    fn boundary_ghz_has_one_logical_qubit() {
999        for n in [3, 5, 8] {
1000            let mut circuit = circuits::ghz_circuit(n);
1001            circuit.num_classical_bits = n;
1002            for i in 0..n {
1003                circuit.add_measure(i, i);
1004            }
1005            let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
1006            let ecc = ErrorChainComplex::build(&circuit, &noise, 42).unwrap();
1007            assert_eq!(
1008                ecc.homology_dim(),
1009                1,
1010                "GHZ-{n} should have 1 logical error class"
1011            );
1012            assert_eq!(ecc.boundary_dim(), n - 1);
1013        }
1014    }
1015
1016    #[test]
1017    fn boundary_bell_pair_has_one_logical() {
1018        let mut circuit = crate::circuit::Circuit::new(2, 2);
1019        circuit.add_gate(crate::gates::Gate::H, &[0]);
1020        circuit.add_gate(crate::gates::Gate::Cx, &[0, 1]);
1021        circuit.add_measure(0, 0);
1022        circuit.add_measure(1, 1);
1023        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
1024        let ecc = ErrorChainComplex::build(&circuit, &noise, 42).unwrap();
1025        assert_eq!(ecc.homology_dim(), 1);
1026        assert_eq!(ecc.boundary_dim(), 1);
1027    }
1028
1029    #[test]
1030    fn boundary_independent_bell_pairs() {
1031        let n_pairs = 3;
1032        let n = n_pairs * 2;
1033        let mut circuit = circuits::independent_bell_pairs(n_pairs);
1034        circuit.num_classical_bits = n;
1035        for i in 0..n {
1036            circuit.add_measure(i, i);
1037        }
1038        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
1039        let ecc = ErrorChainComplex::build(&circuit, &noise, 42).unwrap();
1040        assert_eq!(
1041            ecc.homology_dim(),
1042            n_pairs,
1043            "{n_pairs} bell pairs should have {n_pairs} logical error classes"
1044        );
1045    }
1046
1047    #[test]
1048    fn boundary_exposed_via_sampler() {
1049        let n = 4;
1050        let mut circuit = circuits::ghz_circuit(n);
1051        circuit.num_classical_bits = n;
1052        for i in 0..n {
1053            circuit.add_measure(i, i);
1054        }
1055        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
1056        let sampler = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
1057        assert_eq!(sampler.homology_dim(), 1);
1058        assert_eq!(sampler.boundary_dim(), n - 1);
1059    }
1060
1061    #[test]
1062    fn boundary_partial_measurement() {
1063        let mut circuit = crate::circuit::Circuit::new(3, 1);
1064        circuit.add_gate(crate::gates::Gate::H, &[0]);
1065        circuit.add_gate(crate::gates::Gate::Cx, &[0, 1]);
1066        circuit.add_measure(0, 0);
1067        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
1068        let ecc = ErrorChainComplex::build(&circuit, &noise, 42).unwrap();
1069        // 3 qubits, 1 measured: ker(σ) has dim 2*3-1=5
1070        // Stabilizers: X₀X₁, Z₀Z₁, Z₂ (3 generators)
1071        // X-projection on qubit 0: X₀X₁ has X on q0 → rank(A) = 1
1072        // boundary_dim = 3-1 = 2, homology_dim = 3-1+1 = 3
1073        assert_eq!(ecc.boundary_dim(), 2);
1074        assert_eq!(ecc.homology_dim(), 3);
1075    }
1076
1077    #[test]
1078    fn packed_matches_unpacked() {
1079        let n = 6;
1080        let mut circuit = circuits::ghz_circuit(n);
1081        circuit.num_classical_bits = n;
1082        for i in 0..n {
1083            circuit.add_measure(i, i);
1084        }
1085        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.01);
1086
1087        let mut s1 = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
1088        let mut s2 = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
1089
1090        let unpacked = s1.sample_bulk(500);
1091        let packed = s2.sample_packed(500);
1092
1093        assert_eq!(packed.num_shots(), 500);
1094        assert_eq!(packed.num_measurements(), n);
1095
1096        for (s, shot) in unpacked.iter().enumerate() {
1097            for (m, &val) in shot.iter().enumerate() {
1098                assert_eq!(packed.get_bit(s, m), val, "mismatch at shot={s} meas={m}");
1099            }
1100        }
1101    }
1102
1103    #[test]
1104    fn marginals_matches_unpacked() {
1105        let n = 6;
1106        let mut circuit = circuits::ghz_circuit(n);
1107        circuit.num_classical_bits = n;
1108        for i in 0..n {
1109            circuit.add_measure(i, i);
1110        }
1111        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.01);
1112
1113        let mut s1 = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
1114        let mut s2 = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
1115
1116        let num_shots = 10_000;
1117        let unpacked = s1.sample_bulk(num_shots);
1118        let marginals = s2.sample_marginals(num_shots);
1119
1120        assert_eq!(marginals.len(), n);
1121        for m in 0..n {
1122            let unpacked_p = unpacked.iter().filter(|s| s[m]).count() as f64 / num_shots as f64;
1123            assert!(
1124                (marginals[m] - unpacked_p).abs() < 1e-10,
1125                "marginal mismatch at meas={m}: packed={}, unpacked={unpacked_p}",
1126                marginals[m],
1127            );
1128        }
1129    }
1130
1131    #[test]
1132    fn analytical_marginals_match_sampled_small() {
1133        let n = 6;
1134        let mut circuit = circuits::ghz_circuit(n);
1135        circuit.num_classical_bits = n;
1136        for i in 0..n {
1137            circuit.add_measure(i, i);
1138        }
1139        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.01);
1140
1141        let analytical = noisy_marginals_analytical(&circuit, &noise, 42).unwrap();
1142
1143        let mut sampler = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
1144        let sampled = sampler.sample_marginals(100_000);
1145
1146        assert_eq!(analytical.len(), n);
1147        assert_eq!(sampled.len(), n);
1148        for i in 0..n {
1149            assert!(
1150                (analytical[i] - sampled[i]).abs() < 0.01,
1151                "bit {i}: analytical={:.6}, sampled={:.6}",
1152                analytical[i],
1153                sampled[i],
1154            );
1155        }
1156    }
1157
1158    #[test]
1159    fn analytical_marginals_ghz_50q() {
1160        let n = 50;
1161        let mut circuit = circuits::ghz_circuit(n);
1162        circuit.num_classical_bits = n;
1163        for i in 0..n {
1164            circuit.add_measure(i, i);
1165        }
1166        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
1167
1168        let marginals = noisy_marginals_analytical(&circuit, &noise, 42).unwrap();
1169        assert_eq!(marginals.len(), n);
1170        for (i, &p) in marginals.iter().enumerate() {
1171            assert!(p > 0.0 && p < 1.0, "bit {i}: marginal {p} out of range");
1172            assert!(
1173                (p - 0.5).abs() < 0.05,
1174                "bit {i}: GHZ marginal should be near 0.5, got {p}"
1175            );
1176        }
1177    }
1178
1179    #[test]
1180    fn analytical_marginals_ghz_100q() {
1181        let n = 100;
1182        let mut circuit = circuits::ghz_circuit(n);
1183        circuit.num_classical_bits = n;
1184        for i in 0..n {
1185            circuit.add_measure(i, i);
1186        }
1187        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
1188
1189        let marginals = noisy_marginals_analytical(&circuit, &noise, 42).unwrap();
1190        assert_eq!(marginals.len(), n);
1191        for (i, &p) in marginals.iter().enumerate() {
1192            assert!(p > 0.0 && p < 1.0, "bit {i}: marginal {p} out of range");
1193        }
1194    }
1195
1196    #[test]
1197    fn analytical_marginals_bell_pairs_100q() {
1198        let n_pairs = 50;
1199        let n = n_pairs * 2;
1200        let mut circuit = circuits::independent_bell_pairs(n_pairs);
1201        circuit.num_classical_bits = n;
1202        for i in 0..n {
1203            circuit.add_measure(i, i);
1204        }
1205        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
1206
1207        let marginals = noisy_marginals_analytical(&circuit, &noise, 42).unwrap();
1208        assert_eq!(marginals.len(), n);
1209        for (i, &p) in marginals.iter().enumerate() {
1210            assert!(
1211                (p - 0.5).abs() < 0.05,
1212                "bit {i}: bell pair marginal should be near 0.5, got {p}"
1213            );
1214        }
1215    }
1216
1217    #[test]
1218    fn analytical_marginals_clifford_1000q() {
1219        let n = 1000;
1220        let mut circuit = circuits::clifford_heavy_circuit(n, 2, 42);
1221        circuit.num_classical_bits = n;
1222        for i in 0..n {
1223            circuit.add_measure(i, i);
1224        }
1225        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.001);
1226
1227        let marginals = noisy_marginals_analytical(&circuit, &noise, 42).unwrap();
1228        assert_eq!(marginals.len(), n);
1229        for (i, &p) in marginals.iter().enumerate() {
1230            assert!(
1231                (0.0..=1.0).contains(&p),
1232                "bit {i}: marginal {p} out of range"
1233            );
1234        }
1235    }
1236
1237    #[test]
1238    fn analytical_marginals_deterministic_bits() {
1239        let mut circuit = crate::circuit::Circuit::new(4, 4);
1240        for i in 0..4 {
1241            circuit.add_gate(crate::gates::Gate::X, &[i]);
1242        }
1243        for i in 0..4 {
1244            circuit.add_measure(i, i);
1245        }
1246        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.01);
1247
1248        let marginals = noisy_marginals_analytical(&circuit, &noise, 42).unwrap();
1249        for (i, &p) in marginals.iter().enumerate() {
1250            assert!(
1251                p > 0.95,
1252                "bit {i}: X-then-measure should give p(1) near 1.0, got {p}"
1253            );
1254        }
1255    }
1256
1257    #[test]
1258    fn analytical_marginals_no_noise() {
1259        let n = 6;
1260        let mut circuit = circuits::ghz_circuit(n);
1261        circuit.num_classical_bits = n;
1262        for i in 0..n {
1263            circuit.add_measure(i, i);
1264        }
1265        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.0);
1266
1267        let marginals = noisy_marginals_analytical(&circuit, &noise, 42).unwrap();
1268        for (i, &p) in marginals.iter().enumerate() {
1269            assert!(
1270                (p - 0.5).abs() < 1e-10,
1271                "bit {i}: GHZ with no noise should have marginal 0.5, got {p}"
1272            );
1273        }
1274    }
1275
1276    #[test]
1277    fn chunked_accumulator_matches_packed() {
1278        let n = 6;
1279        let mut circuit = circuits::ghz_circuit(n);
1280        circuit.num_classical_bits = n;
1281        for i in 0..n {
1282            circuit.add_measure(i, i);
1283        }
1284        let noise = NoiseModel::uniform_depolarizing(&circuit, 0.01);
1285
1286        let mut s1 = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
1287        let mut s2 = HomologicalSampler::compile(&circuit, &noise, 42).unwrap();
1288
1289        let num_shots = 5_000;
1290        let packed = s1.sample_packed(num_shots);
1291        let direct_counts = packed.counts();
1292
1293        let mut acc = super::super::compiled::HistogramAccumulator::new();
1294        s2.sample_chunked(num_shots, &mut acc);
1295        let chunked_counts = acc.into_counts();
1296
1297        assert_eq!(direct_counts, chunked_counts);
1298    }
1299}