quantrs2_core/
kak_multiqubit.rs

1//! KAK decomposition for multi-qubit unitaries
2//!
3//! This module extends the Cartan (KAK) decomposition to handle arbitrary
4//! n-qubit unitaries through recursive application and generalized
5//! decomposition techniques.
6
7use crate::{
8    cartan::{CartanCoefficients, CartanDecomposer, CartanDecomposition},
9    error::{QuantRS2Error, QuantRS2Result},
10    gate::{multi::*, single::*, GateOp},
11    matrix_ops::{DenseMatrix, QuantumMatrix},
12    qubit::QubitId,
13    shannon::ShannonDecomposer,
14    synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
15};
16use ndarray::{s, Array1, Array2, ArrayView2, Axis};
17use num_complex::Complex;
18use rustc_hash::FxHashMap;
19use std::f64::consts::PI;
20
21/// Result of multi-qubit KAK decomposition
22#[derive(Debug, Clone)]
23pub struct MultiQubitKAK {
24    /// The decomposed gate sequence
25    pub gates: Vec<Box<dyn GateOp>>,
26    /// Decomposition tree structure
27    pub tree: DecompositionTree,
28    /// Total CNOT count
29    pub cnot_count: usize,
30    /// Total single-qubit gate count
31    pub single_qubit_count: usize,
32    /// Circuit depth
33    pub depth: usize,
34}
35
36/// Tree structure representing the hierarchical decomposition
37#[derive(Debug, Clone)]
38pub enum DecompositionTree {
39    /// Leaf node - single or two-qubit gate
40    Leaf {
41        qubits: Vec<QubitId>,
42        gate_type: LeafType,
43    },
44    /// Internal node - recursive decomposition
45    Node {
46        qubits: Vec<QubitId>,
47        method: DecompositionMethod,
48        children: Vec<DecompositionTree>,
49    },
50}
51
52/// Type of leaf decomposition
53#[derive(Debug, Clone)]
54pub enum LeafType {
55    SingleQubit(SingleQubitDecomposition),
56    TwoQubit(CartanDecomposition),
57}
58
59/// Method used for decomposition at this level
60#[derive(Debug, Clone)]
61pub enum DecompositionMethod {
62    /// Cosine-Sine Decomposition
63    CSD { pivot: usize },
64    /// Quantum Shannon Decomposition
65    Shannon { partition: usize },
66    /// Block diagonalization
67    BlockDiagonal { block_size: usize },
68    /// Direct Cartan for 2 qubits
69    Cartan,
70}
71
72/// Multi-qubit KAK decomposer
73pub struct MultiQubitKAKDecomposer {
74    /// Tolerance for numerical comparisons
75    tolerance: f64,
76    /// Maximum recursion depth
77    max_depth: usize,
78    /// Cache for decompositions
79    cache: FxHashMap<u64, MultiQubitKAK>,
80    /// Use optimized methods
81    use_optimization: bool,
82    /// Cartan decomposer for two-qubit blocks
83    cartan: CartanDecomposer,
84}
85
86impl MultiQubitKAKDecomposer {
87    /// Create a new multi-qubit KAK decomposer
88    pub fn new() -> Self {
89        Self {
90            tolerance: 1e-10,
91            max_depth: 20,
92            cache: FxHashMap::default(),
93            use_optimization: true,
94            cartan: CartanDecomposer::new(),
95        }
96    }
97
98    /// Create with custom tolerance
99    pub fn with_tolerance(tolerance: f64) -> Self {
100        Self {
101            tolerance,
102            max_depth: 20,
103            cache: FxHashMap::default(),
104            use_optimization: true,
105            cartan: CartanDecomposer::with_tolerance(tolerance),
106        }
107    }
108
109    /// Decompose an n-qubit unitary
110    pub fn decompose(
111        &mut self,
112        unitary: &Array2<Complex<f64>>,
113        qubit_ids: &[QubitId],
114    ) -> QuantRS2Result<MultiQubitKAK> {
115        let n = qubit_ids.len();
116        let size = 1 << n;
117
118        // Validate input
119        if unitary.shape() != [size, size] {
120            return Err(QuantRS2Error::InvalidInput(format!(
121                "Unitary size {} doesn't match {} qubits",
122                unitary.shape()[0],
123                n
124            )));
125        }
126
127        // Check unitarity
128        let mat = DenseMatrix::new(unitary.clone())?;
129        if !mat.is_unitary(self.tolerance)? {
130            return Err(QuantRS2Error::InvalidInput(
131                "Matrix is not unitary".to_string(),
132            ));
133        }
134
135        // Check cache
136        if let Some(cached) = self.check_cache(unitary) {
137            return Ok(cached.clone());
138        }
139
140        // Perform decomposition
141        let (tree, gates) = self.decompose_recursive(unitary, qubit_ids, 0)?;
142
143        // Count gates
144        let mut cnot_count = 0;
145        let mut single_qubit_count = 0;
146
147        for gate in &gates {
148            match gate.name() {
149                "CNOT" | "CZ" | "SWAP" => cnot_count += self.count_cnots(gate.name()),
150                _ => single_qubit_count += 1,
151            }
152        }
153
154        let result = MultiQubitKAK {
155            gates,
156            tree,
157            cnot_count,
158            single_qubit_count,
159            depth: 0, // TODO: Calculate actual depth
160        };
161
162        // Cache result
163        self.cache_result(unitary, &result);
164
165        Ok(result)
166    }
167
168    /// Recursive decomposition algorithm
169    fn decompose_recursive(
170        &mut self,
171        unitary: &Array2<Complex<f64>>,
172        qubit_ids: &[QubitId],
173        depth: usize,
174    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
175        if depth > self.max_depth {
176            return Err(QuantRS2Error::InvalidInput(
177                "Maximum recursion depth exceeded".to_string(),
178            ));
179        }
180
181        let n = qubit_ids.len();
182
183        // Base cases
184        match n {
185            0 => {
186                let tree = DecompositionTree::Leaf {
187                    qubits: vec![],
188                    gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
189                        global_phase: 0.0,
190                        theta1: 0.0,
191                        phi: 0.0,
192                        theta2: 0.0,
193                        basis: "ZYZ".to_string(),
194                    }),
195                };
196                Ok((tree, vec![]))
197            }
198            1 => {
199                let decomp = decompose_single_qubit_zyz(&unitary.view())?;
200                let gates = self.single_qubit_to_gates(&decomp, qubit_ids[0]);
201                let tree = DecompositionTree::Leaf {
202                    qubits: qubit_ids.to_vec(),
203                    gate_type: LeafType::SingleQubit(decomp),
204                };
205                Ok((tree, gates))
206            }
207            2 => {
208                let decomp = self.cartan.decompose(unitary)?;
209                let gates = self.cartan.to_gates(&decomp, qubit_ids)?;
210                let tree = DecompositionTree::Leaf {
211                    qubits: qubit_ids.to_vec(),
212                    gate_type: LeafType::TwoQubit(decomp),
213                };
214                Ok((tree, gates))
215            }
216            _ => {
217                // For n > 2, choose decomposition method
218                let method = self.choose_decomposition_method(unitary, n);
219
220                match method {
221                    DecompositionMethod::CSD { pivot } => {
222                        self.decompose_csd(unitary, qubit_ids, pivot, depth)
223                    }
224                    DecompositionMethod::Shannon { partition } => {
225                        self.decompose_shannon(unitary, qubit_ids, partition, depth)
226                    }
227                    DecompositionMethod::BlockDiagonal { block_size } => {
228                        self.decompose_block_diagonal(unitary, qubit_ids, block_size, depth)
229                    }
230                    _ => unreachable!("Invalid method for n > 2"),
231                }
232            }
233        }
234    }
235
236    /// Choose optimal decomposition method based on matrix structure
237    fn choose_decomposition_method(
238        &self,
239        unitary: &Array2<Complex<f64>>,
240        n: usize,
241    ) -> DecompositionMethod {
242        if self.use_optimization {
243            // Analyze matrix structure to choose optimal method
244            if self.has_block_structure(unitary, n) {
245                DecompositionMethod::BlockDiagonal { block_size: n / 2 }
246            } else if n % 2 == 0 {
247                // Even number of qubits - use CSD at midpoint
248                DecompositionMethod::CSD { pivot: n / 2 }
249            } else {
250                // Odd number - use Shannon decomposition
251                DecompositionMethod::Shannon { partition: n / 2 }
252            }
253        } else {
254            // Default to CSD
255            DecompositionMethod::CSD { pivot: n / 2 }
256        }
257    }
258
259    /// Decompose using Cosine-Sine Decomposition
260    fn decompose_csd(
261        &mut self,
262        unitary: &Array2<Complex<f64>>,
263        qubit_ids: &[QubitId],
264        pivot: usize,
265        depth: usize,
266    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
267        let n = qubit_ids.len();
268        let size = 1 << n;
269        let pivot_size = 1 << pivot;
270
271        // Split unitary into blocks based on pivot
272        // U = [A B]
273        //     [C D]
274        let a = unitary.slice(s![..pivot_size, ..pivot_size]).to_owned();
275        let b = unitary.slice(s![..pivot_size, pivot_size..]).to_owned();
276        let c = unitary.slice(s![pivot_size.., ..pivot_size]).to_owned();
277        let d = unitary.slice(s![pivot_size.., pivot_size..]).to_owned();
278
279        // Apply CSD to find:
280        // U = (U1 ⊗ V1) · Σ · (U2 ⊗ V2)
281        // where Σ is diagonal in the CSD basis
282
283        // This is a simplified version - full CSD would use SVD
284        let (u1, v1, sigma, u2, v2) = self.compute_csd(&a, &b, &c, &d)?;
285
286        let mut gates = Vec::new();
287        let mut children = Vec::new();
288
289        // Decompose U2 and V2 (right multiplications)
290        let left_qubits = &qubit_ids[..pivot];
291        let right_qubits = &qubit_ids[pivot..];
292
293        let (u2_tree, u2_gates) = self.decompose_recursive(&u2, left_qubits, depth + 1)?;
294        let (v2_tree, v2_gates) = self.decompose_recursive(&v2, right_qubits, depth + 1)?;
295
296        gates.extend(u2_gates);
297        gates.extend(v2_gates);
298        children.push(u2_tree);
299        children.push(v2_tree);
300
301        // Apply diagonal gates (controlled rotations)
302        let diag_gates = self.diagonal_to_gates(&sigma, qubit_ids)?;
303        gates.extend(diag_gates);
304
305        // Decompose U1 and V1 (left multiplications)
306        let (u1_tree, u1_gates) = self.decompose_recursive(&u1, left_qubits, depth + 1)?;
307        let (v1_tree, v1_gates) = self.decompose_recursive(&v1, right_qubits, depth + 1)?;
308
309        gates.extend(u1_gates);
310        gates.extend(v1_gates);
311        children.push(u1_tree);
312        children.push(v1_tree);
313
314        let tree = DecompositionTree::Node {
315            qubits: qubit_ids.to_vec(),
316            method: DecompositionMethod::CSD { pivot },
317            children,
318        };
319
320        Ok((tree, gates))
321    }
322
323    /// Decompose using Shannon decomposition
324    fn decompose_shannon(
325        &mut self,
326        unitary: &Array2<Complex<f64>>,
327        qubit_ids: &[QubitId],
328        partition: usize,
329        depth: usize,
330    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
331        // Use the Shannon decomposer for this
332        let mut shannon = ShannonDecomposer::new();
333        let decomp = shannon.decompose(unitary, qubit_ids)?;
334
335        // Build tree structure
336        let tree = DecompositionTree::Node {
337            qubits: qubit_ids.to_vec(),
338            method: DecompositionMethod::Shannon { partition },
339            children: vec![], // Shannon decomposer doesn't provide tree structure
340        };
341
342        Ok((tree, decomp.gates))
343    }
344
345    /// Decompose block diagonal matrix
346    fn decompose_block_diagonal(
347        &mut self,
348        unitary: &Array2<Complex<f64>>,
349        qubit_ids: &[QubitId],
350        block_size: usize,
351        depth: usize,
352    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
353        let n = qubit_ids.len();
354        let num_blocks = n / block_size;
355
356        let mut gates = Vec::new();
357        let mut children = Vec::new();
358
359        // Decompose each block independently
360        for i in 0..num_blocks {
361            let start = i * block_size;
362            let end = (i + 1) * block_size;
363            let block_qubits = &qubit_ids[start..end];
364
365            // Extract block from unitary
366            let block = self.extract_block(unitary, i, block_size)?;
367
368            let (block_tree, block_gates) =
369                self.decompose_recursive(&block, block_qubits, depth + 1)?;
370            gates.extend(block_gates);
371            children.push(block_tree);
372        }
373
374        let tree = DecompositionTree::Node {
375            qubits: qubit_ids.to_vec(),
376            method: DecompositionMethod::BlockDiagonal { block_size },
377            children,
378        };
379
380        Ok((tree, gates))
381    }
382
383    /// Compute Cosine-Sine Decomposition
384    fn compute_csd(
385        &self,
386        a: &Array2<Complex<f64>>,
387        b: &Array2<Complex<f64>>,
388        c: &Array2<Complex<f64>>,
389        d: &Array2<Complex<f64>>,
390    ) -> QuantRS2Result<(
391        Array2<Complex<f64>>, // U1
392        Array2<Complex<f64>>, // V1
393        Array2<Complex<f64>>, // Sigma
394        Array2<Complex<f64>>, // U2
395        Array2<Complex<f64>>, // V2
396    )> {
397        // This is a simplified placeholder
398        // Full CSD implementation would use specialized algorithms
399
400        let size = a.shape()[0];
401        let identity = Array2::eye(size);
402        let zero: Array2<Complex<f64>> = Array2::zeros((size, size));
403
404        // For now, return identity transformations
405        let u1 = identity.clone();
406        let v1 = identity.clone();
407        let u2 = identity.clone();
408        let v2 = identity.clone();
409
410        // Sigma would contain the CS angles
411        let mut sigma = Array2::zeros((size * 2, size * 2));
412        sigma.slice_mut(s![..size, ..size]).assign(a);
413        sigma.slice_mut(s![..size, size..]).assign(b);
414        sigma.slice_mut(s![size.., ..size]).assign(c);
415        sigma.slice_mut(s![size.., size..]).assign(d);
416
417        Ok((u1, v1, sigma, u2, v2))
418    }
419
420    /// Convert diagonal matrix to controlled rotation gates
421    fn diagonal_to_gates(
422        &self,
423        diagonal: &Array2<Complex<f64>>,
424        qubit_ids: &[QubitId],
425    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
426        let mut gates = Vec::new();
427
428        // Extract diagonal elements
429        let n = diagonal.shape()[0];
430        for i in 0..n {
431            let phase = diagonal[[i, i]].arg();
432            if phase.abs() > self.tolerance {
433                // Determine which qubits are in state |1⟩ for this diagonal element
434                let mut control_pattern = Vec::new();
435                let mut temp = i;
436                for j in 0..qubit_ids.len() {
437                    if temp & 1 == 1 {
438                        control_pattern.push(j);
439                    }
440                    temp >>= 1;
441                }
442
443                // Create multi-controlled phase gate
444                if control_pattern.is_empty() {
445                    // Global phase - can be ignored
446                } else if control_pattern.len() == 1 {
447                    // Single-qubit phase
448                    gates.push(Box::new(RotationZ {
449                        target: qubit_ids[control_pattern[0]],
450                        theta: phase,
451                    }) as Box<dyn GateOp>);
452                } else {
453                    // Multi-controlled phase - decompose further
454                    // For now, use simple decomposition
455                    let target = qubit_ids[control_pattern.pop().unwrap()];
456                    for &control_idx in &control_pattern {
457                        gates.push(Box::new(CNOT {
458                            control: qubit_ids[control_idx],
459                            target,
460                        }));
461                    }
462
463                    gates.push(Box::new(RotationZ {
464                        target,
465                        theta: phase,
466                    }) as Box<dyn GateOp>);
467
468                    // Uncompute CNOTs
469                    for &control_idx in control_pattern.iter().rev() {
470                        gates.push(Box::new(CNOT {
471                            control: qubit_ids[control_idx],
472                            target,
473                        }));
474                    }
475                }
476            }
477        }
478
479        Ok(gates)
480    }
481
482    /// Check if matrix has block diagonal structure
483    fn has_block_structure(&self, unitary: &Array2<Complex<f64>>, n: usize) -> bool {
484        // Simple check - look for zeros in off-diagonal blocks
485        let size = unitary.shape()[0];
486        let block_size = size / 2;
487
488        let mut off_diagonal_norm = 0.0;
489
490        // Check upper-right block
491        for i in 0..block_size {
492            for j in block_size..size {
493                off_diagonal_norm += unitary[[i, j]].norm_sqr();
494            }
495        }
496
497        // Check lower-left block
498        for i in block_size..size {
499            for j in 0..block_size {
500                off_diagonal_norm += unitary[[i, j]].norm_sqr();
501            }
502        }
503
504        off_diagonal_norm.sqrt() < self.tolerance
505    }
506
507    /// Extract a block from block-diagonal matrix
508    fn extract_block(
509        &self,
510        unitary: &Array2<Complex<f64>>,
511        block_idx: usize,
512        block_size: usize,
513    ) -> QuantRS2Result<Array2<Complex<f64>>> {
514        let size = 1 << block_size;
515        let start = block_idx * size;
516        let end = (block_idx + 1) * size;
517
518        Ok(unitary.slice(s![start..end, start..end]).to_owned())
519    }
520
521    /// Convert single-qubit decomposition to gates
522    fn single_qubit_to_gates(
523        &self,
524        decomp: &SingleQubitDecomposition,
525        qubit: QubitId,
526    ) -> Vec<Box<dyn GateOp>> {
527        let mut gates = Vec::new();
528
529        if decomp.theta1.abs() > self.tolerance {
530            gates.push(Box::new(RotationZ {
531                target: qubit,
532                theta: decomp.theta1,
533            }) as Box<dyn GateOp>);
534        }
535
536        if decomp.phi.abs() > self.tolerance {
537            gates.push(Box::new(RotationY {
538                target: qubit,
539                theta: decomp.phi,
540            }) as Box<dyn GateOp>);
541        }
542
543        if decomp.theta2.abs() > self.tolerance {
544            gates.push(Box::new(RotationZ {
545                target: qubit,
546                theta: decomp.theta2,
547            }) as Box<dyn GateOp>);
548        }
549
550        gates
551    }
552
553    /// Count CNOTs for different gate types
554    fn count_cnots(&self, gate_name: &str) -> usize {
555        match gate_name {
556            "CNOT" => 1,
557            "CZ" => 1,   // CZ = H·CNOT·H
558            "SWAP" => 3, // SWAP uses 3 CNOTs
559            _ => 0,
560        }
561    }
562
563    /// Check cache for existing decomposition
564    fn check_cache(&self, unitary: &Array2<Complex<f64>>) -> Option<&MultiQubitKAK> {
565        // Simple hash based on first few elements
566        // Real implementation would use better hashing
567        None
568    }
569
570    /// Cache decomposition result
571    fn cache_result(&mut self, unitary: &Array2<Complex<f64>>, result: &MultiQubitKAK) {
572        // Cache implementation
573    }
574}
575
576/// Analyze decomposition tree structure
577pub struct KAKTreeAnalyzer {
578    /// Track statistics
579    stats: DecompositionStats,
580}
581
582#[derive(Debug, Default, Clone)]
583pub struct DecompositionStats {
584    pub total_nodes: usize,
585    pub leaf_nodes: usize,
586    pub max_depth: usize,
587    pub method_counts: FxHashMap<String, usize>,
588    pub cnot_distribution: FxHashMap<usize, usize>,
589}
590
591impl KAKTreeAnalyzer {
592    /// Create new analyzer
593    pub fn new() -> Self {
594        Self {
595            stats: DecompositionStats::default(),
596        }
597    }
598
599    /// Analyze decomposition tree
600    pub fn analyze(&mut self, tree: &DecompositionTree) -> DecompositionStats {
601        self.stats = DecompositionStats::default();
602        self.analyze_recursive(tree, 0);
603        self.stats.clone()
604    }
605
606    fn analyze_recursive(&mut self, tree: &DecompositionTree, depth: usize) {
607        self.stats.total_nodes += 1;
608        self.stats.max_depth = self.stats.max_depth.max(depth);
609
610        match tree {
611            DecompositionTree::Leaf { qubits, gate_type } => {
612                self.stats.leaf_nodes += 1;
613
614                match gate_type {
615                    LeafType::SingleQubit(_) => {
616                        *self
617                            .stats
618                            .method_counts
619                            .entry("single_qubit".to_string())
620                            .or_insert(0) += 1;
621                    }
622                    LeafType::TwoQubit(cartan) => {
623                        *self
624                            .stats
625                            .method_counts
626                            .entry("two_qubit".to_string())
627                            .or_insert(0) += 1;
628                        let cnots = cartan.interaction.cnot_count(1e-10);
629                        *self.stats.cnot_distribution.entry(cnots).or_insert(0) += 1;
630                    }
631                }
632            }
633            DecompositionTree::Node {
634                method, children, ..
635            } => {
636                let method_name = match method {
637                    DecompositionMethod::CSD { .. } => "csd",
638                    DecompositionMethod::Shannon { .. } => "shannon",
639                    DecompositionMethod::BlockDiagonal { .. } => "block_diagonal",
640                    DecompositionMethod::Cartan => "cartan",
641                };
642                *self
643                    .stats
644                    .method_counts
645                    .entry(method_name.to_string())
646                    .or_insert(0) += 1;
647
648                for child in children {
649                    self.analyze_recursive(child, depth + 1);
650                }
651            }
652        }
653    }
654}
655
656/// Utility function for quick multi-qubit KAK decomposition
657pub fn kak_decompose_multiqubit(
658    unitary: &Array2<Complex<f64>>,
659    qubit_ids: &[QubitId],
660) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
661    let mut decomposer = MultiQubitKAKDecomposer::new();
662    let decomp = decomposer.decompose(unitary, qubit_ids)?;
663    Ok(decomp.gates)
664}
665
666#[cfg(test)]
667mod tests {
668    use super::*;
669    use ndarray::Array2;
670    use num_complex::Complex;
671
672    #[test]
673    fn test_multiqubit_kak_single() {
674        let mut decomposer = MultiQubitKAKDecomposer::new();
675
676        // Hadamard matrix
677        let h = Array2::from_shape_vec(
678            (2, 2),
679            vec![
680                Complex::new(1.0, 0.0),
681                Complex::new(1.0, 0.0),
682                Complex::new(1.0, 0.0),
683                Complex::new(-1.0, 0.0),
684            ],
685        )
686        .unwrap()
687            / Complex::new(2.0_f64.sqrt(), 0.0);
688
689        let qubit_ids = vec![QubitId(0)];
690        let decomp = decomposer.decompose(&h, &qubit_ids).unwrap();
691
692        assert!(decomp.single_qubit_count <= 3);
693        assert_eq!(decomp.cnot_count, 0);
694
695        // Check tree structure
696        match &decomp.tree {
697            DecompositionTree::Leaf {
698                gate_type: LeafType::SingleQubit(_),
699                ..
700            } => {}
701            _ => panic!("Expected single-qubit leaf"),
702        }
703    }
704
705    #[test]
706    fn test_multiqubit_kak_two() {
707        let mut decomposer = MultiQubitKAKDecomposer::new();
708
709        // CNOT matrix
710        let cnot = Array2::from_shape_vec(
711            (4, 4),
712            vec![
713                Complex::new(1.0, 0.0),
714                Complex::new(0.0, 0.0),
715                Complex::new(0.0, 0.0),
716                Complex::new(0.0, 0.0),
717                Complex::new(0.0, 0.0),
718                Complex::new(1.0, 0.0),
719                Complex::new(0.0, 0.0),
720                Complex::new(0.0, 0.0),
721                Complex::new(0.0, 0.0),
722                Complex::new(0.0, 0.0),
723                Complex::new(0.0, 0.0),
724                Complex::new(1.0, 0.0),
725                Complex::new(0.0, 0.0),
726                Complex::new(0.0, 0.0),
727                Complex::new(1.0, 0.0),
728                Complex::new(0.0, 0.0),
729            ],
730        )
731        .unwrap();
732
733        let qubit_ids = vec![QubitId(0), QubitId(1)];
734        let decomp = decomposer.decompose(&cnot, &qubit_ids).unwrap();
735
736        assert!(decomp.cnot_count <= 1);
737
738        // Check tree structure
739        match &decomp.tree {
740            DecompositionTree::Leaf {
741                gate_type: LeafType::TwoQubit(_),
742                ..
743            } => {}
744            _ => panic!("Expected two-qubit leaf"),
745        }
746    }
747
748    #[test]
749    fn test_multiqubit_kak_three() {
750        let mut decomposer = MultiQubitKAKDecomposer::new();
751
752        // 3-qubit identity
753        let identity = Array2::eye(8);
754        let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
755
756        let qubit_ids = vec![QubitId(0), QubitId(1), QubitId(2)];
757        let decomp = decomposer.decompose(&identity_complex, &qubit_ids).unwrap();
758
759        // Identity should result in empty circuit
760        assert_eq!(decomp.gates.len(), 0);
761        assert_eq!(decomp.cnot_count, 0);
762        assert_eq!(decomp.single_qubit_count, 0);
763    }
764
765    #[test]
766    fn test_tree_analyzer() {
767        let mut analyzer = KAKTreeAnalyzer::new();
768
769        // Create a simple tree
770        let tree = DecompositionTree::Node {
771            qubits: vec![QubitId(0), QubitId(1), QubitId(2)],
772            method: DecompositionMethod::CSD { pivot: 2 },
773            children: vec![
774                DecompositionTree::Leaf {
775                    qubits: vec![QubitId(0), QubitId(1)],
776                    gate_type: LeafType::TwoQubit(CartanDecomposition {
777                        left_gates: (
778                            SingleQubitDecomposition {
779                                global_phase: 0.0,
780                                theta1: 0.0,
781                                phi: 0.0,
782                                theta2: 0.0,
783                                basis: "ZYZ".to_string(),
784                            },
785                            SingleQubitDecomposition {
786                                global_phase: 0.0,
787                                theta1: 0.0,
788                                phi: 0.0,
789                                theta2: 0.0,
790                                basis: "ZYZ".to_string(),
791                            },
792                        ),
793                        right_gates: (
794                            SingleQubitDecomposition {
795                                global_phase: 0.0,
796                                theta1: 0.0,
797                                phi: 0.0,
798                                theta2: 0.0,
799                                basis: "ZYZ".to_string(),
800                            },
801                            SingleQubitDecomposition {
802                                global_phase: 0.0,
803                                theta1: 0.0,
804                                phi: 0.0,
805                                theta2: 0.0,
806                                basis: "ZYZ".to_string(),
807                            },
808                        ),
809                        interaction: CartanCoefficients::new(0.0, 0.0, 0.0),
810                        global_phase: 0.0,
811                    }),
812                },
813                DecompositionTree::Leaf {
814                    qubits: vec![QubitId(2)],
815                    gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
816                        global_phase: 0.0,
817                        theta1: 0.0,
818                        phi: 0.0,
819                        theta2: 0.0,
820                        basis: "ZYZ".to_string(),
821                    }),
822                },
823            ],
824        };
825
826        let stats = analyzer.analyze(&tree);
827
828        assert_eq!(stats.total_nodes, 3);
829        assert_eq!(stats.leaf_nodes, 2);
830        assert_eq!(stats.max_depth, 1);
831        assert_eq!(stats.method_counts.get("csd"), Some(&1));
832    }
833
834    #[test]
835    fn test_block_structure_detection() {
836        let decomposer = MultiQubitKAKDecomposer::new();
837
838        // Create block diagonal matrix
839        let mut block_diag = Array2::zeros((4, 4));
840        block_diag[[0, 0]] = Complex::new(1.0, 0.0);
841        block_diag[[1, 1]] = Complex::new(1.0, 0.0);
842        block_diag[[2, 2]] = Complex::new(1.0, 0.0);
843        block_diag[[3, 3]] = Complex::new(1.0, 0.0);
844
845        assert!(decomposer.has_block_structure(&block_diag, 2));
846
847        // Non-block diagonal
848        block_diag[[0, 2]] = Complex::new(1.0, 0.0);
849        assert!(!decomposer.has_block_structure(&block_diag, 2));
850    }
851}