quantrs2_core/optimization/
peephole.rs

1//! Peephole optimization for quantum circuits
2//!
3//! This module implements peephole optimization, which looks for small patterns
4//! of gates that can be simplified or eliminated.
5
6use crate::error::{QuantRS2Error, QuantRS2Result};
7use crate::gate::{multi::*, single::*, GateOp};
8use crate::qubit::QubitId;
9use std::f64::consts::PI;
10
11use super::{gates_can_commute, OptimizationPass};
12
13/// Peephole optimization pass
14pub struct PeepholeOptimizer {
15    /// Enable rotation merging
16    pub merge_rotations: bool,
17    /// Enable identity removal
18    pub remove_identities: bool,
19    /// Enable gate commutation
20    pub enable_commutation: bool,
21    /// Tolerance for identifying zero rotations
22    pub zero_tolerance: f64,
23}
24
25impl Default for PeepholeOptimizer {
26    fn default() -> Self {
27        Self {
28            merge_rotations: true,
29            remove_identities: true,
30            enable_commutation: true,
31            zero_tolerance: 1e-10,
32        }
33    }
34}
35
36impl PeepholeOptimizer {
37    /// Create a new peephole optimizer
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Check if a rotation angle is effectively zero
43    fn is_zero_rotation(&self, angle: f64) -> bool {
44        let normalized = angle % (2.0 * PI);
45        normalized.abs() < self.zero_tolerance
46            || (normalized - 2.0 * PI).abs() < self.zero_tolerance
47    }
48
49    /// Try to simplify a window of gates
50    fn simplify_window(
51        &self,
52        window: &[Box<dyn GateOp>],
53    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
54        match window.len() {
55            2 => self.simplify_pair(&window[0], &window[1]),
56            3 => self.simplify_triple(&window[0], &window[1], &window[2]),
57            _ => Ok(None),
58        }
59    }
60
61    /// Simplify a pair of gates
62    fn simplify_pair(
63        &self,
64        gate1: &Box<dyn GateOp>,
65        gate2: &Box<dyn GateOp>,
66    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
67        // Handle rotation merging
68        if self.merge_rotations && gate1.qubits() == gate2.qubits() && gate1.qubits().len() == 1 {
69            let qubit = gate1.qubits()[0];
70
71            match (gate1.name(), gate2.name()) {
72                ("RX", "RX") => {
73                    if let (Some(rx1), Some(rx2)) = (
74                        gate1.as_any().downcast_ref::<RotationX>(),
75                        gate2.as_any().downcast_ref::<RotationX>(),
76                    ) {
77                        let combined_angle = rx1.theta + rx2.theta;
78                        if self.is_zero_rotation(combined_angle) {
79                            return Ok(Some(vec![])); // Remove both
80                        }
81                        return Ok(Some(vec![Box::new(RotationX {
82                            target: qubit,
83                            theta: combined_angle,
84                        })]));
85                    }
86                }
87                ("RY", "RY") => {
88                    if let (Some(ry1), Some(ry2)) = (
89                        gate1.as_any().downcast_ref::<RotationY>(),
90                        gate2.as_any().downcast_ref::<RotationY>(),
91                    ) {
92                        let combined_angle = ry1.theta + ry2.theta;
93                        if self.is_zero_rotation(combined_angle) {
94                            return Ok(Some(vec![])); // Remove both
95                        }
96                        return Ok(Some(vec![Box::new(RotationY {
97                            target: qubit,
98                            theta: combined_angle,
99                        })]));
100                    }
101                }
102                ("RZ", "RZ") => {
103                    if let (Some(rz1), Some(rz2)) = (
104                        gate1.as_any().downcast_ref::<RotationZ>(),
105                        gate2.as_any().downcast_ref::<RotationZ>(),
106                    ) {
107                        let combined_angle = rz1.theta + rz2.theta;
108                        if self.is_zero_rotation(combined_angle) {
109                            return Ok(Some(vec![])); // Remove both
110                        }
111                        return Ok(Some(vec![Box::new(RotationZ {
112                            target: qubit,
113                            theta: combined_angle,
114                        })]));
115                    }
116                }
117                _ => {}
118            }
119        }
120
121        // Handle special patterns
122        if gate1.qubits() == gate2.qubits() {
123            match (gate1.name(), gate2.name()) {
124                // T gate combinations
125                ("T", "T") => {
126                    // T² = S
127                    return Ok(Some(vec![Box::new(Phase {
128                        target: gate1.qubits()[0],
129                    })]));
130                }
131                ("T†", "T†") => {
132                    // T†² = S†
133                    return Ok(Some(vec![Box::new(PhaseDagger {
134                        target: gate1.qubits()[0],
135                    })]));
136                }
137
138                // S and T combinations
139                ("S", "T") | ("T", "S") => {
140                    // S·T = T·S = T³ (since T² = S)
141                    let qubit = gate1.qubits()[0];
142                    return Ok(Some(vec![
143                        Box::new(T { target: qubit }),
144                        Box::new(T { target: qubit }),
145                        Box::new(T { target: qubit }),
146                    ]));
147                }
148
149                _ => {}
150            }
151        }
152
153        // Try commutation if enabled
154        if self.enable_commutation && gates_can_commute(gate1.as_ref(), gate2.as_ref()) {
155            // Return in swapped order if it might help later optimizations
156            // This is a heuristic - we prefer to move single-qubit gates earlier
157            if gate1.qubits().len() > gate2.qubits().len() {
158                return Ok(Some(vec![gate2.clone_gate(), gate1.clone_gate()]));
159            }
160        }
161
162        Ok(None)
163    }
164
165    /// Simplify a triple of gates
166    fn simplify_triple(
167        &self,
168        gate1: &Box<dyn GateOp>,
169        gate2: &Box<dyn GateOp>,
170        gate3: &Box<dyn GateOp>,
171    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
172        // CX-Rz-CX pattern (controlled rotation)
173        if gate1.name() == "CNOT" && gate3.name() == "CNOT" && gate2.name() == "RZ" {
174            if let (Some(cx1), Some(cx2), Some(rz)) = (
175                gate1.as_any().downcast_ref::<CNOT>(),
176                gate3.as_any().downcast_ref::<CNOT>(),
177                gate2.as_any().downcast_ref::<RotationZ>(),
178            ) {
179                // Check if it's the controlled-Rz pattern
180                if cx1.control == cx2.control && cx1.target == cx2.target && rz.target == cx1.target
181                {
182                    // This is equivalent to a CRZ gate
183                    return Ok(Some(vec![Box::new(CRZ {
184                        control: cx1.control,
185                        target: cx1.target,
186                        theta: rz.theta,
187                    })]));
188                }
189            }
190        }
191
192        // H-X-H = Z pattern
193        if gate1.name() == "H" && gate2.name() == "X" && gate3.name() == "H" {
194            if gate1.qubits() == gate2.qubits() && gate2.qubits() == gate3.qubits() {
195                return Ok(Some(vec![Box::new(PauliZ {
196                    target: gate1.qubits()[0],
197                })]));
198            }
199        }
200
201        // H-Z-H = X pattern
202        if gate1.name() == "H" && gate2.name() == "Z" && gate3.name() == "H" {
203            if gate1.qubits() == gate2.qubits() && gate2.qubits() == gate3.qubits() {
204                return Ok(Some(vec![Box::new(PauliX {
205                    target: gate1.qubits()[0],
206                })]));
207            }
208        }
209
210        // X-Y-X = -Y pattern
211        if gate1.name() == "X" && gate2.name() == "Y" && gate3.name() == "X" {
212            if gate1.qubits() == gate2.qubits() && gate2.qubits() == gate3.qubits() {
213                let qubit = gate1.qubits()[0];
214                return Ok(Some(vec![
215                    Box::new(PauliY { target: qubit }),
216                    Box::new(PauliZ { target: qubit }), // Global phase -1
217                ]));
218            }
219        }
220
221        Ok(None)
222    }
223
224    /// Remove identity rotations
225    fn remove_identity_rotations(&self, gates: Vec<Box<dyn GateOp>>) -> Vec<Box<dyn GateOp>> {
226        gates
227            .into_iter()
228            .filter(|gate| match gate.name() {
229                "RX" => {
230                    if let Some(rx) = gate.as_any().downcast_ref::<RotationX>() {
231                        !self.is_zero_rotation(rx.theta)
232                    } else {
233                        true
234                    }
235                }
236                "RY" => {
237                    if let Some(ry) = gate.as_any().downcast_ref::<RotationY>() {
238                        !self.is_zero_rotation(ry.theta)
239                    } else {
240                        true
241                    }
242                }
243                "RZ" => {
244                    if let Some(rz) = gate.as_any().downcast_ref::<RotationZ>() {
245                        !self.is_zero_rotation(rz.theta)
246                    } else {
247                        true
248                    }
249                }
250                _ => true,
251            })
252            .collect()
253    }
254}
255
256impl OptimizationPass for PeepholeOptimizer {
257    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
258        let mut current = gates;
259        let mut changed = true;
260        let max_iterations = 10; // Prevent infinite loops
261        let mut iterations = 0;
262
263        while changed && iterations < max_iterations {
264            changed = false;
265            let mut optimized = Vec::new();
266            let mut i = 0;
267
268            while i < current.len() {
269                // Try triple patterns first
270                if i + 2 < current.len() {
271                    if let Some(simplified) =
272                        self.simplify_triple(&current[i], &current[i + 1], &current[i + 2])?
273                    {
274                        optimized.extend(simplified);
275                        i += 3;
276                        changed = true;
277                        continue;
278                    }
279                }
280
281                // Try pair patterns
282                if i + 1 < current.len() {
283                    if let Some(simplified) = self.simplify_pair(&current[i], &current[i + 1])? {
284                        optimized.extend(simplified);
285                        i += 2;
286                        changed = true;
287                        continue;
288                    }
289                }
290
291                // No pattern matched, keep the gate
292                optimized.push(current[i].clone_gate());
293                i += 1;
294            }
295
296            current = optimized;
297            iterations += 1;
298        }
299
300        // Final pass to remove identity rotations
301        if self.remove_identities {
302            current = self.remove_identity_rotations(current);
303        }
304
305        Ok(current)
306    }
307
308    fn name(&self) -> &str {
309        "Peephole Optimization"
310    }
311}
312
313/// Specialized optimizer for T-count reduction
314pub struct TCountOptimizer {
315    /// Maximum search depth for optimization
316    pub max_depth: usize,
317}
318
319impl TCountOptimizer {
320    pub fn new() -> Self {
321        Self { max_depth: 4 }
322    }
323
324    /// Count T gates in a sequence
325    fn count_t_gates(gates: &[Box<dyn GateOp>]) -> usize {
326        gates
327            .iter()
328            .filter(|g| g.name() == "T" || g.name() == "T†")
329            .count()
330    }
331
332    /// Try to reduce T-count by recognizing special patterns
333    fn reduce_t_count(
334        &self,
335        gates: &[Box<dyn GateOp>],
336    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
337        // Pattern: T-S-T = S-T-S (both have T-count 2, but might enable other optimizations)
338        if gates.len() >= 3 {
339            for i in 0..gates.len() - 2 {
340                if gates[i].name() == "T"
341                    && gates[i + 1].name() == "S"
342                    && gates[i + 2].name() == "T"
343                {
344                    if gates[i].qubits() == gates[i + 1].qubits()
345                        && gates[i + 1].qubits() == gates[i + 2].qubits()
346                    {
347                        let qubit = gates[i].qubits()[0];
348                        let mut result = Vec::new();
349
350                        // Copy gates before pattern
351                        for j in 0..i {
352                            result.push(gates[j].clone_gate());
353                        }
354
355                        // Replace pattern
356                        result.push(Box::new(Phase { target: qubit }) as Box<dyn GateOp>);
357                        result.push(Box::new(T { target: qubit }) as Box<dyn GateOp>);
358                        result.push(Box::new(Phase { target: qubit }) as Box<dyn GateOp>);
359
360                        // Copy gates after pattern
361                        for j in i + 3..gates.len() {
362                            result.push(gates[j].clone_gate());
363                        }
364
365                        return Ok(Some(result));
366                    }
367                }
368            }
369        }
370
371        Ok(None)
372    }
373}
374
375impl OptimizationPass for TCountOptimizer {
376    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
377        let original_t_count = Self::count_t_gates(&gates);
378
379        if let Some(optimized) = self.reduce_t_count(&gates)? {
380            let new_t_count = Self::count_t_gates(&optimized);
381            if new_t_count < original_t_count {
382                return Ok(optimized);
383            }
384        }
385
386        Ok(gates)
387    }
388
389    fn name(&self) -> &str {
390        "T-Count Optimization"
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    #[test]
399    fn test_rotation_merging() {
400        let optimizer = PeepholeOptimizer::new();
401        let qubit = QubitId(0);
402
403        let gates: Vec<Box<dyn GateOp>> = vec![
404            Box::new(RotationZ {
405                target: qubit,
406                theta: PI / 4.0,
407            }),
408            Box::new(RotationZ {
409                target: qubit,
410                theta: PI / 4.0,
411            }),
412        ];
413
414        let result = optimizer.optimize(gates).unwrap();
415        assert_eq!(result.len(), 1);
416
417        if let Some(rz) = result[0].as_any().downcast_ref::<RotationZ>() {
418            assert!((rz.theta - PI / 2.0).abs() < 1e-10);
419        } else {
420            panic!("Expected RotationZ");
421        }
422    }
423
424    #[test]
425    fn test_zero_rotation_removal() {
426        let optimizer = PeepholeOptimizer::new();
427        let qubit = QubitId(0);
428
429        let gates: Vec<Box<dyn GateOp>> = vec![
430            Box::new(RotationX {
431                target: qubit,
432                theta: PI,
433            }),
434            Box::new(RotationX {
435                target: qubit,
436                theta: PI,
437            }),
438        ];
439
440        let result = optimizer.optimize(gates).unwrap();
441        assert_eq!(result.len(), 0); // 2π rotation should be removed
442    }
443
444    #[test]
445    fn test_cnot_rz_pattern() {
446        let optimizer = PeepholeOptimizer::new();
447        let q0 = QubitId(0);
448        let q1 = QubitId(1);
449
450        let gates: Vec<Box<dyn GateOp>> = vec![
451            Box::new(CNOT {
452                control: q0,
453                target: q1,
454            }),
455            Box::new(RotationZ {
456                target: q1,
457                theta: PI / 4.0,
458            }),
459            Box::new(CNOT {
460                control: q0,
461                target: q1,
462            }),
463        ];
464
465        let result = optimizer.optimize(gates).unwrap();
466        assert_eq!(result.len(), 1);
467        assert_eq!(result[0].name(), "CRZ");
468    }
469
470    #[test]
471    fn test_h_x_h_pattern() {
472        let optimizer = PeepholeOptimizer::new();
473        let qubit = QubitId(0);
474
475        let gates: Vec<Box<dyn GateOp>> = vec![
476            Box::new(Hadamard { target: qubit }),
477            Box::new(PauliX { target: qubit }),
478            Box::new(Hadamard { target: qubit }),
479        ];
480
481        let result = optimizer.optimize(gates).unwrap();
482        assert_eq!(result.len(), 1);
483        assert_eq!(result[0].name(), "Z");
484    }
485
486    #[test]
487    fn test_t_gate_combination() {
488        let optimizer = PeepholeOptimizer::new();
489        let qubit = QubitId(0);
490
491        let gates: Vec<Box<dyn GateOp>> =
492            vec![Box::new(T { target: qubit }), Box::new(T { target: qubit })];
493
494        let result = optimizer.optimize(gates).unwrap();
495        assert_eq!(result.len(), 1);
496        assert_eq!(result[0].name(), "S");
497    }
498}