1use crate::{
8 cartan::{CartanCoefficients, CartanDecomposer, CartanDecomposition},
9 error::{QuantRS2Error, QuantRS2Result},
10 gate::{multi::*, single::*, GateOp},
11 matrix_ops::{DenseMatrix, QuantumMatrix},
12 qubit::QubitId,
13 shannon::ShannonDecomposer,
14 synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
15};
16use ndarray::{s, Array1, Array2, ArrayView2, Axis};
17use num_complex::Complex;
18use rustc_hash::FxHashMap;
19use std::f64::consts::PI;
20
21#[derive(Debug, Clone)]
23pub struct MultiQubitKAK {
24 pub gates: Vec<Box<dyn GateOp>>,
26 pub tree: DecompositionTree,
28 pub cnot_count: usize,
30 pub single_qubit_count: usize,
32 pub depth: usize,
34}
35
36#[derive(Debug, Clone)]
38pub enum DecompositionTree {
39 Leaf {
41 qubits: Vec<QubitId>,
42 gate_type: LeafType,
43 },
44 Node {
46 qubits: Vec<QubitId>,
47 method: DecompositionMethod,
48 children: Vec<DecompositionTree>,
49 },
50}
51
52#[derive(Debug, Clone)]
54pub enum LeafType {
55 SingleQubit(SingleQubitDecomposition),
56 TwoQubit(CartanDecomposition),
57}
58
59#[derive(Debug, Clone)]
61pub enum DecompositionMethod {
62 CSD { pivot: usize },
64 Shannon { partition: usize },
66 BlockDiagonal { block_size: usize },
68 Cartan,
70}
71
72pub struct MultiQubitKAKDecomposer {
74 tolerance: f64,
76 max_depth: usize,
78 cache: FxHashMap<u64, MultiQubitKAK>,
80 use_optimization: bool,
82 cartan: CartanDecomposer,
84}
85
86impl MultiQubitKAKDecomposer {
87 pub fn new() -> Self {
89 Self {
90 tolerance: 1e-10,
91 max_depth: 20,
92 cache: FxHashMap::default(),
93 use_optimization: true,
94 cartan: CartanDecomposer::new(),
95 }
96 }
97
98 pub fn with_tolerance(tolerance: f64) -> Self {
100 Self {
101 tolerance,
102 max_depth: 20,
103 cache: FxHashMap::default(),
104 use_optimization: true,
105 cartan: CartanDecomposer::with_tolerance(tolerance),
106 }
107 }
108
109 pub fn decompose(
111 &mut self,
112 unitary: &Array2<Complex<f64>>,
113 qubit_ids: &[QubitId],
114 ) -> QuantRS2Result<MultiQubitKAK> {
115 let n = qubit_ids.len();
116 let size = 1 << n;
117
118 if unitary.shape() != [size, size] {
120 return Err(QuantRS2Error::InvalidInput(format!(
121 "Unitary size {} doesn't match {} qubits",
122 unitary.shape()[0],
123 n
124 )));
125 }
126
127 let mat = DenseMatrix::new(unitary.clone())?;
129 if !mat.is_unitary(self.tolerance)? {
130 return Err(QuantRS2Error::InvalidInput(
131 "Matrix is not unitary".to_string(),
132 ));
133 }
134
135 if let Some(cached) = self.check_cache(unitary) {
137 return Ok(cached.clone());
138 }
139
140 let (tree, gates) = self.decompose_recursive(unitary, qubit_ids, 0)?;
142
143 let mut cnot_count = 0;
145 let mut single_qubit_count = 0;
146
147 for gate in &gates {
148 match gate.name() {
149 "CNOT" | "CZ" | "SWAP" => cnot_count += self.count_cnots(gate.name()),
150 _ => single_qubit_count += 1,
151 }
152 }
153
154 let result = MultiQubitKAK {
155 gates,
156 tree,
157 cnot_count,
158 single_qubit_count,
159 depth: 0, };
161
162 self.cache_result(unitary, &result);
164
165 Ok(result)
166 }
167
168 fn decompose_recursive(
170 &mut self,
171 unitary: &Array2<Complex<f64>>,
172 qubit_ids: &[QubitId],
173 depth: usize,
174 ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
175 if depth > self.max_depth {
176 return Err(QuantRS2Error::InvalidInput(
177 "Maximum recursion depth exceeded".to_string(),
178 ));
179 }
180
181 let n = qubit_ids.len();
182
183 match n {
185 0 => {
186 let tree = DecompositionTree::Leaf {
187 qubits: vec![],
188 gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
189 global_phase: 0.0,
190 theta1: 0.0,
191 phi: 0.0,
192 theta2: 0.0,
193 basis: "ZYZ".to_string(),
194 }),
195 };
196 Ok((tree, vec![]))
197 }
198 1 => {
199 let decomp = decompose_single_qubit_zyz(&unitary.view())?;
200 let gates = self.single_qubit_to_gates(&decomp, qubit_ids[0]);
201 let tree = DecompositionTree::Leaf {
202 qubits: qubit_ids.to_vec(),
203 gate_type: LeafType::SingleQubit(decomp),
204 };
205 Ok((tree, gates))
206 }
207 2 => {
208 let decomp = self.cartan.decompose(unitary)?;
209 let gates = self.cartan.to_gates(&decomp, qubit_ids)?;
210 let tree = DecompositionTree::Leaf {
211 qubits: qubit_ids.to_vec(),
212 gate_type: LeafType::TwoQubit(decomp),
213 };
214 Ok((tree, gates))
215 }
216 _ => {
217 let method = self.choose_decomposition_method(unitary, n);
219
220 match method {
221 DecompositionMethod::CSD { pivot } => {
222 self.decompose_csd(unitary, qubit_ids, pivot, depth)
223 }
224 DecompositionMethod::Shannon { partition } => {
225 self.decompose_shannon(unitary, qubit_ids, partition, depth)
226 }
227 DecompositionMethod::BlockDiagonal { block_size } => {
228 self.decompose_block_diagonal(unitary, qubit_ids, block_size, depth)
229 }
230 _ => unreachable!("Invalid method for n > 2"),
231 }
232 }
233 }
234 }
235
236 fn choose_decomposition_method(
238 &self,
239 unitary: &Array2<Complex<f64>>,
240 n: usize,
241 ) -> DecompositionMethod {
242 if self.use_optimization {
243 if self.has_block_structure(unitary, n) {
245 DecompositionMethod::BlockDiagonal { block_size: n / 2 }
246 } else if n % 2 == 0 {
247 DecompositionMethod::CSD { pivot: n / 2 }
249 } else {
250 DecompositionMethod::Shannon { partition: n / 2 }
252 }
253 } else {
254 DecompositionMethod::CSD { pivot: n / 2 }
256 }
257 }
258
259 fn decompose_csd(
261 &mut self,
262 unitary: &Array2<Complex<f64>>,
263 qubit_ids: &[QubitId],
264 pivot: usize,
265 depth: usize,
266 ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
267 let n = qubit_ids.len();
268 let size = 1 << n;
269 let pivot_size = 1 << pivot;
270
271 let a = unitary.slice(s![..pivot_size, ..pivot_size]).to_owned();
275 let b = unitary.slice(s![..pivot_size, pivot_size..]).to_owned();
276 let c = unitary.slice(s![pivot_size.., ..pivot_size]).to_owned();
277 let d = unitary.slice(s![pivot_size.., pivot_size..]).to_owned();
278
279 let (u1, v1, sigma, u2, v2) = self.compute_csd(&a, &b, &c, &d)?;
285
286 let mut gates = Vec::new();
287 let mut children = Vec::new();
288
289 let left_qubits = &qubit_ids[..pivot];
291 let right_qubits = &qubit_ids[pivot..];
292
293 let (u2_tree, u2_gates) = self.decompose_recursive(&u2, left_qubits, depth + 1)?;
294 let (v2_tree, v2_gates) = self.decompose_recursive(&v2, right_qubits, depth + 1)?;
295
296 gates.extend(u2_gates);
297 gates.extend(v2_gates);
298 children.push(u2_tree);
299 children.push(v2_tree);
300
301 let diag_gates = self.diagonal_to_gates(&sigma, qubit_ids)?;
303 gates.extend(diag_gates);
304
305 let (u1_tree, u1_gates) = self.decompose_recursive(&u1, left_qubits, depth + 1)?;
307 let (v1_tree, v1_gates) = self.decompose_recursive(&v1, right_qubits, depth + 1)?;
308
309 gates.extend(u1_gates);
310 gates.extend(v1_gates);
311 children.push(u1_tree);
312 children.push(v1_tree);
313
314 let tree = DecompositionTree::Node {
315 qubits: qubit_ids.to_vec(),
316 method: DecompositionMethod::CSD { pivot },
317 children,
318 };
319
320 Ok((tree, gates))
321 }
322
323 fn decompose_shannon(
325 &mut self,
326 unitary: &Array2<Complex<f64>>,
327 qubit_ids: &[QubitId],
328 partition: usize,
329 depth: usize,
330 ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
331 let mut shannon = ShannonDecomposer::new();
333 let decomp = shannon.decompose(unitary, qubit_ids)?;
334
335 let tree = DecompositionTree::Node {
337 qubits: qubit_ids.to_vec(),
338 method: DecompositionMethod::Shannon { partition },
339 children: vec![], };
341
342 Ok((tree, decomp.gates))
343 }
344
345 fn decompose_block_diagonal(
347 &mut self,
348 unitary: &Array2<Complex<f64>>,
349 qubit_ids: &[QubitId],
350 block_size: usize,
351 depth: usize,
352 ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
353 let n = qubit_ids.len();
354 let num_blocks = n / block_size;
355
356 let mut gates = Vec::new();
357 let mut children = Vec::new();
358
359 for i in 0..num_blocks {
361 let start = i * block_size;
362 let end = (i + 1) * block_size;
363 let block_qubits = &qubit_ids[start..end];
364
365 let block = self.extract_block(unitary, i, block_size)?;
367
368 let (block_tree, block_gates) =
369 self.decompose_recursive(&block, block_qubits, depth + 1)?;
370 gates.extend(block_gates);
371 children.push(block_tree);
372 }
373
374 let tree = DecompositionTree::Node {
375 qubits: qubit_ids.to_vec(),
376 method: DecompositionMethod::BlockDiagonal { block_size },
377 children,
378 };
379
380 Ok((tree, gates))
381 }
382
383 fn compute_csd(
385 &self,
386 a: &Array2<Complex<f64>>,
387 b: &Array2<Complex<f64>>,
388 c: &Array2<Complex<f64>>,
389 d: &Array2<Complex<f64>>,
390 ) -> QuantRS2Result<(
391 Array2<Complex<f64>>, Array2<Complex<f64>>, Array2<Complex<f64>>, Array2<Complex<f64>>, Array2<Complex<f64>>, )> {
397 let size = a.shape()[0];
401 let identity = Array2::eye(size);
402 let zero: Array2<Complex<f64>> = Array2::zeros((size, size));
403
404 let u1 = identity.clone();
406 let v1 = identity.clone();
407 let u2 = identity.clone();
408 let v2 = identity.clone();
409
410 let mut sigma = Array2::zeros((size * 2, size * 2));
412 sigma.slice_mut(s![..size, ..size]).assign(a);
413 sigma.slice_mut(s![..size, size..]).assign(b);
414 sigma.slice_mut(s![size.., ..size]).assign(c);
415 sigma.slice_mut(s![size.., size..]).assign(d);
416
417 Ok((u1, v1, sigma, u2, v2))
418 }
419
420 fn diagonal_to_gates(
422 &self,
423 diagonal: &Array2<Complex<f64>>,
424 qubit_ids: &[QubitId],
425 ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
426 let mut gates = Vec::new();
427
428 let n = diagonal.shape()[0];
430 for i in 0..n {
431 let phase = diagonal[[i, i]].arg();
432 if phase.abs() > self.tolerance {
433 let mut control_pattern = Vec::new();
435 let mut temp = i;
436 for j in 0..qubit_ids.len() {
437 if temp & 1 == 1 {
438 control_pattern.push(j);
439 }
440 temp >>= 1;
441 }
442
443 if control_pattern.is_empty() {
445 } else if control_pattern.len() == 1 {
447 gates.push(Box::new(RotationZ {
449 target: qubit_ids[control_pattern[0]],
450 theta: phase,
451 }) as Box<dyn GateOp>);
452 } else {
453 let target = qubit_ids[control_pattern.pop().unwrap()];
456 for &control_idx in &control_pattern {
457 gates.push(Box::new(CNOT {
458 control: qubit_ids[control_idx],
459 target,
460 }));
461 }
462
463 gates.push(Box::new(RotationZ {
464 target,
465 theta: phase,
466 }) as Box<dyn GateOp>);
467
468 for &control_idx in control_pattern.iter().rev() {
470 gates.push(Box::new(CNOT {
471 control: qubit_ids[control_idx],
472 target,
473 }));
474 }
475 }
476 }
477 }
478
479 Ok(gates)
480 }
481
482 fn has_block_structure(&self, unitary: &Array2<Complex<f64>>, n: usize) -> bool {
484 let size = unitary.shape()[0];
486 let block_size = size / 2;
487
488 let mut off_diagonal_norm = 0.0;
489
490 for i in 0..block_size {
492 for j in block_size..size {
493 off_diagonal_norm += unitary[[i, j]].norm_sqr();
494 }
495 }
496
497 for i in block_size..size {
499 for j in 0..block_size {
500 off_diagonal_norm += unitary[[i, j]].norm_sqr();
501 }
502 }
503
504 off_diagonal_norm.sqrt() < self.tolerance
505 }
506
507 fn extract_block(
509 &self,
510 unitary: &Array2<Complex<f64>>,
511 block_idx: usize,
512 block_size: usize,
513 ) -> QuantRS2Result<Array2<Complex<f64>>> {
514 let size = 1 << block_size;
515 let start = block_idx * size;
516 let end = (block_idx + 1) * size;
517
518 Ok(unitary.slice(s![start..end, start..end]).to_owned())
519 }
520
521 fn single_qubit_to_gates(
523 &self,
524 decomp: &SingleQubitDecomposition,
525 qubit: QubitId,
526 ) -> Vec<Box<dyn GateOp>> {
527 let mut gates = Vec::new();
528
529 if decomp.theta1.abs() > self.tolerance {
530 gates.push(Box::new(RotationZ {
531 target: qubit,
532 theta: decomp.theta1,
533 }) as Box<dyn GateOp>);
534 }
535
536 if decomp.phi.abs() > self.tolerance {
537 gates.push(Box::new(RotationY {
538 target: qubit,
539 theta: decomp.phi,
540 }) as Box<dyn GateOp>);
541 }
542
543 if decomp.theta2.abs() > self.tolerance {
544 gates.push(Box::new(RotationZ {
545 target: qubit,
546 theta: decomp.theta2,
547 }) as Box<dyn GateOp>);
548 }
549
550 gates
551 }
552
553 fn count_cnots(&self, gate_name: &str) -> usize {
555 match gate_name {
556 "CNOT" => 1,
557 "CZ" => 1, "SWAP" => 3, _ => 0,
560 }
561 }
562
563 fn check_cache(&self, unitary: &Array2<Complex<f64>>) -> Option<&MultiQubitKAK> {
565 None
568 }
569
570 fn cache_result(&mut self, unitary: &Array2<Complex<f64>>, result: &MultiQubitKAK) {
572 }
574}
575
576pub struct KAKTreeAnalyzer {
578 stats: DecompositionStats,
580}
581
582#[derive(Debug, Default, Clone)]
583pub struct DecompositionStats {
584 pub total_nodes: usize,
585 pub leaf_nodes: usize,
586 pub max_depth: usize,
587 pub method_counts: FxHashMap<String, usize>,
588 pub cnot_distribution: FxHashMap<usize, usize>,
589}
590
591impl KAKTreeAnalyzer {
592 pub fn new() -> Self {
594 Self {
595 stats: DecompositionStats::default(),
596 }
597 }
598
599 pub fn analyze(&mut self, tree: &DecompositionTree) -> DecompositionStats {
601 self.stats = DecompositionStats::default();
602 self.analyze_recursive(tree, 0);
603 self.stats.clone()
604 }
605
606 fn analyze_recursive(&mut self, tree: &DecompositionTree, depth: usize) {
607 self.stats.total_nodes += 1;
608 self.stats.max_depth = self.stats.max_depth.max(depth);
609
610 match tree {
611 DecompositionTree::Leaf { qubits, gate_type } => {
612 self.stats.leaf_nodes += 1;
613
614 match gate_type {
615 LeafType::SingleQubit(_) => {
616 *self
617 .stats
618 .method_counts
619 .entry("single_qubit".to_string())
620 .or_insert(0) += 1;
621 }
622 LeafType::TwoQubit(cartan) => {
623 *self
624 .stats
625 .method_counts
626 .entry("two_qubit".to_string())
627 .or_insert(0) += 1;
628 let cnots = cartan.interaction.cnot_count(1e-10);
629 *self.stats.cnot_distribution.entry(cnots).or_insert(0) += 1;
630 }
631 }
632 }
633 DecompositionTree::Node {
634 method, children, ..
635 } => {
636 let method_name = match method {
637 DecompositionMethod::CSD { .. } => "csd",
638 DecompositionMethod::Shannon { .. } => "shannon",
639 DecompositionMethod::BlockDiagonal { .. } => "block_diagonal",
640 DecompositionMethod::Cartan => "cartan",
641 };
642 *self
643 .stats
644 .method_counts
645 .entry(method_name.to_string())
646 .or_insert(0) += 1;
647
648 for child in children {
649 self.analyze_recursive(child, depth + 1);
650 }
651 }
652 }
653 }
654}
655
656pub fn kak_decompose_multiqubit(
658 unitary: &Array2<Complex<f64>>,
659 qubit_ids: &[QubitId],
660) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
661 let mut decomposer = MultiQubitKAKDecomposer::new();
662 let decomp = decomposer.decompose(unitary, qubit_ids)?;
663 Ok(decomp.gates)
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669 use ndarray::Array2;
670 use num_complex::Complex;
671
672 #[test]
673 fn test_multiqubit_kak_single() {
674 let mut decomposer = MultiQubitKAKDecomposer::new();
675
676 let h = Array2::from_shape_vec(
678 (2, 2),
679 vec![
680 Complex::new(1.0, 0.0),
681 Complex::new(1.0, 0.0),
682 Complex::new(1.0, 0.0),
683 Complex::new(-1.0, 0.0),
684 ],
685 )
686 .unwrap()
687 / Complex::new(2.0_f64.sqrt(), 0.0);
688
689 let qubit_ids = vec![QubitId(0)];
690 let decomp = decomposer.decompose(&h, &qubit_ids).unwrap();
691
692 assert!(decomp.single_qubit_count <= 3);
693 assert_eq!(decomp.cnot_count, 0);
694
695 match &decomp.tree {
697 DecompositionTree::Leaf {
698 gate_type: LeafType::SingleQubit(_),
699 ..
700 } => {}
701 _ => panic!("Expected single-qubit leaf"),
702 }
703 }
704
705 #[test]
706 fn test_multiqubit_kak_two() {
707 let mut decomposer = MultiQubitKAKDecomposer::new();
708
709 let cnot = Array2::from_shape_vec(
711 (4, 4),
712 vec![
713 Complex::new(1.0, 0.0),
714 Complex::new(0.0, 0.0),
715 Complex::new(0.0, 0.0),
716 Complex::new(0.0, 0.0),
717 Complex::new(0.0, 0.0),
718 Complex::new(1.0, 0.0),
719 Complex::new(0.0, 0.0),
720 Complex::new(0.0, 0.0),
721 Complex::new(0.0, 0.0),
722 Complex::new(0.0, 0.0),
723 Complex::new(0.0, 0.0),
724 Complex::new(1.0, 0.0),
725 Complex::new(0.0, 0.0),
726 Complex::new(0.0, 0.0),
727 Complex::new(1.0, 0.0),
728 Complex::new(0.0, 0.0),
729 ],
730 )
731 .unwrap();
732
733 let qubit_ids = vec![QubitId(0), QubitId(1)];
734 let decomp = decomposer.decompose(&cnot, &qubit_ids).unwrap();
735
736 assert!(decomp.cnot_count <= 1);
737
738 match &decomp.tree {
740 DecompositionTree::Leaf {
741 gate_type: LeafType::TwoQubit(_),
742 ..
743 } => {}
744 _ => panic!("Expected two-qubit leaf"),
745 }
746 }
747
748 #[test]
749 fn test_multiqubit_kak_three() {
750 let mut decomposer = MultiQubitKAKDecomposer::new();
751
752 let identity = Array2::eye(8);
754 let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
755
756 let qubit_ids = vec![QubitId(0), QubitId(1), QubitId(2)];
757 let decomp = decomposer.decompose(&identity_complex, &qubit_ids).unwrap();
758
759 assert_eq!(decomp.gates.len(), 0);
761 assert_eq!(decomp.cnot_count, 0);
762 assert_eq!(decomp.single_qubit_count, 0);
763 }
764
765 #[test]
766 fn test_tree_analyzer() {
767 let mut analyzer = KAKTreeAnalyzer::new();
768
769 let tree = DecompositionTree::Node {
771 qubits: vec![QubitId(0), QubitId(1), QubitId(2)],
772 method: DecompositionMethod::CSD { pivot: 2 },
773 children: vec![
774 DecompositionTree::Leaf {
775 qubits: vec![QubitId(0), QubitId(1)],
776 gate_type: LeafType::TwoQubit(CartanDecomposition {
777 left_gates: (
778 SingleQubitDecomposition {
779 global_phase: 0.0,
780 theta1: 0.0,
781 phi: 0.0,
782 theta2: 0.0,
783 basis: "ZYZ".to_string(),
784 },
785 SingleQubitDecomposition {
786 global_phase: 0.0,
787 theta1: 0.0,
788 phi: 0.0,
789 theta2: 0.0,
790 basis: "ZYZ".to_string(),
791 },
792 ),
793 right_gates: (
794 SingleQubitDecomposition {
795 global_phase: 0.0,
796 theta1: 0.0,
797 phi: 0.0,
798 theta2: 0.0,
799 basis: "ZYZ".to_string(),
800 },
801 SingleQubitDecomposition {
802 global_phase: 0.0,
803 theta1: 0.0,
804 phi: 0.0,
805 theta2: 0.0,
806 basis: "ZYZ".to_string(),
807 },
808 ),
809 interaction: CartanCoefficients::new(0.0, 0.0, 0.0),
810 global_phase: 0.0,
811 }),
812 },
813 DecompositionTree::Leaf {
814 qubits: vec![QubitId(2)],
815 gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
816 global_phase: 0.0,
817 theta1: 0.0,
818 phi: 0.0,
819 theta2: 0.0,
820 basis: "ZYZ".to_string(),
821 }),
822 },
823 ],
824 };
825
826 let stats = analyzer.analyze(&tree);
827
828 assert_eq!(stats.total_nodes, 3);
829 assert_eq!(stats.leaf_nodes, 2);
830 assert_eq!(stats.max_depth, 1);
831 assert_eq!(stats.method_counts.get("csd"), Some(&1));
832 }
833
834 #[test]
835 fn test_block_structure_detection() {
836 let decomposer = MultiQubitKAKDecomposer::new();
837
838 let mut block_diag = Array2::zeros((4, 4));
840 block_diag[[0, 0]] = Complex::new(1.0, 0.0);
841 block_diag[[1, 1]] = Complex::new(1.0, 0.0);
842 block_diag[[2, 2]] = Complex::new(1.0, 0.0);
843 block_diag[[3, 3]] = Complex::new(1.0, 0.0);
844
845 assert!(decomposer.has_block_structure(&block_diag, 2));
846
847 block_diag[[0, 2]] = Complex::new(1.0, 0.0);
849 assert!(!decomposer.has_block_structure(&block_diag, 2));
850 }
851}