1use crate::{
8 cartan::{CartanDecomposer, CartanDecomposition},
9 error::{QuantRS2Error, QuantRS2Result},
10 gate::{multi::*, single::*, GateOp},
11 matrix_ops::{DenseMatrix, QuantumMatrix},
12 qubit::QubitId,
13 shannon::ShannonDecomposer,
14 synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
15};
16use ndarray::{s, Array2};
17use num_complex::Complex;
18use rustc_hash::FxHashMap;
19
20#[derive(Debug, Clone)]
22pub struct MultiQubitKAK {
23 pub gates: Vec<Box<dyn GateOp>>,
25 pub tree: DecompositionTree,
27 pub cnot_count: usize,
29 pub single_qubit_count: usize,
31 pub depth: usize,
33}
34
35#[derive(Debug, Clone)]
37pub enum DecompositionTree {
38 Leaf {
40 qubits: Vec<QubitId>,
41 gate_type: LeafType,
42 },
43 Node {
45 qubits: Vec<QubitId>,
46 method: DecompositionMethod,
47 children: Vec<DecompositionTree>,
48 },
49}
50
51#[derive(Debug, Clone)]
53pub enum LeafType {
54 SingleQubit(SingleQubitDecomposition),
55 TwoQubit(CartanDecomposition),
56}
57
58#[derive(Debug, Clone)]
60pub enum DecompositionMethod {
61 CSD { pivot: usize },
63 Shannon { partition: usize },
65 BlockDiagonal { block_size: usize },
67 Cartan,
69}
70
71pub struct MultiQubitKAKDecomposer {
73 tolerance: f64,
75 max_depth: usize,
77 #[allow(dead_code)]
79 cache: FxHashMap<u64, MultiQubitKAK>,
80 use_optimization: bool,
82 cartan: CartanDecomposer,
84}
85
86impl MultiQubitKAKDecomposer {
87 pub fn new() -> Self {
89 Self {
90 tolerance: 1e-10,
91 max_depth: 20,
92 cache: FxHashMap::default(),
93 use_optimization: true,
94 cartan: CartanDecomposer::new(),
95 }
96 }
97
98 pub fn with_tolerance(tolerance: f64) -> Self {
100 Self {
101 tolerance,
102 max_depth: 20,
103 cache: FxHashMap::default(),
104 use_optimization: true,
105 cartan: CartanDecomposer::with_tolerance(tolerance),
106 }
107 }
108
109 pub fn decompose(
111 &mut self,
112 unitary: &Array2<Complex<f64>>,
113 qubit_ids: &[QubitId],
114 ) -> QuantRS2Result<MultiQubitKAK> {
115 let n = qubit_ids.len();
116 let size = 1 << n;
117
118 if unitary.shape() != [size, size] {
120 return Err(QuantRS2Error::InvalidInput(format!(
121 "Unitary size {} doesn't match {} qubits",
122 unitary.shape()[0],
123 n
124 )));
125 }
126
127 let mat = DenseMatrix::new(unitary.clone())?;
129 if !mat.is_unitary(self.tolerance)? {
130 return Err(QuantRS2Error::InvalidInput(
131 "Matrix is not unitary".to_string(),
132 ));
133 }
134
135 if let Some(cached) = self.check_cache(unitary) {
137 return Ok(cached.clone());
138 }
139
140 let (tree, gates) = self.decompose_recursive(unitary, qubit_ids, 0)?;
142
143 let mut cnot_count = 0;
145 let mut single_qubit_count = 0;
146
147 for gate in &gates {
148 match gate.name() {
149 "CNOT" | "CZ" | "SWAP" => cnot_count += self.count_cnots(gate.name()),
150 _ => single_qubit_count += 1,
151 }
152 }
153
154 let result = MultiQubitKAK {
155 gates,
156 tree,
157 cnot_count,
158 single_qubit_count,
159 depth: 0, };
161
162 self.cache_result(unitary, &result);
164
165 Ok(result)
166 }
167
168 fn decompose_recursive(
170 &mut self,
171 unitary: &Array2<Complex<f64>>,
172 qubit_ids: &[QubitId],
173 depth: usize,
174 ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
175 if depth > self.max_depth {
176 return Err(QuantRS2Error::InvalidInput(
177 "Maximum recursion depth exceeded".to_string(),
178 ));
179 }
180
181 let n = qubit_ids.len();
182
183 match n {
185 0 => {
186 let tree = DecompositionTree::Leaf {
187 qubits: vec![],
188 gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
189 global_phase: 0.0,
190 theta1: 0.0,
191 phi: 0.0,
192 theta2: 0.0,
193 basis: "ZYZ".to_string(),
194 }),
195 };
196 Ok((tree, vec![]))
197 }
198 1 => {
199 let decomp = decompose_single_qubit_zyz(&unitary.view())?;
200 let gates = self.single_qubit_to_gates(&decomp, qubit_ids[0]);
201 let tree = DecompositionTree::Leaf {
202 qubits: qubit_ids.to_vec(),
203 gate_type: LeafType::SingleQubit(decomp),
204 };
205 Ok((tree, gates))
206 }
207 2 => {
208 let decomp = self.cartan.decompose(unitary)?;
209 let gates = self.cartan.to_gates(&decomp, qubit_ids)?;
210 let tree = DecompositionTree::Leaf {
211 qubits: qubit_ids.to_vec(),
212 gate_type: LeafType::TwoQubit(decomp),
213 };
214 Ok((tree, gates))
215 }
216 _ => {
217 let method = self.choose_decomposition_method(unitary, n);
219
220 match method {
221 DecompositionMethod::CSD { pivot } => {
222 self.decompose_csd(unitary, qubit_ids, pivot, depth)
223 }
224 DecompositionMethod::Shannon { partition } => {
225 self.decompose_shannon(unitary, qubit_ids, partition, depth)
226 }
227 DecompositionMethod::BlockDiagonal { block_size } => {
228 self.decompose_block_diagonal(unitary, qubit_ids, block_size, depth)
229 }
230 _ => unreachable!("Invalid method for n > 2"),
231 }
232 }
233 }
234 }
235
236 fn choose_decomposition_method(
238 &self,
239 unitary: &Array2<Complex<f64>>,
240 n: usize,
241 ) -> DecompositionMethod {
242 if self.use_optimization {
243 if self.has_block_structure(unitary, n) {
245 DecompositionMethod::BlockDiagonal { block_size: n / 2 }
246 } else if n % 2 == 0 {
247 DecompositionMethod::CSD { pivot: n / 2 }
249 } else {
250 DecompositionMethod::Shannon { partition: n / 2 }
252 }
253 } else {
254 DecompositionMethod::CSD { pivot: n / 2 }
256 }
257 }
258
259 fn decompose_csd(
261 &mut self,
262 unitary: &Array2<Complex<f64>>,
263 qubit_ids: &[QubitId],
264 pivot: usize,
265 depth: usize,
266 ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
267 let n = qubit_ids.len();
268 let _size = 1 << n;
269 let pivot_size = 1 << pivot;
270
271 let a = unitary.slice(s![..pivot_size, ..pivot_size]).to_owned();
275 let b = unitary.slice(s![..pivot_size, pivot_size..]).to_owned();
276 let c = unitary.slice(s![pivot_size.., ..pivot_size]).to_owned();
277 let d = unitary.slice(s![pivot_size.., pivot_size..]).to_owned();
278
279 let (u1, v1, sigma, u2, v2) = self.compute_csd(&a, &b, &c, &d)?;
285
286 let mut gates = Vec::new();
287 let mut children = Vec::new();
288
289 let left_qubits = &qubit_ids[..pivot];
291 let right_qubits = &qubit_ids[pivot..];
292
293 let (u2_tree, u2_gates) = self.decompose_recursive(&u2, left_qubits, depth + 1)?;
294 let (v2_tree, v2_gates) = self.decompose_recursive(&v2, right_qubits, depth + 1)?;
295
296 gates.extend(u2_gates);
297 gates.extend(v2_gates);
298 children.push(u2_tree);
299 children.push(v2_tree);
300
301 let diag_gates = self.diagonal_to_gates(&sigma, qubit_ids)?;
303 gates.extend(diag_gates);
304
305 let (u1_tree, u1_gates) = self.decompose_recursive(&u1, left_qubits, depth + 1)?;
307 let (v1_tree, v1_gates) = self.decompose_recursive(&v1, right_qubits, depth + 1)?;
308
309 gates.extend(u1_gates);
310 gates.extend(v1_gates);
311 children.push(u1_tree);
312 children.push(v1_tree);
313
314 let tree = DecompositionTree::Node {
315 qubits: qubit_ids.to_vec(),
316 method: DecompositionMethod::CSD { pivot },
317 children,
318 };
319
320 Ok((tree, gates))
321 }
322
323 fn decompose_shannon(
325 &mut self,
326 unitary: &Array2<Complex<f64>>,
327 qubit_ids: &[QubitId],
328 partition: usize,
329 _depth: usize,
330 ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
331 let mut shannon = ShannonDecomposer::new();
333 let decomp = shannon.decompose(unitary, qubit_ids)?;
334
335 let tree = DecompositionTree::Node {
337 qubits: qubit_ids.to_vec(),
338 method: DecompositionMethod::Shannon { partition },
339 children: vec![], };
341
342 Ok((tree, decomp.gates))
343 }
344
345 fn decompose_block_diagonal(
347 &mut self,
348 unitary: &Array2<Complex<f64>>,
349 qubit_ids: &[QubitId],
350 block_size: usize,
351 depth: usize,
352 ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
353 let n = qubit_ids.len();
354 let num_blocks = n / block_size;
355
356 let mut gates = Vec::new();
357 let mut children = Vec::new();
358
359 for i in 0..num_blocks {
361 let start = i * block_size;
362 let end = (i + 1) * block_size;
363 let block_qubits = &qubit_ids[start..end];
364
365 let block = self.extract_block(unitary, i, block_size)?;
367
368 let (block_tree, block_gates) =
369 self.decompose_recursive(&block, block_qubits, depth + 1)?;
370 gates.extend(block_gates);
371 children.push(block_tree);
372 }
373
374 let tree = DecompositionTree::Node {
375 qubits: qubit_ids.to_vec(),
376 method: DecompositionMethod::BlockDiagonal { block_size },
377 children,
378 };
379
380 Ok((tree, gates))
381 }
382
383 fn compute_csd(
385 &self,
386 a: &Array2<Complex<f64>>,
387 b: &Array2<Complex<f64>>,
388 c: &Array2<Complex<f64>>,
389 d: &Array2<Complex<f64>>,
390 ) -> QuantRS2Result<(
391 Array2<Complex<f64>>, Array2<Complex<f64>>, Array2<Complex<f64>>, Array2<Complex<f64>>, Array2<Complex<f64>>, )> {
397 let size = a.shape()[0];
401 let identity = Array2::eye(size);
402 let _zero: Array2<Complex<f64>> = Array2::zeros((size, size));
403
404 let u1 = identity.clone();
406 let v1 = identity.clone();
407 let u2 = identity.clone();
408 let v2 = identity.clone();
409
410 let mut sigma = Array2::zeros((size * 2, size * 2));
412 sigma.slice_mut(s![..size, ..size]).assign(a);
413 sigma.slice_mut(s![..size, size..]).assign(b);
414 sigma.slice_mut(s![size.., ..size]).assign(c);
415 sigma.slice_mut(s![size.., size..]).assign(d);
416
417 Ok((u1, v1, sigma, u2, v2))
418 }
419
420 fn diagonal_to_gates(
422 &self,
423 diagonal: &Array2<Complex<f64>>,
424 qubit_ids: &[QubitId],
425 ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
426 let mut gates = Vec::new();
427
428 let n = diagonal.shape()[0];
430 for i in 0..n {
431 let phase = diagonal[[i, i]].arg();
432 if phase.abs() > self.tolerance {
433 let mut control_pattern = Vec::new();
435 let mut temp = i;
436 for j in 0..qubit_ids.len() {
437 if temp & 1 == 1 {
438 control_pattern.push(j);
439 }
440 temp >>= 1;
441 }
442
443 if control_pattern.is_empty() {
445 } else if control_pattern.len() == 1 {
447 gates.push(Box::new(RotationZ {
449 target: qubit_ids[control_pattern[0]],
450 theta: phase,
451 }) as Box<dyn GateOp>);
452 } else {
453 let target = qubit_ids[control_pattern.pop().unwrap()];
456 for &control_idx in &control_pattern {
457 gates.push(Box::new(CNOT {
458 control: qubit_ids[control_idx],
459 target,
460 }));
461 }
462
463 gates.push(Box::new(RotationZ {
464 target,
465 theta: phase,
466 }) as Box<dyn GateOp>);
467
468 for &control_idx in control_pattern.iter().rev() {
470 gates.push(Box::new(CNOT {
471 control: qubit_ids[control_idx],
472 target,
473 }));
474 }
475 }
476 }
477 }
478
479 Ok(gates)
480 }
481
482 fn has_block_structure(&self, unitary: &Array2<Complex<f64>>, _n: usize) -> bool {
484 let size = unitary.shape()[0];
486 let block_size = size / 2;
487
488 let mut off_diagonal_norm = 0.0;
489
490 for i in 0..block_size {
492 for j in block_size..size {
493 off_diagonal_norm += unitary[[i, j]].norm_sqr();
494 }
495 }
496
497 for i in block_size..size {
499 for j in 0..block_size {
500 off_diagonal_norm += unitary[[i, j]].norm_sqr();
501 }
502 }
503
504 off_diagonal_norm.sqrt() < self.tolerance
505 }
506
507 fn extract_block(
509 &self,
510 unitary: &Array2<Complex<f64>>,
511 block_idx: usize,
512 block_size: usize,
513 ) -> QuantRS2Result<Array2<Complex<f64>>> {
514 let size = 1 << block_size;
515 let start = block_idx * size;
516 let end = (block_idx + 1) * size;
517
518 Ok(unitary.slice(s![start..end, start..end]).to_owned())
519 }
520
521 fn single_qubit_to_gates(
523 &self,
524 decomp: &SingleQubitDecomposition,
525 qubit: QubitId,
526 ) -> Vec<Box<dyn GateOp>> {
527 let mut gates = Vec::new();
528
529 if decomp.theta1.abs() > self.tolerance {
530 gates.push(Box::new(RotationZ {
531 target: qubit,
532 theta: decomp.theta1,
533 }) as Box<dyn GateOp>);
534 }
535
536 if decomp.phi.abs() > self.tolerance {
537 gates.push(Box::new(RotationY {
538 target: qubit,
539 theta: decomp.phi,
540 }) as Box<dyn GateOp>);
541 }
542
543 if decomp.theta2.abs() > self.tolerance {
544 gates.push(Box::new(RotationZ {
545 target: qubit,
546 theta: decomp.theta2,
547 }) as Box<dyn GateOp>);
548 }
549
550 gates
551 }
552
553 fn count_cnots(&self, gate_name: &str) -> usize {
555 match gate_name {
556 "CNOT" => 1,
557 "CZ" => 1, "SWAP" => 3, _ => 0,
560 }
561 }
562
563 fn check_cache(&self, _unitary: &Array2<Complex<f64>>) -> Option<&MultiQubitKAK> {
565 None
568 }
569
570 fn cache_result(&mut self, _unitary: &Array2<Complex<f64>>, _result: &MultiQubitKAK) {
572 }
574}
575
576impl Default for MultiQubitKAKDecomposer {
577 fn default() -> Self {
578 Self::new()
579 }
580}
581
582pub struct KAKTreeAnalyzer {
584 stats: DecompositionStats,
586}
587
588#[derive(Debug, Default, Clone)]
589pub struct DecompositionStats {
590 pub total_nodes: usize,
591 pub leaf_nodes: usize,
592 pub max_depth: usize,
593 pub method_counts: FxHashMap<String, usize>,
594 pub cnot_distribution: FxHashMap<usize, usize>,
595}
596
597impl KAKTreeAnalyzer {
598 pub fn new() -> Self {
600 Self {
601 stats: DecompositionStats::default(),
602 }
603 }
604
605 pub fn analyze(&mut self, tree: &DecompositionTree) -> DecompositionStats {
607 self.stats = DecompositionStats::default();
608 self.analyze_recursive(tree, 0);
609 self.stats.clone()
610 }
611
612 fn analyze_recursive(&mut self, tree: &DecompositionTree, depth: usize) {
613 self.stats.total_nodes += 1;
614 self.stats.max_depth = self.stats.max_depth.max(depth);
615
616 match tree {
617 DecompositionTree::Leaf {
618 qubits: _qubits,
619 gate_type,
620 } => {
621 self.stats.leaf_nodes += 1;
622
623 match gate_type {
624 LeafType::SingleQubit(_) => {
625 *self
626 .stats
627 .method_counts
628 .entry("single_qubit".to_string())
629 .or_insert(0) += 1;
630 }
631 LeafType::TwoQubit(cartan) => {
632 *self
633 .stats
634 .method_counts
635 .entry("two_qubit".to_string())
636 .or_insert(0) += 1;
637 let cnots = cartan.interaction.cnot_count(1e-10);
638 *self.stats.cnot_distribution.entry(cnots).or_insert(0) += 1;
639 }
640 }
641 }
642 DecompositionTree::Node {
643 method, children, ..
644 } => {
645 let method_name = match method {
646 DecompositionMethod::CSD { .. } => "csd",
647 DecompositionMethod::Shannon { .. } => "shannon",
648 DecompositionMethod::BlockDiagonal { .. } => "block_diagonal",
649 DecompositionMethod::Cartan => "cartan",
650 };
651 *self
652 .stats
653 .method_counts
654 .entry(method_name.to_string())
655 .or_insert(0) += 1;
656
657 for child in children {
658 self.analyze_recursive(child, depth + 1);
659 }
660 }
661 }
662 }
663}
664
665pub fn kak_decompose_multiqubit(
667 unitary: &Array2<Complex<f64>>,
668 qubit_ids: &[QubitId],
669) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
670 let mut decomposer = MultiQubitKAKDecomposer::new();
671 let decomp = decomposer.decompose(unitary, qubit_ids)?;
672 Ok(decomp.gates)
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use ndarray::Array2;
679 use num_complex::Complex;
680
681 #[test]
682 fn test_multiqubit_kak_single() {
683 let mut decomposer = MultiQubitKAKDecomposer::new();
684
685 let h = Array2::from_shape_vec(
687 (2, 2),
688 vec![
689 Complex::new(1.0, 0.0),
690 Complex::new(1.0, 0.0),
691 Complex::new(1.0, 0.0),
692 Complex::new(-1.0, 0.0),
693 ],
694 )
695 .unwrap()
696 / Complex::new(2.0_f64.sqrt(), 0.0);
697
698 let qubit_ids = vec![QubitId(0)];
699 let decomp = decomposer.decompose(&h, &qubit_ids).unwrap();
700
701 assert!(decomp.single_qubit_count <= 3);
702 assert_eq!(decomp.cnot_count, 0);
703
704 match &decomp.tree {
706 DecompositionTree::Leaf {
707 gate_type: LeafType::SingleQubit(_),
708 ..
709 } => {}
710 _ => panic!("Expected single-qubit leaf"),
711 }
712 }
713
714 #[test]
715 fn test_multiqubit_kak_two() {
716 let mut decomposer = MultiQubitKAKDecomposer::new();
717
718 let cnot = Array2::from_shape_vec(
720 (4, 4),
721 vec![
722 Complex::new(1.0, 0.0),
723 Complex::new(0.0, 0.0),
724 Complex::new(0.0, 0.0),
725 Complex::new(0.0, 0.0),
726 Complex::new(0.0, 0.0),
727 Complex::new(1.0, 0.0),
728 Complex::new(0.0, 0.0),
729 Complex::new(0.0, 0.0),
730 Complex::new(0.0, 0.0),
731 Complex::new(0.0, 0.0),
732 Complex::new(0.0, 0.0),
733 Complex::new(1.0, 0.0),
734 Complex::new(0.0, 0.0),
735 Complex::new(0.0, 0.0),
736 Complex::new(1.0, 0.0),
737 Complex::new(0.0, 0.0),
738 ],
739 )
740 .unwrap();
741
742 let qubit_ids = vec![QubitId(0), QubitId(1)];
743 let decomp = decomposer.decompose(&cnot, &qubit_ids).unwrap();
744
745 assert!(decomp.cnot_count <= 1);
746
747 match &decomp.tree {
749 DecompositionTree::Leaf {
750 gate_type: LeafType::TwoQubit(_),
751 ..
752 } => {}
753 _ => panic!("Expected two-qubit leaf"),
754 }
755 }
756
757 #[test]
758 fn test_multiqubit_kak_three() {
759 let mut decomposer = MultiQubitKAKDecomposer::new();
760
761 let identity = Array2::eye(8);
763 let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
764
765 let qubit_ids = vec![QubitId(0), QubitId(1), QubitId(2)];
766 let decomp = decomposer.decompose(&identity_complex, &qubit_ids).unwrap();
767
768 assert_eq!(decomp.gates.len(), 0);
770 assert_eq!(decomp.cnot_count, 0);
771 assert_eq!(decomp.single_qubit_count, 0);
772 }
773
774 #[test]
775 fn test_tree_analyzer() {
776 let mut analyzer = KAKTreeAnalyzer::new();
777
778 let tree = DecompositionTree::Node {
780 qubits: vec![QubitId(0), QubitId(1), QubitId(2)],
781 method: DecompositionMethod::CSD { pivot: 2 },
782 children: vec![
783 DecompositionTree::Leaf {
784 qubits: vec![QubitId(0), QubitId(1)],
785 gate_type: LeafType::TwoQubit(CartanDecomposition {
786 left_gates: (
787 SingleQubitDecomposition {
788 global_phase: 0.0,
789 theta1: 0.0,
790 phi: 0.0,
791 theta2: 0.0,
792 basis: "ZYZ".to_string(),
793 },
794 SingleQubitDecomposition {
795 global_phase: 0.0,
796 theta1: 0.0,
797 phi: 0.0,
798 theta2: 0.0,
799 basis: "ZYZ".to_string(),
800 },
801 ),
802 right_gates: (
803 SingleQubitDecomposition {
804 global_phase: 0.0,
805 theta1: 0.0,
806 phi: 0.0,
807 theta2: 0.0,
808 basis: "ZYZ".to_string(),
809 },
810 SingleQubitDecomposition {
811 global_phase: 0.0,
812 theta1: 0.0,
813 phi: 0.0,
814 theta2: 0.0,
815 basis: "ZYZ".to_string(),
816 },
817 ),
818 interaction: crate::prelude::CartanCoefficients::new(0.0, 0.0, 0.0),
819 global_phase: 0.0,
820 }),
821 },
822 DecompositionTree::Leaf {
823 qubits: vec![QubitId(2)],
824 gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
825 global_phase: 0.0,
826 theta1: 0.0,
827 phi: 0.0,
828 theta2: 0.0,
829 basis: "ZYZ".to_string(),
830 }),
831 },
832 ],
833 };
834
835 let stats = analyzer.analyze(&tree);
836
837 assert_eq!(stats.total_nodes, 3);
838 assert_eq!(stats.leaf_nodes, 2);
839 assert_eq!(stats.max_depth, 1);
840 assert_eq!(stats.method_counts.get("csd"), Some(&1));
841 }
842
843 #[test]
844 fn test_block_structure_detection() {
845 let decomposer = MultiQubitKAKDecomposer::new();
846
847 let mut block_diag = Array2::zeros((4, 4));
849 block_diag[[0, 0]] = Complex::new(1.0, 0.0);
850 block_diag[[1, 1]] = Complex::new(1.0, 0.0);
851 block_diag[[2, 2]] = Complex::new(1.0, 0.0);
852 block_diag[[3, 3]] = Complex::new(1.0, 0.0);
853
854 assert!(decomposer.has_block_structure(&block_diag, 2));
855
856 block_diag[[0, 2]] = Complex::new(1.0, 0.0);
858 assert!(!decomposer.has_block_structure(&block_diag, 2));
859 }
860}