Skip to main content

quantrs2_circuit/
template_matching.rs

1//! Template Matching Pass for quantum circuit gate-count reduction.
2//!
3//! Implements a convergent rewriting pass that applies precomputed equivalence
4//! patterns (templates) to reduce the number of gates in a circuit.  Each
5//! template maps a longer gate sequence to a shorter (or empty) equivalent one.
6//!
7//! Unlike the existing [`crate::optimization::TemplateMatching`] which is
8//! coupled to the abstract cost-model framework, this pass works directly on
9//! `Arc<dyn GateOp + Send + Sync>` slices and supports parametric templates
10//! (e.g. RZ(a)·RZ(b) → RZ(a+b)).
11//!
12//! # Usage
13//!
14//! ```rust
15//! use quantrs2_circuit::template_matching::TemplateMatchingPass;
16//! use quantrs2_core::gate::single::{Hadamard, RotationZ};
17//! use quantrs2_core::qubit::QubitId;
18//! use std::sync::Arc;
19//! use quantrs2_core::gate::GateOp;
20//!
21//! let q = QubitId::new(0);
22//! let gates: Vec<Arc<dyn GateOp + Send + Sync>> = vec![
23//!     Arc::new(Hadamard { target: q }),
24//!     Arc::new(Hadamard { target: q }),
25//! ];
26//!
27//! let pass = TemplateMatchingPass::with_standard_templates();
28//! let result = pass.run(&gates);
29//! assert!(result.is_empty(), "H·H should cancel to identity");
30//! ```
31
32use quantrs2_core::gate::{
33    multi::{CNOT, CZ, SWAP},
34    single::{
35        Hadamard, PauliX, PauliY, PauliZ, Phase, PhaseDagger, RotationX, RotationY, RotationZ,
36        TDagger, T,
37    },
38    GateOp,
39};
40use quantrs2_core::qubit::QubitId;
41use std::sync::Arc;
42
43// ─── Template types ──────────────────────────────────────────────────────────
44
45/// A single gate entry in a template pattern or replacement.
46///
47/// Qubit indices are *relative*: 0 = first qubit in the matched window,
48/// 1 = second distinct qubit, etc.
49#[derive(Debug, Clone, PartialEq)]
50pub struct TemplateGate {
51    /// Gate name as returned by `GateOp::name()` (e.g. "H", "CNOT", "RZ").
52    pub gate_name: String,
53    /// Relative qubit indices used by this gate in the pattern.
54    pub qubits: Vec<usize>,
55    /// Numeric parameters (e.g. rotation angle). Empty for non-parametric gates.
56    pub params: Vec<f64>,
57}
58
59impl TemplateGate {
60    fn new(name: impl Into<String>, qubits: Vec<usize>) -> Self {
61        Self {
62            gate_name: name.into(),
63            qubits,
64            params: vec![],
65        }
66    }
67
68    fn with_params(name: impl Into<String>, qubits: Vec<usize>, params: Vec<f64>) -> Self {
69        Self {
70            gate_name: name.into(),
71            qubits,
72            params,
73        }
74    }
75}
76
77// ─── Pattern kind ────────────────────────────────────────────────────────────
78
79/// How a template matches and what it produces.
80#[derive(Clone)]
81enum TemplateKind {
82    /// Fixed gate list → fixed (shorter) gate list.  Qubits are mapped
83    /// positionally using `TemplateGate::qubits`.
84    Fixed {
85        pattern: Vec<TemplateGate>,
86        replacement: Vec<TemplateGate>,
87    },
88    /// Merging two rotation gates of the same type on the same qubit.
89    /// E.g. RZ(a)·RZ(b) → RZ(a+b).
90    RotationMerge { gate_name: &'static str },
91}
92
93// ─── Template ────────────────────────────────────────────────────────────────
94
95/// A named rewrite rule (pattern → replacement).
96#[derive(Clone)]
97pub struct GateTemplate {
98    /// Human-readable name for debugging.
99    pub name: &'static str,
100    kind: TemplateKind,
101}
102
103impl GateTemplate {
104    /// Create a fixed-reduction template.
105    pub fn fixed(
106        name: &'static str,
107        pattern: Vec<TemplateGate>,
108        replacement: Vec<TemplateGate>,
109    ) -> Self {
110        Self {
111            name,
112            kind: TemplateKind::Fixed {
113                pattern,
114                replacement,
115            },
116        }
117    }
118
119    /// Create a rotation-merging template.
120    pub fn rotation_merge(name: &'static str, gate_name: &'static str) -> Self {
121        Self {
122            name,
123            kind: TemplateKind::RotationMerge { gate_name },
124        }
125    }
126}
127
128// ─── Standard template library ───────────────────────────────────────────────
129
130/// Build the standard set of 20+ gate-rewrite templates.
131///
132/// Separated from `TemplateMatchingPass::with_standard_templates` so that the
133/// Vec can be constructed as a single `vec![…]` literal, avoiding the
134/// `clippy::vec_init_then_push` pattern.
135fn standard_templates() -> Vec<GateTemplate> {
136    vec![
137        // ── Single-qubit self-inverse cancellations ──────────────────────────
138        // H·H = I
139        GateTemplate::fixed(
140            "H·H = I",
141            vec![
142                TemplateGate::new("H", vec![0]),
143                TemplateGate::new("H", vec![0]),
144            ],
145            vec![],
146        ),
147        // X·X = I
148        GateTemplate::fixed(
149            "X·X = I",
150            vec![
151                TemplateGate::new("X", vec![0]),
152                TemplateGate::new("X", vec![0]),
153            ],
154            vec![],
155        ),
156        // Y·Y = I
157        GateTemplate::fixed(
158            "Y·Y = I",
159            vec![
160                TemplateGate::new("Y", vec![0]),
161                TemplateGate::new("Y", vec![0]),
162            ],
163            vec![],
164        ),
165        // Z·Z = I
166        GateTemplate::fixed(
167            "Z·Z = I",
168            vec![
169                TemplateGate::new("Z", vec![0]),
170                TemplateGate::new("Z", vec![0]),
171            ],
172            vec![],
173        ),
174        // ── Two-qubit self-inverse cancellations ─────────────────────────────
175        // CNOT(c,t)·CNOT(c,t) = I
176        GateTemplate::fixed(
177            "CNOT·CNOT = I",
178            vec![
179                TemplateGate::new("CNOT", vec![0, 1]),
180                TemplateGate::new("CNOT", vec![0, 1]),
181            ],
182            vec![],
183        ),
184        // CZ(c,t)·CZ(c,t) = I
185        GateTemplate::fixed(
186            "CZ·CZ = I",
187            vec![
188                TemplateGate::new("CZ", vec![0, 1]),
189                TemplateGate::new("CZ", vec![0, 1]),
190            ],
191            vec![],
192        ),
193        // SWAP·SWAP = I
194        GateTemplate::fixed(
195            "SWAP·SWAP = I",
196            vec![
197                TemplateGate::new("SWAP", vec![0, 1]),
198                TemplateGate::new("SWAP", vec![0, 1]),
199            ],
200            vec![],
201        ),
202        // ── Gate composition rules ────────────────────────────────────────────
203        // S·S = Z
204        GateTemplate::fixed(
205            "S·S = Z",
206            vec![
207                TemplateGate::new("S", vec![0]),
208                TemplateGate::new("S", vec![0]),
209            ],
210            vec![TemplateGate::new("Z", vec![0])],
211        ),
212        // T·T = S
213        GateTemplate::fixed(
214            "T·T = S",
215            vec![
216                TemplateGate::new("T", vec![0]),
217                TemplateGate::new("T", vec![0]),
218            ],
219            vec![TemplateGate::new("S", vec![0])],
220        ),
221        // T†·T† = S†
222        GateTemplate::fixed(
223            "T†·T† = S†",
224            vec![
225                TemplateGate::new("T†", vec![0]),
226                TemplateGate::new("T†", vec![0]),
227            ],
228            vec![TemplateGate::new("S†", vec![0])],
229        ),
230        // S†·S† = Z  (since (S†)² = Z† = Z for self-adjoint Z)
231        GateTemplate::fixed(
232            "S†·S† = Z",
233            vec![
234                TemplateGate::new("S†", vec![0]),
235                TemplateGate::new("S†", vec![0]),
236            ],
237            vec![TemplateGate::new("Z", vec![0])],
238        ),
239        // S·S† = I
240        GateTemplate::fixed(
241            "S·S† = I",
242            vec![
243                TemplateGate::new("S", vec![0]),
244                TemplateGate::new("S†", vec![0]),
245            ],
246            vec![],
247        ),
248        // S†·S = I
249        GateTemplate::fixed(
250            "S†·S = I",
251            vec![
252                TemplateGate::new("S†", vec![0]),
253                TemplateGate::new("S", vec![0]),
254            ],
255            vec![],
256        ),
257        // T·T† = I
258        GateTemplate::fixed(
259            "T·T† = I",
260            vec![
261                TemplateGate::new("T", vec![0]),
262                TemplateGate::new("T†", vec![0]),
263            ],
264            vec![],
265        ),
266        // T†·T = I
267        GateTemplate::fixed(
268            "T†·T = I",
269            vec![
270                TemplateGate::new("T†", vec![0]),
271                TemplateGate::new("T", vec![0]),
272            ],
273            vec![],
274        ),
275        // T·T·T·T = Z  (since T^8 = I, T^4 = Z up to global phase)
276        GateTemplate::fixed(
277            "T⁴ = Z",
278            vec![
279                TemplateGate::new("T", vec![0]),
280                TemplateGate::new("T", vec![0]),
281                TemplateGate::new("T", vec![0]),
282                TemplateGate::new("T", vec![0]),
283            ],
284            vec![TemplateGate::new("Z", vec![0])],
285        ),
286        // T⁸ = I
287        GateTemplate::fixed(
288            "T⁸ = I",
289            vec![
290                TemplateGate::new("T", vec![0]),
291                TemplateGate::new("T", vec![0]),
292                TemplateGate::new("T", vec![0]),
293                TemplateGate::new("T", vec![0]),
294                TemplateGate::new("T", vec![0]),
295                TemplateGate::new("T", vec![0]),
296                TemplateGate::new("T", vec![0]),
297                TemplateGate::new("T", vec![0]),
298            ],
299            vec![],
300        ),
301        // ── Conjugation identities (single-qubit) ────────────────────────────
302        // H·X·H = Z
303        GateTemplate::fixed(
304            "H·X·H = Z",
305            vec![
306                TemplateGate::new("H", vec![0]),
307                TemplateGate::new("X", vec![0]),
308                TemplateGate::new("H", vec![0]),
309            ],
310            vec![TemplateGate::new("Z", vec![0])],
311        ),
312        // H·Z·H = X
313        GateTemplate::fixed(
314            "H·Z·H = X",
315            vec![
316                TemplateGate::new("H", vec![0]),
317                TemplateGate::new("Z", vec![0]),
318                TemplateGate::new("H", vec![0]),
319            ],
320            vec![TemplateGate::new("X", vec![0])],
321        ),
322        // H·Y·H = Y  (up to global phase: H·Y·H = −Y, phase ignored → Y)
323        GateTemplate::fixed(
324            "H·Y·H = Y (global phase)",
325            vec![
326                TemplateGate::new("H", vec![0]),
327                TemplateGate::new("Y", vec![0]),
328                TemplateGate::new("H", vec![0]),
329            ],
330            vec![TemplateGate::new("Y", vec![0])],
331        ),
332        // X·Z·X = Z  (up to global phase)
333        GateTemplate::fixed(
334            "X·Z·X = Z (global phase)",
335            vec![
336                TemplateGate::new("X", vec![0]),
337                TemplateGate::new("Z", vec![0]),
338                TemplateGate::new("X", vec![0]),
339            ],
340            vec![TemplateGate::new("Z", vec![0])],
341        ),
342        // Z·X·Z = X  (up to global phase)
343        GateTemplate::fixed(
344            "Z·X·Z = X (global phase)",
345            vec![
346                TemplateGate::new("Z", vec![0]),
347                TemplateGate::new("X", vec![0]),
348                TemplateGate::new("Z", vec![0]),
349            ],
350            vec![TemplateGate::new("X", vec![0])],
351        ),
352        // ── Rotation merging ──────────────────────────────────────────────────
353        // RZ(a)·RZ(b) = RZ(a+b)
354        GateTemplate::rotation_merge("RZ·RZ = RZ(a+b)", "RZ"),
355        // RX(a)·RX(b) = RX(a+b)
356        GateTemplate::rotation_merge("RX·RX = RX(a+b)", "RX"),
357        // RY(a)·RY(b) = RY(a+b)
358        GateTemplate::rotation_merge("RY·RY = RY(a+b)", "RY"),
359    ]
360}
361
362// ─── TemplateMatchingPass ────────────────────────────────────────────────────
363
364/// Gate-reduction pass using precomputed equivalence templates.
365///
366/// Iterates through the gate list, tries to match each template at each
367/// position, replaces with the shorter equivalent, and repeats until
368/// convergence (no further reductions found).
369pub struct TemplateMatchingPass {
370    templates: Vec<GateTemplate>,
371}
372
373impl TemplateMatchingPass {
374    /// Create a pass with the provided set of templates.
375    pub fn new(templates: Vec<GateTemplate>) -> Self {
376        Self { templates }
377    }
378
379    /// Create a pass with the standard library of 20+ templates.
380    ///
381    /// Includes:
382    /// - Self-inverse cancellations: H·H, X·X, Y·Y, Z·Z, CNOT·CNOT, CZ·CZ, SWAP·SWAP
383    /// - Gate composition rules: S·S→Z, T·T→S, T†·T†→S†, Z·Z→I
384    /// - Conjugation identities: H·X·H→Z, H·Z·H→X, H·Y·H→Y (global phase)
385    /// - Rotation merging: RZ(a)·RZ(b)→RZ(a+b), RX(a)·RX(b)→RX(a+b), RY(a)·RY(b)→RY(a+b)
386    pub fn with_standard_templates() -> Self {
387        Self {
388            templates: standard_templates(),
389        }
390    }
391
392    /// Run the pass on a gate list, returning the reduced gate list.
393    ///
394    /// Iterates until no further reductions are found (convergence).
395    pub fn run(
396        &self,
397        gates: &[Arc<dyn GateOp + Send + Sync>],
398    ) -> Vec<Arc<dyn GateOp + Send + Sync>> {
399        let mut current: Vec<Arc<dyn GateOp + Send + Sync>> = gates.to_vec();
400
401        loop {
402            let reduced = self.single_pass(&current);
403            if reduced.len() == current.len() {
404                // No reduction occurred — converged.
405                break;
406            }
407            current = reduced;
408        }
409
410        current
411    }
412
413    /// Apply one pass over the gate list, trying all templates.
414    fn single_pass(
415        &self,
416        gates: &[Arc<dyn GateOp + Send + Sync>],
417    ) -> Vec<Arc<dyn GateOp + Send + Sync>> {
418        let mut result: Vec<Arc<dyn GateOp + Send + Sync>> = Vec::with_capacity(gates.len());
419        let mut i = 0;
420
421        'outer: while i < gates.len() {
422            // Try every template at position i
423            for template in &self.templates {
424                if let Some((replacement, consumed)) = self.try_apply_template(template, gates, i) {
425                    result.extend(replacement);
426                    i += consumed;
427                    continue 'outer;
428                }
429            }
430            // No template matched — keep this gate
431            result.push(gates[i].clone());
432            i += 1;
433        }
434
435        result
436    }
437
438    /// Try to apply `template` starting at index `start` in `gates`.
439    ///
440    /// Returns `Some((replacement_gates, gates_consumed))` on a match, or `None`.
441    fn try_apply_template(
442        &self,
443        template: &GateTemplate,
444        gates: &[Arc<dyn GateOp + Send + Sync>],
445        start: usize,
446    ) -> Option<(Vec<Arc<dyn GateOp + Send + Sync>>, usize)> {
447        match &template.kind {
448            TemplateKind::Fixed {
449                pattern,
450                replacement,
451            } => self.try_match_fixed(pattern, replacement, gates, start),
452            TemplateKind::RotationMerge { gate_name } => {
453                self.try_merge_rotation(gate_name, gates, start)
454            }
455        }
456    }
457
458    /// Match a fixed pattern and produce a fixed replacement.
459    fn try_match_fixed(
460        &self,
461        pattern: &[TemplateGate],
462        replacement: &[TemplateGate],
463        gates: &[Arc<dyn GateOp + Send + Sync>],
464        start: usize,
465    ) -> Option<(Vec<Arc<dyn GateOp + Send + Sync>>, usize)> {
466        if start + pattern.len() > gates.len() {
467            return None;
468        }
469
470        // Build relative-qubit → concrete-QubitId mapping
471        let mut qubit_map: Vec<Option<QubitId>> = Vec::new();
472
473        for (pat_gate, real_gate) in pattern.iter().zip(gates[start..].iter()) {
474            // Check gate name
475            if real_gate.name() != pat_gate.gate_name {
476                return None;
477            }
478
479            let real_qubits = real_gate.qubits();
480
481            // Check arity
482            if real_qubits.len() != pat_gate.qubits.len() {
483                return None;
484            }
485
486            // Extend qubit_map if needed and verify consistency
487            for (rel_idx, &concrete) in pat_gate.qubits.iter().zip(real_qubits.iter()) {
488                // Grow mapping
489                while qubit_map.len() <= *rel_idx {
490                    qubit_map.push(None);
491                }
492
493                match qubit_map[*rel_idx] {
494                    None => qubit_map[*rel_idx] = Some(concrete),
495                    Some(existing) => {
496                        if existing != concrete {
497                            return None; // Inconsistent qubit mapping
498                        }
499                    }
500                }
501            }
502
503            // For two-qubit gates with two distinct relative indices, ensure
504            // the two concrete qubits are actually distinct.
505            if pat_gate.qubits.len() == 2 {
506                let r0 = pat_gate.qubits[0];
507                let r1 = pat_gate.qubits[1];
508                if r0 != r1 {
509                    // Already stored: verify they're distinct concrete qubits
510                    if qubit_map.get(r0).copied().flatten() == qubit_map.get(r1).copied().flatten()
511                    {
512                        return None;
513                    }
514                }
515            }
516        }
517
518        // All gates in pattern must act on the same concrete qubit set —
519        // for single-qubit patterns this is naturally enforced above.
520        // For the two-qubit patterns (CNOT, CZ, SWAP) the mapping already
521        // captures control/target ordering correctly.
522
523        // Build replacement gates
524        let mut result: Vec<Arc<dyn GateOp + Send + Sync>> = Vec::new();
525        for rep_gate in replacement {
526            let concrete_qubits: Vec<QubitId> = rep_gate
527                .qubits
528                .iter()
529                .filter_map(|&rel| qubit_map.get(rel).copied().flatten())
530                .collect();
531
532            if concrete_qubits.len() != rep_gate.qubits.len() {
533                return None; // Qubit not in mapping (shouldn't happen with well-formed templates)
534            }
535
536            let gate_arc = make_gate(&rep_gate.gate_name, &concrete_qubits, &rep_gate.params)?;
537            result.push(gate_arc);
538        }
539
540        Some((result, pattern.len()))
541    }
542
543    /// Try to merge two consecutive rotation gates of the same type on the same qubit.
544    ///
545    /// Produces a single merged rotation gate (or identity if the total angle ≈ 0).
546    fn try_merge_rotation(
547        &self,
548        gate_name: &'static str,
549        gates: &[Arc<dyn GateOp + Send + Sync>],
550        start: usize,
551    ) -> Option<(Vec<Arc<dyn GateOp + Send + Sync>>, usize)> {
552        if start + 1 >= gates.len() {
553            return None;
554        }
555
556        let g0 = &gates[start];
557        let g1 = &gates[start + 1];
558
559        // Both must have the same gate name
560        if g0.name() != gate_name || g1.name() != gate_name {
561            return None;
562        }
563
564        // Both must act on the same single qubit
565        let q0 = g0.qubits();
566        let q1 = g1.qubits();
567        if q0.len() != 1 || q1.len() != 1 || q0[0] != q1[0] {
568            return None;
569        }
570
571        // Extract rotation angles via downcast
572        let theta0 = extract_rotation_angle(g0.as_ref(), gate_name)?;
573        let theta1 = extract_rotation_angle(g1.as_ref(), gate_name)?;
574        let combined = theta0 + theta1;
575
576        let qubit = q0[0];
577
578        // If the combined angle is effectively zero (mod 2π), produce identity
579        let angle_mod = combined.rem_euclid(2.0 * std::f64::consts::PI);
580        if angle_mod < 1e-9 || (2.0 * std::f64::consts::PI - angle_mod) < 1e-9 {
581            return Some((vec![], 2));
582        }
583
584        // Otherwise produce merged rotation
585        let merged = make_gate(gate_name, &[qubit], &[combined])?;
586        Some((vec![merged], 2))
587    }
588}
589
590// ─── Helpers ─────────────────────────────────────────────────────────────────
591
592/// Extract the rotation angle from a gate known to be one of "RX", "RY", "RZ".
593fn extract_rotation_angle(gate: &dyn GateOp, gate_name: &str) -> Option<f64> {
594    match gate_name {
595        "RX" => gate.as_any().downcast_ref::<RotationX>().map(|g| g.theta),
596        "RY" => gate.as_any().downcast_ref::<RotationY>().map(|g| g.theta),
597        "RZ" => gate.as_any().downcast_ref::<RotationZ>().map(|g| g.theta),
598        _ => None,
599    }
600}
601
602/// Construct an `Arc<dyn GateOp + Send + Sync>` from a name, concrete qubits, and params.
603///
604/// Returns `None` if the gate name is unrecognised or arity is wrong.
605fn make_gate(
606    name: &str,
607    qubits: &[QubitId],
608    params: &[f64],
609) -> Option<Arc<dyn GateOp + Send + Sync>> {
610    match (name, qubits.len()) {
611        ("H", 1) => Some(Arc::new(Hadamard { target: qubits[0] })),
612        ("X", 1) => Some(Arc::new(PauliX { target: qubits[0] })),
613        ("Y", 1) => Some(Arc::new(PauliY { target: qubits[0] })),
614        ("Z", 1) => Some(Arc::new(PauliZ { target: qubits[0] })),
615        ("S", 1) => Some(Arc::new(Phase { target: qubits[0] })),
616        ("S†", 1) => Some(Arc::new(PhaseDagger { target: qubits[0] })),
617        ("T", 1) => Some(Arc::new(T { target: qubits[0] })),
618        ("T†", 1) => Some(Arc::new(TDagger { target: qubits[0] })),
619        ("CNOT", 2) => Some(Arc::new(CNOT {
620            control: qubits[0],
621            target: qubits[1],
622        })),
623        ("CZ", 2) => Some(Arc::new(CZ {
624            control: qubits[0],
625            target: qubits[1],
626        })),
627        ("SWAP", 2) => Some(Arc::new(SWAP {
628            qubit1: qubits[0],
629            qubit2: qubits[1],
630        })),
631        ("RX", 1) if !params.is_empty() => Some(Arc::new(RotationX {
632            target: qubits[0],
633            theta: params[0],
634        })),
635        ("RY", 1) if !params.is_empty() => Some(Arc::new(RotationY {
636            target: qubits[0],
637            theta: params[0],
638        })),
639        ("RZ", 1) if !params.is_empty() => Some(Arc::new(RotationZ {
640            target: qubits[0],
641            theta: params[0],
642        })),
643        _ => None,
644    }
645}
646
647// ─── TemplateGate constructors (public API) ──────────────────────────────────
648
649impl TemplateGate {
650    /// Single-qubit gate without parameters.
651    pub fn single(gate_name: impl Into<String>) -> Self {
652        Self::new(gate_name, vec![0])
653    }
654
655    /// Two-qubit gate (control=0, target=1 by default).
656    pub fn two_qubit(gate_name: impl Into<String>) -> Self {
657        Self::new(gate_name, vec![0, 1])
658    }
659
660    /// Rotation gate with an explicit angle.
661    pub fn rotation(gate_name: impl Into<String>, angle: f64) -> Self {
662        Self::with_params(gate_name, vec![0], vec![angle])
663    }
664}
665
666// ─── Tests ────────────────────────────────────────────────────────────────────
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671    use quantrs2_core::gate::{
672        multi::{CNOT, CZ, SWAP},
673        single::{Hadamard, PauliX, PauliY, PauliZ, Phase, RotationX, RotationY, RotationZ, T},
674        GateOp,
675    };
676    use quantrs2_core::qubit::QubitId;
677    use std::sync::Arc;
678
679    fn q(id: u32) -> QubitId {
680        QubitId::new(id)
681    }
682
683    fn arc<G: GateOp + Send + Sync + 'static>(g: G) -> Arc<dyn GateOp + Send + Sync> {
684        Arc::new(g)
685    }
686
687    fn pass() -> TemplateMatchingPass {
688        TemplateMatchingPass::with_standard_templates()
689    }
690
691    // ── Cancellation tests ───────────────────────────────────────────────────
692
693    #[test]
694    fn test_hh_cancellation() {
695        let q0 = q(0);
696        let gates = vec![arc(Hadamard { target: q0 }), arc(Hadamard { target: q0 })];
697        let result = pass().run(&gates);
698        assert!(
699            result.is_empty(),
700            "H·H should cancel to identity (0 gates), got {}",
701            result.len()
702        );
703    }
704
705    #[test]
706    fn test_xx_cancellation() {
707        let q0 = q(0);
708        let gates = vec![arc(PauliX { target: q0 }), arc(PauliX { target: q0 })];
709        let result = pass().run(&gates);
710        assert!(result.is_empty(), "X·X should cancel");
711    }
712
713    #[test]
714    fn test_yy_cancellation() {
715        let q0 = q(0);
716        let gates = vec![arc(PauliY { target: q0 }), arc(PauliY { target: q0 })];
717        let result = pass().run(&gates);
718        assert!(result.is_empty(), "Y·Y should cancel");
719    }
720
721    #[test]
722    fn test_zz_cancellation() {
723        let q0 = q(0);
724        let gates = vec![arc(PauliZ { target: q0 }), arc(PauliZ { target: q0 })];
725        let result = pass().run(&gates);
726        assert!(result.is_empty(), "Z·Z should cancel");
727    }
728
729    #[test]
730    fn test_cnot_cancellation() {
731        let (c, t) = (q(0), q(1));
732        let gates = vec![
733            arc(CNOT {
734                control: c,
735                target: t,
736            }),
737            arc(CNOT {
738                control: c,
739                target: t,
740            }),
741        ];
742        let result = pass().run(&gates);
743        assert!(
744            result.is_empty(),
745            "CNOT·CNOT should cancel to identity, got {} gates",
746            result.len()
747        );
748    }
749
750    #[test]
751    fn test_cz_cancellation() {
752        let (c, t) = (q(0), q(1));
753        let gates = vec![
754            arc(CZ {
755                control: c,
756                target: t,
757            }),
758            arc(CZ {
759                control: c,
760                target: t,
761            }),
762        ];
763        let result = pass().run(&gates);
764        assert!(result.is_empty(), "CZ·CZ should cancel");
765    }
766
767    #[test]
768    fn test_swap_cancellation() {
769        let (a, b) = (q(0), q(1));
770        let gates = vec![
771            arc(SWAP {
772                qubit1: a,
773                qubit2: b,
774            }),
775            arc(SWAP {
776                qubit1: a,
777                qubit2: b,
778            }),
779        ];
780        let result = pass().run(&gates);
781        assert!(result.is_empty(), "SWAP·SWAP should cancel");
782    }
783
784    // ── Composition tests ────────────────────────────────────────────────────
785
786    #[test]
787    fn test_ss_to_z() {
788        let q0 = q(0);
789        let gates = vec![arc(Phase { target: q0 }), arc(Phase { target: q0 })];
790        let result = pass().run(&gates);
791        assert_eq!(result.len(), 1, "S·S should produce one gate");
792        assert_eq!(result[0].name(), "Z", "S·S should produce Z");
793    }
794
795    #[test]
796    fn test_tt_to_s() {
797        let q0 = q(0);
798        let gates = vec![arc(T { target: q0 }), arc(T { target: q0 })];
799        let result = pass().run(&gates);
800        assert_eq!(result.len(), 1, "T·T should produce one gate");
801        assert_eq!(result[0].name(), "S", "T·T should produce S");
802    }
803
804    // ── Rotation merging tests ───────────────────────────────────────────────
805
806    #[test]
807    fn test_rz_merging() {
808        let q0 = q(0);
809        let gates = vec![
810            arc(RotationZ {
811                target: q0,
812                theta: 0.3,
813            }),
814            arc(RotationZ {
815                target: q0,
816                theta: 0.7,
817            }),
818        ];
819        let result = pass().run(&gates);
820        assert_eq!(result.len(), 1, "RZ(0.3)·RZ(0.7) should merge to one gate");
821        assert_eq!(result[0].name(), "RZ");
822        let merged = result[0]
823            .as_any()
824            .downcast_ref::<RotationZ>()
825            .expect("should downcast to RotationZ");
826        assert!(
827            (merged.theta - 1.0).abs() < 1e-9,
828            "merged angle should be 1.0, got {}",
829            merged.theta
830        );
831    }
832
833    #[test]
834    fn test_rx_merging() {
835        let q0 = q(0);
836        let gates = vec![
837            arc(RotationX {
838                target: q0,
839                theta: 0.5,
840            }),
841            arc(RotationX {
842                target: q0,
843                theta: 0.5,
844            }),
845        ];
846        let result = pass().run(&gates);
847        assert_eq!(result.len(), 1, "RX(0.5)·RX(0.5) should merge");
848        let merged = result[0]
849            .as_any()
850            .downcast_ref::<RotationX>()
851            .expect("should downcast to RotationX");
852        assert!(
853            (merged.theta - 1.0).abs() < 1e-9,
854            "merged angle should be 1.0"
855        );
856    }
857
858    #[test]
859    fn test_ry_merging() {
860        let q0 = q(0);
861        let gates = vec![
862            arc(RotationY {
863                target: q0,
864                theta: 0.2,
865            }),
866            arc(RotationY {
867                target: q0,
868                theta: 0.8,
869            }),
870        ];
871        let result = pass().run(&gates);
872        assert_eq!(result.len(), 1, "RY(0.2)·RY(0.8) should merge");
873    }
874
875    // ── No false reduction tests ─────────────────────────────────────────────
876
877    #[test]
878    fn test_no_false_reduction_different_qubits() {
879        // H on q0 and H on q1 — must NOT cancel (different qubits)
880        let gates = vec![
881            arc(Hadamard { target: q(0) }),
882            arc(Hadamard { target: q(1) }),
883        ];
884        let result = pass().run(&gates);
885        assert_eq!(
886            result.len(),
887            2,
888            "H q[0]; H q[1]; must stay (different qubits)"
889        );
890    }
891
892    #[test]
893    fn test_no_false_reduction_different_gates() {
894        // H then X — must stay
895        let q0 = q(0);
896        let gates = vec![arc(Hadamard { target: q0 }), arc(PauliX { target: q0 })];
897        let result = pass().run(&gates);
898        assert_eq!(result.len(), 2, "H·X must not reduce");
899    }
900
901    #[test]
902    fn test_cnot_different_controls_no_cancel() {
903        // CNOT(0,2) and CNOT(1,2) — different control qubits, must NOT cancel
904        let gates = vec![
905            arc(CNOT {
906                control: q(0),
907                target: q(2),
908            }),
909            arc(CNOT {
910                control: q(1),
911                target: q(2),
912            }),
913        ];
914        let result = pass().run(&gates);
915        assert_eq!(
916            result.len(),
917            2,
918            "CNOT with different controls must not cancel"
919        );
920    }
921
922    #[test]
923    fn test_cnot_different_targets_no_cancel() {
924        // CNOT(0,1) and CNOT(0,2) — different target qubits, must NOT cancel
925        let gates = vec![
926            arc(CNOT {
927                control: q(0),
928                target: q(1),
929            }),
930            arc(CNOT {
931                control: q(0),
932                target: q(2),
933            }),
934        ];
935        let result = pass().run(&gates);
936        assert_eq!(
937            result.len(),
938            2,
939            "CNOT with different targets must not cancel"
940        );
941    }
942
943    // ── Conjugation identity tests ───────────────────────────────────────────
944
945    #[test]
946    fn test_hxh_to_z() {
947        let q0 = q(0);
948        let gates = vec![
949            arc(Hadamard { target: q0 }),
950            arc(PauliX { target: q0 }),
951            arc(Hadamard { target: q0 }),
952        ];
953        let result = pass().run(&gates);
954        assert_eq!(result.len(), 1, "H·X·H should reduce to one gate");
955        assert_eq!(result[0].name(), "Z", "H·X·H = Z");
956    }
957
958    #[test]
959    fn test_hzh_to_x() {
960        let q0 = q(0);
961        let gates = vec![
962            arc(Hadamard { target: q0 }),
963            arc(PauliZ { target: q0 }),
964            arc(Hadamard { target: q0 }),
965        ];
966        let result = pass().run(&gates);
967        assert_eq!(result.len(), 1, "H·Z·H should reduce to one gate");
968        assert_eq!(result[0].name(), "X", "H·Z·H = X");
969    }
970
971    // ── Convergence test ─────────────────────────────────────────────────────
972
973    #[test]
974    fn test_multi_pass_convergence() {
975        // H·H·H·H = I (4 H gates cancel in two passes)
976        let q0 = q(0);
977        let gates = vec![
978            arc(Hadamard { target: q0 }),
979            arc(Hadamard { target: q0 }),
980            arc(Hadamard { target: q0 }),
981            arc(Hadamard { target: q0 }),
982        ];
983        let result = pass().run(&gates);
984        assert!(result.is_empty(), "H⁴ should converge to identity");
985    }
986
987    // ── RZ identity (zero angle) ─────────────────────────────────────────────
988
989    #[test]
990    fn test_rz_cancels_to_identity_when_total_is_2pi() {
991        let q0 = q(0);
992        let two_pi = 2.0 * std::f64::consts::PI;
993        let gates = vec![
994            arc(RotationZ {
995                target: q0,
996                theta: two_pi * 0.6,
997            }),
998            arc(RotationZ {
999                target: q0,
1000                theta: two_pi * 0.4,
1001            }),
1002        ];
1003        let result = pass().run(&gates);
1004        // Combined angle = 2π, which should cancel to identity
1005        assert!(
1006            result.is_empty(),
1007            "RZ(0.6·2π)·RZ(0.4·2π) should cancel to identity, got {} gates",
1008            result.len()
1009        );
1010    }
1011
1012    // ── Rz merging across multiple pairs ─────────────────────────────────────
1013
1014    #[test]
1015    fn test_rz_merging_three_gates() {
1016        // RZ(0.3)·RZ(0.3)·RZ(0.3) — first two merge, then we get RZ(0.6)·RZ(0.3) = RZ(0.9)
1017        let q0 = q(0);
1018        let gates = vec![
1019            arc(RotationZ {
1020                target: q0,
1021                theta: 0.3,
1022            }),
1023            arc(RotationZ {
1024                target: q0,
1025                theta: 0.3,
1026            }),
1027            arc(RotationZ {
1028                target: q0,
1029                theta: 0.3,
1030            }),
1031        ];
1032        let result = pass().run(&gates);
1033        assert_eq!(result.len(), 1, "RZ·RZ·RZ should merge to one gate");
1034        let merged = result[0]
1035            .as_any()
1036            .downcast_ref::<RotationZ>()
1037            .expect("should be RZ");
1038        assert!(
1039            (merged.theta - 0.9).abs() < 1e-9,
1040            "merged angle should be 0.9, got {}",
1041            merged.theta
1042        );
1043    }
1044}