quantrs2_core/
kak_multiqubit.rs

1//! KAK decomposition for multi-qubit unitaries
2//!
3//! This module extends the Cartan (KAK) decomposition to handle arbitrary
4//! n-qubit unitaries through recursive application and generalized
5//! decomposition techniques.
6
7use crate::{
8    cartan::{CartanDecomposer, CartanDecomposition},
9    error::{QuantRS2Error, QuantRS2Result},
10    gate::{multi::*, single::*, GateOp},
11    matrix_ops::{DenseMatrix, QuantumMatrix},
12    qubit::QubitId,
13    shannon::ShannonDecomposer,
14    synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
15};
16use rustc_hash::FxHashMap;
17use scirs2_core::ndarray::{s, Array2};
18use scirs2_core::Complex;
19
20/// Result of multi-qubit KAK decomposition
21#[derive(Debug, Clone)]
22pub struct MultiQubitKAK {
23    /// The decomposed gate sequence
24    pub gates: Vec<Box<dyn GateOp>>,
25    /// Decomposition tree structure
26    pub tree: DecompositionTree,
27    /// Total CNOT count
28    pub cnot_count: usize,
29    /// Total single-qubit gate count
30    pub single_qubit_count: usize,
31    /// Circuit depth
32    pub depth: usize,
33}
34
35/// Tree structure representing the hierarchical decomposition
36#[derive(Debug, Clone)]
37pub enum DecompositionTree {
38    /// Leaf node - single or two-qubit gate
39    Leaf {
40        qubits: Vec<QubitId>,
41        gate_type: LeafType,
42    },
43    /// Internal node - recursive decomposition
44    Node {
45        qubits: Vec<QubitId>,
46        method: DecompositionMethod,
47        children: Vec<Self>,
48    },
49}
50
51/// Type of leaf decomposition
52#[derive(Debug, Clone)]
53pub enum LeafType {
54    SingleQubit(SingleQubitDecomposition),
55    TwoQubit(CartanDecomposition),
56}
57
58/// Method used for decomposition at this level
59#[derive(Debug, Clone)]
60pub enum DecompositionMethod {
61    /// Cosine-Sine Decomposition
62    CSD { pivot: usize },
63    /// Quantum Shannon Decomposition
64    Shannon { partition: usize },
65    /// Block diagonalization
66    BlockDiagonal { block_size: usize },
67    /// Direct Cartan for 2 qubits
68    Cartan,
69}
70
71/// Multi-qubit KAK decomposer
72pub struct MultiQubitKAKDecomposer {
73    /// Tolerance for numerical comparisons
74    tolerance: f64,
75    /// Maximum recursion depth
76    max_depth: usize,
77    /// Cache for decompositions
78    #[allow(dead_code)]
79    cache: FxHashMap<u64, MultiQubitKAK>,
80    /// Use optimized methods
81    use_optimization: bool,
82    /// Cartan decomposer for two-qubit blocks
83    cartan: CartanDecomposer,
84}
85
86impl MultiQubitKAKDecomposer {
87    /// Create a new multi-qubit KAK decomposer
88    pub fn new() -> Self {
89        Self {
90            tolerance: 1e-10,
91            max_depth: 20,
92            cache: FxHashMap::default(),
93            use_optimization: true,
94            cartan: CartanDecomposer::new(),
95        }
96    }
97
98    /// Create with custom tolerance
99    pub fn with_tolerance(tolerance: f64) -> Self {
100        Self {
101            tolerance,
102            max_depth: 20,
103            cache: FxHashMap::default(),
104            use_optimization: true,
105            cartan: CartanDecomposer::with_tolerance(tolerance),
106        }
107    }
108
109    /// Decompose an n-qubit unitary
110    pub fn decompose(
111        &mut self,
112        unitary: &Array2<Complex<f64>>,
113        qubit_ids: &[QubitId],
114    ) -> QuantRS2Result<MultiQubitKAK> {
115        let n = qubit_ids.len();
116        let size = 1 << n;
117
118        // Validate input
119        if unitary.shape() != [size, size] {
120            return Err(QuantRS2Error::InvalidInput(format!(
121                "Unitary size {} doesn't match {} qubits",
122                unitary.shape()[0],
123                n
124            )));
125        }
126
127        // Check unitarity
128        let mat = DenseMatrix::new(unitary.clone())?;
129        if !mat.is_unitary(self.tolerance)? {
130            return Err(QuantRS2Error::InvalidInput(
131                "Matrix is not unitary".to_string(),
132            ));
133        }
134
135        // Check cache
136        if let Some(cached) = self.check_cache(unitary) {
137            return Ok(cached.clone());
138        }
139
140        // Perform decomposition
141        let (tree, gates) = self.decompose_recursive(unitary, qubit_ids, 0)?;
142
143        // Count gates
144        let mut cnot_count = 0;
145        let mut single_qubit_count = 0;
146
147        for gate in &gates {
148            match gate.name() {
149                "CNOT" | "CZ" | "SWAP" => cnot_count += self.count_cnots(gate.name()),
150                _ => single_qubit_count += 1,
151            }
152        }
153
154        let result = MultiQubitKAK {
155            gates,
156            tree,
157            cnot_count,
158            single_qubit_count,
159            depth: 0, // TODO: Calculate actual depth
160        };
161
162        // Cache result
163        self.cache_result(unitary, &result);
164
165        Ok(result)
166    }
167
168    /// Recursive decomposition algorithm
169    fn decompose_recursive(
170        &mut self,
171        unitary: &Array2<Complex<f64>>,
172        qubit_ids: &[QubitId],
173        depth: usize,
174    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
175        if depth > self.max_depth {
176            return Err(QuantRS2Error::InvalidInput(
177                "Maximum recursion depth exceeded".to_string(),
178            ));
179        }
180
181        let n = qubit_ids.len();
182
183        // Base cases
184        match n {
185            0 => {
186                let tree = DecompositionTree::Leaf {
187                    qubits: vec![],
188                    gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
189                        global_phase: 0.0,
190                        theta1: 0.0,
191                        phi: 0.0,
192                        theta2: 0.0,
193                        basis: "ZYZ".to_string(),
194                    }),
195                };
196                Ok((tree, vec![]))
197            }
198            1 => {
199                let decomp = decompose_single_qubit_zyz(&unitary.view())?;
200                let gates = self.single_qubit_to_gates(&decomp, qubit_ids[0]);
201                let tree = DecompositionTree::Leaf {
202                    qubits: qubit_ids.to_vec(),
203                    gate_type: LeafType::SingleQubit(decomp),
204                };
205                Ok((tree, gates))
206            }
207            2 => {
208                let decomp = self.cartan.decompose(unitary)?;
209                let gates = self.cartan.to_gates(&decomp, qubit_ids)?;
210                let tree = DecompositionTree::Leaf {
211                    qubits: qubit_ids.to_vec(),
212                    gate_type: LeafType::TwoQubit(decomp),
213                };
214                Ok((tree, gates))
215            }
216            _ => {
217                // For n > 2, choose decomposition method
218                let method = self.choose_decomposition_method(unitary, n);
219
220                match method {
221                    DecompositionMethod::CSD { pivot } => {
222                        self.decompose_csd(unitary, qubit_ids, pivot, depth)
223                    }
224                    DecompositionMethod::Shannon { partition } => {
225                        self.decompose_shannon(unitary, qubit_ids, partition, depth)
226                    }
227                    DecompositionMethod::BlockDiagonal { block_size } => {
228                        self.decompose_block_diagonal(unitary, qubit_ids, block_size, depth)
229                    }
230                    DecompositionMethod::Cartan => unreachable!("Invalid method for n > 2"),
231                }
232            }
233        }
234    }
235
236    /// Choose optimal decomposition method based on matrix structure
237    fn choose_decomposition_method(
238        &self,
239        unitary: &Array2<Complex<f64>>,
240        n: usize,
241    ) -> DecompositionMethod {
242        if self.use_optimization {
243            // Analyze matrix structure to choose optimal method
244            if self.has_block_structure(unitary, n) {
245                DecompositionMethod::BlockDiagonal { block_size: n / 2 }
246            } else if n % 2 == 0 {
247                // Even number of qubits - use CSD at midpoint
248                DecompositionMethod::CSD { pivot: n / 2 }
249            } else {
250                // Odd number - use Shannon decomposition
251                DecompositionMethod::Shannon { partition: n / 2 }
252            }
253        } else {
254            // Default to CSD
255            DecompositionMethod::CSD { pivot: n / 2 }
256        }
257    }
258
259    /// Decompose using Cosine-Sine Decomposition
260    fn decompose_csd(
261        &mut self,
262        unitary: &Array2<Complex<f64>>,
263        qubit_ids: &[QubitId],
264        pivot: usize,
265        depth: usize,
266    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
267        let n = qubit_ids.len();
268        // let _size = 1 << n;
269        let pivot_size = 1 << pivot;
270
271        // Split unitary into blocks based on pivot
272        // U = [A B]
273        //     [C D]
274        let a = unitary.slice(s![..pivot_size, ..pivot_size]).to_owned();
275        let b = unitary.slice(s![..pivot_size, pivot_size..]).to_owned();
276        let c = unitary.slice(s![pivot_size.., ..pivot_size]).to_owned();
277        let d = unitary.slice(s![pivot_size.., pivot_size..]).to_owned();
278
279        // Apply CSD to find:
280        // U = (U1 ⊗ V1) · Σ · (U2 ⊗ V2)
281        // where Σ is diagonal in the CSD basis
282
283        // This is a simplified version - full CSD would use SVD
284        let (u1, v1, sigma, u2, v2) = self.compute_csd(&a, &b, &c, &d)?;
285
286        let mut gates = Vec::new();
287        let mut children = Vec::new();
288
289        // Decompose U2 and V2 (right multiplications)
290        let left_qubits = &qubit_ids[..pivot];
291        let right_qubits = &qubit_ids[pivot..];
292
293        let (u2_tree, u2_gates) = self.decompose_recursive(&u2, left_qubits, depth + 1)?;
294        let (v2_tree, v2_gates) = self.decompose_recursive(&v2, right_qubits, depth + 1)?;
295
296        gates.extend(u2_gates);
297        gates.extend(v2_gates);
298        children.push(u2_tree);
299        children.push(v2_tree);
300
301        // Apply diagonal gates (controlled rotations)
302        let diag_gates = self.diagonal_to_gates(&sigma, qubit_ids)?;
303        gates.extend(diag_gates);
304
305        // Decompose U1 and V1 (left multiplications)
306        let (u1_tree, u1_gates) = self.decompose_recursive(&u1, left_qubits, depth + 1)?;
307        let (v1_tree, v1_gates) = self.decompose_recursive(&v1, right_qubits, depth + 1)?;
308
309        gates.extend(u1_gates);
310        gates.extend(v1_gates);
311        children.push(u1_tree);
312        children.push(v1_tree);
313
314        let tree = DecompositionTree::Node {
315            qubits: qubit_ids.to_vec(),
316            method: DecompositionMethod::CSD { pivot },
317            children,
318        };
319
320        Ok((tree, gates))
321    }
322
323    /// Decompose using Shannon decomposition
324    fn decompose_shannon(
325        &self,
326        unitary: &Array2<Complex<f64>>,
327        qubit_ids: &[QubitId],
328        partition: usize,
329        _depth: usize,
330    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
331        // Use the Shannon decomposer for this
332        let mut shannon = ShannonDecomposer::new();
333        let decomp = shannon.decompose(unitary, qubit_ids)?;
334
335        // Build tree structure
336        let tree = DecompositionTree::Node {
337            qubits: qubit_ids.to_vec(),
338            method: DecompositionMethod::Shannon { partition },
339            children: vec![], // Shannon decomposer doesn't provide tree structure
340        };
341
342        Ok((tree, decomp.gates))
343    }
344
345    /// Decompose block diagonal matrix
346    fn decompose_block_diagonal(
347        &mut self,
348        unitary: &Array2<Complex<f64>>,
349        qubit_ids: &[QubitId],
350        block_size: usize,
351        depth: usize,
352    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
353        let n = qubit_ids.len();
354        let num_blocks = n / block_size;
355
356        let mut gates = Vec::new();
357        let mut children = Vec::new();
358
359        // Decompose each block independently
360        for i in 0..num_blocks {
361            let start = i * block_size;
362            let end = (i + 1) * block_size;
363            let block_qubits = &qubit_ids[start..end];
364
365            // Extract block from unitary
366            let block = self.extract_block(unitary, i, block_size)?;
367
368            let (block_tree, block_gates) =
369                self.decompose_recursive(&block, block_qubits, depth + 1)?;
370            gates.extend(block_gates);
371            children.push(block_tree);
372        }
373
374        let tree = DecompositionTree::Node {
375            qubits: qubit_ids.to_vec(),
376            method: DecompositionMethod::BlockDiagonal { block_size },
377            children,
378        };
379
380        Ok((tree, gates))
381    }
382
383    /// Compute Cosine-Sine Decomposition
384    fn compute_csd(
385        &self,
386        a: &Array2<Complex<f64>>,
387        b: &Array2<Complex<f64>>,
388        c: &Array2<Complex<f64>>,
389        d: &Array2<Complex<f64>>,
390    ) -> QuantRS2Result<(
391        Array2<Complex<f64>>, // U1
392        Array2<Complex<f64>>, // V1
393        Array2<Complex<f64>>, // Sigma
394        Array2<Complex<f64>>, // U2
395        Array2<Complex<f64>>, // V2
396    )> {
397        // This is a simplified placeholder
398        // Full CSD implementation would use specialized algorithms
399
400        let size = a.shape()[0];
401        let identity = Array2::eye(size);
402        let _zero: Array2<Complex<f64>> = Array2::zeros((size, size));
403
404        // For now, return identity transformations
405        let u1 = identity.clone();
406        let v1 = identity.clone();
407        let u2 = identity.clone();
408        let v2 = identity;
409
410        // Sigma would contain the CS angles
411        let mut sigma = Array2::zeros((size * 2, size * 2));
412        sigma.slice_mut(s![..size, ..size]).assign(a);
413        sigma.slice_mut(s![..size, size..]).assign(b);
414        sigma.slice_mut(s![size.., ..size]).assign(c);
415        sigma.slice_mut(s![size.., size..]).assign(d);
416
417        Ok((u1, v1, sigma, u2, v2))
418    }
419
420    /// Convert diagonal matrix to controlled rotation gates
421    fn diagonal_to_gates(
422        &self,
423        diagonal: &Array2<Complex<f64>>,
424        qubit_ids: &[QubitId],
425    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
426        let mut gates = Vec::new();
427
428        // Extract diagonal elements
429        let n = diagonal.shape()[0];
430        for i in 0..n {
431            let phase = diagonal[[i, i]].arg();
432            if phase.abs() > self.tolerance {
433                // Determine which qubits are in state |1⟩ for this diagonal element
434                let mut control_pattern = Vec::new();
435                let mut temp = i;
436                for j in 0..qubit_ids.len() {
437                    if temp & 1 == 1 {
438                        control_pattern.push(j);
439                    }
440                    temp >>= 1;
441                }
442
443                // Create multi-controlled phase gate
444                if control_pattern.is_empty() {
445                    // Global phase - can be ignored
446                } else if control_pattern.len() == 1 {
447                    // Single-qubit phase
448                    gates.push(Box::new(RotationZ {
449                        target: qubit_ids[control_pattern[0]],
450                        theta: phase,
451                    }) as Box<dyn GateOp>);
452                } else {
453                    // Multi-controlled phase - decompose further
454                    // For now, use simple decomposition
455                    // Note: control_pattern.len() >= 2 at this point, so pop is safe
456                    let target_idx = control_pattern.pop().unwrap_or(0);
457                    let target = qubit_ids[target_idx];
458                    for &control_idx in &control_pattern {
459                        gates.push(Box::new(CNOT {
460                            control: qubit_ids[control_idx],
461                            target,
462                        }));
463                    }
464
465                    gates.push(Box::new(RotationZ {
466                        target,
467                        theta: phase,
468                    }) as Box<dyn GateOp>);
469
470                    // Uncompute CNOTs
471                    for &control_idx in control_pattern.iter().rev() {
472                        gates.push(Box::new(CNOT {
473                            control: qubit_ids[control_idx],
474                            target,
475                        }));
476                    }
477                }
478            }
479        }
480
481        Ok(gates)
482    }
483
484    /// Check if matrix has block diagonal structure
485    fn has_block_structure(&self, unitary: &Array2<Complex<f64>>, _n: usize) -> bool {
486        // Simple check - look for zeros in off-diagonal blocks
487        let size = unitary.shape()[0];
488        let block_size = size / 2;
489
490        let mut off_diagonal_norm = 0.0;
491
492        // Check upper-right block
493        for i in 0..block_size {
494            for j in block_size..size {
495                off_diagonal_norm += unitary[[i, j]].norm_sqr();
496            }
497        }
498
499        // Check lower-left block
500        for i in block_size..size {
501            for j in 0..block_size {
502                off_diagonal_norm += unitary[[i, j]].norm_sqr();
503            }
504        }
505
506        off_diagonal_norm.sqrt() < self.tolerance
507    }
508
509    /// Extract a block from block-diagonal matrix
510    fn extract_block(
511        &self,
512        unitary: &Array2<Complex<f64>>,
513        block_idx: usize,
514        block_size: usize,
515    ) -> QuantRS2Result<Array2<Complex<f64>>> {
516        let size = 1 << block_size;
517        let start = block_idx * size;
518        let end = (block_idx + 1) * size;
519
520        Ok(unitary.slice(s![start..end, start..end]).to_owned())
521    }
522
523    /// Convert single-qubit decomposition to gates
524    fn single_qubit_to_gates(
525        &self,
526        decomp: &SingleQubitDecomposition,
527        qubit: QubitId,
528    ) -> Vec<Box<dyn GateOp>> {
529        let mut gates = Vec::new();
530
531        if decomp.theta1.abs() > self.tolerance {
532            gates.push(Box::new(RotationZ {
533                target: qubit,
534                theta: decomp.theta1,
535            }) as Box<dyn GateOp>);
536        }
537
538        if decomp.phi.abs() > self.tolerance {
539            gates.push(Box::new(RotationY {
540                target: qubit,
541                theta: decomp.phi,
542            }) as Box<dyn GateOp>);
543        }
544
545        if decomp.theta2.abs() > self.tolerance {
546            gates.push(Box::new(RotationZ {
547                target: qubit,
548                theta: decomp.theta2,
549            }) as Box<dyn GateOp>);
550        }
551
552        gates
553    }
554
555    /// Count CNOTs for different gate types
556    fn count_cnots(&self, gate_name: &str) -> usize {
557        match gate_name {
558            "CNOT" | "CZ" => 1, // CZ = H·CNOT·H
559            "SWAP" => 3,        // SWAP uses 3 CNOTs
560            _ => 0,
561        }
562    }
563
564    /// Check cache for existing decomposition
565    const fn check_cache(&self, _unitary: &Array2<Complex<f64>>) -> Option<&MultiQubitKAK> {
566        // Simple hash based on first few elements
567        // Real implementation would use better hashing
568        None
569    }
570
571    /// Cache decomposition result
572    const fn cache_result(&self, _unitary: &Array2<Complex<f64>>, _result: &MultiQubitKAK) {
573        // Cache implementation
574    }
575}
576
577impl Default for MultiQubitKAKDecomposer {
578    fn default() -> Self {
579        Self::new()
580    }
581}
582
583/// Analyze decomposition tree structure
584pub struct KAKTreeAnalyzer {
585    /// Track statistics
586    stats: DecompositionStats,
587}
588
589#[derive(Debug, Default, Clone)]
590pub struct DecompositionStats {
591    pub total_nodes: usize,
592    pub leaf_nodes: usize,
593    pub max_depth: usize,
594    pub method_counts: FxHashMap<String, usize>,
595    pub cnot_distribution: FxHashMap<usize, usize>,
596}
597
598impl KAKTreeAnalyzer {
599    /// Create new analyzer
600    pub fn new() -> Self {
601        Self {
602            stats: DecompositionStats::default(),
603        }
604    }
605
606    /// Analyze decomposition tree
607    pub fn analyze(&mut self, tree: &DecompositionTree) -> DecompositionStats {
608        self.stats = DecompositionStats::default();
609        self.analyze_recursive(tree, 0);
610        self.stats.clone()
611    }
612
613    fn analyze_recursive(&mut self, tree: &DecompositionTree, depth: usize) {
614        self.stats.total_nodes += 1;
615        self.stats.max_depth = self.stats.max_depth.max(depth);
616
617        match tree {
618            DecompositionTree::Leaf {
619                qubits: _qubits,
620                gate_type,
621            } => {
622                self.stats.leaf_nodes += 1;
623
624                match gate_type {
625                    LeafType::SingleQubit(_) => {
626                        *self
627                            .stats
628                            .method_counts
629                            .entry("single_qubit".to_string())
630                            .or_insert(0) += 1;
631                    }
632                    LeafType::TwoQubit(cartan) => {
633                        *self
634                            .stats
635                            .method_counts
636                            .entry("two_qubit".to_string())
637                            .or_insert(0) += 1;
638                        let cnots = cartan.interaction.cnot_count(1e-10);
639                        *self.stats.cnot_distribution.entry(cnots).or_insert(0) += 1;
640                    }
641                }
642            }
643            DecompositionTree::Node {
644                method, children, ..
645            } => {
646                let method_name = match method {
647                    DecompositionMethod::CSD { .. } => "csd",
648                    DecompositionMethod::Shannon { .. } => "shannon",
649                    DecompositionMethod::BlockDiagonal { .. } => "block_diagonal",
650                    DecompositionMethod::Cartan => "cartan",
651                };
652                *self
653                    .stats
654                    .method_counts
655                    .entry(method_name.to_string())
656                    .or_insert(0) += 1;
657
658                for child in children {
659                    self.analyze_recursive(child, depth + 1);
660                }
661            }
662        }
663    }
664}
665
666/// Utility function for quick multi-qubit KAK decomposition
667pub fn kak_decompose_multiqubit(
668    unitary: &Array2<Complex<f64>>,
669    qubit_ids: &[QubitId],
670) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
671    let mut decomposer = MultiQubitKAKDecomposer::new();
672    let decomp = decomposer.decompose(unitary, qubit_ids)?;
673    Ok(decomp.gates)
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679    use scirs2_core::ndarray::Array2;
680    use scirs2_core::Complex;
681
682    #[test]
683    fn test_multiqubit_kak_single() {
684        let mut decomposer = MultiQubitKAKDecomposer::new();
685
686        // Hadamard matrix
687        let h = Array2::from_shape_vec(
688            (2, 2),
689            vec![
690                Complex::new(1.0, 0.0),
691                Complex::new(1.0, 0.0),
692                Complex::new(1.0, 0.0),
693                Complex::new(-1.0, 0.0),
694            ],
695        )
696        .expect("Failed to create Hadamard matrix")
697            / Complex::new(2.0_f64.sqrt(), 0.0);
698
699        let qubit_ids = vec![QubitId(0)];
700        let decomp = decomposer
701            .decompose(&h, &qubit_ids)
702            .expect("Single-qubit KAK decomposition failed");
703
704        assert!(decomp.single_qubit_count <= 3);
705        assert_eq!(decomp.cnot_count, 0);
706
707        // Check tree structure
708        match &decomp.tree {
709            DecompositionTree::Leaf {
710                gate_type: LeafType::SingleQubit(_),
711                ..
712            } => {}
713            _ => panic!("Expected single-qubit leaf"),
714        }
715    }
716
717    #[test]
718    fn test_multiqubit_kak_two() {
719        let mut decomposer = MultiQubitKAKDecomposer::new();
720
721        // CNOT matrix
722        let cnot = Array2::from_shape_vec(
723            (4, 4),
724            vec![
725                Complex::new(1.0, 0.0),
726                Complex::new(0.0, 0.0),
727                Complex::new(0.0, 0.0),
728                Complex::new(0.0, 0.0),
729                Complex::new(0.0, 0.0),
730                Complex::new(1.0, 0.0),
731                Complex::new(0.0, 0.0),
732                Complex::new(0.0, 0.0),
733                Complex::new(0.0, 0.0),
734                Complex::new(0.0, 0.0),
735                Complex::new(0.0, 0.0),
736                Complex::new(1.0, 0.0),
737                Complex::new(0.0, 0.0),
738                Complex::new(0.0, 0.0),
739                Complex::new(1.0, 0.0),
740                Complex::new(0.0, 0.0),
741            ],
742        )
743        .expect("Failed to create CNOT matrix");
744
745        let qubit_ids = vec![QubitId(0), QubitId(1)];
746        let decomp = decomposer
747            .decompose(&cnot, &qubit_ids)
748            .expect("Two-qubit KAK decomposition failed");
749
750        assert!(decomp.cnot_count <= 1);
751
752        // Check tree structure
753        match &decomp.tree {
754            DecompositionTree::Leaf {
755                gate_type: LeafType::TwoQubit(_),
756                ..
757            } => {}
758            _ => panic!("Expected two-qubit leaf"),
759        }
760    }
761
762    #[test]
763    fn test_multiqubit_kak_three() {
764        let mut decomposer = MultiQubitKAKDecomposer::new();
765
766        // 3-qubit identity
767        let identity = Array2::eye(8);
768        let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
769
770        let qubit_ids = vec![QubitId(0), QubitId(1), QubitId(2)];
771        let decomp = decomposer
772            .decompose(&identity_complex, &qubit_ids)
773            .expect("Three-qubit KAK decomposition failed");
774
775        // Identity should result in empty circuit
776        assert_eq!(decomp.gates.len(), 0);
777        assert_eq!(decomp.cnot_count, 0);
778        assert_eq!(decomp.single_qubit_count, 0);
779    }
780
781    #[test]
782    fn test_tree_analyzer() {
783        let mut analyzer = KAKTreeAnalyzer::new();
784
785        // Create a simple tree
786        let tree = DecompositionTree::Node {
787            qubits: vec![QubitId(0), QubitId(1), QubitId(2)],
788            method: DecompositionMethod::CSD { pivot: 2 },
789            children: vec![
790                DecompositionTree::Leaf {
791                    qubits: vec![QubitId(0), QubitId(1)],
792                    gate_type: LeafType::TwoQubit(CartanDecomposition {
793                        left_gates: (
794                            SingleQubitDecomposition {
795                                global_phase: 0.0,
796                                theta1: 0.0,
797                                phi: 0.0,
798                                theta2: 0.0,
799                                basis: "ZYZ".to_string(),
800                            },
801                            SingleQubitDecomposition {
802                                global_phase: 0.0,
803                                theta1: 0.0,
804                                phi: 0.0,
805                                theta2: 0.0,
806                                basis: "ZYZ".to_string(),
807                            },
808                        ),
809                        right_gates: (
810                            SingleQubitDecomposition {
811                                global_phase: 0.0,
812                                theta1: 0.0,
813                                phi: 0.0,
814                                theta2: 0.0,
815                                basis: "ZYZ".to_string(),
816                            },
817                            SingleQubitDecomposition {
818                                global_phase: 0.0,
819                                theta1: 0.0,
820                                phi: 0.0,
821                                theta2: 0.0,
822                                basis: "ZYZ".to_string(),
823                            },
824                        ),
825                        interaction: crate::prelude::CartanCoefficients::new(0.0, 0.0, 0.0),
826                        global_phase: 0.0,
827                    }),
828                },
829                DecompositionTree::Leaf {
830                    qubits: vec![QubitId(2)],
831                    gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
832                        global_phase: 0.0,
833                        theta1: 0.0,
834                        phi: 0.0,
835                        theta2: 0.0,
836                        basis: "ZYZ".to_string(),
837                    }),
838                },
839            ],
840        };
841
842        let stats = analyzer.analyze(&tree);
843
844        assert_eq!(stats.total_nodes, 3);
845        assert_eq!(stats.leaf_nodes, 2);
846        assert_eq!(stats.max_depth, 1);
847        assert_eq!(stats.method_counts.get("csd"), Some(&1));
848    }
849
850    #[test]
851    fn test_block_structure_detection() {
852        let decomposer = MultiQubitKAKDecomposer::new();
853
854        // Create block diagonal matrix
855        let mut block_diag = Array2::zeros((4, 4));
856        block_diag[[0, 0]] = Complex::new(1.0, 0.0);
857        block_diag[[1, 1]] = Complex::new(1.0, 0.0);
858        block_diag[[2, 2]] = Complex::new(1.0, 0.0);
859        block_diag[[3, 3]] = Complex::new(1.0, 0.0);
860
861        assert!(decomposer.has_block_structure(&block_diag, 2));
862
863        // Non-block diagonal
864        block_diag[[0, 2]] = Complex::new(1.0, 0.0);
865        assert!(!decomposer.has_block_structure(&block_diag, 2));
866    }
867}