Skip to main content

prism_q/backend/
sparse.rs

1//! Sparse state-vector simulation backend.
2//!
3//! Stores only non-zero amplitudes in a `HashMap<usize, Complex64>`, giving O(k) memory
4//! where k is the number of non-zero basis states. Amplitudes below a configurable
5//! epsilon are pruned after each gate to maintain sparsity.
6//!
7//! # When to prefer this backend
8//!
9//! - States with few non-zero amplitudes (computational basis states, limited superposition).
10//! - Large qubit counts where the state stays sparse throughout the circuit.
11//! - Classical-like circuits with limited branching.
12//!
13//! # When NOT to use this backend
14//!
15//! - After a layer of Hadamard gates (state becomes maximally dense).
16//! - Small qubit counts where dense statevector is faster due to HashMap overhead.
17
18use std::collections::HashMap;
19
20use num_complex::Complex64;
21use rand::Rng;
22use rand::SeedableRng;
23use rand_chacha::ChaCha8Rng;
24
25#[cfg(feature = "parallel")]
26use rayon::prelude::*;
27
28#[cfg(feature = "parallel")]
29const MIN_STATES_FOR_PAR: usize = 4096;
30
31use crate::backend::{
32    dense_probability_len, dense_statevector_len, is_phase_one, reserve_dense_output, Backend,
33};
34use crate::circuit::Instruction;
35use crate::error::Result;
36use crate::gates::{DiagEntry, Gate};
37
38const DEFAULT_EPSILON: f64 = 1e-16;
39
40/// Sparse state-vector backend, O(k) where k is the number of non-zero amplitudes.
41pub struct SparseBackend {
42    num_qubits: usize,
43    state: HashMap<usize, Complex64>,
44    swap_buf: HashMap<usize, Complex64>,
45    classical_bits: Vec<bool>,
46    rng: ChaCha8Rng,
47    epsilon: f64,
48}
49
50impl SparseBackend {
51    /// Create a new sparse backend with the given RNG seed.
52    pub fn new(seed: u64) -> Self {
53        Self {
54            num_qubits: 0,
55            state: HashMap::new(),
56            swap_buf: HashMap::new(),
57            classical_bits: Vec::new(),
58            rng: ChaCha8Rng::seed_from_u64(seed),
59            epsilon: DEFAULT_EPSILON,
60        }
61    }
62
63    #[inline(always)]
64    fn prune(&mut self) {
65        let eps = self.epsilon;
66        self.state.retain(|_, amp| amp.norm_sqr() >= eps);
67    }
68
69    #[inline(always)]
70    fn apply_single_qubit(&mut self, target: usize, mat: [[Complex64; 2]; 2]) {
71        let mask = 1usize << target;
72        let zero = Complex64::new(0.0, 0.0);
73        self.swap_buf.clear();
74        self.swap_buf.reserve(self.state.len() * 2);
75
76        for (&idx, &amp) in &self.state {
77            let bit = (idx >> target) & 1;
78            let partner = idx ^ mask;
79
80            *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
81            *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
82        }
83
84        std::mem::swap(&mut self.state, &mut self.swap_buf);
85        self.prune();
86    }
87
88    /// CX is a deterministic 1:1 index mapping. No near-zero amplitudes are created.
89    #[inline(always)]
90    fn apply_cx(&mut self, control: usize, target: usize) {
91        let ctrl_mask = 1usize << control;
92        let tgt_mask = 1usize << target;
93        self.swap_buf.clear();
94        self.swap_buf.reserve(self.state.len());
95        self.swap_buf.extend(self.state.drain().map(|(idx, amp)| {
96            if idx & ctrl_mask != 0 {
97                (idx ^ tgt_mask, amp)
98            } else {
99                (idx, amp)
100            }
101        }));
102        std::mem::swap(&mut self.state, &mut self.swap_buf);
103    }
104
105    #[inline(always)]
106    fn apply_cz(&mut self, q0: usize, q1: usize) {
107        let mask0 = 1usize << q0;
108        let mask1 = 1usize << q1;
109        for (&idx, amp) in self.state.iter_mut() {
110            if idx & mask0 != 0 && idx & mask1 != 0 {
111                *amp = -*amp;
112            }
113        }
114    }
115
116    #[inline(always)]
117    fn apply_swap(&mut self, q0: usize, q1: usize) {
118        let m0 = 1usize << q0;
119        let m1 = 1usize << q1;
120        self.swap_buf.clear();
121        self.swap_buf.reserve(self.state.len());
122        self.swap_buf.extend(self.state.drain().map(|(idx, amp)| {
123            let bit0 = (idx >> q0) & 1;
124            let bit1 = (idx >> q1) & 1;
125            if bit0 != bit1 {
126                (idx ^ m0 ^ m1, amp)
127            } else {
128                (idx, amp)
129            }
130        }));
131        std::mem::swap(&mut self.state, &mut self.swap_buf);
132    }
133
134    #[inline(always)]
135    fn apply_cu(&mut self, control: usize, target: usize, mat: [[Complex64; 2]; 2]) {
136        let ctrl_mask = 1usize << control;
137        let tgt_mask = 1usize << target;
138        let zero = Complex64::new(0.0, 0.0);
139        self.swap_buf.clear();
140        self.swap_buf.reserve(self.state.len() * 2);
141
142        for (&idx, &amp) in &self.state {
143            if idx & ctrl_mask == 0 {
144                *self.swap_buf.entry(idx).or_insert(zero) += amp;
145            } else {
146                let bit = (idx >> target) & 1;
147                let partner = idx ^ tgt_mask;
148                *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
149                *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
150            }
151        }
152
153        std::mem::swap(&mut self.state, &mut self.swap_buf);
154        self.prune();
155    }
156
157    #[inline(always)]
158    fn apply_mcu(&mut self, controls: &[usize], target: usize, mat: [[Complex64; 2]; 2]) {
159        let ctrl_mask: usize = controls.iter().map(|&q| 1usize << q).fold(0, |a, b| a | b);
160        let tgt_mask = 1usize << target;
161        let zero = Complex64::new(0.0, 0.0);
162        self.swap_buf.clear();
163        self.swap_buf.reserve(self.state.len() * 2);
164
165        for (&idx, &amp) in &self.state {
166            if idx & ctrl_mask != ctrl_mask {
167                *self.swap_buf.entry(idx).or_insert(zero) += amp;
168            } else {
169                let bit = (idx >> target) & 1;
170                let partner = idx ^ tgt_mask;
171                *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
172                *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
173            }
174        }
175
176        std::mem::swap(&mut self.state, &mut self.swap_buf);
177        self.prune();
178    }
179
180    #[inline(always)]
181    fn apply_cu_phase(&mut self, control: usize, target: usize, phase: Complex64) {
182        let ctrl_mask = 1usize << control;
183        let tgt_mask = 1usize << target;
184        for (&idx, amp) in self.state.iter_mut() {
185            if idx & ctrl_mask != 0 && idx & tgt_mask != 0 {
186                *amp *= phase;
187            }
188        }
189    }
190
191    #[inline(always)]
192    fn apply_mcu_phase(&mut self, controls: &[usize], target: usize, phase: Complex64) {
193        let ctrl_mask: usize = controls.iter().map(|&q| 1usize << q).fold(0, |a, b| a | b);
194        let tgt_mask = 1usize << target;
195        for (&idx, amp) in self.state.iter_mut() {
196            if idx & ctrl_mask == ctrl_mask && idx & tgt_mask != 0 {
197                *amp *= phase;
198            }
199        }
200    }
201
202    #[inline(always)]
203    fn apply_rzz(&mut self, q0: usize, q1: usize, theta: f64) {
204        let phase_same = Complex64::from_polar(1.0, -theta / 2.0);
205        let phase_diff = Complex64::from_polar(1.0, theta / 2.0);
206        for (idx, amp) in self.state.iter_mut() {
207            let parity = ((*idx >> q0) ^ (*idx >> q1)) & 1;
208            *amp *= if parity == 0 { phase_same } else { phase_diff };
209        }
210    }
211
212    fn apply_batch_phase(&mut self, control: usize, phases: &[(usize, Complex64)]) {
213        let ctrl_mask = 1usize << control;
214        let one = Complex64::new(1.0, 0.0);
215        for (&idx, amp) in self.state.iter_mut() {
216            if idx & ctrl_mask == 0 {
217                continue;
218            }
219            let mut combined = one;
220            for &(target, phase) in phases {
221                if idx & (1usize << target) != 0 {
222                    combined *= phase;
223                }
224            }
225            if !is_phase_one(combined) {
226                *amp *= combined;
227            }
228        }
229    }
230
231    fn apply_fused_2q(&mut self, q0: usize, q1: usize, mat: &[[Complex64; 4]; 4]) {
232        let mask0 = 1usize << q0;
233        let mask1 = 1usize << q1;
234        let zero = Complex64::new(0.0, 0.0);
235        self.swap_buf.clear();
236        self.swap_buf.reserve(self.state.len() * 2);
237
238        for (&idx, &amp) in &self.state {
239            let bit0 = (idx >> q0) & 1;
240            let bit1 = (idx >> q1) & 1;
241            let row = bit0 * 2 + bit1;
242            let base = idx & !(mask0 | mask1);
243
244            for (col, mat_row) in mat.iter().enumerate() {
245                let coeff = mat_row[row];
246                if coeff == zero {
247                    continue;
248                }
249                let col_bit0 = (col >> 1) & 1;
250                let col_bit1 = col & 1;
251                let dest = base | (col_bit0 << q0) | (col_bit1 << q1);
252                *self.swap_buf.entry(dest).or_insert(zero) += coeff * amp;
253            }
254        }
255
256        std::mem::swap(&mut self.state, &mut self.swap_buf);
257        self.prune();
258    }
259
260    fn apply_reset(&mut self, qubit: usize) {
261        let mask = 1usize << qubit;
262
263        #[cfg(feature = "parallel")]
264        let prob_zero: f64 = if self.state.len() >= MIN_STATES_FOR_PAR {
265            self.state
266                .par_iter()
267                .filter(|(&idx, _)| idx & mask == 0)
268                .map(|(_, amp)| amp.norm_sqr())
269                .sum()
270        } else {
271            self.state
272                .iter()
273                .filter(|(&idx, _)| idx & mask == 0)
274                .map(|(_, amp)| amp.norm_sqr())
275                .sum()
276        };
277
278        #[cfg(not(feature = "parallel"))]
279        let prob_zero: f64 = self
280            .state
281            .iter()
282            .filter(|(&idx, _)| idx & mask == 0)
283            .map(|(_, amp)| amp.norm_sqr())
284            .sum();
285
286        if prob_zero > 0.0 {
287            let inv_norm = 1.0 / prob_zero.sqrt();
288            self.state.retain(|&idx, amp| {
289                if idx & mask == 0 {
290                    *amp *= inv_norm;
291                    true
292                } else {
293                    false
294                }
295            });
296        } else {
297            self.state.clear();
298            self.state.insert(0, Complex64::new(1.0, 0.0));
299        }
300    }
301
302    fn apply_measure(&mut self, qubit: usize, classical_bit: usize) {
303        let mask = 1usize << qubit;
304
305        #[cfg(feature = "parallel")]
306        let prob_one: f64 = if self.state.len() >= MIN_STATES_FOR_PAR {
307            self.state
308                .par_iter()
309                .filter(|(&idx, _)| idx & mask != 0)
310                .map(|(_, amp)| amp.norm_sqr())
311                .sum()
312        } else {
313            self.state
314                .iter()
315                .filter(|(&idx, _)| idx & mask != 0)
316                .map(|(_, amp)| amp.norm_sqr())
317                .sum()
318        };
319
320        #[cfg(not(feature = "parallel"))]
321        let prob_one: f64 = self
322            .state
323            .iter()
324            .filter(|(&idx, _)| idx & mask != 0)
325            .map(|(_, amp)| amp.norm_sqr())
326            .sum();
327
328        let outcome = self.rng.random::<f64>() < prob_one;
329        self.classical_bits[classical_bit] = outcome;
330
331        let inv_norm = crate::backend::measurement_inv_norm(outcome, prob_one);
332
333        self.state.retain(|&idx, amp| {
334            let matches = (idx & mask != 0) == outcome;
335            if matches {
336                *amp *= inv_norm;
337            }
338            matches
339        });
340    }
341
342    fn dispatch_gate(&mut self, gate: &Gate, targets: &[usize]) {
343        match gate {
344            Gate::Rzz(theta) => {
345                self.apply_rzz(targets[0], targets[1], *theta);
346            }
347            Gate::Cx => {
348                self.apply_cx(targets[0], targets[1]);
349            }
350            Gate::Cz => {
351                self.apply_cz(targets[0], targets[1]);
352            }
353            Gate::Swap => {
354                self.apply_swap(targets[0], targets[1]);
355            }
356            Gate::Cu(mat) => {
357                if let Some(phase) = gate.controlled_phase() {
358                    self.apply_cu_phase(targets[0], targets[1], phase);
359                } else {
360                    self.apply_cu(targets[0], targets[1], **mat);
361                }
362            }
363            Gate::Mcu(data) => {
364                let num_ctrl = data.num_controls as usize;
365                if let Some(phase) = gate.controlled_phase() {
366                    self.apply_mcu_phase(&targets[..num_ctrl], targets[num_ctrl], phase);
367                } else {
368                    self.apply_mcu(&targets[..num_ctrl], targets[num_ctrl], data.mat);
369                }
370            }
371            Gate::BatchPhase(data) => {
372                self.apply_batch_phase(targets[0], &data.phases);
373            }
374            Gate::BatchRzz(data) => {
375                for &(q0, q1, theta) in &data.edges {
376                    self.apply_rzz(q0, q1, theta);
377                }
378            }
379            Gate::DiagonalBatch(data) => {
380                for entry in &data.entries {
381                    match entry {
382                        DiagEntry::Phase1q { qubit, d0, d1 } => {
383                            let mask = 1usize << qubit;
384                            for (idx, amp) in self.state.iter_mut() {
385                                if (*idx & mask) != 0 {
386                                    *amp *= d1;
387                                } else {
388                                    *amp *= d0;
389                                }
390                            }
391                        }
392                        DiagEntry::Phase2q { q0, q1, phase } => {
393                            let mask = (1usize << q0) | (1usize << q1);
394                            for (idx, amp) in self.state.iter_mut() {
395                                if (*idx & mask) == mask {
396                                    *amp *= phase;
397                                }
398                            }
399                        }
400                        DiagEntry::Parity2q { q0, q1, same, diff } => {
401                            for (idx, amp) in self.state.iter_mut() {
402                                let parity = ((*idx >> q0) ^ (*idx >> q1)) & 1;
403                                *amp *= if parity == 0 { *same } else { *diff };
404                            }
405                        }
406                    }
407                }
408            }
409            Gate::MultiFused(data) => {
410                for &(target, mat) in &data.gates {
411                    self.apply_single_qubit(target, mat);
412                }
413            }
414            Gate::Fused2q(mat) => {
415                self.apply_fused_2q(targets[0], targets[1], mat);
416            }
417            Gate::Multi2q(data) => {
418                for &(q0, q1, ref mat) in &data.gates {
419                    self.apply_fused_2q(q0, q1, mat);
420                }
421            }
422            other => {
423                debug_assert!(
424                    targets.len() == 1,
425                    "sparse dispatch_gate: unexpected multi-qubit gate {:?}",
426                    other
427                );
428                let mat = other.matrix_2x2();
429                self.apply_single_qubit(targets[0], mat);
430            }
431        }
432    }
433}
434
435impl Backend for SparseBackend {
436    fn name(&self) -> &'static str {
437        "sparse"
438    }
439
440    fn init(&mut self, num_qubits: usize, num_classical_bits: usize) -> Result<()> {
441        self.num_qubits = num_qubits;
442        self.state.clear();
443        self.state.insert(0, Complex64::new(1.0, 0.0));
444        self.classical_bits = vec![false; num_classical_bits];
445        Ok(())
446    }
447
448    fn apply(&mut self, instruction: &Instruction) -> Result<()> {
449        match instruction {
450            Instruction::Gate { gate, targets } => self.dispatch_gate(gate, targets),
451            Instruction::Measure {
452                qubit,
453                classical_bit,
454            } => {
455                self.apply_measure(*qubit, *classical_bit);
456            }
457            Instruction::Reset { qubit } => {
458                self.apply_reset(*qubit);
459            }
460            Instruction::Barrier { .. } => {}
461            Instruction::Conditional {
462                condition,
463                gate,
464                targets,
465            } => {
466                if condition.evaluate(&self.classical_bits) {
467                    self.dispatch_gate(gate, targets);
468                }
469            }
470        }
471        Ok(())
472    }
473
474    fn reset(&mut self, qubit: usize) -> Result<()> {
475        self.apply_reset(qubit);
476        Ok(())
477    }
478
479    fn apply_1q_matrix(&mut self, qubit: usize, matrix: &[[Complex64; 2]; 2]) -> Result<()> {
480        self.apply_single_qubit(qubit, *matrix);
481        Ok(())
482    }
483
484    fn reduced_density_matrix_1q(&self, qubit: usize) -> Result<[[Complex64; 2]; 2]> {
485        let mask = 1usize << qubit;
486        let mut p0 = 0.0f64;
487        let mut p1 = 0.0f64;
488        let mut r = Complex64::new(0.0, 0.0);
489
490        for (&idx, &amp) in &self.state {
491            if idx & mask == 0 {
492                p0 += amp.norm_sqr();
493                if let Some(&amp_one) = self.state.get(&(idx | mask)) {
494                    r += amp_one * amp.conj();
495                }
496            } else {
497                p1 += amp.norm_sqr();
498            }
499        }
500
501        Ok([
502            [Complex64::new(p0, 0.0), r.conj()],
503            [r, Complex64::new(p1, 0.0)],
504        ])
505    }
506
507    fn classical_results(&self) -> &[bool] {
508        &self.classical_bits
509    }
510
511    fn probabilities(&self) -> Result<Vec<f64>> {
512        let dim = dense_probability_len(self.name(), self.num_qubits)?;
513        let mut probs = Vec::new();
514        reserve_dense_output(&mut probs, dim, self.name(), "probabilities")?;
515        probs.resize(dim, 0.0f64);
516        for (&idx, amp) in &self.state {
517            probs[idx] = amp.norm_sqr();
518        }
519        Ok(probs)
520    }
521
522    fn num_qubits(&self) -> usize {
523        self.num_qubits
524    }
525
526    fn export_statevector(&self) -> Result<Vec<Complex64>> {
527        let dim = dense_statevector_len(self.name(), "statevector export", self.num_qubits)?;
528        let mut sv = Vec::new();
529        reserve_dense_output(&mut sv, dim, self.name(), "statevector export")?;
530        sv.resize(dim, Complex64::new(0.0, 0.0));
531        for (&idx, &amp) in &self.state {
532            sv[idx] = amp;
533        }
534        Ok(sv)
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use crate::circuit::Circuit;
542    use crate::sim;
543
544    const EPS: f64 = 1e-12;
545
546    fn run_sparse(circuit: &Circuit) -> SparseBackend {
547        let mut b = SparseBackend::new(42);
548        sim::run_on(&mut b, circuit).unwrap();
549        b
550    }
551
552    fn run_sparse_probs(circuit: &Circuit) -> Vec<f64> {
553        let b = run_sparse(circuit);
554        b.probabilities().unwrap()
555    }
556
557    #[test]
558    fn test_init_zero_state() {
559        let mut b = SparseBackend::new(42);
560        b.init(3, 0).unwrap();
561        assert_eq!(b.state.len(), 1);
562        assert!((b.state[&0].re - 1.0).abs() < EPS);
563    }
564
565    #[test]
566    fn test_x_gate() {
567        let mut c = Circuit::new(1, 0);
568        c.add_gate(Gate::X, &[0]);
569        let b = run_sparse(&c);
570        assert_eq!(b.state.len(), 1);
571        assert!(b.state.contains_key(&1));
572        assert!((b.state[&1].norm() - 1.0).abs() < EPS);
573    }
574
575    #[test]
576    fn test_h_creates_superposition() {
577        let mut c = Circuit::new(1, 0);
578        c.add_gate(Gate::H, &[0]);
579        let b = run_sparse(&c);
580        assert_eq!(b.state.len(), 2);
581        assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
582        assert!((b.state[&1].norm_sqr() - 0.5).abs() < EPS);
583    }
584
585    #[test]
586    fn test_hh_is_identity() {
587        let mut c = Circuit::new(1, 0);
588        c.add_gate(Gate::H, &[0]);
589        c.add_gate(Gate::H, &[0]);
590        let b = run_sparse(&c);
591        assert_eq!(b.state.len(), 1);
592        assert!((b.state[&0].re - 1.0).abs() < EPS);
593    }
594
595    #[test]
596    fn test_cx_bell_state() {
597        let mut c = Circuit::new(2, 0);
598        c.add_gate(Gate::H, &[0]);
599        c.add_gate(Gate::Cx, &[0, 1]);
600        let b = run_sparse(&c);
601        assert_eq!(b.state.len(), 2);
602        assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
603        assert!((b.state[&3].norm_sqr() - 0.5).abs() < EPS);
604    }
605
606    #[test]
607    fn test_cz_phase() {
608        let mut c = Circuit::new(2, 0);
609        c.add_gate(Gate::X, &[0]);
610        c.add_gate(Gate::X, &[1]);
611        c.add_gate(Gate::Cz, &[0, 1]);
612        let b = run_sparse(&c);
613        assert_eq!(b.state.len(), 1);
614        assert!((b.state[&3].re - (-1.0)).abs() < EPS);
615    }
616
617    #[test]
618    fn test_swap() {
619        let mut c = Circuit::new(2, 0);
620        c.add_gate(Gate::X, &[1]);
621        c.add_gate(Gate::Swap, &[0, 1]);
622        let b = run_sparse(&c);
623        assert_eq!(b.state.len(), 1);
624        assert!(b.state.contains_key(&1));
625    }
626
627    #[test]
628    fn test_rx_pi() {
629        let mut c = Circuit::new(1, 0);
630        c.add_gate(Gate::Rx(std::f64::consts::PI), &[0]);
631        let probs = run_sparse_probs(&c);
632        assert!(probs[0].abs() < EPS);
633        assert!((probs[1] - 1.0).abs() < EPS);
634    }
635
636    #[test]
637    fn test_rz_preserves_sparsity() {
638        let mut c = Circuit::new(1, 0);
639        c.add_gate(Gate::Rz(1.234), &[0]);
640        let b = run_sparse(&c);
641        assert_eq!(b.state.len(), 1);
642        assert!((b.state[&0].norm() - 1.0).abs() < EPS);
643    }
644
645    #[test]
646    fn test_measure_collapses() {
647        let mut c = Circuit::new(1, 1);
648        c.add_gate(Gate::H, &[0]);
649        c.add_measure(0, 0);
650        let b = run_sparse(&c);
651        assert_eq!(b.state.len(), 1);
652        let outcome = b.classical_results()[0];
653        if outcome {
654            assert!(b.state.contains_key(&1));
655        } else {
656            assert!(b.state.contains_key(&0));
657        }
658    }
659
660    #[test]
661    fn test_measure_deterministic() {
662        let mut c = Circuit::new(1, 1);
663        c.add_gate(Gate::H, &[0]);
664        c.add_measure(0, 0);
665
666        let b1 = run_sparse(&c);
667        let b2 = run_sparse(&c);
668        assert_eq!(b1.classical_results()[0], b2.classical_results()[0]);
669    }
670
671    #[test]
672    fn test_probs_bell() {
673        let mut c = Circuit::new(2, 0);
674        c.add_gate(Gate::H, &[0]);
675        c.add_gate(Gate::Cx, &[0, 1]);
676        let probs = run_sparse_probs(&c);
677        assert!((probs[0] - 0.5).abs() < EPS);
678        assert!(probs[1].abs() < EPS);
679        assert!(probs[2].abs() < EPS);
680        assert!((probs[3] - 0.5).abs() < EPS);
681    }
682
683    #[test]
684    fn test_probs_zero_state() {
685        let c = Circuit::new(3, 0);
686        let probs = run_sparse_probs(&c);
687        assert!((probs[0] - 1.0).abs() < EPS);
688        let rest: f64 = probs[1..].iter().sum();
689        assert!(rest.abs() < EPS);
690    }
691
692    #[test]
693    fn test_pruning() {
694        let mut b = SparseBackend::new(42);
695        b.init(1, 0).unwrap();
696        b.state.insert(1, Complex64::new(1e-20, 0.0));
697        assert_eq!(b.state.len(), 2);
698        b.prune();
699        assert_eq!(b.state.len(), 1);
700        assert!(b.state.contains_key(&0));
701    }
702
703    #[test]
704    fn test_fused_gate() {
705        let h_mat = Gate::H.matrix_2x2();
706        let t_mat = Gate::T.matrix_2x2();
707        let zero = Complex64::new(0.0, 0.0);
708        let mut fused = [[zero; 2]; 2];
709        for i in 0..2 {
710            for j in 0..2 {
711                for k in 0..2 {
712                    fused[i][j] += t_mat[i][k] * h_mat[k][j];
713                }
714            }
715        }
716
717        let mut c1 = Circuit::new(1, 0);
718        c1.add_gate(Gate::H, &[0]);
719        c1.add_gate(Gate::T, &[0]);
720        let p1 = run_sparse_probs(&c1);
721
722        let mut c2 = Circuit::new(1, 0);
723        c2.add_gate(Gate::Fused(Box::new(fused)), &[0]);
724        let p2 = run_sparse_probs(&c2);
725
726        for (a, b) in p1.iter().zip(p2.iter()) {
727            assert!((a - b).abs() < EPS);
728        }
729    }
730
731    #[test]
732    fn test_ghz_4_sparse() {
733        let mut c = Circuit::new(4, 0);
734        c.add_gate(Gate::H, &[0]);
735        for i in 0..3 {
736            c.add_gate(Gate::Cx, &[i, i + 1]);
737        }
738        let b = run_sparse(&c);
739        assert_eq!(b.state.len(), 2);
740        assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
741        assert!((b.state[&15].norm_sqr() - 0.5).abs() < EPS);
742    }
743
744    #[test]
745    fn test_cu_phase_applies_phase() {
746        let mut c = Circuit::new(2, 0);
747        c.add_gate(Gate::X, &[0]);
748        c.add_gate(Gate::X, &[1]);
749        c.add_gate(Gate::cphase(std::f64::consts::FRAC_PI_4), &[0, 1]);
750        let b = run_sparse(&c);
751        assert_eq!(b.state.len(), 1);
752        let expected = Complex64::from_polar(1.0, std::f64::consts::FRAC_PI_4);
753        assert!((b.state[&3] - expected).norm() < EPS);
754    }
755
756    #[test]
757    fn test_cu_phase_no_action_control_zero() {
758        let mut c = Circuit::new(2, 0);
759        c.add_gate(Gate::H, &[1]);
760        c.add_gate(Gate::cphase(1.0), &[0, 1]);
761        let b = run_sparse(&c);
762        let h = 1.0 / 2.0_f64.sqrt();
763        assert!((b.state[&0].re - h).abs() < EPS);
764        assert!((b.state[&2].re - h).abs() < EPS);
765        assert!(!b.state.contains_key(&1));
766        assert!(!b.state.contains_key(&3));
767    }
768
769    #[test]
770    fn test_cu_phase_matches_cz() {
771        let mut c1 = Circuit::new(2, 0);
772        c1.add_gate(Gate::H, &[0]);
773        c1.add_gate(Gate::H, &[1]);
774        c1.add_gate(Gate::cphase(std::f64::consts::PI), &[0, 1]);
775
776        let mut c2 = Circuit::new(2, 0);
777        c2.add_gate(Gate::H, &[0]);
778        c2.add_gate(Gate::H, &[1]);
779        c2.add_gate(Gate::Cz, &[0, 1]);
780
781        let b1 = run_sparse(&c1);
782        let b2 = run_sparse(&c2);
783
784        for (&idx, &amp1) in &b1.state {
785            let amp2 = b2
786                .state
787                .get(&idx)
788                .copied()
789                .unwrap_or(Complex64::new(0.0, 0.0));
790            assert!((amp1 - amp2).norm() < EPS, "mismatch at idx {idx}");
791        }
792    }
793
794    #[test]
795    fn test_batch_phase_matches_individual() {
796        use crate::gates::BatchPhaseData;
797        use smallvec::smallvec;
798
799        let phase1 = Complex64::from_polar(1.0, 0.5);
800        let phase2 = Complex64::from_polar(1.0, 1.2);
801
802        let mut c1 = Circuit::new(3, 0);
803        c1.add_gate(Gate::H, &[0]);
804        c1.add_gate(Gate::H, &[1]);
805        c1.add_gate(Gate::H, &[2]);
806        c1.add_gate(Gate::cphase(0.5), &[0, 1]);
807        c1.add_gate(Gate::cphase(1.2), &[0, 2]);
808        let p1 = run_sparse_probs(&c1);
809
810        let mut c2 = Circuit::new(3, 0);
811        c2.add_gate(Gate::H, &[0]);
812        c2.add_gate(Gate::H, &[1]);
813        c2.add_gate(Gate::H, &[2]);
814        c2.add_gate(
815            Gate::BatchPhase(Box::new(BatchPhaseData {
816                phases: smallvec![(1, phase1), (2, phase2)],
817            })),
818            &[0, 1, 2],
819        );
820        let p2 = run_sparse_probs(&c2);
821
822        for (a, b) in p1.iter().zip(p2.iter()) {
823            assert!((a - b).abs() < EPS, "probs mismatch: {a} vs {b}");
824        }
825    }
826}