quantrs2_core/
synthesis.rs

1//! Gate synthesis from unitary matrices
2//!
3//! This module provides algorithms to decompose arbitrary unitary matrices
4//! into sequences of quantum gates, including:
5//! - Single-qubit unitary decomposition (ZYZ, XYX, etc.)
6//! - Two-qubit unitary decomposition (KAK/Cartan)
7//! - General n-qubit synthesis using Cosine-Sine decomposition
8
9use crate::cartan::{CartanDecomposer, CartanDecomposition};
10// use crate::controlled::{make_controlled, ControlledGate};
11use crate::error::{QuantRS2Error, QuantRS2Result};
12use crate::gate::{single::*, GateOp};
13use crate::matrix_ops::{matrices_approx_equal, DenseMatrix, QuantumMatrix};
14use crate::qubit::QubitId;
15use scirs2_core::ndarray::{Array2, ArrayView2};
16use scirs2_core::Complex64;
17use std::f64::consts::PI;
18
19/// Result of single-qubit decomposition
20#[derive(Debug, Clone)]
21pub struct SingleQubitDecomposition {
22    /// Global phase
23    pub global_phase: f64,
24    /// First rotation angle (Z or X depending on basis)
25    pub theta1: f64,
26    /// Middle rotation angle (Y)
27    pub phi: f64,
28    /// Last rotation angle (Z or X depending on basis)
29    pub theta2: f64,
30    /// The basis used (e.g., "ZYZ", "XYX")
31    pub basis: String,
32}
33
34/// Decompose a single-qubit unitary into ZYZ rotations
35pub fn decompose_single_qubit_zyz(
36    unitary: &ArrayView2<Complex64>,
37) -> QuantRS2Result<SingleQubitDecomposition> {
38    if unitary.shape() != &[2, 2] {
39        return Err(QuantRS2Error::InvalidInput(
40            "Single-qubit unitary must be 2x2".to_string(),
41        ));
42    }
43
44    // Check unitarity
45    let matrix = DenseMatrix::new(unitary.to_owned())?;
46    if !matrix.is_unitary(1e-10)? {
47        return Err(QuantRS2Error::InvalidInput(
48            "Matrix is not unitary".to_string(),
49        ));
50    }
51
52    // Extract matrix elements
53    let a = unitary[[0, 0]];
54    let b = unitary[[0, 1]];
55    let c = unitary[[1, 0]];
56    let d = unitary[[1, 1]];
57
58    // Calculate global phase from determinant
59    let det = a * d - b * c;
60    let global_phase = det.arg() / 2.0;
61
62    // Normalize by the determinant to make the matrix special unitary
63    let det_sqrt = det.sqrt();
64    let a = a / det_sqrt;
65    let b = b / det_sqrt;
66    let c = c / det_sqrt;
67    let d = d / det_sqrt;
68
69    // Decompose into ZYZ angles
70    // U = e^(i*global_phase) * Rz(theta2) * Ry(phi) * Rz(theta1)
71
72    let phi = 2.0 * a.norm().acos();
73
74    let (theta1, theta2) = if phi.abs() < 1e-10 {
75        // Identity or phase gate
76        let phase = if a.norm() > 0.5 {
77            a.arg() * 2.0
78        } else {
79            d.arg() * 2.0
80        };
81        (0.0, phase)
82    } else if (phi - PI).abs() < 1e-10 {
83        // Pi rotation
84        (-b.arg() + c.arg() + PI, 0.0)
85    } else {
86        let theta1 = c.arg() - b.arg();
87        let theta2 = a.arg() + b.arg();
88        (theta1, theta2)
89    };
90
91    Ok(SingleQubitDecomposition {
92        global_phase,
93        theta1,
94        phi,
95        theta2,
96        basis: "ZYZ".to_string(),
97    })
98}
99
100/// Decompose a single-qubit unitary into XYX rotations
101pub fn decompose_single_qubit_xyx(
102    unitary: &ArrayView2<Complex64>,
103) -> QuantRS2Result<SingleQubitDecomposition> {
104    // Convert to Pauli basis and use ZYZ decomposition
105    // Safety: 2x2 shape with 4 elements is guaranteed valid
106    let h_gate = Array2::from_shape_vec(
107        (2, 2),
108        vec![
109            Complex64::new(1.0, 0.0),
110            Complex64::new(1.0, 0.0),
111            Complex64::new(1.0, 0.0),
112            Complex64::new(-1.0, 0.0),
113        ],
114    )
115    .expect("2x2 Hadamard matrix shape is always valid")
116        / Complex64::new(2.0_f64.sqrt(), 0.0);
117
118    // Transform: U' = H * U * H
119    let u_transformed = h_gate.dot(unitary).dot(&h_gate);
120    let decomp = decompose_single_qubit_zyz(&u_transformed.view())?;
121
122    Ok(SingleQubitDecomposition {
123        global_phase: decomp.global_phase,
124        theta1: decomp.theta1,
125        phi: decomp.phi,
126        theta2: decomp.theta2,
127        basis: "XYX".to_string(),
128    })
129}
130
131/// Convert single-qubit decomposition to gate sequence
132pub fn single_qubit_gates(
133    decomp: &SingleQubitDecomposition,
134    qubit: QubitId,
135) -> Vec<Box<dyn GateOp>> {
136    let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
137
138    match decomp.basis.as_str() {
139        "ZYZ" => {
140            if decomp.theta1.abs() > 1e-10 {
141                gates.push(Box::new(RotationZ {
142                    target: qubit,
143                    theta: decomp.theta1,
144                }));
145            }
146            if decomp.phi.abs() > 1e-10 {
147                gates.push(Box::new(RotationY {
148                    target: qubit,
149                    theta: decomp.phi,
150                }));
151            }
152            if decomp.theta2.abs() > 1e-10 {
153                gates.push(Box::new(RotationZ {
154                    target: qubit,
155                    theta: decomp.theta2,
156                }));
157            }
158        }
159        "XYX" => {
160            if decomp.theta1.abs() > 1e-10 {
161                gates.push(Box::new(RotationX {
162                    target: qubit,
163                    theta: decomp.theta1,
164                }));
165            }
166            if decomp.phi.abs() > 1e-10 {
167                gates.push(Box::new(RotationY {
168                    target: qubit,
169                    theta: decomp.phi,
170                }));
171            }
172            if decomp.theta2.abs() > 1e-10 {
173                gates.push(Box::new(RotationX {
174                    target: qubit,
175                    theta: decomp.theta2,
176                }));
177            }
178        }
179        _ => {} // Unknown basis
180    }
181
182    gates
183}
184
185/// Result of two-qubit KAK decomposition (alias for CartanDecomposition)
186pub type KAKDecomposition = CartanDecomposition;
187
188/// Decompose a two-qubit unitary using KAK decomposition
189pub fn decompose_two_qubit_kak(
190    unitary: &ArrayView2<Complex64>,
191) -> QuantRS2Result<KAKDecomposition> {
192    // Use Cartan decomposer for KAK decomposition
193    let mut decomposer = CartanDecomposer::new();
194    let owned_unitary = unitary.to_owned();
195    decomposer.decompose(&owned_unitary)
196}
197
198/// Convert KAK decomposition to gate sequence
199pub fn kak_to_gates(
200    decomp: &KAKDecomposition,
201    qubit1: QubitId,
202    qubit2: QubitId,
203) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
204    // Use CartanDecomposer to convert to gates
205    let decomposer = CartanDecomposer::new();
206    let qubit_ids = vec![qubit1, qubit2];
207    decomposer.to_gates(decomp, &qubit_ids)
208}
209
210/// Synthesize an arbitrary unitary matrix into quantum gates
211pub fn synthesize_unitary(
212    unitary: &ArrayView2<Complex64>,
213    qubits: &[QubitId],
214) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
215    let n = unitary.nrows();
216
217    if n != unitary.ncols() {
218        return Err(QuantRS2Error::InvalidInput(
219            "Matrix must be square".to_string(),
220        ));
221    }
222
223    let num_qubits = (n as f64).log2() as usize;
224    if (1 << num_qubits) != n {
225        return Err(QuantRS2Error::InvalidInput(
226            "Matrix dimension must be a power of 2".to_string(),
227        ));
228    }
229
230    if qubits.len() != num_qubits {
231        return Err(QuantRS2Error::InvalidInput(format!(
232            "Need {} qubits, got {}",
233            num_qubits,
234            qubits.len()
235        )));
236    }
237
238    // Check unitarity
239    let matrix = DenseMatrix::new(unitary.to_owned())?;
240    if !matrix.is_unitary(1e-10)? {
241        return Err(QuantRS2Error::InvalidInput(
242            "Matrix is not unitary".to_string(),
243        ));
244    }
245
246    match num_qubits {
247        1 => {
248            let decomp = decompose_single_qubit_zyz(unitary)?;
249            Ok(single_qubit_gates(&decomp, qubits[0]))
250        }
251        2 => {
252            let decomp = decompose_two_qubit_kak(unitary)?;
253            kak_to_gates(&decomp, qubits[0], qubits[1])
254        }
255        _ => {
256            // For n-qubit gates, use recursive decomposition
257            // This is a placeholder - would implement Cosine-Sine decomposition
258            Err(QuantRS2Error::UnsupportedOperation(format!(
259                "Synthesis for {num_qubits}-qubit gates not yet implemented"
260            )))
261        }
262    }
263}
264
265/// Check if a unitary is close to a known gate
266pub fn identify_gate(unitary: &ArrayView2<Complex64>, tolerance: f64) -> Option<String> {
267    let n = unitary.nrows();
268
269    match n {
270        2 => {
271            // Check common single-qubit gates
272            let gates = vec![
273                ("I", Array2::eye(2)),
274                // Safety: All 2x2 shapes with 4 elements are guaranteed valid
275                (
276                    "X",
277                    Array2::from_shape_vec(
278                        (2, 2),
279                        vec![
280                            Complex64::new(0.0, 0.0),
281                            Complex64::new(1.0, 0.0),
282                            Complex64::new(1.0, 0.0),
283                            Complex64::new(0.0, 0.0),
284                        ],
285                    )
286                    .expect("2x2 X gate shape is always valid"),
287                ),
288                (
289                    "Y",
290                    Array2::from_shape_vec(
291                        (2, 2),
292                        vec![
293                            Complex64::new(0.0, 0.0),
294                            Complex64::new(0.0, -1.0),
295                            Complex64::new(0.0, 1.0),
296                            Complex64::new(0.0, 0.0),
297                        ],
298                    )
299                    .expect("2x2 Y gate shape is always valid"),
300                ),
301                (
302                    "Z",
303                    Array2::from_shape_vec(
304                        (2, 2),
305                        vec![
306                            Complex64::new(1.0, 0.0),
307                            Complex64::new(0.0, 0.0),
308                            Complex64::new(0.0, 0.0),
309                            Complex64::new(-1.0, 0.0),
310                        ],
311                    )
312                    .expect("2x2 Z gate shape is always valid"),
313                ),
314                (
315                    "H",
316                    Array2::from_shape_vec(
317                        (2, 2),
318                        vec![
319                            Complex64::new(1.0, 0.0),
320                            Complex64::new(1.0, 0.0),
321                            Complex64::new(1.0, 0.0),
322                            Complex64::new(-1.0, 0.0),
323                        ],
324                    )
325                    .expect("2x2 H gate shape is always valid")
326                        / Complex64::new(2.0_f64.sqrt(), 0.0),
327                ),
328            ];
329
330            for (name, gate) in gates {
331                if matrices_approx_equal(unitary, &gate.view(), tolerance) {
332                    return Some(name.to_string());
333                }
334            }
335        }
336        4 => {
337            // Check common two-qubit gates
338            let mut cnot = Array2::eye(4);
339            cnot[[2, 2]] = Complex64::new(0.0, 0.0);
340            cnot[[2, 3]] = Complex64::new(1.0, 0.0);
341            cnot[[3, 2]] = Complex64::new(1.0, 0.0);
342            cnot[[3, 3]] = Complex64::new(0.0, 0.0);
343
344            if matrices_approx_equal(unitary, &cnot.view(), tolerance) {
345                return Some("CNOT".to_string());
346            }
347        }
348        _ => {}
349    }
350
351    None
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    #[ignore] // TODO: Fix ZYZ decomposition algorithm
360    fn test_single_qubit_decomposition() {
361        // Test Hadamard gate
362        let h = Array2::from_shape_vec(
363            (2, 2),
364            vec![
365                Complex64::new(1.0, 0.0),
366                Complex64::new(1.0, 0.0),
367                Complex64::new(1.0, 0.0),
368                Complex64::new(-1.0, 0.0),
369            ],
370        )
371        .expect("Hadamard matrix shape is always valid 2x2")
372            / Complex64::new(2.0_f64.sqrt(), 0.0);
373
374        let decomp =
375            decompose_single_qubit_zyz(&h.view()).expect("ZYZ decomposition should succeed");
376
377        // Reconstruct and verify
378        let rz1 = Array2::from_shape_vec(
379            (2, 2),
380            vec![
381                Complex64::new(0.0, -decomp.theta1 / 2.0).exp(),
382                Complex64::new(0.0, 0.0),
383                Complex64::new(0.0, 0.0),
384                Complex64::new(0.0, decomp.theta1 / 2.0).exp(),
385            ],
386        )
387        .expect("Rz1 matrix shape is always valid 2x2");
388
389        let ry = Array2::from_shape_vec(
390            (2, 2),
391            vec![
392                Complex64::new((decomp.phi / 2.0).cos(), 0.0),
393                Complex64::new(-(decomp.phi / 2.0).sin(), 0.0),
394                Complex64::new((decomp.phi / 2.0).sin(), 0.0),
395                Complex64::new((decomp.phi / 2.0).cos(), 0.0),
396            ],
397        )
398        .expect("Ry matrix shape is always valid 2x2");
399
400        let rz2 = Array2::from_shape_vec(
401            (2, 2),
402            vec![
403                Complex64::new(0.0, -decomp.theta2 / 2.0).exp(),
404                Complex64::new(0.0, 0.0),
405                Complex64::new(0.0, 0.0),
406                Complex64::new(0.0, decomp.theta2 / 2.0).exp(),
407            ],
408        )
409        .expect("Rz2 matrix shape is always valid 2x2");
410
411        // Reconstruct: e^(i*global_phase) * Rz(theta2) * Ry(phi) * Rz(theta1)
412        let reconstructed = Complex64::new(0.0, decomp.global_phase).exp() * rz2.dot(&ry).dot(&rz1);
413
414        // Check reconstruction
415        assert!(matrices_approx_equal(
416            &h.view(),
417            &reconstructed.view(),
418            1e-10
419        ));
420    }
421
422    #[test]
423    fn test_gate_identification() {
424        let x = Array2::from_shape_vec(
425            (2, 2),
426            vec![
427                Complex64::new(0.0, 0.0),
428                Complex64::new(1.0, 0.0),
429                Complex64::new(1.0, 0.0),
430                Complex64::new(0.0, 0.0),
431            ],
432        )
433        .expect("X gate matrix shape is always valid 2x2");
434
435        assert_eq!(identify_gate(&x.view(), 1e-10), Some("X".to_string()));
436    }
437}