use quantrs2_core::gate::{
multi::{CNOT, CZ, SWAP},
single::{
Hadamard, PauliX, PauliY, PauliZ, Phase, PhaseDagger, RotationX, RotationY, RotationZ,
TDagger, T,
},
GateOp,
};
use quantrs2_core::qubit::QubitId;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub struct TemplateGate {
pub gate_name: String,
pub qubits: Vec<usize>,
pub params: Vec<f64>,
}
impl TemplateGate {
fn new(name: impl Into<String>, qubits: Vec<usize>) -> Self {
Self {
gate_name: name.into(),
qubits,
params: vec![],
}
}
fn with_params(name: impl Into<String>, qubits: Vec<usize>, params: Vec<f64>) -> Self {
Self {
gate_name: name.into(),
qubits,
params,
}
}
}
#[derive(Clone)]
enum TemplateKind {
Fixed {
pattern: Vec<TemplateGate>,
replacement: Vec<TemplateGate>,
},
RotationMerge { gate_name: &'static str },
}
#[derive(Clone)]
pub struct GateTemplate {
pub name: &'static str,
kind: TemplateKind,
}
impl GateTemplate {
pub fn fixed(
name: &'static str,
pattern: Vec<TemplateGate>,
replacement: Vec<TemplateGate>,
) -> Self {
Self {
name,
kind: TemplateKind::Fixed {
pattern,
replacement,
},
}
}
pub fn rotation_merge(name: &'static str, gate_name: &'static str) -> Self {
Self {
name,
kind: TemplateKind::RotationMerge { gate_name },
}
}
}
fn standard_templates() -> Vec<GateTemplate> {
vec![
GateTemplate::fixed(
"H·H = I",
vec![
TemplateGate::new("H", vec![0]),
TemplateGate::new("H", vec![0]),
],
vec![],
),
GateTemplate::fixed(
"X·X = I",
vec![
TemplateGate::new("X", vec![0]),
TemplateGate::new("X", vec![0]),
],
vec![],
),
GateTemplate::fixed(
"Y·Y = I",
vec![
TemplateGate::new("Y", vec![0]),
TemplateGate::new("Y", vec![0]),
],
vec![],
),
GateTemplate::fixed(
"Z·Z = I",
vec![
TemplateGate::new("Z", vec![0]),
TemplateGate::new("Z", vec![0]),
],
vec![],
),
GateTemplate::fixed(
"CNOT·CNOT = I",
vec![
TemplateGate::new("CNOT", vec![0, 1]),
TemplateGate::new("CNOT", vec![0, 1]),
],
vec![],
),
GateTemplate::fixed(
"CZ·CZ = I",
vec![
TemplateGate::new("CZ", vec![0, 1]),
TemplateGate::new("CZ", vec![0, 1]),
],
vec![],
),
GateTemplate::fixed(
"SWAP·SWAP = I",
vec![
TemplateGate::new("SWAP", vec![0, 1]),
TemplateGate::new("SWAP", vec![0, 1]),
],
vec![],
),
GateTemplate::fixed(
"S·S = Z",
vec![
TemplateGate::new("S", vec![0]),
TemplateGate::new("S", vec![0]),
],
vec![TemplateGate::new("Z", vec![0])],
),
GateTemplate::fixed(
"T·T = S",
vec![
TemplateGate::new("T", vec![0]),
TemplateGate::new("T", vec![0]),
],
vec![TemplateGate::new("S", vec![0])],
),
GateTemplate::fixed(
"T†·T† = S†",
vec![
TemplateGate::new("T†", vec![0]),
TemplateGate::new("T†", vec![0]),
],
vec![TemplateGate::new("S†", vec![0])],
),
GateTemplate::fixed(
"S†·S† = Z",
vec![
TemplateGate::new("S†", vec![0]),
TemplateGate::new("S†", vec![0]),
],
vec![TemplateGate::new("Z", vec![0])],
),
GateTemplate::fixed(
"S·S† = I",
vec![
TemplateGate::new("S", vec![0]),
TemplateGate::new("S†", vec![0]),
],
vec![],
),
GateTemplate::fixed(
"S†·S = I",
vec![
TemplateGate::new("S†", vec![0]),
TemplateGate::new("S", vec![0]),
],
vec![],
),
GateTemplate::fixed(
"T·T† = I",
vec![
TemplateGate::new("T", vec![0]),
TemplateGate::new("T†", vec![0]),
],
vec![],
),
GateTemplate::fixed(
"T†·T = I",
vec![
TemplateGate::new("T†", vec![0]),
TemplateGate::new("T", vec![0]),
],
vec![],
),
GateTemplate::fixed(
"T⁴ = Z",
vec![
TemplateGate::new("T", vec![0]),
TemplateGate::new("T", vec![0]),
TemplateGate::new("T", vec![0]),
TemplateGate::new("T", vec![0]),
],
vec![TemplateGate::new("Z", vec![0])],
),
GateTemplate::fixed(
"T⁸ = I",
vec![
TemplateGate::new("T", vec![0]),
TemplateGate::new("T", vec![0]),
TemplateGate::new("T", vec![0]),
TemplateGate::new("T", vec![0]),
TemplateGate::new("T", vec![0]),
TemplateGate::new("T", vec![0]),
TemplateGate::new("T", vec![0]),
TemplateGate::new("T", vec![0]),
],
vec![],
),
GateTemplate::fixed(
"H·X·H = Z",
vec![
TemplateGate::new("H", vec![0]),
TemplateGate::new("X", vec![0]),
TemplateGate::new("H", vec![0]),
],
vec![TemplateGate::new("Z", vec![0])],
),
GateTemplate::fixed(
"H·Z·H = X",
vec![
TemplateGate::new("H", vec![0]),
TemplateGate::new("Z", vec![0]),
TemplateGate::new("H", vec![0]),
],
vec![TemplateGate::new("X", vec![0])],
),
GateTemplate::fixed(
"H·Y·H = Y (global phase)",
vec![
TemplateGate::new("H", vec![0]),
TemplateGate::new("Y", vec![0]),
TemplateGate::new("H", vec![0]),
],
vec![TemplateGate::new("Y", vec![0])],
),
GateTemplate::fixed(
"X·Z·X = Z (global phase)",
vec![
TemplateGate::new("X", vec![0]),
TemplateGate::new("Z", vec![0]),
TemplateGate::new("X", vec![0]),
],
vec![TemplateGate::new("Z", vec![0])],
),
GateTemplate::fixed(
"Z·X·Z = X (global phase)",
vec![
TemplateGate::new("Z", vec![0]),
TemplateGate::new("X", vec![0]),
TemplateGate::new("Z", vec![0]),
],
vec![TemplateGate::new("X", vec![0])],
),
GateTemplate::rotation_merge("RZ·RZ = RZ(a+b)", "RZ"),
GateTemplate::rotation_merge("RX·RX = RX(a+b)", "RX"),
GateTemplate::rotation_merge("RY·RY = RY(a+b)", "RY"),
]
}
pub struct TemplateMatchingPass {
templates: Vec<GateTemplate>,
}
impl TemplateMatchingPass {
pub fn new(templates: Vec<GateTemplate>) -> Self {
Self { templates }
}
pub fn with_standard_templates() -> Self {
Self {
templates: standard_templates(),
}
}
pub fn run(
&self,
gates: &[Arc<dyn GateOp + Send + Sync>],
) -> Vec<Arc<dyn GateOp + Send + Sync>> {
let mut current: Vec<Arc<dyn GateOp + Send + Sync>> = gates.to_vec();
loop {
let reduced = self.single_pass(¤t);
if reduced.len() == current.len() {
break;
}
current = reduced;
}
current
}
fn single_pass(
&self,
gates: &[Arc<dyn GateOp + Send + Sync>],
) -> Vec<Arc<dyn GateOp + Send + Sync>> {
let mut result: Vec<Arc<dyn GateOp + Send + Sync>> = Vec::with_capacity(gates.len());
let mut i = 0;
'outer: while i < gates.len() {
for template in &self.templates {
if let Some((replacement, consumed)) = self.try_apply_template(template, gates, i) {
result.extend(replacement);
i += consumed;
continue 'outer;
}
}
result.push(gates[i].clone());
i += 1;
}
result
}
fn try_apply_template(
&self,
template: &GateTemplate,
gates: &[Arc<dyn GateOp + Send + Sync>],
start: usize,
) -> Option<(Vec<Arc<dyn GateOp + Send + Sync>>, usize)> {
match &template.kind {
TemplateKind::Fixed {
pattern,
replacement,
} => self.try_match_fixed(pattern, replacement, gates, start),
TemplateKind::RotationMerge { gate_name } => {
self.try_merge_rotation(gate_name, gates, start)
}
}
}
fn try_match_fixed(
&self,
pattern: &[TemplateGate],
replacement: &[TemplateGate],
gates: &[Arc<dyn GateOp + Send + Sync>],
start: usize,
) -> Option<(Vec<Arc<dyn GateOp + Send + Sync>>, usize)> {
if start + pattern.len() > gates.len() {
return None;
}
let mut qubit_map: Vec<Option<QubitId>> = Vec::new();
for (pat_gate, real_gate) in pattern.iter().zip(gates[start..].iter()) {
if real_gate.name() != pat_gate.gate_name {
return None;
}
let real_qubits = real_gate.qubits();
if real_qubits.len() != pat_gate.qubits.len() {
return None;
}
for (rel_idx, &concrete) in pat_gate.qubits.iter().zip(real_qubits.iter()) {
while qubit_map.len() <= *rel_idx {
qubit_map.push(None);
}
match qubit_map[*rel_idx] {
None => qubit_map[*rel_idx] = Some(concrete),
Some(existing) => {
if existing != concrete {
return None; }
}
}
}
if pat_gate.qubits.len() == 2 {
let r0 = pat_gate.qubits[0];
let r1 = pat_gate.qubits[1];
if r0 != r1 {
if qubit_map.get(r0).copied().flatten() == qubit_map.get(r1).copied().flatten()
{
return None;
}
}
}
}
let mut result: Vec<Arc<dyn GateOp + Send + Sync>> = Vec::new();
for rep_gate in replacement {
let concrete_qubits: Vec<QubitId> = rep_gate
.qubits
.iter()
.filter_map(|&rel| qubit_map.get(rel).copied().flatten())
.collect();
if concrete_qubits.len() != rep_gate.qubits.len() {
return None; }
let gate_arc = make_gate(&rep_gate.gate_name, &concrete_qubits, &rep_gate.params)?;
result.push(gate_arc);
}
Some((result, pattern.len()))
}
fn try_merge_rotation(
&self,
gate_name: &'static str,
gates: &[Arc<dyn GateOp + Send + Sync>],
start: usize,
) -> Option<(Vec<Arc<dyn GateOp + Send + Sync>>, usize)> {
if start + 1 >= gates.len() {
return None;
}
let g0 = &gates[start];
let g1 = &gates[start + 1];
if g0.name() != gate_name || g1.name() != gate_name {
return None;
}
let q0 = g0.qubits();
let q1 = g1.qubits();
if q0.len() != 1 || q1.len() != 1 || q0[0] != q1[0] {
return None;
}
let theta0 = extract_rotation_angle(g0.as_ref(), gate_name)?;
let theta1 = extract_rotation_angle(g1.as_ref(), gate_name)?;
let combined = theta0 + theta1;
let qubit = q0[0];
let angle_mod = combined.rem_euclid(2.0 * std::f64::consts::PI);
if angle_mod < 1e-9 || (2.0 * std::f64::consts::PI - angle_mod) < 1e-9 {
return Some((vec![], 2));
}
let merged = make_gate(gate_name, &[qubit], &[combined])?;
Some((vec![merged], 2))
}
}
fn extract_rotation_angle(gate: &dyn GateOp, gate_name: &str) -> Option<f64> {
match gate_name {
"RX" => gate.as_any().downcast_ref::<RotationX>().map(|g| g.theta),
"RY" => gate.as_any().downcast_ref::<RotationY>().map(|g| g.theta),
"RZ" => gate.as_any().downcast_ref::<RotationZ>().map(|g| g.theta),
_ => None,
}
}
fn make_gate(
name: &str,
qubits: &[QubitId],
params: &[f64],
) -> Option<Arc<dyn GateOp + Send + Sync>> {
match (name, qubits.len()) {
("H", 1) => Some(Arc::new(Hadamard { target: qubits[0] })),
("X", 1) => Some(Arc::new(PauliX { target: qubits[0] })),
("Y", 1) => Some(Arc::new(PauliY { target: qubits[0] })),
("Z", 1) => Some(Arc::new(PauliZ { target: qubits[0] })),
("S", 1) => Some(Arc::new(Phase { target: qubits[0] })),
("S†", 1) => Some(Arc::new(PhaseDagger { target: qubits[0] })),
("T", 1) => Some(Arc::new(T { target: qubits[0] })),
("T†", 1) => Some(Arc::new(TDagger { target: qubits[0] })),
("CNOT", 2) => Some(Arc::new(CNOT {
control: qubits[0],
target: qubits[1],
})),
("CZ", 2) => Some(Arc::new(CZ {
control: qubits[0],
target: qubits[1],
})),
("SWAP", 2) => Some(Arc::new(SWAP {
qubit1: qubits[0],
qubit2: qubits[1],
})),
("RX", 1) if !params.is_empty() => Some(Arc::new(RotationX {
target: qubits[0],
theta: params[0],
})),
("RY", 1) if !params.is_empty() => Some(Arc::new(RotationY {
target: qubits[0],
theta: params[0],
})),
("RZ", 1) if !params.is_empty() => Some(Arc::new(RotationZ {
target: qubits[0],
theta: params[0],
})),
_ => None,
}
}
impl TemplateGate {
pub fn single(gate_name: impl Into<String>) -> Self {
Self::new(gate_name, vec![0])
}
pub fn two_qubit(gate_name: impl Into<String>) -> Self {
Self::new(gate_name, vec![0, 1])
}
pub fn rotation(gate_name: impl Into<String>, angle: f64) -> Self {
Self::with_params(gate_name, vec![0], vec![angle])
}
}
#[cfg(test)]
mod tests {
use super::*;
use quantrs2_core::gate::{
multi::{CNOT, CZ, SWAP},
single::{Hadamard, PauliX, PauliY, PauliZ, Phase, RotationX, RotationY, RotationZ, T},
GateOp,
};
use quantrs2_core::qubit::QubitId;
use std::sync::Arc;
fn q(id: u32) -> QubitId {
QubitId::new(id)
}
fn arc<G: GateOp + Send + Sync + 'static>(g: G) -> Arc<dyn GateOp + Send + Sync> {
Arc::new(g)
}
fn pass() -> TemplateMatchingPass {
TemplateMatchingPass::with_standard_templates()
}
#[test]
fn test_hh_cancellation() {
let q0 = q(0);
let gates = vec![arc(Hadamard { target: q0 }), arc(Hadamard { target: q0 })];
let result = pass().run(&gates);
assert!(
result.is_empty(),
"H·H should cancel to identity (0 gates), got {}",
result.len()
);
}
#[test]
fn test_xx_cancellation() {
let q0 = q(0);
let gates = vec![arc(PauliX { target: q0 }), arc(PauliX { target: q0 })];
let result = pass().run(&gates);
assert!(result.is_empty(), "X·X should cancel");
}
#[test]
fn test_yy_cancellation() {
let q0 = q(0);
let gates = vec![arc(PauliY { target: q0 }), arc(PauliY { target: q0 })];
let result = pass().run(&gates);
assert!(result.is_empty(), "Y·Y should cancel");
}
#[test]
fn test_zz_cancellation() {
let q0 = q(0);
let gates = vec![arc(PauliZ { target: q0 }), arc(PauliZ { target: q0 })];
let result = pass().run(&gates);
assert!(result.is_empty(), "Z·Z should cancel");
}
#[test]
fn test_cnot_cancellation() {
let (c, t) = (q(0), q(1));
let gates = vec![
arc(CNOT {
control: c,
target: t,
}),
arc(CNOT {
control: c,
target: t,
}),
];
let result = pass().run(&gates);
assert!(
result.is_empty(),
"CNOT·CNOT should cancel to identity, got {} gates",
result.len()
);
}
#[test]
fn test_cz_cancellation() {
let (c, t) = (q(0), q(1));
let gates = vec![
arc(CZ {
control: c,
target: t,
}),
arc(CZ {
control: c,
target: t,
}),
];
let result = pass().run(&gates);
assert!(result.is_empty(), "CZ·CZ should cancel");
}
#[test]
fn test_swap_cancellation() {
let (a, b) = (q(0), q(1));
let gates = vec![
arc(SWAP {
qubit1: a,
qubit2: b,
}),
arc(SWAP {
qubit1: a,
qubit2: b,
}),
];
let result = pass().run(&gates);
assert!(result.is_empty(), "SWAP·SWAP should cancel");
}
#[test]
fn test_ss_to_z() {
let q0 = q(0);
let gates = vec![arc(Phase { target: q0 }), arc(Phase { target: q0 })];
let result = pass().run(&gates);
assert_eq!(result.len(), 1, "S·S should produce one gate");
assert_eq!(result[0].name(), "Z", "S·S should produce Z");
}
#[test]
fn test_tt_to_s() {
let q0 = q(0);
let gates = vec![arc(T { target: q0 }), arc(T { target: q0 })];
let result = pass().run(&gates);
assert_eq!(result.len(), 1, "T·T should produce one gate");
assert_eq!(result[0].name(), "S", "T·T should produce S");
}
#[test]
fn test_rz_merging() {
let q0 = q(0);
let gates = vec![
arc(RotationZ {
target: q0,
theta: 0.3,
}),
arc(RotationZ {
target: q0,
theta: 0.7,
}),
];
let result = pass().run(&gates);
assert_eq!(result.len(), 1, "RZ(0.3)·RZ(0.7) should merge to one gate");
assert_eq!(result[0].name(), "RZ");
let merged = result[0]
.as_any()
.downcast_ref::<RotationZ>()
.expect("should downcast to RotationZ");
assert!(
(merged.theta - 1.0).abs() < 1e-9,
"merged angle should be 1.0, got {}",
merged.theta
);
}
#[test]
fn test_rx_merging() {
let q0 = q(0);
let gates = vec![
arc(RotationX {
target: q0,
theta: 0.5,
}),
arc(RotationX {
target: q0,
theta: 0.5,
}),
];
let result = pass().run(&gates);
assert_eq!(result.len(), 1, "RX(0.5)·RX(0.5) should merge");
let merged = result[0]
.as_any()
.downcast_ref::<RotationX>()
.expect("should downcast to RotationX");
assert!(
(merged.theta - 1.0).abs() < 1e-9,
"merged angle should be 1.0"
);
}
#[test]
fn test_ry_merging() {
let q0 = q(0);
let gates = vec![
arc(RotationY {
target: q0,
theta: 0.2,
}),
arc(RotationY {
target: q0,
theta: 0.8,
}),
];
let result = pass().run(&gates);
assert_eq!(result.len(), 1, "RY(0.2)·RY(0.8) should merge");
}
#[test]
fn test_no_false_reduction_different_qubits() {
let gates = vec![
arc(Hadamard { target: q(0) }),
arc(Hadamard { target: q(1) }),
];
let result = pass().run(&gates);
assert_eq!(
result.len(),
2,
"H q[0]; H q[1]; must stay (different qubits)"
);
}
#[test]
fn test_no_false_reduction_different_gates() {
let q0 = q(0);
let gates = vec![arc(Hadamard { target: q0 }), arc(PauliX { target: q0 })];
let result = pass().run(&gates);
assert_eq!(result.len(), 2, "H·X must not reduce");
}
#[test]
fn test_cnot_different_controls_no_cancel() {
let gates = vec![
arc(CNOT {
control: q(0),
target: q(2),
}),
arc(CNOT {
control: q(1),
target: q(2),
}),
];
let result = pass().run(&gates);
assert_eq!(
result.len(),
2,
"CNOT with different controls must not cancel"
);
}
#[test]
fn test_cnot_different_targets_no_cancel() {
let gates = vec![
arc(CNOT {
control: q(0),
target: q(1),
}),
arc(CNOT {
control: q(0),
target: q(2),
}),
];
let result = pass().run(&gates);
assert_eq!(
result.len(),
2,
"CNOT with different targets must not cancel"
);
}
#[test]
fn test_hxh_to_z() {
let q0 = q(0);
let gates = vec![
arc(Hadamard { target: q0 }),
arc(PauliX { target: q0 }),
arc(Hadamard { target: q0 }),
];
let result = pass().run(&gates);
assert_eq!(result.len(), 1, "H·X·H should reduce to one gate");
assert_eq!(result[0].name(), "Z", "H·X·H = Z");
}
#[test]
fn test_hzh_to_x() {
let q0 = q(0);
let gates = vec![
arc(Hadamard { target: q0 }),
arc(PauliZ { target: q0 }),
arc(Hadamard { target: q0 }),
];
let result = pass().run(&gates);
assert_eq!(result.len(), 1, "H·Z·H should reduce to one gate");
assert_eq!(result[0].name(), "X", "H·Z·H = X");
}
#[test]
fn test_multi_pass_convergence() {
let q0 = q(0);
let gates = vec![
arc(Hadamard { target: q0 }),
arc(Hadamard { target: q0 }),
arc(Hadamard { target: q0 }),
arc(Hadamard { target: q0 }),
];
let result = pass().run(&gates);
assert!(result.is_empty(), "H⁴ should converge to identity");
}
#[test]
fn test_rz_cancels_to_identity_when_total_is_2pi() {
let q0 = q(0);
let two_pi = 2.0 * std::f64::consts::PI;
let gates = vec![
arc(RotationZ {
target: q0,
theta: two_pi * 0.6,
}),
arc(RotationZ {
target: q0,
theta: two_pi * 0.4,
}),
];
let result = pass().run(&gates);
assert!(
result.is_empty(),
"RZ(0.6·2π)·RZ(0.4·2π) should cancel to identity, got {} gates",
result.len()
);
}
#[test]
fn test_rz_merging_three_gates() {
let q0 = q(0);
let gates = vec![
arc(RotationZ {
target: q0,
theta: 0.3,
}),
arc(RotationZ {
target: q0,
theta: 0.3,
}),
arc(RotationZ {
target: q0,
theta: 0.3,
}),
];
let result = pass().run(&gates);
assert_eq!(result.len(), 1, "RZ·RZ·RZ should merge to one gate");
let merged = result[0]
.as_any()
.downcast_ref::<RotationZ>()
.expect("should be RZ");
assert!(
(merged.theta - 0.9).abs() < 1e-9,
"merged angle should be 0.9, got {}",
merged.theta
);
}
}