Skip to main content

prism_q/backend/
sparse.rs

1//! Sparse state-vector simulation backend.
2//!
3//! Stores only non-zero amplitudes in a `HashMap<usize, Complex64>`, giving O(k) memory
4//! where k is the number of non-zero basis states. Amplitudes below a configurable
5//! epsilon are pruned after each gate to maintain sparsity.
6//!
7//! # When to prefer this backend
8//!
9//! - States with few non-zero amplitudes (computational basis states, limited superposition).
10//! - Large qubit counts where the state stays sparse throughout the circuit.
11//! - Classical-like circuits with limited branching.
12//!
13//! # When NOT to use this backend
14//!
15//! - After a layer of Hadamard gates (state becomes maximally dense).
16//! - Small qubit counts where dense statevector is faster due to HashMap overhead.
17
18use std::collections::HashMap;
19
20use num_complex::Complex64;
21use rand::Rng;
22use rand::SeedableRng;
23use rand_chacha::ChaCha8Rng;
24
25#[cfg(feature = "parallel")]
26use rayon::prelude::*;
27
28#[cfg(feature = "parallel")]
29const MIN_STATES_FOR_PAR: usize = 4096;
30
31use crate::backend::{is_phase_one, Backend, MAX_PROB_QUBITS};
32use crate::circuit::Instruction;
33use crate::error::{PrismError, Result};
34use crate::gates::{DiagEntry, Gate};
35
36const DEFAULT_EPSILON: f64 = 1e-16;
37
38/// Sparse state-vector backend, O(k) where k is the number of non-zero amplitudes.
39pub struct SparseBackend {
40    num_qubits: usize,
41    state: HashMap<usize, Complex64>,
42    swap_buf: HashMap<usize, Complex64>,
43    classical_bits: Vec<bool>,
44    rng: ChaCha8Rng,
45    epsilon: f64,
46}
47
48impl SparseBackend {
49    /// Create a new sparse backend with the given RNG seed.
50    pub fn new(seed: u64) -> Self {
51        Self {
52            num_qubits: 0,
53            state: HashMap::new(),
54            swap_buf: HashMap::new(),
55            classical_bits: Vec::new(),
56            rng: ChaCha8Rng::seed_from_u64(seed),
57            epsilon: DEFAULT_EPSILON,
58        }
59    }
60
61    #[inline(always)]
62    fn prune(&mut self) {
63        let eps = self.epsilon;
64        self.state.retain(|_, amp| amp.norm_sqr() >= eps);
65    }
66
67    #[inline(always)]
68    fn apply_single_qubit(&mut self, target: usize, mat: [[Complex64; 2]; 2]) {
69        let mask = 1usize << target;
70        let zero = Complex64::new(0.0, 0.0);
71        self.swap_buf.clear();
72        self.swap_buf.reserve(self.state.len() * 2);
73
74        for (&idx, &amp) in &self.state {
75            let bit = (idx >> target) & 1;
76            let partner = idx ^ mask;
77
78            *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
79            *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
80        }
81
82        std::mem::swap(&mut self.state, &mut self.swap_buf);
83        self.prune();
84    }
85
86    /// CX is a deterministic 1:1 index mapping. No near-zero amplitudes are created.
87    #[inline(always)]
88    fn apply_cx(&mut self, control: usize, target: usize) {
89        let ctrl_mask = 1usize << control;
90        let tgt_mask = 1usize << target;
91        self.swap_buf.clear();
92        self.swap_buf.reserve(self.state.len());
93        self.swap_buf.extend(self.state.drain().map(|(idx, amp)| {
94            if idx & ctrl_mask != 0 {
95                (idx ^ tgt_mask, amp)
96            } else {
97                (idx, amp)
98            }
99        }));
100        std::mem::swap(&mut self.state, &mut self.swap_buf);
101    }
102
103    #[inline(always)]
104    fn apply_cz(&mut self, q0: usize, q1: usize) {
105        let mask0 = 1usize << q0;
106        let mask1 = 1usize << q1;
107        for (&idx, amp) in self.state.iter_mut() {
108            if idx & mask0 != 0 && idx & mask1 != 0 {
109                *amp = -*amp;
110            }
111        }
112    }
113
114    #[inline(always)]
115    fn apply_swap(&mut self, q0: usize, q1: usize) {
116        let m0 = 1usize << q0;
117        let m1 = 1usize << q1;
118        self.swap_buf.clear();
119        self.swap_buf.reserve(self.state.len());
120        self.swap_buf.extend(self.state.drain().map(|(idx, amp)| {
121            let bit0 = (idx >> q0) & 1;
122            let bit1 = (idx >> q1) & 1;
123            if bit0 != bit1 {
124                (idx ^ m0 ^ m1, amp)
125            } else {
126                (idx, amp)
127            }
128        }));
129        std::mem::swap(&mut self.state, &mut self.swap_buf);
130    }
131
132    #[inline(always)]
133    fn apply_cu(&mut self, control: usize, target: usize, mat: [[Complex64; 2]; 2]) {
134        let ctrl_mask = 1usize << control;
135        let tgt_mask = 1usize << target;
136        let zero = Complex64::new(0.0, 0.0);
137        self.swap_buf.clear();
138        self.swap_buf.reserve(self.state.len() * 2);
139
140        for (&idx, &amp) in &self.state {
141            if idx & ctrl_mask == 0 {
142                *self.swap_buf.entry(idx).or_insert(zero) += amp;
143            } else {
144                let bit = (idx >> target) & 1;
145                let partner = idx ^ tgt_mask;
146                *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
147                *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
148            }
149        }
150
151        std::mem::swap(&mut self.state, &mut self.swap_buf);
152        self.prune();
153    }
154
155    #[inline(always)]
156    fn apply_mcu(&mut self, controls: &[usize], target: usize, mat: [[Complex64; 2]; 2]) {
157        let ctrl_mask: usize = controls.iter().map(|&q| 1usize << q).fold(0, |a, b| a | b);
158        let tgt_mask = 1usize << target;
159        let zero = Complex64::new(0.0, 0.0);
160        self.swap_buf.clear();
161        self.swap_buf.reserve(self.state.len() * 2);
162
163        for (&idx, &amp) in &self.state {
164            if idx & ctrl_mask != ctrl_mask {
165                *self.swap_buf.entry(idx).or_insert(zero) += amp;
166            } else {
167                let bit = (idx >> target) & 1;
168                let partner = idx ^ tgt_mask;
169                *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
170                *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
171            }
172        }
173
174        std::mem::swap(&mut self.state, &mut self.swap_buf);
175        self.prune();
176    }
177
178    #[inline(always)]
179    fn apply_cu_phase(&mut self, control: usize, target: usize, phase: Complex64) {
180        let ctrl_mask = 1usize << control;
181        let tgt_mask = 1usize << target;
182        for (&idx, amp) in self.state.iter_mut() {
183            if idx & ctrl_mask != 0 && idx & tgt_mask != 0 {
184                *amp *= phase;
185            }
186        }
187    }
188
189    #[inline(always)]
190    fn apply_mcu_phase(&mut self, controls: &[usize], target: usize, phase: Complex64) {
191        let ctrl_mask: usize = controls.iter().map(|&q| 1usize << q).fold(0, |a, b| a | b);
192        let tgt_mask = 1usize << target;
193        for (&idx, amp) in self.state.iter_mut() {
194            if idx & ctrl_mask == ctrl_mask && idx & tgt_mask != 0 {
195                *amp *= phase;
196            }
197        }
198    }
199
200    fn apply_batch_phase(&mut self, control: usize, phases: &[(usize, Complex64)]) {
201        let ctrl_mask = 1usize << control;
202        let one = Complex64::new(1.0, 0.0);
203        for (&idx, amp) in self.state.iter_mut() {
204            if idx & ctrl_mask == 0 {
205                continue;
206            }
207            let mut combined = one;
208            for &(target, phase) in phases {
209                if idx & (1usize << target) != 0 {
210                    combined *= phase;
211                }
212            }
213            if !is_phase_one(combined) {
214                *amp *= combined;
215            }
216        }
217    }
218
219    fn apply_fused_2q(&mut self, q0: usize, q1: usize, mat: &[[Complex64; 4]; 4]) {
220        let mask0 = 1usize << q0;
221        let mask1 = 1usize << q1;
222        let zero = Complex64::new(0.0, 0.0);
223        self.swap_buf.clear();
224        self.swap_buf.reserve(self.state.len() * 2);
225
226        for (&idx, &amp) in &self.state {
227            let bit0 = (idx >> q0) & 1;
228            let bit1 = (idx >> q1) & 1;
229            let row = bit0 * 2 + bit1;
230            let base = idx & !(mask0 | mask1);
231
232            for (col, mat_row) in mat.iter().enumerate() {
233                let coeff = mat_row[row];
234                if coeff == zero {
235                    continue;
236                }
237                let col_bit0 = (col >> 1) & 1;
238                let col_bit1 = col & 1;
239                let dest = base | (col_bit0 << q0) | (col_bit1 << q1);
240                *self.swap_buf.entry(dest).or_insert(zero) += coeff * amp;
241            }
242        }
243
244        std::mem::swap(&mut self.state, &mut self.swap_buf);
245        self.prune();
246    }
247
248    fn apply_reset(&mut self, qubit: usize) {
249        let mask = 1usize << qubit;
250
251        #[cfg(feature = "parallel")]
252        let prob_zero: f64 = if self.state.len() >= MIN_STATES_FOR_PAR {
253            self.state
254                .par_iter()
255                .filter(|(&idx, _)| idx & mask == 0)
256                .map(|(_, amp)| amp.norm_sqr())
257                .sum()
258        } else {
259            self.state
260                .iter()
261                .filter(|(&idx, _)| idx & mask == 0)
262                .map(|(_, amp)| amp.norm_sqr())
263                .sum()
264        };
265
266        #[cfg(not(feature = "parallel"))]
267        let prob_zero: f64 = self
268            .state
269            .iter()
270            .filter(|(&idx, _)| idx & mask == 0)
271            .map(|(_, amp)| amp.norm_sqr())
272            .sum();
273
274        if prob_zero > 0.0 {
275            let inv_norm = 1.0 / prob_zero.sqrt();
276            self.state.retain(|&idx, amp| {
277                if idx & mask == 0 {
278                    *amp *= inv_norm;
279                    true
280                } else {
281                    false
282                }
283            });
284        } else {
285            self.state.clear();
286            self.state.insert(0, Complex64::new(1.0, 0.0));
287        }
288    }
289
290    fn apply_measure(&mut self, qubit: usize, classical_bit: usize) {
291        let mask = 1usize << qubit;
292
293        #[cfg(feature = "parallel")]
294        let prob_one: f64 = if self.state.len() >= MIN_STATES_FOR_PAR {
295            self.state
296                .par_iter()
297                .filter(|(&idx, _)| idx & mask != 0)
298                .map(|(_, amp)| amp.norm_sqr())
299                .sum()
300        } else {
301            self.state
302                .iter()
303                .filter(|(&idx, _)| idx & mask != 0)
304                .map(|(_, amp)| amp.norm_sqr())
305                .sum()
306        };
307
308        #[cfg(not(feature = "parallel"))]
309        let prob_one: f64 = self
310            .state
311            .iter()
312            .filter(|(&idx, _)| idx & mask != 0)
313            .map(|(_, amp)| amp.norm_sqr())
314            .sum();
315
316        let outcome = self.rng.random::<f64>() < prob_one;
317        self.classical_bits[classical_bit] = outcome;
318
319        let inv_norm = crate::backend::measurement_inv_norm(outcome, prob_one);
320
321        self.state.retain(|&idx, amp| {
322            let matches = (idx & mask != 0) == outcome;
323            if matches {
324                *amp *= inv_norm;
325            }
326            matches
327        });
328    }
329
330    fn dispatch_gate(&mut self, gate: &Gate, targets: &[usize]) {
331        match gate {
332            Gate::Rzz(theta) => {
333                let phase_same = Complex64::from_polar(1.0, -theta / 2.0);
334                let phase_diff = Complex64::from_polar(1.0, theta / 2.0);
335                let q0 = targets[0];
336                let q1 = targets[1];
337                for (idx, amp) in self.state.iter_mut() {
338                    let parity = ((*idx >> q0) ^ (*idx >> q1)) & 1;
339                    *amp *= if parity == 0 { phase_same } else { phase_diff };
340                }
341            }
342            Gate::Cx => {
343                self.apply_cx(targets[0], targets[1]);
344            }
345            Gate::Cz => {
346                self.apply_cz(targets[0], targets[1]);
347            }
348            Gate::Swap => {
349                self.apply_swap(targets[0], targets[1]);
350            }
351            Gate::Cu(mat) => {
352                if let Some(phase) = gate.controlled_phase() {
353                    self.apply_cu_phase(targets[0], targets[1], phase);
354                } else {
355                    self.apply_cu(targets[0], targets[1], **mat);
356                }
357            }
358            Gate::Mcu(data) => {
359                let num_ctrl = data.num_controls as usize;
360                if let Some(phase) = gate.controlled_phase() {
361                    self.apply_mcu_phase(&targets[..num_ctrl], targets[num_ctrl], phase);
362                } else {
363                    self.apply_mcu(&targets[..num_ctrl], targets[num_ctrl], data.mat);
364                }
365            }
366            Gate::BatchPhase(data) => {
367                self.apply_batch_phase(targets[0], &data.phases);
368            }
369            Gate::BatchRzz(data) => {
370                for &(q0, q1, theta) in &data.edges {
371                    let phase_same = Complex64::from_polar(1.0, -theta / 2.0);
372                    let phase_diff = Complex64::from_polar(1.0, theta / 2.0);
373                    for (idx, amp) in self.state.iter_mut() {
374                        let parity = ((*idx >> q0) ^ (*idx >> q1)) & 1;
375                        *amp *= if parity == 0 { phase_same } else { phase_diff };
376                    }
377                }
378            }
379            Gate::DiagonalBatch(data) => {
380                for entry in &data.entries {
381                    match entry {
382                        DiagEntry::Phase1q { qubit, d0, d1 } => {
383                            let mask = 1usize << qubit;
384                            for (idx, amp) in self.state.iter_mut() {
385                                if (*idx & mask) != 0 {
386                                    *amp *= d1;
387                                } else {
388                                    *amp *= d0;
389                                }
390                            }
391                        }
392                        DiagEntry::Phase2q { q0, q1, phase } => {
393                            let mask = (1usize << q0) | (1usize << q1);
394                            for (idx, amp) in self.state.iter_mut() {
395                                if (*idx & mask) == mask {
396                                    *amp *= phase;
397                                }
398                            }
399                        }
400                        DiagEntry::Parity2q { q0, q1, same, diff } => {
401                            for (idx, amp) in self.state.iter_mut() {
402                                let parity = ((*idx >> q0) ^ (*idx >> q1)) & 1;
403                                *amp *= if parity == 0 { *same } else { *diff };
404                            }
405                        }
406                    }
407                }
408            }
409            Gate::MultiFused(data) => {
410                for &(target, mat) in &data.gates {
411                    self.apply_single_qubit(target, mat);
412                }
413            }
414            Gate::Fused2q(mat) => {
415                self.apply_fused_2q(targets[0], targets[1], mat);
416            }
417            Gate::Multi2q(data) => {
418                for &(q0, q1, ref mat) in &data.gates {
419                    self.apply_fused_2q(q0, q1, mat);
420                }
421            }
422            other => {
423                debug_assert!(
424                    targets.len() == 1,
425                    "sparse dispatch_gate: unexpected multi-qubit gate {:?}",
426                    other
427                );
428                let mat = other.matrix_2x2();
429                self.apply_single_qubit(targets[0], mat);
430            }
431        }
432    }
433}
434
435impl Backend for SparseBackend {
436    fn name(&self) -> &'static str {
437        "sparse"
438    }
439
440    fn init(&mut self, num_qubits: usize, num_classical_bits: usize) -> Result<()> {
441        self.num_qubits = num_qubits;
442        self.state.clear();
443        self.state.insert(0, Complex64::new(1.0, 0.0));
444        self.classical_bits = vec![false; num_classical_bits];
445        Ok(())
446    }
447
448    fn apply(&mut self, instruction: &Instruction) -> Result<()> {
449        match instruction {
450            Instruction::Gate { gate, targets } => self.dispatch_gate(gate, targets),
451            Instruction::Measure {
452                qubit,
453                classical_bit,
454            } => {
455                self.apply_measure(*qubit, *classical_bit);
456            }
457            Instruction::Reset { qubit } => {
458                self.apply_reset(*qubit);
459            }
460            Instruction::Barrier { .. } => {}
461            Instruction::Conditional {
462                condition,
463                gate,
464                targets,
465            } => {
466                if condition.evaluate(&self.classical_bits) {
467                    self.dispatch_gate(gate, targets);
468                }
469            }
470        }
471        Ok(())
472    }
473
474    fn reset(&mut self, qubit: usize) -> Result<()> {
475        self.apply_reset(qubit);
476        Ok(())
477    }
478
479    fn reduced_density_matrix_1q(&self, qubit: usize) -> Result<[[Complex64; 2]; 2]> {
480        let mask = 1usize << qubit;
481        let mut p0 = 0.0f64;
482        let mut p1 = 0.0f64;
483        let mut r = Complex64::new(0.0, 0.0);
484
485        for (&idx, &amp) in &self.state {
486            if idx & mask == 0 {
487                p0 += amp.norm_sqr();
488                if let Some(&amp_one) = self.state.get(&(idx | mask)) {
489                    r += amp_one * amp.conj();
490                }
491            } else {
492                p1 += amp.norm_sqr();
493            }
494        }
495
496        Ok([
497            [Complex64::new(p0, 0.0), r.conj()],
498            [r, Complex64::new(p1, 0.0)],
499        ])
500    }
501
502    fn classical_results(&self) -> &[bool] {
503        &self.classical_bits
504    }
505
506    fn probabilities(&self) -> Result<Vec<f64>> {
507        if self.num_qubits > MAX_PROB_QUBITS {
508            return Err(PrismError::BackendUnsupported {
509                backend: self.name().to_string(),
510                operation: format!(
511                    "probabilities for {} qubits (max {})",
512                    self.num_qubits, MAX_PROB_QUBITS
513                ),
514            });
515        }
516        let dim = 1usize << self.num_qubits;
517        let mut probs = vec![0.0f64; dim];
518        for (&idx, amp) in &self.state {
519            probs[idx] = amp.norm_sqr();
520        }
521        Ok(probs)
522    }
523
524    fn num_qubits(&self) -> usize {
525        self.num_qubits
526    }
527
528    fn export_statevector(&self) -> Result<Vec<Complex64>> {
529        if self.num_qubits > MAX_PROB_QUBITS {
530            return Err(PrismError::BackendUnsupported {
531                backend: self.name().to_string(),
532                operation: format!(
533                    "statevector export for {} qubits (max {})",
534                    self.num_qubits, MAX_PROB_QUBITS
535                ),
536            });
537        }
538        let dim = 1usize << self.num_qubits;
539        let mut sv = vec![Complex64::new(0.0, 0.0); dim];
540        for (&idx, &amp) in &self.state {
541            sv[idx] = amp;
542        }
543        Ok(sv)
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550    use crate::circuit::Circuit;
551    use crate::sim;
552
553    const EPS: f64 = 1e-12;
554
555    fn run_sparse(circuit: &Circuit) -> SparseBackend {
556        let mut b = SparseBackend::new(42);
557        sim::run_on(&mut b, circuit).unwrap();
558        b
559    }
560
561    fn run_sparse_probs(circuit: &Circuit) -> Vec<f64> {
562        let b = run_sparse(circuit);
563        b.probabilities().unwrap()
564    }
565
566    #[test]
567    fn test_init_zero_state() {
568        let mut b = SparseBackend::new(42);
569        b.init(3, 0).unwrap();
570        assert_eq!(b.state.len(), 1);
571        assert!((b.state[&0].re - 1.0).abs() < EPS);
572    }
573
574    #[test]
575    fn test_x_gate() {
576        let mut c = Circuit::new(1, 0);
577        c.add_gate(Gate::X, &[0]);
578        let b = run_sparse(&c);
579        assert_eq!(b.state.len(), 1);
580        assert!(b.state.contains_key(&1));
581        assert!((b.state[&1].norm() - 1.0).abs() < EPS);
582    }
583
584    #[test]
585    fn test_h_creates_superposition() {
586        let mut c = Circuit::new(1, 0);
587        c.add_gate(Gate::H, &[0]);
588        let b = run_sparse(&c);
589        assert_eq!(b.state.len(), 2);
590        assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
591        assert!((b.state[&1].norm_sqr() - 0.5).abs() < EPS);
592    }
593
594    #[test]
595    fn test_hh_is_identity() {
596        let mut c = Circuit::new(1, 0);
597        c.add_gate(Gate::H, &[0]);
598        c.add_gate(Gate::H, &[0]);
599        let b = run_sparse(&c);
600        assert_eq!(b.state.len(), 1);
601        assert!((b.state[&0].re - 1.0).abs() < EPS);
602    }
603
604    #[test]
605    fn test_cx_bell_state() {
606        let mut c = Circuit::new(2, 0);
607        c.add_gate(Gate::H, &[0]);
608        c.add_gate(Gate::Cx, &[0, 1]);
609        let b = run_sparse(&c);
610        assert_eq!(b.state.len(), 2);
611        assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
612        assert!((b.state[&3].norm_sqr() - 0.5).abs() < EPS);
613    }
614
615    #[test]
616    fn test_cz_phase() {
617        let mut c = Circuit::new(2, 0);
618        c.add_gate(Gate::X, &[0]);
619        c.add_gate(Gate::X, &[1]);
620        c.add_gate(Gate::Cz, &[0, 1]);
621        let b = run_sparse(&c);
622        assert_eq!(b.state.len(), 1);
623        assert!((b.state[&3].re - (-1.0)).abs() < EPS);
624    }
625
626    #[test]
627    fn test_swap() {
628        let mut c = Circuit::new(2, 0);
629        c.add_gate(Gate::X, &[1]);
630        c.add_gate(Gate::Swap, &[0, 1]);
631        let b = run_sparse(&c);
632        assert_eq!(b.state.len(), 1);
633        assert!(b.state.contains_key(&1));
634    }
635
636    #[test]
637    fn test_rx_pi() {
638        let mut c = Circuit::new(1, 0);
639        c.add_gate(Gate::Rx(std::f64::consts::PI), &[0]);
640        let probs = run_sparse_probs(&c);
641        assert!(probs[0].abs() < EPS);
642        assert!((probs[1] - 1.0).abs() < EPS);
643    }
644
645    #[test]
646    fn test_rz_preserves_sparsity() {
647        let mut c = Circuit::new(1, 0);
648        c.add_gate(Gate::Rz(1.234), &[0]);
649        let b = run_sparse(&c);
650        assert_eq!(b.state.len(), 1);
651        assert!((b.state[&0].norm() - 1.0).abs() < EPS);
652    }
653
654    #[test]
655    fn test_measure_collapses() {
656        let mut c = Circuit::new(1, 1);
657        c.add_gate(Gate::H, &[0]);
658        c.add_measure(0, 0);
659        let b = run_sparse(&c);
660        assert_eq!(b.state.len(), 1);
661        let outcome = b.classical_results()[0];
662        if outcome {
663            assert!(b.state.contains_key(&1));
664        } else {
665            assert!(b.state.contains_key(&0));
666        }
667    }
668
669    #[test]
670    fn test_measure_deterministic() {
671        let mut c = Circuit::new(1, 1);
672        c.add_gate(Gate::H, &[0]);
673        c.add_measure(0, 0);
674
675        let b1 = run_sparse(&c);
676        let b2 = run_sparse(&c);
677        assert_eq!(b1.classical_results()[0], b2.classical_results()[0]);
678    }
679
680    #[test]
681    fn test_probs_bell() {
682        let mut c = Circuit::new(2, 0);
683        c.add_gate(Gate::H, &[0]);
684        c.add_gate(Gate::Cx, &[0, 1]);
685        let probs = run_sparse_probs(&c);
686        assert!((probs[0] - 0.5).abs() < EPS);
687        assert!(probs[1].abs() < EPS);
688        assert!(probs[2].abs() < EPS);
689        assert!((probs[3] - 0.5).abs() < EPS);
690    }
691
692    #[test]
693    fn test_probs_zero_state() {
694        let c = Circuit::new(3, 0);
695        let probs = run_sparse_probs(&c);
696        assert!((probs[0] - 1.0).abs() < EPS);
697        let rest: f64 = probs[1..].iter().sum();
698        assert!(rest.abs() < EPS);
699    }
700
701    #[test]
702    fn test_pruning() {
703        let mut b = SparseBackend::new(42);
704        b.init(1, 0).unwrap();
705        b.state.insert(1, Complex64::new(1e-20, 0.0));
706        assert_eq!(b.state.len(), 2);
707        b.prune();
708        assert_eq!(b.state.len(), 1);
709        assert!(b.state.contains_key(&0));
710    }
711
712    #[test]
713    fn test_fused_gate() {
714        let h_mat = Gate::H.matrix_2x2();
715        let t_mat = Gate::T.matrix_2x2();
716        let zero = Complex64::new(0.0, 0.0);
717        let mut fused = [[zero; 2]; 2];
718        for i in 0..2 {
719            for j in 0..2 {
720                for k in 0..2 {
721                    fused[i][j] += t_mat[i][k] * h_mat[k][j];
722                }
723            }
724        }
725
726        let mut c1 = Circuit::new(1, 0);
727        c1.add_gate(Gate::H, &[0]);
728        c1.add_gate(Gate::T, &[0]);
729        let p1 = run_sparse_probs(&c1);
730
731        let mut c2 = Circuit::new(1, 0);
732        c2.add_gate(Gate::Fused(Box::new(fused)), &[0]);
733        let p2 = run_sparse_probs(&c2);
734
735        for (a, b) in p1.iter().zip(p2.iter()) {
736            assert!((a - b).abs() < EPS);
737        }
738    }
739
740    #[test]
741    fn test_ghz_4_sparse() {
742        let mut c = Circuit::new(4, 0);
743        c.add_gate(Gate::H, &[0]);
744        for i in 0..3 {
745            c.add_gate(Gate::Cx, &[i, i + 1]);
746        }
747        let b = run_sparse(&c);
748        assert_eq!(b.state.len(), 2);
749        assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
750        assert!((b.state[&15].norm_sqr() - 0.5).abs() < EPS);
751    }
752
753    #[test]
754    fn test_cu_phase_applies_phase() {
755        let mut c = Circuit::new(2, 0);
756        c.add_gate(Gate::X, &[0]);
757        c.add_gate(Gate::X, &[1]);
758        c.add_gate(Gate::cphase(std::f64::consts::FRAC_PI_4), &[0, 1]);
759        let b = run_sparse(&c);
760        assert_eq!(b.state.len(), 1);
761        let expected = Complex64::from_polar(1.0, std::f64::consts::FRAC_PI_4);
762        assert!((b.state[&3] - expected).norm() < EPS);
763    }
764
765    #[test]
766    fn test_cu_phase_no_action_control_zero() {
767        let mut c = Circuit::new(2, 0);
768        c.add_gate(Gate::H, &[1]);
769        c.add_gate(Gate::cphase(1.0), &[0, 1]);
770        let b = run_sparse(&c);
771        let h = 1.0 / 2.0_f64.sqrt();
772        assert!((b.state[&0].re - h).abs() < EPS);
773        assert!((b.state[&2].re - h).abs() < EPS);
774        assert!(!b.state.contains_key(&1));
775        assert!(!b.state.contains_key(&3));
776    }
777
778    #[test]
779    fn test_cu_phase_matches_cz() {
780        let mut c1 = Circuit::new(2, 0);
781        c1.add_gate(Gate::H, &[0]);
782        c1.add_gate(Gate::H, &[1]);
783        c1.add_gate(Gate::cphase(std::f64::consts::PI), &[0, 1]);
784
785        let mut c2 = Circuit::new(2, 0);
786        c2.add_gate(Gate::H, &[0]);
787        c2.add_gate(Gate::H, &[1]);
788        c2.add_gate(Gate::Cz, &[0, 1]);
789
790        let b1 = run_sparse(&c1);
791        let b2 = run_sparse(&c2);
792
793        for (&idx, &amp1) in &b1.state {
794            let amp2 = b2
795                .state
796                .get(&idx)
797                .copied()
798                .unwrap_or(Complex64::new(0.0, 0.0));
799            assert!((amp1 - amp2).norm() < EPS, "mismatch at idx {idx}");
800        }
801    }
802
803    #[test]
804    fn test_batch_phase_matches_individual() {
805        use crate::gates::BatchPhaseData;
806        use smallvec::smallvec;
807
808        let phase1 = Complex64::from_polar(1.0, 0.5);
809        let phase2 = Complex64::from_polar(1.0, 1.2);
810
811        let mut c1 = Circuit::new(3, 0);
812        c1.add_gate(Gate::H, &[0]);
813        c1.add_gate(Gate::H, &[1]);
814        c1.add_gate(Gate::H, &[2]);
815        c1.add_gate(Gate::cphase(0.5), &[0, 1]);
816        c1.add_gate(Gate::cphase(1.2), &[0, 2]);
817        let p1 = run_sparse_probs(&c1);
818
819        let mut c2 = Circuit::new(3, 0);
820        c2.add_gate(Gate::H, &[0]);
821        c2.add_gate(Gate::H, &[1]);
822        c2.add_gate(Gate::H, &[2]);
823        c2.add_gate(
824            Gate::BatchPhase(Box::new(BatchPhaseData {
825                phases: smallvec![(1, phase1), (2, phase2)],
826            })),
827            &[0, 1, 2],
828        );
829        let p2 = run_sparse_probs(&c2);
830
831        for (a, b) in p1.iter().zip(p2.iter()) {
832            assert!((a - b).abs() < EPS, "probs mismatch: {a} vs {b}");
833        }
834    }
835}