1use std::collections::HashMap;
19
20use num_complex::Complex64;
21use rand::Rng;
22use rand::SeedableRng;
23use rand_chacha::ChaCha8Rng;
24
25#[cfg(feature = "parallel")]
26use rayon::prelude::*;
27
28#[cfg(feature = "parallel")]
29const MIN_STATES_FOR_PAR: usize = 4096;
30
31use crate::backend::{
32 dense_probability_len, dense_statevector_len, is_phase_one, reserve_dense_output, Backend,
33};
34use crate::circuit::Instruction;
35use crate::error::Result;
36use crate::gates::{DiagEntry, Gate};
37
38const DEFAULT_EPSILON: f64 = 1e-16;
39
40pub struct SparseBackend {
42 num_qubits: usize,
43 state: HashMap<usize, Complex64>,
44 swap_buf: HashMap<usize, Complex64>,
45 classical_bits: Vec<bool>,
46 rng: ChaCha8Rng,
47 epsilon: f64,
48}
49
50impl SparseBackend {
51 pub fn new(seed: u64) -> Self {
53 Self {
54 num_qubits: 0,
55 state: HashMap::new(),
56 swap_buf: HashMap::new(),
57 classical_bits: Vec::new(),
58 rng: ChaCha8Rng::seed_from_u64(seed),
59 epsilon: DEFAULT_EPSILON,
60 }
61 }
62
63 #[inline(always)]
64 fn prune(&mut self) {
65 let eps = self.epsilon;
66 self.state.retain(|_, amp| amp.norm_sqr() >= eps);
67 }
68
69 #[inline(always)]
70 fn apply_single_qubit(&mut self, target: usize, mat: [[Complex64; 2]; 2]) {
71 let mask = 1usize << target;
72 let zero = Complex64::new(0.0, 0.0);
73 self.swap_buf.clear();
74 self.swap_buf.reserve(self.state.len() * 2);
75
76 for (&idx, &) in &self.state {
77 let bit = (idx >> target) & 1;
78 let partner = idx ^ mask;
79
80 *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
81 *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
82 }
83
84 std::mem::swap(&mut self.state, &mut self.swap_buf);
85 self.prune();
86 }
87
88 #[inline(always)]
90 fn apply_cx(&mut self, control: usize, target: usize) {
91 let ctrl_mask = 1usize << control;
92 let tgt_mask = 1usize << target;
93 self.swap_buf.clear();
94 self.swap_buf.reserve(self.state.len());
95 self.swap_buf.extend(self.state.drain().map(|(idx, amp)| {
96 if idx & ctrl_mask != 0 {
97 (idx ^ tgt_mask, amp)
98 } else {
99 (idx, amp)
100 }
101 }));
102 std::mem::swap(&mut self.state, &mut self.swap_buf);
103 }
104
105 #[inline(always)]
106 fn apply_cz(&mut self, q0: usize, q1: usize) {
107 let mask0 = 1usize << q0;
108 let mask1 = 1usize << q1;
109 for (&idx, amp) in self.state.iter_mut() {
110 if idx & mask0 != 0 && idx & mask1 != 0 {
111 *amp = -*amp;
112 }
113 }
114 }
115
116 #[inline(always)]
117 fn apply_swap(&mut self, q0: usize, q1: usize) {
118 let m0 = 1usize << q0;
119 let m1 = 1usize << q1;
120 self.swap_buf.clear();
121 self.swap_buf.reserve(self.state.len());
122 self.swap_buf.extend(self.state.drain().map(|(idx, amp)| {
123 let bit0 = (idx >> q0) & 1;
124 let bit1 = (idx >> q1) & 1;
125 if bit0 != bit1 {
126 (idx ^ m0 ^ m1, amp)
127 } else {
128 (idx, amp)
129 }
130 }));
131 std::mem::swap(&mut self.state, &mut self.swap_buf);
132 }
133
134 #[inline(always)]
135 fn apply_cu(&mut self, control: usize, target: usize, mat: [[Complex64; 2]; 2]) {
136 let ctrl_mask = 1usize << control;
137 let tgt_mask = 1usize << target;
138 let zero = Complex64::new(0.0, 0.0);
139 self.swap_buf.clear();
140 self.swap_buf.reserve(self.state.len() * 2);
141
142 for (&idx, &) in &self.state {
143 if idx & ctrl_mask == 0 {
144 *self.swap_buf.entry(idx).or_insert(zero) += amp;
145 } else {
146 let bit = (idx >> target) & 1;
147 let partner = idx ^ tgt_mask;
148 *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
149 *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
150 }
151 }
152
153 std::mem::swap(&mut self.state, &mut self.swap_buf);
154 self.prune();
155 }
156
157 #[inline(always)]
158 fn apply_mcu(&mut self, controls: &[usize], target: usize, mat: [[Complex64; 2]; 2]) {
159 let ctrl_mask: usize = controls.iter().map(|&q| 1usize << q).fold(0, |a, b| a | b);
160 let tgt_mask = 1usize << target;
161 let zero = Complex64::new(0.0, 0.0);
162 self.swap_buf.clear();
163 self.swap_buf.reserve(self.state.len() * 2);
164
165 for (&idx, &) in &self.state {
166 if idx & ctrl_mask != ctrl_mask {
167 *self.swap_buf.entry(idx).or_insert(zero) += amp;
168 } else {
169 let bit = (idx >> target) & 1;
170 let partner = idx ^ tgt_mask;
171 *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
172 *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
173 }
174 }
175
176 std::mem::swap(&mut self.state, &mut self.swap_buf);
177 self.prune();
178 }
179
180 #[inline(always)]
181 fn apply_cu_phase(&mut self, control: usize, target: usize, phase: Complex64) {
182 let ctrl_mask = 1usize << control;
183 let tgt_mask = 1usize << target;
184 for (&idx, amp) in self.state.iter_mut() {
185 if idx & ctrl_mask != 0 && idx & tgt_mask != 0 {
186 *amp *= phase;
187 }
188 }
189 }
190
191 #[inline(always)]
192 fn apply_mcu_phase(&mut self, controls: &[usize], target: usize, phase: Complex64) {
193 let ctrl_mask: usize = controls.iter().map(|&q| 1usize << q).fold(0, |a, b| a | b);
194 let tgt_mask = 1usize << target;
195 for (&idx, amp) in self.state.iter_mut() {
196 if idx & ctrl_mask == ctrl_mask && idx & tgt_mask != 0 {
197 *amp *= phase;
198 }
199 }
200 }
201
202 #[inline(always)]
203 fn apply_rzz(&mut self, q0: usize, q1: usize, theta: f64) {
204 let phase_same = Complex64::from_polar(1.0, -theta / 2.0);
205 let phase_diff = Complex64::from_polar(1.0, theta / 2.0);
206 for (idx, amp) in self.state.iter_mut() {
207 let parity = ((*idx >> q0) ^ (*idx >> q1)) & 1;
208 *amp *= if parity == 0 { phase_same } else { phase_diff };
209 }
210 }
211
212 fn apply_batch_phase(&mut self, control: usize, phases: &[(usize, Complex64)]) {
213 let ctrl_mask = 1usize << control;
214 let one = Complex64::new(1.0, 0.0);
215 for (&idx, amp) in self.state.iter_mut() {
216 if idx & ctrl_mask == 0 {
217 continue;
218 }
219 let mut combined = one;
220 for &(target, phase) in phases {
221 if idx & (1usize << target) != 0 {
222 combined *= phase;
223 }
224 }
225 if !is_phase_one(combined) {
226 *amp *= combined;
227 }
228 }
229 }
230
231 fn apply_fused_2q(&mut self, q0: usize, q1: usize, mat: &[[Complex64; 4]; 4]) {
232 let mask0 = 1usize << q0;
233 let mask1 = 1usize << q1;
234 let zero = Complex64::new(0.0, 0.0);
235 self.swap_buf.clear();
236 self.swap_buf.reserve(self.state.len() * 2);
237
238 for (&idx, &) in &self.state {
239 let bit0 = (idx >> q0) & 1;
240 let bit1 = (idx >> q1) & 1;
241 let row = bit0 * 2 + bit1;
242 let base = idx & !(mask0 | mask1);
243
244 for (col, mat_row) in mat.iter().enumerate() {
245 let coeff = mat_row[row];
246 if coeff == zero {
247 continue;
248 }
249 let col_bit0 = (col >> 1) & 1;
250 let col_bit1 = col & 1;
251 let dest = base | (col_bit0 << q0) | (col_bit1 << q1);
252 *self.swap_buf.entry(dest).or_insert(zero) += coeff * amp;
253 }
254 }
255
256 std::mem::swap(&mut self.state, &mut self.swap_buf);
257 self.prune();
258 }
259
260 fn apply_reset(&mut self, qubit: usize) {
261 let mask = 1usize << qubit;
262
263 #[cfg(feature = "parallel")]
264 let prob_zero: f64 = if self.state.len() >= MIN_STATES_FOR_PAR {
265 self.state
266 .par_iter()
267 .filter(|(&idx, _)| idx & mask == 0)
268 .map(|(_, amp)| amp.norm_sqr())
269 .sum()
270 } else {
271 self.state
272 .iter()
273 .filter(|(&idx, _)| idx & mask == 0)
274 .map(|(_, amp)| amp.norm_sqr())
275 .sum()
276 };
277
278 #[cfg(not(feature = "parallel"))]
279 let prob_zero: f64 = self
280 .state
281 .iter()
282 .filter(|(&idx, _)| idx & mask == 0)
283 .map(|(_, amp)| amp.norm_sqr())
284 .sum();
285
286 if prob_zero > 0.0 {
287 let inv_norm = 1.0 / prob_zero.sqrt();
288 self.state.retain(|&idx, amp| {
289 if idx & mask == 0 {
290 *amp *= inv_norm;
291 true
292 } else {
293 false
294 }
295 });
296 } else {
297 self.state.clear();
298 self.state.insert(0, Complex64::new(1.0, 0.0));
299 }
300 }
301
302 fn apply_measure(&mut self, qubit: usize, classical_bit: usize) {
303 let mask = 1usize << qubit;
304
305 #[cfg(feature = "parallel")]
306 let prob_one: f64 = if self.state.len() >= MIN_STATES_FOR_PAR {
307 self.state
308 .par_iter()
309 .filter(|(&idx, _)| idx & mask != 0)
310 .map(|(_, amp)| amp.norm_sqr())
311 .sum()
312 } else {
313 self.state
314 .iter()
315 .filter(|(&idx, _)| idx & mask != 0)
316 .map(|(_, amp)| amp.norm_sqr())
317 .sum()
318 };
319
320 #[cfg(not(feature = "parallel"))]
321 let prob_one: f64 = self
322 .state
323 .iter()
324 .filter(|(&idx, _)| idx & mask != 0)
325 .map(|(_, amp)| amp.norm_sqr())
326 .sum();
327
328 let outcome = self.rng.random::<f64>() < prob_one;
329 self.classical_bits[classical_bit] = outcome;
330
331 let inv_norm = crate::backend::measurement_inv_norm(outcome, prob_one);
332
333 self.state.retain(|&idx, amp| {
334 let matches = (idx & mask != 0) == outcome;
335 if matches {
336 *amp *= inv_norm;
337 }
338 matches
339 });
340 }
341
342 fn dispatch_gate(&mut self, gate: &Gate, targets: &[usize]) {
343 match gate {
344 Gate::Rzz(theta) => {
345 self.apply_rzz(targets[0], targets[1], *theta);
346 }
347 Gate::Cx => {
348 self.apply_cx(targets[0], targets[1]);
349 }
350 Gate::Cz => {
351 self.apply_cz(targets[0], targets[1]);
352 }
353 Gate::Swap => {
354 self.apply_swap(targets[0], targets[1]);
355 }
356 Gate::Cu(mat) => {
357 if let Some(phase) = gate.controlled_phase() {
358 self.apply_cu_phase(targets[0], targets[1], phase);
359 } else {
360 self.apply_cu(targets[0], targets[1], **mat);
361 }
362 }
363 Gate::Mcu(data) => {
364 let num_ctrl = data.num_controls as usize;
365 if let Some(phase) = gate.controlled_phase() {
366 self.apply_mcu_phase(&targets[..num_ctrl], targets[num_ctrl], phase);
367 } else {
368 self.apply_mcu(&targets[..num_ctrl], targets[num_ctrl], data.mat);
369 }
370 }
371 Gate::BatchPhase(data) => {
372 self.apply_batch_phase(targets[0], &data.phases);
373 }
374 Gate::BatchRzz(data) => {
375 for &(q0, q1, theta) in &data.edges {
376 self.apply_rzz(q0, q1, theta);
377 }
378 }
379 Gate::DiagonalBatch(data) => {
380 for entry in &data.entries {
381 match entry {
382 DiagEntry::Phase1q { qubit, d0, d1 } => {
383 let mask = 1usize << qubit;
384 for (idx, amp) in self.state.iter_mut() {
385 if (*idx & mask) != 0 {
386 *amp *= d1;
387 } else {
388 *amp *= d0;
389 }
390 }
391 }
392 DiagEntry::Phase2q { q0, q1, phase } => {
393 let mask = (1usize << q0) | (1usize << q1);
394 for (idx, amp) in self.state.iter_mut() {
395 if (*idx & mask) == mask {
396 *amp *= phase;
397 }
398 }
399 }
400 DiagEntry::Parity2q { q0, q1, same, diff } => {
401 for (idx, amp) in self.state.iter_mut() {
402 let parity = ((*idx >> q0) ^ (*idx >> q1)) & 1;
403 *amp *= if parity == 0 { *same } else { *diff };
404 }
405 }
406 }
407 }
408 }
409 Gate::MultiFused(data) => {
410 for &(target, mat) in &data.gates {
411 self.apply_single_qubit(target, mat);
412 }
413 }
414 Gate::Fused2q(mat) => {
415 self.apply_fused_2q(targets[0], targets[1], mat);
416 }
417 Gate::Multi2q(data) => {
418 for &(q0, q1, ref mat) in &data.gates {
419 self.apply_fused_2q(q0, q1, mat);
420 }
421 }
422 other => {
423 debug_assert!(
424 targets.len() == 1,
425 "sparse dispatch_gate: unexpected multi-qubit gate {:?}",
426 other
427 );
428 let mat = other.matrix_2x2();
429 self.apply_single_qubit(targets[0], mat);
430 }
431 }
432 }
433}
434
435impl Backend for SparseBackend {
436 fn name(&self) -> &'static str {
437 "sparse"
438 }
439
440 fn init(&mut self, num_qubits: usize, num_classical_bits: usize) -> Result<()> {
441 self.num_qubits = num_qubits;
442 self.state.clear();
443 self.state.insert(0, Complex64::new(1.0, 0.0));
444 self.classical_bits = vec![false; num_classical_bits];
445 Ok(())
446 }
447
448 fn apply(&mut self, instruction: &Instruction) -> Result<()> {
449 match instruction {
450 Instruction::Gate { gate, targets } => self.dispatch_gate(gate, targets),
451 Instruction::Measure {
452 qubit,
453 classical_bit,
454 } => {
455 self.apply_measure(*qubit, *classical_bit);
456 }
457 Instruction::Reset { qubit } => {
458 self.apply_reset(*qubit);
459 }
460 Instruction::Barrier { .. } => {}
461 Instruction::Conditional {
462 condition,
463 gate,
464 targets,
465 } => {
466 if condition.evaluate(&self.classical_bits) {
467 self.dispatch_gate(gate, targets);
468 }
469 }
470 }
471 Ok(())
472 }
473
474 fn reset(&mut self, qubit: usize) -> Result<()> {
475 self.apply_reset(qubit);
476 Ok(())
477 }
478
479 fn apply_1q_matrix(&mut self, qubit: usize, matrix: &[[Complex64; 2]; 2]) -> Result<()> {
480 self.apply_single_qubit(qubit, *matrix);
481 Ok(())
482 }
483
484 fn reduced_density_matrix_1q(&self, qubit: usize) -> Result<[[Complex64; 2]; 2]> {
485 let mask = 1usize << qubit;
486 let mut p0 = 0.0f64;
487 let mut p1 = 0.0f64;
488 let mut r = Complex64::new(0.0, 0.0);
489
490 for (&idx, &) in &self.state {
491 if idx & mask == 0 {
492 p0 += amp.norm_sqr();
493 if let Some(&_one) = self.state.get(&(idx | mask)) {
494 r += amp_one * amp.conj();
495 }
496 } else {
497 p1 += amp.norm_sqr();
498 }
499 }
500
501 Ok([
502 [Complex64::new(p0, 0.0), r.conj()],
503 [r, Complex64::new(p1, 0.0)],
504 ])
505 }
506
507 fn classical_results(&self) -> &[bool] {
508 &self.classical_bits
509 }
510
511 fn probabilities(&self) -> Result<Vec<f64>> {
512 let dim = dense_probability_len(self.name(), self.num_qubits)?;
513 let mut probs = Vec::new();
514 reserve_dense_output(&mut probs, dim, self.name(), "probabilities")?;
515 probs.resize(dim, 0.0f64);
516 for (&idx, amp) in &self.state {
517 probs[idx] = amp.norm_sqr();
518 }
519 Ok(probs)
520 }
521
522 fn num_qubits(&self) -> usize {
523 self.num_qubits
524 }
525
526 fn export_statevector(&self) -> Result<Vec<Complex64>> {
527 let dim = dense_statevector_len(self.name(), "statevector export", self.num_qubits)?;
528 let mut sv = Vec::new();
529 reserve_dense_output(&mut sv, dim, self.name(), "statevector export")?;
530 sv.resize(dim, Complex64::new(0.0, 0.0));
531 for (&idx, &) in &self.state {
532 sv[idx] = amp;
533 }
534 Ok(sv)
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use crate::circuit::Circuit;
542 use crate::sim;
543
544 const EPS: f64 = 1e-12;
545
546 fn run_sparse(circuit: &Circuit) -> SparseBackend {
547 let mut b = SparseBackend::new(42);
548 sim::run_on(&mut b, circuit).unwrap();
549 b
550 }
551
552 fn run_sparse_probs(circuit: &Circuit) -> Vec<f64> {
553 let b = run_sparse(circuit);
554 b.probabilities().unwrap()
555 }
556
557 #[test]
558 fn test_init_zero_state() {
559 let mut b = SparseBackend::new(42);
560 b.init(3, 0).unwrap();
561 assert_eq!(b.state.len(), 1);
562 assert!((b.state[&0].re - 1.0).abs() < EPS);
563 }
564
565 #[test]
566 fn test_x_gate() {
567 let mut c = Circuit::new(1, 0);
568 c.add_gate(Gate::X, &[0]);
569 let b = run_sparse(&c);
570 assert_eq!(b.state.len(), 1);
571 assert!(b.state.contains_key(&1));
572 assert!((b.state[&1].norm() - 1.0).abs() < EPS);
573 }
574
575 #[test]
576 fn test_h_creates_superposition() {
577 let mut c = Circuit::new(1, 0);
578 c.add_gate(Gate::H, &[0]);
579 let b = run_sparse(&c);
580 assert_eq!(b.state.len(), 2);
581 assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
582 assert!((b.state[&1].norm_sqr() - 0.5).abs() < EPS);
583 }
584
585 #[test]
586 fn test_hh_is_identity() {
587 let mut c = Circuit::new(1, 0);
588 c.add_gate(Gate::H, &[0]);
589 c.add_gate(Gate::H, &[0]);
590 let b = run_sparse(&c);
591 assert_eq!(b.state.len(), 1);
592 assert!((b.state[&0].re - 1.0).abs() < EPS);
593 }
594
595 #[test]
596 fn test_cx_bell_state() {
597 let mut c = Circuit::new(2, 0);
598 c.add_gate(Gate::H, &[0]);
599 c.add_gate(Gate::Cx, &[0, 1]);
600 let b = run_sparse(&c);
601 assert_eq!(b.state.len(), 2);
602 assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
603 assert!((b.state[&3].norm_sqr() - 0.5).abs() < EPS);
604 }
605
606 #[test]
607 fn test_cz_phase() {
608 let mut c = Circuit::new(2, 0);
609 c.add_gate(Gate::X, &[0]);
610 c.add_gate(Gate::X, &[1]);
611 c.add_gate(Gate::Cz, &[0, 1]);
612 let b = run_sparse(&c);
613 assert_eq!(b.state.len(), 1);
614 assert!((b.state[&3].re - (-1.0)).abs() < EPS);
615 }
616
617 #[test]
618 fn test_swap() {
619 let mut c = Circuit::new(2, 0);
620 c.add_gate(Gate::X, &[1]);
621 c.add_gate(Gate::Swap, &[0, 1]);
622 let b = run_sparse(&c);
623 assert_eq!(b.state.len(), 1);
624 assert!(b.state.contains_key(&1));
625 }
626
627 #[test]
628 fn test_rx_pi() {
629 let mut c = Circuit::new(1, 0);
630 c.add_gate(Gate::Rx(std::f64::consts::PI), &[0]);
631 let probs = run_sparse_probs(&c);
632 assert!(probs[0].abs() < EPS);
633 assert!((probs[1] - 1.0).abs() < EPS);
634 }
635
636 #[test]
637 fn test_rz_preserves_sparsity() {
638 let mut c = Circuit::new(1, 0);
639 c.add_gate(Gate::Rz(1.234), &[0]);
640 let b = run_sparse(&c);
641 assert_eq!(b.state.len(), 1);
642 assert!((b.state[&0].norm() - 1.0).abs() < EPS);
643 }
644
645 #[test]
646 fn test_measure_collapses() {
647 let mut c = Circuit::new(1, 1);
648 c.add_gate(Gate::H, &[0]);
649 c.add_measure(0, 0);
650 let b = run_sparse(&c);
651 assert_eq!(b.state.len(), 1);
652 let outcome = b.classical_results()[0];
653 if outcome {
654 assert!(b.state.contains_key(&1));
655 } else {
656 assert!(b.state.contains_key(&0));
657 }
658 }
659
660 #[test]
661 fn test_measure_deterministic() {
662 let mut c = Circuit::new(1, 1);
663 c.add_gate(Gate::H, &[0]);
664 c.add_measure(0, 0);
665
666 let b1 = run_sparse(&c);
667 let b2 = run_sparse(&c);
668 assert_eq!(b1.classical_results()[0], b2.classical_results()[0]);
669 }
670
671 #[test]
672 fn test_probs_bell() {
673 let mut c = Circuit::new(2, 0);
674 c.add_gate(Gate::H, &[0]);
675 c.add_gate(Gate::Cx, &[0, 1]);
676 let probs = run_sparse_probs(&c);
677 assert!((probs[0] - 0.5).abs() < EPS);
678 assert!(probs[1].abs() < EPS);
679 assert!(probs[2].abs() < EPS);
680 assert!((probs[3] - 0.5).abs() < EPS);
681 }
682
683 #[test]
684 fn test_probs_zero_state() {
685 let c = Circuit::new(3, 0);
686 let probs = run_sparse_probs(&c);
687 assert!((probs[0] - 1.0).abs() < EPS);
688 let rest: f64 = probs[1..].iter().sum();
689 assert!(rest.abs() < EPS);
690 }
691
692 #[test]
693 fn test_pruning() {
694 let mut b = SparseBackend::new(42);
695 b.init(1, 0).unwrap();
696 b.state.insert(1, Complex64::new(1e-20, 0.0));
697 assert_eq!(b.state.len(), 2);
698 b.prune();
699 assert_eq!(b.state.len(), 1);
700 assert!(b.state.contains_key(&0));
701 }
702
703 #[test]
704 fn test_fused_gate() {
705 let h_mat = Gate::H.matrix_2x2();
706 let t_mat = Gate::T.matrix_2x2();
707 let zero = Complex64::new(0.0, 0.0);
708 let mut fused = [[zero; 2]; 2];
709 for i in 0..2 {
710 for j in 0..2 {
711 for k in 0..2 {
712 fused[i][j] += t_mat[i][k] * h_mat[k][j];
713 }
714 }
715 }
716
717 let mut c1 = Circuit::new(1, 0);
718 c1.add_gate(Gate::H, &[0]);
719 c1.add_gate(Gate::T, &[0]);
720 let p1 = run_sparse_probs(&c1);
721
722 let mut c2 = Circuit::new(1, 0);
723 c2.add_gate(Gate::Fused(Box::new(fused)), &[0]);
724 let p2 = run_sparse_probs(&c2);
725
726 for (a, b) in p1.iter().zip(p2.iter()) {
727 assert!((a - b).abs() < EPS);
728 }
729 }
730
731 #[test]
732 fn test_ghz_4_sparse() {
733 let mut c = Circuit::new(4, 0);
734 c.add_gate(Gate::H, &[0]);
735 for i in 0..3 {
736 c.add_gate(Gate::Cx, &[i, i + 1]);
737 }
738 let b = run_sparse(&c);
739 assert_eq!(b.state.len(), 2);
740 assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
741 assert!((b.state[&15].norm_sqr() - 0.5).abs() < EPS);
742 }
743
744 #[test]
745 fn test_cu_phase_applies_phase() {
746 let mut c = Circuit::new(2, 0);
747 c.add_gate(Gate::X, &[0]);
748 c.add_gate(Gate::X, &[1]);
749 c.add_gate(Gate::cphase(std::f64::consts::FRAC_PI_4), &[0, 1]);
750 let b = run_sparse(&c);
751 assert_eq!(b.state.len(), 1);
752 let expected = Complex64::from_polar(1.0, std::f64::consts::FRAC_PI_4);
753 assert!((b.state[&3] - expected).norm() < EPS);
754 }
755
756 #[test]
757 fn test_cu_phase_no_action_control_zero() {
758 let mut c = Circuit::new(2, 0);
759 c.add_gate(Gate::H, &[1]);
760 c.add_gate(Gate::cphase(1.0), &[0, 1]);
761 let b = run_sparse(&c);
762 let h = 1.0 / 2.0_f64.sqrt();
763 assert!((b.state[&0].re - h).abs() < EPS);
764 assert!((b.state[&2].re - h).abs() < EPS);
765 assert!(!b.state.contains_key(&1));
766 assert!(!b.state.contains_key(&3));
767 }
768
769 #[test]
770 fn test_cu_phase_matches_cz() {
771 let mut c1 = Circuit::new(2, 0);
772 c1.add_gate(Gate::H, &[0]);
773 c1.add_gate(Gate::H, &[1]);
774 c1.add_gate(Gate::cphase(std::f64::consts::PI), &[0, 1]);
775
776 let mut c2 = Circuit::new(2, 0);
777 c2.add_gate(Gate::H, &[0]);
778 c2.add_gate(Gate::H, &[1]);
779 c2.add_gate(Gate::Cz, &[0, 1]);
780
781 let b1 = run_sparse(&c1);
782 let b2 = run_sparse(&c2);
783
784 for (&idx, &1) in &b1.state {
785 let amp2 = b2
786 .state
787 .get(&idx)
788 .copied()
789 .unwrap_or(Complex64::new(0.0, 0.0));
790 assert!((amp1 - amp2).norm() < EPS, "mismatch at idx {idx}");
791 }
792 }
793
794 #[test]
795 fn test_batch_phase_matches_individual() {
796 use crate::gates::BatchPhaseData;
797 use smallvec::smallvec;
798
799 let phase1 = Complex64::from_polar(1.0, 0.5);
800 let phase2 = Complex64::from_polar(1.0, 1.2);
801
802 let mut c1 = Circuit::new(3, 0);
803 c1.add_gate(Gate::H, &[0]);
804 c1.add_gate(Gate::H, &[1]);
805 c1.add_gate(Gate::H, &[2]);
806 c1.add_gate(Gate::cphase(0.5), &[0, 1]);
807 c1.add_gate(Gate::cphase(1.2), &[0, 2]);
808 let p1 = run_sparse_probs(&c1);
809
810 let mut c2 = Circuit::new(3, 0);
811 c2.add_gate(Gate::H, &[0]);
812 c2.add_gate(Gate::H, &[1]);
813 c2.add_gate(Gate::H, &[2]);
814 c2.add_gate(
815 Gate::BatchPhase(Box::new(BatchPhaseData {
816 phases: smallvec![(1, phase1), (2, phase2)],
817 })),
818 &[0, 1, 2],
819 );
820 let p2 = run_sparse_probs(&c2);
821
822 for (a, b) in p1.iter().zip(p2.iter()) {
823 assert!((a - b).abs() < EPS, "probs mismatch: {a} vs {b}");
824 }
825 }
826}