Skip to main content

prism_q/gates/
mod.rs

1//! Gate definitions and matrix representations.
2//!
3//! Gates are represented as an enum for fast dispatch without trait-object overhead
4//! in the simulation hot path. Matrix representations use stack-allocated arrays
5//! to avoid heap allocation during gate application.
6//!
7//! # Hot-path design notes
8//! - `Gate` methods take `&self`, the enum is 16 bytes (Box indirection for `Fused`).
9//! - `matrix_2x2` returns `[[Complex64; 2]; 2]` on the stack.
10//! - Two-qubit gates (CX, CZ, SWAP) have dedicated application routines in
11//!   backends rather than materializing a 4×4 matrix.
12
13use num_complex::Complex64;
14use smallvec::SmallVec;
15use std::f64::consts::{FRAC_1_SQRT_2, PI};
16use std::fmt;
17
18/// Threshold for detecting near-zero matrix elements (norm_sqr).
19///
20/// Used in `preserves_sparsity()` to test if off-diagonal or diagonal entries
21/// are effectively zero, indicating a permutation/diagonal gate structure.
22const NEAR_ZERO_NORM_SQ: f64 = 1e-24;
23
24/// Threshold for detecting identity-like matrices (element norm).
25///
26/// Used in `is_diagonal_1q()` for fused gate diagonal detection and in
27/// `controlled_phase()` for phase-gate structure recognition.
28const IDENTITY_EPS: f64 = 1e-12;
29
30/// Quantum gate identifier.
31///
32/// Covers the v0 supported gate set. Most variants are data-free or carry an `f64`
33/// parameter inline. The `Fused` variant uses `Box` to keep the enum at 16 bytes.
34#[derive(Debug, Clone, PartialEq)]
35pub enum Gate {
36    /// Identity.
37    Id,
38    /// Pauli-X (bit flip).
39    X,
40    /// Pauli-Y.
41    Y,
42    /// Pauli-Z (phase flip).
43    Z,
44    /// Hadamard.
45    H,
46    /// S gate (√Z).
47    S,
48    /// S† gate.
49    Sdg,
50    /// T gate (π/8).
51    T,
52    /// T† gate.
53    Tdg,
54    /// √X gate.
55    SX,
56    /// √X† gate.
57    SXdg,
58
59    /// Rotation about X-axis by angle (radians).
60    Rx(f64),
61    /// Rotation about Y-axis by angle (radians).
62    Ry(f64),
63    /// Rotation about Z-axis by angle (radians).
64    Rz(f64),
65    /// Phase gate `[[1,0],[0,e^{iθ}]]`.
66    P(f64),
67
68    /// ZZ rotation: diag(e^{-iθ/2}, e^{iθ/2}, e^{iθ/2}, e^{-iθ/2}).
69    /// Qubit order: [q0, q1] (symmetric).
70    Rzz(f64),
71
72    /// Controlled-X (CNOT). Qubit order: [control, target].
73    Cx,
74    /// Controlled-Z. Qubit order: [q0, q1] (symmetric).
75    Cz,
76    /// SWAP. Qubit order: [q0, q1] (symmetric).
77    Swap,
78
79    /// Controlled-unitary. Applies the boxed 2×2 matrix to the target qubit
80    /// only when the control qubit is |1⟩. Qubit order: [control, target].
81    /// Boxed to keep `Gate` at 16 bytes.
82    Cu(Box<[[Complex64; 2]; 2]>),
83
84    /// Multi-controlled unitary. Applies the 2×2 matrix to the target qubit
85    /// only when all control qubits are |1⟩. Qubit order:
86    /// `[ctrl_0, ctrl_1, ..., ctrl_{k-1}, target]`.
87    /// Boxed to keep `Gate` at 16 bytes.
88    Mcu(Box<McuData>),
89
90    /// Pre-fused single-qubit unitary (product of consecutive gates on the same target).
91    /// Boxed to keep `Gate` at 16 bytes for cache-friendly instruction streams.
92    Fused(Box<[[Complex64; 2]; 2]>),
93
94    /// Batched controlled-phase: multiple cphase gates sharing a control qubit,
95    /// fused into a single pass over the statevector. Created by the cphase
96    /// fusion pass. Targets: `[control]`. The `BatchPhaseData` holds per-target
97    /// phases. Boxed to keep `Gate` at 16 bytes.
98    BatchPhase(Box<BatchPhaseData>),
99
100    /// Batched ZZ rotations: multiple Rzz gates fused into a single pass.
101    /// Created by the batch-Rzz fusion pass. The `BatchRzzData` holds per-edge
102    /// angles. Boxed to keep `Gate` at 16 bytes.
103    BatchRzz(Box<BatchRzzData>),
104
105    /// Batched diagonal gates: a contiguous run of diagonal 1q and 2q gates
106    /// collapsed into a single state-vector sweep with a precomputed phase LUT.
107    /// Subsumes BatchPhase and BatchRzz for mixed diagonal runs. Created by the
108    /// diagonal batch fusion pass. Boxed to keep `Gate` at 16 bytes.
109    DiagonalBatch(Box<DiagonalBatchData>),
110
111    /// Multiple single-qubit gates on distinct qubits, batched for a single
112    /// tiled pass over the statevector. Created by the multi-gate fusion pass.
113    /// Boxed to keep `Gate` at 16 bytes.
114    MultiFused(Box<MultiFusedData>),
115
116    /// Pre-fused two-qubit unitary (4×4 matrix). Created by the 2q fusion pass
117    /// which absorbs adjacent single-qubit gates into a two-qubit gate.
118    /// Boxed to keep `Gate` at 16 bytes.
119    Fused2q(Box<[[Complex64; 4]; 4]>),
120
121    /// Multiple two-qubit gates batched for a single tiled pass over the
122    /// statevector. Created by the multi-2q fusion pass. Each entry stores
123    /// `(q0, q1, 4×4 matrix)`. Boxed to keep `Gate` at 16 bytes.
124    Multi2q(Box<Multi2qData>),
125
126    /// Quantum Fourier Transform on `start..start+num`.
127    ///
128    /// The CPU statevector backend has a fast whole-state FFT path. Subrange
129    /// blocks and non-native backends expand to textbook H, cphase, and swap
130    /// gates before execution.
131    /// Boxless: `(u8, u8)` fits within the 16-byte enum slot.
132    QftBlock { start: u8, num: u8 },
133}
134
135/// Data for a multi-controlled unitary gate.
136#[derive(Debug, Clone, PartialEq)]
137pub struct McuData {
138    /// 2×2 unitary applied to the target qubit.
139    pub mat: [[Complex64; 2]; 2],
140    /// Number of control qubits (≥ 2).
141    pub num_controls: u8,
142}
143
144/// Data for a batched controlled-phase gate.
145///
146/// Multiple cphase gates sharing a control qubit are fused into one pass.
147/// Each entry is `(target_qubit, phase)`. The control qubit is stored in the
148/// instruction's `targets[0]`.
149#[derive(Debug, Clone, PartialEq)]
150pub struct BatchPhaseData {
151    pub phases: SmallVec<[(usize, Complex64); 8]>,
152}
153
154/// Data for batched ZZ rotations.
155///
156/// Multiple Rzz gates batched into a single pass over the statevector.
157/// Each entry is `(qubit_0, qubit_1, theta)`. All qubits are stored in the
158/// instruction's `targets`.
159#[derive(Debug, Clone, PartialEq)]
160pub struct BatchRzzData {
161    pub edges: Vec<(usize, usize, f64)>,
162}
163
164/// An individual diagonal phase contribution in a [`DiagonalBatchData`].
165#[derive(Debug, Clone, PartialEq)]
166pub enum DiagEntry {
167    /// Diagonal on a single qubit: `state[i] *= d0` when bit 0, `*= d1` when bit 1.
168    Phase1q {
169        qubit: usize,
170        d0: Complex64,
171        d1: Complex64,
172    },
173    /// Phase on a qubit pair: `state[i] *= phase` when both bits are set (CZ/CPhase).
174    Phase2q {
175        q0: usize,
176        q1: usize,
177        phase: Complex64,
178    },
179    /// Parity-dependent phase (Rzz): `state[i] *= same` when parity is even,
180    /// `state[i] *= diff` when parity is odd.
181    Parity2q {
182        q0: usize,
183        q1: usize,
184        same: Complex64,
185        diff: Complex64,
186    },
187}
188
189impl DiagEntry {
190    pub fn as_1q_matrix(&self) -> Option<(usize, [[Complex64; 2]; 2])> {
191        match *self {
192            DiagEntry::Phase1q { qubit, d0, d1 } => {
193                let z = Complex64::new(0.0, 0.0);
194                Some((qubit, [[d0, z], [z, d1]]))
195            }
196            _ => None,
197        }
198    }
199
200    pub fn as_2q_matrix(&self) -> Option<(usize, usize, [[Complex64; 4]; 4])> {
201        let z = Complex64::new(0.0, 0.0);
202        let one = Complex64::new(1.0, 0.0);
203        match *self {
204            DiagEntry::Phase2q { q0, q1, phase } => Some((
205                q0,
206                q1,
207                [
208                    [one, z, z, z],
209                    [z, one, z, z],
210                    [z, z, one, z],
211                    [z, z, z, phase],
212                ],
213            )),
214            DiagEntry::Parity2q {
215                q0, q1, same, diff, ..
216            } => Some((
217                q0,
218                q1,
219                [
220                    [same, z, z, z],
221                    [z, diff, z, z],
222                    [z, z, diff, z],
223                    [z, z, z, same],
224                ],
225            )),
226            _ => None,
227        }
228    }
229}
230
231/// Data for a batched diagonal gate pass.
232///
233/// A contiguous run of diagonal gates collapsed into a precomputed phase LUT.
234/// The `entries` describe individual phase contributions; the kernel extracts
235/// unique qubits, builds a LUT indexed by their bits, and applies in one sweep.
236#[derive(Debug, Clone, PartialEq)]
237pub struct DiagonalBatchData {
238    pub entries: Vec<DiagEntry>,
239}
240
241/// Data for multi-gate single-pass fusion.
242///
243/// Batches consecutive single-qubit gates on distinct qubits into one tiled
244/// pass over the statevector. Each entry is `(target_qubit, 2×2 matrix)`.
245#[derive(Debug, Clone, PartialEq)]
246pub struct MultiFusedData {
247    pub gates: Vec<(usize, [[Complex64; 2]; 2])>,
248    pub all_diagonal: bool,
249}
250
251/// Data for multi-2q tiled pass fusion.
252///
253/// Batches consecutive two-qubit gates into a single cache-tiled pass over the
254/// statevector. Each entry is `(q0, q1, 4×4 matrix)`. Gate order is preserved.
255#[derive(Debug, Clone, PartialEq)]
256pub struct Multi2qData {
257    pub gates: Vec<(usize, usize, [[Complex64; 4]; 4])>,
258}
259
260/// Kronecker product of two 2×2 matrices: A ⊗ B → 4×4.
261///
262/// Result indices: `(i*2+j, k*2+l) = A[i][k] * B[j][l]`
263/// where i,k index A (targets\[0\]) and j,l index B (targets\[1\]).
264#[inline]
265pub(crate) fn kron_2x2(a: &[[Complex64; 2]; 2], b: &[[Complex64; 2]; 2]) -> [[Complex64; 4]; 4] {
266    let mut result = [[Complex64::new(0.0, 0.0); 4]; 4];
267    for i in 0..2 {
268        for k in 0..2 {
269            let aik = a[i][k];
270            for j in 0..2 {
271                for l in 0..2 {
272                    result[i * 2 + j][k * 2 + l] = aik * b[j][l];
273                }
274            }
275        }
276    }
277    result
278}
279
280/// Product of two 4×4 matrices: A · B.
281#[inline]
282pub(crate) fn mat_mul_4x4(a: &[[Complex64; 4]; 4], b: &[[Complex64; 4]; 4]) -> [[Complex64; 4]; 4] {
283    let zero = Complex64::new(0.0, 0.0);
284    let mut result = [[zero; 4]; 4];
285    for i in 0..4 {
286        for j in 0..4 {
287            let mut sum = zero;
288            for k in 0..4 {
289                sum += a[i][k] * b[k][j];
290            }
291            result[i][j] = sum;
292        }
293    }
294    result
295}
296
297/// Conjugate-transpose of a 4×4 matrix (U†).
298fn adjoint_4x4(m: &[[Complex64; 4]; 4]) -> [[Complex64; 4]; 4] {
299    let mut result = [[Complex64::new(0.0, 0.0); 4]; 4];
300    for i in 0..4 {
301        for j in 0..4 {
302            result[i][j] = m[j][i].conj();
303        }
304    }
305    result
306}
307
308/// Conjugate-transpose of a 2×2 matrix (U†).
309fn adjoint_2x2(m: &[[Complex64; 2]; 2]) -> [[Complex64; 2]; 2] {
310    [
311        [m[0][0].conj(), m[1][0].conj()],
312        [m[0][1].conj(), m[1][1].conj()],
313    ]
314}
315
316#[inline]
317fn count_unique_qubits<I: IntoIterator<Item = usize>>(iter: I) -> usize {
318    let mut seen: SmallVec<[usize; 8]> = SmallVec::new();
319    for q in iter {
320        if !seen.contains(&q) {
321            seen.push(q);
322        }
323    }
324    seen.len()
325}
326
327#[inline]
328fn push_unique_qubit(seen: &mut SmallVec<[usize; 8]>, qubit: usize) {
329    if !seen.contains(&qubit) {
330        seen.push(qubit);
331    }
332}
333
334#[inline]
335fn count_unique_diag_qubits(entries: &[DiagEntry]) -> usize {
336    let mut seen: SmallVec<[usize; 8]> = SmallVec::new();
337    for entry in entries {
338        match entry {
339            DiagEntry::Phase1q { qubit, .. } => push_unique_qubit(&mut seen, *qubit),
340            DiagEntry::Phase2q { q0, q1, .. } | DiagEntry::Parity2q { q0, q1, .. } => {
341                push_unique_qubit(&mut seen, *q0);
342                push_unique_qubit(&mut seen, *q1);
343            }
344        }
345    }
346    seen.len()
347}
348
349/// Product of two 2×2 matrices: A · B.
350#[inline]
351pub(crate) fn mat_mul_2x2(a: &[[Complex64; 2]; 2], b: &[[Complex64; 2]; 2]) -> [[Complex64; 2]; 2] {
352    [
353        [
354            a[0][0] * b[0][0] + a[0][1] * b[1][0],
355            a[0][0] * b[0][1] + a[0][1] * b[1][1],
356        ],
357        [
358            a[1][0] * b[0][0] + a[1][1] * b[1][0],
359            a[1][0] * b[0][1] + a[1][1] * b[1][1],
360        ],
361    ]
362}
363
364impl Gate {
365    /// Number of qubits this gate acts on.
366    #[inline]
367    pub fn num_qubits(&self) -> usize {
368        match self {
369            Gate::Rzz(_) | Gate::Cx | Gate::Cz | Gate::Swap | Gate::Cu(_) | Gate::Fused2q(_) => 2,
370            Gate::Mcu(data) => data.num_controls as usize + 1,
371            Gate::BatchPhase(data) => 1 + data.phases.len(),
372            Gate::QftBlock { num, .. } => *num as usize,
373            Gate::BatchRzz(data) => {
374                count_unique_qubits(data.edges.iter().flat_map(|&(q0, q1, _)| [q0, q1]))
375            }
376            Gate::DiagonalBatch(data) => count_unique_diag_qubits(&data.entries),
377            Gate::MultiFused(data) => data.gates.len(),
378            Gate::Multi2q(data) => {
379                count_unique_qubits(data.gates.iter().flat_map(|&(q0, q1, _)| [q0, q1]))
380            }
381            _ => 1,
382        }
383    }
384
385    /// Returns the 2×2 unitary matrix for single-qubit gates.
386    ///
387    /// # Panics
388    /// Panics if called on a multi-qubit or batch gate (`Cx`, `Cz`, `Swap`,
389    /// `Cu`, `Mcu`, `BatchPhase`, `MultiFused`, `Fused2q`, `Multi2q`).
390    #[inline]
391    pub fn matrix_2x2(&self) -> [[Complex64; 2]; 2] {
392        let zero = Complex64::new(0.0, 0.0);
393        let one = Complex64::new(1.0, 0.0);
394        let i = Complex64::new(0.0, 1.0);
395        let neg_i = Complex64::new(0.0, -1.0);
396        let h = Complex64::new(FRAC_1_SQRT_2, 0.0);
397
398        match self {
399            Gate::Id => [[one, zero], [zero, one]],
400            Gate::X => [[zero, one], [one, zero]],
401            Gate::Y => [[zero, neg_i], [i, zero]],
402            Gate::Z => [[one, zero], [zero, -one]],
403            Gate::H => [[h, h], [h, -h]],
404            Gate::S => [[one, zero], [zero, i]],
405            Gate::Sdg => [[one, zero], [zero, neg_i]],
406            Gate::T => {
407                let phase = Complex64::from_polar(1.0, PI / 4.0);
408                [[one, zero], [zero, phase]]
409            }
410            Gate::Tdg => {
411                let phase = Complex64::from_polar(1.0, -PI / 4.0);
412                [[one, zero], [zero, phase]]
413            }
414            Gate::SX => {
415                let half = Complex64::new(0.5, 0.0);
416                let half_i = Complex64::new(0.0, 0.5);
417                [
418                    [half + half_i, half - half_i],
419                    [half - half_i, half + half_i],
420                ]
421            }
422            Gate::SXdg => {
423                let half = Complex64::new(0.5, 0.0);
424                let half_i = Complex64::new(0.0, 0.5);
425                [
426                    [half - half_i, half + half_i],
427                    [half + half_i, half - half_i],
428                ]
429            }
430            Gate::Rx(theta) => {
431                let c = Complex64::new((theta / 2.0).cos(), 0.0);
432                let s = Complex64::new(0.0, -(theta / 2.0).sin());
433                [[c, s], [s, c]]
434            }
435            Gate::Ry(theta) => {
436                let c = Complex64::new((theta / 2.0).cos(), 0.0);
437                let s = Complex64::new((theta / 2.0).sin(), 0.0);
438                [[c, -s], [s, c]]
439            }
440            Gate::Rz(theta) => {
441                let e_neg = Complex64::from_polar(1.0, -theta / 2.0);
442                let e_pos = Complex64::from_polar(1.0, theta / 2.0);
443                [[e_neg, zero], [zero, e_pos]]
444            }
445            Gate::P(theta) => {
446                let phase = Complex64::from_polar(1.0, *theta);
447                [[one, zero], [zero, phase]]
448            }
449            Gate::Fused(mat) => **mat,
450            Gate::Rzz(_)
451            | Gate::Cx
452            | Gate::Cz
453            | Gate::Swap
454            | Gate::Cu(_)
455            | Gate::Mcu(_)
456            | Gate::BatchPhase(_)
457            | Gate::QftBlock { .. }
458            | Gate::BatchRzz(_)
459            | Gate::DiagonalBatch(_)
460            | Gate::MultiFused(_)
461            | Gate::Fused2q(_)
462            | Gate::Multi2q(_) => {
463                panic!(
464                    "matrix_2x2 called on {}-qubit gate `{}`; use dedicated backend routine",
465                    self.num_qubits(),
466                    self.name()
467                )
468            }
469        }
470    }
471
472    /// Returns the 4×4 unitary matrix for two-qubit gates.
473    ///
474    /// Matrix indices follow the convention: row/col `i*2+j` where `i` indexes
475    /// `targets[0]` and `j` indexes `targets[1]`.
476    ///
477    /// # Panics
478    /// Panics on gates other than `Cx`, `Cz`, `Swap`, `Cu`, or `Fused2q`.
479    pub fn matrix_4x4(&self) -> [[Complex64; 4]; 4] {
480        let z = Complex64::new(0.0, 0.0);
481        let o = Complex64::new(1.0, 0.0);
482        let m = Complex64::new(-1.0, 0.0);
483        match self {
484            Gate::Rzz(theta) => {
485                let ps = Complex64::from_polar(1.0, -theta / 2.0);
486                let pd = Complex64::from_polar(1.0, theta / 2.0);
487                [[ps, z, z, z], [z, pd, z, z], [z, z, pd, z], [z, z, z, ps]]
488            }
489            Gate::Cx => [[o, z, z, z], [z, o, z, z], [z, z, z, o], [z, z, o, z]],
490            Gate::Cz => [[o, z, z, z], [z, o, z, z], [z, z, o, z], [z, z, z, m]],
491            Gate::Swap => [[o, z, z, z], [z, z, o, z], [z, o, z, z], [z, z, z, o]],
492            Gate::Cu(mat) => [
493                [o, z, z, z],
494                [z, o, z, z],
495                [z, z, mat[0][0], mat[0][1]],
496                [z, z, mat[1][0], mat[1][1]],
497            ],
498            Gate::Fused2q(mat) => **mat,
499            _ => panic!(
500                "matrix_4x4 called on non-standard-2q gate `{}`",
501                self.name()
502            ),
503        }
504    }
505
506    /// Human-readable gate name (for errors, logs, and OpenQASM round-tripping).
507    #[inline]
508    pub fn name(&self) -> &'static str {
509        match self {
510            Gate::Id => "id",
511            Gate::X => "x",
512            Gate::Y => "y",
513            Gate::Z => "z",
514            Gate::H => "h",
515            Gate::S => "s",
516            Gate::Sdg => "sdg",
517            Gate::T => "t",
518            Gate::Tdg => "tdg",
519            Gate::SX => "sx",
520            Gate::SXdg => "sxdg",
521            Gate::Rx(_) => "rx",
522            Gate::Ry(_) => "ry",
523            Gate::Rz(_) => "rz",
524            Gate::P(_) => "p",
525            Gate::Rzz(_) => "rzz",
526            Gate::Cx => "cx",
527            Gate::Cz => "cz",
528            Gate::Swap => "swap",
529            Gate::Cu(_) => "cu",
530            Gate::Mcu(_) => "mcu",
531            Gate::Fused(_) => "fused",
532            Gate::BatchPhase(_) => "batch_phase",
533            Gate::QftBlock { .. } => "qft_block",
534            Gate::BatchRzz(_) => "batch_rzz",
535            Gate::DiagonalBatch(_) => "diagonal_batch",
536            Gate::MultiFused(_) => "multi_fused",
537            Gate::Fused2q(_) => "fused_2q",
538            Gate::Multi2q(_) => "multi_2q",
539        }
540    }
541
542    /// Compute the inverse (adjoint) of this gate.
543    pub fn inverse(&self) -> Gate {
544        match self {
545            Gate::Id | Gate::X | Gate::Y | Gate::Z | Gate::H => self.clone(),
546            Gate::S => Gate::Sdg,
547            Gate::Sdg => Gate::S,
548            Gate::T => Gate::Tdg,
549            Gate::Tdg => Gate::T,
550            Gate::SX => Gate::SXdg,
551            Gate::SXdg => Gate::SX,
552            Gate::Rx(theta) => Gate::Rx(-theta),
553            Gate::Ry(theta) => Gate::Ry(-theta),
554            Gate::Rz(theta) => Gate::Rz(-theta),
555            Gate::P(theta) => Gate::P(-theta),
556            Gate::Rzz(theta) => Gate::Rzz(-theta),
557            Gate::Cx | Gate::Cz | Gate::Swap => self.clone(),
558            Gate::Cu(mat) => Gate::cu(adjoint_2x2(mat)),
559            Gate::Mcu(data) => Gate::mcu(adjoint_2x2(&data.mat), data.num_controls),
560            Gate::Fused(mat) => Gate::Fused(Box::new(adjoint_2x2(mat))),
561            Gate::BatchPhase(data) => Gate::BatchPhase(Box::new(BatchPhaseData {
562                phases: data.phases.iter().map(|&(q, p)| (q, p.conj())).collect(),
563            })),
564            Gate::QftBlock { .. } => {
565                panic!(
566                    "Gate::QftBlock has no in-place inverse. Run \
567                     circuit::expand_qft_blocks before applying `inv @` or any \
568                     transform that calls Gate::inverse()."
569                )
570            }
571            Gate::BatchRzz(data) => Gate::BatchRzz(Box::new(BatchRzzData {
572                edges: data
573                    .edges
574                    .iter()
575                    .map(|&(q0, q1, theta)| (q0, q1, -theta))
576                    .collect(),
577            })),
578            Gate::DiagonalBatch(data) => Gate::DiagonalBatch(Box::new(DiagonalBatchData {
579                entries: data
580                    .entries
581                    .iter()
582                    .map(|e| match e {
583                        DiagEntry::Phase1q { qubit, d0, d1 } => DiagEntry::Phase1q {
584                            qubit: *qubit,
585                            d0: d0.conj(),
586                            d1: d1.conj(),
587                        },
588                        DiagEntry::Phase2q { q0, q1, phase } => DiagEntry::Phase2q {
589                            q0: *q0,
590                            q1: *q1,
591                            phase: phase.conj(),
592                        },
593                        DiagEntry::Parity2q { q0, q1, same, diff } => DiagEntry::Parity2q {
594                            q0: *q0,
595                            q1: *q1,
596                            same: same.conj(),
597                            diff: diff.conj(),
598                        },
599                    })
600                    .collect(),
601            })),
602            Gate::MultiFused(data) => Gate::MultiFused(Box::new(MultiFusedData {
603                gates: data
604                    .gates
605                    .iter()
606                    .map(|&(target, mat)| (target, adjoint_2x2(&mat)))
607                    .collect(),
608                all_diagonal: data.all_diagonal,
609            })),
610            Gate::Fused2q(mat) => Gate::Fused2q(Box::new(adjoint_4x4(mat))),
611            Gate::Multi2q(data) => Gate::Multi2q(Box::new(Multi2qData {
612                gates: data
613                    .gates
614                    .iter()
615                    .rev()
616                    .map(|&(q0, q1, ref mat)| (q0, q1, adjoint_4x4(mat)))
617                    .collect(),
618            })),
619        }
620    }
621
622    /// Compute integer power of a single-qubit gate.
623    ///
624    /// Returns the gate raised to the `k`-th power. Negative `k` inverts first.
625    /// Only valid for single-qubit gates.
626    pub fn matrix_power(&self, k: i64) -> Gate {
627        debug_assert_eq!(
628            self.num_qubits(),
629            1,
630            "matrix_power only for single-qubit gates"
631        );
632        if k == 0 {
633            return Gate::Id;
634        }
635        if k == 1 {
636            return self.clone();
637        }
638        let base = if k < 0 { self.inverse() } else { self.clone() };
639        let n = k.unsigned_abs() as usize;
640        if n == 1 {
641            return base;
642        }
643        let base_mat = base.matrix_2x2();
644        let mut acc = base_mat;
645        for _ in 1..n {
646            acc = mat_mul_2x2(&base_mat, &acc);
647        }
648        Gate::Fused(Box::new(acc))
649    }
650
651    /// Create a single-controlled unitary gate with the given 2x2 matrix.
652    pub fn cu(mat: [[Complex64; 2]; 2]) -> Gate {
653        Gate::Cu(Box::new(mat))
654    }
655
656    /// Create a multi-controlled unitary gate with `num_controls` control qubits.
657    pub fn mcu(mat: [[Complex64; 2]; 2], num_controls: u8) -> Gate {
658        Gate::Mcu(Box::new(McuData { mat, num_controls }))
659    }
660
661    /// Create a controlled-phase gate CPhase(θ) = Cu(\[\[1,0\],\[0,e^{iθ}\]\]).
662    ///
663    /// Applies phase e^{iθ} to |11⟩ and identity to all other basis states.
664    pub fn cphase(theta: f64) -> Gate {
665        let one = Complex64::new(1.0, 0.0);
666        let zero = Complex64::new(0.0, 0.0);
667        let phase = Complex64::from_polar(1.0, theta);
668        Gate::cu([[one, zero], [zero, phase]])
669    }
670
671    /// Returns the phase if this is a controlled-phase gate (Cu/Mcu with
672    /// diagonal matrix `[[1,0],[0,e^{iθ}]]`).
673    ///
674    /// Used by backends to dispatch to optimized phase-only kernels that
675    /// touch half the memory of the generic controlled-unitary kernel.
676    #[inline]
677    pub fn controlled_phase(&self) -> Option<Complex64> {
678        let mat = match self {
679            Gate::Cu(mat) => &**mat,
680            Gate::Mcu(data) => &data.mat,
681            _ => return None,
682        };
683        if (mat[0][0].re - 1.0).abs() < IDENTITY_EPS
684            && mat[0][0].im.abs() < IDENTITY_EPS
685            && mat[0][1].norm() < IDENTITY_EPS
686            && mat[1][0].norm() < IDENTITY_EPS
687            && (mat[1][1].norm() - 1.0).abs() < IDENTITY_EPS
688        {
689            Some(mat[1][1])
690        } else {
691            None
692        }
693    }
694
695    /// True if this is a diagonal single-qubit gate (matrix is `[[a,0],[0,b]]`).
696    ///
697    /// Diagonal gates commute with CX on the control qubit and with CZ on
698    /// either qubit. Used by the commutation-aware reordering pass.
699    #[inline]
700    pub fn is_diagonal_1q(&self) -> bool {
701        match self {
702            Gate::Id
703            | Gate::Z
704            | Gate::S
705            | Gate::Sdg
706            | Gate::T
707            | Gate::Tdg
708            | Gate::Rz(_)
709            | Gate::P(_) => true,
710            Gate::Fused(m) => m[0][1].norm() < IDENTITY_EPS && m[1][0].norm() < IDENTITY_EPS,
711            _ => false,
712        }
713    }
714
715    /// True if this is a self-inverse two-qubit gate (applying it twice = identity).
716    #[inline]
717    pub fn is_self_inverse_2q(&self) -> bool {
718        matches!(self, Gate::Cx | Gate::Cz | Gate::Swap)
719    }
720
721    /// True if this gate maps computational basis states to computational basis
722    /// states (with at most a phase). Such gates preserve the number of non-zero
723    /// amplitudes, making the sparse backend optimal (O(1) memory for |0...0⟩).
724    ///
725    /// Includes diagonal gates (Z, S, T, Rz, P, CZ) and permutation gates
726    /// (X, Y, CX, SWAP). Excludes superposition-creating gates (H, Rx, Ry, SX).
727    #[inline]
728    pub fn preserves_sparsity(&self) -> bool {
729        match self {
730            Gate::Id | Gate::X | Gate::Y | Gate::Z => true,
731            Gate::S | Gate::Sdg | Gate::T | Gate::Tdg => true,
732            Gate::Rz(_) | Gate::P(_) => true,
733            Gate::Rzz(_) | Gate::Cx | Gate::Cz | Gate::Swap => true,
734            Gate::Cu(mat) | Gate::Fused(mat) => {
735                let is_diag = mat[0][1].norm_sqr() < NEAR_ZERO_NORM_SQ
736                    && mat[1][0].norm_sqr() < NEAR_ZERO_NORM_SQ;
737                let is_antidiag = mat[0][0].norm_sqr() < NEAR_ZERO_NORM_SQ
738                    && mat[1][1].norm_sqr() < NEAR_ZERO_NORM_SQ;
739                is_diag || is_antidiag
740            }
741            Gate::Mcu(data) => {
742                let m = &data.mat;
743                let is_diag = m[0][1].norm_sqr() < NEAR_ZERO_NORM_SQ
744                    && m[1][0].norm_sqr() < NEAR_ZERO_NORM_SQ;
745                let is_antidiag = m[0][0].norm_sqr() < NEAR_ZERO_NORM_SQ
746                    && m[1][1].norm_sqr() < NEAR_ZERO_NORM_SQ;
747                is_diag || is_antidiag
748            }
749            Gate::BatchPhase(_) | Gate::BatchRzz(_) | Gate::DiagonalBatch(_) => true,
750            _ => false,
751        }
752    }
753
754    /// Try to recognize a 2x2 unitary matrix as a named gate (up to global phase).
755    ///
756    /// Used by the fusion pass to emit named gate variants instead of opaque
757    /// `Gate::Fused` matrices, enabling downstream passes (e.g. `clifford_prefix_split`)
758    /// to identify Clifford gates that arose from fusion (e.g. T·T → S).
759    pub fn recognize_matrix(mat: &[[Complex64; 2]; 2]) -> Option<Gate> {
760        const EPS: f64 = 1e-10;
761
762        // Check each candidate gate. For each, compute the global phase ratio
763        // mat[i][j] / ref[i][j] using the first non-zero entry, then verify
764        // all other entries match under that same phase.
765        let candidates: &[Gate] = &[
766            Gate::H,
767            Gate::X,
768            Gate::Y,
769            Gate::Z,
770            Gate::S,
771            Gate::Sdg,
772            Gate::T,
773            Gate::Tdg,
774            Gate::SX,
775            Gate::SXdg,
776        ];
777
778        for candidate in candidates {
779            let ref_mat = candidate.matrix_2x2();
780            if matrices_equal_up_to_phase(mat, &ref_mat, EPS) {
781                return Some(candidate.clone());
782            }
783        }
784
785        // Identity check: all off-diagonal zero, diagonal entries equal
786        if mat[0][1].norm_sqr() < EPS
787            && mat[1][0].norm_sqr() < EPS
788            && (mat[0][0] - mat[1][1]).norm_sqr() < EPS
789            && mat[0][0].norm_sqr() > EPS
790        {
791            return Some(Gate::Id);
792        }
793
794        None
795    }
796
797    /// True if this gate is a Clifford gate (relevant for stabilizer backend).
798    #[inline]
799    pub fn is_clifford(&self) -> bool {
800        matches!(
801            self,
802            Gate::Id
803                | Gate::X
804                | Gate::Y
805                | Gate::Z
806                | Gate::H
807                | Gate::S
808                | Gate::Sdg
809                | Gate::SX
810                | Gate::SXdg
811                | Gate::Cx
812                | Gate::Cz
813                | Gate::Swap
814        )
815    }
816}
817
818/// Check if two 2x2 unitary matrices are equal up to a global phase factor.
819fn matrices_equal_up_to_phase(a: &[[Complex64; 2]; 2], b: &[[Complex64; 2]; 2], eps: f64) -> bool {
820    // Find the first non-zero entry in b to determine the phase ratio
821    let mut phase = None;
822    for i in 0..2 {
823        for j in 0..2 {
824            if b[i][j].norm_sqr() > eps {
825                if a[i][j].norm_sqr() < eps {
826                    return false;
827                }
828                phase = Some(a[i][j] / b[i][j]);
829                break;
830            }
831        }
832        if phase.is_some() {
833            break;
834        }
835    }
836
837    let phase = match phase {
838        Some(p) => p,
839        None => return true, // Both are zero matrices
840    };
841
842    // Verify all entries match under the same phase
843    for i in 0..2 {
844        for j in 0..2 {
845            let expected = phase * b[i][j];
846            if (a[i][j] - expected).norm_sqr() > eps {
847                return false;
848            }
849        }
850    }
851    true
852}
853
854fn format_angle(theta: f64) -> String {
855    const FRACTIONS: &[(f64, &str)] = &[
856        (1.0, "π"),
857        (-1.0, "-π"),
858        (0.5, "π/2"),
859        (-0.5, "-π/2"),
860        (0.25, "π/4"),
861        (-0.25, "-π/4"),
862        (1.0 / 3.0, "π/3"),
863        (-1.0 / 3.0, "-π/3"),
864        (2.0 / 3.0, "2π/3"),
865        (-2.0 / 3.0, "-2π/3"),
866        (1.0 / 6.0, "π/6"),
867        (-1.0 / 6.0, "-π/6"),
868        (5.0 / 6.0, "5π/6"),
869        (-5.0 / 6.0, "-5π/6"),
870        (1.0 / 8.0, "π/8"),
871        (-1.0 / 8.0, "-π/8"),
872        (3.0 / 8.0, "3π/8"),
873        (-3.0 / 8.0, "-3π/8"),
874        (1.5, "3π/2"),
875        (-1.5, "-3π/2"),
876        (2.0, "2π"),
877        (-2.0, "-2π"),
878    ];
879    let ratio = theta / std::f64::consts::PI;
880    for &(frac, label) in FRACTIONS {
881        if (ratio - frac).abs() < 1e-10 {
882            return label.to_string();
883        }
884    }
885    format!("{:.4}", theta)
886}
887
888impl fmt::Display for Gate {
889    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
890        match self {
891            Gate::Id => f.write_str("I"),
892            Gate::X => f.write_str("X"),
893            Gate::Y => f.write_str("Y"),
894            Gate::Z => f.write_str("Z"),
895            Gate::H => f.write_str("H"),
896            Gate::S => f.write_str("S"),
897            Gate::Sdg => f.write_str("Sdg"),
898            Gate::T => f.write_str("T"),
899            Gate::Tdg => f.write_str("Tdg"),
900            Gate::SX => f.write_str("SX"),
901            Gate::SXdg => f.write_str("SXdg"),
902            Gate::Rx(t) => write!(f, "Rx({})", format_angle(*t)),
903            Gate::Ry(t) => write!(f, "Ry({})", format_angle(*t)),
904            Gate::Rz(t) => write!(f, "Rz({})", format_angle(*t)),
905            Gate::P(t) => write!(f, "P({})", format_angle(*t)),
906            Gate::Rzz(t) => write!(f, "Rzz({})", format_angle(*t)),
907            Gate::Cx => f.write_str("CX"),
908            Gate::Cz => f.write_str("CZ"),
909            Gate::Swap => f.write_str("SWAP"),
910            Gate::Cu(_) => f.write_str("CU"),
911            Gate::Mcu(data) => write!(f, "MCU({}ctrl)", data.num_controls),
912            Gate::Fused(_) => f.write_str("U"),
913            Gate::Fused2q(_) => f.write_str("U2"),
914            Gate::MultiFused(data) => write!(f, "MF[{}]", data.gates.len()),
915            Gate::BatchPhase(data) => write!(f, "BP[{}]", data.phases.len()),
916            Gate::QftBlock { start, num } => write!(f, "QFT[{}..{}]", start, start + num),
917            Gate::BatchRzz(data) => write!(f, "BZZ[{}]", data.edges.len()),
918            Gate::DiagonalBatch(data) => write!(f, "BD[{}]", data.entries.len()),
919            Gate::Multi2q(data) => write!(f, "M2[{}]", data.gates.len()),
920        }
921    }
922}
923
924#[cfg(test)]
925mod tests {
926    use super::*;
927
928    #[test]
929    fn format_angle_pi_fractions() {
930        assert_eq!(format_angle(std::f64::consts::PI), "π");
931        assert_eq!(format_angle(std::f64::consts::FRAC_PI_2), "π/2");
932        assert_eq!(format_angle(std::f64::consts::FRAC_PI_4), "π/4");
933        assert_eq!(format_angle(-std::f64::consts::FRAC_PI_4), "-π/4");
934        assert_eq!(format_angle(std::f64::consts::PI / 3.0), "π/3");
935        assert_eq!(format_angle(0.123), "0.1230");
936    }
937
938    #[test]
939    fn display_labels() {
940        assert_eq!(Gate::H.to_string(), "H");
941        assert_eq!(Gate::Cx.to_string(), "CX");
942        assert_eq!(Gate::Rx(std::f64::consts::FRAC_PI_2).to_string(), "Rx(π/2)");
943        assert_eq!(Gate::Rz(0.5).to_string(), "Rz(0.5000)");
944        assert_eq!(Gate::Id.to_string(), "I");
945        assert_eq!(Gate::Swap.to_string(), "SWAP");
946    }
947
948    #[test]
949    fn test_gate_arity() {
950        assert_eq!(Gate::H.num_qubits(), 1);
951        assert_eq!(Gate::Rx(0.5).num_qubits(), 1);
952        assert_eq!(Gate::Cx.num_qubits(), 2);
953        assert_eq!(Gate::Swap.num_qubits(), 2);
954    }
955
956    #[test]
957    fn batch_gate_arity_counts_qubits_above_word_boundary() {
958        let one = Complex64::new(1.0, 0.0);
959        let batch_rzz = Gate::BatchRzz(Box::new(BatchRzzData {
960            edges: vec![(0, 64, 0.25), (64, 129, 0.5)],
961        }));
962        assert_eq!(batch_rzz.num_qubits(), 3);
963
964        let diagonal_batch = Gate::DiagonalBatch(Box::new(DiagonalBatchData {
965            entries: vec![
966                DiagEntry::Phase1q {
967                    qubit: 64,
968                    d0: one,
969                    d1: -one,
970                },
971                DiagEntry::Phase2q {
972                    q0: 64,
973                    q1: 130,
974                    phase: -one,
975                },
976            ],
977        }));
978        assert_eq!(diagonal_batch.num_qubits(), 2);
979
980        let multi_2q = Gate::Multi2q(Box::new(Multi2qData {
981            gates: vec![
982                (63, 64, Gate::Cx.matrix_4x4()),
983                (64, 130, Gate::Cz.matrix_4x4()),
984            ],
985        }));
986        assert_eq!(multi_2q.num_qubits(), 3);
987    }
988
989    #[test]
990    fn test_h_matrix_is_unitary() {
991        let m = Gate::H.matrix_2x2();
992        // H * H = I
993        let mut product = [[Complex64::new(0.0, 0.0); 2]; 2];
994        for i in 0..2 {
995            for j in 0..2 {
996                for (k, row) in m.iter().enumerate() {
997                    product[i][j] += m[i][k] * row[j];
998                }
999            }
1000        }
1001        let eps = 1e-12;
1002        assert!((product[0][0].re - 1.0).abs() < eps);
1003        assert!(product[0][0].im.abs() < eps);
1004        assert!(product[0][1].norm() < eps);
1005        assert!(product[1][0].norm() < eps);
1006        assert!((product[1][1].re - 1.0).abs() < eps);
1007    }
1008
1009    #[test]
1010    fn test_rx_pi_equals_neg_i_x() {
1011        let rx = Gate::Rx(std::f64::consts::PI).matrix_2x2();
1012        // Rx(π) = -i·X  (up to global phase)
1013        // |Rx(π)[0][1]| should be 1
1014        assert!((rx[0][1].norm() - 1.0).abs() < 1e-12);
1015        assert!((rx[1][0].norm() - 1.0).abs() < 1e-12);
1016        assert!(rx[0][0].norm() < 1e-12);
1017        assert!(rx[1][1].norm() < 1e-12);
1018    }
1019
1020    #[test]
1021    fn test_clifford_classification() {
1022        assert!(Gate::H.is_clifford());
1023        assert!(Gate::S.is_clifford());
1024        assert!(Gate::Cx.is_clifford());
1025        assert!(!Gate::T.is_clifford());
1026        assert!(!Gate::Rx(0.5).is_clifford());
1027        assert!(!Gate::Cu(Box::new([[Complex64::new(1.0, 0.0); 2]; 2])).is_clifford());
1028    }
1029
1030    #[test]
1031    fn test_preserves_sparsity() {
1032        // Diagonal and permutation gates preserve sparsity
1033        assert!(Gate::Id.preserves_sparsity());
1034        assert!(Gate::X.preserves_sparsity());
1035        assert!(Gate::Y.preserves_sparsity());
1036        assert!(Gate::Z.preserves_sparsity());
1037        assert!(Gate::S.preserves_sparsity());
1038        assert!(Gate::T.preserves_sparsity());
1039        assert!(Gate::Rz(1.0).preserves_sparsity());
1040        assert!(Gate::P(0.5).preserves_sparsity());
1041        assert!(Gate::Cx.preserves_sparsity());
1042        assert!(Gate::Cz.preserves_sparsity());
1043        assert!(Gate::Swap.preserves_sparsity());
1044
1045        // Superposition-creating gates do NOT preserve sparsity
1046        assert!(!Gate::H.preserves_sparsity());
1047        assert!(!Gate::Rx(0.5).preserves_sparsity());
1048        assert!(!Gate::Ry(0.5).preserves_sparsity());
1049        assert!(!Gate::SX.preserves_sparsity());
1050        assert!(!Gate::SXdg.preserves_sparsity());
1051
1052        // Cu with diagonal matrix preserves sparsity
1053        let diag = Box::new([
1054            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
1055            [Complex64::new(0.0, 0.0), Complex64::new(0.0, 1.0)],
1056        ]);
1057        assert!(Gate::Cu(diag).preserves_sparsity());
1058
1059        // Cu with H-like matrix does NOT preserve sparsity
1060        let h_mat = Box::new(Gate::H.matrix_2x2());
1061        assert!(!Gate::Cu(h_mat).preserves_sparsity());
1062    }
1063
1064    #[test]
1065    fn test_cu_arity() {
1066        let mat = Gate::H.matrix_2x2();
1067        assert_eq!(Gate::Cu(Box::new(mat)).num_qubits(), 2);
1068    }
1069
1070    fn assert_mat_close(a: &[[Complex64; 2]; 2], b: &[[Complex64; 2]; 2], eps: f64) {
1071        for i in 0..2 {
1072            for j in 0..2 {
1073                assert!(
1074                    (a[i][j] - b[i][j]).norm() < eps,
1075                    "mat[{i}][{j}]: expected {:?}, got {:?}",
1076                    b[i][j],
1077                    a[i][j]
1078                );
1079            }
1080        }
1081    }
1082
1083    #[test]
1084    fn test_inverse_self_inverse() {
1085        assert_eq!(Gate::H.inverse(), Gate::H);
1086        assert_eq!(Gate::X.inverse(), Gate::X);
1087        assert_eq!(Gate::Y.inverse(), Gate::Y);
1088        assert_eq!(Gate::Z.inverse(), Gate::Z);
1089        assert_eq!(Gate::Id.inverse(), Gate::Id);
1090        assert_eq!(Gate::Cx.inverse(), Gate::Cx);
1091        assert_eq!(Gate::Cz.inverse(), Gate::Cz);
1092        assert_eq!(Gate::Swap.inverse(), Gate::Swap);
1093    }
1094
1095    #[test]
1096    fn test_inverse_adjoint_pairs() {
1097        assert_eq!(Gate::S.inverse(), Gate::Sdg);
1098        assert_eq!(Gate::Sdg.inverse(), Gate::S);
1099        assert_eq!(Gate::T.inverse(), Gate::Tdg);
1100        assert_eq!(Gate::Tdg.inverse(), Gate::T);
1101    }
1102
1103    #[test]
1104    fn test_inverse_parametric() {
1105        assert_eq!(Gate::Rx(0.5).inverse(), Gate::Rx(-0.5));
1106        assert_eq!(Gate::Ry(1.0).inverse(), Gate::Ry(-1.0));
1107        assert_eq!(Gate::Rz(PI).inverse(), Gate::Rz(-PI));
1108    }
1109
1110    #[test]
1111    fn test_inverse_fused_is_adjoint() {
1112        let s_mat = Gate::S.matrix_2x2();
1113        let fused = Gate::Fused(Box::new(s_mat));
1114        let inv = fused.inverse();
1115        if let Gate::Fused(inv_mat) = &inv {
1116            assert_mat_close(inv_mat, &Gate::Sdg.matrix_2x2(), 1e-12);
1117        } else {
1118            panic!("expected Fused");
1119        }
1120    }
1121
1122    #[test]
1123    fn test_inverse_cu() {
1124        let rz_mat = Gate::Rz(0.5).matrix_2x2();
1125        let cu = Gate::Cu(Box::new(rz_mat));
1126        let inv = cu.inverse();
1127        if let Gate::Cu(inv_mat) = &inv {
1128            let expected = Gate::Rz(-0.5).matrix_2x2();
1129            assert_mat_close(inv_mat, &expected, 1e-12);
1130        } else {
1131            panic!("expected Cu");
1132        }
1133    }
1134
1135    #[test]
1136    fn test_matrix_power_zero() {
1137        assert_eq!(Gate::X.matrix_power(0), Gate::Id);
1138        assert_eq!(Gate::Rz(0.5).matrix_power(0), Gate::Id);
1139    }
1140
1141    #[test]
1142    fn test_matrix_power_one() {
1143        assert_eq!(Gate::X.matrix_power(1), Gate::X);
1144        assert_eq!(Gate::H.matrix_power(1), Gate::H);
1145    }
1146
1147    #[test]
1148    fn test_matrix_power_x_squared() {
1149        let x2 = Gate::X.matrix_power(2);
1150        if let Gate::Fused(mat) = &x2 {
1151            assert_mat_close(mat, &Gate::Id.matrix_2x2(), 1e-12);
1152        } else {
1153            panic!("expected Fused");
1154        }
1155    }
1156
1157    #[test]
1158    fn test_matrix_power_t_squared_is_s() {
1159        let t2 = Gate::T.matrix_power(2);
1160        if let Gate::Fused(mat) = &t2 {
1161            assert_mat_close(mat, &Gate::S.matrix_2x2(), 1e-12);
1162        } else {
1163            panic!("expected Fused");
1164        }
1165    }
1166
1167    #[test]
1168    fn test_matrix_power_negative() {
1169        let t_inv2 = Gate::T.matrix_power(-2);
1170        if let Gate::Fused(mat) = &t_inv2 {
1171            assert_mat_close(mat, &Gate::Sdg.matrix_2x2(), 1e-12);
1172        } else {
1173            panic!("expected Fused");
1174        }
1175    }
1176
1177    #[test]
1178    fn test_mcu_arity() {
1179        let mat = Gate::H.matrix_2x2();
1180        let mcu2 = Gate::Mcu(Box::new(McuData {
1181            mat,
1182            num_controls: 2,
1183        }));
1184        assert_eq!(mcu2.num_qubits(), 3);
1185        let mcu3 = Gate::Mcu(Box::new(McuData {
1186            mat,
1187            num_controls: 3,
1188        }));
1189        assert_eq!(mcu3.num_qubits(), 4);
1190    }
1191
1192    #[test]
1193    fn test_mcu_not_clifford() {
1194        let mat = Gate::X.matrix_2x2();
1195        let mcu = Gate::Mcu(Box::new(McuData {
1196            mat,
1197            num_controls: 2,
1198        }));
1199        assert!(!mcu.is_clifford());
1200    }
1201
1202    #[test]
1203    fn test_mcu_inverse() {
1204        let rz_mat = Gate::Rz(0.5).matrix_2x2();
1205        let mcu = Gate::Mcu(Box::new(McuData {
1206            mat: rz_mat,
1207            num_controls: 2,
1208        }));
1209        let inv = mcu.inverse();
1210        if let Gate::Mcu(inv_data) = &inv {
1211            let expected = Gate::Rz(-0.5).matrix_2x2();
1212            assert_mat_close(&inv_data.mat, &expected, 1e-12);
1213            assert_eq!(inv_data.num_controls, 2);
1214        } else {
1215            panic!("expected Mcu");
1216        }
1217    }
1218
1219    #[test]
1220    fn test_mcu_name() {
1221        let mat = Gate::H.matrix_2x2();
1222        let mcu = Gate::Mcu(Box::new(McuData {
1223            mat,
1224            num_controls: 2,
1225        }));
1226        assert_eq!(mcu.name(), "mcu");
1227    }
1228
1229    #[test]
1230    fn test_cphase_constructor() {
1231        let g = Gate::cphase(PI / 4.0);
1232        assert_eq!(g.num_qubits(), 2);
1233        assert_eq!(g.name(), "cu");
1234        if let Gate::Cu(mat) = &g {
1235            let one = Complex64::new(1.0, 0.0);
1236            assert!((mat[0][0] - one).norm() < 1e-14);
1237            assert!(mat[0][1].norm() < 1e-14);
1238            assert!(mat[1][0].norm() < 1e-14);
1239            let expected = Complex64::from_polar(1.0, PI / 4.0);
1240            assert!((mat[1][1] - expected).norm() < 1e-14);
1241        } else {
1242            panic!("expected Cu");
1243        }
1244    }
1245
1246    #[test]
1247    fn test_controlled_phase_detection() {
1248        let cp = Gate::cphase(0.5);
1249        assert!(cp.controlled_phase().is_some());
1250        let phase = cp.controlled_phase().unwrap();
1251        let expected = Complex64::from_polar(1.0, 0.5);
1252        assert!((phase - expected).norm() < 1e-14);
1253
1254        // Non-diagonal Cu should not be detected
1255        let h_mat = Gate::H.matrix_2x2();
1256        let cu_h = Gate::Cu(Box::new(h_mat));
1257        assert!(cu_h.controlled_phase().is_none());
1258
1259        // CZ is Cu([[1,0],[0,-1]]), should be detected (phase = -1)
1260        let z_mat = Gate::Z.matrix_2x2();
1261        let cu_z = Gate::Cu(Box::new(z_mat));
1262        assert!(cu_z.controlled_phase().is_some());
1263        let z_phase = cu_z.controlled_phase().unwrap();
1264        assert!((z_phase.re - (-1.0)).abs() < 1e-14);
1265
1266        // Rz-based Cu is diagonal but mat[0][0] != 1, should NOT be detected
1267        let rz_mat = Gate::Rz(0.5).matrix_2x2();
1268        let cu_rz = Gate::Cu(Box::new(rz_mat));
1269        assert!(cu_rz.controlled_phase().is_none());
1270
1271        // Non-Cu gates should return None
1272        assert!(Gate::H.controlled_phase().is_none());
1273        assert!(Gate::Cx.controlled_phase().is_none());
1274    }
1275
1276    #[test]
1277    fn test_controlled_phase_mcu() {
1278        let one = Complex64::new(1.0, 0.0);
1279        let zero = Complex64::new(0.0, 0.0);
1280        let phase = Complex64::from_polar(1.0, 0.7);
1281        let mcu = Gate::Mcu(Box::new(McuData {
1282            mat: [[one, zero], [zero, phase]],
1283            num_controls: 2,
1284        }));
1285        assert!(mcu.controlled_phase().is_some());
1286        assert!((mcu.controlled_phase().unwrap() - phase).norm() < 1e-14);
1287    }
1288
1289    #[test]
1290    fn test_sx_matrix_is_sqrt_x() {
1291        let sx = Gate::SX.matrix_2x2();
1292        let sx2 = mat_mul_2x2(&sx, &sx);
1293        assert_mat_close(&sx2, &Gate::X.matrix_2x2(), 1e-12);
1294    }
1295
1296    #[test]
1297    fn test_sxdg_is_sx_inverse() {
1298        let sx = Gate::SX.matrix_2x2();
1299        let sxdg = Gate::SXdg.matrix_2x2();
1300        let product = mat_mul_2x2(&sx, &sxdg);
1301        assert_mat_close(&product, &Gate::Id.matrix_2x2(), 1e-12);
1302    }
1303
1304    #[test]
1305    fn test_p_gate_matrix() {
1306        let p = Gate::P(PI / 4.0).matrix_2x2();
1307        let t = Gate::T.matrix_2x2();
1308        assert_mat_close(&p, &t, 1e-12);
1309    }
1310
1311    #[test]
1312    fn test_sx_is_clifford() {
1313        assert!(Gate::SX.is_clifford());
1314        assert!(Gate::SXdg.is_clifford());
1315    }
1316
1317    #[test]
1318    fn test_p_inverse() {
1319        assert_eq!(Gate::P(0.5).inverse(), Gate::P(-0.5));
1320    }
1321
1322    #[test]
1323    fn test_sx_inverse_pair() {
1324        assert_eq!(Gate::SX.inverse(), Gate::SXdg);
1325        assert_eq!(Gate::SXdg.inverse(), Gate::SX);
1326    }
1327
1328    #[test]
1329    fn test_is_diagonal_1q() {
1330        assert!(Gate::Id.is_diagonal_1q());
1331        assert!(Gate::Z.is_diagonal_1q());
1332        assert!(Gate::S.is_diagonal_1q());
1333        assert!(Gate::Sdg.is_diagonal_1q());
1334        assert!(Gate::T.is_diagonal_1q());
1335        assert!(Gate::Tdg.is_diagonal_1q());
1336        assert!(Gate::Rz(0.5).is_diagonal_1q());
1337        assert!(Gate::P(0.5).is_diagonal_1q());
1338        assert!(!Gate::H.is_diagonal_1q());
1339        assert!(!Gate::X.is_diagonal_1q());
1340        assert!(!Gate::Y.is_diagonal_1q());
1341        assert!(!Gate::Rx(0.5).is_diagonal_1q());
1342        assert!(!Gate::Ry(0.5).is_diagonal_1q());
1343        assert!(!Gate::SX.is_diagonal_1q());
1344        assert!(!Gate::Cx.is_diagonal_1q());
1345
1346        let diag_fused = Gate::Fused(Box::new(Gate::T.matrix_2x2()));
1347        assert!(diag_fused.is_diagonal_1q());
1348        let nondiag_fused = Gate::Fused(Box::new(Gate::H.matrix_2x2()));
1349        assert!(!nondiag_fused.is_diagonal_1q());
1350    }
1351
1352    #[test]
1353    fn test_is_self_inverse_2q() {
1354        assert!(Gate::Cx.is_self_inverse_2q());
1355        assert!(Gate::Cz.is_self_inverse_2q());
1356        assert!(Gate::Swap.is_self_inverse_2q());
1357        assert!(!Gate::H.is_self_inverse_2q());
1358        assert!(!Gate::T.is_self_inverse_2q());
1359        let mat = Gate::H.matrix_2x2();
1360        assert!(!Gate::Cu(Box::new(mat)).is_self_inverse_2q());
1361    }
1362
1363    #[test]
1364    fn test_gate_enum_size() {
1365        assert_eq!(
1366            std::mem::size_of::<Gate>(),
1367            16,
1368            "Gate enum must stay at 16 bytes"
1369        );
1370    }
1371
1372    #[test]
1373    fn test_recognize_named_gates() {
1374        for gate in &[
1375            Gate::H,
1376            Gate::X,
1377            Gate::Y,
1378            Gate::Z,
1379            Gate::S,
1380            Gate::Sdg,
1381            Gate::T,
1382            Gate::Tdg,
1383            Gate::SX,
1384            Gate::SXdg,
1385        ] {
1386            let mat = gate.matrix_2x2();
1387            let recognized = Gate::recognize_matrix(&mat);
1388            assert_eq!(
1389                recognized.as_ref(),
1390                Some(gate),
1391                "failed to recognize {:?}",
1392                gate.name()
1393            );
1394        }
1395    }
1396
1397    #[test]
1398    fn test_recognize_identity() {
1399        let id = Gate::Id.matrix_2x2();
1400        assert_eq!(Gate::recognize_matrix(&id), Some(Gate::Id));
1401    }
1402
1403    #[test]
1404    fn test_recognize_t_squared_is_s() {
1405        let t = Gate::T.matrix_2x2();
1406        let tt = mat_mul_2x2(&t, &t);
1407        assert_eq!(Gate::recognize_matrix(&tt), Some(Gate::S));
1408    }
1409
1410    #[test]
1411    fn test_recognize_s_squared_is_z() {
1412        let s = Gate::S.matrix_2x2();
1413        let ss = mat_mul_2x2(&s, &s);
1414        assert_eq!(Gate::recognize_matrix(&ss), Some(Gate::Z));
1415    }
1416
1417    #[test]
1418    fn test_recognize_h_squared_is_identity() {
1419        let h = Gate::H.matrix_2x2();
1420        let hh = mat_mul_2x2(&h, &h);
1421        assert_eq!(Gate::recognize_matrix(&hh), Some(Gate::Id));
1422    }
1423
1424    #[test]
1425    fn test_recognize_t_fourth_is_z() {
1426        let t = Gate::T.matrix_2x2();
1427        let t2 = mat_mul_2x2(&t, &t);
1428        let t4 = mat_mul_2x2(&t2, &t2);
1429        assert_eq!(Gate::recognize_matrix(&t4), Some(Gate::Z));
1430    }
1431
1432    #[test]
1433    fn test_recognize_non_clifford_returns_none() {
1434        let rx = Gate::Rx(0.7).matrix_2x2();
1435        assert_eq!(Gate::recognize_matrix(&rx), None);
1436        let ry = Gate::Ry(1.3).matrix_2x2();
1437        assert_eq!(Gate::recognize_matrix(&ry), None);
1438    }
1439
1440    #[test]
1441    fn test_recognize_global_phase_invariance() {
1442        let phase = Complex64::from_polar(1.0, 0.42);
1443        let h = Gate::H.matrix_2x2();
1444        let phased = [
1445            [h[0][0] * phase, h[0][1] * phase],
1446            [h[1][0] * phase, h[1][1] * phase],
1447        ];
1448        assert_eq!(Gate::recognize_matrix(&phased), Some(Gate::H));
1449    }
1450}