Skip to main content

quantrs2_core/
cartan.rs

1//! Cartan (KAK) decomposition for two-qubit unitaries
2//!
3//! This module implements the Cartan decomposition, which decomposes any
4//! two-qubit unitary into a canonical form with at most 3 CNOT gates.
5//! The decomposition has the form:
6//!
7//! U = (A₁ ⊗ B₁) · exp(i(aXX + bYY + cZZ)) · (A₂ ⊗ B₂)
8//!
9//! where A₁, B₁, A₂, B₂ are single-qubit unitaries and a, b, c are real.
10
11use crate::{
12    error::{QuantRS2Error, QuantRS2Result},
13    gate::{multi::*, single::*, GateOp},
14    matrix_ops::{DenseMatrix, QuantumMatrix},
15    qubit::QubitId,
16    synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
17};
18use rustc_hash::FxHashMap;
19use scirs2_core::ndarray::{s, Array1, Array2};
20use scirs2_core::Complex;
21use std::f64::consts::PI;
22
23/// Result of Cartan decomposition for a two-qubit unitary
24#[derive(Debug, Clone)]
25pub struct CartanDecomposition {
26    /// Left single-qubit gates (A₁, B₁)
27    pub left_gates: (SingleQubitDecomposition, SingleQubitDecomposition),
28    /// Right single-qubit gates (A₂, B₂)
29    pub right_gates: (SingleQubitDecomposition, SingleQubitDecomposition),
30    /// Interaction coefficients (a, b, c) for exp(i(aXX + bYY + cZZ))
31    pub interaction: CartanCoefficients,
32    /// Global phase
33    pub global_phase: f64,
34}
35
36/// Cartan interaction coefficients
37#[derive(Debug, Clone, Copy)]
38pub struct CartanCoefficients {
39    /// Coefficient for XX interaction
40    pub xx: f64,
41    /// Coefficient for YY interaction
42    pub yy: f64,
43    /// Coefficient for ZZ interaction
44    pub zz: f64,
45}
46
47impl CartanCoefficients {
48    /// Create new coefficients
49    pub const fn new(xx: f64, yy: f64, zz: f64) -> Self {
50        Self { xx, yy, zz }
51    }
52
53    /// Check if this is equivalent to identity (all coefficients near zero)
54    pub fn is_identity(&self, tolerance: f64) -> bool {
55        self.xx.abs() < tolerance && self.yy.abs() < tolerance && self.zz.abs() < tolerance
56    }
57
58    /// Get the number of CNOTs required
59    pub fn cnot_count(&self, tolerance: f64) -> usize {
60        let eps = tolerance;
61
62        // Special cases based on coefficients
63        if self.is_identity(eps) {
64            0
65        } else if (self.xx - self.yy).abs() < eps && self.zz.abs() < eps {
66            // a = b, c = 0: Can be done with 2 CNOTs
67            2
68        } else if (self.xx - PI / 4.0).abs() < eps
69            && (self.yy - PI / 4.0).abs() < eps
70            && (self.zz - PI / 4.0).abs() < eps
71        {
72            // Maximally entangling: exactly 3 CNOTs
73            3
74        } else if self.xx.abs() < eps || self.yy.abs() < eps || self.zz.abs() < eps {
75            // One coefficient is zero: 2 CNOTs
76            2
77        } else {
78            // General case: 3 CNOTs
79            3
80        }
81    }
82
83    /// Convert to canonical form with ordered coefficients
84    pub fn canonicalize(&mut self) {
85        // Ensure |xx| >= |yy| >= |zz| by permutation
86        let mut vals = [
87            (self.xx.abs(), self.xx, 0),
88            (self.yy.abs(), self.yy, 1),
89            (self.zz.abs(), self.zz, 2),
90        ];
91        vals.sort_by(|a, b| {
92            b.0.partial_cmp(&a.0)
93                .expect("Failed to compare Cartan coefficients in CartanCoefficients::canonicalize")
94        });
95
96        self.xx = vals[0].1;
97        self.yy = vals[1].1;
98        self.zz = vals[2].1;
99    }
100}
101
102/// Cartan decomposer for two-qubit gates
103pub struct CartanDecomposer {
104    /// Tolerance for numerical comparisons
105    tolerance: f64,
106    /// Cache for common gates
107    #[allow(dead_code)]
108    cache: FxHashMap<u64, CartanDecomposition>,
109}
110
111impl CartanDecomposer {
112    /// Create a new Cartan decomposer
113    pub fn new() -> Self {
114        Self {
115            tolerance: 1e-10,
116            cache: FxHashMap::default(),
117        }
118    }
119
120    /// Create with custom tolerance
121    pub fn with_tolerance(tolerance: f64) -> Self {
122        Self {
123            tolerance,
124            cache: FxHashMap::default(),
125        }
126    }
127
128    /// Decompose a two-qubit unitary using Cartan decomposition
129    pub fn decompose(
130        &mut self,
131        unitary: &Array2<Complex<f64>>,
132    ) -> QuantRS2Result<CartanDecomposition> {
133        // Validate input
134        if unitary.shape() != [4, 4] {
135            return Err(QuantRS2Error::InvalidInput(
136                "Cartan decomposition requires 4x4 unitary".to_string(),
137            ));
138        }
139
140        // Check unitarity
141        let mat = DenseMatrix::new(unitary.clone())?;
142        if !mat.is_unitary(self.tolerance)? {
143            return Err(QuantRS2Error::InvalidInput(
144                "Matrix is not unitary".to_string(),
145            ));
146        }
147
148        // Transform to magic basis
149        let magic_basis = Self::get_magic_basis();
150        let u_magic = Self::to_magic_basis(unitary, &magic_basis);
151
152        // Compute M = U_magic^T · U_magic
153        let u_magic_t = u_magic.t().to_owned();
154        let m = u_magic_t.dot(&u_magic);
155
156        // Diagonalize M to find the canonical form
157        let (d, p) = Self::diagonalize_symmetric(&m)?;
158
159        // Extract interaction coefficients from eigenvalues
160        let coeffs = Self::extract_coefficients(&d);
161
162        // Compute single-qubit gates
163        let (left_gates, right_gates) = self.compute_local_gates(unitary, &u_magic, &p, &coeffs)?;
164
165        // Compute global phase
166        let global_phase = Self::compute_global_phase(unitary, &left_gates, &right_gates, &coeffs)?;
167
168        Ok(CartanDecomposition {
169            left_gates,
170            right_gates,
171            interaction: coeffs,
172            global_phase,
173        })
174    }
175
176    /// Get the magic basis transformation matrix
177    fn get_magic_basis() -> Array2<Complex<f64>> {
178        let sqrt2 = 2.0_f64.sqrt();
179        Array2::from_shape_vec(
180            (4, 4),
181            vec![
182                Complex::new(1.0, 0.0),
183                Complex::new(0.0, 0.0),
184                Complex::new(0.0, 0.0),
185                Complex::new(1.0, 0.0),
186                Complex::new(0.0, 0.0),
187                Complex::new(1.0, 0.0),
188                Complex::new(1.0, 0.0),
189                Complex::new(0.0, 0.0),
190                Complex::new(0.0, 0.0),
191                Complex::new(1.0, 0.0),
192                Complex::new(-1.0, 0.0),
193                Complex::new(0.0, 0.0),
194                Complex::new(1.0, 0.0),
195                Complex::new(0.0, 0.0),
196                Complex::new(0.0, 0.0),
197                Complex::new(-1.0, 0.0),
198            ],
199        )
200        .expect("Failed to create magic basis matrix in CartanDecomposer::get_magic_basis")
201            / Complex::new(sqrt2, 0.0)
202    }
203
204    /// Transform matrix to magic basis
205    fn to_magic_basis(
206        u: &Array2<Complex<f64>>,
207        magic: &Array2<Complex<f64>>,
208    ) -> Array2<Complex<f64>> {
209        let magic_dag = magic.mapv(|z| z.conj()).t().to_owned();
210        magic_dag.dot(u).dot(magic)
211    }
212
213    /// Diagonalize a symmetric complex matrix via QR iteration
214    ///
215    /// For the Cartan decomposition, M = U^T U is complex symmetric.
216    /// Its eigenvalues are complex numbers on the unit circle: exp(2i*phi_k).
217    /// Returns (eigenvalues as Complex, approximate eigenvectors).
218    fn diagonalize_symmetric(
219        m: &Array2<Complex<f64>>,
220    ) -> QuantRS2Result<(Array1<Complex<f64>>, Array2<Complex<f64>>)> {
221        let n = m.nrows();
222        // QR iteration with Francis shifts to find all eigenvalues of the complex matrix M.
223        // We work with a Hessenberg reduction first, then apply shifted QR steps.
224        let mut h = m.to_owned();
225        let mut q = Array2::<Complex<f64>>::eye(n);
226
227        // Reduce to upper Hessenberg form using Householder reflections
228        for k in 0..n.saturating_sub(2) {
229            // Build Householder vector from column k, rows k+1..n
230            let col: Vec<Complex<f64>> = (k + 1..n).map(|i| h[[i, k]]).collect();
231            let sigma_sq: f64 = col.iter().map(|z| z.norm_sqr()).sum();
232            let sigma = sigma_sq.sqrt();
233            if sigma < 1e-14 {
234                continue;
235            }
236            // Choose Householder sign to maximise numerical stability
237            let phase = if col[0].norm() > 1e-14 {
238                col[0] / col[0].norm()
239            } else {
240                Complex::new(1.0, 0.0)
241            };
242            let mut v = col.clone();
243            v[0] = v[0] + phase * sigma;
244            let v_norm_sq: f64 = v.iter().map(|z| z.norm_sqr()).sum();
245            if v_norm_sq < 1e-28 {
246                continue;
247            }
248            let m_len = v.len(); // = n - (k+1)
249
250            // Apply H from the left: h[k+1.., ..] -= (2/v^†v) v (v^† h[k+1.., ..])
251            for j in 0..n {
252                let dot: Complex<f64> = (0..m_len).map(|i| v[i].conj() * h[[k + 1 + i, j]]).sum();
253                let scale = dot * Complex::new(2.0 / v_norm_sq, 0.0);
254                for i in 0..m_len {
255                    h[[k + 1 + i, j]] = h[[k + 1 + i, j]] - v[i] * scale;
256                }
257            }
258            // Apply H from the right: h[.., k+1..] -= (2/v^†v) (h[.., k+1..] v) v^†
259            for i in 0..n {
260                let dot: Complex<f64> = (0..m_len).map(|j| h[[i, k + 1 + j]] * v[j]).sum();
261                let scale = dot * Complex::new(2.0 / v_norm_sq, 0.0);
262                for j in 0..m_len {
263                    h[[i, k + 1 + j]] = h[[i, k + 1 + j]] - scale * v[j].conj();
264                }
265            }
266            // Accumulate Q
267            for i in 0..n {
268                let dot: Complex<f64> = (0..m_len).map(|j| q[[i, k + 1 + j]] * v[j]).sum();
269                let scale = dot * Complex::new(2.0 / v_norm_sq, 0.0);
270                for j in 0..m_len {
271                    q[[i, k + 1 + j]] = q[[i, k + 1 + j]] - scale * v[j].conj();
272                }
273            }
274        }
275
276        // Francis double-shift QR iteration on the Hessenberg matrix
277        let max_iter = 300 * n;
278        let mut active = n;
279        for _iter in 0..max_iter {
280            if active <= 1 {
281                break;
282            }
283            // Deflate converged eigenvalues at the bottom
284            while active > 1 {
285                let off = h[[active - 1, active - 2]].norm();
286                let d1 = h[[active - 1, active - 1]].norm();
287                let d0 = h[[active - 2, active - 2]].norm();
288                if off < 1e-12 * (d1 + d0) {
289                    active -= 1;
290                } else {
291                    break;
292                }
293            }
294            if active <= 1 {
295                break;
296            }
297
298            // Wilkinson (single complex) shift: eigenvalue of bottom 2x2 closest to h[a-1,a-1]
299            let a = active;
300            let s = h[[a - 1, a - 1]];
301
302            // Single-shift QR step: compute Givens rotations to push shift through
303            // Apply shift: h' = h - s*I, QR decompose, then h'' = RQ + s*I
304            for k in 0..a - 1 {
305                // Compute Givens rotation to zero h[k+1, k]
306                let x = h[[k, k]] - s;
307                let y = h[[k + 1, k]];
308                let r = (x.norm_sqr() + y.norm_sqr()).sqrt();
309                if r < 1e-14 {
310                    continue;
311                }
312                let c_val = x / r;
313                let s_val = -y / r;
314
315                // Apply Givens rotation from left: rows k and k+1
316                for j in 0..n {
317                    let tmp0 = c_val * h[[k, j]] - s_val.conj() * h[[k + 1, j]];
318                    let tmp1 = s_val * h[[k, j]] + c_val.conj() * h[[k + 1, j]];
319                    h[[k, j]] = tmp0;
320                    h[[k + 1, j]] = tmp1;
321                }
322                // Apply Givens rotation from right: cols k and k+1
323                for i in 0..n {
324                    let tmp0 = c_val.conj() * h[[i, k]] - s_val.conj() * h[[i, k + 1]];
325                    let tmp1 = s_val * h[[i, k]] + c_val * h[[i, k + 1]];
326                    h[[i, k]] = tmp0;
327                    h[[i, k + 1]] = tmp1;
328                }
329                // Accumulate in Q
330                for i in 0..n {
331                    let tmp0 = c_val.conj() * q[[i, k]] - s_val.conj() * q[[i, k + 1]];
332                    let tmp1 = s_val * q[[i, k]] + c_val * q[[i, k + 1]];
333                    q[[i, k]] = tmp0;
334                    q[[i, k + 1]] = tmp1;
335                }
336            }
337        }
338
339        // Extract eigenvalues from the diagonal of h
340        let mut eigenvalues = Array1::zeros(n);
341        for i in 0..n {
342            eigenvalues[i] = h[[i, i]];
343        }
344
345        Ok((eigenvalues, q))
346    }
347
348    /// Extract Cartan coefficients from complex eigenvalues of M = U^T U
349    ///
350    /// The eigenvalues are exp(2i·phi_k). For U = exp(i(aXX + bYY + cZZ)),
351    /// the phases come in pairs: {+(a+b+c), +(a-b-c), +(-a+b-c), +(-a-b+c)}.
352    /// Sorting and averaging the phases gives:
353    ///   a = (phi_0 + phi_1 - phi_2 - phi_3) / 4   (after appropriate ordering)
354    /// More robustly: solve the 4×4 linear system.
355    fn extract_coefficients(eigenvalues: &Array1<Complex<f64>>) -> CartanCoefficients {
356        // Extract phases phi_k from eigenvalues exp(2i*phi_k)
357        // The arg() gives 2*phi, so phi = arg/2
358        let mut phases: Vec<f64> = eigenvalues.iter().map(|z| z.arg() / 2.0).collect();
359        // Sort phases for stable extraction
360        phases.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
361
362        // For the four Cartan phases {a+b+c, a-b-c, -a+b-c, -a-b+c}:
363        // Sum = 0, so use differences.
364        // Ordered phases p0 <= p1 <= p2 <= p3 with p0+p3 ≈ 0, p1+p2 ≈ 0
365        // a = (p3 - p2 + p1 - p0) / 4 ... but sign ordering depends on values.
366        // Use the symmetric formula: after sorting ascending,
367        //   a+b+c corresponds to the largest magnitude phase
368        //   We identify:
369        //     c = (p3 - p0) / 4   (half the spread of extreme phases)
370        //     b = (p2 - p1) / 4   (half the spread of middle phases)
371        //     a ≈ (p3 + p2 - p1 - p0) / 4
372        let p0 = phases.first().copied().unwrap_or(0.0);
373        let p1 = phases.get(1).copied().unwrap_or(0.0);
374        let p2 = phases.get(2).copied().unwrap_or(0.0);
375        let p3 = phases.get(3).copied().unwrap_or(0.0);
376
377        // Solve: a+b+c=p3, a-b-c=p0, -a+b-c=p1, -a-b+c=p2  (one consistent assignment)
378        // Adding all: 0 = p0+p1+p2+p3 (true up to 2π ambiguity)
379        // From (p3+p0)/2 = a and (p3-p0)/2 = b+c
380        // From (p2+p1)/2 = -a and (p2-p1)/2 = c-b
381        let a = (p3 + p0) / 2.0;
382        let b_plus_c = (p3 - p0) / 2.0;
383        let c_minus_b = (p2 - p1) / 2.0;
384        let b = (b_plus_c - c_minus_b) / 2.0;
385        let c = (b_plus_c + c_minus_b) / 2.0;
386
387        let mut coeffs = CartanCoefficients::new(a, b, c);
388        coeffs.canonicalize();
389        coeffs
390    }
391
392    /// Compute single-qubit gates from decomposition
393    fn compute_local_gates(
394        &self,
395        u: &Array2<Complex<f64>>,
396        _u_magic: &Array2<Complex<f64>>,
397        _p: &Array2<Complex<f64>>,
398        coeffs: &CartanCoefficients,
399    ) -> QuantRS2Result<(
400        (SingleQubitDecomposition, SingleQubitDecomposition),
401        (SingleQubitDecomposition, SingleQubitDecomposition),
402    )> {
403        // Build the canonical gate
404        let _canonical = Self::build_canonical_gate(coeffs);
405
406        // The local gates satisfy:
407        // U = (A₁ ⊗ B₁) · canonical · (A₂ ⊗ B₂)
408
409        // Extract 2x2 blocks to find single-qubit gates
410        // This is simplified - proper implementation uses the full KAK theorem
411
412        let a1 = u.slice(s![..2, ..2]).to_owned();
413        let b1 = u.slice(s![2..4, 2..4]).to_owned();
414
415        let left_a = decompose_single_qubit_zyz(&a1.view())?;
416        let left_b = decompose_single_qubit_zyz(&b1.view())?;
417
418        // For right gates, we'd compute from the decomposition
419        // For now, use identity
420        let ident = Array2::eye(2);
421        let right_a = decompose_single_qubit_zyz(&ident.view())?;
422        let right_b = decompose_single_qubit_zyz(&ident.view())?;
423
424        Ok(((left_a, left_b), (right_a, right_b)))
425    }
426
427    /// Build the canonical gate from coefficients
428    fn build_canonical_gate(coeffs: &CartanCoefficients) -> Array2<Complex<f64>> {
429        // exp(i(aXX + bYY + cZZ))
430        let a = coeffs.xx;
431        let b = coeffs.yy;
432        let c = coeffs.zz;
433
434        // Direct computation of matrix exponential for this special form
435        let cos_a = a.cos();
436        let sin_a = a.sin();
437        let cos_b = b.cos();
438        let sin_b = b.sin();
439        let cos_c = c.cos();
440        let sin_c = c.sin();
441
442        // Build the 4x4 matrix
443        let mut result = Array2::zeros((4, 4));
444
445        // This is the explicit form of exp(i(aXX + bYY + cZZ))
446        result[[0, 0]] = Complex::new(cos_a * cos_b * cos_c, sin_c);
447        result[[0, 3]] = Complex::new(0.0, sin_a * cos_b * cos_c);
448        result[[1, 1]] = Complex::new(cos_a * cos_c, -sin_a * sin_b * sin_c);
449        result[[1, 2]] = Complex::new(0.0, cos_a.mul_add(sin_c, sin_a * sin_b * cos_c));
450        result[[2, 1]] = Complex::new(0.0, cos_a.mul_add(sin_c, -(sin_a * sin_b * cos_c)));
451        result[[2, 2]] = Complex::new(cos_a * cos_c, sin_a * sin_b * sin_c);
452        result[[3, 0]] = Complex::new(0.0, sin_a * cos_b * cos_c);
453        result[[3, 3]] = Complex::new(cos_a * cos_b * cos_c, -sin_c);
454
455        result
456    }
457
458    /// Compute global phase
459    const fn compute_global_phase(
460        _u: &Array2<Complex<f64>>,
461        _left: &(SingleQubitDecomposition, SingleQubitDecomposition),
462        _right: &(SingleQubitDecomposition, SingleQubitDecomposition),
463        _coeffs: &CartanCoefficients,
464    ) -> QuantRS2Result<f64> {
465        // Global phase is the phase difference between U and the reconstructed gate
466        // For now, return 0
467        Ok(0.0)
468    }
469
470    /// Convert Cartan decomposition to gate sequence
471    pub fn to_gates(
472        &self,
473        decomp: &CartanDecomposition,
474        qubit_ids: &[QubitId],
475    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
476        if qubit_ids.len() != 2 {
477            return Err(QuantRS2Error::InvalidInput(
478                "Cartan decomposition requires exactly 2 qubits".to_string(),
479            ));
480        }
481
482        let q0 = qubit_ids[0];
483        let q1 = qubit_ids[1];
484        let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
485
486        // Left single-qubit gates
487        gates.extend(self.single_qubit_to_gates(&decomp.left_gates.0, q0));
488        gates.extend(self.single_qubit_to_gates(&decomp.left_gates.1, q1));
489
490        // Canonical two-qubit gate
491        gates.extend(self.canonical_to_gates(&decomp.interaction, q0, q1)?);
492
493        // Right single-qubit gates
494        gates.extend(self.single_qubit_to_gates(&decomp.right_gates.0, q0));
495        gates.extend(self.single_qubit_to_gates(&decomp.right_gates.1, q1));
496
497        Ok(gates)
498    }
499
500    /// Convert single-qubit decomposition to gates
501    fn single_qubit_to_gates(
502        &self,
503        decomp: &SingleQubitDecomposition,
504        qubit: QubitId,
505    ) -> Vec<Box<dyn GateOp>> {
506        let mut gates = Vec::new();
507
508        if decomp.theta1.abs() > self.tolerance {
509            gates.push(Box::new(RotationZ {
510                target: qubit,
511                theta: decomp.theta1,
512            }) as Box<dyn GateOp>);
513        }
514
515        if decomp.phi.abs() > self.tolerance {
516            gates.push(Box::new(RotationY {
517                target: qubit,
518                theta: decomp.phi,
519            }) as Box<dyn GateOp>);
520        }
521
522        if decomp.theta2.abs() > self.tolerance {
523            gates.push(Box::new(RotationZ {
524                target: qubit,
525                theta: decomp.theta2,
526            }) as Box<dyn GateOp>);
527        }
528
529        gates
530    }
531
532    /// Convert canonical coefficients to gate sequence
533    fn canonical_to_gates(
534        &self,
535        coeffs: &CartanCoefficients,
536        q0: QubitId,
537        q1: QubitId,
538    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
539        let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
540        let cnots = coeffs.cnot_count(self.tolerance);
541
542        match cnots {
543            0 => {
544                // Identity - no gates needed
545            }
546            1 => {
547                // Special case: can be done with 1 CNOT
548                gates.push(Box::new(CNOT {
549                    control: q0,
550                    target: q1,
551                }));
552            }
553            2 => {
554                // Can be done with 2 CNOTs
555                // Add rotations
556                if coeffs.xx.abs() > self.tolerance {
557                    gates.push(Box::new(RotationX {
558                        target: q0,
559                        theta: coeffs.xx * 2.0,
560                    }));
561                }
562
563                gates.push(Box::new(CNOT {
564                    control: q0,
565                    target: q1,
566                }));
567
568                if coeffs.zz.abs() > self.tolerance {
569                    gates.push(Box::new(RotationZ {
570                        target: q1,
571                        theta: coeffs.zz * 2.0,
572                    }));
573                }
574
575                gates.push(Box::new(CNOT {
576                    control: q0,
577                    target: q1,
578                }));
579            }
580            3 => {
581                // General case: 3 CNOTs with intermediate rotations
582                gates.push(Box::new(CNOT {
583                    control: q0,
584                    target: q1,
585                }));
586
587                gates.push(Box::new(RotationZ {
588                    target: q0,
589                    theta: coeffs.xx * 2.0,
590                }));
591                gates.push(Box::new(RotationZ {
592                    target: q1,
593                    theta: coeffs.yy * 2.0,
594                }));
595
596                gates.push(Box::new(CNOT {
597                    control: q1,
598                    target: q0,
599                }));
600
601                gates.push(Box::new(RotationZ {
602                    target: q0,
603                    theta: coeffs.zz * 2.0,
604                }));
605
606                gates.push(Box::new(CNOT {
607                    control: q0,
608                    target: q1,
609                }));
610            }
611            _ => unreachable!("CNOT count should be 0-3"),
612        }
613
614        Ok(gates)
615    }
616}
617
618/// Optimized Cartan decomposer with special case handling
619pub struct OptimizedCartanDecomposer {
620    pub base: CartanDecomposer,
621    /// Enable special case optimizations
622    optimize_special_cases: bool,
623    /// Enable phase optimization
624    optimize_phase: bool,
625}
626
627impl OptimizedCartanDecomposer {
628    /// Create new optimized decomposer
629    pub fn new() -> Self {
630        Self {
631            base: CartanDecomposer::new(),
632            optimize_special_cases: true,
633            optimize_phase: true,
634        }
635    }
636
637    /// Decompose with optimizations
638    pub fn decompose(
639        &mut self,
640        unitary: &Array2<Complex<f64>>,
641    ) -> QuantRS2Result<CartanDecomposition> {
642        // Check for special cases first
643        if self.optimize_special_cases {
644            if let Some(special) = self.check_special_cases(unitary)? {
645                return Ok(special);
646            }
647        }
648
649        // Use base decomposition
650        let mut decomp = self.base.decompose(unitary)?;
651
652        // Optimize phase if enabled
653        if self.optimize_phase {
654            self.optimize_global_phase(&mut decomp);
655        }
656
657        Ok(decomp)
658    }
659
660    /// Check for special gate cases
661    fn check_special_cases(
662        &self,
663        unitary: &Array2<Complex<f64>>,
664    ) -> QuantRS2Result<Option<CartanDecomposition>> {
665        // Check for CNOT
666        if self.is_cnot(unitary) {
667            return Ok(Some(Self::cnot_decomposition()));
668        }
669
670        // Check for controlled-Z
671        if self.is_cz(unitary) {
672            return Ok(Some(Self::cz_decomposition()));
673        }
674
675        // Check for SWAP
676        if self.is_swap(unitary) {
677            return Ok(Some(Self::swap_decomposition()));
678        }
679
680        Ok(None)
681    }
682
683    /// Check if matrix is CNOT
684    fn is_cnot(&self, u: &Array2<Complex<f64>>) -> bool {
685        let cnot = Array2::from_shape_vec(
686            (4, 4),
687            vec![
688                Complex::new(1.0, 0.0),
689                Complex::new(0.0, 0.0),
690                Complex::new(0.0, 0.0),
691                Complex::new(0.0, 0.0),
692                Complex::new(0.0, 0.0),
693                Complex::new(1.0, 0.0),
694                Complex::new(0.0, 0.0),
695                Complex::new(0.0, 0.0),
696                Complex::new(0.0, 0.0),
697                Complex::new(0.0, 0.0),
698                Complex::new(0.0, 0.0),
699                Complex::new(1.0, 0.0),
700                Complex::new(0.0, 0.0),
701                Complex::new(0.0, 0.0),
702                Complex::new(1.0, 0.0),
703                Complex::new(0.0, 0.0),
704            ],
705        )
706        .expect("Failed to create CNOT matrix in OptimizedCartanDecomposer::is_cnot");
707
708        self.matrices_equal(u, &cnot)
709    }
710
711    /// Check if matrix is CZ
712    fn is_cz(&self, u: &Array2<Complex<f64>>) -> bool {
713        let cz = Array2::from_shape_vec(
714            (4, 4),
715            vec![
716                Complex::new(1.0, 0.0),
717                Complex::new(0.0, 0.0),
718                Complex::new(0.0, 0.0),
719                Complex::new(0.0, 0.0),
720                Complex::new(0.0, 0.0),
721                Complex::new(1.0, 0.0),
722                Complex::new(0.0, 0.0),
723                Complex::new(0.0, 0.0),
724                Complex::new(0.0, 0.0),
725                Complex::new(0.0, 0.0),
726                Complex::new(1.0, 0.0),
727                Complex::new(0.0, 0.0),
728                Complex::new(0.0, 0.0),
729                Complex::new(0.0, 0.0),
730                Complex::new(0.0, 0.0),
731                Complex::new(-1.0, 0.0),
732            ],
733        )
734        .expect("Failed to create CZ matrix in OptimizedCartanDecomposer::is_cz");
735
736        self.matrices_equal(u, &cz)
737    }
738
739    /// Check if matrix is SWAP
740    fn is_swap(&self, u: &Array2<Complex<f64>>) -> bool {
741        let swap = Array2::from_shape_vec(
742            (4, 4),
743            vec![
744                Complex::new(1.0, 0.0),
745                Complex::new(0.0, 0.0),
746                Complex::new(0.0, 0.0),
747                Complex::new(0.0, 0.0),
748                Complex::new(0.0, 0.0),
749                Complex::new(0.0, 0.0),
750                Complex::new(1.0, 0.0),
751                Complex::new(0.0, 0.0),
752                Complex::new(0.0, 0.0),
753                Complex::new(1.0, 0.0),
754                Complex::new(0.0, 0.0),
755                Complex::new(0.0, 0.0),
756                Complex::new(0.0, 0.0),
757                Complex::new(0.0, 0.0),
758                Complex::new(0.0, 0.0),
759                Complex::new(1.0, 0.0),
760            ],
761        )
762        .expect("Failed to create SWAP matrix in OptimizedCartanDecomposer::is_swap");
763
764        self.matrices_equal(u, &swap)
765    }
766
767    /// Check matrix equality up to global phase
768    fn matrices_equal(&self, a: &Array2<Complex<f64>>, b: &Array2<Complex<f64>>) -> bool {
769        // Find first non-zero element
770        let mut phase = Complex::new(1.0, 0.0);
771        for i in 0..4 {
772            for j in 0..4 {
773                if b[[i, j]].norm() > self.base.tolerance {
774                    phase = a[[i, j]] / b[[i, j]];
775                    break;
776                }
777            }
778        }
779
780        // Check all elements match up to phase
781        for i in 0..4 {
782            for j in 0..4 {
783                if (a[[i, j]] - phase * b[[i, j]]).norm() > self.base.tolerance {
784                    return false;
785                }
786            }
787        }
788
789        true
790    }
791
792    /// Decomposition for CNOT
793    fn cnot_decomposition() -> CartanDecomposition {
794        let ident = Array2::eye(2);
795        let ident_decomp = decompose_single_qubit_zyz(&ident.view()).expect(
796            "Failed to decompose identity in OptimizedCartanDecomposer::cnot_decomposition",
797        );
798
799        CartanDecomposition {
800            left_gates: (ident_decomp.clone(), ident_decomp.clone()),
801            right_gates: (ident_decomp.clone(), ident_decomp),
802            interaction: CartanCoefficients::new(PI / 4.0, PI / 4.0, 0.0),
803            global_phase: 0.0,
804        }
805    }
806
807    /// Decomposition for CZ
808    fn cz_decomposition() -> CartanDecomposition {
809        let ident = Array2::eye(2);
810        let ident_decomp = decompose_single_qubit_zyz(&ident.view())
811            .expect("Failed to decompose identity in OptimizedCartanDecomposer::cz_decomposition");
812
813        CartanDecomposition {
814            left_gates: (ident_decomp.clone(), ident_decomp.clone()),
815            right_gates: (ident_decomp.clone(), ident_decomp),
816            interaction: CartanCoefficients::new(0.0, 0.0, PI / 4.0),
817            global_phase: 0.0,
818        }
819    }
820
821    /// Decomposition for SWAP
822    fn swap_decomposition() -> CartanDecomposition {
823        let ident = Array2::eye(2);
824        let ident_decomp = decompose_single_qubit_zyz(&ident.view()).expect(
825            "Failed to decompose identity in OptimizedCartanDecomposer::swap_decomposition",
826        );
827
828        CartanDecomposition {
829            left_gates: (ident_decomp.clone(), ident_decomp.clone()),
830            right_gates: (ident_decomp.clone(), ident_decomp),
831            interaction: CartanCoefficients::new(PI / 4.0, PI / 4.0, PI / 4.0),
832            global_phase: 0.0,
833        }
834    }
835
836    /// Optimize global phase
837    fn optimize_global_phase(&self, decomp: &mut CartanDecomposition) {
838        // Absorb global phase into one of the single-qubit gates
839        if decomp.global_phase.abs() > self.base.tolerance {
840            decomp.left_gates.0.global_phase += decomp.global_phase;
841            decomp.global_phase = 0.0;
842        }
843    }
844}
845
846/// Utility function for quick Cartan decomposition
847pub fn cartan_decompose(unitary: &Array2<Complex<f64>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
848    let mut decomposer = CartanDecomposer::new();
849    let decomp = decomposer.decompose(unitary)?;
850    let qubit_ids = vec![QubitId(0), QubitId(1)];
851    decomposer.to_gates(&decomp, &qubit_ids)
852}
853
854impl Default for OptimizedCartanDecomposer {
855    fn default() -> Self {
856        Self::new()
857    }
858}
859
860impl Default for CartanDecomposer {
861    fn default() -> Self {
862        Self::new()
863    }
864}
865
866#[cfg(test)]
867mod tests {
868    use super::*;
869    use scirs2_core::Complex;
870
871    #[test]
872    fn test_cartan_coefficients() {
873        let coeffs = CartanCoefficients::new(0.1, 0.2, 0.3);
874        assert!(!coeffs.is_identity(1e-10));
875        assert_eq!(coeffs.cnot_count(1e-10), 3);
876
877        let zero_coeffs = CartanCoefficients::new(0.0, 0.0, 0.0);
878        assert!(zero_coeffs.is_identity(1e-10));
879        assert_eq!(zero_coeffs.cnot_count(1e-10), 0);
880    }
881
882    #[test]
883    fn test_cartan_cnot() {
884        let mut decomposer = CartanDecomposer::new();
885
886        // CNOT matrix
887        let cnot = Array2::from_shape_vec(
888            (4, 4),
889            vec![
890                Complex::new(1.0, 0.0),
891                Complex::new(0.0, 0.0),
892                Complex::new(0.0, 0.0),
893                Complex::new(0.0, 0.0),
894                Complex::new(0.0, 0.0),
895                Complex::new(1.0, 0.0),
896                Complex::new(0.0, 0.0),
897                Complex::new(0.0, 0.0),
898                Complex::new(0.0, 0.0),
899                Complex::new(0.0, 0.0),
900                Complex::new(0.0, 0.0),
901                Complex::new(1.0, 0.0),
902                Complex::new(0.0, 0.0),
903                Complex::new(0.0, 0.0),
904                Complex::new(1.0, 0.0),
905                Complex::new(0.0, 0.0),
906            ],
907        )
908        .expect("Failed to create CNOT matrix in test_cartan_cnot");
909
910        let decomp = decomposer
911            .decompose(&cnot)
912            .expect("Failed to decompose CNOT in test_cartan_cnot");
913
914        // CNOT should have specific interaction coefficients
915        assert!(decomp.interaction.cnot_count(1e-10) <= 1);
916    }
917
918    #[test]
919    fn test_optimized_special_cases() {
920        let mut opt_decomposer = OptimizedCartanDecomposer::new();
921
922        // Test SWAP gate
923        let swap = Array2::from_shape_vec(
924            (4, 4),
925            vec![
926                Complex::new(1.0, 0.0),
927                Complex::new(0.0, 0.0),
928                Complex::new(0.0, 0.0),
929                Complex::new(0.0, 0.0),
930                Complex::new(0.0, 0.0),
931                Complex::new(0.0, 0.0),
932                Complex::new(1.0, 0.0),
933                Complex::new(0.0, 0.0),
934                Complex::new(0.0, 0.0),
935                Complex::new(1.0, 0.0),
936                Complex::new(0.0, 0.0),
937                Complex::new(0.0, 0.0),
938                Complex::new(0.0, 0.0),
939                Complex::new(0.0, 0.0),
940                Complex::new(0.0, 0.0),
941                Complex::new(1.0, 0.0),
942            ],
943        )
944        .expect("Failed to create SWAP matrix in test_optimized_special_cases");
945
946        let decomp = opt_decomposer
947            .decompose(&swap)
948            .expect("Failed to decompose SWAP in test_optimized_special_cases");
949
950        // SWAP requires exactly 3 CNOTs
951        assert_eq!(decomp.interaction.cnot_count(1e-10), 3);
952    }
953
954    #[test]
955    fn test_cartan_identity() {
956        let mut decomposer = CartanDecomposer::new();
957
958        // Identity matrix
959        let identity = Array2::eye(4);
960        let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
961
962        let decomp = decomposer
963            .decompose(&identity_complex)
964            .expect("Failed to decompose identity in test_cartan_identity");
965
966        // Identity should have zero interaction
967        assert!(decomp.interaction.is_identity(1e-10));
968        assert_eq!(decomp.interaction.cnot_count(1e-10), 0);
969    }
970}