quantrs2_core/
kak_multiqubit.rs

1//! KAK decomposition for multi-qubit unitaries
2//!
3//! This module extends the Cartan (KAK) decomposition to handle arbitrary
4//! n-qubit unitaries through recursive application and generalized
5//! decomposition techniques.
6
7use crate::{
8    cartan::{CartanDecomposer, CartanDecomposition},
9    error::{QuantRS2Error, QuantRS2Result},
10    gate::{multi::*, single::*, GateOp},
11    matrix_ops::{DenseMatrix, QuantumMatrix},
12    qubit::QubitId,
13    shannon::ShannonDecomposer,
14    synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
15};
16use ndarray::{s, Array2};
17use num_complex::Complex;
18use rustc_hash::FxHashMap;
19
20/// Result of multi-qubit KAK decomposition
21#[derive(Debug, Clone)]
22pub struct MultiQubitKAK {
23    /// The decomposed gate sequence
24    pub gates: Vec<Box<dyn GateOp>>,
25    /// Decomposition tree structure
26    pub tree: DecompositionTree,
27    /// Total CNOT count
28    pub cnot_count: usize,
29    /// Total single-qubit gate count
30    pub single_qubit_count: usize,
31    /// Circuit depth
32    pub depth: usize,
33}
34
35/// Tree structure representing the hierarchical decomposition
36#[derive(Debug, Clone)]
37pub enum DecompositionTree {
38    /// Leaf node - single or two-qubit gate
39    Leaf {
40        qubits: Vec<QubitId>,
41        gate_type: LeafType,
42    },
43    /// Internal node - recursive decomposition
44    Node {
45        qubits: Vec<QubitId>,
46        method: DecompositionMethod,
47        children: Vec<DecompositionTree>,
48    },
49}
50
51/// Type of leaf decomposition
52#[derive(Debug, Clone)]
53pub enum LeafType {
54    SingleQubit(SingleQubitDecomposition),
55    TwoQubit(CartanDecomposition),
56}
57
58/// Method used for decomposition at this level
59#[derive(Debug, Clone)]
60pub enum DecompositionMethod {
61    /// Cosine-Sine Decomposition
62    CSD { pivot: usize },
63    /// Quantum Shannon Decomposition
64    Shannon { partition: usize },
65    /// Block diagonalization
66    BlockDiagonal { block_size: usize },
67    /// Direct Cartan for 2 qubits
68    Cartan,
69}
70
71/// Multi-qubit KAK decomposer
72pub struct MultiQubitKAKDecomposer {
73    /// Tolerance for numerical comparisons
74    tolerance: f64,
75    /// Maximum recursion depth
76    max_depth: usize,
77    /// Cache for decompositions
78    #[allow(dead_code)]
79    cache: FxHashMap<u64, MultiQubitKAK>,
80    /// Use optimized methods
81    use_optimization: bool,
82    /// Cartan decomposer for two-qubit blocks
83    cartan: CartanDecomposer,
84}
85
86impl MultiQubitKAKDecomposer {
87    /// Create a new multi-qubit KAK decomposer
88    pub fn new() -> Self {
89        Self {
90            tolerance: 1e-10,
91            max_depth: 20,
92            cache: FxHashMap::default(),
93            use_optimization: true,
94            cartan: CartanDecomposer::new(),
95        }
96    }
97
98    /// Create with custom tolerance
99    pub fn with_tolerance(tolerance: f64) -> Self {
100        Self {
101            tolerance,
102            max_depth: 20,
103            cache: FxHashMap::default(),
104            use_optimization: true,
105            cartan: CartanDecomposer::with_tolerance(tolerance),
106        }
107    }
108
109    /// Decompose an n-qubit unitary
110    pub fn decompose(
111        &mut self,
112        unitary: &Array2<Complex<f64>>,
113        qubit_ids: &[QubitId],
114    ) -> QuantRS2Result<MultiQubitKAK> {
115        let n = qubit_ids.len();
116        let size = 1 << n;
117
118        // Validate input
119        if unitary.shape() != [size, size] {
120            return Err(QuantRS2Error::InvalidInput(format!(
121                "Unitary size {} doesn't match {} qubits",
122                unitary.shape()[0],
123                n
124            )));
125        }
126
127        // Check unitarity
128        let mat = DenseMatrix::new(unitary.clone())?;
129        if !mat.is_unitary(self.tolerance)? {
130            return Err(QuantRS2Error::InvalidInput(
131                "Matrix is not unitary".to_string(),
132            ));
133        }
134
135        // Check cache
136        if let Some(cached) = self.check_cache(unitary) {
137            return Ok(cached.clone());
138        }
139
140        // Perform decomposition
141        let (tree, gates) = self.decompose_recursive(unitary, qubit_ids, 0)?;
142
143        // Count gates
144        let mut cnot_count = 0;
145        let mut single_qubit_count = 0;
146
147        for gate in &gates {
148            match gate.name() {
149                "CNOT" | "CZ" | "SWAP" => cnot_count += self.count_cnots(gate.name()),
150                _ => single_qubit_count += 1,
151            }
152        }
153
154        let result = MultiQubitKAK {
155            gates,
156            tree,
157            cnot_count,
158            single_qubit_count,
159            depth: 0, // TODO: Calculate actual depth
160        };
161
162        // Cache result
163        self.cache_result(unitary, &result);
164
165        Ok(result)
166    }
167
168    /// Recursive decomposition algorithm
169    fn decompose_recursive(
170        &mut self,
171        unitary: &Array2<Complex<f64>>,
172        qubit_ids: &[QubitId],
173        depth: usize,
174    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
175        if depth > self.max_depth {
176            return Err(QuantRS2Error::InvalidInput(
177                "Maximum recursion depth exceeded".to_string(),
178            ));
179        }
180
181        let n = qubit_ids.len();
182
183        // Base cases
184        match n {
185            0 => {
186                let tree = DecompositionTree::Leaf {
187                    qubits: vec![],
188                    gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
189                        global_phase: 0.0,
190                        theta1: 0.0,
191                        phi: 0.0,
192                        theta2: 0.0,
193                        basis: "ZYZ".to_string(),
194                    }),
195                };
196                Ok((tree, vec![]))
197            }
198            1 => {
199                let decomp = decompose_single_qubit_zyz(&unitary.view())?;
200                let gates = self.single_qubit_to_gates(&decomp, qubit_ids[0]);
201                let tree = DecompositionTree::Leaf {
202                    qubits: qubit_ids.to_vec(),
203                    gate_type: LeafType::SingleQubit(decomp),
204                };
205                Ok((tree, gates))
206            }
207            2 => {
208                let decomp = self.cartan.decompose(unitary)?;
209                let gates = self.cartan.to_gates(&decomp, qubit_ids)?;
210                let tree = DecompositionTree::Leaf {
211                    qubits: qubit_ids.to_vec(),
212                    gate_type: LeafType::TwoQubit(decomp),
213                };
214                Ok((tree, gates))
215            }
216            _ => {
217                // For n > 2, choose decomposition method
218                let method = self.choose_decomposition_method(unitary, n);
219
220                match method {
221                    DecompositionMethod::CSD { pivot } => {
222                        self.decompose_csd(unitary, qubit_ids, pivot, depth)
223                    }
224                    DecompositionMethod::Shannon { partition } => {
225                        self.decompose_shannon(unitary, qubit_ids, partition, depth)
226                    }
227                    DecompositionMethod::BlockDiagonal { block_size } => {
228                        self.decompose_block_diagonal(unitary, qubit_ids, block_size, depth)
229                    }
230                    _ => unreachable!("Invalid method for n > 2"),
231                }
232            }
233        }
234    }
235
236    /// Choose optimal decomposition method based on matrix structure
237    fn choose_decomposition_method(
238        &self,
239        unitary: &Array2<Complex<f64>>,
240        n: usize,
241    ) -> DecompositionMethod {
242        if self.use_optimization {
243            // Analyze matrix structure to choose optimal method
244            if self.has_block_structure(unitary, n) {
245                DecompositionMethod::BlockDiagonal { block_size: n / 2 }
246            } else if n % 2 == 0 {
247                // Even number of qubits - use CSD at midpoint
248                DecompositionMethod::CSD { pivot: n / 2 }
249            } else {
250                // Odd number - use Shannon decomposition
251                DecompositionMethod::Shannon { partition: n / 2 }
252            }
253        } else {
254            // Default to CSD
255            DecompositionMethod::CSD { pivot: n / 2 }
256        }
257    }
258
259    /// Decompose using Cosine-Sine Decomposition
260    fn decompose_csd(
261        &mut self,
262        unitary: &Array2<Complex<f64>>,
263        qubit_ids: &[QubitId],
264        pivot: usize,
265        depth: usize,
266    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
267        let n = qubit_ids.len();
268        let _size = 1 << n;
269        let pivot_size = 1 << pivot;
270
271        // Split unitary into blocks based on pivot
272        // U = [A B]
273        //     [C D]
274        let a = unitary.slice(s![..pivot_size, ..pivot_size]).to_owned();
275        let b = unitary.slice(s![..pivot_size, pivot_size..]).to_owned();
276        let c = unitary.slice(s![pivot_size.., ..pivot_size]).to_owned();
277        let d = unitary.slice(s![pivot_size.., pivot_size..]).to_owned();
278
279        // Apply CSD to find:
280        // U = (U1 ⊗ V1) · Σ · (U2 ⊗ V2)
281        // where Σ is diagonal in the CSD basis
282
283        // This is a simplified version - full CSD would use SVD
284        let (u1, v1, sigma, u2, v2) = self.compute_csd(&a, &b, &c, &d)?;
285
286        let mut gates = Vec::new();
287        let mut children = Vec::new();
288
289        // Decompose U2 and V2 (right multiplications)
290        let left_qubits = &qubit_ids[..pivot];
291        let right_qubits = &qubit_ids[pivot..];
292
293        let (u2_tree, u2_gates) = self.decompose_recursive(&u2, left_qubits, depth + 1)?;
294        let (v2_tree, v2_gates) = self.decompose_recursive(&v2, right_qubits, depth + 1)?;
295
296        gates.extend(u2_gates);
297        gates.extend(v2_gates);
298        children.push(u2_tree);
299        children.push(v2_tree);
300
301        // Apply diagonal gates (controlled rotations)
302        let diag_gates = self.diagonal_to_gates(&sigma, qubit_ids)?;
303        gates.extend(diag_gates);
304
305        // Decompose U1 and V1 (left multiplications)
306        let (u1_tree, u1_gates) = self.decompose_recursive(&u1, left_qubits, depth + 1)?;
307        let (v1_tree, v1_gates) = self.decompose_recursive(&v1, right_qubits, depth + 1)?;
308
309        gates.extend(u1_gates);
310        gates.extend(v1_gates);
311        children.push(u1_tree);
312        children.push(v1_tree);
313
314        let tree = DecompositionTree::Node {
315            qubits: qubit_ids.to_vec(),
316            method: DecompositionMethod::CSD { pivot },
317            children,
318        };
319
320        Ok((tree, gates))
321    }
322
323    /// Decompose using Shannon decomposition
324    fn decompose_shannon(
325        &mut self,
326        unitary: &Array2<Complex<f64>>,
327        qubit_ids: &[QubitId],
328        partition: usize,
329        _depth: usize,
330    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
331        // Use the Shannon decomposer for this
332        let mut shannon = ShannonDecomposer::new();
333        let decomp = shannon.decompose(unitary, qubit_ids)?;
334
335        // Build tree structure
336        let tree = DecompositionTree::Node {
337            qubits: qubit_ids.to_vec(),
338            method: DecompositionMethod::Shannon { partition },
339            children: vec![], // Shannon decomposer doesn't provide tree structure
340        };
341
342        Ok((tree, decomp.gates))
343    }
344
345    /// Decompose block diagonal matrix
346    fn decompose_block_diagonal(
347        &mut self,
348        unitary: &Array2<Complex<f64>>,
349        qubit_ids: &[QubitId],
350        block_size: usize,
351        depth: usize,
352    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
353        let n = qubit_ids.len();
354        let num_blocks = n / block_size;
355
356        let mut gates = Vec::new();
357        let mut children = Vec::new();
358
359        // Decompose each block independently
360        for i in 0..num_blocks {
361            let start = i * block_size;
362            let end = (i + 1) * block_size;
363            let block_qubits = &qubit_ids[start..end];
364
365            // Extract block from unitary
366            let block = self.extract_block(unitary, i, block_size)?;
367
368            let (block_tree, block_gates) =
369                self.decompose_recursive(&block, block_qubits, depth + 1)?;
370            gates.extend(block_gates);
371            children.push(block_tree);
372        }
373
374        let tree = DecompositionTree::Node {
375            qubits: qubit_ids.to_vec(),
376            method: DecompositionMethod::BlockDiagonal { block_size },
377            children,
378        };
379
380        Ok((tree, gates))
381    }
382
383    /// Compute Cosine-Sine Decomposition
384    fn compute_csd(
385        &self,
386        a: &Array2<Complex<f64>>,
387        b: &Array2<Complex<f64>>,
388        c: &Array2<Complex<f64>>,
389        d: &Array2<Complex<f64>>,
390    ) -> QuantRS2Result<(
391        Array2<Complex<f64>>, // U1
392        Array2<Complex<f64>>, // V1
393        Array2<Complex<f64>>, // Sigma
394        Array2<Complex<f64>>, // U2
395        Array2<Complex<f64>>, // V2
396    )> {
397        // This is a simplified placeholder
398        // Full CSD implementation would use specialized algorithms
399
400        let size = a.shape()[0];
401        let identity = Array2::eye(size);
402        let _zero: Array2<Complex<f64>> = Array2::zeros((size, size));
403
404        // For now, return identity transformations
405        let u1 = identity.clone();
406        let v1 = identity.clone();
407        let u2 = identity.clone();
408        let v2 = identity.clone();
409
410        // Sigma would contain the CS angles
411        let mut sigma = Array2::zeros((size * 2, size * 2));
412        sigma.slice_mut(s![..size, ..size]).assign(a);
413        sigma.slice_mut(s![..size, size..]).assign(b);
414        sigma.slice_mut(s![size.., ..size]).assign(c);
415        sigma.slice_mut(s![size.., size..]).assign(d);
416
417        Ok((u1, v1, sigma, u2, v2))
418    }
419
420    /// Convert diagonal matrix to controlled rotation gates
421    fn diagonal_to_gates(
422        &self,
423        diagonal: &Array2<Complex<f64>>,
424        qubit_ids: &[QubitId],
425    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
426        let mut gates = Vec::new();
427
428        // Extract diagonal elements
429        let n = diagonal.shape()[0];
430        for i in 0..n {
431            let phase = diagonal[[i, i]].arg();
432            if phase.abs() > self.tolerance {
433                // Determine which qubits are in state |1⟩ for this diagonal element
434                let mut control_pattern = Vec::new();
435                let mut temp = i;
436                for j in 0..qubit_ids.len() {
437                    if temp & 1 == 1 {
438                        control_pattern.push(j);
439                    }
440                    temp >>= 1;
441                }
442
443                // Create multi-controlled phase gate
444                if control_pattern.is_empty() {
445                    // Global phase - can be ignored
446                } else if control_pattern.len() == 1 {
447                    // Single-qubit phase
448                    gates.push(Box::new(RotationZ {
449                        target: qubit_ids[control_pattern[0]],
450                        theta: phase,
451                    }) as Box<dyn GateOp>);
452                } else {
453                    // Multi-controlled phase - decompose further
454                    // For now, use simple decomposition
455                    let target = qubit_ids[control_pattern.pop().unwrap()];
456                    for &control_idx in &control_pattern {
457                        gates.push(Box::new(CNOT {
458                            control: qubit_ids[control_idx],
459                            target,
460                        }));
461                    }
462
463                    gates.push(Box::new(RotationZ {
464                        target,
465                        theta: phase,
466                    }) as Box<dyn GateOp>);
467
468                    // Uncompute CNOTs
469                    for &control_idx in control_pattern.iter().rev() {
470                        gates.push(Box::new(CNOT {
471                            control: qubit_ids[control_idx],
472                            target,
473                        }));
474                    }
475                }
476            }
477        }
478
479        Ok(gates)
480    }
481
482    /// Check if matrix has block diagonal structure
483    fn has_block_structure(&self, unitary: &Array2<Complex<f64>>, _n: usize) -> bool {
484        // Simple check - look for zeros in off-diagonal blocks
485        let size = unitary.shape()[0];
486        let block_size = size / 2;
487
488        let mut off_diagonal_norm = 0.0;
489
490        // Check upper-right block
491        for i in 0..block_size {
492            for j in block_size..size {
493                off_diagonal_norm += unitary[[i, j]].norm_sqr();
494            }
495        }
496
497        // Check lower-left block
498        for i in block_size..size {
499            for j in 0..block_size {
500                off_diagonal_norm += unitary[[i, j]].norm_sqr();
501            }
502        }
503
504        off_diagonal_norm.sqrt() < self.tolerance
505    }
506
507    /// Extract a block from block-diagonal matrix
508    fn extract_block(
509        &self,
510        unitary: &Array2<Complex<f64>>,
511        block_idx: usize,
512        block_size: usize,
513    ) -> QuantRS2Result<Array2<Complex<f64>>> {
514        let size = 1 << block_size;
515        let start = block_idx * size;
516        let end = (block_idx + 1) * size;
517
518        Ok(unitary.slice(s![start..end, start..end]).to_owned())
519    }
520
521    /// Convert single-qubit decomposition to gates
522    fn single_qubit_to_gates(
523        &self,
524        decomp: &SingleQubitDecomposition,
525        qubit: QubitId,
526    ) -> Vec<Box<dyn GateOp>> {
527        let mut gates = Vec::new();
528
529        if decomp.theta1.abs() > self.tolerance {
530            gates.push(Box::new(RotationZ {
531                target: qubit,
532                theta: decomp.theta1,
533            }) as Box<dyn GateOp>);
534        }
535
536        if decomp.phi.abs() > self.tolerance {
537            gates.push(Box::new(RotationY {
538                target: qubit,
539                theta: decomp.phi,
540            }) as Box<dyn GateOp>);
541        }
542
543        if decomp.theta2.abs() > self.tolerance {
544            gates.push(Box::new(RotationZ {
545                target: qubit,
546                theta: decomp.theta2,
547            }) as Box<dyn GateOp>);
548        }
549
550        gates
551    }
552
553    /// Count CNOTs for different gate types
554    fn count_cnots(&self, gate_name: &str) -> usize {
555        match gate_name {
556            "CNOT" => 1,
557            "CZ" => 1,   // CZ = H·CNOT·H
558            "SWAP" => 3, // SWAP uses 3 CNOTs
559            _ => 0,
560        }
561    }
562
563    /// Check cache for existing decomposition
564    fn check_cache(&self, _unitary: &Array2<Complex<f64>>) -> Option<&MultiQubitKAK> {
565        // Simple hash based on first few elements
566        // Real implementation would use better hashing
567        None
568    }
569
570    /// Cache decomposition result
571    fn cache_result(&mut self, _unitary: &Array2<Complex<f64>>, _result: &MultiQubitKAK) {
572        // Cache implementation
573    }
574}
575
576impl Default for MultiQubitKAKDecomposer {
577    fn default() -> Self {
578        Self::new()
579    }
580}
581
582/// Analyze decomposition tree structure
583pub struct KAKTreeAnalyzer {
584    /// Track statistics
585    stats: DecompositionStats,
586}
587
588#[derive(Debug, Default, Clone)]
589pub struct DecompositionStats {
590    pub total_nodes: usize,
591    pub leaf_nodes: usize,
592    pub max_depth: usize,
593    pub method_counts: FxHashMap<String, usize>,
594    pub cnot_distribution: FxHashMap<usize, usize>,
595}
596
597impl KAKTreeAnalyzer {
598    /// Create new analyzer
599    pub fn new() -> Self {
600        Self {
601            stats: DecompositionStats::default(),
602        }
603    }
604
605    /// Analyze decomposition tree
606    pub fn analyze(&mut self, tree: &DecompositionTree) -> DecompositionStats {
607        self.stats = DecompositionStats::default();
608        self.analyze_recursive(tree, 0);
609        self.stats.clone()
610    }
611
612    fn analyze_recursive(&mut self, tree: &DecompositionTree, depth: usize) {
613        self.stats.total_nodes += 1;
614        self.stats.max_depth = self.stats.max_depth.max(depth);
615
616        match tree {
617            DecompositionTree::Leaf {
618                qubits: _qubits,
619                gate_type,
620            } => {
621                self.stats.leaf_nodes += 1;
622
623                match gate_type {
624                    LeafType::SingleQubit(_) => {
625                        *self
626                            .stats
627                            .method_counts
628                            .entry("single_qubit".to_string())
629                            .or_insert(0) += 1;
630                    }
631                    LeafType::TwoQubit(cartan) => {
632                        *self
633                            .stats
634                            .method_counts
635                            .entry("two_qubit".to_string())
636                            .or_insert(0) += 1;
637                        let cnots = cartan.interaction.cnot_count(1e-10);
638                        *self.stats.cnot_distribution.entry(cnots).or_insert(0) += 1;
639                    }
640                }
641            }
642            DecompositionTree::Node {
643                method, children, ..
644            } => {
645                let method_name = match method {
646                    DecompositionMethod::CSD { .. } => "csd",
647                    DecompositionMethod::Shannon { .. } => "shannon",
648                    DecompositionMethod::BlockDiagonal { .. } => "block_diagonal",
649                    DecompositionMethod::Cartan => "cartan",
650                };
651                *self
652                    .stats
653                    .method_counts
654                    .entry(method_name.to_string())
655                    .or_insert(0) += 1;
656
657                for child in children {
658                    self.analyze_recursive(child, depth + 1);
659                }
660            }
661        }
662    }
663}
664
665/// Utility function for quick multi-qubit KAK decomposition
666pub fn kak_decompose_multiqubit(
667    unitary: &Array2<Complex<f64>>,
668    qubit_ids: &[QubitId],
669) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
670    let mut decomposer = MultiQubitKAKDecomposer::new();
671    let decomp = decomposer.decompose(unitary, qubit_ids)?;
672    Ok(decomp.gates)
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use ndarray::Array2;
679    use num_complex::Complex;
680
681    #[test]
682    fn test_multiqubit_kak_single() {
683        let mut decomposer = MultiQubitKAKDecomposer::new();
684
685        // Hadamard matrix
686        let h = Array2::from_shape_vec(
687            (2, 2),
688            vec![
689                Complex::new(1.0, 0.0),
690                Complex::new(1.0, 0.0),
691                Complex::new(1.0, 0.0),
692                Complex::new(-1.0, 0.0),
693            ],
694        )
695        .unwrap()
696            / Complex::new(2.0_f64.sqrt(), 0.0);
697
698        let qubit_ids = vec![QubitId(0)];
699        let decomp = decomposer.decompose(&h, &qubit_ids).unwrap();
700
701        assert!(decomp.single_qubit_count <= 3);
702        assert_eq!(decomp.cnot_count, 0);
703
704        // Check tree structure
705        match &decomp.tree {
706            DecompositionTree::Leaf {
707                gate_type: LeafType::SingleQubit(_),
708                ..
709            } => {}
710            _ => panic!("Expected single-qubit leaf"),
711        }
712    }
713
714    #[test]
715    fn test_multiqubit_kak_two() {
716        let mut decomposer = MultiQubitKAKDecomposer::new();
717
718        // CNOT matrix
719        let cnot = Array2::from_shape_vec(
720            (4, 4),
721            vec![
722                Complex::new(1.0, 0.0),
723                Complex::new(0.0, 0.0),
724                Complex::new(0.0, 0.0),
725                Complex::new(0.0, 0.0),
726                Complex::new(0.0, 0.0),
727                Complex::new(1.0, 0.0),
728                Complex::new(0.0, 0.0),
729                Complex::new(0.0, 0.0),
730                Complex::new(0.0, 0.0),
731                Complex::new(0.0, 0.0),
732                Complex::new(0.0, 0.0),
733                Complex::new(1.0, 0.0),
734                Complex::new(0.0, 0.0),
735                Complex::new(0.0, 0.0),
736                Complex::new(1.0, 0.0),
737                Complex::new(0.0, 0.0),
738            ],
739        )
740        .unwrap();
741
742        let qubit_ids = vec![QubitId(0), QubitId(1)];
743        let decomp = decomposer.decompose(&cnot, &qubit_ids).unwrap();
744
745        assert!(decomp.cnot_count <= 1);
746
747        // Check tree structure
748        match &decomp.tree {
749            DecompositionTree::Leaf {
750                gate_type: LeafType::TwoQubit(_),
751                ..
752            } => {}
753            _ => panic!("Expected two-qubit leaf"),
754        }
755    }
756
757    #[test]
758    fn test_multiqubit_kak_three() {
759        let mut decomposer = MultiQubitKAKDecomposer::new();
760
761        // 3-qubit identity
762        let identity = Array2::eye(8);
763        let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
764
765        let qubit_ids = vec![QubitId(0), QubitId(1), QubitId(2)];
766        let decomp = decomposer.decompose(&identity_complex, &qubit_ids).unwrap();
767
768        // Identity should result in empty circuit
769        assert_eq!(decomp.gates.len(), 0);
770        assert_eq!(decomp.cnot_count, 0);
771        assert_eq!(decomp.single_qubit_count, 0);
772    }
773
774    #[test]
775    fn test_tree_analyzer() {
776        let mut analyzer = KAKTreeAnalyzer::new();
777
778        // Create a simple tree
779        let tree = DecompositionTree::Node {
780            qubits: vec![QubitId(0), QubitId(1), QubitId(2)],
781            method: DecompositionMethod::CSD { pivot: 2 },
782            children: vec![
783                DecompositionTree::Leaf {
784                    qubits: vec![QubitId(0), QubitId(1)],
785                    gate_type: LeafType::TwoQubit(CartanDecomposition {
786                        left_gates: (
787                            SingleQubitDecomposition {
788                                global_phase: 0.0,
789                                theta1: 0.0,
790                                phi: 0.0,
791                                theta2: 0.0,
792                                basis: "ZYZ".to_string(),
793                            },
794                            SingleQubitDecomposition {
795                                global_phase: 0.0,
796                                theta1: 0.0,
797                                phi: 0.0,
798                                theta2: 0.0,
799                                basis: "ZYZ".to_string(),
800                            },
801                        ),
802                        right_gates: (
803                            SingleQubitDecomposition {
804                                global_phase: 0.0,
805                                theta1: 0.0,
806                                phi: 0.0,
807                                theta2: 0.0,
808                                basis: "ZYZ".to_string(),
809                            },
810                            SingleQubitDecomposition {
811                                global_phase: 0.0,
812                                theta1: 0.0,
813                                phi: 0.0,
814                                theta2: 0.0,
815                                basis: "ZYZ".to_string(),
816                            },
817                        ),
818                        interaction: crate::prelude::CartanCoefficients::new(0.0, 0.0, 0.0),
819                        global_phase: 0.0,
820                    }),
821                },
822                DecompositionTree::Leaf {
823                    qubits: vec![QubitId(2)],
824                    gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
825                        global_phase: 0.0,
826                        theta1: 0.0,
827                        phi: 0.0,
828                        theta2: 0.0,
829                        basis: "ZYZ".to_string(),
830                    }),
831                },
832            ],
833        };
834
835        let stats = analyzer.analyze(&tree);
836
837        assert_eq!(stats.total_nodes, 3);
838        assert_eq!(stats.leaf_nodes, 2);
839        assert_eq!(stats.max_depth, 1);
840        assert_eq!(stats.method_counts.get("csd"), Some(&1));
841    }
842
843    #[test]
844    fn test_block_structure_detection() {
845        let decomposer = MultiQubitKAKDecomposer::new();
846
847        // Create block diagonal matrix
848        let mut block_diag = Array2::zeros((4, 4));
849        block_diag[[0, 0]] = Complex::new(1.0, 0.0);
850        block_diag[[1, 1]] = Complex::new(1.0, 0.0);
851        block_diag[[2, 2]] = Complex::new(1.0, 0.0);
852        block_diag[[3, 3]] = Complex::new(1.0, 0.0);
853
854        assert!(decomposer.has_block_structure(&block_diag, 2));
855
856        // Non-block diagonal
857        block_diag[[0, 2]] = Complex::new(1.0, 0.0);
858        assert!(!decomposer.has_block_structure(&block_diag, 2));
859    }
860}