1use crate::{
12 error::{QuantRS2Error, QuantRS2Result},
13 gate::{multi::*, single::*, GateOp},
14 matrix_ops::{DenseMatrix, QuantumMatrix},
15 qubit::QubitId,
16 synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
17};
18use rustc_hash::FxHashMap;
19use scirs2_core::ndarray::{s, Array1, Array2};
20use scirs2_core::Complex;
21use std::f64::consts::PI;
22
23#[derive(Debug, Clone)]
25pub struct CartanDecomposition {
26 pub left_gates: (SingleQubitDecomposition, SingleQubitDecomposition),
28 pub right_gates: (SingleQubitDecomposition, SingleQubitDecomposition),
30 pub interaction: CartanCoefficients,
32 pub global_phase: f64,
34}
35
36#[derive(Debug, Clone, Copy)]
38pub struct CartanCoefficients {
39 pub xx: f64,
41 pub yy: f64,
43 pub zz: f64,
45}
46
47impl CartanCoefficients {
48 pub const fn new(xx: f64, yy: f64, zz: f64) -> Self {
50 Self { xx, yy, zz }
51 }
52
53 pub fn is_identity(&self, tolerance: f64) -> bool {
55 self.xx.abs() < tolerance && self.yy.abs() < tolerance && self.zz.abs() < tolerance
56 }
57
58 pub fn cnot_count(&self, tolerance: f64) -> usize {
60 let eps = tolerance;
61
62 if self.is_identity(eps) {
64 0
65 } else if (self.xx - self.yy).abs() < eps && self.zz.abs() < eps {
66 2
68 } else if (self.xx - PI / 4.0).abs() < eps
69 && (self.yy - PI / 4.0).abs() < eps
70 && (self.zz - PI / 4.0).abs() < eps
71 {
72 3
74 } else if self.xx.abs() < eps || self.yy.abs() < eps || self.zz.abs() < eps {
75 2
77 } else {
78 3
80 }
81 }
82
83 pub fn canonicalize(&mut self) {
85 let mut vals = [
87 (self.xx.abs(), self.xx, 0),
88 (self.yy.abs(), self.yy, 1),
89 (self.zz.abs(), self.zz, 2),
90 ];
91 vals.sort_by(|a, b| {
92 b.0.partial_cmp(&a.0)
93 .expect("Failed to compare Cartan coefficients in CartanCoefficients::canonicalize")
94 });
95
96 self.xx = vals[0].1;
97 self.yy = vals[1].1;
98 self.zz = vals[2].1;
99 }
100}
101
102pub struct CartanDecomposer {
104 tolerance: f64,
106 #[allow(dead_code)]
108 cache: FxHashMap<u64, CartanDecomposition>,
109}
110
111impl CartanDecomposer {
112 pub fn new() -> Self {
114 Self {
115 tolerance: 1e-10,
116 cache: FxHashMap::default(),
117 }
118 }
119
120 pub fn with_tolerance(tolerance: f64) -> Self {
122 Self {
123 tolerance,
124 cache: FxHashMap::default(),
125 }
126 }
127
128 pub fn decompose(
130 &mut self,
131 unitary: &Array2<Complex<f64>>,
132 ) -> QuantRS2Result<CartanDecomposition> {
133 if unitary.shape() != [4, 4] {
135 return Err(QuantRS2Error::InvalidInput(
136 "Cartan decomposition requires 4x4 unitary".to_string(),
137 ));
138 }
139
140 let mat = DenseMatrix::new(unitary.clone())?;
142 if !mat.is_unitary(self.tolerance)? {
143 return Err(QuantRS2Error::InvalidInput(
144 "Matrix is not unitary".to_string(),
145 ));
146 }
147
148 let magic_basis = Self::get_magic_basis();
150 let u_magic = Self::to_magic_basis(unitary, &magic_basis);
151
152 let u_magic_t = u_magic.t().to_owned();
154 let m = u_magic_t.dot(&u_magic);
155
156 let (d, p) = Self::diagonalize_symmetric(&m)?;
158
159 let coeffs = Self::extract_coefficients(&d);
161
162 let (left_gates, right_gates) = self.compute_local_gates(unitary, &u_magic, &p, &coeffs)?;
164
165 let global_phase = Self::compute_global_phase(unitary, &left_gates, &right_gates, &coeffs)?;
167
168 Ok(CartanDecomposition {
169 left_gates,
170 right_gates,
171 interaction: coeffs,
172 global_phase,
173 })
174 }
175
176 fn get_magic_basis() -> Array2<Complex<f64>> {
178 let sqrt2 = 2.0_f64.sqrt();
179 Array2::from_shape_vec(
180 (4, 4),
181 vec![
182 Complex::new(1.0, 0.0),
183 Complex::new(0.0, 0.0),
184 Complex::new(0.0, 0.0),
185 Complex::new(1.0, 0.0),
186 Complex::new(0.0, 0.0),
187 Complex::new(1.0, 0.0),
188 Complex::new(1.0, 0.0),
189 Complex::new(0.0, 0.0),
190 Complex::new(0.0, 0.0),
191 Complex::new(1.0, 0.0),
192 Complex::new(-1.0, 0.0),
193 Complex::new(0.0, 0.0),
194 Complex::new(1.0, 0.0),
195 Complex::new(0.0, 0.0),
196 Complex::new(0.0, 0.0),
197 Complex::new(-1.0, 0.0),
198 ],
199 )
200 .expect("Failed to create magic basis matrix in CartanDecomposer::get_magic_basis")
201 / Complex::new(sqrt2, 0.0)
202 }
203
204 fn to_magic_basis(
206 u: &Array2<Complex<f64>>,
207 magic: &Array2<Complex<f64>>,
208 ) -> Array2<Complex<f64>> {
209 let magic_dag = magic.mapv(|z| z.conj()).t().to_owned();
210 magic_dag.dot(u).dot(magic)
211 }
212
213 fn diagonalize_symmetric(
219 m: &Array2<Complex<f64>>,
220 ) -> QuantRS2Result<(Array1<Complex<f64>>, Array2<Complex<f64>>)> {
221 let n = m.nrows();
222 let mut h = m.to_owned();
225 let mut q = Array2::<Complex<f64>>::eye(n);
226
227 for k in 0..n.saturating_sub(2) {
229 let col: Vec<Complex<f64>> = (k + 1..n).map(|i| h[[i, k]]).collect();
231 let sigma_sq: f64 = col.iter().map(|z| z.norm_sqr()).sum();
232 let sigma = sigma_sq.sqrt();
233 if sigma < 1e-14 {
234 continue;
235 }
236 let phase = if col[0].norm() > 1e-14 {
238 col[0] / col[0].norm()
239 } else {
240 Complex::new(1.0, 0.0)
241 };
242 let mut v = col.clone();
243 v[0] = v[0] + phase * sigma;
244 let v_norm_sq: f64 = v.iter().map(|z| z.norm_sqr()).sum();
245 if v_norm_sq < 1e-28 {
246 continue;
247 }
248 let m_len = v.len(); for j in 0..n {
252 let dot: Complex<f64> = (0..m_len).map(|i| v[i].conj() * h[[k + 1 + i, j]]).sum();
253 let scale = dot * Complex::new(2.0 / v_norm_sq, 0.0);
254 for i in 0..m_len {
255 h[[k + 1 + i, j]] = h[[k + 1 + i, j]] - v[i] * scale;
256 }
257 }
258 for i in 0..n {
260 let dot: Complex<f64> = (0..m_len).map(|j| h[[i, k + 1 + j]] * v[j]).sum();
261 let scale = dot * Complex::new(2.0 / v_norm_sq, 0.0);
262 for j in 0..m_len {
263 h[[i, k + 1 + j]] = h[[i, k + 1 + j]] - scale * v[j].conj();
264 }
265 }
266 for i in 0..n {
268 let dot: Complex<f64> = (0..m_len).map(|j| q[[i, k + 1 + j]] * v[j]).sum();
269 let scale = dot * Complex::new(2.0 / v_norm_sq, 0.0);
270 for j in 0..m_len {
271 q[[i, k + 1 + j]] = q[[i, k + 1 + j]] - scale * v[j].conj();
272 }
273 }
274 }
275
276 let max_iter = 300 * n;
278 let mut active = n;
279 for _iter in 0..max_iter {
280 if active <= 1 {
281 break;
282 }
283 while active > 1 {
285 let off = h[[active - 1, active - 2]].norm();
286 let d1 = h[[active - 1, active - 1]].norm();
287 let d0 = h[[active - 2, active - 2]].norm();
288 if off < 1e-12 * (d1 + d0) {
289 active -= 1;
290 } else {
291 break;
292 }
293 }
294 if active <= 1 {
295 break;
296 }
297
298 let a = active;
300 let s = h[[a - 1, a - 1]];
301
302 for k in 0..a - 1 {
305 let x = h[[k, k]] - s;
307 let y = h[[k + 1, k]];
308 let r = (x.norm_sqr() + y.norm_sqr()).sqrt();
309 if r < 1e-14 {
310 continue;
311 }
312 let c_val = x / r;
313 let s_val = -y / r;
314
315 for j in 0..n {
317 let tmp0 = c_val * h[[k, j]] - s_val.conj() * h[[k + 1, j]];
318 let tmp1 = s_val * h[[k, j]] + c_val.conj() * h[[k + 1, j]];
319 h[[k, j]] = tmp0;
320 h[[k + 1, j]] = tmp1;
321 }
322 for i in 0..n {
324 let tmp0 = c_val.conj() * h[[i, k]] - s_val.conj() * h[[i, k + 1]];
325 let tmp1 = s_val * h[[i, k]] + c_val * h[[i, k + 1]];
326 h[[i, k]] = tmp0;
327 h[[i, k + 1]] = tmp1;
328 }
329 for i in 0..n {
331 let tmp0 = c_val.conj() * q[[i, k]] - s_val.conj() * q[[i, k + 1]];
332 let tmp1 = s_val * q[[i, k]] + c_val * q[[i, k + 1]];
333 q[[i, k]] = tmp0;
334 q[[i, k + 1]] = tmp1;
335 }
336 }
337 }
338
339 let mut eigenvalues = Array1::zeros(n);
341 for i in 0..n {
342 eigenvalues[i] = h[[i, i]];
343 }
344
345 Ok((eigenvalues, q))
346 }
347
348 fn extract_coefficients(eigenvalues: &Array1<Complex<f64>>) -> CartanCoefficients {
356 let mut phases: Vec<f64> = eigenvalues.iter().map(|z| z.arg() / 2.0).collect();
359 phases.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
361
362 let p0 = phases.first().copied().unwrap_or(0.0);
373 let p1 = phases.get(1).copied().unwrap_or(0.0);
374 let p2 = phases.get(2).copied().unwrap_or(0.0);
375 let p3 = phases.get(3).copied().unwrap_or(0.0);
376
377 let a = (p3 + p0) / 2.0;
382 let b_plus_c = (p3 - p0) / 2.0;
383 let c_minus_b = (p2 - p1) / 2.0;
384 let b = (b_plus_c - c_minus_b) / 2.0;
385 let c = (b_plus_c + c_minus_b) / 2.0;
386
387 let mut coeffs = CartanCoefficients::new(a, b, c);
388 coeffs.canonicalize();
389 coeffs
390 }
391
392 fn compute_local_gates(
394 &self,
395 u: &Array2<Complex<f64>>,
396 _u_magic: &Array2<Complex<f64>>,
397 _p: &Array2<Complex<f64>>,
398 coeffs: &CartanCoefficients,
399 ) -> QuantRS2Result<(
400 (SingleQubitDecomposition, SingleQubitDecomposition),
401 (SingleQubitDecomposition, SingleQubitDecomposition),
402 )> {
403 let _canonical = Self::build_canonical_gate(coeffs);
405
406 let a1 = u.slice(s![..2, ..2]).to_owned();
413 let b1 = u.slice(s![2..4, 2..4]).to_owned();
414
415 let left_a = decompose_single_qubit_zyz(&a1.view())?;
416 let left_b = decompose_single_qubit_zyz(&b1.view())?;
417
418 let ident = Array2::eye(2);
421 let right_a = decompose_single_qubit_zyz(&ident.view())?;
422 let right_b = decompose_single_qubit_zyz(&ident.view())?;
423
424 Ok(((left_a, left_b), (right_a, right_b)))
425 }
426
427 fn build_canonical_gate(coeffs: &CartanCoefficients) -> Array2<Complex<f64>> {
429 let a = coeffs.xx;
431 let b = coeffs.yy;
432 let c = coeffs.zz;
433
434 let cos_a = a.cos();
436 let sin_a = a.sin();
437 let cos_b = b.cos();
438 let sin_b = b.sin();
439 let cos_c = c.cos();
440 let sin_c = c.sin();
441
442 let mut result = Array2::zeros((4, 4));
444
445 result[[0, 0]] = Complex::new(cos_a * cos_b * cos_c, sin_c);
447 result[[0, 3]] = Complex::new(0.0, sin_a * cos_b * cos_c);
448 result[[1, 1]] = Complex::new(cos_a * cos_c, -sin_a * sin_b * sin_c);
449 result[[1, 2]] = Complex::new(0.0, cos_a.mul_add(sin_c, sin_a * sin_b * cos_c));
450 result[[2, 1]] = Complex::new(0.0, cos_a.mul_add(sin_c, -(sin_a * sin_b * cos_c)));
451 result[[2, 2]] = Complex::new(cos_a * cos_c, sin_a * sin_b * sin_c);
452 result[[3, 0]] = Complex::new(0.0, sin_a * cos_b * cos_c);
453 result[[3, 3]] = Complex::new(cos_a * cos_b * cos_c, -sin_c);
454
455 result
456 }
457
458 const fn compute_global_phase(
460 _u: &Array2<Complex<f64>>,
461 _left: &(SingleQubitDecomposition, SingleQubitDecomposition),
462 _right: &(SingleQubitDecomposition, SingleQubitDecomposition),
463 _coeffs: &CartanCoefficients,
464 ) -> QuantRS2Result<f64> {
465 Ok(0.0)
468 }
469
470 pub fn to_gates(
472 &self,
473 decomp: &CartanDecomposition,
474 qubit_ids: &[QubitId],
475 ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
476 if qubit_ids.len() != 2 {
477 return Err(QuantRS2Error::InvalidInput(
478 "Cartan decomposition requires exactly 2 qubits".to_string(),
479 ));
480 }
481
482 let q0 = qubit_ids[0];
483 let q1 = qubit_ids[1];
484 let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
485
486 gates.extend(self.single_qubit_to_gates(&decomp.left_gates.0, q0));
488 gates.extend(self.single_qubit_to_gates(&decomp.left_gates.1, q1));
489
490 gates.extend(self.canonical_to_gates(&decomp.interaction, q0, q1)?);
492
493 gates.extend(self.single_qubit_to_gates(&decomp.right_gates.0, q0));
495 gates.extend(self.single_qubit_to_gates(&decomp.right_gates.1, q1));
496
497 Ok(gates)
498 }
499
500 fn single_qubit_to_gates(
502 &self,
503 decomp: &SingleQubitDecomposition,
504 qubit: QubitId,
505 ) -> Vec<Box<dyn GateOp>> {
506 let mut gates = Vec::new();
507
508 if decomp.theta1.abs() > self.tolerance {
509 gates.push(Box::new(RotationZ {
510 target: qubit,
511 theta: decomp.theta1,
512 }) as Box<dyn GateOp>);
513 }
514
515 if decomp.phi.abs() > self.tolerance {
516 gates.push(Box::new(RotationY {
517 target: qubit,
518 theta: decomp.phi,
519 }) as Box<dyn GateOp>);
520 }
521
522 if decomp.theta2.abs() > self.tolerance {
523 gates.push(Box::new(RotationZ {
524 target: qubit,
525 theta: decomp.theta2,
526 }) as Box<dyn GateOp>);
527 }
528
529 gates
530 }
531
532 fn canonical_to_gates(
534 &self,
535 coeffs: &CartanCoefficients,
536 q0: QubitId,
537 q1: QubitId,
538 ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
539 let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
540 let cnots = coeffs.cnot_count(self.tolerance);
541
542 match cnots {
543 0 => {
544 }
546 1 => {
547 gates.push(Box::new(CNOT {
549 control: q0,
550 target: q1,
551 }));
552 }
553 2 => {
554 if coeffs.xx.abs() > self.tolerance {
557 gates.push(Box::new(RotationX {
558 target: q0,
559 theta: coeffs.xx * 2.0,
560 }));
561 }
562
563 gates.push(Box::new(CNOT {
564 control: q0,
565 target: q1,
566 }));
567
568 if coeffs.zz.abs() > self.tolerance {
569 gates.push(Box::new(RotationZ {
570 target: q1,
571 theta: coeffs.zz * 2.0,
572 }));
573 }
574
575 gates.push(Box::new(CNOT {
576 control: q0,
577 target: q1,
578 }));
579 }
580 3 => {
581 gates.push(Box::new(CNOT {
583 control: q0,
584 target: q1,
585 }));
586
587 gates.push(Box::new(RotationZ {
588 target: q0,
589 theta: coeffs.xx * 2.0,
590 }));
591 gates.push(Box::new(RotationZ {
592 target: q1,
593 theta: coeffs.yy * 2.0,
594 }));
595
596 gates.push(Box::new(CNOT {
597 control: q1,
598 target: q0,
599 }));
600
601 gates.push(Box::new(RotationZ {
602 target: q0,
603 theta: coeffs.zz * 2.0,
604 }));
605
606 gates.push(Box::new(CNOT {
607 control: q0,
608 target: q1,
609 }));
610 }
611 _ => unreachable!("CNOT count should be 0-3"),
612 }
613
614 Ok(gates)
615 }
616}
617
618pub struct OptimizedCartanDecomposer {
620 pub base: CartanDecomposer,
621 optimize_special_cases: bool,
623 optimize_phase: bool,
625}
626
627impl OptimizedCartanDecomposer {
628 pub fn new() -> Self {
630 Self {
631 base: CartanDecomposer::new(),
632 optimize_special_cases: true,
633 optimize_phase: true,
634 }
635 }
636
637 pub fn decompose(
639 &mut self,
640 unitary: &Array2<Complex<f64>>,
641 ) -> QuantRS2Result<CartanDecomposition> {
642 if self.optimize_special_cases {
644 if let Some(special) = self.check_special_cases(unitary)? {
645 return Ok(special);
646 }
647 }
648
649 let mut decomp = self.base.decompose(unitary)?;
651
652 if self.optimize_phase {
654 self.optimize_global_phase(&mut decomp);
655 }
656
657 Ok(decomp)
658 }
659
660 fn check_special_cases(
662 &self,
663 unitary: &Array2<Complex<f64>>,
664 ) -> QuantRS2Result<Option<CartanDecomposition>> {
665 if self.is_cnot(unitary) {
667 return Ok(Some(Self::cnot_decomposition()));
668 }
669
670 if self.is_cz(unitary) {
672 return Ok(Some(Self::cz_decomposition()));
673 }
674
675 if self.is_swap(unitary) {
677 return Ok(Some(Self::swap_decomposition()));
678 }
679
680 Ok(None)
681 }
682
683 fn is_cnot(&self, u: &Array2<Complex<f64>>) -> bool {
685 let cnot = Array2::from_shape_vec(
686 (4, 4),
687 vec![
688 Complex::new(1.0, 0.0),
689 Complex::new(0.0, 0.0),
690 Complex::new(0.0, 0.0),
691 Complex::new(0.0, 0.0),
692 Complex::new(0.0, 0.0),
693 Complex::new(1.0, 0.0),
694 Complex::new(0.0, 0.0),
695 Complex::new(0.0, 0.0),
696 Complex::new(0.0, 0.0),
697 Complex::new(0.0, 0.0),
698 Complex::new(0.0, 0.0),
699 Complex::new(1.0, 0.0),
700 Complex::new(0.0, 0.0),
701 Complex::new(0.0, 0.0),
702 Complex::new(1.0, 0.0),
703 Complex::new(0.0, 0.0),
704 ],
705 )
706 .expect("Failed to create CNOT matrix in OptimizedCartanDecomposer::is_cnot");
707
708 self.matrices_equal(u, &cnot)
709 }
710
711 fn is_cz(&self, u: &Array2<Complex<f64>>) -> bool {
713 let cz = Array2::from_shape_vec(
714 (4, 4),
715 vec![
716 Complex::new(1.0, 0.0),
717 Complex::new(0.0, 0.0),
718 Complex::new(0.0, 0.0),
719 Complex::new(0.0, 0.0),
720 Complex::new(0.0, 0.0),
721 Complex::new(1.0, 0.0),
722 Complex::new(0.0, 0.0),
723 Complex::new(0.0, 0.0),
724 Complex::new(0.0, 0.0),
725 Complex::new(0.0, 0.0),
726 Complex::new(1.0, 0.0),
727 Complex::new(0.0, 0.0),
728 Complex::new(0.0, 0.0),
729 Complex::new(0.0, 0.0),
730 Complex::new(0.0, 0.0),
731 Complex::new(-1.0, 0.0),
732 ],
733 )
734 .expect("Failed to create CZ matrix in OptimizedCartanDecomposer::is_cz");
735
736 self.matrices_equal(u, &cz)
737 }
738
739 fn is_swap(&self, u: &Array2<Complex<f64>>) -> bool {
741 let swap = Array2::from_shape_vec(
742 (4, 4),
743 vec![
744 Complex::new(1.0, 0.0),
745 Complex::new(0.0, 0.0),
746 Complex::new(0.0, 0.0),
747 Complex::new(0.0, 0.0),
748 Complex::new(0.0, 0.0),
749 Complex::new(0.0, 0.0),
750 Complex::new(1.0, 0.0),
751 Complex::new(0.0, 0.0),
752 Complex::new(0.0, 0.0),
753 Complex::new(1.0, 0.0),
754 Complex::new(0.0, 0.0),
755 Complex::new(0.0, 0.0),
756 Complex::new(0.0, 0.0),
757 Complex::new(0.0, 0.0),
758 Complex::new(0.0, 0.0),
759 Complex::new(1.0, 0.0),
760 ],
761 )
762 .expect("Failed to create SWAP matrix in OptimizedCartanDecomposer::is_swap");
763
764 self.matrices_equal(u, &swap)
765 }
766
767 fn matrices_equal(&self, a: &Array2<Complex<f64>>, b: &Array2<Complex<f64>>) -> bool {
769 let mut phase = Complex::new(1.0, 0.0);
771 for i in 0..4 {
772 for j in 0..4 {
773 if b[[i, j]].norm() > self.base.tolerance {
774 phase = a[[i, j]] / b[[i, j]];
775 break;
776 }
777 }
778 }
779
780 for i in 0..4 {
782 for j in 0..4 {
783 if (a[[i, j]] - phase * b[[i, j]]).norm() > self.base.tolerance {
784 return false;
785 }
786 }
787 }
788
789 true
790 }
791
792 fn cnot_decomposition() -> CartanDecomposition {
794 let ident = Array2::eye(2);
795 let ident_decomp = decompose_single_qubit_zyz(&ident.view()).expect(
796 "Failed to decompose identity in OptimizedCartanDecomposer::cnot_decomposition",
797 );
798
799 CartanDecomposition {
800 left_gates: (ident_decomp.clone(), ident_decomp.clone()),
801 right_gates: (ident_decomp.clone(), ident_decomp),
802 interaction: CartanCoefficients::new(PI / 4.0, PI / 4.0, 0.0),
803 global_phase: 0.0,
804 }
805 }
806
807 fn cz_decomposition() -> CartanDecomposition {
809 let ident = Array2::eye(2);
810 let ident_decomp = decompose_single_qubit_zyz(&ident.view())
811 .expect("Failed to decompose identity in OptimizedCartanDecomposer::cz_decomposition");
812
813 CartanDecomposition {
814 left_gates: (ident_decomp.clone(), ident_decomp.clone()),
815 right_gates: (ident_decomp.clone(), ident_decomp),
816 interaction: CartanCoefficients::new(0.0, 0.0, PI / 4.0),
817 global_phase: 0.0,
818 }
819 }
820
821 fn swap_decomposition() -> CartanDecomposition {
823 let ident = Array2::eye(2);
824 let ident_decomp = decompose_single_qubit_zyz(&ident.view()).expect(
825 "Failed to decompose identity in OptimizedCartanDecomposer::swap_decomposition",
826 );
827
828 CartanDecomposition {
829 left_gates: (ident_decomp.clone(), ident_decomp.clone()),
830 right_gates: (ident_decomp.clone(), ident_decomp),
831 interaction: CartanCoefficients::new(PI / 4.0, PI / 4.0, PI / 4.0),
832 global_phase: 0.0,
833 }
834 }
835
836 fn optimize_global_phase(&self, decomp: &mut CartanDecomposition) {
838 if decomp.global_phase.abs() > self.base.tolerance {
840 decomp.left_gates.0.global_phase += decomp.global_phase;
841 decomp.global_phase = 0.0;
842 }
843 }
844}
845
846pub fn cartan_decompose(unitary: &Array2<Complex<f64>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
848 let mut decomposer = CartanDecomposer::new();
849 let decomp = decomposer.decompose(unitary)?;
850 let qubit_ids = vec![QubitId(0), QubitId(1)];
851 decomposer.to_gates(&decomp, &qubit_ids)
852}
853
854impl Default for OptimizedCartanDecomposer {
855 fn default() -> Self {
856 Self::new()
857 }
858}
859
860impl Default for CartanDecomposer {
861 fn default() -> Self {
862 Self::new()
863 }
864}
865
866#[cfg(test)]
867mod tests {
868 use super::*;
869 use scirs2_core::Complex;
870
871 #[test]
872 fn test_cartan_coefficients() {
873 let coeffs = CartanCoefficients::new(0.1, 0.2, 0.3);
874 assert!(!coeffs.is_identity(1e-10));
875 assert_eq!(coeffs.cnot_count(1e-10), 3);
876
877 let zero_coeffs = CartanCoefficients::new(0.0, 0.0, 0.0);
878 assert!(zero_coeffs.is_identity(1e-10));
879 assert_eq!(zero_coeffs.cnot_count(1e-10), 0);
880 }
881
882 #[test]
883 fn test_cartan_cnot() {
884 let mut decomposer = CartanDecomposer::new();
885
886 let cnot = Array2::from_shape_vec(
888 (4, 4),
889 vec![
890 Complex::new(1.0, 0.0),
891 Complex::new(0.0, 0.0),
892 Complex::new(0.0, 0.0),
893 Complex::new(0.0, 0.0),
894 Complex::new(0.0, 0.0),
895 Complex::new(1.0, 0.0),
896 Complex::new(0.0, 0.0),
897 Complex::new(0.0, 0.0),
898 Complex::new(0.0, 0.0),
899 Complex::new(0.0, 0.0),
900 Complex::new(0.0, 0.0),
901 Complex::new(1.0, 0.0),
902 Complex::new(0.0, 0.0),
903 Complex::new(0.0, 0.0),
904 Complex::new(1.0, 0.0),
905 Complex::new(0.0, 0.0),
906 ],
907 )
908 .expect("Failed to create CNOT matrix in test_cartan_cnot");
909
910 let decomp = decomposer
911 .decompose(&cnot)
912 .expect("Failed to decompose CNOT in test_cartan_cnot");
913
914 assert!(decomp.interaction.cnot_count(1e-10) <= 1);
916 }
917
918 #[test]
919 fn test_optimized_special_cases() {
920 let mut opt_decomposer = OptimizedCartanDecomposer::new();
921
922 let swap = Array2::from_shape_vec(
924 (4, 4),
925 vec![
926 Complex::new(1.0, 0.0),
927 Complex::new(0.0, 0.0),
928 Complex::new(0.0, 0.0),
929 Complex::new(0.0, 0.0),
930 Complex::new(0.0, 0.0),
931 Complex::new(0.0, 0.0),
932 Complex::new(1.0, 0.0),
933 Complex::new(0.0, 0.0),
934 Complex::new(0.0, 0.0),
935 Complex::new(1.0, 0.0),
936 Complex::new(0.0, 0.0),
937 Complex::new(0.0, 0.0),
938 Complex::new(0.0, 0.0),
939 Complex::new(0.0, 0.0),
940 Complex::new(0.0, 0.0),
941 Complex::new(1.0, 0.0),
942 ],
943 )
944 .expect("Failed to create SWAP matrix in test_optimized_special_cases");
945
946 let decomp = opt_decomposer
947 .decompose(&swap)
948 .expect("Failed to decompose SWAP in test_optimized_special_cases");
949
950 assert_eq!(decomp.interaction.cnot_count(1e-10), 3);
952 }
953
954 #[test]
955 fn test_cartan_identity() {
956 let mut decomposer = CartanDecomposer::new();
957
958 let identity = Array2::eye(4);
960 let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
961
962 let decomp = decomposer
963 .decompose(&identity_complex)
964 .expect("Failed to decompose identity in test_cartan_identity");
965
966 assert!(decomp.interaction.is_identity(1e-10));
968 assert_eq!(decomp.interaction.cnot_count(1e-10), 0);
969 }
970}