use super::{gates_can_commute, OptimizationPass};
use crate::error::QuantRS2Result;
use crate::gate::{multi::*, single::*, GateOp};
use std::f64::consts::PI;
pub struct PeepholeOptimizer {
pub merge_rotations: bool,
pub remove_identities: bool,
pub enable_commutation: bool,
pub zero_tolerance: f64,
}
impl Default for PeepholeOptimizer {
fn default() -> Self {
Self {
merge_rotations: true,
remove_identities: true,
enable_commutation: true,
zero_tolerance: 1e-10,
}
}
}
impl PeepholeOptimizer {
pub fn new() -> Self {
Self::default()
}
fn is_zero_rotation(&self, angle: f64) -> bool {
let normalized = angle % (2.0 * PI);
normalized.abs() < self.zero_tolerance
|| 2.0f64.mul_add(-PI, normalized).abs() < self.zero_tolerance
}
#[allow(dead_code)]
fn simplify_window(
&self,
window: &[Box<dyn GateOp>],
) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
match window.len() {
2 => self.simplify_pair(&window[0], &window[1]),
3 => Self::simplify_triple(&window[0], &window[1], &window[2]),
_ => Ok(None),
}
}
fn simplify_pair(
&self,
gate1: &Box<dyn GateOp>,
gate2: &Box<dyn GateOp>,
) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
if self.merge_rotations && gate1.qubits() == gate2.qubits() && gate1.qubits().len() == 1 {
let qubit = gate1.qubits()[0];
match (gate1.name(), gate2.name()) {
("RX", "RX") => {
if let (Some(rx1), Some(rx2)) = (
gate1.as_any().downcast_ref::<RotationX>(),
gate2.as_any().downcast_ref::<RotationX>(),
) {
let combined_angle = rx1.theta + rx2.theta;
if self.is_zero_rotation(combined_angle) {
return Ok(Some(vec![]));
}
return Ok(Some(vec![Box::new(RotationX {
target: qubit,
theta: combined_angle,
})]));
}
}
("RY", "RY") => {
if let (Some(ry1), Some(ry2)) = (
gate1.as_any().downcast_ref::<RotationY>(),
gate2.as_any().downcast_ref::<RotationY>(),
) {
let combined_angle = ry1.theta + ry2.theta;
if self.is_zero_rotation(combined_angle) {
return Ok(Some(vec![]));
}
return Ok(Some(vec![Box::new(RotationY {
target: qubit,
theta: combined_angle,
})]));
}
}
("RZ", "RZ") => {
if let (Some(rz1), Some(rz2)) = (
gate1.as_any().downcast_ref::<RotationZ>(),
gate2.as_any().downcast_ref::<RotationZ>(),
) {
let combined_angle = rz1.theta + rz2.theta;
if self.is_zero_rotation(combined_angle) {
return Ok(Some(vec![]));
}
return Ok(Some(vec![Box::new(RotationZ {
target: qubit,
theta: combined_angle,
})]));
}
}
_ => {}
}
}
if gate1.qubits() == gate2.qubits() {
match (gate1.name(), gate2.name()) {
("T", "T") => {
return Ok(Some(vec![Box::new(Phase {
target: gate1.qubits()[0],
})]));
}
("T†", "T†") => {
return Ok(Some(vec![Box::new(PhaseDagger {
target: gate1.qubits()[0],
})]));
}
("S", "T") | ("T", "S") => {
let qubit = gate1.qubits()[0];
return Ok(Some(vec![
Box::new(T { target: qubit }),
Box::new(T { target: qubit }),
Box::new(T { target: qubit }),
]));
}
_ => {}
}
}
if self.enable_commutation
&& gates_can_commute(gate1.as_ref(), gate2.as_ref())
&& gate1.qubits().len() > gate2.qubits().len()
{
return Ok(Some(vec![gate2.clone_gate(), gate1.clone_gate()]));
}
Ok(None)
}
fn simplify_triple(
gate1: &Box<dyn GateOp>,
gate2: &Box<dyn GateOp>,
gate3: &Box<dyn GateOp>,
) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
if gate1.name() == "CNOT" && gate3.name() == "CNOT" && gate2.name() == "RZ" {
if let (Some(cx1), Some(cx2), Some(rz)) = (
gate1.as_any().downcast_ref::<CNOT>(),
gate3.as_any().downcast_ref::<CNOT>(),
gate2.as_any().downcast_ref::<RotationZ>(),
) {
if cx1.control == cx2.control && cx1.target == cx2.target && rz.target == cx1.target
{
return Ok(Some(vec![Box::new(CRZ {
control: cx1.control,
target: cx1.target,
theta: rz.theta,
})]));
}
}
}
if gate1.name() == "H"
&& gate2.name() == "X"
&& gate3.name() == "H"
&& gate1.qubits() == gate2.qubits()
&& gate2.qubits() == gate3.qubits()
{
return Ok(Some(vec![Box::new(PauliZ {
target: gate1.qubits()[0],
})]));
}
if gate1.name() == "H"
&& gate2.name() == "Z"
&& gate3.name() == "H"
&& gate1.qubits() == gate2.qubits()
&& gate2.qubits() == gate3.qubits()
{
return Ok(Some(vec![Box::new(PauliX {
target: gate1.qubits()[0],
})]));
}
if gate1.name() == "X"
&& gate2.name() == "Y"
&& gate3.name() == "X"
&& gate1.qubits() == gate2.qubits()
&& gate2.qubits() == gate3.qubits()
{
let qubit = gate1.qubits()[0];
return Ok(Some(vec![
Box::new(PauliY { target: qubit }),
Box::new(PauliZ { target: qubit }),
]));
}
Ok(None)
}
fn remove_identity_rotations(&self, gates: Vec<Box<dyn GateOp>>) -> Vec<Box<dyn GateOp>> {
gates
.into_iter()
.filter(|gate| match gate.name() {
"RX" => {
if let Some(rx) = gate.as_any().downcast_ref::<RotationX>() {
!self.is_zero_rotation(rx.theta)
} else {
true
}
}
"RY" => {
if let Some(ry) = gate.as_any().downcast_ref::<RotationY>() {
!self.is_zero_rotation(ry.theta)
} else {
true
}
}
"RZ" => {
if let Some(rz) = gate.as_any().downcast_ref::<RotationZ>() {
!self.is_zero_rotation(rz.theta)
} else {
true
}
}
_ => true,
})
.collect()
}
}
impl OptimizationPass for PeepholeOptimizer {
fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
let mut current = gates;
let mut changed = true;
let max_iterations = 10;
let mut iterations = 0;
while changed && iterations < max_iterations {
changed = false;
let mut optimized = Vec::new();
let mut i = 0;
while i < current.len() {
if i + 2 < current.len() {
if let Some(simplified) =
Self::simplify_triple(¤t[i], ¤t[i + 1], ¤t[i + 2])?
{
optimized.extend(simplified);
i += 3;
changed = true;
continue;
}
}
if i + 1 < current.len() {
if let Some(simplified) = self.simplify_pair(¤t[i], ¤t[i + 1])? {
optimized.extend(simplified);
i += 2;
changed = true;
continue;
}
}
optimized.push(current[i].clone_gate());
i += 1;
}
current = optimized;
iterations += 1;
}
if self.remove_identities {
current = self.remove_identity_rotations(current);
}
Ok(current)
}
fn name(&self) -> &'static str {
"Peephole Optimization"
}
}
pub struct TCountOptimizer {
pub max_depth: usize,
}
impl Default for TCountOptimizer {
fn default() -> Self {
Self::new()
}
}
impl TCountOptimizer {
pub const fn new() -> Self {
Self { max_depth: 4 }
}
fn count_t_gates(gates: &[Box<dyn GateOp>]) -> usize {
gates
.iter()
.filter(|g| g.name() == "T" || g.name() == "T†")
.count()
}
fn reduce_t_count(gates: &[Box<dyn GateOp>]) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
if gates.len() >= 3 {
for i in 0..gates.len() - 2 {
if gates[i].name() == "T"
&& gates[i + 1].name() == "S"
&& gates[i + 2].name() == "T"
&& gates[i].qubits() == gates[i + 1].qubits()
&& gates[i + 1].qubits() == gates[i + 2].qubits()
{
let qubit = gates[i].qubits()[0];
let mut result = Vec::new();
for j in 0..i {
result.push(gates[j].clone_gate());
}
result.push(Box::new(Phase { target: qubit }) as Box<dyn GateOp>);
result.push(Box::new(T { target: qubit }) as Box<dyn GateOp>);
result.push(Box::new(Phase { target: qubit }) as Box<dyn GateOp>);
for j in i + 3..gates.len() {
result.push(gates[j].clone_gate());
}
return Ok(Some(result));
}
}
}
Ok(None)
}
}
impl OptimizationPass for TCountOptimizer {
fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
let original_t_count = Self::count_t_gates(&gates);
if let Some(optimized) = Self::reduce_t_count(&gates)? {
let new_t_count = Self::count_t_gates(&optimized);
if new_t_count < original_t_count {
return Ok(optimized);
}
}
Ok(gates)
}
fn name(&self) -> &'static str {
"T-Count Optimization"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::QubitId;
#[test]
fn test_rotation_merging() {
let optimizer = PeepholeOptimizer::new();
let qubit = QubitId(0);
let gates: Vec<Box<dyn GateOp>> = vec![
Box::new(RotationZ {
target: qubit,
theta: PI / 4.0,
}),
Box::new(RotationZ {
target: qubit,
theta: PI / 4.0,
}),
];
let result = optimizer
.optimize(gates)
.expect("Failed to optimize gates in test_rotation_merging");
assert_eq!(result.len(), 1);
if let Some(rz) = result[0].as_any().downcast_ref::<RotationZ>() {
assert!((rz.theta - PI / 2.0).abs() < 1e-10);
} else {
panic!("Expected RotationZ");
}
}
#[test]
fn test_zero_rotation_removal() {
let optimizer = PeepholeOptimizer::new();
let qubit = QubitId(0);
let gates: Vec<Box<dyn GateOp>> = vec![
Box::new(RotationX {
target: qubit,
theta: PI,
}),
Box::new(RotationX {
target: qubit,
theta: PI,
}),
];
let result = optimizer
.optimize(gates)
.expect("Failed to optimize gates in test_zero_rotation_removal");
assert_eq!(result.len(), 0);
}
#[test]
fn test_cnot_rz_pattern() {
let optimizer = PeepholeOptimizer::new();
let q0 = QubitId(0);
let q1 = QubitId(1);
let gates: Vec<Box<dyn GateOp>> = vec![
Box::new(CNOT {
control: q0,
target: q1,
}),
Box::new(RotationZ {
target: q1,
theta: PI / 4.0,
}),
Box::new(CNOT {
control: q0,
target: q1,
}),
];
let result = optimizer
.optimize(gates)
.expect("Failed to optimize gates in test_cnot_rz_pattern");
assert_eq!(result.len(), 1);
assert_eq!(result[0].name(), "CRZ");
}
#[test]
fn test_h_x_h_pattern() {
let optimizer = PeepholeOptimizer::new();
let qubit = QubitId(0);
let gates: Vec<Box<dyn GateOp>> = vec![
Box::new(Hadamard { target: qubit }),
Box::new(PauliX { target: qubit }),
Box::new(Hadamard { target: qubit }),
];
let result = optimizer
.optimize(gates)
.expect("Failed to optimize gates in test_h_x_h_pattern");
assert_eq!(result.len(), 1);
assert_eq!(result[0].name(), "Z");
}
#[test]
fn test_t_gate_combination() {
let optimizer = PeepholeOptimizer::new();
let qubit = QubitId(0);
let gates: Vec<Box<dyn GateOp>> =
vec![Box::new(T { target: qubit }), Box::new(T { target: qubit })];
let result = optimizer
.optimize(gates)
.expect("Failed to optimize gates in test_t_gate_combination");
assert_eq!(result.len(), 1);
assert_eq!(result[0].name(), "S");
}
}