quantrs2_core/optimization/
peephole.rs

1//! Peephole optimization for quantum circuits
2//!
3//! This module implements peephole optimization, which looks for small patterns
4//! of gates that can be simplified or eliminated.
5
6use crate::error::QuantRS2Result;
7use crate::gate::{multi::*, single::*, GateOp};
8use std::f64::consts::PI;
9
10use super::{gates_can_commute, OptimizationPass};
11
12/// Peephole optimization pass
13pub struct PeepholeOptimizer {
14    /// Enable rotation merging
15    pub merge_rotations: bool,
16    /// Enable identity removal
17    pub remove_identities: bool,
18    /// Enable gate commutation
19    pub enable_commutation: bool,
20    /// Tolerance for identifying zero rotations
21    pub zero_tolerance: f64,
22}
23
24impl Default for PeepholeOptimizer {
25    fn default() -> Self {
26        Self {
27            merge_rotations: true,
28            remove_identities: true,
29            enable_commutation: true,
30            zero_tolerance: 1e-10,
31        }
32    }
33}
34
35impl PeepholeOptimizer {
36    /// Create a new peephole optimizer
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Check if a rotation angle is effectively zero
42    fn is_zero_rotation(&self, angle: f64) -> bool {
43        let normalized = angle % (2.0 * PI);
44        normalized.abs() < self.zero_tolerance
45            || (normalized - 2.0 * PI).abs() < self.zero_tolerance
46    }
47
48    /// Try to simplify a window of gates
49    #[allow(dead_code)]
50    fn simplify_window(
51        &self,
52        window: &[Box<dyn GateOp>],
53    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
54        match window.len() {
55            2 => self.simplify_pair(&window[0], &window[1]),
56            3 => self.simplify_triple(&window[0], &window[1], &window[2]),
57            _ => Ok(None),
58        }
59    }
60
61    /// Simplify a pair of gates
62    fn simplify_pair(
63        &self,
64        gate1: &Box<dyn GateOp>,
65        gate2: &Box<dyn GateOp>,
66    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
67        // Handle rotation merging
68        if self.merge_rotations && gate1.qubits() == gate2.qubits() && gate1.qubits().len() == 1 {
69            let qubit = gate1.qubits()[0];
70
71            match (gate1.name(), gate2.name()) {
72                ("RX", "RX") => {
73                    if let (Some(rx1), Some(rx2)) = (
74                        gate1.as_any().downcast_ref::<RotationX>(),
75                        gate2.as_any().downcast_ref::<RotationX>(),
76                    ) {
77                        let combined_angle = rx1.theta + rx2.theta;
78                        if self.is_zero_rotation(combined_angle) {
79                            return Ok(Some(vec![])); // Remove both
80                        }
81                        return Ok(Some(vec![Box::new(RotationX {
82                            target: qubit,
83                            theta: combined_angle,
84                        })]));
85                    }
86                }
87                ("RY", "RY") => {
88                    if let (Some(ry1), Some(ry2)) = (
89                        gate1.as_any().downcast_ref::<RotationY>(),
90                        gate2.as_any().downcast_ref::<RotationY>(),
91                    ) {
92                        let combined_angle = ry1.theta + ry2.theta;
93                        if self.is_zero_rotation(combined_angle) {
94                            return Ok(Some(vec![])); // Remove both
95                        }
96                        return Ok(Some(vec![Box::new(RotationY {
97                            target: qubit,
98                            theta: combined_angle,
99                        })]));
100                    }
101                }
102                ("RZ", "RZ") => {
103                    if let (Some(rz1), Some(rz2)) = (
104                        gate1.as_any().downcast_ref::<RotationZ>(),
105                        gate2.as_any().downcast_ref::<RotationZ>(),
106                    ) {
107                        let combined_angle = rz1.theta + rz2.theta;
108                        if self.is_zero_rotation(combined_angle) {
109                            return Ok(Some(vec![])); // Remove both
110                        }
111                        return Ok(Some(vec![Box::new(RotationZ {
112                            target: qubit,
113                            theta: combined_angle,
114                        })]));
115                    }
116                }
117                _ => {}
118            }
119        }
120
121        // Handle special patterns
122        if gate1.qubits() == gate2.qubits() {
123            match (gate1.name(), gate2.name()) {
124                // T gate combinations
125                ("T", "T") => {
126                    // T² = S
127                    return Ok(Some(vec![Box::new(Phase {
128                        target: gate1.qubits()[0],
129                    })]));
130                }
131                ("T†", "T†") => {
132                    // T†² = S†
133                    return Ok(Some(vec![Box::new(PhaseDagger {
134                        target: gate1.qubits()[0],
135                    })]));
136                }
137
138                // S and T combinations
139                ("S", "T") | ("T", "S") => {
140                    // S·T = T·S = T³ (since T² = S)
141                    let qubit = gate1.qubits()[0];
142                    return Ok(Some(vec![
143                        Box::new(T { target: qubit }),
144                        Box::new(T { target: qubit }),
145                        Box::new(T { target: qubit }),
146                    ]));
147                }
148
149                _ => {}
150            }
151        }
152
153        // Try commutation if enabled
154        if self.enable_commutation && gates_can_commute(gate1.as_ref(), gate2.as_ref()) {
155            // Return in swapped order if it might help later optimizations
156            // This is a heuristic - we prefer to move single-qubit gates earlier
157            if gate1.qubits().len() > gate2.qubits().len() {
158                return Ok(Some(vec![gate2.clone_gate(), gate1.clone_gate()]));
159            }
160        }
161
162        Ok(None)
163    }
164
165    /// Simplify a triple of gates
166    fn simplify_triple(
167        &self,
168        gate1: &Box<dyn GateOp>,
169        gate2: &Box<dyn GateOp>,
170        gate3: &Box<dyn GateOp>,
171    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
172        // CX-Rz-CX pattern (controlled rotation)
173        if gate1.name() == "CNOT" && gate3.name() == "CNOT" && gate2.name() == "RZ" {
174            if let (Some(cx1), Some(cx2), Some(rz)) = (
175                gate1.as_any().downcast_ref::<CNOT>(),
176                gate3.as_any().downcast_ref::<CNOT>(),
177                gate2.as_any().downcast_ref::<RotationZ>(),
178            ) {
179                // Check if it's the controlled-Rz pattern
180                if cx1.control == cx2.control && cx1.target == cx2.target && rz.target == cx1.target
181                {
182                    // This is equivalent to a CRZ gate
183                    return Ok(Some(vec![Box::new(CRZ {
184                        control: cx1.control,
185                        target: cx1.target,
186                        theta: rz.theta,
187                    })]));
188                }
189            }
190        }
191
192        // H-X-H = Z pattern
193        if gate1.name() == "H"
194            && gate2.name() == "X"
195            && gate3.name() == "H"
196            && gate1.qubits() == gate2.qubits()
197            && gate2.qubits() == gate3.qubits()
198        {
199            return Ok(Some(vec![Box::new(PauliZ {
200                target: gate1.qubits()[0],
201            })]));
202        }
203
204        // H-Z-H = X pattern
205        if gate1.name() == "H"
206            && gate2.name() == "Z"
207            && gate3.name() == "H"
208            && gate1.qubits() == gate2.qubits()
209            && gate2.qubits() == gate3.qubits()
210        {
211            return Ok(Some(vec![Box::new(PauliX {
212                target: gate1.qubits()[0],
213            })]));
214        }
215
216        // X-Y-X = -Y pattern
217        if gate1.name() == "X"
218            && gate2.name() == "Y"
219            && gate3.name() == "X"
220            && gate1.qubits() == gate2.qubits()
221            && gate2.qubits() == gate3.qubits()
222        {
223            let qubit = gate1.qubits()[0];
224            return Ok(Some(vec![
225                Box::new(PauliY { target: qubit }),
226                Box::new(PauliZ { target: qubit }), // Global phase -1
227            ]));
228        }
229
230        Ok(None)
231    }
232
233    /// Remove identity rotations
234    fn remove_identity_rotations(&self, gates: Vec<Box<dyn GateOp>>) -> Vec<Box<dyn GateOp>> {
235        gates
236            .into_iter()
237            .filter(|gate| match gate.name() {
238                "RX" => {
239                    if let Some(rx) = gate.as_any().downcast_ref::<RotationX>() {
240                        !self.is_zero_rotation(rx.theta)
241                    } else {
242                        true
243                    }
244                }
245                "RY" => {
246                    if let Some(ry) = gate.as_any().downcast_ref::<RotationY>() {
247                        !self.is_zero_rotation(ry.theta)
248                    } else {
249                        true
250                    }
251                }
252                "RZ" => {
253                    if let Some(rz) = gate.as_any().downcast_ref::<RotationZ>() {
254                        !self.is_zero_rotation(rz.theta)
255                    } else {
256                        true
257                    }
258                }
259                _ => true,
260            })
261            .collect()
262    }
263}
264
265impl OptimizationPass for PeepholeOptimizer {
266    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
267        let mut current = gates;
268        let mut changed = true;
269        let max_iterations = 10; // Prevent infinite loops
270        let mut iterations = 0;
271
272        while changed && iterations < max_iterations {
273            changed = false;
274            let mut optimized = Vec::new();
275            let mut i = 0;
276
277            while i < current.len() {
278                // Try triple patterns first
279                if i + 2 < current.len() {
280                    if let Some(simplified) =
281                        self.simplify_triple(&current[i], &current[i + 1], &current[i + 2])?
282                    {
283                        optimized.extend(simplified);
284                        i += 3;
285                        changed = true;
286                        continue;
287                    }
288                }
289
290                // Try pair patterns
291                if i + 1 < current.len() {
292                    if let Some(simplified) = self.simplify_pair(&current[i], &current[i + 1])? {
293                        optimized.extend(simplified);
294                        i += 2;
295                        changed = true;
296                        continue;
297                    }
298                }
299
300                // No pattern matched, keep the gate
301                optimized.push(current[i].clone_gate());
302                i += 1;
303            }
304
305            current = optimized;
306            iterations += 1;
307        }
308
309        // Final pass to remove identity rotations
310        if self.remove_identities {
311            current = self.remove_identity_rotations(current);
312        }
313
314        Ok(current)
315    }
316
317    fn name(&self) -> &str {
318        "Peephole Optimization"
319    }
320}
321
322/// Specialized optimizer for T-count reduction
323pub struct TCountOptimizer {
324    /// Maximum search depth for optimization
325    pub max_depth: usize,
326}
327
328impl Default for TCountOptimizer {
329    fn default() -> Self {
330        Self::new()
331    }
332}
333
334impl TCountOptimizer {
335    pub fn new() -> Self {
336        Self { max_depth: 4 }
337    }
338
339    /// Count T gates in a sequence
340    fn count_t_gates(gates: &[Box<dyn GateOp>]) -> usize {
341        gates
342            .iter()
343            .filter(|g| g.name() == "T" || g.name() == "T†")
344            .count()
345    }
346
347    /// Try to reduce T-count by recognizing special patterns
348    fn reduce_t_count(
349        &self,
350        gates: &[Box<dyn GateOp>],
351    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
352        // Pattern: T-S-T = S-T-S (both have T-count 2, but might enable other optimizations)
353        if gates.len() >= 3 {
354            for i in 0..gates.len() - 2 {
355                if gates[i].name() == "T"
356                    && gates[i + 1].name() == "S"
357                    && gates[i + 2].name() == "T"
358                    && gates[i].qubits() == gates[i + 1].qubits()
359                    && gates[i + 1].qubits() == gates[i + 2].qubits()
360                {
361                    let qubit = gates[i].qubits()[0];
362                    let mut result = Vec::new();
363
364                    // Copy gates before pattern
365                    for j in 0..i {
366                        result.push(gates[j].clone_gate());
367                    }
368
369                    // Replace pattern
370                    result.push(Box::new(Phase { target: qubit }) as Box<dyn GateOp>);
371                    result.push(Box::new(T { target: qubit }) as Box<dyn GateOp>);
372                    result.push(Box::new(Phase { target: qubit }) as Box<dyn GateOp>);
373
374                    // Copy gates after pattern
375                    for j in i + 3..gates.len() {
376                        result.push(gates[j].clone_gate());
377                    }
378
379                    return Ok(Some(result));
380                }
381            }
382        }
383
384        Ok(None)
385    }
386}
387
388impl OptimizationPass for TCountOptimizer {
389    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
390        let original_t_count = Self::count_t_gates(&gates);
391
392        if let Some(optimized) = self.reduce_t_count(&gates)? {
393            let new_t_count = Self::count_t_gates(&optimized);
394            if new_t_count < original_t_count {
395                return Ok(optimized);
396            }
397        }
398
399        Ok(gates)
400    }
401
402    fn name(&self) -> &str {
403        "T-Count Optimization"
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use crate::prelude::QubitId;
411
412    #[test]
413    fn test_rotation_merging() {
414        let optimizer = PeepholeOptimizer::new();
415        let qubit = QubitId(0);
416
417        let gates: Vec<Box<dyn GateOp>> = vec![
418            Box::new(RotationZ {
419                target: qubit,
420                theta: PI / 4.0,
421            }),
422            Box::new(RotationZ {
423                target: qubit,
424                theta: PI / 4.0,
425            }),
426        ];
427
428        let result = optimizer.optimize(gates).unwrap();
429        assert_eq!(result.len(), 1);
430
431        if let Some(rz) = result[0].as_any().downcast_ref::<RotationZ>() {
432            assert!((rz.theta - PI / 2.0).abs() < 1e-10);
433        } else {
434            panic!("Expected RotationZ");
435        }
436    }
437
438    #[test]
439    fn test_zero_rotation_removal() {
440        let optimizer = PeepholeOptimizer::new();
441        let qubit = QubitId(0);
442
443        let gates: Vec<Box<dyn GateOp>> = vec![
444            Box::new(RotationX {
445                target: qubit,
446                theta: PI,
447            }),
448            Box::new(RotationX {
449                target: qubit,
450                theta: PI,
451            }),
452        ];
453
454        let result = optimizer.optimize(gates).unwrap();
455        assert_eq!(result.len(), 0); // 2π rotation should be removed
456    }
457
458    #[test]
459    fn test_cnot_rz_pattern() {
460        let optimizer = PeepholeOptimizer::new();
461        let q0 = QubitId(0);
462        let q1 = QubitId(1);
463
464        let gates: Vec<Box<dyn GateOp>> = vec![
465            Box::new(CNOT {
466                control: q0,
467                target: q1,
468            }),
469            Box::new(RotationZ {
470                target: q1,
471                theta: PI / 4.0,
472            }),
473            Box::new(CNOT {
474                control: q0,
475                target: q1,
476            }),
477        ];
478
479        let result = optimizer.optimize(gates).unwrap();
480        assert_eq!(result.len(), 1);
481        assert_eq!(result[0].name(), "CRZ");
482    }
483
484    #[test]
485    fn test_h_x_h_pattern() {
486        let optimizer = PeepholeOptimizer::new();
487        let qubit = QubitId(0);
488
489        let gates: Vec<Box<dyn GateOp>> = vec![
490            Box::new(Hadamard { target: qubit }),
491            Box::new(PauliX { target: qubit }),
492            Box::new(Hadamard { target: qubit }),
493        ];
494
495        let result = optimizer.optimize(gates).unwrap();
496        assert_eq!(result.len(), 1);
497        assert_eq!(result[0].name(), "Z");
498    }
499
500    #[test]
501    fn test_t_gate_combination() {
502        let optimizer = PeepholeOptimizer::new();
503        let qubit = QubitId(0);
504
505        let gates: Vec<Box<dyn GateOp>> =
506            vec![Box::new(T { target: qubit }), Box::new(T { target: qubit })];
507
508        let result = optimizer.optimize(gates).unwrap();
509        assert_eq!(result.len(), 1);
510        assert_eq!(result[0].name(), "S");
511    }
512}