1use crate::{
8 cartan::OptimizedCartanDecomposer,
9 controlled::make_controlled,
10 error::{QuantRS2Error, QuantRS2Result},
11 gate::{single::*, GateOp},
12 matrix_ops::{DenseMatrix, QuantumMatrix},
13 qubit::QubitId,
14 synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
15};
16use rustc_hash::FxHashMap;
17use scirs2_core::ndarray::{s, Array2};
18use scirs2_core::Complex;
19use std::f64::consts::PI;
20
21#[derive(Debug, Clone)]
23pub struct ShannonDecomposition {
24 pub gates: Vec<Box<dyn GateOp>>,
26 pub cnot_count: usize,
28 pub single_qubit_count: usize,
30 pub depth: usize,
32}
33
34pub struct ShannonDecomposer {
36 tolerance: f64,
38 cache: FxHashMap<u64, ShannonDecomposition>,
40 max_depth: usize,
42}
43
44impl ShannonDecomposer {
45 pub fn new() -> Self {
47 Self {
48 tolerance: 1e-10,
49 cache: FxHashMap::default(),
50 max_depth: 20,
51 }
52 }
53
54 pub fn with_tolerance(tolerance: f64) -> Self {
56 Self {
57 tolerance,
58 cache: FxHashMap::default(),
59 max_depth: 20,
60 }
61 }
62
63 pub fn decompose(
65 &mut self,
66 unitary: &Array2<Complex<f64>>,
67 qubit_ids: &[QubitId],
68 ) -> QuantRS2Result<ShannonDecomposition> {
69 let n = qubit_ids.len();
70 let size = 1 << n;
71
72 if unitary.shape() != [size, size] {
74 return Err(QuantRS2Error::InvalidInput(format!(
75 "Unitary size {} doesn't match {} qubits",
76 unitary.shape()[0],
77 n
78 )));
79 }
80
81 let mat = DenseMatrix::new(unitary.clone())?;
83 if !mat.is_unitary(self.tolerance)? {
84 return Err(QuantRS2Error::InvalidInput(
85 "Matrix is not unitary".to_string(),
86 ));
87 }
88
89 if n == 0 {
91 return Ok(ShannonDecomposition {
92 gates: vec![],
93 cnot_count: 0,
94 single_qubit_count: 0,
95 depth: 0,
96 });
97 }
98
99 if n == 1 {
100 let decomp = decompose_single_qubit_zyz(&unitary.view())?;
102 let gates = self.single_qubit_to_gates(&decomp, qubit_ids[0]);
103 let count = gates.len();
104
105 return Ok(ShannonDecomposition {
106 gates,
107 cnot_count: 0,
108 single_qubit_count: count,
109 depth: count,
110 });
111 }
112
113 if n == 2 {
114 return self.decompose_two_qubit(unitary, qubit_ids);
116 }
117
118 self.decompose_recursive(unitary, qubit_ids, 0)
120 }
121
122 fn decompose_recursive(
124 &mut self,
125 unitary: &Array2<Complex<f64>>,
126 qubit_ids: &[QubitId],
127 depth: usize,
128 ) -> QuantRS2Result<ShannonDecomposition> {
129 if depth > self.max_depth {
130 return Err(QuantRS2Error::InvalidInput(
131 "Maximum recursion depth exceeded".to_string(),
132 ));
133 }
134
135 let n = qubit_ids.len();
136 let half_size = 1 << (n - 1);
137
138 let a = unitary.slice(s![..half_size, ..half_size]).to_owned();
142 let b = unitary.slice(s![..half_size, half_size..]).to_owned();
143 let c = unitary.slice(s![half_size.., ..half_size]).to_owned();
144 let d = unitary.slice(s![half_size.., half_size..]).to_owned();
145
146 let (v, w, u_diag) = self.block_diagonalize(&a, &b, &c, &d)?;
150
151 let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
152 let mut cnot_count = 0;
153 let mut single_qubit_count = 0;
154
155 if !self.is_identity(&w) {
157 let w_decomp = self.decompose_recursive(&w, &qubit_ids[1..], depth + 1)?;
158 gates.extend(w_decomp.gates);
159 cnot_count += w_decomp.cnot_count;
160 single_qubit_count += w_decomp.single_qubit_count;
161 }
162
163 let diag_gates = self.decompose_controlled_diagonal(&u_diag, qubit_ids)?;
165 cnot_count += diag_gates.1;
166 single_qubit_count += diag_gates.2;
167 gates.extend(diag_gates.0);
168
169 if !self.is_identity(&v) {
171 let v_dag = v.mapv(|z| z.conj()).t().to_owned();
172 let v_decomp = self.decompose_recursive(&v_dag, &qubit_ids[1..], depth + 1)?;
173 gates.extend(v_decomp.gates);
174 cnot_count += v_decomp.cnot_count;
175 single_qubit_count += v_decomp.single_qubit_count;
176 }
177
178 let depth = gates.len();
180
181 Ok(ShannonDecomposition {
182 gates,
183 cnot_count,
184 single_qubit_count,
185 depth,
186 })
187 }
188
189 fn block_diagonalize(
191 &self,
192 a: &Array2<Complex<f64>>,
193 b: &Array2<Complex<f64>>,
194 c: &Array2<Complex<f64>>,
195 d: &Array2<Complex<f64>>,
196 ) -> QuantRS2Result<(
197 Array2<Complex<f64>>,
198 Array2<Complex<f64>>,
199 Array2<Complex<f64>>,
200 )> {
201 let size = a.shape()[0];
202
203 let b_norm = b.iter().map(|z| z.norm_sqr()).sum::<f64>().sqrt();
212 let c_norm = c.iter().map(|z| z.norm_sqr()).sum::<f64>().sqrt();
213
214 if b_norm < self.tolerance && c_norm < self.tolerance {
215 let identity = Array2::eye(size);
216 let combined = self.combine_blocks(a, b, c, d);
217 return Ok((identity.clone(), identity, combined));
218 }
219
220 let combined = self.combine_blocks(a, b, c, d);
223
224 let identity = Array2::eye(size);
227 Ok((identity.clone(), identity, combined))
228 }
229
230 fn combine_blocks(
232 &self,
233 a: &Array2<Complex<f64>>,
234 b: &Array2<Complex<f64>>,
235 c: &Array2<Complex<f64>>,
236 d: &Array2<Complex<f64>>,
237 ) -> Array2<Complex<f64>> {
238 let size = a.shape()[0];
239 let total_size = 2 * size;
240 let mut result = Array2::zeros((total_size, total_size));
241
242 result.slice_mut(s![..size, ..size]).assign(a);
243 result.slice_mut(s![..size, size..]).assign(b);
244 result.slice_mut(s![size.., ..size]).assign(c);
245 result.slice_mut(s![size.., size..]).assign(d);
246
247 result
248 }
249
250 fn decompose_controlled_diagonal(
252 &self,
253 diagonal: &Array2<Complex<f64>>,
254 qubit_ids: &[QubitId],
255 ) -> QuantRS2Result<(Vec<Box<dyn GateOp>>, usize, usize)> {
256 let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
257 let mut cnot_count = 0;
258 let mut single_qubit_count = 0;
259
260 let n = diagonal.shape()[0];
262 let mut phases = Vec::with_capacity(n);
263
264 for i in 0..n {
265 let phase = diagonal[[i, i]].arg();
266 phases.push(phase);
267 }
268
269 let control = qubit_ids[0];
272
273 for (i, &phase) in phases.iter().enumerate() {
274 if phase.abs() > self.tolerance {
275 if i == 0 {
276 let gate: Box<dyn GateOp> = Box::new(RotationZ {
278 target: control,
279 theta: phase,
280 });
281 gates.push(gate);
282 single_qubit_count += 1;
283 } else {
284 let base_gate = Box::new(RotationZ {
288 target: qubit_ids[1],
289 theta: phase,
290 });
291
292 let controlled = Box::new(make_controlled(vec![control], *base_gate));
293 gates.push(controlled);
294 cnot_count += 2; single_qubit_count += 3; }
297 }
298 }
299
300 Ok((gates, cnot_count, single_qubit_count))
301 }
302
303 fn decompose_two_qubit(
305 &self,
306 unitary: &Array2<Complex<f64>>,
307 qubit_ids: &[QubitId],
308 ) -> QuantRS2Result<ShannonDecomposition> {
309 if self.is_identity(unitary) {
311 return Ok(ShannonDecomposition {
312 gates: vec![],
313 cnot_count: 0,
314 single_qubit_count: 0,
315 depth: 0,
316 });
317 }
318
319 let mut cartan_decomposer = OptimizedCartanDecomposer::new();
321 let cartan_decomp = cartan_decomposer.decompose(unitary)?;
322 let gates = cartan_decomposer.base.to_gates(&cartan_decomp, qubit_ids)?;
323
324 let mut cnot_count = 0;
326 let mut single_qubit_count = 0;
327
328 for gate in &gates {
329 match gate.name() {
330 "CNOT" => cnot_count += 1,
331 _ => single_qubit_count += 1,
332 }
333 }
334
335 let depth = gates.len();
336
337 Ok(ShannonDecomposition {
338 gates,
339 cnot_count,
340 single_qubit_count,
341 depth,
342 })
343 }
344
345 fn single_qubit_to_gates(
347 &self,
348 decomp: &SingleQubitDecomposition,
349 qubit: QubitId,
350 ) -> Vec<Box<dyn GateOp>> {
351 let mut gates = Vec::new();
352
353 if decomp.theta1.abs() > self.tolerance {
355 gates.push(Box::new(RotationZ {
356 target: qubit,
357 theta: decomp.theta1,
358 }) as Box<dyn GateOp>);
359 }
360
361 if decomp.phi.abs() > self.tolerance {
363 gates.push(Box::new(RotationY {
364 target: qubit,
365 theta: decomp.phi,
366 }) as Box<dyn GateOp>);
367 }
368
369 if decomp.theta2.abs() > self.tolerance {
371 gates.push(Box::new(RotationZ {
372 target: qubit,
373 theta: decomp.theta2,
374 }) as Box<dyn GateOp>);
375 }
376
377 gates
380 }
381
382 fn is_identity(&self, matrix: &Array2<Complex<f64>>) -> bool {
384 let n = matrix.shape()[0];
385
386 for i in 0..n {
387 for j in 0..n {
388 let expected = if i == j {
389 Complex::new(1.0, 0.0)
390 } else {
391 Complex::new(0.0, 0.0)
392 };
393 if (matrix[[i, j]] - expected).norm() > self.tolerance {
394 return false;
395 }
396 }
397 }
398
399 true
400 }
401}
402
403pub struct OptimizedShannonDecomposer {
405 base: ShannonDecomposer,
406 peephole: bool,
408 commutation: bool,
410}
411
412impl OptimizedShannonDecomposer {
413 pub fn new() -> Self {
415 Self {
416 base: ShannonDecomposer::new(),
417 peephole: true,
418 commutation: true,
419 }
420 }
421
422 pub fn decompose(
424 &mut self,
425 unitary: &Array2<Complex<f64>>,
426 qubit_ids: &[QubitId],
427 ) -> QuantRS2Result<ShannonDecomposition> {
428 let mut decomp = self.base.decompose(unitary, qubit_ids)?;
430
431 if self.peephole {
432 decomp = self.apply_peephole_optimization(decomp)?;
433 }
434
435 if self.commutation {
436 decomp = self.apply_commutation_optimization(decomp)?;
437 }
438
439 Ok(decomp)
440 }
441
442 fn apply_peephole_optimization(
444 &self,
445 mut decomp: ShannonDecomposition,
446 ) -> QuantRS2Result<ShannonDecomposition> {
447 let mut optimized_gates = Vec::new();
453 let mut i = 0;
454
455 while i < decomp.gates.len() {
456 if i + 1 < decomp.gates.len() {
457 if self.gates_cancel(&decomp.gates[i], &decomp.gates[i + 1]) {
459 i += 2;
461 decomp.cnot_count =
462 decomp
463 .cnot_count
464 .saturating_sub(if decomp.gates[i - 2].name() == "CNOT" {
465 2
466 } else {
467 0
468 });
469 decomp.single_qubit_count = decomp.single_qubit_count.saturating_sub(
470 if decomp.gates[i - 2].name() == "CNOT" {
471 0
472 } else {
473 2
474 },
475 );
476 continue;
477 }
478
479 if let Some(merged) =
481 self.try_merge_rotations(&decomp.gates[i], &decomp.gates[i + 1])
482 {
483 optimized_gates.push(merged);
484 i += 2;
485 decomp.single_qubit_count = decomp.single_qubit_count.saturating_sub(1);
486 continue;
487 }
488 }
489
490 optimized_gates.push(decomp.gates[i].clone());
491 i += 1;
492 }
493
494 decomp.gates = optimized_gates;
495 decomp.depth = decomp.gates.len();
496
497 Ok(decomp)
498 }
499
500 const fn apply_commutation_optimization(
502 &self,
503 decomp: ShannonDecomposition,
504 ) -> QuantRS2Result<ShannonDecomposition> {
505 Ok(decomp)
510 }
511
512 fn gates_cancel(&self, gate1: &Box<dyn GateOp>, gate2: &Box<dyn GateOp>) -> bool {
514 if gate1.name() == gate2.name() && gate1.qubits() == gate2.qubits() {
516 match gate1.name() {
517 "X" | "Y" | "Z" | "H" | "CNOT" | "SWAP" => true,
518 _ => false,
519 }
520 } else {
521 false
522 }
523 }
524
525 fn try_merge_rotations(
527 &self,
528 gate1: &Box<dyn GateOp>,
529 gate2: &Box<dyn GateOp>,
530 ) -> Option<Box<dyn GateOp>> {
531 if gate1.qubits() != gate2.qubits() || gate1.qubits().len() != 1 {
533 return None;
534 }
535
536 let qubit = gate1.qubits()[0];
537
538 match (gate1.name(), gate2.name()) {
539 ("RZ", "RZ") => {
540 let theta1 = gate1.as_any().downcast_ref::<RotationZ>()?.theta;
541 let theta2 = gate2.as_any().downcast_ref::<RotationZ>()?.theta;
542 Some(Box::new(RotationZ {
543 target: qubit,
544 theta: theta1 + theta2,
545 }))
546 }
547 ("RX", "RX") => {
548 let theta1 = gate1.as_any().downcast_ref::<RotationX>()?.theta;
549 let theta2 = gate2.as_any().downcast_ref::<RotationX>()?.theta;
550 Some(Box::new(RotationX {
551 target: qubit,
552 theta: theta1 + theta2,
553 }))
554 }
555 ("RY", "RY") => {
556 let theta1 = gate1.as_any().downcast_ref::<RotationY>()?.theta;
557 let theta2 = gate2.as_any().downcast_ref::<RotationY>()?.theta;
558 Some(Box::new(RotationY {
559 target: qubit,
560 theta: theta1 + theta2,
561 }))
562 }
563 _ => None,
564 }
565 }
566}
567
568pub fn shannon_decompose(
570 unitary: &Array2<Complex<f64>>,
571 qubit_ids: &[QubitId],
572) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
573 let mut decomposer = ShannonDecomposer::new();
574 let decomp = decomposer.decompose(unitary, qubit_ids)?;
575 Ok(decomp.gates)
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use scirs2_core::ndarray::Array2;
582 use scirs2_core::Complex;
583
584 #[test]
585 fn test_shannon_single_qubit() {
586 let mut decomposer = ShannonDecomposer::new();
587
588 let h = Array2::from_shape_vec(
590 (2, 2),
591 vec![
592 Complex::new(1.0, 0.0),
593 Complex::new(1.0, 0.0),
594 Complex::new(1.0, 0.0),
595 Complex::new(-1.0, 0.0),
596 ],
597 )
598 .expect("Failed to create Hadamard matrix")
599 / Complex::new(2.0_f64.sqrt(), 0.0);
600
601 let qubit_ids = vec![QubitId(0)];
602 let decomp = decomposer
603 .decompose(&h, &qubit_ids)
604 .expect("Failed to decompose Hadamard gate");
605
606 assert!(decomp.single_qubit_count <= 3);
608 assert_eq!(decomp.cnot_count, 0);
609 }
610
611 #[test]
612 fn test_shannon_two_qubit() {
613 let mut decomposer = ShannonDecomposer::new();
614
615 let cnot = Array2::from_shape_vec(
617 (4, 4),
618 vec![
619 Complex::new(1.0, 0.0),
620 Complex::new(0.0, 0.0),
621 Complex::new(0.0, 0.0),
622 Complex::new(0.0, 0.0),
623 Complex::new(0.0, 0.0),
624 Complex::new(1.0, 0.0),
625 Complex::new(0.0, 0.0),
626 Complex::new(0.0, 0.0),
627 Complex::new(0.0, 0.0),
628 Complex::new(0.0, 0.0),
629 Complex::new(0.0, 0.0),
630 Complex::new(1.0, 0.0),
631 Complex::new(0.0, 0.0),
632 Complex::new(0.0, 0.0),
633 Complex::new(1.0, 0.0),
634 Complex::new(0.0, 0.0),
635 ],
636 )
637 .expect("Failed to create CNOT matrix");
638
639 let qubit_ids = vec![QubitId(0), QubitId(1)];
640 let decomp = decomposer
641 .decompose(&cnot, &qubit_ids)
642 .expect("Failed to decompose CNOT gate");
643
644 assert!(decomp.cnot_count <= 3);
646 }
647
648 #[test]
649 fn test_optimized_decomposer() {
650 let mut decomposer = OptimizedShannonDecomposer::new();
651
652 let identity = Array2::eye(4);
654 let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
655
656 let qubit_ids = vec![QubitId(0), QubitId(1)];
657 let decomp = decomposer
658 .decompose(&identity_complex, &qubit_ids)
659 .expect("Failed to decompose identity matrix");
660
661 assert_eq!(decomp.gates.len(), 0);
663 }
664
665 #[test]
666 fn test_merge_rz_rotations() {
667 let decomposer = OptimizedShannonDecomposer::new();
668 let qubit = QubitId(0);
669 let g1 = Box::new(RotationZ {
670 target: qubit,
671 theta: 0.3,
672 }) as Box<dyn GateOp>;
673 let g2 = Box::new(RotationZ {
674 target: qubit,
675 theta: 0.4,
676 }) as Box<dyn GateOp>;
677 let merged = decomposer
678 .try_merge_rotations(&g1, &g2)
679 .expect("should merge RZ+RZ");
680 let rz = merged
681 .as_any()
682 .downcast_ref::<RotationZ>()
683 .expect("merged gate must be RotationZ");
684 assert!(
685 (rz.theta - 0.7).abs() < 1e-10,
686 "merged theta should be 0.7, got {}",
687 rz.theta
688 );
689 }
690
691 #[test]
692 fn test_merge_rx_rotations() {
693 let decomposer = OptimizedShannonDecomposer::new();
694 let qubit = QubitId(0);
695 let g1 = Box::new(RotationX {
696 target: qubit,
697 theta: 0.5,
698 }) as Box<dyn GateOp>;
699 let g2 = Box::new(RotationX {
700 target: qubit,
701 theta: 0.3,
702 }) as Box<dyn GateOp>;
703 let merged = decomposer
704 .try_merge_rotations(&g1, &g2)
705 .expect("should merge RX+RX");
706 let rx = merged
707 .as_any()
708 .downcast_ref::<RotationX>()
709 .expect("merged gate must be RotationX");
710 assert!(
711 (rx.theta - 0.8).abs() < 1e-10,
712 "merged theta should be 0.8, got {}",
713 rx.theta
714 );
715 }
716
717 #[test]
718 fn test_no_merge_different_axes() {
719 let decomposer = OptimizedShannonDecomposer::new();
720 let qubit = QubitId(0);
721 let g1 = Box::new(RotationZ {
722 target: qubit,
723 theta: 0.3,
724 }) as Box<dyn GateOp>;
725 let g2 = Box::new(RotationX {
726 target: qubit,
727 theta: 0.4,
728 }) as Box<dyn GateOp>;
729 assert!(
730 decomposer.try_merge_rotations(&g1, &g2).is_none(),
731 "RZ and RX should not merge"
732 );
733 }
734}