1use std::collections::HashMap;
19
20use num_complex::Complex64;
21use rand::Rng;
22use rand::SeedableRng;
23use rand_chacha::ChaCha8Rng;
24
25#[cfg(feature = "parallel")]
26use rayon::prelude::*;
27
28#[cfg(feature = "parallel")]
29const MIN_STATES_FOR_PAR: usize = 4096;
30
31use crate::backend::{is_phase_one, Backend, MAX_PROB_QUBITS};
32use crate::circuit::Instruction;
33use crate::error::{PrismError, Result};
34use crate::gates::{DiagEntry, Gate};
35
36const DEFAULT_EPSILON: f64 = 1e-16;
37
38pub struct SparseBackend {
40 num_qubits: usize,
41 state: HashMap<usize, Complex64>,
42 swap_buf: HashMap<usize, Complex64>,
43 classical_bits: Vec<bool>,
44 rng: ChaCha8Rng,
45 epsilon: f64,
46}
47
48impl SparseBackend {
49 pub fn new(seed: u64) -> Self {
51 Self {
52 num_qubits: 0,
53 state: HashMap::new(),
54 swap_buf: HashMap::new(),
55 classical_bits: Vec::new(),
56 rng: ChaCha8Rng::seed_from_u64(seed),
57 epsilon: DEFAULT_EPSILON,
58 }
59 }
60
61 #[inline(always)]
62 fn prune(&mut self) {
63 let eps = self.epsilon;
64 self.state.retain(|_, amp| amp.norm_sqr() >= eps);
65 }
66
67 #[inline(always)]
68 fn apply_single_qubit(&mut self, target: usize, mat: [[Complex64; 2]; 2]) {
69 let mask = 1usize << target;
70 let zero = Complex64::new(0.0, 0.0);
71 self.swap_buf.clear();
72 self.swap_buf.reserve(self.state.len() * 2);
73
74 for (&idx, &) in &self.state {
75 let bit = (idx >> target) & 1;
76 let partner = idx ^ mask;
77
78 *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
79 *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
80 }
81
82 std::mem::swap(&mut self.state, &mut self.swap_buf);
83 self.prune();
84 }
85
86 #[inline(always)]
88 fn apply_cx(&mut self, control: usize, target: usize) {
89 let ctrl_mask = 1usize << control;
90 let tgt_mask = 1usize << target;
91 self.swap_buf.clear();
92 self.swap_buf.reserve(self.state.len());
93 self.swap_buf.extend(self.state.drain().map(|(idx, amp)| {
94 if idx & ctrl_mask != 0 {
95 (idx ^ tgt_mask, amp)
96 } else {
97 (idx, amp)
98 }
99 }));
100 std::mem::swap(&mut self.state, &mut self.swap_buf);
101 }
102
103 #[inline(always)]
104 fn apply_cz(&mut self, q0: usize, q1: usize) {
105 let mask0 = 1usize << q0;
106 let mask1 = 1usize << q1;
107 for (&idx, amp) in self.state.iter_mut() {
108 if idx & mask0 != 0 && idx & mask1 != 0 {
109 *amp = -*amp;
110 }
111 }
112 }
113
114 #[inline(always)]
115 fn apply_swap(&mut self, q0: usize, q1: usize) {
116 let m0 = 1usize << q0;
117 let m1 = 1usize << q1;
118 self.swap_buf.clear();
119 self.swap_buf.reserve(self.state.len());
120 self.swap_buf.extend(self.state.drain().map(|(idx, amp)| {
121 let bit0 = (idx >> q0) & 1;
122 let bit1 = (idx >> q1) & 1;
123 if bit0 != bit1 {
124 (idx ^ m0 ^ m1, amp)
125 } else {
126 (idx, amp)
127 }
128 }));
129 std::mem::swap(&mut self.state, &mut self.swap_buf);
130 }
131
132 #[inline(always)]
133 fn apply_cu(&mut self, control: usize, target: usize, mat: [[Complex64; 2]; 2]) {
134 let ctrl_mask = 1usize << control;
135 let tgt_mask = 1usize << target;
136 let zero = Complex64::new(0.0, 0.0);
137 self.swap_buf.clear();
138 self.swap_buf.reserve(self.state.len() * 2);
139
140 for (&idx, &) in &self.state {
141 if idx & ctrl_mask == 0 {
142 *self.swap_buf.entry(idx).or_insert(zero) += amp;
143 } else {
144 let bit = (idx >> target) & 1;
145 let partner = idx ^ tgt_mask;
146 *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
147 *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
148 }
149 }
150
151 std::mem::swap(&mut self.state, &mut self.swap_buf);
152 self.prune();
153 }
154
155 #[inline(always)]
156 fn apply_mcu(&mut self, controls: &[usize], target: usize, mat: [[Complex64; 2]; 2]) {
157 let ctrl_mask: usize = controls.iter().map(|&q| 1usize << q).fold(0, |a, b| a | b);
158 let tgt_mask = 1usize << target;
159 let zero = Complex64::new(0.0, 0.0);
160 self.swap_buf.clear();
161 self.swap_buf.reserve(self.state.len() * 2);
162
163 for (&idx, &) in &self.state {
164 if idx & ctrl_mask != ctrl_mask {
165 *self.swap_buf.entry(idx).or_insert(zero) += amp;
166 } else {
167 let bit = (idx >> target) & 1;
168 let partner = idx ^ tgt_mask;
169 *self.swap_buf.entry(idx).or_insert(zero) += mat[bit][bit] * amp;
170 *self.swap_buf.entry(partner).or_insert(zero) += mat[1 - bit][bit] * amp;
171 }
172 }
173
174 std::mem::swap(&mut self.state, &mut self.swap_buf);
175 self.prune();
176 }
177
178 #[inline(always)]
179 fn apply_cu_phase(&mut self, control: usize, target: usize, phase: Complex64) {
180 let ctrl_mask = 1usize << control;
181 let tgt_mask = 1usize << target;
182 for (&idx, amp) in self.state.iter_mut() {
183 if idx & ctrl_mask != 0 && idx & tgt_mask != 0 {
184 *amp *= phase;
185 }
186 }
187 }
188
189 #[inline(always)]
190 fn apply_mcu_phase(&mut self, controls: &[usize], target: usize, phase: Complex64) {
191 let ctrl_mask: usize = controls.iter().map(|&q| 1usize << q).fold(0, |a, b| a | b);
192 let tgt_mask = 1usize << target;
193 for (&idx, amp) in self.state.iter_mut() {
194 if idx & ctrl_mask == ctrl_mask && idx & tgt_mask != 0 {
195 *amp *= phase;
196 }
197 }
198 }
199
200 fn apply_batch_phase(&mut self, control: usize, phases: &[(usize, Complex64)]) {
201 let ctrl_mask = 1usize << control;
202 let one = Complex64::new(1.0, 0.0);
203 for (&idx, amp) in self.state.iter_mut() {
204 if idx & ctrl_mask == 0 {
205 continue;
206 }
207 let mut combined = one;
208 for &(target, phase) in phases {
209 if idx & (1usize << target) != 0 {
210 combined *= phase;
211 }
212 }
213 if !is_phase_one(combined) {
214 *amp *= combined;
215 }
216 }
217 }
218
219 fn apply_fused_2q(&mut self, q0: usize, q1: usize, mat: &[[Complex64; 4]; 4]) {
220 let mask0 = 1usize << q0;
221 let mask1 = 1usize << q1;
222 let zero = Complex64::new(0.0, 0.0);
223 self.swap_buf.clear();
224 self.swap_buf.reserve(self.state.len() * 2);
225
226 for (&idx, &) in &self.state {
227 let bit0 = (idx >> q0) & 1;
228 let bit1 = (idx >> q1) & 1;
229 let row = bit0 * 2 + bit1;
230 let base = idx & !(mask0 | mask1);
231
232 for (col, mat_row) in mat.iter().enumerate() {
233 let coeff = mat_row[row];
234 if coeff == zero {
235 continue;
236 }
237 let col_bit0 = (col >> 1) & 1;
238 let col_bit1 = col & 1;
239 let dest = base | (col_bit0 << q0) | (col_bit1 << q1);
240 *self.swap_buf.entry(dest).or_insert(zero) += coeff * amp;
241 }
242 }
243
244 std::mem::swap(&mut self.state, &mut self.swap_buf);
245 self.prune();
246 }
247
248 fn apply_reset(&mut self, qubit: usize) {
249 let mask = 1usize << qubit;
250
251 #[cfg(feature = "parallel")]
252 let prob_zero: f64 = if self.state.len() >= MIN_STATES_FOR_PAR {
253 self.state
254 .par_iter()
255 .filter(|(&idx, _)| idx & mask == 0)
256 .map(|(_, amp)| amp.norm_sqr())
257 .sum()
258 } else {
259 self.state
260 .iter()
261 .filter(|(&idx, _)| idx & mask == 0)
262 .map(|(_, amp)| amp.norm_sqr())
263 .sum()
264 };
265
266 #[cfg(not(feature = "parallel"))]
267 let prob_zero: f64 = self
268 .state
269 .iter()
270 .filter(|(&idx, _)| idx & mask == 0)
271 .map(|(_, amp)| amp.norm_sqr())
272 .sum();
273
274 if prob_zero > 0.0 {
275 let inv_norm = 1.0 / prob_zero.sqrt();
276 self.state.retain(|&idx, amp| {
277 if idx & mask == 0 {
278 *amp *= inv_norm;
279 true
280 } else {
281 false
282 }
283 });
284 } else {
285 self.state.clear();
286 self.state.insert(0, Complex64::new(1.0, 0.0));
287 }
288 }
289
290 fn apply_measure(&mut self, qubit: usize, classical_bit: usize) {
291 let mask = 1usize << qubit;
292
293 #[cfg(feature = "parallel")]
294 let prob_one: f64 = if self.state.len() >= MIN_STATES_FOR_PAR {
295 self.state
296 .par_iter()
297 .filter(|(&idx, _)| idx & mask != 0)
298 .map(|(_, amp)| amp.norm_sqr())
299 .sum()
300 } else {
301 self.state
302 .iter()
303 .filter(|(&idx, _)| idx & mask != 0)
304 .map(|(_, amp)| amp.norm_sqr())
305 .sum()
306 };
307
308 #[cfg(not(feature = "parallel"))]
309 let prob_one: f64 = self
310 .state
311 .iter()
312 .filter(|(&idx, _)| idx & mask != 0)
313 .map(|(_, amp)| amp.norm_sqr())
314 .sum();
315
316 let outcome = self.rng.random::<f64>() < prob_one;
317 self.classical_bits[classical_bit] = outcome;
318
319 let inv_norm = crate::backend::measurement_inv_norm(outcome, prob_one);
320
321 self.state.retain(|&idx, amp| {
322 let matches = (idx & mask != 0) == outcome;
323 if matches {
324 *amp *= inv_norm;
325 }
326 matches
327 });
328 }
329
330 fn dispatch_gate(&mut self, gate: &Gate, targets: &[usize]) {
331 match gate {
332 Gate::Rzz(theta) => {
333 let phase_same = Complex64::from_polar(1.0, -theta / 2.0);
334 let phase_diff = Complex64::from_polar(1.0, theta / 2.0);
335 let q0 = targets[0];
336 let q1 = targets[1];
337 for (idx, amp) in self.state.iter_mut() {
338 let parity = ((*idx >> q0) ^ (*idx >> q1)) & 1;
339 *amp *= if parity == 0 { phase_same } else { phase_diff };
340 }
341 }
342 Gate::Cx => {
343 self.apply_cx(targets[0], targets[1]);
344 }
345 Gate::Cz => {
346 self.apply_cz(targets[0], targets[1]);
347 }
348 Gate::Swap => {
349 self.apply_swap(targets[0], targets[1]);
350 }
351 Gate::Cu(mat) => {
352 if let Some(phase) = gate.controlled_phase() {
353 self.apply_cu_phase(targets[0], targets[1], phase);
354 } else {
355 self.apply_cu(targets[0], targets[1], **mat);
356 }
357 }
358 Gate::Mcu(data) => {
359 let num_ctrl = data.num_controls as usize;
360 if let Some(phase) = gate.controlled_phase() {
361 self.apply_mcu_phase(&targets[..num_ctrl], targets[num_ctrl], phase);
362 } else {
363 self.apply_mcu(&targets[..num_ctrl], targets[num_ctrl], data.mat);
364 }
365 }
366 Gate::BatchPhase(data) => {
367 self.apply_batch_phase(targets[0], &data.phases);
368 }
369 Gate::BatchRzz(data) => {
370 for &(q0, q1, theta) in &data.edges {
371 let phase_same = Complex64::from_polar(1.0, -theta / 2.0);
372 let phase_diff = Complex64::from_polar(1.0, theta / 2.0);
373 for (idx, amp) in self.state.iter_mut() {
374 let parity = ((*idx >> q0) ^ (*idx >> q1)) & 1;
375 *amp *= if parity == 0 { phase_same } else { phase_diff };
376 }
377 }
378 }
379 Gate::DiagonalBatch(data) => {
380 for entry in &data.entries {
381 match entry {
382 DiagEntry::Phase1q { qubit, d0, d1 } => {
383 let mask = 1usize << qubit;
384 for (idx, amp) in self.state.iter_mut() {
385 if (*idx & mask) != 0 {
386 *amp *= d1;
387 } else {
388 *amp *= d0;
389 }
390 }
391 }
392 DiagEntry::Phase2q { q0, q1, phase } => {
393 let mask = (1usize << q0) | (1usize << q1);
394 for (idx, amp) in self.state.iter_mut() {
395 if (*idx & mask) == mask {
396 *amp *= phase;
397 }
398 }
399 }
400 DiagEntry::Parity2q { q0, q1, same, diff } => {
401 for (idx, amp) in self.state.iter_mut() {
402 let parity = ((*idx >> q0) ^ (*idx >> q1)) & 1;
403 *amp *= if parity == 0 { *same } else { *diff };
404 }
405 }
406 }
407 }
408 }
409 Gate::MultiFused(data) => {
410 for &(target, mat) in &data.gates {
411 self.apply_single_qubit(target, mat);
412 }
413 }
414 Gate::Fused2q(mat) => {
415 self.apply_fused_2q(targets[0], targets[1], mat);
416 }
417 Gate::Multi2q(data) => {
418 for &(q0, q1, ref mat) in &data.gates {
419 self.apply_fused_2q(q0, q1, mat);
420 }
421 }
422 other => {
423 debug_assert!(
424 targets.len() == 1,
425 "sparse dispatch_gate: unexpected multi-qubit gate {:?}",
426 other
427 );
428 let mat = other.matrix_2x2();
429 self.apply_single_qubit(targets[0], mat);
430 }
431 }
432 }
433}
434
435impl Backend for SparseBackend {
436 fn name(&self) -> &'static str {
437 "sparse"
438 }
439
440 fn init(&mut self, num_qubits: usize, num_classical_bits: usize) -> Result<()> {
441 self.num_qubits = num_qubits;
442 self.state.clear();
443 self.state.insert(0, Complex64::new(1.0, 0.0));
444 self.classical_bits = vec![false; num_classical_bits];
445 Ok(())
446 }
447
448 fn apply(&mut self, instruction: &Instruction) -> Result<()> {
449 match instruction {
450 Instruction::Gate { gate, targets } => self.dispatch_gate(gate, targets),
451 Instruction::Measure {
452 qubit,
453 classical_bit,
454 } => {
455 self.apply_measure(*qubit, *classical_bit);
456 }
457 Instruction::Reset { qubit } => {
458 self.apply_reset(*qubit);
459 }
460 Instruction::Barrier { .. } => {}
461 Instruction::Conditional {
462 condition,
463 gate,
464 targets,
465 } => {
466 if condition.evaluate(&self.classical_bits) {
467 self.dispatch_gate(gate, targets);
468 }
469 }
470 }
471 Ok(())
472 }
473
474 fn reset(&mut self, qubit: usize) -> Result<()> {
475 self.apply_reset(qubit);
476 Ok(())
477 }
478
479 fn reduced_density_matrix_1q(&self, qubit: usize) -> Result<[[Complex64; 2]; 2]> {
480 let mask = 1usize << qubit;
481 let mut p0 = 0.0f64;
482 let mut p1 = 0.0f64;
483 let mut r = Complex64::new(0.0, 0.0);
484
485 for (&idx, &) in &self.state {
486 if idx & mask == 0 {
487 p0 += amp.norm_sqr();
488 if let Some(&_one) = self.state.get(&(idx | mask)) {
489 r += amp_one * amp.conj();
490 }
491 } else {
492 p1 += amp.norm_sqr();
493 }
494 }
495
496 Ok([
497 [Complex64::new(p0, 0.0), r.conj()],
498 [r, Complex64::new(p1, 0.0)],
499 ])
500 }
501
502 fn classical_results(&self) -> &[bool] {
503 &self.classical_bits
504 }
505
506 fn probabilities(&self) -> Result<Vec<f64>> {
507 if self.num_qubits > MAX_PROB_QUBITS {
508 return Err(PrismError::BackendUnsupported {
509 backend: self.name().to_string(),
510 operation: format!(
511 "probabilities for {} qubits (max {})",
512 self.num_qubits, MAX_PROB_QUBITS
513 ),
514 });
515 }
516 let dim = 1usize << self.num_qubits;
517 let mut probs = vec![0.0f64; dim];
518 for (&idx, amp) in &self.state {
519 probs[idx] = amp.norm_sqr();
520 }
521 Ok(probs)
522 }
523
524 fn num_qubits(&self) -> usize {
525 self.num_qubits
526 }
527
528 fn export_statevector(&self) -> Result<Vec<Complex64>> {
529 if self.num_qubits > MAX_PROB_QUBITS {
530 return Err(PrismError::BackendUnsupported {
531 backend: self.name().to_string(),
532 operation: format!(
533 "statevector export for {} qubits (max {})",
534 self.num_qubits, MAX_PROB_QUBITS
535 ),
536 });
537 }
538 let dim = 1usize << self.num_qubits;
539 let mut sv = vec![Complex64::new(0.0, 0.0); dim];
540 for (&idx, &) in &self.state {
541 sv[idx] = amp;
542 }
543 Ok(sv)
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use crate::circuit::Circuit;
551 use crate::sim;
552
553 const EPS: f64 = 1e-12;
554
555 fn run_sparse(circuit: &Circuit) -> SparseBackend {
556 let mut b = SparseBackend::new(42);
557 sim::run_on(&mut b, circuit).unwrap();
558 b
559 }
560
561 fn run_sparse_probs(circuit: &Circuit) -> Vec<f64> {
562 let b = run_sparse(circuit);
563 b.probabilities().unwrap()
564 }
565
566 #[test]
567 fn test_init_zero_state() {
568 let mut b = SparseBackend::new(42);
569 b.init(3, 0).unwrap();
570 assert_eq!(b.state.len(), 1);
571 assert!((b.state[&0].re - 1.0).abs() < EPS);
572 }
573
574 #[test]
575 fn test_x_gate() {
576 let mut c = Circuit::new(1, 0);
577 c.add_gate(Gate::X, &[0]);
578 let b = run_sparse(&c);
579 assert_eq!(b.state.len(), 1);
580 assert!(b.state.contains_key(&1));
581 assert!((b.state[&1].norm() - 1.0).abs() < EPS);
582 }
583
584 #[test]
585 fn test_h_creates_superposition() {
586 let mut c = Circuit::new(1, 0);
587 c.add_gate(Gate::H, &[0]);
588 let b = run_sparse(&c);
589 assert_eq!(b.state.len(), 2);
590 assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
591 assert!((b.state[&1].norm_sqr() - 0.5).abs() < EPS);
592 }
593
594 #[test]
595 fn test_hh_is_identity() {
596 let mut c = Circuit::new(1, 0);
597 c.add_gate(Gate::H, &[0]);
598 c.add_gate(Gate::H, &[0]);
599 let b = run_sparse(&c);
600 assert_eq!(b.state.len(), 1);
601 assert!((b.state[&0].re - 1.0).abs() < EPS);
602 }
603
604 #[test]
605 fn test_cx_bell_state() {
606 let mut c = Circuit::new(2, 0);
607 c.add_gate(Gate::H, &[0]);
608 c.add_gate(Gate::Cx, &[0, 1]);
609 let b = run_sparse(&c);
610 assert_eq!(b.state.len(), 2);
611 assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
612 assert!((b.state[&3].norm_sqr() - 0.5).abs() < EPS);
613 }
614
615 #[test]
616 fn test_cz_phase() {
617 let mut c = Circuit::new(2, 0);
618 c.add_gate(Gate::X, &[0]);
619 c.add_gate(Gate::X, &[1]);
620 c.add_gate(Gate::Cz, &[0, 1]);
621 let b = run_sparse(&c);
622 assert_eq!(b.state.len(), 1);
623 assert!((b.state[&3].re - (-1.0)).abs() < EPS);
624 }
625
626 #[test]
627 fn test_swap() {
628 let mut c = Circuit::new(2, 0);
629 c.add_gate(Gate::X, &[1]);
630 c.add_gate(Gate::Swap, &[0, 1]);
631 let b = run_sparse(&c);
632 assert_eq!(b.state.len(), 1);
633 assert!(b.state.contains_key(&1));
634 }
635
636 #[test]
637 fn test_rx_pi() {
638 let mut c = Circuit::new(1, 0);
639 c.add_gate(Gate::Rx(std::f64::consts::PI), &[0]);
640 let probs = run_sparse_probs(&c);
641 assert!(probs[0].abs() < EPS);
642 assert!((probs[1] - 1.0).abs() < EPS);
643 }
644
645 #[test]
646 fn test_rz_preserves_sparsity() {
647 let mut c = Circuit::new(1, 0);
648 c.add_gate(Gate::Rz(1.234), &[0]);
649 let b = run_sparse(&c);
650 assert_eq!(b.state.len(), 1);
651 assert!((b.state[&0].norm() - 1.0).abs() < EPS);
652 }
653
654 #[test]
655 fn test_measure_collapses() {
656 let mut c = Circuit::new(1, 1);
657 c.add_gate(Gate::H, &[0]);
658 c.add_measure(0, 0);
659 let b = run_sparse(&c);
660 assert_eq!(b.state.len(), 1);
661 let outcome = b.classical_results()[0];
662 if outcome {
663 assert!(b.state.contains_key(&1));
664 } else {
665 assert!(b.state.contains_key(&0));
666 }
667 }
668
669 #[test]
670 fn test_measure_deterministic() {
671 let mut c = Circuit::new(1, 1);
672 c.add_gate(Gate::H, &[0]);
673 c.add_measure(0, 0);
674
675 let b1 = run_sparse(&c);
676 let b2 = run_sparse(&c);
677 assert_eq!(b1.classical_results()[0], b2.classical_results()[0]);
678 }
679
680 #[test]
681 fn test_probs_bell() {
682 let mut c = Circuit::new(2, 0);
683 c.add_gate(Gate::H, &[0]);
684 c.add_gate(Gate::Cx, &[0, 1]);
685 let probs = run_sparse_probs(&c);
686 assert!((probs[0] - 0.5).abs() < EPS);
687 assert!(probs[1].abs() < EPS);
688 assert!(probs[2].abs() < EPS);
689 assert!((probs[3] - 0.5).abs() < EPS);
690 }
691
692 #[test]
693 fn test_probs_zero_state() {
694 let c = Circuit::new(3, 0);
695 let probs = run_sparse_probs(&c);
696 assert!((probs[0] - 1.0).abs() < EPS);
697 let rest: f64 = probs[1..].iter().sum();
698 assert!(rest.abs() < EPS);
699 }
700
701 #[test]
702 fn test_pruning() {
703 let mut b = SparseBackend::new(42);
704 b.init(1, 0).unwrap();
705 b.state.insert(1, Complex64::new(1e-20, 0.0));
706 assert_eq!(b.state.len(), 2);
707 b.prune();
708 assert_eq!(b.state.len(), 1);
709 assert!(b.state.contains_key(&0));
710 }
711
712 #[test]
713 fn test_fused_gate() {
714 let h_mat = Gate::H.matrix_2x2();
715 let t_mat = Gate::T.matrix_2x2();
716 let zero = Complex64::new(0.0, 0.0);
717 let mut fused = [[zero; 2]; 2];
718 for i in 0..2 {
719 for j in 0..2 {
720 for k in 0..2 {
721 fused[i][j] += t_mat[i][k] * h_mat[k][j];
722 }
723 }
724 }
725
726 let mut c1 = Circuit::new(1, 0);
727 c1.add_gate(Gate::H, &[0]);
728 c1.add_gate(Gate::T, &[0]);
729 let p1 = run_sparse_probs(&c1);
730
731 let mut c2 = Circuit::new(1, 0);
732 c2.add_gate(Gate::Fused(Box::new(fused)), &[0]);
733 let p2 = run_sparse_probs(&c2);
734
735 for (a, b) in p1.iter().zip(p2.iter()) {
736 assert!((a - b).abs() < EPS);
737 }
738 }
739
740 #[test]
741 fn test_ghz_4_sparse() {
742 let mut c = Circuit::new(4, 0);
743 c.add_gate(Gate::H, &[0]);
744 for i in 0..3 {
745 c.add_gate(Gate::Cx, &[i, i + 1]);
746 }
747 let b = run_sparse(&c);
748 assert_eq!(b.state.len(), 2);
749 assert!((b.state[&0].norm_sqr() - 0.5).abs() < EPS);
750 assert!((b.state[&15].norm_sqr() - 0.5).abs() < EPS);
751 }
752
753 #[test]
754 fn test_cu_phase_applies_phase() {
755 let mut c = Circuit::new(2, 0);
756 c.add_gate(Gate::X, &[0]);
757 c.add_gate(Gate::X, &[1]);
758 c.add_gate(Gate::cphase(std::f64::consts::FRAC_PI_4), &[0, 1]);
759 let b = run_sparse(&c);
760 assert_eq!(b.state.len(), 1);
761 let expected = Complex64::from_polar(1.0, std::f64::consts::FRAC_PI_4);
762 assert!((b.state[&3] - expected).norm() < EPS);
763 }
764
765 #[test]
766 fn test_cu_phase_no_action_control_zero() {
767 let mut c = Circuit::new(2, 0);
768 c.add_gate(Gate::H, &[1]);
769 c.add_gate(Gate::cphase(1.0), &[0, 1]);
770 let b = run_sparse(&c);
771 let h = 1.0 / 2.0_f64.sqrt();
772 assert!((b.state[&0].re - h).abs() < EPS);
773 assert!((b.state[&2].re - h).abs() < EPS);
774 assert!(!b.state.contains_key(&1));
775 assert!(!b.state.contains_key(&3));
776 }
777
778 #[test]
779 fn test_cu_phase_matches_cz() {
780 let mut c1 = Circuit::new(2, 0);
781 c1.add_gate(Gate::H, &[0]);
782 c1.add_gate(Gate::H, &[1]);
783 c1.add_gate(Gate::cphase(std::f64::consts::PI), &[0, 1]);
784
785 let mut c2 = Circuit::new(2, 0);
786 c2.add_gate(Gate::H, &[0]);
787 c2.add_gate(Gate::H, &[1]);
788 c2.add_gate(Gate::Cz, &[0, 1]);
789
790 let b1 = run_sparse(&c1);
791 let b2 = run_sparse(&c2);
792
793 for (&idx, &1) in &b1.state {
794 let amp2 = b2
795 .state
796 .get(&idx)
797 .copied()
798 .unwrap_or(Complex64::new(0.0, 0.0));
799 assert!((amp1 - amp2).norm() < EPS, "mismatch at idx {idx}");
800 }
801 }
802
803 #[test]
804 fn test_batch_phase_matches_individual() {
805 use crate::gates::BatchPhaseData;
806 use smallvec::smallvec;
807
808 let phase1 = Complex64::from_polar(1.0, 0.5);
809 let phase2 = Complex64::from_polar(1.0, 1.2);
810
811 let mut c1 = Circuit::new(3, 0);
812 c1.add_gate(Gate::H, &[0]);
813 c1.add_gate(Gate::H, &[1]);
814 c1.add_gate(Gate::H, &[2]);
815 c1.add_gate(Gate::cphase(0.5), &[0, 1]);
816 c1.add_gate(Gate::cphase(1.2), &[0, 2]);
817 let p1 = run_sparse_probs(&c1);
818
819 let mut c2 = Circuit::new(3, 0);
820 c2.add_gate(Gate::H, &[0]);
821 c2.add_gate(Gate::H, &[1]);
822 c2.add_gate(Gate::H, &[2]);
823 c2.add_gate(
824 Gate::BatchPhase(Box::new(BatchPhaseData {
825 phases: smallvec![(1, phase1), (2, phase2)],
826 })),
827 &[0, 1, 2],
828 );
829 let p2 = run_sparse_probs(&c2);
830
831 for (a, b) in p1.iter().zip(p2.iter()) {
832 assert!((a - b).abs() < EPS, "probs mismatch: {a} vs {b}");
833 }
834 }
835}