quantrs2_core/optimization/
peephole.rs

1//! Peephole optimization for quantum circuits
2//!
3//! This module implements peephole optimization, which looks for small patterns
4//! of gates that can be simplified or eliminated.
5use super::{gates_can_commute, OptimizationPass};
6use crate::error::QuantRS2Result;
7use crate::gate::{multi::*, single::*, GateOp};
8use std::f64::consts::PI;
9/// Peephole optimization pass
10pub struct PeepholeOptimizer {
11    /// Enable rotation merging
12    pub merge_rotations: bool,
13    /// Enable identity removal
14    pub remove_identities: bool,
15    /// Enable gate commutation
16    pub enable_commutation: bool,
17    /// Tolerance for identifying zero rotations
18    pub zero_tolerance: f64,
19}
20impl Default for PeepholeOptimizer {
21    fn default() -> Self {
22        Self {
23            merge_rotations: true,
24            remove_identities: true,
25            enable_commutation: true,
26            zero_tolerance: 1e-10,
27        }
28    }
29}
30impl PeepholeOptimizer {
31    /// Create a new peephole optimizer
32    pub fn new() -> Self {
33        Self::default()
34    }
35    /// Check if a rotation angle is effectively zero
36    fn is_zero_rotation(&self, angle: f64) -> bool {
37        let normalized = angle % (2.0 * PI);
38        normalized.abs() < self.zero_tolerance
39            || 2.0f64.mul_add(-PI, normalized).abs() < self.zero_tolerance
40    }
41    /// Try to simplify a window of gates
42    #[allow(dead_code)]
43    fn simplify_window(
44        &self,
45        window: &[Box<dyn GateOp>],
46    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
47        match window.len() {
48            2 => self.simplify_pair(&window[0], &window[1]),
49            3 => Self::simplify_triple(&window[0], &window[1], &window[2]),
50            _ => Ok(None),
51        }
52    }
53    /// Simplify a pair of gates
54    fn simplify_pair(
55        &self,
56        gate1: &Box<dyn GateOp>,
57        gate2: &Box<dyn GateOp>,
58    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
59        if self.merge_rotations && gate1.qubits() == gate2.qubits() && gate1.qubits().len() == 1 {
60            let qubit = gate1.qubits()[0];
61            match (gate1.name(), gate2.name()) {
62                ("RX", "RX") => {
63                    if let (Some(rx1), Some(rx2)) = (
64                        gate1.as_any().downcast_ref::<RotationX>(),
65                        gate2.as_any().downcast_ref::<RotationX>(),
66                    ) {
67                        let combined_angle = rx1.theta + rx2.theta;
68                        if self.is_zero_rotation(combined_angle) {
69                            return Ok(Some(vec![]));
70                        }
71                        return Ok(Some(vec![Box::new(RotationX {
72                            target: qubit,
73                            theta: combined_angle,
74                        })]));
75                    }
76                }
77                ("RY", "RY") => {
78                    if let (Some(ry1), Some(ry2)) = (
79                        gate1.as_any().downcast_ref::<RotationY>(),
80                        gate2.as_any().downcast_ref::<RotationY>(),
81                    ) {
82                        let combined_angle = ry1.theta + ry2.theta;
83                        if self.is_zero_rotation(combined_angle) {
84                            return Ok(Some(vec![]));
85                        }
86                        return Ok(Some(vec![Box::new(RotationY {
87                            target: qubit,
88                            theta: combined_angle,
89                        })]));
90                    }
91                }
92                ("RZ", "RZ") => {
93                    if let (Some(rz1), Some(rz2)) = (
94                        gate1.as_any().downcast_ref::<RotationZ>(),
95                        gate2.as_any().downcast_ref::<RotationZ>(),
96                    ) {
97                        let combined_angle = rz1.theta + rz2.theta;
98                        if self.is_zero_rotation(combined_angle) {
99                            return Ok(Some(vec![]));
100                        }
101                        return Ok(Some(vec![Box::new(RotationZ {
102                            target: qubit,
103                            theta: combined_angle,
104                        })]));
105                    }
106                }
107                _ => {}
108            }
109        }
110        if gate1.qubits() == gate2.qubits() {
111            match (gate1.name(), gate2.name()) {
112                ("T", "T") => {
113                    return Ok(Some(vec![Box::new(Phase {
114                        target: gate1.qubits()[0],
115                    })]));
116                }
117                ("T†", "T†") => {
118                    return Ok(Some(vec![Box::new(PhaseDagger {
119                        target: gate1.qubits()[0],
120                    })]));
121                }
122                ("S", "T") | ("T", "S") => {
123                    let qubit = gate1.qubits()[0];
124                    return Ok(Some(vec![
125                        Box::new(T { target: qubit }),
126                        Box::new(T { target: qubit }),
127                        Box::new(T { target: qubit }),
128                    ]));
129                }
130                _ => {}
131            }
132        }
133        if self.enable_commutation
134            && gates_can_commute(gate1.as_ref(), gate2.as_ref())
135            && gate1.qubits().len() > gate2.qubits().len()
136        {
137            return Ok(Some(vec![gate2.clone_gate(), gate1.clone_gate()]));
138        }
139        Ok(None)
140    }
141    /// Simplify a triple of gates
142    fn simplify_triple(
143        gate1: &Box<dyn GateOp>,
144        gate2: &Box<dyn GateOp>,
145        gate3: &Box<dyn GateOp>,
146    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
147        if gate1.name() == "CNOT" && gate3.name() == "CNOT" && gate2.name() == "RZ" {
148            if let (Some(cx1), Some(cx2), Some(rz)) = (
149                gate1.as_any().downcast_ref::<CNOT>(),
150                gate3.as_any().downcast_ref::<CNOT>(),
151                gate2.as_any().downcast_ref::<RotationZ>(),
152            ) {
153                if cx1.control == cx2.control && cx1.target == cx2.target && rz.target == cx1.target
154                {
155                    return Ok(Some(vec![Box::new(CRZ {
156                        control: cx1.control,
157                        target: cx1.target,
158                        theta: rz.theta,
159                    })]));
160                }
161            }
162        }
163        if gate1.name() == "H"
164            && gate2.name() == "X"
165            && gate3.name() == "H"
166            && gate1.qubits() == gate2.qubits()
167            && gate2.qubits() == gate3.qubits()
168        {
169            return Ok(Some(vec![Box::new(PauliZ {
170                target: gate1.qubits()[0],
171            })]));
172        }
173        if gate1.name() == "H"
174            && gate2.name() == "Z"
175            && gate3.name() == "H"
176            && gate1.qubits() == gate2.qubits()
177            && gate2.qubits() == gate3.qubits()
178        {
179            return Ok(Some(vec![Box::new(PauliX {
180                target: gate1.qubits()[0],
181            })]));
182        }
183        if gate1.name() == "X"
184            && gate2.name() == "Y"
185            && gate3.name() == "X"
186            && gate1.qubits() == gate2.qubits()
187            && gate2.qubits() == gate3.qubits()
188        {
189            let qubit = gate1.qubits()[0];
190            return Ok(Some(vec![
191                Box::new(PauliY { target: qubit }),
192                Box::new(PauliZ { target: qubit }),
193            ]));
194        }
195        Ok(None)
196    }
197    /// Remove identity rotations
198    fn remove_identity_rotations(&self, gates: Vec<Box<dyn GateOp>>) -> Vec<Box<dyn GateOp>> {
199        gates
200            .into_iter()
201            .filter(|gate| match gate.name() {
202                "RX" => {
203                    if let Some(rx) = gate.as_any().downcast_ref::<RotationX>() {
204                        !self.is_zero_rotation(rx.theta)
205                    } else {
206                        true
207                    }
208                }
209                "RY" => {
210                    if let Some(ry) = gate.as_any().downcast_ref::<RotationY>() {
211                        !self.is_zero_rotation(ry.theta)
212                    } else {
213                        true
214                    }
215                }
216                "RZ" => {
217                    if let Some(rz) = gate.as_any().downcast_ref::<RotationZ>() {
218                        !self.is_zero_rotation(rz.theta)
219                    } else {
220                        true
221                    }
222                }
223                _ => true,
224            })
225            .collect()
226    }
227}
228impl OptimizationPass for PeepholeOptimizer {
229    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
230        let mut current = gates;
231        let mut changed = true;
232        let max_iterations = 10;
233        let mut iterations = 0;
234        while changed && iterations < max_iterations {
235            changed = false;
236            let mut optimized = Vec::new();
237            let mut i = 0;
238            while i < current.len() {
239                if i + 2 < current.len() {
240                    if let Some(simplified) =
241                        Self::simplify_triple(&current[i], &current[i + 1], &current[i + 2])?
242                    {
243                        optimized.extend(simplified);
244                        i += 3;
245                        changed = true;
246                        continue;
247                    }
248                }
249                if i + 1 < current.len() {
250                    if let Some(simplified) = self.simplify_pair(&current[i], &current[i + 1])? {
251                        optimized.extend(simplified);
252                        i += 2;
253                        changed = true;
254                        continue;
255                    }
256                }
257                optimized.push(current[i].clone_gate());
258                i += 1;
259            }
260            current = optimized;
261            iterations += 1;
262        }
263        if self.remove_identities {
264            current = self.remove_identity_rotations(current);
265        }
266        Ok(current)
267    }
268    fn name(&self) -> &'static str {
269        "Peephole Optimization"
270    }
271}
272/// Specialized optimizer for T-count reduction
273pub struct TCountOptimizer {
274    /// Maximum search depth for optimization
275    pub max_depth: usize,
276}
277impl Default for TCountOptimizer {
278    fn default() -> Self {
279        Self::new()
280    }
281}
282impl TCountOptimizer {
283    pub const fn new() -> Self {
284        Self { max_depth: 4 }
285    }
286    /// Count T gates in a sequence
287    fn count_t_gates(gates: &[Box<dyn GateOp>]) -> usize {
288        gates
289            .iter()
290            .filter(|g| g.name() == "T" || g.name() == "T†")
291            .count()
292    }
293    /// Try to reduce T-count by recognizing special patterns
294    fn reduce_t_count(gates: &[Box<dyn GateOp>]) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
295        if gates.len() >= 3 {
296            for i in 0..gates.len() - 2 {
297                if gates[i].name() == "T"
298                    && gates[i + 1].name() == "S"
299                    && gates[i + 2].name() == "T"
300                    && gates[i].qubits() == gates[i + 1].qubits()
301                    && gates[i + 1].qubits() == gates[i + 2].qubits()
302                {
303                    let qubit = gates[i].qubits()[0];
304                    let mut result = Vec::new();
305                    for j in 0..i {
306                        result.push(gates[j].clone_gate());
307                    }
308                    result.push(Box::new(Phase { target: qubit }) as Box<dyn GateOp>);
309                    result.push(Box::new(T { target: qubit }) as Box<dyn GateOp>);
310                    result.push(Box::new(Phase { target: qubit }) as Box<dyn GateOp>);
311                    for j in i + 3..gates.len() {
312                        result.push(gates[j].clone_gate());
313                    }
314                    return Ok(Some(result));
315                }
316            }
317        }
318        Ok(None)
319    }
320}
321impl OptimizationPass for TCountOptimizer {
322    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
323        let original_t_count = Self::count_t_gates(&gates);
324        if let Some(optimized) = Self::reduce_t_count(&gates)? {
325            let new_t_count = Self::count_t_gates(&optimized);
326            if new_t_count < original_t_count {
327                return Ok(optimized);
328            }
329        }
330        Ok(gates)
331    }
332    fn name(&self) -> &'static str {
333        "T-Count Optimization"
334    }
335}
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::prelude::QubitId;
340    #[test]
341    fn test_rotation_merging() {
342        let optimizer = PeepholeOptimizer::new();
343        let qubit = QubitId(0);
344        let gates: Vec<Box<dyn GateOp>> = vec![
345            Box::new(RotationZ {
346                target: qubit,
347                theta: PI / 4.0,
348            }),
349            Box::new(RotationZ {
350                target: qubit,
351                theta: PI / 4.0,
352            }),
353        ];
354        let result = optimizer
355            .optimize(gates)
356            .expect("Failed to optimize gates in test_rotation_merging");
357        assert_eq!(result.len(), 1);
358        if let Some(rz) = result[0].as_any().downcast_ref::<RotationZ>() {
359            assert!((rz.theta - PI / 2.0).abs() < 1e-10);
360        } else {
361            panic!("Expected RotationZ");
362        }
363    }
364    #[test]
365    fn test_zero_rotation_removal() {
366        let optimizer = PeepholeOptimizer::new();
367        let qubit = QubitId(0);
368        let gates: Vec<Box<dyn GateOp>> = vec![
369            Box::new(RotationX {
370                target: qubit,
371                theta: PI,
372            }),
373            Box::new(RotationX {
374                target: qubit,
375                theta: PI,
376            }),
377        ];
378        let result = optimizer
379            .optimize(gates)
380            .expect("Failed to optimize gates in test_zero_rotation_removal");
381        assert_eq!(result.len(), 0);
382    }
383    #[test]
384    fn test_cnot_rz_pattern() {
385        let optimizer = PeepholeOptimizer::new();
386        let q0 = QubitId(0);
387        let q1 = QubitId(1);
388        let gates: Vec<Box<dyn GateOp>> = vec![
389            Box::new(CNOT {
390                control: q0,
391                target: q1,
392            }),
393            Box::new(RotationZ {
394                target: q1,
395                theta: PI / 4.0,
396            }),
397            Box::new(CNOT {
398                control: q0,
399                target: q1,
400            }),
401        ];
402        let result = optimizer
403            .optimize(gates)
404            .expect("Failed to optimize gates in test_cnot_rz_pattern");
405        assert_eq!(result.len(), 1);
406        assert_eq!(result[0].name(), "CRZ");
407    }
408    #[test]
409    fn test_h_x_h_pattern() {
410        let optimizer = PeepholeOptimizer::new();
411        let qubit = QubitId(0);
412        let gates: Vec<Box<dyn GateOp>> = vec![
413            Box::new(Hadamard { target: qubit }),
414            Box::new(PauliX { target: qubit }),
415            Box::new(Hadamard { target: qubit }),
416        ];
417        let result = optimizer
418            .optimize(gates)
419            .expect("Failed to optimize gates in test_h_x_h_pattern");
420        assert_eq!(result.len(), 1);
421        assert_eq!(result[0].name(), "Z");
422    }
423    #[test]
424    fn test_t_gate_combination() {
425        let optimizer = PeepholeOptimizer::new();
426        let qubit = QubitId(0);
427        let gates: Vec<Box<dyn GateOp>> =
428            vec![Box::new(T { target: qubit }), Box::new(T { target: qubit })];
429        let result = optimizer
430            .optimize(gates)
431            .expect("Failed to optimize gates in test_t_gate_combination");
432        assert_eq!(result.len(), 1);
433        assert_eq!(result[0].name(), "S");
434    }
435}