Skip to main content

prism_q/gates/
mod.rs

1//! Gate definitions and matrix representations.
2//!
3//! Gates are represented as an enum for fast dispatch without trait-object overhead
4//! in the simulation hot path. Matrix representations use stack-allocated arrays
5//! to avoid heap allocation during gate application.
6//!
7//! # Hot-path design notes
8//! - `Gate` methods take `&self`, the enum is 16 bytes (Box indirection for `Fused`).
9//! - `matrix_2x2` returns `[[Complex64; 2]; 2]` on the stack.
10//! - Two-qubit gates (CX, CZ, SWAP) have dedicated application routines in
11//!   backends rather than materializing a 4×4 matrix.
12
13use num_complex::Complex64;
14use smallvec::SmallVec;
15use std::f64::consts::{FRAC_1_SQRT_2, PI};
16use std::fmt;
17
18/// Threshold for detecting near-zero matrix elements (norm_sqr).
19///
20/// Used in `preserves_sparsity()` to test if off-diagonal or diagonal entries
21/// are effectively zero, indicating a permutation/diagonal gate structure.
22const NEAR_ZERO_NORM_SQ: f64 = 1e-24;
23
24/// Threshold for detecting identity-like matrices (element norm).
25///
26/// Used in `is_diagonal_1q()` for fused gate diagonal detection and in
27/// `controlled_phase()` for phase-gate structure recognition.
28const IDENTITY_EPS: f64 = 1e-12;
29
30/// Quantum gate identifier.
31///
32/// Covers the v0 supported gate set. Most variants are data-free or carry an `f64`
33/// parameter inline. The `Fused` variant uses `Box` to keep the enum at 16 bytes.
34#[derive(Debug, Clone, PartialEq)]
35pub enum Gate {
36    /// Identity.
37    Id,
38    /// Pauli-X (bit flip).
39    X,
40    /// Pauli-Y.
41    Y,
42    /// Pauli-Z (phase flip).
43    Z,
44    /// Hadamard.
45    H,
46    /// S gate (√Z).
47    S,
48    /// S† gate.
49    Sdg,
50    /// T gate (π/8).
51    T,
52    /// T† gate.
53    Tdg,
54    /// √X gate.
55    SX,
56    /// √X† gate.
57    SXdg,
58
59    /// Rotation about X-axis by angle (radians).
60    Rx(f64),
61    /// Rotation about Y-axis by angle (radians).
62    Ry(f64),
63    /// Rotation about Z-axis by angle (radians).
64    Rz(f64),
65    /// Phase gate `[[1,0],[0,e^{iθ}]]`.
66    P(f64),
67
68    /// ZZ rotation: diag(e^{-iθ/2}, e^{iθ/2}, e^{iθ/2}, e^{-iθ/2}).
69    /// Qubit order: [q0, q1] (symmetric).
70    Rzz(f64),
71
72    /// Controlled-X (CNOT). Qubit order: [control, target].
73    Cx,
74    /// Controlled-Z. Qubit order: [q0, q1] (symmetric).
75    Cz,
76    /// SWAP. Qubit order: [q0, q1] (symmetric).
77    Swap,
78
79    /// Controlled-unitary. Applies the boxed 2×2 matrix to the target qubit
80    /// only when the control qubit is |1⟩. Qubit order: [control, target].
81    /// Boxed to keep `Gate` at 16 bytes.
82    Cu(Box<[[Complex64; 2]; 2]>),
83
84    /// Multi-controlled unitary. Applies the 2×2 matrix to the target qubit
85    /// only when all control qubits are |1⟩. Qubit order:
86    /// `[ctrl_0, ctrl_1, ..., ctrl_{k-1}, target]`.
87    /// Boxed to keep `Gate` at 16 bytes.
88    Mcu(Box<McuData>),
89
90    /// Pre-fused single-qubit unitary (product of consecutive gates on the same target).
91    /// Boxed to keep `Gate` at 16 bytes for cache-friendly instruction streams.
92    Fused(Box<[[Complex64; 2]; 2]>),
93
94    /// Batched controlled-phase: multiple cphase gates sharing a control qubit,
95    /// fused into a single pass over the statevector. Created by the cphase
96    /// fusion pass. Targets: `[control]`. The `BatchPhaseData` holds per-target
97    /// phases. Boxed to keep `Gate` at 16 bytes.
98    BatchPhase(Box<BatchPhaseData>),
99
100    /// Batched ZZ rotations: multiple Rzz gates fused into a single pass.
101    /// Created by the batch-Rzz fusion pass. The `BatchRzzData` holds per-edge
102    /// angles. Boxed to keep `Gate` at 16 bytes.
103    BatchRzz(Box<BatchRzzData>),
104
105    /// Batched diagonal gates: a contiguous run of diagonal 1q and 2q gates
106    /// collapsed into a single state-vector sweep with a precomputed phase LUT.
107    /// Subsumes BatchPhase and BatchRzz for mixed diagonal runs. Created by the
108    /// diagonal batch fusion pass. Boxed to keep `Gate` at 16 bytes.
109    DiagonalBatch(Box<DiagonalBatchData>),
110
111    /// Multiple single-qubit gates on distinct qubits, batched for a single
112    /// tiled pass over the statevector. Created by the multi-gate fusion pass.
113    /// Boxed to keep `Gate` at 16 bytes.
114    MultiFused(Box<MultiFusedData>),
115
116    /// Pre-fused two-qubit unitary (4×4 matrix). Created by the 2q fusion pass
117    /// which absorbs adjacent single-qubit gates into a two-qubit gate.
118    /// Boxed to keep `Gate` at 16 bytes.
119    Fused2q(Box<[[Complex64; 4]; 4]>),
120
121    /// Multiple two-qubit gates batched for a single tiled pass over the
122    /// statevector. Created by the multi-2q fusion pass. Each entry stores
123    /// `(q0, q1, 4×4 matrix)`. Boxed to keep `Gate` at 16 bytes.
124    Multi2q(Box<Multi2qData>),
125
126    /// Quantum Fourier Transform on `start..start+num`.
127    ///
128    /// The CPU statevector backend has a fast whole-state FFT path. Subrange
129    /// blocks and non-native backends expand to textbook H, cphase, and swap
130    /// gates before execution.
131    /// Boxless: `(u8, u8)` fits within the 16-byte enum slot.
132    QftBlock { start: u8, num: u8 },
133}
134
135/// Data for a multi-controlled unitary gate.
136#[derive(Debug, Clone, PartialEq)]
137pub struct McuData {
138    /// 2×2 unitary applied to the target qubit.
139    pub mat: [[Complex64; 2]; 2],
140    /// Number of control qubits (≥ 2).
141    pub num_controls: u8,
142}
143
144/// Data for a batched controlled-phase gate.
145///
146/// Multiple cphase gates sharing a control qubit are fused into one pass.
147/// Each entry is `(target_qubit, phase)`. The control qubit is stored in the
148/// instruction's `targets[0]`.
149#[derive(Debug, Clone, PartialEq)]
150pub struct BatchPhaseData {
151    pub phases: SmallVec<[(usize, Complex64); 8]>,
152}
153
154/// Data for batched ZZ rotations.
155///
156/// Multiple Rzz gates batched into a single pass over the statevector.
157/// Each entry is `(qubit_0, qubit_1, theta)`. All qubits are stored in the
158/// instruction's `targets`.
159#[derive(Debug, Clone, PartialEq)]
160pub struct BatchRzzData {
161    pub edges: Vec<(usize, usize, f64)>,
162}
163
164/// An individual diagonal phase contribution in a [`DiagonalBatchData`].
165#[derive(Debug, Clone, PartialEq)]
166pub enum DiagEntry {
167    /// Diagonal on a single qubit: `state[i] *= d0` when bit 0, `*= d1` when bit 1.
168    Phase1q {
169        qubit: usize,
170        d0: Complex64,
171        d1: Complex64,
172    },
173    /// Phase on a qubit pair: `state[i] *= phase` when both bits are set (CZ/CPhase).
174    Phase2q {
175        q0: usize,
176        q1: usize,
177        phase: Complex64,
178    },
179    /// Parity-dependent phase (Rzz): `state[i] *= same` when parity is even,
180    /// `state[i] *= diff` when parity is odd.
181    Parity2q {
182        q0: usize,
183        q1: usize,
184        same: Complex64,
185        diff: Complex64,
186    },
187}
188
189impl DiagEntry {
190    pub fn as_1q_matrix(&self) -> Option<(usize, [[Complex64; 2]; 2])> {
191        match *self {
192            DiagEntry::Phase1q { qubit, d0, d1 } => {
193                let z = Complex64::new(0.0, 0.0);
194                Some((qubit, [[d0, z], [z, d1]]))
195            }
196            _ => None,
197        }
198    }
199
200    pub fn as_2q_matrix(&self) -> Option<(usize, usize, [[Complex64; 4]; 4])> {
201        let z = Complex64::new(0.0, 0.0);
202        let one = Complex64::new(1.0, 0.0);
203        match *self {
204            DiagEntry::Phase2q { q0, q1, phase } => Some((
205                q0,
206                q1,
207                [
208                    [one, z, z, z],
209                    [z, one, z, z],
210                    [z, z, one, z],
211                    [z, z, z, phase],
212                ],
213            )),
214            DiagEntry::Parity2q {
215                q0, q1, same, diff, ..
216            } => Some((
217                q0,
218                q1,
219                [
220                    [same, z, z, z],
221                    [z, diff, z, z],
222                    [z, z, diff, z],
223                    [z, z, z, same],
224                ],
225            )),
226            _ => None,
227        }
228    }
229}
230
231/// Data for a batched diagonal gate pass.
232///
233/// A contiguous run of diagonal gates collapsed into a precomputed phase LUT.
234/// The `entries` describe individual phase contributions; the kernel extracts
235/// unique qubits, builds a LUT indexed by their bits, and applies in one sweep.
236#[derive(Debug, Clone, PartialEq)]
237pub struct DiagonalBatchData {
238    pub entries: Vec<DiagEntry>,
239}
240
241/// Data for multi-gate single-pass fusion.
242///
243/// Batches consecutive single-qubit gates on distinct qubits into one tiled
244/// pass over the statevector. Each entry is `(target_qubit, 2×2 matrix)`.
245#[derive(Debug, Clone, PartialEq)]
246pub struct MultiFusedData {
247    pub gates: Vec<(usize, [[Complex64; 2]; 2])>,
248    pub all_diagonal: bool,
249}
250
251/// Data for multi-2q tiled pass fusion.
252///
253/// Batches consecutive two-qubit gates into a single cache-tiled pass over the
254/// statevector. Each entry is `(q0, q1, 4×4 matrix)`. Gate order is preserved.
255#[derive(Debug, Clone, PartialEq)]
256pub struct Multi2qData {
257    pub gates: Vec<(usize, usize, [[Complex64; 4]; 4])>,
258}
259
260/// Kronecker product of two 2×2 matrices: A ⊗ B → 4×4.
261///
262/// Result indices: `(i*2+j, k*2+l) = A[i][k] * B[j][l]`
263/// where i,k index A (targets\[0\]) and j,l index B (targets\[1\]).
264#[inline]
265pub(crate) fn kron_2x2(a: &[[Complex64; 2]; 2], b: &[[Complex64; 2]; 2]) -> [[Complex64; 4]; 4] {
266    let mut result = [[Complex64::new(0.0, 0.0); 4]; 4];
267    for i in 0..2 {
268        for k in 0..2 {
269            let aik = a[i][k];
270            for j in 0..2 {
271                for l in 0..2 {
272                    result[i * 2 + j][k * 2 + l] = aik * b[j][l];
273                }
274            }
275        }
276    }
277    result
278}
279
280/// Product of two 4×4 matrices: A · B.
281#[inline]
282pub(crate) fn mat_mul_4x4(a: &[[Complex64; 4]; 4], b: &[[Complex64; 4]; 4]) -> [[Complex64; 4]; 4] {
283    let zero = Complex64::new(0.0, 0.0);
284    let mut result = [[zero; 4]; 4];
285    for i in 0..4 {
286        for j in 0..4 {
287            let mut sum = zero;
288            for k in 0..4 {
289                sum += a[i][k] * b[k][j];
290            }
291            result[i][j] = sum;
292        }
293    }
294    result
295}
296
297/// Conjugate-transpose of a 4×4 matrix (U†).
298fn adjoint_4x4(m: &[[Complex64; 4]; 4]) -> [[Complex64; 4]; 4] {
299    let mut result = [[Complex64::new(0.0, 0.0); 4]; 4];
300    for i in 0..4 {
301        for j in 0..4 {
302            result[i][j] = m[j][i].conj();
303        }
304    }
305    result
306}
307
308/// Conjugate-transpose of a 2×2 matrix (U†).
309fn adjoint_2x2(m: &[[Complex64; 2]; 2]) -> [[Complex64; 2]; 2] {
310    [
311        [m[0][0].conj(), m[1][0].conj()],
312        [m[0][1].conj(), m[1][1].conj()],
313    ]
314}
315
316/// Product of two 2×2 matrices: A · B.
317#[inline]
318pub(crate) fn mat_mul_2x2(a: &[[Complex64; 2]; 2], b: &[[Complex64; 2]; 2]) -> [[Complex64; 2]; 2] {
319    [
320        [
321            a[0][0] * b[0][0] + a[0][1] * b[1][0],
322            a[0][0] * b[0][1] + a[0][1] * b[1][1],
323        ],
324        [
325            a[1][0] * b[0][0] + a[1][1] * b[1][0],
326            a[1][0] * b[0][1] + a[1][1] * b[1][1],
327        ],
328    ]
329}
330
331impl Gate {
332    /// Number of qubits this gate acts on.
333    #[inline]
334    pub fn num_qubits(&self) -> usize {
335        match self {
336            Gate::Rzz(_) | Gate::Cx | Gate::Cz | Gate::Swap | Gate::Cu(_) | Gate::Fused2q(_) => 2,
337            Gate::Mcu(data) => data.num_controls as usize + 1,
338            Gate::BatchPhase(data) => 1 + data.phases.len(),
339            Gate::QftBlock { num, .. } => *num as usize,
340            Gate::BatchRzz(data) => {
341                let mut count = 0;
342                let mut seen = [false; 64];
343                for &(q0, q1, _) in &data.edges {
344                    if !seen[q0] {
345                        seen[q0] = true;
346                        count += 1;
347                    }
348                    if !seen[q1] {
349                        seen[q1] = true;
350                        count += 1;
351                    }
352                }
353                count
354            }
355            Gate::DiagonalBatch(data) => {
356                let mut count = 0;
357                let mut seen = [false; 64];
358                for e in &data.entries {
359                    let qs = match e {
360                        DiagEntry::Phase1q { qubit, .. } => [*qubit, usize::MAX],
361                        DiagEntry::Phase2q { q0, q1, .. } | DiagEntry::Parity2q { q0, q1, .. } => {
362                            [*q0, *q1]
363                        }
364                    };
365                    for &q in &qs {
366                        if q < 64 && !seen[q] {
367                            seen[q] = true;
368                            count += 1;
369                        }
370                    }
371                }
372                count
373            }
374            Gate::MultiFused(data) => data.gates.len(),
375            Gate::Multi2q(data) => {
376                let mut count = 0;
377                let mut seen = [false; 64];
378                for &(q0, q1, _) in &data.gates {
379                    if !seen[q0] {
380                        seen[q0] = true;
381                        count += 1;
382                    }
383                    if !seen[q1] {
384                        seen[q1] = true;
385                        count += 1;
386                    }
387                }
388                count
389            }
390            _ => 1,
391        }
392    }
393
394    /// Returns the 2×2 unitary matrix for single-qubit gates.
395    ///
396    /// # Panics
397    /// Panics if called on a multi-qubit or batch gate (`Cx`, `Cz`, `Swap`,
398    /// `Cu`, `Mcu`, `BatchPhase`, `MultiFused`, `Fused2q`, `Multi2q`).
399    #[inline]
400    pub fn matrix_2x2(&self) -> [[Complex64; 2]; 2] {
401        let zero = Complex64::new(0.0, 0.0);
402        let one = Complex64::new(1.0, 0.0);
403        let i = Complex64::new(0.0, 1.0);
404        let neg_i = Complex64::new(0.0, -1.0);
405        let h = Complex64::new(FRAC_1_SQRT_2, 0.0);
406
407        match self {
408            Gate::Id => [[one, zero], [zero, one]],
409            Gate::X => [[zero, one], [one, zero]],
410            Gate::Y => [[zero, neg_i], [i, zero]],
411            Gate::Z => [[one, zero], [zero, -one]],
412            Gate::H => [[h, h], [h, -h]],
413            Gate::S => [[one, zero], [zero, i]],
414            Gate::Sdg => [[one, zero], [zero, neg_i]],
415            Gate::T => {
416                let phase = Complex64::from_polar(1.0, PI / 4.0);
417                [[one, zero], [zero, phase]]
418            }
419            Gate::Tdg => {
420                let phase = Complex64::from_polar(1.0, -PI / 4.0);
421                [[one, zero], [zero, phase]]
422            }
423            Gate::SX => {
424                let half = Complex64::new(0.5, 0.0);
425                let half_i = Complex64::new(0.0, 0.5);
426                [
427                    [half + half_i, half - half_i],
428                    [half - half_i, half + half_i],
429                ]
430            }
431            Gate::SXdg => {
432                let half = Complex64::new(0.5, 0.0);
433                let half_i = Complex64::new(0.0, 0.5);
434                [
435                    [half - half_i, half + half_i],
436                    [half + half_i, half - half_i],
437                ]
438            }
439            Gate::Rx(theta) => {
440                let c = Complex64::new((theta / 2.0).cos(), 0.0);
441                let s = Complex64::new(0.0, -(theta / 2.0).sin());
442                [[c, s], [s, c]]
443            }
444            Gate::Ry(theta) => {
445                let c = Complex64::new((theta / 2.0).cos(), 0.0);
446                let s = Complex64::new((theta / 2.0).sin(), 0.0);
447                [[c, -s], [s, c]]
448            }
449            Gate::Rz(theta) => {
450                let e_neg = Complex64::from_polar(1.0, -theta / 2.0);
451                let e_pos = Complex64::from_polar(1.0, theta / 2.0);
452                [[e_neg, zero], [zero, e_pos]]
453            }
454            Gate::P(theta) => {
455                let phase = Complex64::from_polar(1.0, *theta);
456                [[one, zero], [zero, phase]]
457            }
458            Gate::Fused(mat) => **mat,
459            Gate::Rzz(_)
460            | Gate::Cx
461            | Gate::Cz
462            | Gate::Swap
463            | Gate::Cu(_)
464            | Gate::Mcu(_)
465            | Gate::BatchPhase(_)
466            | Gate::QftBlock { .. }
467            | Gate::BatchRzz(_)
468            | Gate::DiagonalBatch(_)
469            | Gate::MultiFused(_)
470            | Gate::Fused2q(_)
471            | Gate::Multi2q(_) => {
472                panic!(
473                    "matrix_2x2 called on {}-qubit gate `{}`; use dedicated backend routine",
474                    self.num_qubits(),
475                    self.name()
476                )
477            }
478        }
479    }
480
481    /// Returns the 4×4 unitary matrix for two-qubit gates.
482    ///
483    /// Matrix indices follow the convention: row/col `i*2+j` where `i` indexes
484    /// `targets[0]` and `j` indexes `targets[1]`.
485    ///
486    /// # Panics
487    /// Panics on gates other than `Cx`, `Cz`, `Swap`, `Cu`, or `Fused2q`.
488    pub fn matrix_4x4(&self) -> [[Complex64; 4]; 4] {
489        let z = Complex64::new(0.0, 0.0);
490        let o = Complex64::new(1.0, 0.0);
491        let m = Complex64::new(-1.0, 0.0);
492        match self {
493            Gate::Rzz(theta) => {
494                let ps = Complex64::from_polar(1.0, -theta / 2.0);
495                let pd = Complex64::from_polar(1.0, theta / 2.0);
496                [[ps, z, z, z], [z, pd, z, z], [z, z, pd, z], [z, z, z, ps]]
497            }
498            Gate::Cx => [[o, z, z, z], [z, o, z, z], [z, z, z, o], [z, z, o, z]],
499            Gate::Cz => [[o, z, z, z], [z, o, z, z], [z, z, o, z], [z, z, z, m]],
500            Gate::Swap => [[o, z, z, z], [z, z, o, z], [z, o, z, z], [z, z, z, o]],
501            Gate::Cu(mat) => [
502                [o, z, z, z],
503                [z, o, z, z],
504                [z, z, mat[0][0], mat[0][1]],
505                [z, z, mat[1][0], mat[1][1]],
506            ],
507            Gate::Fused2q(mat) => **mat,
508            _ => panic!(
509                "matrix_4x4 called on non-standard-2q gate `{}`",
510                self.name()
511            ),
512        }
513    }
514
515    /// Human-readable gate name (for errors, logs, and OpenQASM round-tripping).
516    #[inline]
517    pub fn name(&self) -> &'static str {
518        match self {
519            Gate::Id => "id",
520            Gate::X => "x",
521            Gate::Y => "y",
522            Gate::Z => "z",
523            Gate::H => "h",
524            Gate::S => "s",
525            Gate::Sdg => "sdg",
526            Gate::T => "t",
527            Gate::Tdg => "tdg",
528            Gate::SX => "sx",
529            Gate::SXdg => "sxdg",
530            Gate::Rx(_) => "rx",
531            Gate::Ry(_) => "ry",
532            Gate::Rz(_) => "rz",
533            Gate::P(_) => "p",
534            Gate::Rzz(_) => "rzz",
535            Gate::Cx => "cx",
536            Gate::Cz => "cz",
537            Gate::Swap => "swap",
538            Gate::Cu(_) => "cu",
539            Gate::Mcu(_) => "mcu",
540            Gate::Fused(_) => "fused",
541            Gate::BatchPhase(_) => "batch_phase",
542            Gate::QftBlock { .. } => "qft_block",
543            Gate::BatchRzz(_) => "batch_rzz",
544            Gate::DiagonalBatch(_) => "diagonal_batch",
545            Gate::MultiFused(_) => "multi_fused",
546            Gate::Fused2q(_) => "fused_2q",
547            Gate::Multi2q(_) => "multi_2q",
548        }
549    }
550
551    /// Compute the inverse (adjoint) of this gate.
552    pub fn inverse(&self) -> Gate {
553        match self {
554            Gate::Id | Gate::X | Gate::Y | Gate::Z | Gate::H => self.clone(),
555            Gate::S => Gate::Sdg,
556            Gate::Sdg => Gate::S,
557            Gate::T => Gate::Tdg,
558            Gate::Tdg => Gate::T,
559            Gate::SX => Gate::SXdg,
560            Gate::SXdg => Gate::SX,
561            Gate::Rx(theta) => Gate::Rx(-theta),
562            Gate::Ry(theta) => Gate::Ry(-theta),
563            Gate::Rz(theta) => Gate::Rz(-theta),
564            Gate::P(theta) => Gate::P(-theta),
565            Gate::Rzz(theta) => Gate::Rzz(-theta),
566            Gate::Cx | Gate::Cz | Gate::Swap => self.clone(),
567            Gate::Cu(mat) => Gate::cu(adjoint_2x2(mat)),
568            Gate::Mcu(data) => Gate::mcu(adjoint_2x2(&data.mat), data.num_controls),
569            Gate::Fused(mat) => Gate::Fused(Box::new(adjoint_2x2(mat))),
570            Gate::BatchPhase(data) => Gate::BatchPhase(Box::new(BatchPhaseData {
571                phases: data.phases.iter().map(|&(q, p)| (q, p.conj())).collect(),
572            })),
573            Gate::QftBlock { .. } => {
574                panic!(
575                    "Gate::QftBlock has no in-place inverse. Run \
576                     circuit::expand_qft_blocks before applying `inv @` or any \
577                     transform that calls Gate::inverse()."
578                )
579            }
580            Gate::BatchRzz(data) => Gate::BatchRzz(Box::new(BatchRzzData {
581                edges: data
582                    .edges
583                    .iter()
584                    .map(|&(q0, q1, theta)| (q0, q1, -theta))
585                    .collect(),
586            })),
587            Gate::DiagonalBatch(data) => Gate::DiagonalBatch(Box::new(DiagonalBatchData {
588                entries: data
589                    .entries
590                    .iter()
591                    .map(|e| match e {
592                        DiagEntry::Phase1q { qubit, d0, d1 } => DiagEntry::Phase1q {
593                            qubit: *qubit,
594                            d0: d0.conj(),
595                            d1: d1.conj(),
596                        },
597                        DiagEntry::Phase2q { q0, q1, phase } => DiagEntry::Phase2q {
598                            q0: *q0,
599                            q1: *q1,
600                            phase: phase.conj(),
601                        },
602                        DiagEntry::Parity2q { q0, q1, same, diff } => DiagEntry::Parity2q {
603                            q0: *q0,
604                            q1: *q1,
605                            same: same.conj(),
606                            diff: diff.conj(),
607                        },
608                    })
609                    .collect(),
610            })),
611            Gate::MultiFused(data) => Gate::MultiFused(Box::new(MultiFusedData {
612                gates: data
613                    .gates
614                    .iter()
615                    .map(|&(target, mat)| (target, adjoint_2x2(&mat)))
616                    .collect(),
617                all_diagonal: data.all_diagonal,
618            })),
619            Gate::Fused2q(mat) => Gate::Fused2q(Box::new(adjoint_4x4(mat))),
620            Gate::Multi2q(data) => Gate::Multi2q(Box::new(Multi2qData {
621                gates: data
622                    .gates
623                    .iter()
624                    .rev()
625                    .map(|&(q0, q1, ref mat)| (q0, q1, adjoint_4x4(mat)))
626                    .collect(),
627            })),
628        }
629    }
630
631    /// Compute integer power of a single-qubit gate.
632    ///
633    /// Returns the gate raised to the `k`-th power. Negative `k` inverts first.
634    /// Only valid for single-qubit gates.
635    pub fn matrix_power(&self, k: i64) -> Gate {
636        debug_assert_eq!(
637            self.num_qubits(),
638            1,
639            "matrix_power only for single-qubit gates"
640        );
641        if k == 0 {
642            return Gate::Id;
643        }
644        if k == 1 {
645            return self.clone();
646        }
647        let base = if k < 0 { self.inverse() } else { self.clone() };
648        let n = k.unsigned_abs() as usize;
649        if n == 1 {
650            return base;
651        }
652        let base_mat = base.matrix_2x2();
653        let mut acc = base_mat;
654        for _ in 1..n {
655            acc = mat_mul_2x2(&base_mat, &acc);
656        }
657        Gate::Fused(Box::new(acc))
658    }
659
660    /// Create a single-controlled unitary gate with the given 2x2 matrix.
661    pub fn cu(mat: [[Complex64; 2]; 2]) -> Gate {
662        Gate::Cu(Box::new(mat))
663    }
664
665    /// Create a multi-controlled unitary gate with `num_controls` control qubits.
666    pub fn mcu(mat: [[Complex64; 2]; 2], num_controls: u8) -> Gate {
667        Gate::Mcu(Box::new(McuData { mat, num_controls }))
668    }
669
670    /// Create a controlled-phase gate CPhase(θ) = Cu(\[\[1,0\],\[0,e^{iθ}\]\]).
671    ///
672    /// Applies phase e^{iθ} to |11⟩ and identity to all other basis states.
673    pub fn cphase(theta: f64) -> Gate {
674        let one = Complex64::new(1.0, 0.0);
675        let zero = Complex64::new(0.0, 0.0);
676        let phase = Complex64::from_polar(1.0, theta);
677        Gate::cu([[one, zero], [zero, phase]])
678    }
679
680    /// Returns the phase if this is a controlled-phase gate (Cu/Mcu with
681    /// diagonal matrix `[[1,0],[0,e^{iθ}]]`).
682    ///
683    /// Used by backends to dispatch to optimized phase-only kernels that
684    /// touch half the memory of the generic controlled-unitary kernel.
685    #[inline]
686    pub fn controlled_phase(&self) -> Option<Complex64> {
687        let mat = match self {
688            Gate::Cu(mat) => &**mat,
689            Gate::Mcu(data) => &data.mat,
690            _ => return None,
691        };
692        if (mat[0][0].re - 1.0).abs() < IDENTITY_EPS
693            && mat[0][0].im.abs() < IDENTITY_EPS
694            && mat[0][1].norm() < IDENTITY_EPS
695            && mat[1][0].norm() < IDENTITY_EPS
696            && (mat[1][1].norm() - 1.0).abs() < IDENTITY_EPS
697        {
698            Some(mat[1][1])
699        } else {
700            None
701        }
702    }
703
704    /// True if this is a diagonal single-qubit gate (matrix is `[[a,0],[0,b]]`).
705    ///
706    /// Diagonal gates commute with CX on the control qubit and with CZ on
707    /// either qubit. Used by the commutation-aware reordering pass.
708    #[inline]
709    pub fn is_diagonal_1q(&self) -> bool {
710        match self {
711            Gate::Id
712            | Gate::Z
713            | Gate::S
714            | Gate::Sdg
715            | Gate::T
716            | Gate::Tdg
717            | Gate::Rz(_)
718            | Gate::P(_) => true,
719            Gate::Fused(m) => m[0][1].norm() < IDENTITY_EPS && m[1][0].norm() < IDENTITY_EPS,
720            _ => false,
721        }
722    }
723
724    /// True if this is a self-inverse two-qubit gate (applying it twice = identity).
725    #[inline]
726    pub fn is_self_inverse_2q(&self) -> bool {
727        matches!(self, Gate::Cx | Gate::Cz | Gate::Swap)
728    }
729
730    /// True if this gate maps computational basis states to computational basis
731    /// states (with at most a phase). Such gates preserve the number of non-zero
732    /// amplitudes, making the sparse backend optimal (O(1) memory for |0...0⟩).
733    ///
734    /// Includes diagonal gates (Z, S, T, Rz, P, CZ) and permutation gates
735    /// (X, Y, CX, SWAP). Excludes superposition-creating gates (H, Rx, Ry, SX).
736    #[inline]
737    pub fn preserves_sparsity(&self) -> bool {
738        match self {
739            Gate::Id | Gate::X | Gate::Y | Gate::Z => true,
740            Gate::S | Gate::Sdg | Gate::T | Gate::Tdg => true,
741            Gate::Rz(_) | Gate::P(_) => true,
742            Gate::Rzz(_) | Gate::Cx | Gate::Cz | Gate::Swap => true,
743            Gate::Cu(mat) | Gate::Fused(mat) => {
744                let is_diag = mat[0][1].norm_sqr() < NEAR_ZERO_NORM_SQ
745                    && mat[1][0].norm_sqr() < NEAR_ZERO_NORM_SQ;
746                let is_antidiag = mat[0][0].norm_sqr() < NEAR_ZERO_NORM_SQ
747                    && mat[1][1].norm_sqr() < NEAR_ZERO_NORM_SQ;
748                is_diag || is_antidiag
749            }
750            Gate::Mcu(data) => {
751                let m = &data.mat;
752                let is_diag = m[0][1].norm_sqr() < NEAR_ZERO_NORM_SQ
753                    && m[1][0].norm_sqr() < NEAR_ZERO_NORM_SQ;
754                let is_antidiag = m[0][0].norm_sqr() < NEAR_ZERO_NORM_SQ
755                    && m[1][1].norm_sqr() < NEAR_ZERO_NORM_SQ;
756                is_diag || is_antidiag
757            }
758            Gate::BatchPhase(_) | Gate::BatchRzz(_) | Gate::DiagonalBatch(_) => true,
759            _ => false,
760        }
761    }
762
763    /// Try to recognize a 2x2 unitary matrix as a named gate (up to global phase).
764    ///
765    /// Used by the fusion pass to emit named gate variants instead of opaque
766    /// `Gate::Fused` matrices, enabling downstream passes (e.g. `clifford_prefix_split`)
767    /// to identify Clifford gates that arose from fusion (e.g. T·T → S).
768    pub fn recognize_matrix(mat: &[[Complex64; 2]; 2]) -> Option<Gate> {
769        const EPS: f64 = 1e-10;
770
771        // Check each candidate gate. For each, compute the global phase ratio
772        // mat[i][j] / ref[i][j] using the first non-zero entry, then verify
773        // all other entries match under that same phase.
774        let candidates: &[Gate] = &[
775            Gate::H,
776            Gate::X,
777            Gate::Y,
778            Gate::Z,
779            Gate::S,
780            Gate::Sdg,
781            Gate::T,
782            Gate::Tdg,
783            Gate::SX,
784            Gate::SXdg,
785        ];
786
787        for candidate in candidates {
788            let ref_mat = candidate.matrix_2x2();
789            if matrices_equal_up_to_phase(mat, &ref_mat, EPS) {
790                return Some(candidate.clone());
791            }
792        }
793
794        // Identity check: all off-diagonal zero, diagonal entries equal
795        if mat[0][1].norm_sqr() < EPS
796            && mat[1][0].norm_sqr() < EPS
797            && (mat[0][0] - mat[1][1]).norm_sqr() < EPS
798            && mat[0][0].norm_sqr() > EPS
799        {
800            return Some(Gate::Id);
801        }
802
803        None
804    }
805
806    /// True if this gate is a Clifford gate (relevant for stabilizer backend).
807    #[inline]
808    pub fn is_clifford(&self) -> bool {
809        matches!(
810            self,
811            Gate::Id
812                | Gate::X
813                | Gate::Y
814                | Gate::Z
815                | Gate::H
816                | Gate::S
817                | Gate::Sdg
818                | Gate::SX
819                | Gate::SXdg
820                | Gate::Cx
821                | Gate::Cz
822                | Gate::Swap
823        )
824    }
825}
826
827/// Check if two 2x2 unitary matrices are equal up to a global phase factor.
828fn matrices_equal_up_to_phase(a: &[[Complex64; 2]; 2], b: &[[Complex64; 2]; 2], eps: f64) -> bool {
829    // Find the first non-zero entry in b to determine the phase ratio
830    let mut phase = None;
831    for i in 0..2 {
832        for j in 0..2 {
833            if b[i][j].norm_sqr() > eps {
834                if a[i][j].norm_sqr() < eps {
835                    return false;
836                }
837                phase = Some(a[i][j] / b[i][j]);
838                break;
839            }
840        }
841        if phase.is_some() {
842            break;
843        }
844    }
845
846    let phase = match phase {
847        Some(p) => p,
848        None => return true, // Both are zero matrices
849    };
850
851    // Verify all entries match under the same phase
852    for i in 0..2 {
853        for j in 0..2 {
854            let expected = phase * b[i][j];
855            if (a[i][j] - expected).norm_sqr() > eps {
856                return false;
857            }
858        }
859    }
860    true
861}
862
863fn format_angle(theta: f64) -> String {
864    const FRACTIONS: &[(f64, &str)] = &[
865        (1.0, "π"),
866        (-1.0, "-π"),
867        (0.5, "π/2"),
868        (-0.5, "-π/2"),
869        (0.25, "π/4"),
870        (-0.25, "-π/4"),
871        (1.0 / 3.0, "π/3"),
872        (-1.0 / 3.0, "-π/3"),
873        (2.0 / 3.0, "2π/3"),
874        (-2.0 / 3.0, "-2π/3"),
875        (1.0 / 6.0, "π/6"),
876        (-1.0 / 6.0, "-π/6"),
877        (5.0 / 6.0, "5π/6"),
878        (-5.0 / 6.0, "-5π/6"),
879        (1.0 / 8.0, "π/8"),
880        (-1.0 / 8.0, "-π/8"),
881        (3.0 / 8.0, "3π/8"),
882        (-3.0 / 8.0, "-3π/8"),
883        (1.5, "3π/2"),
884        (-1.5, "-3π/2"),
885        (2.0, "2π"),
886        (-2.0, "-2π"),
887    ];
888    let ratio = theta / std::f64::consts::PI;
889    for &(frac, label) in FRACTIONS {
890        if (ratio - frac).abs() < 1e-10 {
891            return label.to_string();
892        }
893    }
894    format!("{:.4}", theta)
895}
896
897impl fmt::Display for Gate {
898    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
899        match self {
900            Gate::Id => f.write_str("I"),
901            Gate::X => f.write_str("X"),
902            Gate::Y => f.write_str("Y"),
903            Gate::Z => f.write_str("Z"),
904            Gate::H => f.write_str("H"),
905            Gate::S => f.write_str("S"),
906            Gate::Sdg => f.write_str("Sdg"),
907            Gate::T => f.write_str("T"),
908            Gate::Tdg => f.write_str("Tdg"),
909            Gate::SX => f.write_str("SX"),
910            Gate::SXdg => f.write_str("SXdg"),
911            Gate::Rx(t) => write!(f, "Rx({})", format_angle(*t)),
912            Gate::Ry(t) => write!(f, "Ry({})", format_angle(*t)),
913            Gate::Rz(t) => write!(f, "Rz({})", format_angle(*t)),
914            Gate::P(t) => write!(f, "P({})", format_angle(*t)),
915            Gate::Rzz(t) => write!(f, "Rzz({})", format_angle(*t)),
916            Gate::Cx => f.write_str("CX"),
917            Gate::Cz => f.write_str("CZ"),
918            Gate::Swap => f.write_str("SWAP"),
919            Gate::Cu(_) => f.write_str("CU"),
920            Gate::Mcu(data) => write!(f, "MCU({}ctrl)", data.num_controls),
921            Gate::Fused(_) => f.write_str("U"),
922            Gate::Fused2q(_) => f.write_str("U2"),
923            Gate::MultiFused(data) => write!(f, "MF[{}]", data.gates.len()),
924            Gate::BatchPhase(data) => write!(f, "BP[{}]", data.phases.len()),
925            Gate::QftBlock { start, num } => write!(f, "QFT[{}..{}]", start, start + num),
926            Gate::BatchRzz(data) => write!(f, "BZZ[{}]", data.edges.len()),
927            Gate::DiagonalBatch(data) => write!(f, "BD[{}]", data.entries.len()),
928            Gate::Multi2q(data) => write!(f, "M2[{}]", data.gates.len()),
929        }
930    }
931}
932
933#[cfg(test)]
934mod tests {
935    use super::*;
936
937    #[test]
938    fn format_angle_pi_fractions() {
939        assert_eq!(format_angle(std::f64::consts::PI), "π");
940        assert_eq!(format_angle(std::f64::consts::FRAC_PI_2), "π/2");
941        assert_eq!(format_angle(std::f64::consts::FRAC_PI_4), "π/4");
942        assert_eq!(format_angle(-std::f64::consts::FRAC_PI_4), "-π/4");
943        assert_eq!(format_angle(std::f64::consts::PI / 3.0), "π/3");
944        assert_eq!(format_angle(0.123), "0.1230");
945    }
946
947    #[test]
948    fn display_labels() {
949        assert_eq!(Gate::H.to_string(), "H");
950        assert_eq!(Gate::Cx.to_string(), "CX");
951        assert_eq!(Gate::Rx(std::f64::consts::FRAC_PI_2).to_string(), "Rx(π/2)");
952        assert_eq!(Gate::Rz(0.5).to_string(), "Rz(0.5000)");
953        assert_eq!(Gate::Id.to_string(), "I");
954        assert_eq!(Gate::Swap.to_string(), "SWAP");
955    }
956
957    #[test]
958    fn test_gate_arity() {
959        assert_eq!(Gate::H.num_qubits(), 1);
960        assert_eq!(Gate::Rx(0.5).num_qubits(), 1);
961        assert_eq!(Gate::Cx.num_qubits(), 2);
962        assert_eq!(Gate::Swap.num_qubits(), 2);
963    }
964
965    #[test]
966    fn test_h_matrix_is_unitary() {
967        let m = Gate::H.matrix_2x2();
968        // H * H = I
969        let mut product = [[Complex64::new(0.0, 0.0); 2]; 2];
970        for i in 0..2 {
971            for j in 0..2 {
972                for (k, row) in m.iter().enumerate() {
973                    product[i][j] += m[i][k] * row[j];
974                }
975            }
976        }
977        let eps = 1e-12;
978        assert!((product[0][0].re - 1.0).abs() < eps);
979        assert!(product[0][0].im.abs() < eps);
980        assert!(product[0][1].norm() < eps);
981        assert!(product[1][0].norm() < eps);
982        assert!((product[1][1].re - 1.0).abs() < eps);
983    }
984
985    #[test]
986    fn test_rx_pi_equals_neg_i_x() {
987        let rx = Gate::Rx(std::f64::consts::PI).matrix_2x2();
988        // Rx(π) = -i·X  (up to global phase)
989        // |Rx(π)[0][1]| should be 1
990        assert!((rx[0][1].norm() - 1.0).abs() < 1e-12);
991        assert!((rx[1][0].norm() - 1.0).abs() < 1e-12);
992        assert!(rx[0][0].norm() < 1e-12);
993        assert!(rx[1][1].norm() < 1e-12);
994    }
995
996    #[test]
997    fn test_clifford_classification() {
998        assert!(Gate::H.is_clifford());
999        assert!(Gate::S.is_clifford());
1000        assert!(Gate::Cx.is_clifford());
1001        assert!(!Gate::T.is_clifford());
1002        assert!(!Gate::Rx(0.5).is_clifford());
1003        assert!(!Gate::Cu(Box::new([[Complex64::new(1.0, 0.0); 2]; 2])).is_clifford());
1004    }
1005
1006    #[test]
1007    fn test_preserves_sparsity() {
1008        // Diagonal and permutation gates preserve sparsity
1009        assert!(Gate::Id.preserves_sparsity());
1010        assert!(Gate::X.preserves_sparsity());
1011        assert!(Gate::Y.preserves_sparsity());
1012        assert!(Gate::Z.preserves_sparsity());
1013        assert!(Gate::S.preserves_sparsity());
1014        assert!(Gate::T.preserves_sparsity());
1015        assert!(Gate::Rz(1.0).preserves_sparsity());
1016        assert!(Gate::P(0.5).preserves_sparsity());
1017        assert!(Gate::Cx.preserves_sparsity());
1018        assert!(Gate::Cz.preserves_sparsity());
1019        assert!(Gate::Swap.preserves_sparsity());
1020
1021        // Superposition-creating gates do NOT preserve sparsity
1022        assert!(!Gate::H.preserves_sparsity());
1023        assert!(!Gate::Rx(0.5).preserves_sparsity());
1024        assert!(!Gate::Ry(0.5).preserves_sparsity());
1025        assert!(!Gate::SX.preserves_sparsity());
1026        assert!(!Gate::SXdg.preserves_sparsity());
1027
1028        // Cu with diagonal matrix preserves sparsity
1029        let diag = Box::new([
1030            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
1031            [Complex64::new(0.0, 0.0), Complex64::new(0.0, 1.0)],
1032        ]);
1033        assert!(Gate::Cu(diag).preserves_sparsity());
1034
1035        // Cu with H-like matrix does NOT preserve sparsity
1036        let h_mat = Box::new(Gate::H.matrix_2x2());
1037        assert!(!Gate::Cu(h_mat).preserves_sparsity());
1038    }
1039
1040    #[test]
1041    fn test_cu_arity() {
1042        let mat = Gate::H.matrix_2x2();
1043        assert_eq!(Gate::Cu(Box::new(mat)).num_qubits(), 2);
1044    }
1045
1046    fn assert_mat_close(a: &[[Complex64; 2]; 2], b: &[[Complex64; 2]; 2], eps: f64) {
1047        for i in 0..2 {
1048            for j in 0..2 {
1049                assert!(
1050                    (a[i][j] - b[i][j]).norm() < eps,
1051                    "mat[{i}][{j}]: expected {:?}, got {:?}",
1052                    b[i][j],
1053                    a[i][j]
1054                );
1055            }
1056        }
1057    }
1058
1059    #[test]
1060    fn test_inverse_self_inverse() {
1061        assert_eq!(Gate::H.inverse(), Gate::H);
1062        assert_eq!(Gate::X.inverse(), Gate::X);
1063        assert_eq!(Gate::Y.inverse(), Gate::Y);
1064        assert_eq!(Gate::Z.inverse(), Gate::Z);
1065        assert_eq!(Gate::Id.inverse(), Gate::Id);
1066        assert_eq!(Gate::Cx.inverse(), Gate::Cx);
1067        assert_eq!(Gate::Cz.inverse(), Gate::Cz);
1068        assert_eq!(Gate::Swap.inverse(), Gate::Swap);
1069    }
1070
1071    #[test]
1072    fn test_inverse_adjoint_pairs() {
1073        assert_eq!(Gate::S.inverse(), Gate::Sdg);
1074        assert_eq!(Gate::Sdg.inverse(), Gate::S);
1075        assert_eq!(Gate::T.inverse(), Gate::Tdg);
1076        assert_eq!(Gate::Tdg.inverse(), Gate::T);
1077    }
1078
1079    #[test]
1080    fn test_inverse_parametric() {
1081        assert_eq!(Gate::Rx(0.5).inverse(), Gate::Rx(-0.5));
1082        assert_eq!(Gate::Ry(1.0).inverse(), Gate::Ry(-1.0));
1083        assert_eq!(Gate::Rz(PI).inverse(), Gate::Rz(-PI));
1084    }
1085
1086    #[test]
1087    fn test_inverse_fused_is_adjoint() {
1088        let s_mat = Gate::S.matrix_2x2();
1089        let fused = Gate::Fused(Box::new(s_mat));
1090        let inv = fused.inverse();
1091        if let Gate::Fused(inv_mat) = &inv {
1092            assert_mat_close(inv_mat, &Gate::Sdg.matrix_2x2(), 1e-12);
1093        } else {
1094            panic!("expected Fused");
1095        }
1096    }
1097
1098    #[test]
1099    fn test_inverse_cu() {
1100        let rz_mat = Gate::Rz(0.5).matrix_2x2();
1101        let cu = Gate::Cu(Box::new(rz_mat));
1102        let inv = cu.inverse();
1103        if let Gate::Cu(inv_mat) = &inv {
1104            let expected = Gate::Rz(-0.5).matrix_2x2();
1105            assert_mat_close(inv_mat, &expected, 1e-12);
1106        } else {
1107            panic!("expected Cu");
1108        }
1109    }
1110
1111    #[test]
1112    fn test_matrix_power_zero() {
1113        assert_eq!(Gate::X.matrix_power(0), Gate::Id);
1114        assert_eq!(Gate::Rz(0.5).matrix_power(0), Gate::Id);
1115    }
1116
1117    #[test]
1118    fn test_matrix_power_one() {
1119        assert_eq!(Gate::X.matrix_power(1), Gate::X);
1120        assert_eq!(Gate::H.matrix_power(1), Gate::H);
1121    }
1122
1123    #[test]
1124    fn test_matrix_power_x_squared() {
1125        let x2 = Gate::X.matrix_power(2);
1126        if let Gate::Fused(mat) = &x2 {
1127            assert_mat_close(mat, &Gate::Id.matrix_2x2(), 1e-12);
1128        } else {
1129            panic!("expected Fused");
1130        }
1131    }
1132
1133    #[test]
1134    fn test_matrix_power_t_squared_is_s() {
1135        let t2 = Gate::T.matrix_power(2);
1136        if let Gate::Fused(mat) = &t2 {
1137            assert_mat_close(mat, &Gate::S.matrix_2x2(), 1e-12);
1138        } else {
1139            panic!("expected Fused");
1140        }
1141    }
1142
1143    #[test]
1144    fn test_matrix_power_negative() {
1145        let t_inv2 = Gate::T.matrix_power(-2);
1146        if let Gate::Fused(mat) = &t_inv2 {
1147            assert_mat_close(mat, &Gate::Sdg.matrix_2x2(), 1e-12);
1148        } else {
1149            panic!("expected Fused");
1150        }
1151    }
1152
1153    #[test]
1154    fn test_mcu_arity() {
1155        let mat = Gate::H.matrix_2x2();
1156        let mcu2 = Gate::Mcu(Box::new(McuData {
1157            mat,
1158            num_controls: 2,
1159        }));
1160        assert_eq!(mcu2.num_qubits(), 3);
1161        let mcu3 = Gate::Mcu(Box::new(McuData {
1162            mat,
1163            num_controls: 3,
1164        }));
1165        assert_eq!(mcu3.num_qubits(), 4);
1166    }
1167
1168    #[test]
1169    fn test_mcu_not_clifford() {
1170        let mat = Gate::X.matrix_2x2();
1171        let mcu = Gate::Mcu(Box::new(McuData {
1172            mat,
1173            num_controls: 2,
1174        }));
1175        assert!(!mcu.is_clifford());
1176    }
1177
1178    #[test]
1179    fn test_mcu_inverse() {
1180        let rz_mat = Gate::Rz(0.5).matrix_2x2();
1181        let mcu = Gate::Mcu(Box::new(McuData {
1182            mat: rz_mat,
1183            num_controls: 2,
1184        }));
1185        let inv = mcu.inverse();
1186        if let Gate::Mcu(inv_data) = &inv {
1187            let expected = Gate::Rz(-0.5).matrix_2x2();
1188            assert_mat_close(&inv_data.mat, &expected, 1e-12);
1189            assert_eq!(inv_data.num_controls, 2);
1190        } else {
1191            panic!("expected Mcu");
1192        }
1193    }
1194
1195    #[test]
1196    fn test_mcu_name() {
1197        let mat = Gate::H.matrix_2x2();
1198        let mcu = Gate::Mcu(Box::new(McuData {
1199            mat,
1200            num_controls: 2,
1201        }));
1202        assert_eq!(mcu.name(), "mcu");
1203    }
1204
1205    #[test]
1206    fn test_cphase_constructor() {
1207        let g = Gate::cphase(PI / 4.0);
1208        assert_eq!(g.num_qubits(), 2);
1209        assert_eq!(g.name(), "cu");
1210        if let Gate::Cu(mat) = &g {
1211            let one = Complex64::new(1.0, 0.0);
1212            assert!((mat[0][0] - one).norm() < 1e-14);
1213            assert!(mat[0][1].norm() < 1e-14);
1214            assert!(mat[1][0].norm() < 1e-14);
1215            let expected = Complex64::from_polar(1.0, PI / 4.0);
1216            assert!((mat[1][1] - expected).norm() < 1e-14);
1217        } else {
1218            panic!("expected Cu");
1219        }
1220    }
1221
1222    #[test]
1223    fn test_controlled_phase_detection() {
1224        let cp = Gate::cphase(0.5);
1225        assert!(cp.controlled_phase().is_some());
1226        let phase = cp.controlled_phase().unwrap();
1227        let expected = Complex64::from_polar(1.0, 0.5);
1228        assert!((phase - expected).norm() < 1e-14);
1229
1230        // Non-diagonal Cu should not be detected
1231        let h_mat = Gate::H.matrix_2x2();
1232        let cu_h = Gate::Cu(Box::new(h_mat));
1233        assert!(cu_h.controlled_phase().is_none());
1234
1235        // CZ is Cu([[1,0],[0,-1]]), should be detected (phase = -1)
1236        let z_mat = Gate::Z.matrix_2x2();
1237        let cu_z = Gate::Cu(Box::new(z_mat));
1238        assert!(cu_z.controlled_phase().is_some());
1239        let z_phase = cu_z.controlled_phase().unwrap();
1240        assert!((z_phase.re - (-1.0)).abs() < 1e-14);
1241
1242        // Rz-based Cu is diagonal but mat[0][0] != 1, should NOT be detected
1243        let rz_mat = Gate::Rz(0.5).matrix_2x2();
1244        let cu_rz = Gate::Cu(Box::new(rz_mat));
1245        assert!(cu_rz.controlled_phase().is_none());
1246
1247        // Non-Cu gates should return None
1248        assert!(Gate::H.controlled_phase().is_none());
1249        assert!(Gate::Cx.controlled_phase().is_none());
1250    }
1251
1252    #[test]
1253    fn test_controlled_phase_mcu() {
1254        let one = Complex64::new(1.0, 0.0);
1255        let zero = Complex64::new(0.0, 0.0);
1256        let phase = Complex64::from_polar(1.0, 0.7);
1257        let mcu = Gate::Mcu(Box::new(McuData {
1258            mat: [[one, zero], [zero, phase]],
1259            num_controls: 2,
1260        }));
1261        assert!(mcu.controlled_phase().is_some());
1262        assert!((mcu.controlled_phase().unwrap() - phase).norm() < 1e-14);
1263    }
1264
1265    #[test]
1266    fn test_sx_matrix_is_sqrt_x() {
1267        let sx = Gate::SX.matrix_2x2();
1268        let sx2 = mat_mul_2x2(&sx, &sx);
1269        assert_mat_close(&sx2, &Gate::X.matrix_2x2(), 1e-12);
1270    }
1271
1272    #[test]
1273    fn test_sxdg_is_sx_inverse() {
1274        let sx = Gate::SX.matrix_2x2();
1275        let sxdg = Gate::SXdg.matrix_2x2();
1276        let product = mat_mul_2x2(&sx, &sxdg);
1277        assert_mat_close(&product, &Gate::Id.matrix_2x2(), 1e-12);
1278    }
1279
1280    #[test]
1281    fn test_p_gate_matrix() {
1282        let p = Gate::P(PI / 4.0).matrix_2x2();
1283        let t = Gate::T.matrix_2x2();
1284        assert_mat_close(&p, &t, 1e-12);
1285    }
1286
1287    #[test]
1288    fn test_sx_is_clifford() {
1289        assert!(Gate::SX.is_clifford());
1290        assert!(Gate::SXdg.is_clifford());
1291    }
1292
1293    #[test]
1294    fn test_p_inverse() {
1295        assert_eq!(Gate::P(0.5).inverse(), Gate::P(-0.5));
1296    }
1297
1298    #[test]
1299    fn test_sx_inverse_pair() {
1300        assert_eq!(Gate::SX.inverse(), Gate::SXdg);
1301        assert_eq!(Gate::SXdg.inverse(), Gate::SX);
1302    }
1303
1304    #[test]
1305    fn test_is_diagonal_1q() {
1306        assert!(Gate::Id.is_diagonal_1q());
1307        assert!(Gate::Z.is_diagonal_1q());
1308        assert!(Gate::S.is_diagonal_1q());
1309        assert!(Gate::Sdg.is_diagonal_1q());
1310        assert!(Gate::T.is_diagonal_1q());
1311        assert!(Gate::Tdg.is_diagonal_1q());
1312        assert!(Gate::Rz(0.5).is_diagonal_1q());
1313        assert!(Gate::P(0.5).is_diagonal_1q());
1314        assert!(!Gate::H.is_diagonal_1q());
1315        assert!(!Gate::X.is_diagonal_1q());
1316        assert!(!Gate::Y.is_diagonal_1q());
1317        assert!(!Gate::Rx(0.5).is_diagonal_1q());
1318        assert!(!Gate::Ry(0.5).is_diagonal_1q());
1319        assert!(!Gate::SX.is_diagonal_1q());
1320        assert!(!Gate::Cx.is_diagonal_1q());
1321
1322        let diag_fused = Gate::Fused(Box::new(Gate::T.matrix_2x2()));
1323        assert!(diag_fused.is_diagonal_1q());
1324        let nondiag_fused = Gate::Fused(Box::new(Gate::H.matrix_2x2()));
1325        assert!(!nondiag_fused.is_diagonal_1q());
1326    }
1327
1328    #[test]
1329    fn test_is_self_inverse_2q() {
1330        assert!(Gate::Cx.is_self_inverse_2q());
1331        assert!(Gate::Cz.is_self_inverse_2q());
1332        assert!(Gate::Swap.is_self_inverse_2q());
1333        assert!(!Gate::H.is_self_inverse_2q());
1334        assert!(!Gate::T.is_self_inverse_2q());
1335        let mat = Gate::H.matrix_2x2();
1336        assert!(!Gate::Cu(Box::new(mat)).is_self_inverse_2q());
1337    }
1338
1339    #[test]
1340    fn test_gate_enum_size() {
1341        assert_eq!(
1342            std::mem::size_of::<Gate>(),
1343            16,
1344            "Gate enum must stay at 16 bytes"
1345        );
1346    }
1347
1348    #[test]
1349    fn test_recognize_named_gates() {
1350        for gate in &[
1351            Gate::H,
1352            Gate::X,
1353            Gate::Y,
1354            Gate::Z,
1355            Gate::S,
1356            Gate::Sdg,
1357            Gate::T,
1358            Gate::Tdg,
1359            Gate::SX,
1360            Gate::SXdg,
1361        ] {
1362            let mat = gate.matrix_2x2();
1363            let recognized = Gate::recognize_matrix(&mat);
1364            assert_eq!(
1365                recognized.as_ref(),
1366                Some(gate),
1367                "failed to recognize {:?}",
1368                gate.name()
1369            );
1370        }
1371    }
1372
1373    #[test]
1374    fn test_recognize_identity() {
1375        let id = Gate::Id.matrix_2x2();
1376        assert_eq!(Gate::recognize_matrix(&id), Some(Gate::Id));
1377    }
1378
1379    #[test]
1380    fn test_recognize_t_squared_is_s() {
1381        let t = Gate::T.matrix_2x2();
1382        let tt = mat_mul_2x2(&t, &t);
1383        assert_eq!(Gate::recognize_matrix(&tt), Some(Gate::S));
1384    }
1385
1386    #[test]
1387    fn test_recognize_s_squared_is_z() {
1388        let s = Gate::S.matrix_2x2();
1389        let ss = mat_mul_2x2(&s, &s);
1390        assert_eq!(Gate::recognize_matrix(&ss), Some(Gate::Z));
1391    }
1392
1393    #[test]
1394    fn test_recognize_h_squared_is_identity() {
1395        let h = Gate::H.matrix_2x2();
1396        let hh = mat_mul_2x2(&h, &h);
1397        assert_eq!(Gate::recognize_matrix(&hh), Some(Gate::Id));
1398    }
1399
1400    #[test]
1401    fn test_recognize_t_fourth_is_z() {
1402        let t = Gate::T.matrix_2x2();
1403        let t2 = mat_mul_2x2(&t, &t);
1404        let t4 = mat_mul_2x2(&t2, &t2);
1405        assert_eq!(Gate::recognize_matrix(&t4), Some(Gate::Z));
1406    }
1407
1408    #[test]
1409    fn test_recognize_non_clifford_returns_none() {
1410        let rx = Gate::Rx(0.7).matrix_2x2();
1411        assert_eq!(Gate::recognize_matrix(&rx), None);
1412        let ry = Gate::Ry(1.3).matrix_2x2();
1413        assert_eq!(Gate::recognize_matrix(&ry), None);
1414    }
1415
1416    #[test]
1417    fn test_recognize_global_phase_invariance() {
1418        let phase = Complex64::from_polar(1.0, 0.42);
1419        let h = Gate::H.matrix_2x2();
1420        let phased = [
1421            [h[0][0] * phase, h[0][1] * phase],
1422            [h[1][0] * phase, h[1][1] * phase],
1423        ];
1424        assert_eq!(Gate::recognize_matrix(&phased), Some(Gate::H));
1425    }
1426}