Skip to main content

quantrs2_circuit/optimization/passes/
basic_passes.rs

1//! Basic optimization passes: gate cancellation, commutation, merging, and rotation merging.
2
3use crate::optimization::cost_model::CostModel;
4use crate::optimization::gate_properties::CommutationTable;
5use quantrs2_core::error::QuantRS2Result;
6use quantrs2_core::gate::{
7    multi,
8    single::{self, RotationX, RotationY, RotationZ},
9    GateOp,
10};
11use quantrs2_core::qubit::QubitId;
12use std::collections::HashSet;
13use std::f64::consts::PI;
14
15use super::OptimizationPass;
16
17/// Gate cancellation pass - removes redundant gates
18pub struct GateCancellation {
19    aggressive: bool,
20}
21
22impl GateCancellation {
23    #[must_use]
24    pub const fn new(aggressive: bool) -> Self {
25        Self { aggressive }
26    }
27}
28
29impl OptimizationPass for GateCancellation {
30    fn name(&self) -> &'static str {
31        "Gate Cancellation"
32    }
33
34    fn apply_to_gates(
35        &self,
36        gates: Vec<Box<dyn GateOp>>,
37        _cost_model: &dyn CostModel,
38    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
39        let mut optimized = Vec::new();
40        let mut i = 0;
41
42        while i < gates.len() {
43            if i + 1 < gates.len() {
44                let gate1 = &gates[i];
45                let gate2 = &gates[i + 1];
46
47                // Check if gates act on the same qubits
48                if gate1.qubits() == gate2.qubits() && gate1.name() == gate2.name() {
49                    // Check for self-inverse gates (H, X, Y, Z)
50                    match gate1.name() {
51                        "H" | "X" | "Y" | "Z" => {
52                            // These gates cancel when applied twice - skip both
53                            i += 2;
54                            continue;
55                        }
56                        "RX" | "RY" | "RZ" => {
57                            // Check if rotations cancel
58                            if let (Some(rx1), Some(rx2)) = (
59                                gate1.as_any().downcast_ref::<single::RotationX>(),
60                                gate2.as_any().downcast_ref::<single::RotationX>(),
61                            ) {
62                                let combined_angle = rx1.theta + rx2.theta;
63                                // Check if the combined rotation is effectively zero
64                                if (combined_angle % (2.0 * PI)).abs() < 1e-10 {
65                                    i += 2;
66                                    continue;
67                                }
68                            } else if let (Some(ry1), Some(ry2)) = (
69                                gate1.as_any().downcast_ref::<single::RotationY>(),
70                                gate2.as_any().downcast_ref::<single::RotationY>(),
71                            ) {
72                                let combined_angle = ry1.theta + ry2.theta;
73                                if (combined_angle % (2.0 * PI)).abs() < 1e-10 {
74                                    i += 2;
75                                    continue;
76                                }
77                            } else if let (Some(rz1), Some(rz2)) = (
78                                gate1.as_any().downcast_ref::<single::RotationZ>(),
79                                gate2.as_any().downcast_ref::<single::RotationZ>(),
80                            ) {
81                                let combined_angle = rz1.theta + rz2.theta;
82                                if (combined_angle % (2.0 * PI)).abs() < 1e-10 {
83                                    i += 2;
84                                    continue;
85                                }
86                            }
87                        }
88                        "CNOT" => {
89                            // CNOT is self-inverse
90                            if let (Some(cnot1), Some(cnot2)) = (
91                                gate1.as_any().downcast_ref::<multi::CNOT>(),
92                                gate2.as_any().downcast_ref::<multi::CNOT>(),
93                            ) {
94                                if cnot1.control == cnot2.control && cnot1.target == cnot2.target {
95                                    i += 2;
96                                    continue;
97                                }
98                            }
99                        }
100                        _ => {}
101                    }
102                }
103
104                // Look for more complex cancellations if aggressive mode is enabled
105                if self.aggressive && i + 2 < gates.len() {
106                    // Check for patterns like X-Y-X-Y or Z-H-Z-H
107                    let gate3 = &gates[i + 2];
108                    if gate1.qubits() == gate3.qubits()
109                        && gate1.name() == gate3.name()
110                        && i + 3 < gates.len()
111                    {
112                        let gate4 = &gates[i + 3];
113                        if gate2.qubits() == gate4.qubits()
114                            && gate2.name() == gate4.name()
115                            && gate1.qubits() == gate2.qubits()
116                        {
117                            // Pattern detected, check if it simplifies
118                            match (gate1.name(), gate2.name()) {
119                                ("X", "Y") | ("Y", "X") | ("Z", "H") | ("H", "Z") => {
120                                    // These patterns can sometimes simplify
121                                    // For now, we'll keep them as they might not always cancel
122                                }
123                                _ => {}
124                            }
125                        }
126                    }
127                }
128            }
129
130            // If we didn't skip, add the gate to optimized list
131            optimized.push(gates[i].clone());
132            i += 1;
133        }
134
135        Ok(optimized)
136    }
137}
138
139/// Gate commutation pass - reorders gates to enable other optimizations
140pub struct GateCommutation {
141    max_lookahead: usize,
142    commutation_table: CommutationTable,
143}
144
145impl GateCommutation {
146    #[must_use]
147    pub fn new(max_lookahead: usize) -> Self {
148        Self {
149            max_lookahead,
150            commutation_table: CommutationTable::new(),
151        }
152    }
153}
154
155impl GateCommutation {
156    /// Check if two gates commute based on commutation rules
157    fn gates_commute(&self, gate1: &dyn GateOp, gate2: &dyn GateOp) -> bool {
158        // Use commutation table if available
159        if self.commutation_table.commutes(gate1.name(), gate2.name()) {
160            return true;
161        }
162
163        // Additional commutation rules
164        match (gate1.name(), gate2.name()) {
165            // Pauli gates commutation
166            ("X", "X") | ("Y", "Y") | ("Z", "Z") => true,
167            ("I", _) | (_, "I") => true,
168
169            // Phase/T gates commute with Z
170            ("S" | "T", "Z") | ("Z", "S" | "T") => true,
171
172            // Same-axis rotations commute
173            ("RX", "RX") | ("RY", "RY") | ("RZ", "RZ") => true,
174
175            // RZ commutes with Z-like gates
176            ("RZ", "Z" | "S" | "T") | ("Z" | "S" | "T", "RZ") => true,
177
178            _ => false,
179        }
180    }
181
182    /// Check if swapping gates at position i would enable optimizations
183    fn would_benefit_from_swap(&self, gates: &[Box<dyn GateOp>], i: usize) -> bool {
184        if i + 2 >= gates.len() {
185            return false;
186        }
187
188        let gate1 = &gates[i];
189        let gate2 = &gates[i + 1];
190        let gate3 = &gates[i + 2];
191
192        // Check if swapping would create cancellation opportunities
193        if gate1.name() == gate3.name() && gate1.qubits() == gate3.qubits() {
194            // After swap, gate2 and gate3 (originally gate1) would be adjacent
195            match gate3.name() {
196                "H" | "X" | "Y" | "Z" => return true,
197                _ => {}
198            }
199        }
200
201        // Check if swapping would enable rotation merging
202        if gate2.name() == gate3.name() && gate2.qubits() == gate3.qubits() {
203            match gate2.name() {
204                "RX" | "RY" | "RZ" => return true,
205                _ => {}
206            }
207        }
208
209        false
210    }
211}
212
213impl OptimizationPass for GateCommutation {
214    fn name(&self) -> &'static str {
215        "Gate Commutation"
216    }
217
218    fn apply_to_gates(
219        &self,
220        gates: Vec<Box<dyn GateOp>>,
221        _cost_model: &dyn CostModel,
222    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
223        if gates.len() < 2 {
224            return Ok(gates);
225        }
226
227        let mut optimized = gates;
228        // Bound the number of outer iterations to prevent oscillation.
229        // Each pass does at most one forward scan; repeated passes let reordering
230        // propagate, but the bound ensures we always terminate.
231        let max_outer = self.max_lookahead * 2 + 1;
232        let mut outer_iter = 0;
233        let mut changed = true;
234
235        // Keep trying to commute gates until no more changes or the iteration
236        // bound is reached.
237        while changed && outer_iter < max_outer {
238            changed = false;
239            outer_iter += 1;
240            let mut i = 0;
241
242            while i < optimized.len().saturating_sub(1) {
243                let can_swap = {
244                    let gate1 = &optimized[i];
245                    let gate2 = &optimized[i + 1];
246
247                    // Check if gates act on different qubits (always commute)
248                    let qubits1: HashSet<_> = gate1.qubits().into_iter().collect();
249                    let qubits2: HashSet<_> = gate2.qubits().into_iter().collect();
250
251                    if qubits1.is_disjoint(&qubits2) {
252                        // Gates on disjoint qubits: only swap when it would enable
253                        // further optimisations (not just because they commute).
254                        self.would_benefit_from_swap(&optimized, i)
255                    } else if qubits1 == qubits2 {
256                        // Same qubit set: only swap when a downstream gate of the
257                        // same type exists that could later cancel or merge.
258                        // Swapping two identical same-qubit gates is always a no-op,
259                        // so guard against that first.
260                        if gate1.name() == gate2.name() {
261                            // Identical gate names on same qubits: swapping achieves
262                            // nothing useful — skip to avoid oscillation.
263                            false
264                        } else {
265                            self.gates_commute(gate1.as_ref(), gate2.as_ref())
266                        }
267                    } else {
268                        // Overlapping but not identical qubit sets
269                        false
270                    }
271                };
272
273                if can_swap {
274                    optimized.swap(i, i + 1);
275                    changed = true;
276                }
277                // Always advance forward to avoid cycling on the same pair.
278                i += 1;
279
280                // Limit lookahead to prevent excessive computation
281                if i >= self.max_lookahead {
282                    break;
283                }
284            }
285        }
286
287        Ok(optimized)
288    }
289}
290
291/// Gate merging pass - combines adjacent gates
292pub struct GateMerging {
293    merge_rotations: bool,
294    merge_threshold: f64,
295}
296
297impl GateMerging {
298    #[must_use]
299    pub const fn new(merge_rotations: bool, merge_threshold: f64) -> Self {
300        Self {
301            merge_rotations,
302            merge_threshold,
303        }
304    }
305}
306
307impl OptimizationPass for GateMerging {
308    fn name(&self) -> &'static str {
309        "Gate Merging"
310    }
311
312    fn apply_to_gates(
313        &self,
314        gates: Vec<Box<dyn GateOp>>,
315        _cost_model: &dyn CostModel,
316    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
317        let mut optimized = Vec::new();
318        let mut i = 0;
319
320        while i < gates.len() {
321            if i + 1 < gates.len() && self.merge_rotations {
322                let gate1 = &gates[i];
323                let gate2 = &gates[i + 1];
324
325                // Try to merge rotation gates
326                if gate1.qubits() == gate2.qubits() {
327                    let merged = match (gate1.name(), gate2.name()) {
328                        // Same-axis rotations can be directly merged
329                        ("RX", "RX") | ("RY", "RY") | ("RZ", "RZ") => {
330                            // Already handled by RotationMerging pass, skip here
331                            None
332                        }
333                        // Different axis rotations might be mergeable using Euler decomposition
334                        ("RZ" | "RY", "RX") | ("RX" | "RY", "RZ") | ("RX" | "RZ", "RY")
335                            if self.merge_threshold > 0.0 =>
336                        {
337                            // Complex merging would require matrix multiplication
338                            // For now, skip this advanced optimization
339                            None
340                        }
341                        // Phase gates (S, T) can sometimes be merged with RZ
342                        ("S" | "T", "RZ") | ("RZ", "S" | "T") => {
343                            // S = RZ(π/2), T = RZ(π/4)
344                            // These could be merged but need special handling
345                            None
346                        }
347                        _ => None,
348                    };
349
350                    if let Some(merged_gate) = merged {
351                        optimized.push(merged_gate);
352                        i += 2;
353                        continue;
354                    }
355                }
356            }
357
358            // Check for special merging patterns
359            if i + 1 < gates.len() {
360                let gate1 = &gates[i];
361                let gate2 = &gates[i + 1];
362
363                // H-Z-H = X, H-X-H = Z (basis change)
364                if i + 2 < gates.len() {
365                    let gate3 = &gates[i + 2];
366                    if gate1.name() == "H"
367                        && gate3.name() == "H"
368                        && gate1.qubits() == gate2.qubits()
369                        && gate2.qubits() == gate3.qubits()
370                    {
371                        match gate2.name() {
372                            "Z" => {
373                                // H-Z-H = X
374                                optimized.push(Box::new(single::PauliX {
375                                    target: gate1.qubits()[0],
376                                })
377                                    as Box<dyn GateOp>);
378                                i += 3;
379                                continue;
380                            }
381                            "X" => {
382                                // H-X-H = Z
383                                optimized.push(Box::new(single::PauliZ {
384                                    target: gate1.qubits()[0],
385                                })
386                                    as Box<dyn GateOp>);
387                                i += 3;
388                                continue;
389                            }
390                            _ => {}
391                        }
392                    }
393                }
394            }
395
396            // If no merging happened, keep the original gate
397            optimized.push(gates[i].clone());
398            i += 1;
399        }
400
401        Ok(optimized)
402    }
403}
404
405/// Rotation merging pass - specifically merges rotation gates
406pub struct RotationMerging {
407    tolerance: f64,
408}
409
410impl RotationMerging {
411    #[must_use]
412    pub const fn new(tolerance: f64) -> Self {
413        Self { tolerance }
414    }
415
416    /// Check if angle is effectively zero (or 2π multiple)
417    fn is_zero_rotation(&self, angle: f64) -> bool {
418        let normalized = angle % (2.0 * PI);
419        normalized.abs() < self.tolerance || 2.0f64.mul_add(-PI, normalized).abs() < self.tolerance
420    }
421
422    /// Merge two rotation angles
423    fn merge_angles(&self, angle1: f64, angle2: f64) -> f64 {
424        let merged = angle1 + angle2;
425        let normalized = merged % (2.0 * PI);
426        if normalized > PI {
427            2.0f64.mul_add(-PI, normalized)
428        } else if normalized < -PI {
429            2.0f64.mul_add(PI, normalized)
430        } else {
431            normalized
432        }
433    }
434}
435
436impl OptimizationPass for RotationMerging {
437    fn name(&self) -> &'static str {
438        "Rotation Merging"
439    }
440
441    fn apply_to_gates(
442        &self,
443        gates: Vec<Box<dyn GateOp>>,
444        _cost_model: &dyn CostModel,
445    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
446        let mut optimized = Vec::new();
447        let mut i = 0;
448
449        while i < gates.len() {
450            if i + 1 < gates.len() {
451                let gate1 = &gates[i];
452                let gate2 = &gates[i + 1];
453
454                // Check if both gates are rotations on the same qubit and axis
455                if gate1.qubits() == gate2.qubits() && gate1.name() == gate2.name() {
456                    match gate1.name() {
457                        "RX" => {
458                            if let (Some(rx1), Some(rx2)) = (
459                                gate1.as_any().downcast_ref::<single::RotationX>(),
460                                gate2.as_any().downcast_ref::<single::RotationX>(),
461                            ) {
462                                let merged_angle = self.merge_angles(rx1.theta, rx2.theta);
463                                if self.is_zero_rotation(merged_angle) {
464                                    // Skip both gates if the merged rotation is effectively zero
465                                    i += 2;
466                                    continue;
467                                }
468                                // Create a new merged rotation gate
469                                optimized.push(Box::new(single::RotationX {
470                                    target: rx1.target,
471                                    theta: merged_angle,
472                                })
473                                    as Box<dyn GateOp>);
474                                i += 2;
475                                continue;
476                            }
477                        }
478                        "RY" => {
479                            if let (Some(ry1), Some(ry2)) = (
480                                gate1.as_any().downcast_ref::<single::RotationY>(),
481                                gate2.as_any().downcast_ref::<single::RotationY>(),
482                            ) {
483                                let merged_angle = self.merge_angles(ry1.theta, ry2.theta);
484                                if self.is_zero_rotation(merged_angle) {
485                                    i += 2;
486                                    continue;
487                                }
488                                optimized.push(Box::new(single::RotationY {
489                                    target: ry1.target,
490                                    theta: merged_angle,
491                                })
492                                    as Box<dyn GateOp>);
493                                i += 2;
494                                continue;
495                            }
496                        }
497                        "RZ" => {
498                            if let (Some(rz1), Some(rz2)) = (
499                                gate1.as_any().downcast_ref::<single::RotationZ>(),
500                                gate2.as_any().downcast_ref::<single::RotationZ>(),
501                            ) {
502                                let merged_angle = self.merge_angles(rz1.theta, rz2.theta);
503                                if self.is_zero_rotation(merged_angle) {
504                                    i += 2;
505                                    continue;
506                                }
507                                optimized.push(Box::new(single::RotationZ {
508                                    target: rz1.target,
509                                    theta: merged_angle,
510                                })
511                                    as Box<dyn GateOp>);
512                                i += 2;
513                                continue;
514                            }
515                        }
516                        _ => {}
517                    }
518                }
519            }
520
521            // If we didn't merge, keep the original gate
522            optimized.push(gates[i].clone());
523            i += 1;
524        }
525
526        Ok(optimized)
527    }
528}