quantrs2_core/
cartan.rs

1//! Cartan (KAK) decomposition for two-qubit unitaries
2//!
3//! This module implements the Cartan decomposition, which decomposes any
4//! two-qubit unitary into a canonical form with at most 3 CNOT gates.
5//! The decomposition has the form:
6//!
7//! U = (A₁ ⊗ B₁) · exp(i(aXX + bYY + cZZ)) · (A₂ ⊗ B₂)
8//!
9//! where A₁, B₁, A₂, B₂ are single-qubit unitaries and a, b, c are real.
10
11use crate::{
12    error::{QuantRS2Error, QuantRS2Result},
13    gate::{multi::*, single::*, GateOp},
14    matrix_ops::{DenseMatrix, QuantumMatrix},
15    qubit::QubitId,
16    synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
17};
18use rustc_hash::FxHashMap;
19use scirs2_core::ndarray::{s, Array1, Array2};
20use scirs2_core::Complex;
21use std::f64::consts::PI;
22
23/// Result of Cartan decomposition for a two-qubit unitary
24#[derive(Debug, Clone)]
25pub struct CartanDecomposition {
26    /// Left single-qubit gates (A₁, B₁)
27    pub left_gates: (SingleQubitDecomposition, SingleQubitDecomposition),
28    /// Right single-qubit gates (A₂, B₂)
29    pub right_gates: (SingleQubitDecomposition, SingleQubitDecomposition),
30    /// Interaction coefficients (a, b, c) for exp(i(aXX + bYY + cZZ))
31    pub interaction: CartanCoefficients,
32    /// Global phase
33    pub global_phase: f64,
34}
35
36/// Cartan interaction coefficients
37#[derive(Debug, Clone, Copy)]
38pub struct CartanCoefficients {
39    /// Coefficient for XX interaction
40    pub xx: f64,
41    /// Coefficient for YY interaction
42    pub yy: f64,
43    /// Coefficient for ZZ interaction
44    pub zz: f64,
45}
46
47impl CartanCoefficients {
48    /// Create new coefficients
49    pub const fn new(xx: f64, yy: f64, zz: f64) -> Self {
50        Self { xx, yy, zz }
51    }
52
53    /// Check if this is equivalent to identity (all coefficients near zero)
54    pub fn is_identity(&self, tolerance: f64) -> bool {
55        self.xx.abs() < tolerance && self.yy.abs() < tolerance && self.zz.abs() < tolerance
56    }
57
58    /// Get the number of CNOTs required
59    pub fn cnot_count(&self, tolerance: f64) -> usize {
60        let eps = tolerance;
61
62        // Special cases based on coefficients
63        if self.is_identity(eps) {
64            0
65        } else if (self.xx - self.yy).abs() < eps && self.zz.abs() < eps {
66            // a = b, c = 0: Can be done with 2 CNOTs
67            2
68        } else if (self.xx - PI / 4.0).abs() < eps
69            && (self.yy - PI / 4.0).abs() < eps
70            && (self.zz - PI / 4.0).abs() < eps
71        {
72            // Maximally entangling: exactly 3 CNOTs
73            3
74        } else if self.xx.abs() < eps || self.yy.abs() < eps || self.zz.abs() < eps {
75            // One coefficient is zero: 2 CNOTs
76            2
77        } else {
78            // General case: 3 CNOTs
79            3
80        }
81    }
82
83    /// Convert to canonical form with ordered coefficients
84    pub fn canonicalize(&mut self) {
85        // Ensure |xx| >= |yy| >= |zz| by permutation
86        let mut vals = [
87            (self.xx.abs(), self.xx, 0),
88            (self.yy.abs(), self.yy, 1),
89            (self.zz.abs(), self.zz, 2),
90        ];
91        vals.sort_by(|a, b| {
92            b.0.partial_cmp(&a.0)
93                .expect("Failed to compare Cartan coefficients in CartanCoefficients::canonicalize")
94        });
95
96        self.xx = vals[0].1;
97        self.yy = vals[1].1;
98        self.zz = vals[2].1;
99    }
100}
101
102/// Cartan decomposer for two-qubit gates
103pub struct CartanDecomposer {
104    /// Tolerance for numerical comparisons
105    tolerance: f64,
106    /// Cache for common gates
107    #[allow(dead_code)]
108    cache: FxHashMap<u64, CartanDecomposition>,
109}
110
111impl CartanDecomposer {
112    /// Create a new Cartan decomposer
113    pub fn new() -> Self {
114        Self {
115            tolerance: 1e-10,
116            cache: FxHashMap::default(),
117        }
118    }
119
120    /// Create with custom tolerance
121    pub fn with_tolerance(tolerance: f64) -> Self {
122        Self {
123            tolerance,
124            cache: FxHashMap::default(),
125        }
126    }
127
128    /// Decompose a two-qubit unitary using Cartan decomposition
129    pub fn decompose(
130        &mut self,
131        unitary: &Array2<Complex<f64>>,
132    ) -> QuantRS2Result<CartanDecomposition> {
133        // Validate input
134        if unitary.shape() != [4, 4] {
135            return Err(QuantRS2Error::InvalidInput(
136                "Cartan decomposition requires 4x4 unitary".to_string(),
137            ));
138        }
139
140        // Check unitarity
141        let mat = DenseMatrix::new(unitary.clone())?;
142        if !mat.is_unitary(self.tolerance)? {
143            return Err(QuantRS2Error::InvalidInput(
144                "Matrix is not unitary".to_string(),
145            ));
146        }
147
148        // Transform to magic basis
149        let magic_basis = Self::get_magic_basis();
150        let u_magic = Self::to_magic_basis(unitary, &magic_basis);
151
152        // Compute M = U_magic^T · U_magic
153        let u_magic_t = u_magic.t().to_owned();
154        let m = u_magic_t.dot(&u_magic);
155
156        // Diagonalize M to find the canonical form
157        let (d, p) = Self::diagonalize_symmetric(&m)?;
158
159        // Extract interaction coefficients from eigenvalues
160        let coeffs = Self::extract_coefficients(&d);
161
162        // Compute single-qubit gates
163        let (left_gates, right_gates) = self.compute_local_gates(unitary, &u_magic, &p, &coeffs)?;
164
165        // Compute global phase
166        let global_phase = Self::compute_global_phase(unitary, &left_gates, &right_gates, &coeffs)?;
167
168        Ok(CartanDecomposition {
169            left_gates,
170            right_gates,
171            interaction: coeffs,
172            global_phase,
173        })
174    }
175
176    /// Get the magic basis transformation matrix
177    fn get_magic_basis() -> Array2<Complex<f64>> {
178        let sqrt2 = 2.0_f64.sqrt();
179        Array2::from_shape_vec(
180            (4, 4),
181            vec![
182                Complex::new(1.0, 0.0),
183                Complex::new(0.0, 0.0),
184                Complex::new(0.0, 0.0),
185                Complex::new(1.0, 0.0),
186                Complex::new(0.0, 0.0),
187                Complex::new(1.0, 0.0),
188                Complex::new(1.0, 0.0),
189                Complex::new(0.0, 0.0),
190                Complex::new(0.0, 0.0),
191                Complex::new(1.0, 0.0),
192                Complex::new(-1.0, 0.0),
193                Complex::new(0.0, 0.0),
194                Complex::new(1.0, 0.0),
195                Complex::new(0.0, 0.0),
196                Complex::new(0.0, 0.0),
197                Complex::new(-1.0, 0.0),
198            ],
199        )
200        .expect("Failed to create magic basis matrix in CartanDecomposer::get_magic_basis")
201            / Complex::new(sqrt2, 0.0)
202    }
203
204    /// Transform matrix to magic basis
205    fn to_magic_basis(
206        u: &Array2<Complex<f64>>,
207        magic: &Array2<Complex<f64>>,
208    ) -> Array2<Complex<f64>> {
209        let magic_dag = magic.mapv(|z| z.conj()).t().to_owned();
210        magic_dag.dot(u).dot(magic)
211    }
212
213    /// Diagonalize a symmetric complex matrix
214    fn diagonalize_symmetric(
215        m: &Array2<Complex<f64>>,
216    ) -> QuantRS2Result<(Array1<f64>, Array2<Complex<f64>>)> {
217        // For a unitary in magic basis, M = U^T U has special structure
218        // Its eigenvalues come in pairs (λ, λ*) and determine the interaction
219
220        // Simplified diagonalization for 4x4 case
221        // In practice, would use proper eigendecomposition
222
223        // For now, extract diagonal approximation
224        let mut eigenvalues = Array1::zeros(4);
225        let eigenvectors = Array2::eye(4);
226
227        // This is a placeholder - real implementation needs eigendecomposition
228        for i in 0..4 {
229            eigenvalues[i] = m[[i, i]].norm();
230        }
231
232        Ok((eigenvalues, eigenvectors))
233    }
234
235    /// Extract Cartan coefficients from eigenvalues
236    fn extract_coefficients(eigenvalues: &Array1<f64>) -> CartanCoefficients {
237        // The eigenvalues of M determine the interaction coefficients
238        // For U = exp(i(aXX + bYY + cZZ)), the eigenvalues are:
239        // exp(2i(a+b+c)), exp(2i(a-b-c)), exp(2i(-a+b-c)), exp(2i(-a-b+c))
240
241        // Extract phases from eigenvalues
242        let phases: Vec<f64> = eigenvalues
243            .iter()
244            .map(|&lambda| lambda.ln() / 2.0)
245            .collect();
246
247        // Solve for a, b, c
248        // This is simplified - proper implementation uses the correct formula
249        let a = (phases[0] - phases[3]) / 4.0;
250        let b = (phases[0] - phases[2]) / 4.0;
251        let c = (phases[0] - phases[1]) / 4.0;
252
253        let mut coeffs = CartanCoefficients::new(a, b, c);
254        coeffs.canonicalize();
255        coeffs
256    }
257
258    /// Compute single-qubit gates from decomposition
259    fn compute_local_gates(
260        &self,
261        u: &Array2<Complex<f64>>,
262        _u_magic: &Array2<Complex<f64>>,
263        _p: &Array2<Complex<f64>>,
264        coeffs: &CartanCoefficients,
265    ) -> QuantRS2Result<(
266        (SingleQubitDecomposition, SingleQubitDecomposition),
267        (SingleQubitDecomposition, SingleQubitDecomposition),
268    )> {
269        // Build the canonical gate
270        let _canonical = Self::build_canonical_gate(coeffs);
271
272        // The local gates satisfy:
273        // U = (A₁ ⊗ B₁) · canonical · (A₂ ⊗ B₂)
274
275        // Extract 2x2 blocks to find single-qubit gates
276        // This is simplified - proper implementation uses the full KAK theorem
277
278        let a1 = u.slice(s![..2, ..2]).to_owned();
279        let b1 = u.slice(s![2..4, 2..4]).to_owned();
280
281        let left_a = decompose_single_qubit_zyz(&a1.view())?;
282        let left_b = decompose_single_qubit_zyz(&b1.view())?;
283
284        // For right gates, we'd compute from the decomposition
285        // For now, use identity
286        let ident = Array2::eye(2);
287        let right_a = decompose_single_qubit_zyz(&ident.view())?;
288        let right_b = decompose_single_qubit_zyz(&ident.view())?;
289
290        Ok(((left_a, left_b), (right_a, right_b)))
291    }
292
293    /// Build the canonical gate from coefficients
294    fn build_canonical_gate(coeffs: &CartanCoefficients) -> Array2<Complex<f64>> {
295        // exp(i(aXX + bYY + cZZ))
296        let a = coeffs.xx;
297        let b = coeffs.yy;
298        let c = coeffs.zz;
299
300        // Direct computation of matrix exponential for this special form
301        let cos_a = a.cos();
302        let sin_a = a.sin();
303        let cos_b = b.cos();
304        let sin_b = b.sin();
305        let cos_c = c.cos();
306        let sin_c = c.sin();
307
308        // Build the 4x4 matrix
309        let mut result = Array2::zeros((4, 4));
310
311        // This is the explicit form of exp(i(aXX + bYY + cZZ))
312        result[[0, 0]] = Complex::new(cos_a * cos_b * cos_c, sin_c);
313        result[[0, 3]] = Complex::new(0.0, sin_a * cos_b * cos_c);
314        result[[1, 1]] = Complex::new(cos_a * cos_c, -sin_a * sin_b * sin_c);
315        result[[1, 2]] = Complex::new(0.0, cos_a.mul_add(sin_c, sin_a * sin_b * cos_c));
316        result[[2, 1]] = Complex::new(0.0, cos_a.mul_add(sin_c, -(sin_a * sin_b * cos_c)));
317        result[[2, 2]] = Complex::new(cos_a * cos_c, sin_a * sin_b * sin_c);
318        result[[3, 0]] = Complex::new(0.0, sin_a * cos_b * cos_c);
319        result[[3, 3]] = Complex::new(cos_a * cos_b * cos_c, -sin_c);
320
321        result
322    }
323
324    /// Compute global phase
325    const fn compute_global_phase(
326        _u: &Array2<Complex<f64>>,
327        _left: &(SingleQubitDecomposition, SingleQubitDecomposition),
328        _right: &(SingleQubitDecomposition, SingleQubitDecomposition),
329        _coeffs: &CartanCoefficients,
330    ) -> QuantRS2Result<f64> {
331        // Global phase is the phase difference between U and the reconstructed gate
332        // For now, return 0
333        Ok(0.0)
334    }
335
336    /// Convert Cartan decomposition to gate sequence
337    pub fn to_gates(
338        &self,
339        decomp: &CartanDecomposition,
340        qubit_ids: &[QubitId],
341    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
342        if qubit_ids.len() != 2 {
343            return Err(QuantRS2Error::InvalidInput(
344                "Cartan decomposition requires exactly 2 qubits".to_string(),
345            ));
346        }
347
348        let q0 = qubit_ids[0];
349        let q1 = qubit_ids[1];
350        let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
351
352        // Left single-qubit gates
353        gates.extend(self.single_qubit_to_gates(&decomp.left_gates.0, q0));
354        gates.extend(self.single_qubit_to_gates(&decomp.left_gates.1, q1));
355
356        // Canonical two-qubit gate
357        gates.extend(self.canonical_to_gates(&decomp.interaction, q0, q1)?);
358
359        // Right single-qubit gates
360        gates.extend(self.single_qubit_to_gates(&decomp.right_gates.0, q0));
361        gates.extend(self.single_qubit_to_gates(&decomp.right_gates.1, q1));
362
363        Ok(gates)
364    }
365
366    /// Convert single-qubit decomposition to gates
367    fn single_qubit_to_gates(
368        &self,
369        decomp: &SingleQubitDecomposition,
370        qubit: QubitId,
371    ) -> Vec<Box<dyn GateOp>> {
372        let mut gates = Vec::new();
373
374        if decomp.theta1.abs() > self.tolerance {
375            gates.push(Box::new(RotationZ {
376                target: qubit,
377                theta: decomp.theta1,
378            }) as Box<dyn GateOp>);
379        }
380
381        if decomp.phi.abs() > self.tolerance {
382            gates.push(Box::new(RotationY {
383                target: qubit,
384                theta: decomp.phi,
385            }) as Box<dyn GateOp>);
386        }
387
388        if decomp.theta2.abs() > self.tolerance {
389            gates.push(Box::new(RotationZ {
390                target: qubit,
391                theta: decomp.theta2,
392            }) as Box<dyn GateOp>);
393        }
394
395        gates
396    }
397
398    /// Convert canonical coefficients to gate sequence
399    fn canonical_to_gates(
400        &self,
401        coeffs: &CartanCoefficients,
402        q0: QubitId,
403        q1: QubitId,
404    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
405        let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
406        let cnots = coeffs.cnot_count(self.tolerance);
407
408        match cnots {
409            0 => {
410                // Identity - no gates needed
411            }
412            1 => {
413                // Special case: can be done with 1 CNOT
414                gates.push(Box::new(CNOT {
415                    control: q0,
416                    target: q1,
417                }));
418            }
419            2 => {
420                // Can be done with 2 CNOTs
421                // Add rotations
422                if coeffs.xx.abs() > self.tolerance {
423                    gates.push(Box::new(RotationX {
424                        target: q0,
425                        theta: coeffs.xx * 2.0,
426                    }));
427                }
428
429                gates.push(Box::new(CNOT {
430                    control: q0,
431                    target: q1,
432                }));
433
434                if coeffs.zz.abs() > self.tolerance {
435                    gates.push(Box::new(RotationZ {
436                        target: q1,
437                        theta: coeffs.zz * 2.0,
438                    }));
439                }
440
441                gates.push(Box::new(CNOT {
442                    control: q0,
443                    target: q1,
444                }));
445            }
446            3 => {
447                // General case: 3 CNOTs with intermediate rotations
448                gates.push(Box::new(CNOT {
449                    control: q0,
450                    target: q1,
451                }));
452
453                gates.push(Box::new(RotationZ {
454                    target: q0,
455                    theta: coeffs.xx * 2.0,
456                }));
457                gates.push(Box::new(RotationZ {
458                    target: q1,
459                    theta: coeffs.yy * 2.0,
460                }));
461
462                gates.push(Box::new(CNOT {
463                    control: q1,
464                    target: q0,
465                }));
466
467                gates.push(Box::new(RotationZ {
468                    target: q0,
469                    theta: coeffs.zz * 2.0,
470                }));
471
472                gates.push(Box::new(CNOT {
473                    control: q0,
474                    target: q1,
475                }));
476            }
477            _ => unreachable!("CNOT count should be 0-3"),
478        }
479
480        Ok(gates)
481    }
482}
483
484/// Optimized Cartan decomposer with special case handling
485pub struct OptimizedCartanDecomposer {
486    pub base: CartanDecomposer,
487    /// Enable special case optimizations
488    optimize_special_cases: bool,
489    /// Enable phase optimization
490    optimize_phase: bool,
491}
492
493impl OptimizedCartanDecomposer {
494    /// Create new optimized decomposer
495    pub fn new() -> Self {
496        Self {
497            base: CartanDecomposer::new(),
498            optimize_special_cases: true,
499            optimize_phase: true,
500        }
501    }
502
503    /// Decompose with optimizations
504    pub fn decompose(
505        &mut self,
506        unitary: &Array2<Complex<f64>>,
507    ) -> QuantRS2Result<CartanDecomposition> {
508        // Check for special cases first
509        if self.optimize_special_cases {
510            if let Some(special) = self.check_special_cases(unitary)? {
511                return Ok(special);
512            }
513        }
514
515        // Use base decomposition
516        let mut decomp = self.base.decompose(unitary)?;
517
518        // Optimize phase if enabled
519        if self.optimize_phase {
520            self.optimize_global_phase(&mut decomp);
521        }
522
523        Ok(decomp)
524    }
525
526    /// Check for special gate cases
527    fn check_special_cases(
528        &self,
529        unitary: &Array2<Complex<f64>>,
530    ) -> QuantRS2Result<Option<CartanDecomposition>> {
531        // Check for CNOT
532        if self.is_cnot(unitary) {
533            return Ok(Some(Self::cnot_decomposition()));
534        }
535
536        // Check for controlled-Z
537        if self.is_cz(unitary) {
538            return Ok(Some(Self::cz_decomposition()));
539        }
540
541        // Check for SWAP
542        if self.is_swap(unitary) {
543            return Ok(Some(Self::swap_decomposition()));
544        }
545
546        Ok(None)
547    }
548
549    /// Check if matrix is CNOT
550    fn is_cnot(&self, u: &Array2<Complex<f64>>) -> bool {
551        let cnot = Array2::from_shape_vec(
552            (4, 4),
553            vec![
554                Complex::new(1.0, 0.0),
555                Complex::new(0.0, 0.0),
556                Complex::new(0.0, 0.0),
557                Complex::new(0.0, 0.0),
558                Complex::new(0.0, 0.0),
559                Complex::new(1.0, 0.0),
560                Complex::new(0.0, 0.0),
561                Complex::new(0.0, 0.0),
562                Complex::new(0.0, 0.0),
563                Complex::new(0.0, 0.0),
564                Complex::new(0.0, 0.0),
565                Complex::new(1.0, 0.0),
566                Complex::new(0.0, 0.0),
567                Complex::new(0.0, 0.0),
568                Complex::new(1.0, 0.0),
569                Complex::new(0.0, 0.0),
570            ],
571        )
572        .expect("Failed to create CNOT matrix in OptimizedCartanDecomposer::is_cnot");
573
574        self.matrices_equal(u, &cnot)
575    }
576
577    /// Check if matrix is CZ
578    fn is_cz(&self, u: &Array2<Complex<f64>>) -> bool {
579        let cz = Array2::from_shape_vec(
580            (4, 4),
581            vec![
582                Complex::new(1.0, 0.0),
583                Complex::new(0.0, 0.0),
584                Complex::new(0.0, 0.0),
585                Complex::new(0.0, 0.0),
586                Complex::new(0.0, 0.0),
587                Complex::new(1.0, 0.0),
588                Complex::new(0.0, 0.0),
589                Complex::new(0.0, 0.0),
590                Complex::new(0.0, 0.0),
591                Complex::new(0.0, 0.0),
592                Complex::new(1.0, 0.0),
593                Complex::new(0.0, 0.0),
594                Complex::new(0.0, 0.0),
595                Complex::new(0.0, 0.0),
596                Complex::new(0.0, 0.0),
597                Complex::new(-1.0, 0.0),
598            ],
599        )
600        .expect("Failed to create CZ matrix in OptimizedCartanDecomposer::is_cz");
601
602        self.matrices_equal(u, &cz)
603    }
604
605    /// Check if matrix is SWAP
606    fn is_swap(&self, u: &Array2<Complex<f64>>) -> bool {
607        let swap = Array2::from_shape_vec(
608            (4, 4),
609            vec![
610                Complex::new(1.0, 0.0),
611                Complex::new(0.0, 0.0),
612                Complex::new(0.0, 0.0),
613                Complex::new(0.0, 0.0),
614                Complex::new(0.0, 0.0),
615                Complex::new(0.0, 0.0),
616                Complex::new(1.0, 0.0),
617                Complex::new(0.0, 0.0),
618                Complex::new(0.0, 0.0),
619                Complex::new(1.0, 0.0),
620                Complex::new(0.0, 0.0),
621                Complex::new(0.0, 0.0),
622                Complex::new(0.0, 0.0),
623                Complex::new(0.0, 0.0),
624                Complex::new(0.0, 0.0),
625                Complex::new(1.0, 0.0),
626            ],
627        )
628        .expect("Failed to create SWAP matrix in OptimizedCartanDecomposer::is_swap");
629
630        self.matrices_equal(u, &swap)
631    }
632
633    /// Check matrix equality up to global phase
634    fn matrices_equal(&self, a: &Array2<Complex<f64>>, b: &Array2<Complex<f64>>) -> bool {
635        // Find first non-zero element
636        let mut phase = Complex::new(1.0, 0.0);
637        for i in 0..4 {
638            for j in 0..4 {
639                if b[[i, j]].norm() > self.base.tolerance {
640                    phase = a[[i, j]] / b[[i, j]];
641                    break;
642                }
643            }
644        }
645
646        // Check all elements match up to phase
647        for i in 0..4 {
648            for j in 0..4 {
649                if (a[[i, j]] - phase * b[[i, j]]).norm() > self.base.tolerance {
650                    return false;
651                }
652            }
653        }
654
655        true
656    }
657
658    /// Decomposition for CNOT
659    fn cnot_decomposition() -> CartanDecomposition {
660        let ident = Array2::eye(2);
661        let ident_decomp = decompose_single_qubit_zyz(&ident.view()).expect(
662            "Failed to decompose identity in OptimizedCartanDecomposer::cnot_decomposition",
663        );
664
665        CartanDecomposition {
666            left_gates: (ident_decomp.clone(), ident_decomp.clone()),
667            right_gates: (ident_decomp.clone(), ident_decomp),
668            interaction: CartanCoefficients::new(PI / 4.0, PI / 4.0, 0.0),
669            global_phase: 0.0,
670        }
671    }
672
673    /// Decomposition for CZ
674    fn cz_decomposition() -> CartanDecomposition {
675        let ident = Array2::eye(2);
676        let ident_decomp = decompose_single_qubit_zyz(&ident.view())
677            .expect("Failed to decompose identity in OptimizedCartanDecomposer::cz_decomposition");
678
679        CartanDecomposition {
680            left_gates: (ident_decomp.clone(), ident_decomp.clone()),
681            right_gates: (ident_decomp.clone(), ident_decomp),
682            interaction: CartanCoefficients::new(0.0, 0.0, PI / 4.0),
683            global_phase: 0.0,
684        }
685    }
686
687    /// Decomposition for SWAP
688    fn swap_decomposition() -> CartanDecomposition {
689        let ident = Array2::eye(2);
690        let ident_decomp = decompose_single_qubit_zyz(&ident.view()).expect(
691            "Failed to decompose identity in OptimizedCartanDecomposer::swap_decomposition",
692        );
693
694        CartanDecomposition {
695            left_gates: (ident_decomp.clone(), ident_decomp.clone()),
696            right_gates: (ident_decomp.clone(), ident_decomp),
697            interaction: CartanCoefficients::new(PI / 4.0, PI / 4.0, PI / 4.0),
698            global_phase: 0.0,
699        }
700    }
701
702    /// Optimize global phase
703    fn optimize_global_phase(&self, decomp: &mut CartanDecomposition) {
704        // Absorb global phase into one of the single-qubit gates
705        if decomp.global_phase.abs() > self.base.tolerance {
706            decomp.left_gates.0.global_phase += decomp.global_phase;
707            decomp.global_phase = 0.0;
708        }
709    }
710}
711
712/// Utility function for quick Cartan decomposition
713pub fn cartan_decompose(unitary: &Array2<Complex<f64>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
714    let mut decomposer = CartanDecomposer::new();
715    let decomp = decomposer.decompose(unitary)?;
716    let qubit_ids = vec![QubitId(0), QubitId(1)];
717    decomposer.to_gates(&decomp, &qubit_ids)
718}
719
720impl Default for OptimizedCartanDecomposer {
721    fn default() -> Self {
722        Self::new()
723    }
724}
725
726impl Default for CartanDecomposer {
727    fn default() -> Self {
728        Self::new()
729    }
730}
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735    use scirs2_core::Complex;
736
737    #[test]
738    fn test_cartan_coefficients() {
739        let coeffs = CartanCoefficients::new(0.1, 0.2, 0.3);
740        assert!(!coeffs.is_identity(1e-10));
741        assert_eq!(coeffs.cnot_count(1e-10), 3);
742
743        let zero_coeffs = CartanCoefficients::new(0.0, 0.0, 0.0);
744        assert!(zero_coeffs.is_identity(1e-10));
745        assert_eq!(zero_coeffs.cnot_count(1e-10), 0);
746    }
747
748    #[test]
749    fn test_cartan_cnot() {
750        let mut decomposer = CartanDecomposer::new();
751
752        // CNOT matrix
753        let cnot = Array2::from_shape_vec(
754            (4, 4),
755            vec![
756                Complex::new(1.0, 0.0),
757                Complex::new(0.0, 0.0),
758                Complex::new(0.0, 0.0),
759                Complex::new(0.0, 0.0),
760                Complex::new(0.0, 0.0),
761                Complex::new(1.0, 0.0),
762                Complex::new(0.0, 0.0),
763                Complex::new(0.0, 0.0),
764                Complex::new(0.0, 0.0),
765                Complex::new(0.0, 0.0),
766                Complex::new(0.0, 0.0),
767                Complex::new(1.0, 0.0),
768                Complex::new(0.0, 0.0),
769                Complex::new(0.0, 0.0),
770                Complex::new(1.0, 0.0),
771                Complex::new(0.0, 0.0),
772            ],
773        )
774        .expect("Failed to create CNOT matrix in test_cartan_cnot");
775
776        let decomp = decomposer
777            .decompose(&cnot)
778            .expect("Failed to decompose CNOT in test_cartan_cnot");
779
780        // CNOT should have specific interaction coefficients
781        assert!(decomp.interaction.cnot_count(1e-10) <= 1);
782    }
783
784    #[test]
785    fn test_optimized_special_cases() {
786        let mut opt_decomposer = OptimizedCartanDecomposer::new();
787
788        // Test SWAP gate
789        let swap = Array2::from_shape_vec(
790            (4, 4),
791            vec![
792                Complex::new(1.0, 0.0),
793                Complex::new(0.0, 0.0),
794                Complex::new(0.0, 0.0),
795                Complex::new(0.0, 0.0),
796                Complex::new(0.0, 0.0),
797                Complex::new(0.0, 0.0),
798                Complex::new(1.0, 0.0),
799                Complex::new(0.0, 0.0),
800                Complex::new(0.0, 0.0),
801                Complex::new(1.0, 0.0),
802                Complex::new(0.0, 0.0),
803                Complex::new(0.0, 0.0),
804                Complex::new(0.0, 0.0),
805                Complex::new(0.0, 0.0),
806                Complex::new(0.0, 0.0),
807                Complex::new(1.0, 0.0),
808            ],
809        )
810        .expect("Failed to create SWAP matrix in test_optimized_special_cases");
811
812        let decomp = opt_decomposer
813            .decompose(&swap)
814            .expect("Failed to decompose SWAP in test_optimized_special_cases");
815
816        // SWAP requires exactly 3 CNOTs
817        assert_eq!(decomp.interaction.cnot_count(1e-10), 3);
818    }
819
820    #[test]
821    fn test_cartan_identity() {
822        let mut decomposer = CartanDecomposer::new();
823
824        // Identity matrix
825        let identity = Array2::eye(4);
826        let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
827
828        let decomp = decomposer
829            .decompose(&identity_complex)
830            .expect("Failed to decompose identity in test_cartan_identity");
831
832        // Identity should have zero interaction
833        assert!(decomp.interaction.is_identity(1e-10));
834        assert_eq!(decomp.interaction.cnot_count(1e-10), 0);
835    }
836}