Skip to main content

quantrs2_circuit/optimization/passes/
mod.rs

1//! Individual optimization passes
2//!
3//! This module implements various optimization passes that can be applied to quantum circuits.
4//! It re-exports passes from the sub-modules:
5//! - `basic_passes`: GateCancellation, GateCommutation, GateMerging, RotationMerging
6//! - `advanced_passes`: DecompositionOptimization, CostBasedOptimization, TwoQubitOptimization
7//! - `rewriting_passes`: PeepholeOptimization, TemplateMatching, CircuitRewriting, ParallelizationPass
8
9pub mod advanced_passes;
10pub mod basic_passes;
11pub mod rewriting_passes;
12
13use crate::builder::Circuit;
14use crate::optimization::cost_model::CostModel;
15use quantrs2_core::error::QuantRS2Result;
16use quantrs2_core::gate::GateOp;
17
18// Re-export all public pass types and functions
19pub use advanced_passes::{
20    CostBasedOptimization, CostTarget, DecompositionOptimization, TwoQubitOptimization,
21};
22pub use basic_passes::{GateCancellation, GateCommutation, GateMerging, RotationMerging};
23pub use rewriting_passes::{
24    all_same_single_qubit, extract_rx_angle, extract_ry_angle, extract_rz_angle, is_identity_angle,
25    normalise_angle, parallelize_gates, single_qubit_of, utils, CircuitRewriting, CircuitTemplate,
26    ParallelizationPass, PeepholeOptimization, PeepholePattern, RewriteRule, TemplateMatching,
27};
28
29/// Trait for optimization passes (object-safe version)
30pub trait OptimizationPass: Send + Sync {
31    /// Name of the optimization pass
32    fn name(&self) -> &str;
33
34    /// Apply the optimization pass to a gate list
35    fn apply_to_gates(
36        &self,
37        gates: Vec<Box<dyn GateOp>>,
38        cost_model: &dyn CostModel,
39    ) -> QuantRS2Result<Vec<Box<dyn GateOp>>>;
40
41    /// Check if this pass should be applied
42    fn should_apply(&self) -> bool {
43        true
44    }
45}
46
47/// Extension trait for circuit operations
48pub trait OptimizationPassExt<const N: usize> {
49    fn apply(&self, circuit: &Circuit<N>, cost_model: &dyn CostModel)
50        -> QuantRS2Result<Circuit<N>>;
51    fn should_apply_to_circuit(&self, circuit: &Circuit<N>) -> bool;
52}
53
54impl<T: OptimizationPass + ?Sized, const N: usize> OptimizationPassExt<N> for T {
55    fn apply(
56        &self,
57        circuit: &Circuit<N>,
58        cost_model: &dyn CostModel,
59    ) -> QuantRS2Result<Circuit<N>> {
60        // Extract gates from the circuit as owned boxes.
61        let gates: Vec<Box<dyn GateOp>> = circuit.gates_as_boxes();
62
63        // Run the optimisation pass on the gate list.
64        let optimized_gates = self.apply_to_gates(gates, cost_model)?;
65
66        // Reconstruct a new circuit from the optimised gate list.
67        Circuit::<N>::from_gates(optimized_gates)
68    }
69
70    fn should_apply_to_circuit(&self, _circuit: &Circuit<N>) -> bool {
71        self.should_apply()
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use crate::builder::Circuit;
79    use crate::optimization::cost_model::AbstractCostModel;
80    use quantrs2_core::gate::single::{Hadamard, PauliX};
81    use quantrs2_core::qubit::QubitId;
82
83    /// `GateCancellation` is a pass-through if nothing cancels.
84    #[test]
85    fn test_gate_cancellation_no_op() {
86        let pass = GateCancellation::new(false);
87        let cost = AbstractCostModel::default();
88
89        let mut circuit = Circuit::<2>::new();
90        circuit
91            .add_gate(Hadamard {
92                target: QubitId::new(0),
93            })
94            .expect("add gate");
95
96        let result = pass.apply(&circuit, &cost).expect("apply pass");
97        assert_eq!(result.num_gates(), 1, "single H should not be removed");
98    }
99
100    /// Two consecutive X gates on the same qubit should cancel (X is self-inverse).
101    #[test]
102    fn test_xx_cancellation_reduces_gate_count() {
103        let pass = GateCancellation::new(false);
104        let cost = AbstractCostModel::default();
105
106        let mut circuit = Circuit::<2>::new();
107        circuit
108            .add_gate(PauliX {
109                target: QubitId::new(0),
110            })
111            .expect("add gate 1");
112        circuit
113            .add_gate(PauliX {
114                target: QubitId::new(0),
115            })
116            .expect("add gate 2");
117
118        assert_eq!(
119            circuit.num_gates(),
120            2,
121            "circuit should have 2 gates before optimization"
122        );
123
124        let result = pass.apply(&circuit, &cost).expect("apply pass");
125        assert_eq!(
126            result.num_gates(),
127            0,
128            "X-X on same qubit should cancel to empty"
129        );
130    }
131
132    /// Two consecutive H gates on the same qubit should cancel (H is self-inverse).
133    #[test]
134    fn test_hh_cancellation_reduces_gate_count() {
135        let pass = GateCancellation::new(false);
136        let cost = AbstractCostModel::default();
137
138        let mut circuit = Circuit::<2>::new();
139        circuit
140            .add_gate(Hadamard {
141                target: QubitId::new(1),
142            })
143            .expect("add gate 1");
144        circuit
145            .add_gate(Hadamard {
146                target: QubitId::new(1),
147            })
148            .expect("add gate 2");
149
150        assert_eq!(
151            circuit.num_gates(),
152            2,
153            "circuit should have 2 gates before optimization"
154        );
155
156        let result = pass.apply(&circuit, &cost).expect("apply pass");
157        assert_eq!(
158            result.num_gates(),
159            0,
160            "H-H on same qubit should cancel to empty"
161        );
162    }
163
164    /// X-X on *different* qubits must NOT cancel.
165    #[test]
166    fn test_xx_different_qubits_no_cancellation() {
167        let pass = GateCancellation::new(false);
168        let cost = AbstractCostModel::default();
169
170        let mut circuit = Circuit::<2>::new();
171        circuit
172            .add_gate(PauliX {
173                target: QubitId::new(0),
174            })
175            .expect("add gate 1");
176        circuit
177            .add_gate(PauliX {
178                target: QubitId::new(1),
179            })
180            .expect("add gate 2");
181
182        let result = pass.apply(&circuit, &cost).expect("apply pass");
183        assert_eq!(
184            result.num_gates(),
185            2,
186            "X on qubit 0 and X on qubit 1 should not cancel"
187        );
188    }
189
190    /// Verify that `OptimizationPassExt::apply` actually reconstructs the circuit
191    /// (gate count changes, not just a clone).
192    #[test]
193    fn test_apply_ext_returns_optimized_circuit() {
194        let pass = GateCancellation::new(false);
195        let cost = AbstractCostModel::default();
196
197        // Three gates: X(0), X(0), H(1)  →  X(0) and X(0) cancel, leaving H(1).
198        let mut circuit = Circuit::<2>::new();
199        circuit.x(QubitId::new(0)).expect("x 0");
200        circuit.x(QubitId::new(0)).expect("x 0 again");
201        circuit.h(QubitId::new(1)).expect("h 1");
202
203        let result = pass.apply(&circuit, &cost).expect("apply");
204        assert_eq!(result.num_gates(), 1, "only H on qubit 1 should remain");
205        assert_eq!(result.gates()[0].name(), "H");
206    }
207
208    /// CNOT(0,1) followed immediately by CNOT(0,1) must cancel to empty.
209    #[test]
210    fn test_two_qubit_cnot_cancellation() {
211        use quantrs2_core::gate::multi::CNOT;
212        let pass = TwoQubitOptimization::new(false, true);
213        let cost = AbstractCostModel::default();
214        let q0 = QubitId::new(0);
215        let q1 = QubitId::new(1);
216        let gates: Vec<Box<dyn GateOp>> = vec![
217            Box::new(CNOT {
218                control: q0,
219                target: q1,
220            }),
221            Box::new(CNOT {
222                control: q0,
223                target: q1,
224            }),
225        ];
226        let result = pass.apply_to_gates(gates, &cost).expect("apply");
227        assert_eq!(result.len(), 0, "CNOT(0,1)+CNOT(0,1) must cancel");
228    }
229
230    /// CNOT(a,b), CNOT(b,a), CNOT(a,b) must become SWAP(a,b).
231    #[test]
232    fn test_two_qubit_swap_detection() {
233        use quantrs2_core::gate::multi::CNOT;
234        let pass = TwoQubitOptimization::new(false, true);
235        let cost = AbstractCostModel::default();
236        let q0 = QubitId::new(0);
237        let q1 = QubitId::new(1);
238        let gates: Vec<Box<dyn GateOp>> = vec![
239            Box::new(CNOT {
240                control: q0,
241                target: q1,
242            }),
243            Box::new(CNOT {
244                control: q1,
245                target: q0,
246            }),
247            Box::new(CNOT {
248                control: q0,
249                target: q1,
250            }),
251        ];
252        let result = pass.apply_to_gates(gates, &cost).expect("apply");
253        assert_eq!(
254            result.len(),
255            1,
256            "3-CNOT SWAP pattern must become 1 SWAP gate"
257        );
258        assert_eq!(result[0].name(), "SWAP");
259    }
260
261    /// H(q0) and H(q1) are independent → both appear in the first layer (depth 1).
262    /// CNOT(q0,q1) depends on both → appears in the second layer (depth 2).
263    #[test]
264    fn test_parallelize_independent_then_cnot() {
265        use quantrs2_core::gate::multi::CNOT;
266        let q0 = QubitId::new(0);
267        let q1 = QubitId::new(1);
268        let gates: Vec<Box<dyn GateOp>> = vec![
269            Box::new(Hadamard { target: q0 }),
270            Box::new(Hadamard { target: q1 }),
271            Box::new(CNOT {
272                control: q0,
273                target: q1,
274            }),
275        ];
276        let result = parallelize_gates(gates);
277        assert_eq!(result.len(), 3);
278        assert_eq!(result[0].name(), "H");
279        assert_eq!(result[0].qubits()[0], q0);
280        assert_eq!(result[1].name(), "H");
281        assert_eq!(result[1].qubits()[0], q1);
282        assert_eq!(result[2].name(), "CNOT");
283    }
284}