Skip to main content

quantrs2_core/
synthesis.rs

1//! Gate synthesis from unitary matrices
2//!
3//! This module provides algorithms to decompose arbitrary unitary matrices
4//! into sequences of quantum gates, including:
5//! - Single-qubit unitary decomposition (ZYZ, XYX, etc.)
6//! - Two-qubit unitary decomposition (KAK/Cartan)
7//! - General n-qubit synthesis using Cosine-Sine decomposition
8
9use crate::cartan::{CartanDecomposer, CartanDecomposition};
10// use crate::controlled::{make_controlled, ControlledGate};
11use crate::error::{QuantRS2Error, QuantRS2Result};
12use crate::gate::{single::*, GateOp};
13use crate::matrix_ops::{matrices_approx_equal, DenseMatrix, QuantumMatrix};
14use crate::qubit::QubitId;
15use scirs2_core::ndarray::{Array2, ArrayView2};
16use scirs2_core::Complex64;
17use std::f64::consts::PI;
18
19/// Result of single-qubit decomposition
20#[derive(Debug, Clone)]
21pub struct SingleQubitDecomposition {
22    /// Global phase
23    pub global_phase: f64,
24    /// First rotation angle (Z or X depending on basis)
25    pub theta1: f64,
26    /// Middle rotation angle (Y)
27    pub phi: f64,
28    /// Last rotation angle (Z or X depending on basis)
29    pub theta2: f64,
30    /// The basis used (e.g., "ZYZ", "XYX")
31    pub basis: String,
32}
33
34/// Decompose a single-qubit unitary into ZYZ rotations
35pub fn decompose_single_qubit_zyz(
36    unitary: &ArrayView2<Complex64>,
37) -> QuantRS2Result<SingleQubitDecomposition> {
38    if unitary.shape() != &[2, 2] {
39        return Err(QuantRS2Error::InvalidInput(
40            "Single-qubit unitary must be 2x2".to_string(),
41        ));
42    }
43
44    // Check unitarity
45    let matrix = DenseMatrix::new(unitary.to_owned())?;
46    if !matrix.is_unitary(1e-10)? {
47        return Err(QuantRS2Error::InvalidInput(
48            "Matrix is not unitary".to_string(),
49        ));
50    }
51
52    // Extract matrix elements
53    let a = unitary[[0, 0]];
54    let b = unitary[[0, 1]];
55    let c = unitary[[1, 0]];
56    let d = unitary[[1, 1]];
57
58    // Calculate global phase from determinant
59    let det = a * d - b * c;
60    let global_phase = det.arg() / 2.0;
61
62    // Normalize by the determinant to make the matrix special unitary
63    let det_sqrt = det.sqrt();
64    let a = a / det_sqrt;
65    let b = b / det_sqrt;
66    let c = c / det_sqrt;
67    let d = d / det_sqrt;
68
69    // Decompose into ZYZ angles
70    // U = e^(i*global_phase) * Rz(theta2) * Ry(phi) * Rz(theta1)
71
72    let phi = 2.0 * a.norm().acos();
73
74    let (theta1, theta2) = if phi.abs() < 1e-10 {
75        // Identity or phase gate
76        let phase = if a.norm() > 0.5 {
77            a.arg() * 2.0
78        } else {
79            d.arg() * 2.0
80        };
81        (0.0, phase)
82    } else if (phi - PI).abs() < 1e-10 {
83        // Pi rotation
84        (-b.arg() + c.arg() + PI, 0.0)
85    } else {
86        // From Rz(θ₂)·Ry(φ)·Rz(θ₁) decomposition of SU(2) matrix [[a,b],[c,d]]:
87        //   arg(a) = -(θ₁+θ₂)/2  and  arg(c) = (θ₂-θ₁)/2
88        // Solving: θ₁ = -arg(a) - arg(c),  θ₂ = arg(c) - arg(a)
89        let theta1 = -a.arg() - c.arg();
90        let theta2 = c.arg() - a.arg();
91        (theta1, theta2)
92    };
93
94    Ok(SingleQubitDecomposition {
95        global_phase,
96        theta1,
97        phi,
98        theta2,
99        basis: "ZYZ".to_string(),
100    })
101}
102
103/// Decompose a single-qubit unitary into XYX rotations
104pub fn decompose_single_qubit_xyx(
105    unitary: &ArrayView2<Complex64>,
106) -> QuantRS2Result<SingleQubitDecomposition> {
107    // Convert to Pauli basis and use ZYZ decomposition
108    // Safety: 2x2 shape with 4 elements is guaranteed valid
109    let h_gate = Array2::from_shape_vec(
110        (2, 2),
111        vec![
112            Complex64::new(1.0, 0.0),
113            Complex64::new(1.0, 0.0),
114            Complex64::new(1.0, 0.0),
115            Complex64::new(-1.0, 0.0),
116        ],
117    )
118    .expect("2x2 Hadamard matrix shape is always valid")
119        / Complex64::new(2.0_f64.sqrt(), 0.0);
120
121    // Transform: U' = H * U * H
122    let u_transformed = h_gate.dot(unitary).dot(&h_gate);
123    let decomp = decompose_single_qubit_zyz(&u_transformed.view())?;
124
125    Ok(SingleQubitDecomposition {
126        global_phase: decomp.global_phase,
127        theta1: decomp.theta1,
128        phi: decomp.phi,
129        theta2: decomp.theta2,
130        basis: "XYX".to_string(),
131    })
132}
133
134/// Convert single-qubit decomposition to gate sequence
135pub fn single_qubit_gates(
136    decomp: &SingleQubitDecomposition,
137    qubit: QubitId,
138) -> Vec<Box<dyn GateOp>> {
139    let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
140
141    match decomp.basis.as_str() {
142        "ZYZ" => {
143            if decomp.theta1.abs() > 1e-10 {
144                gates.push(Box::new(RotationZ {
145                    target: qubit,
146                    theta: decomp.theta1,
147                }));
148            }
149            if decomp.phi.abs() > 1e-10 {
150                gates.push(Box::new(RotationY {
151                    target: qubit,
152                    theta: decomp.phi,
153                }));
154            }
155            if decomp.theta2.abs() > 1e-10 {
156                gates.push(Box::new(RotationZ {
157                    target: qubit,
158                    theta: decomp.theta2,
159                }));
160            }
161        }
162        "XYX" => {
163            if decomp.theta1.abs() > 1e-10 {
164                gates.push(Box::new(RotationX {
165                    target: qubit,
166                    theta: decomp.theta1,
167                }));
168            }
169            if decomp.phi.abs() > 1e-10 {
170                gates.push(Box::new(RotationY {
171                    target: qubit,
172                    theta: decomp.phi,
173                }));
174            }
175            if decomp.theta2.abs() > 1e-10 {
176                gates.push(Box::new(RotationX {
177                    target: qubit,
178                    theta: decomp.theta2,
179                }));
180            }
181        }
182        _ => {} // Unknown basis
183    }
184
185    gates
186}
187
188/// Result of two-qubit KAK decomposition (alias for CartanDecomposition)
189pub type KAKDecomposition = CartanDecomposition;
190
191/// Decompose a two-qubit unitary using KAK decomposition
192pub fn decompose_two_qubit_kak(
193    unitary: &ArrayView2<Complex64>,
194) -> QuantRS2Result<KAKDecomposition> {
195    // Use Cartan decomposer for KAK decomposition
196    let mut decomposer = CartanDecomposer::new();
197    let owned_unitary = unitary.to_owned();
198    decomposer.decompose(&owned_unitary)
199}
200
201/// Convert KAK decomposition to gate sequence
202pub fn kak_to_gates(
203    decomp: &KAKDecomposition,
204    qubit1: QubitId,
205    qubit2: QubitId,
206) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
207    // Use CartanDecomposer to convert to gates
208    let decomposer = CartanDecomposer::new();
209    let qubit_ids = vec![qubit1, qubit2];
210    decomposer.to_gates(decomp, &qubit_ids)
211}
212
213/// Synthesize an arbitrary unitary matrix into quantum gates
214pub fn synthesize_unitary(
215    unitary: &ArrayView2<Complex64>,
216    qubits: &[QubitId],
217) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
218    let n = unitary.nrows();
219
220    if n != unitary.ncols() {
221        return Err(QuantRS2Error::InvalidInput(
222            "Matrix must be square".to_string(),
223        ));
224    }
225
226    let num_qubits = (n as f64).log2() as usize;
227    if (1 << num_qubits) != n {
228        return Err(QuantRS2Error::InvalidInput(
229            "Matrix dimension must be a power of 2".to_string(),
230        ));
231    }
232
233    if qubits.len() != num_qubits {
234        return Err(QuantRS2Error::InvalidInput(format!(
235            "Need {} qubits, got {}",
236            num_qubits,
237            qubits.len()
238        )));
239    }
240
241    // Check unitarity
242    let matrix = DenseMatrix::new(unitary.to_owned())?;
243    if !matrix.is_unitary(1e-10)? {
244        return Err(QuantRS2Error::InvalidInput(
245            "Matrix is not unitary".to_string(),
246        ));
247    }
248
249    match num_qubits {
250        1 => {
251            let decomp = decompose_single_qubit_zyz(unitary)?;
252            Ok(single_qubit_gates(&decomp, qubits[0]))
253        }
254        2 => {
255            let decomp = decompose_two_qubit_kak(unitary)?;
256            kak_to_gates(&decomp, qubits[0], qubits[1])
257        }
258        _ => {
259            // For n-qubit gates, use recursive decomposition
260            // This is a placeholder - would implement Cosine-Sine decomposition
261            Err(QuantRS2Error::UnsupportedOperation(format!(
262                "Synthesis for {num_qubits}-qubit gates not yet implemented"
263            )))
264        }
265    }
266}
267
268/// Check if a unitary is close to a known gate
269pub fn identify_gate(unitary: &ArrayView2<Complex64>, tolerance: f64) -> Option<String> {
270    let n = unitary.nrows();
271
272    match n {
273        2 => {
274            // Check common single-qubit gates
275            let gates = vec![
276                ("I", Array2::eye(2)),
277                // Safety: All 2x2 shapes with 4 elements are guaranteed valid
278                (
279                    "X",
280                    Array2::from_shape_vec(
281                        (2, 2),
282                        vec![
283                            Complex64::new(0.0, 0.0),
284                            Complex64::new(1.0, 0.0),
285                            Complex64::new(1.0, 0.0),
286                            Complex64::new(0.0, 0.0),
287                        ],
288                    )
289                    .expect("2x2 X gate shape is always valid"),
290                ),
291                (
292                    "Y",
293                    Array2::from_shape_vec(
294                        (2, 2),
295                        vec![
296                            Complex64::new(0.0, 0.0),
297                            Complex64::new(0.0, -1.0),
298                            Complex64::new(0.0, 1.0),
299                            Complex64::new(0.0, 0.0),
300                        ],
301                    )
302                    .expect("2x2 Y gate shape is always valid"),
303                ),
304                (
305                    "Z",
306                    Array2::from_shape_vec(
307                        (2, 2),
308                        vec![
309                            Complex64::new(1.0, 0.0),
310                            Complex64::new(0.0, 0.0),
311                            Complex64::new(0.0, 0.0),
312                            Complex64::new(-1.0, 0.0),
313                        ],
314                    )
315                    .expect("2x2 Z gate shape is always valid"),
316                ),
317                (
318                    "H",
319                    Array2::from_shape_vec(
320                        (2, 2),
321                        vec![
322                            Complex64::new(1.0, 0.0),
323                            Complex64::new(1.0, 0.0),
324                            Complex64::new(1.0, 0.0),
325                            Complex64::new(-1.0, 0.0),
326                        ],
327                    )
328                    .expect("2x2 H gate shape is always valid")
329                        / Complex64::new(2.0_f64.sqrt(), 0.0),
330                ),
331            ];
332
333            for (name, gate) in gates {
334                if matrices_approx_equal(unitary, &gate.view(), tolerance) {
335                    return Some(name.to_string());
336                }
337            }
338        }
339        4 => {
340            // Check common two-qubit gates
341            let mut cnot = Array2::eye(4);
342            cnot[[2, 2]] = Complex64::new(0.0, 0.0);
343            cnot[[2, 3]] = Complex64::new(1.0, 0.0);
344            cnot[[3, 2]] = Complex64::new(1.0, 0.0);
345            cnot[[3, 3]] = Complex64::new(0.0, 0.0);
346
347            if matrices_approx_equal(unitary, &cnot.view(), tolerance) {
348                return Some("CNOT".to_string());
349            }
350        }
351        _ => {}
352    }
353
354    None
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_single_qubit_decomposition() {
363        // Test Hadamard gate
364        let h = Array2::from_shape_vec(
365            (2, 2),
366            vec![
367                Complex64::new(1.0, 0.0),
368                Complex64::new(1.0, 0.0),
369                Complex64::new(1.0, 0.0),
370                Complex64::new(-1.0, 0.0),
371            ],
372        )
373        .expect("Hadamard matrix shape is always valid 2x2")
374            / Complex64::new(2.0_f64.sqrt(), 0.0);
375
376        let decomp =
377            decompose_single_qubit_zyz(&h.view()).expect("ZYZ decomposition should succeed");
378
379        // Reconstruct and verify
380        let rz1 = Array2::from_shape_vec(
381            (2, 2),
382            vec![
383                Complex64::new(0.0, -decomp.theta1 / 2.0).exp(),
384                Complex64::new(0.0, 0.0),
385                Complex64::new(0.0, 0.0),
386                Complex64::new(0.0, decomp.theta1 / 2.0).exp(),
387            ],
388        )
389        .expect("Rz1 matrix shape is always valid 2x2");
390
391        let ry = Array2::from_shape_vec(
392            (2, 2),
393            vec![
394                Complex64::new((decomp.phi / 2.0).cos(), 0.0),
395                Complex64::new(-(decomp.phi / 2.0).sin(), 0.0),
396                Complex64::new((decomp.phi / 2.0).sin(), 0.0),
397                Complex64::new((decomp.phi / 2.0).cos(), 0.0),
398            ],
399        )
400        .expect("Ry matrix shape is always valid 2x2");
401
402        let rz2 = Array2::from_shape_vec(
403            (2, 2),
404            vec![
405                Complex64::new(0.0, -decomp.theta2 / 2.0).exp(),
406                Complex64::new(0.0, 0.0),
407                Complex64::new(0.0, 0.0),
408                Complex64::new(0.0, decomp.theta2 / 2.0).exp(),
409            ],
410        )
411        .expect("Rz2 matrix shape is always valid 2x2");
412
413        // Reconstruct: e^(i*global_phase) * Rz(theta2) * Ry(phi) * Rz(theta1)
414        let reconstructed = Complex64::new(0.0, decomp.global_phase).exp() * rz2.dot(&ry).dot(&rz1);
415
416        // Check reconstruction
417        assert!(matrices_approx_equal(
418            &h.view(),
419            &reconstructed.view(),
420            1e-10
421        ));
422    }
423
424    #[test]
425    fn test_gate_identification() {
426        let x = Array2::from_shape_vec(
427            (2, 2),
428            vec![
429                Complex64::new(0.0, 0.0),
430                Complex64::new(1.0, 0.0),
431                Complex64::new(1.0, 0.0),
432                Complex64::new(0.0, 0.0),
433            ],
434        )
435        .expect("X gate matrix shape is always valid 2x2");
436
437        assert_eq!(identify_gate(&x.view(), 1e-10), Some("X".to_string()));
438    }
439}