Skip to main content

quantrs2_core/
kak_multiqubit.rs

1//! KAK decomposition for multi-qubit unitaries
2//!
3//! This module extends the Cartan (KAK) decomposition to handle arbitrary
4//! n-qubit unitaries through recursive application and generalized
5//! decomposition techniques.
6
7use crate::{
8    cartan::{CartanDecomposer, CartanDecomposition},
9    error::{QuantRS2Error, QuantRS2Result},
10    gate::{multi::*, single::*, GateOp},
11    matrix_ops::{DenseMatrix, QuantumMatrix},
12    qubit::QubitId,
13    shannon::ShannonDecomposer,
14    synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
15};
16use rustc_hash::FxHashMap;
17use scirs2_core::ndarray::{s, Array2};
18use scirs2_core::Complex;
19
20/// Result of multi-qubit KAK decomposition
21#[derive(Debug, Clone)]
22pub struct MultiQubitKAK {
23    /// The decomposed gate sequence
24    pub gates: Vec<Box<dyn GateOp>>,
25    /// Decomposition tree structure
26    pub tree: DecompositionTree,
27    /// Total CNOT count
28    pub cnot_count: usize,
29    /// Total single-qubit gate count
30    pub single_qubit_count: usize,
31    /// Circuit depth
32    pub depth: usize,
33}
34
35/// Tree structure representing the hierarchical decomposition
36#[derive(Debug, Clone)]
37pub enum DecompositionTree {
38    /// Leaf node - single or two-qubit gate
39    Leaf {
40        qubits: Vec<QubitId>,
41        gate_type: LeafType,
42    },
43    /// Internal node - recursive decomposition
44    Node {
45        qubits: Vec<QubitId>,
46        method: DecompositionMethod,
47        children: Vec<Self>,
48    },
49}
50
51/// Type of leaf decomposition
52#[derive(Debug, Clone)]
53pub enum LeafType {
54    SingleQubit(SingleQubitDecomposition),
55    TwoQubit(CartanDecomposition),
56}
57
58/// Method used for decomposition at this level
59#[derive(Debug, Clone)]
60pub enum DecompositionMethod {
61    /// Cosine-Sine Decomposition
62    CSD { pivot: usize },
63    /// Quantum Shannon Decomposition
64    Shannon { partition: usize },
65    /// Block diagonalization
66    BlockDiagonal { block_size: usize },
67    /// Direct Cartan for 2 qubits
68    Cartan,
69}
70
71/// Multi-qubit KAK decomposer
72pub struct MultiQubitKAKDecomposer {
73    /// Tolerance for numerical comparisons
74    tolerance: f64,
75    /// Maximum recursion depth
76    max_depth: usize,
77    /// Cache for decompositions
78    #[allow(dead_code)]
79    cache: FxHashMap<u64, MultiQubitKAK>,
80    /// Use optimized methods
81    use_optimization: bool,
82    /// Cartan decomposer for two-qubit blocks
83    cartan: CartanDecomposer,
84}
85
86impl MultiQubitKAKDecomposer {
87    /// Create a new multi-qubit KAK decomposer
88    pub fn new() -> Self {
89        Self {
90            tolerance: 1e-10,
91            max_depth: 20,
92            cache: FxHashMap::default(),
93            use_optimization: true,
94            cartan: CartanDecomposer::new(),
95        }
96    }
97
98    /// Create with custom tolerance
99    pub fn with_tolerance(tolerance: f64) -> Self {
100        Self {
101            tolerance,
102            max_depth: 20,
103            cache: FxHashMap::default(),
104            use_optimization: true,
105            cartan: CartanDecomposer::with_tolerance(tolerance),
106        }
107    }
108
109    /// Decompose an n-qubit unitary
110    pub fn decompose(
111        &mut self,
112        unitary: &Array2<Complex<f64>>,
113        qubit_ids: &[QubitId],
114    ) -> QuantRS2Result<MultiQubitKAK> {
115        let n = qubit_ids.len();
116        let size = 1 << n;
117
118        // Validate input
119        if unitary.shape() != [size, size] {
120            return Err(QuantRS2Error::InvalidInput(format!(
121                "Unitary size {} doesn't match {} qubits",
122                unitary.shape()[0],
123                n
124            )));
125        }
126
127        // Check unitarity
128        let mat = DenseMatrix::new(unitary.clone())?;
129        if !mat.is_unitary(self.tolerance)? {
130            return Err(QuantRS2Error::InvalidInput(
131                "Matrix is not unitary".to_string(),
132            ));
133        }
134
135        // Check cache
136        if let Some(cached) = self.check_cache(unitary) {
137            return Ok(cached.clone());
138        }
139
140        // Perform decomposition
141        let (tree, gates) = self.decompose_recursive(unitary, qubit_ids, 0)?;
142
143        // Count gates
144        let mut cnot_count = 0;
145        let mut single_qubit_count = 0;
146
147        for gate in &gates {
148            match gate.name() {
149                "CNOT" | "CZ" | "SWAP" => cnot_count += self.count_cnots(gate.name()),
150                _ => single_qubit_count += 1,
151            }
152        }
153
154        // Calculate actual circuit depth using critical path through DAG.
155        // For each gate in topological order:
156        //   depth[gate] = 1 + max(depth[prev_gate]) over all preceding gates sharing a qubit.
157        let depth = Self::calculate_circuit_depth(&gates);
158
159        let result = MultiQubitKAK {
160            gates,
161            tree,
162            cnot_count,
163            single_qubit_count,
164            depth,
165        };
166
167        // Cache result
168        self.cache_result(unitary, &result);
169
170        Ok(result)
171    }
172
173    /// Recursive decomposition algorithm
174    fn decompose_recursive(
175        &mut self,
176        unitary: &Array2<Complex<f64>>,
177        qubit_ids: &[QubitId],
178        depth: usize,
179    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
180        if depth > self.max_depth {
181            return Err(QuantRS2Error::InvalidInput(
182                "Maximum recursion depth exceeded".to_string(),
183            ));
184        }
185
186        let n = qubit_ids.len();
187
188        // Base cases
189        match n {
190            0 => {
191                let tree = DecompositionTree::Leaf {
192                    qubits: vec![],
193                    gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
194                        global_phase: 0.0,
195                        theta1: 0.0,
196                        phi: 0.0,
197                        theta2: 0.0,
198                        basis: "ZYZ".to_string(),
199                    }),
200                };
201                Ok((tree, vec![]))
202            }
203            1 => {
204                let decomp = decompose_single_qubit_zyz(&unitary.view())?;
205                let gates = self.single_qubit_to_gates(&decomp, qubit_ids[0]);
206                let tree = DecompositionTree::Leaf {
207                    qubits: qubit_ids.to_vec(),
208                    gate_type: LeafType::SingleQubit(decomp),
209                };
210                Ok((tree, gates))
211            }
212            2 => {
213                let decomp = self.cartan.decompose(unitary)?;
214                let gates = self.cartan.to_gates(&decomp, qubit_ids)?;
215                let tree = DecompositionTree::Leaf {
216                    qubits: qubit_ids.to_vec(),
217                    gate_type: LeafType::TwoQubit(decomp),
218                };
219                Ok((tree, gates))
220            }
221            _ => {
222                // For n > 2, choose decomposition method
223                let method = self.choose_decomposition_method(unitary, n);
224
225                match method {
226                    DecompositionMethod::CSD { pivot } => {
227                        self.decompose_csd(unitary, qubit_ids, pivot, depth)
228                    }
229                    DecompositionMethod::Shannon { partition } => {
230                        self.decompose_shannon(unitary, qubit_ids, partition, depth)
231                    }
232                    DecompositionMethod::BlockDiagonal { block_size } => {
233                        self.decompose_block_diagonal(unitary, qubit_ids, block_size, depth)
234                    }
235                    DecompositionMethod::Cartan => unreachable!("Invalid method for n > 2"),
236                }
237            }
238        }
239    }
240
241    /// Choose optimal decomposition method based on matrix structure
242    fn choose_decomposition_method(
243        &self,
244        unitary: &Array2<Complex<f64>>,
245        n: usize,
246    ) -> DecompositionMethod {
247        if self.use_optimization {
248            // Analyze matrix structure to choose optimal method
249            if self.has_block_structure(unitary, n) {
250                DecompositionMethod::BlockDiagonal { block_size: n / 2 }
251            } else if n % 2 == 0 {
252                // Even number of qubits - use CSD at midpoint
253                DecompositionMethod::CSD { pivot: n / 2 }
254            } else {
255                // Odd number - use Shannon decomposition
256                DecompositionMethod::Shannon { partition: n / 2 }
257            }
258        } else {
259            // Default to CSD
260            DecompositionMethod::CSD { pivot: n / 2 }
261        }
262    }
263
264    /// Decompose using Cosine-Sine Decomposition
265    fn decompose_csd(
266        &mut self,
267        unitary: &Array2<Complex<f64>>,
268        qubit_ids: &[QubitId],
269        pivot: usize,
270        depth: usize,
271    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
272        let n = qubit_ids.len();
273        // let _size = 1 << n;
274        let pivot_size = 1 << pivot;
275
276        // Split unitary into blocks based on pivot
277        // U = [A B]
278        //     [C D]
279        let a = unitary.slice(s![..pivot_size, ..pivot_size]).to_owned();
280        let b = unitary.slice(s![..pivot_size, pivot_size..]).to_owned();
281        let c = unitary.slice(s![pivot_size.., ..pivot_size]).to_owned();
282        let d = unitary.slice(s![pivot_size.., pivot_size..]).to_owned();
283
284        // Apply CSD to find:
285        // U = (U1 ⊗ V1) · Σ · (U2 ⊗ V2)
286        // where Σ is diagonal in the CSD basis
287
288        // This is a simplified version - full CSD would use SVD
289        let (u1, v1, sigma, u2, v2) = self.compute_csd(&a, &b, &c, &d)?;
290
291        let mut gates = Vec::new();
292        let mut children = Vec::new();
293
294        // Decompose U2 and V2 (right multiplications)
295        let left_qubits = &qubit_ids[..pivot];
296        let right_qubits = &qubit_ids[pivot..];
297
298        let (u2_tree, u2_gates) = self.decompose_recursive(&u2, left_qubits, depth + 1)?;
299        let (v2_tree, v2_gates) = self.decompose_recursive(&v2, right_qubits, depth + 1)?;
300
301        gates.extend(u2_gates);
302        gates.extend(v2_gates);
303        children.push(u2_tree);
304        children.push(v2_tree);
305
306        // Apply diagonal gates (controlled rotations)
307        let diag_gates = self.diagonal_to_gates(&sigma, qubit_ids)?;
308        gates.extend(diag_gates);
309
310        // Decompose U1 and V1 (left multiplications)
311        let (u1_tree, u1_gates) = self.decompose_recursive(&u1, left_qubits, depth + 1)?;
312        let (v1_tree, v1_gates) = self.decompose_recursive(&v1, right_qubits, depth + 1)?;
313
314        gates.extend(u1_gates);
315        gates.extend(v1_gates);
316        children.push(u1_tree);
317        children.push(v1_tree);
318
319        let tree = DecompositionTree::Node {
320            qubits: qubit_ids.to_vec(),
321            method: DecompositionMethod::CSD { pivot },
322            children,
323        };
324
325        Ok((tree, gates))
326    }
327
328    /// Decompose using Shannon decomposition
329    fn decompose_shannon(
330        &self,
331        unitary: &Array2<Complex<f64>>,
332        qubit_ids: &[QubitId],
333        partition: usize,
334        _depth: usize,
335    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
336        // Use the Shannon decomposer for this
337        let mut shannon = ShannonDecomposer::new();
338        let decomp = shannon.decompose(unitary, qubit_ids)?;
339
340        // Build tree structure
341        let tree = DecompositionTree::Node {
342            qubits: qubit_ids.to_vec(),
343            method: DecompositionMethod::Shannon { partition },
344            children: vec![], // Shannon decomposer doesn't provide tree structure
345        };
346
347        Ok((tree, decomp.gates))
348    }
349
350    /// Decompose block diagonal matrix
351    fn decompose_block_diagonal(
352        &mut self,
353        unitary: &Array2<Complex<f64>>,
354        qubit_ids: &[QubitId],
355        block_size: usize,
356        depth: usize,
357    ) -> QuantRS2Result<(DecompositionTree, Vec<Box<dyn GateOp>>)> {
358        let n = qubit_ids.len();
359        let num_blocks = n / block_size;
360
361        let mut gates = Vec::new();
362        let mut children = Vec::new();
363
364        // Decompose each block independently
365        for i in 0..num_blocks {
366            let start = i * block_size;
367            let end = (i + 1) * block_size;
368            let block_qubits = &qubit_ids[start..end];
369
370            // Extract block from unitary
371            let block = self.extract_block(unitary, i, block_size)?;
372
373            let (block_tree, block_gates) =
374                self.decompose_recursive(&block, block_qubits, depth + 1)?;
375            gates.extend(block_gates);
376            children.push(block_tree);
377        }
378
379        let tree = DecompositionTree::Node {
380            qubits: qubit_ids.to_vec(),
381            method: DecompositionMethod::BlockDiagonal { block_size },
382            children,
383        };
384
385        Ok((tree, gates))
386    }
387
388    /// Compute Cosine-Sine Decomposition
389    fn compute_csd(
390        &self,
391        a: &Array2<Complex<f64>>,
392        b: &Array2<Complex<f64>>,
393        c: &Array2<Complex<f64>>,
394        d: &Array2<Complex<f64>>,
395    ) -> QuantRS2Result<(
396        Array2<Complex<f64>>, // U1
397        Array2<Complex<f64>>, // V1
398        Array2<Complex<f64>>, // Sigma
399        Array2<Complex<f64>>, // U2
400        Array2<Complex<f64>>, // V2
401    )> {
402        // This is a simplified placeholder
403        // Full CSD implementation would use specialized algorithms
404
405        let size = a.shape()[0];
406        let identity = Array2::eye(size);
407        let _zero: Array2<Complex<f64>> = Array2::zeros((size, size));
408
409        // For now, return identity transformations
410        let u1 = identity.clone();
411        let v1 = identity.clone();
412        let u2 = identity.clone();
413        let v2 = identity;
414
415        // Sigma would contain the CS angles
416        let mut sigma = Array2::zeros((size * 2, size * 2));
417        sigma.slice_mut(s![..size, ..size]).assign(a);
418        sigma.slice_mut(s![..size, size..]).assign(b);
419        sigma.slice_mut(s![size.., ..size]).assign(c);
420        sigma.slice_mut(s![size.., size..]).assign(d);
421
422        Ok((u1, v1, sigma, u2, v2))
423    }
424
425    /// Convert diagonal matrix to controlled rotation gates
426    fn diagonal_to_gates(
427        &self,
428        diagonal: &Array2<Complex<f64>>,
429        qubit_ids: &[QubitId],
430    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
431        let mut gates = Vec::new();
432
433        // Extract diagonal elements
434        let n = diagonal.shape()[0];
435        for i in 0..n {
436            let phase = diagonal[[i, i]].arg();
437            if phase.abs() > self.tolerance {
438                // Determine which qubits are in state |1⟩ for this diagonal element
439                let mut control_pattern = Vec::new();
440                let mut temp = i;
441                for j in 0..qubit_ids.len() {
442                    if temp & 1 == 1 {
443                        control_pattern.push(j);
444                    }
445                    temp >>= 1;
446                }
447
448                // Create multi-controlled phase gate
449                if control_pattern.is_empty() {
450                    // Global phase - can be ignored
451                } else if control_pattern.len() == 1 {
452                    // Single-qubit phase
453                    gates.push(Box::new(RotationZ {
454                        target: qubit_ids[control_pattern[0]],
455                        theta: phase,
456                    }) as Box<dyn GateOp>);
457                } else {
458                    // Multi-controlled phase - decompose further
459                    // For now, use simple decomposition
460                    // Note: control_pattern.len() >= 2 at this point, so pop is safe
461                    let target_idx = control_pattern.pop().unwrap_or(0);
462                    let target = qubit_ids[target_idx];
463                    for &control_idx in &control_pattern {
464                        gates.push(Box::new(CNOT {
465                            control: qubit_ids[control_idx],
466                            target,
467                        }));
468                    }
469
470                    gates.push(Box::new(RotationZ {
471                        target,
472                        theta: phase,
473                    }) as Box<dyn GateOp>);
474
475                    // Uncompute CNOTs
476                    for &control_idx in control_pattern.iter().rev() {
477                        gates.push(Box::new(CNOT {
478                            control: qubit_ids[control_idx],
479                            target,
480                        }));
481                    }
482                }
483            }
484        }
485
486        Ok(gates)
487    }
488
489    /// Check if matrix has block diagonal structure
490    fn has_block_structure(&self, unitary: &Array2<Complex<f64>>, _n: usize) -> bool {
491        // Simple check - look for zeros in off-diagonal blocks
492        let size = unitary.shape()[0];
493        let block_size = size / 2;
494
495        let mut off_diagonal_norm = 0.0;
496
497        // Check upper-right block
498        for i in 0..block_size {
499            for j in block_size..size {
500                off_diagonal_norm += unitary[[i, j]].norm_sqr();
501            }
502        }
503
504        // Check lower-left block
505        for i in block_size..size {
506            for j in 0..block_size {
507                off_diagonal_norm += unitary[[i, j]].norm_sqr();
508            }
509        }
510
511        off_diagonal_norm.sqrt() < self.tolerance
512    }
513
514    /// Extract a block from block-diagonal matrix
515    fn extract_block(
516        &self,
517        unitary: &Array2<Complex<f64>>,
518        block_idx: usize,
519        block_size: usize,
520    ) -> QuantRS2Result<Array2<Complex<f64>>> {
521        let size = 1 << block_size;
522        let start = block_idx * size;
523        let end = (block_idx + 1) * size;
524
525        Ok(unitary.slice(s![start..end, start..end]).to_owned())
526    }
527
528    /// Convert single-qubit decomposition to gates
529    fn single_qubit_to_gates(
530        &self,
531        decomp: &SingleQubitDecomposition,
532        qubit: QubitId,
533    ) -> Vec<Box<dyn GateOp>> {
534        let mut gates = Vec::new();
535
536        if decomp.theta1.abs() > self.tolerance {
537            gates.push(Box::new(RotationZ {
538                target: qubit,
539                theta: decomp.theta1,
540            }) as Box<dyn GateOp>);
541        }
542
543        if decomp.phi.abs() > self.tolerance {
544            gates.push(Box::new(RotationY {
545                target: qubit,
546                theta: decomp.phi,
547            }) as Box<dyn GateOp>);
548        }
549
550        if decomp.theta2.abs() > self.tolerance {
551            gates.push(Box::new(RotationZ {
552                target: qubit,
553                theta: decomp.theta2,
554            }) as Box<dyn GateOp>);
555        }
556
557        gates
558    }
559
560    /// Count CNOTs for different gate types
561    fn count_cnots(&self, gate_name: &str) -> usize {
562        match gate_name {
563            "CNOT" | "CZ" => 1, // CZ = H·CNOT·H
564            "SWAP" => 3,        // SWAP uses 3 CNOTs
565            _ => 0,
566        }
567    }
568
569    /// Check cache for existing decomposition
570    /// Calculate circuit depth as the length of the critical path through the DAG.
571    ///
572    /// For each gate in topological order (gates are already ordered):
573    ///   `depth[i] = 1 + max(depth[j])` for all j < i that share at least one qubit with gate i.
574    ///
575    /// Uses a BFS/forward-pass approach since gates are given in topological order.
576    fn calculate_circuit_depth(gates: &[Box<dyn GateOp>]) -> usize {
577        if gates.is_empty() {
578            return 0;
579        }
580
581        // depth_at[i] = the depth level at which gate i completes (1-based)
582        let mut depth_at: Vec<usize> = vec![0; gates.len()];
583        // last_qubit_depth maps qubit id -> (gate_index, depth) of the last gate on that qubit
584        let mut last_qubit_finish: FxHashMap<u32, usize> = FxHashMap::default();
585
586        for (i, gate) in gates.iter().enumerate() {
587            let qubits = gate.qubits();
588            // Find the maximum finish depth among all preceding gates on shared qubits
589            let predecessor_max_depth = qubits
590                .iter()
591                .filter_map(|q| last_qubit_finish.get(&q.0).copied())
592                .max()
593                .unwrap_or(0);
594
595            depth_at[i] = predecessor_max_depth + 1;
596
597            // Update last finish depth for each qubit this gate touches
598            for q in &qubits {
599                last_qubit_finish.insert(q.0, depth_at[i]);
600            }
601        }
602
603        depth_at.into_iter().max().unwrap_or(0)
604    }
605
606    const fn check_cache(&self, _unitary: &Array2<Complex<f64>>) -> Option<&MultiQubitKAK> {
607        // Simple hash based on first few elements
608        // Real implementation would use better hashing
609        None
610    }
611
612    /// Cache decomposition result
613    const fn cache_result(&self, _unitary: &Array2<Complex<f64>>, _result: &MultiQubitKAK) {
614        // Cache implementation
615    }
616}
617
618impl Default for MultiQubitKAKDecomposer {
619    fn default() -> Self {
620        Self::new()
621    }
622}
623
624/// Analyze decomposition tree structure
625pub struct KAKTreeAnalyzer {
626    /// Track statistics
627    stats: DecompositionStats,
628}
629
630#[derive(Debug, Default, Clone)]
631pub struct DecompositionStats {
632    pub total_nodes: usize,
633    pub leaf_nodes: usize,
634    pub max_depth: usize,
635    pub method_counts: FxHashMap<String, usize>,
636    pub cnot_distribution: FxHashMap<usize, usize>,
637}
638
639impl KAKTreeAnalyzer {
640    /// Create new analyzer
641    pub fn new() -> Self {
642        Self {
643            stats: DecompositionStats::default(),
644        }
645    }
646
647    /// Analyze decomposition tree
648    pub fn analyze(&mut self, tree: &DecompositionTree) -> DecompositionStats {
649        self.stats = DecompositionStats::default();
650        self.analyze_recursive(tree, 0);
651        self.stats.clone()
652    }
653
654    fn analyze_recursive(&mut self, tree: &DecompositionTree, depth: usize) {
655        self.stats.total_nodes += 1;
656        self.stats.max_depth = self.stats.max_depth.max(depth);
657
658        match tree {
659            DecompositionTree::Leaf {
660                qubits: _qubits,
661                gate_type,
662            } => {
663                self.stats.leaf_nodes += 1;
664
665                match gate_type {
666                    LeafType::SingleQubit(_) => {
667                        *self
668                            .stats
669                            .method_counts
670                            .entry("single_qubit".to_string())
671                            .or_insert(0) += 1;
672                    }
673                    LeafType::TwoQubit(cartan) => {
674                        *self
675                            .stats
676                            .method_counts
677                            .entry("two_qubit".to_string())
678                            .or_insert(0) += 1;
679                        let cnots = cartan.interaction.cnot_count(1e-10);
680                        *self.stats.cnot_distribution.entry(cnots).or_insert(0) += 1;
681                    }
682                }
683            }
684            DecompositionTree::Node {
685                method, children, ..
686            } => {
687                let method_name = match method {
688                    DecompositionMethod::CSD { .. } => "csd",
689                    DecompositionMethod::Shannon { .. } => "shannon",
690                    DecompositionMethod::BlockDiagonal { .. } => "block_diagonal",
691                    DecompositionMethod::Cartan => "cartan",
692                };
693                *self
694                    .stats
695                    .method_counts
696                    .entry(method_name.to_string())
697                    .or_insert(0) += 1;
698
699                for child in children {
700                    self.analyze_recursive(child, depth + 1);
701                }
702            }
703        }
704    }
705}
706
707/// Utility function for quick multi-qubit KAK decomposition
708pub fn kak_decompose_multiqubit(
709    unitary: &Array2<Complex<f64>>,
710    qubit_ids: &[QubitId],
711) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
712    let mut decomposer = MultiQubitKAKDecomposer::new();
713    let decomp = decomposer.decompose(unitary, qubit_ids)?;
714    Ok(decomp.gates)
715}
716
717#[cfg(test)]
718mod tests {
719    use super::*;
720    use scirs2_core::ndarray::Array2;
721    use scirs2_core::Complex;
722
723    #[test]
724    fn test_multiqubit_kak_single() {
725        let mut decomposer = MultiQubitKAKDecomposer::new();
726
727        // Hadamard matrix
728        let h = Array2::from_shape_vec(
729            (2, 2),
730            vec![
731                Complex::new(1.0, 0.0),
732                Complex::new(1.0, 0.0),
733                Complex::new(1.0, 0.0),
734                Complex::new(-1.0, 0.0),
735            ],
736        )
737        .expect("Failed to create Hadamard matrix")
738            / Complex::new(2.0_f64.sqrt(), 0.0);
739
740        let qubit_ids = vec![QubitId(0)];
741        let decomp = decomposer
742            .decompose(&h, &qubit_ids)
743            .expect("Single-qubit KAK decomposition failed");
744
745        assert!(decomp.single_qubit_count <= 3);
746        assert_eq!(decomp.cnot_count, 0);
747
748        // Check tree structure
749        match &decomp.tree {
750            DecompositionTree::Leaf {
751                gate_type: LeafType::SingleQubit(_),
752                ..
753            } => {}
754            _ => panic!("Expected single-qubit leaf"),
755        }
756    }
757
758    #[test]
759    fn test_multiqubit_kak_two() {
760        let mut decomposer = MultiQubitKAKDecomposer::new();
761
762        // CNOT matrix
763        let cnot = Array2::from_shape_vec(
764            (4, 4),
765            vec![
766                Complex::new(1.0, 0.0),
767                Complex::new(0.0, 0.0),
768                Complex::new(0.0, 0.0),
769                Complex::new(0.0, 0.0),
770                Complex::new(0.0, 0.0),
771                Complex::new(1.0, 0.0),
772                Complex::new(0.0, 0.0),
773                Complex::new(0.0, 0.0),
774                Complex::new(0.0, 0.0),
775                Complex::new(0.0, 0.0),
776                Complex::new(0.0, 0.0),
777                Complex::new(1.0, 0.0),
778                Complex::new(0.0, 0.0),
779                Complex::new(0.0, 0.0),
780                Complex::new(1.0, 0.0),
781                Complex::new(0.0, 0.0),
782            ],
783        )
784        .expect("Failed to create CNOT matrix");
785
786        let qubit_ids = vec![QubitId(0), QubitId(1)];
787        let decomp = decomposer
788            .decompose(&cnot, &qubit_ids)
789            .expect("Two-qubit KAK decomposition failed");
790
791        assert!(decomp.cnot_count <= 1);
792
793        // Check tree structure
794        match &decomp.tree {
795            DecompositionTree::Leaf {
796                gate_type: LeafType::TwoQubit(_),
797                ..
798            } => {}
799            _ => panic!("Expected two-qubit leaf"),
800        }
801    }
802
803    #[test]
804    fn test_multiqubit_kak_three() {
805        let mut decomposer = MultiQubitKAKDecomposer::new();
806
807        // 3-qubit identity
808        let identity = Array2::eye(8);
809        let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
810
811        let qubit_ids = vec![QubitId(0), QubitId(1), QubitId(2)];
812        let decomp = decomposer
813            .decompose(&identity_complex, &qubit_ids)
814            .expect("Three-qubit KAK decomposition failed");
815
816        // Identity should result in empty circuit
817        assert_eq!(decomp.gates.len(), 0);
818        assert_eq!(decomp.cnot_count, 0);
819        assert_eq!(decomp.single_qubit_count, 0);
820    }
821
822    #[test]
823    fn test_tree_analyzer() {
824        let mut analyzer = KAKTreeAnalyzer::new();
825
826        // Create a simple tree
827        let tree = DecompositionTree::Node {
828            qubits: vec![QubitId(0), QubitId(1), QubitId(2)],
829            method: DecompositionMethod::CSD { pivot: 2 },
830            children: vec![
831                DecompositionTree::Leaf {
832                    qubits: vec![QubitId(0), QubitId(1)],
833                    gate_type: LeafType::TwoQubit(CartanDecomposition {
834                        left_gates: (
835                            SingleQubitDecomposition {
836                                global_phase: 0.0,
837                                theta1: 0.0,
838                                phi: 0.0,
839                                theta2: 0.0,
840                                basis: "ZYZ".to_string(),
841                            },
842                            SingleQubitDecomposition {
843                                global_phase: 0.0,
844                                theta1: 0.0,
845                                phi: 0.0,
846                                theta2: 0.0,
847                                basis: "ZYZ".to_string(),
848                            },
849                        ),
850                        right_gates: (
851                            SingleQubitDecomposition {
852                                global_phase: 0.0,
853                                theta1: 0.0,
854                                phi: 0.0,
855                                theta2: 0.0,
856                                basis: "ZYZ".to_string(),
857                            },
858                            SingleQubitDecomposition {
859                                global_phase: 0.0,
860                                theta1: 0.0,
861                                phi: 0.0,
862                                theta2: 0.0,
863                                basis: "ZYZ".to_string(),
864                            },
865                        ),
866                        interaction: crate::prelude::CartanCoefficients::new(0.0, 0.0, 0.0),
867                        global_phase: 0.0,
868                    }),
869                },
870                DecompositionTree::Leaf {
871                    qubits: vec![QubitId(2)],
872                    gate_type: LeafType::SingleQubit(SingleQubitDecomposition {
873                        global_phase: 0.0,
874                        theta1: 0.0,
875                        phi: 0.0,
876                        theta2: 0.0,
877                        basis: "ZYZ".to_string(),
878                    }),
879                },
880            ],
881        };
882
883        let stats = analyzer.analyze(&tree);
884
885        assert_eq!(stats.total_nodes, 3);
886        assert_eq!(stats.leaf_nodes, 2);
887        assert_eq!(stats.max_depth, 1);
888        assert_eq!(stats.method_counts.get("csd"), Some(&1));
889    }
890
891    #[test]
892    fn test_block_structure_detection() {
893        let decomposer = MultiQubitKAKDecomposer::new();
894
895        // Create block diagonal matrix
896        let mut block_diag = Array2::zeros((4, 4));
897        block_diag[[0, 0]] = Complex::new(1.0, 0.0);
898        block_diag[[1, 1]] = Complex::new(1.0, 0.0);
899        block_diag[[2, 2]] = Complex::new(1.0, 0.0);
900        block_diag[[3, 3]] = Complex::new(1.0, 0.0);
901
902        assert!(decomposer.has_block_structure(&block_diag, 2));
903
904        // Non-block diagonal
905        block_diag[[0, 2]] = Complex::new(1.0, 0.0);
906        assert!(!decomposer.has_block_structure(&block_diag, 2));
907    }
908}