quantrs2_core/optimization/
fusion.rs

1//! Gate fusion optimization pass
2//!
3//! This module implements gate fusion, which combines adjacent compatible gates
4//! into single operations to reduce circuit depth and improve performance.
5
6use crate::error::{QuantRS2Error, QuantRS2Result};
7use crate::gate::{multi::*, single::*, GateOp};
8use crate::synthesis::{identify_gate, synthesize_unitary};
9use scirs2_core::ndarray::Array2;
10
11use super::OptimizationPass;
12
13/// Gate fusion optimization pass
14pub struct GateFusion {
15    /// Whether to fuse single-qubit gates
16    pub fuse_single_qubit: bool,
17    /// Whether to fuse two-qubit gates
18    pub fuse_two_qubit: bool,
19    /// Maximum number of gates to fuse together
20    pub max_fusion_size: usize,
21    /// Tolerance for gate identification
22    pub tolerance: f64,
23}
24
25impl Default for GateFusion {
26    fn default() -> Self {
27        Self {
28            fuse_single_qubit: true,
29            fuse_two_qubit: true,
30            max_fusion_size: 4,
31            tolerance: 1e-10,
32        }
33    }
34}
35
36impl GateFusion {
37    /// Create a new gate fusion pass
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Try to fuse a sequence of single-qubit gates
43    fn fuse_single_qubit_gates(
44        &self,
45        gates: &[Box<dyn GateOp>],
46    ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
47        if gates.is_empty() {
48            return Ok(None);
49        }
50
51        // Check all gates act on the same qubit
52        let target_qubit = gates[0].qubits()[0];
53        if !gates
54            .iter()
55            .all(|g| g.qubits().len() == 1 && g.qubits()[0] == target_qubit)
56        {
57            return Ok(None);
58        }
59
60        // Compute the combined unitary matrix
61        let mut combined = Array2::eye(2);
62        for gate in gates {
63            let gate_matrix = gate.matrix()?;
64            let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)
65                .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
66            combined = combined.dot(&gate_array);
67        }
68
69        // Try to identify the combined gate
70        if let Some(gate_name) = identify_gate(&combined.view(), self.tolerance) {
71            // Convert identified gate name to actual gate
72            let identified_gate = match gate_name.as_str() {
73                "X" => Some(Box::new(PauliX {
74                    target: target_qubit,
75                }) as Box<dyn GateOp>),
76                "Y" => Some(Box::new(PauliY {
77                    target: target_qubit,
78                }) as Box<dyn GateOp>),
79                "Z" => Some(Box::new(PauliZ {
80                    target: target_qubit,
81                }) as Box<dyn GateOp>),
82                "H" => Some(Box::new(Hadamard {
83                    target: target_qubit,
84                }) as Box<dyn GateOp>),
85                "S" => Some(Box::new(Phase {
86                    target: target_qubit,
87                }) as Box<dyn GateOp>),
88                "S†" => Some(Box::new(PhaseDagger {
89                    target: target_qubit,
90                }) as Box<dyn GateOp>),
91                "T" => Some(Box::new(T {
92                    target: target_qubit,
93                }) as Box<dyn GateOp>),
94                "T†" => Some(Box::new(TDagger {
95                    target: target_qubit,
96                }) as Box<dyn GateOp>),
97                "I" | _ => None, // Identity or unknown
98            };
99
100            if let Some(gate) = identified_gate {
101                return Ok(Some(gate));
102            }
103        }
104
105        // If we can't identify it, synthesize it
106        let synthesized = synthesize_unitary(&combined.view(), &[target_qubit])?;
107        if synthesized.len() < gates.len() {
108            // Only use synthesis if it reduces gate count
109            if synthesized.len() == 1 {
110                Ok(synthesized.into_iter().next())
111            } else {
112                Ok(None)
113            }
114        } else {
115            Ok(None)
116        }
117    }
118
119    /// Try to fuse CNOT gates
120    fn fuse_cnot_gates(
121        &self,
122        gates: &[Box<dyn GateOp>],
123    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
124        if gates.len() < 2 {
125            return Ok(None);
126        }
127
128        let mut fused = Vec::new();
129        let mut i = 0;
130
131        while i < gates.len() {
132            if i + 1 < gates.len() {
133                if let (Some(cnot1), Some(cnot2)) = (
134                    gates[i].as_any().downcast_ref::<CNOT>(),
135                    gates[i + 1].as_any().downcast_ref::<CNOT>(),
136                ) {
137                    // Two CNOTs with same control and target cancel
138                    if cnot1.control == cnot2.control && cnot1.target == cnot2.target {
139                        // Skip both gates
140                        i += 2;
141                        continue;
142                    }
143                    // CNOT(a,b) followed by CNOT(b,a) is a SWAP
144                    else if cnot1.control == cnot2.target && cnot1.target == cnot2.control {
145                        fused.push(Box::new(SWAP {
146                            qubit1: cnot1.control,
147                            qubit2: cnot1.target,
148                        }) as Box<dyn GateOp>);
149                        i += 2;
150                        continue;
151                    }
152                }
153            }
154
155            // No fusion possible, keep the gate
156            fused.push(gates[i].clone_gate());
157            i += 1;
158        }
159
160        if fused.len() < gates.len() {
161            Ok(Some(fused))
162        } else {
163            Ok(None)
164        }
165    }
166
167    /// Try to fuse rotation gates
168    fn fuse_rotation_gates(
169        &self,
170        gates: &[Box<dyn GateOp>],
171    ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
172        if gates.len() < 2 {
173            return Ok(None);
174        }
175
176        // Check if all gates are rotations around the same axis on the same qubit
177        let first_gate = &gates[0];
178        let target_qubit = first_gate.qubits()[0];
179
180        match first_gate.name() {
181            "RX" => {
182                let mut total_angle = 0.0;
183                for gate in gates {
184                    if let Some(rx) = gate.as_any().downcast_ref::<RotationX>() {
185                        if rx.target != target_qubit {
186                            return Ok(None);
187                        }
188                        total_angle += rx.theta;
189                    } else {
190                        return Ok(None);
191                    }
192                }
193                Ok(Some(Box::new(RotationX {
194                    target: target_qubit,
195                    theta: total_angle,
196                })))
197            }
198            "RY" => {
199                let mut total_angle = 0.0;
200                for gate in gates {
201                    if let Some(ry) = gate.as_any().downcast_ref::<RotationY>() {
202                        if ry.target != target_qubit {
203                            return Ok(None);
204                        }
205                        total_angle += ry.theta;
206                    } else {
207                        return Ok(None);
208                    }
209                }
210                Ok(Some(Box::new(RotationY {
211                    target: target_qubit,
212                    theta: total_angle,
213                })))
214            }
215            "RZ" => {
216                let mut total_angle = 0.0;
217                for gate in gates {
218                    if let Some(rz) = gate.as_any().downcast_ref::<RotationZ>() {
219                        if rz.target != target_qubit {
220                            return Ok(None);
221                        }
222                        total_angle += rz.theta;
223                    } else {
224                        return Ok(None);
225                    }
226                }
227                Ok(Some(Box::new(RotationZ {
228                    target: target_qubit,
229                    theta: total_angle,
230                })))
231            }
232            _ => Ok(None),
233        }
234    }
235
236    /// Find fusable gate sequences
237    fn find_fusable_sequences(&self, gates: &[Box<dyn GateOp>]) -> Vec<(usize, usize)> {
238        let mut sequences = Vec::new();
239        let mut i = 0;
240
241        while i < gates.len() {
242            // For single-qubit gates, find consecutive gates on same qubit
243            if gates[i].qubits().len() == 1 {
244                let target_qubit = gates[i].qubits()[0];
245                let mut j = i + 1;
246
247                while j < gates.len() && j - i < self.max_fusion_size {
248                    if gates[j].qubits().len() == 1 && gates[j].qubits()[0] == target_qubit {
249                        j += 1;
250                    } else {
251                        break;
252                    }
253                }
254
255                if j > i + 1 {
256                    sequences.push((i, j));
257                    i = j;
258                    continue;
259                }
260            }
261
262            // For multi-qubit gates, look for specific patterns
263            if gates[i].name() == "CNOT" && i + 1 < gates.len() && gates[i + 1].name() == "CNOT" {
264                sequences.push((i, i + 2));
265                i += 2;
266                continue;
267            }
268
269            i += 1;
270        }
271
272        sequences
273    }
274}
275
276impl OptimizationPass for GateFusion {
277    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
278        let mut optimized = Vec::new();
279        let mut processed = vec![false; gates.len()];
280
281        // Find fusable sequences
282        let sequences = self.find_fusable_sequences(&gates);
283
284        for (start, end) in sequences {
285            let sequence = &gates[start..end];
286
287            // Skip if already processed
288            if processed[start] {
289                continue;
290            }
291
292            // Try different fusion strategies
293            let mut fused = false;
294
295            // Try rotation fusion first (most specific)
296            if let Some(fused_gate) = self.fuse_rotation_gates(sequence)? {
297                optimized.push(fused_gate);
298                fused = true;
299            }
300            // Try CNOT fusion
301            else if sequence.iter().all(|g| g.name() == "CNOT") {
302                if let Some(fused_gates) = self.fuse_cnot_gates(sequence)? {
303                    optimized.extend(fused_gates);
304                    fused = true;
305                }
306            }
307            // Try general single-qubit fusion
308            else if self.fuse_single_qubit && sequence.iter().all(|g| g.qubits().len() == 1) {
309                if let Some(fused_gate) = self.fuse_single_qubit_gates(sequence)? {
310                    optimized.push(fused_gate);
311                    fused = true;
312                }
313            }
314
315            // Mark as processed
316            if fused {
317                for i in start..end {
318                    processed[i] = true;
319                }
320            }
321        }
322
323        // Add unfused gates
324        for (i, gate) in gates.into_iter().enumerate() {
325            if !processed[i] {
326                optimized.push(gate);
327            }
328        }
329
330        Ok(optimized)
331    }
332
333    fn name(&self) -> &'static str {
334        "Gate Fusion"
335    }
336}
337
338/// Specialized fusion for Clifford gates
339pub struct CliffordFusion {
340    #[allow(dead_code)]
341    tolerance: f64,
342}
343
344impl CliffordFusion {
345    pub const fn new() -> Self {
346        Self { tolerance: 1e-10 }
347    }
348
349    /// Fuse adjacent Clifford gates
350    fn fuse_clifford_pair(
351        &self,
352        gate1: &dyn GateOp,
353        gate2: &dyn GateOp,
354    ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
355        // Only fuse if gates act on same qubit
356        if gate1.qubits() != gate2.qubits() || gate1.qubits().len() != 1 {
357            return Ok(None);
358        }
359
360        let qubit = gate1.qubits()[0];
361
362        match (gate1.name(), gate2.name()) {
363            // Self-inverse gates (H, X, Y, Z, S†S, SS†)
364            ("H", "H") | ("X", "X") | ("Y", "Y") | ("Z", "Z") | ("S", "S†") | ("S†", "S") => {
365                Ok(None) // Identity - will be removed
366            }
367
368            // S gate combinations & Pauli combinations resulting in Z
369            ("S", "S") | ("S†", "S†") | ("X", "Y") | ("Y" | "H", "X") => {
370                Ok(Some(Box::new(PauliZ { target: qubit }))) // SS/S†S†/XY/YX/HX → Z
371            }
372
373            // Pauli combinations
374            ("X", "Z") | ("Z", "X") => Ok(Some(Box::new(PauliY { target: qubit }))), // XZ = -iY, ZX = iY
375            ("Y" | "H", "Z") | ("Z", "Y") => {
376                Ok(Some(Box::new(PauliX { target: qubit }))) // YZ/ZY/HZ → X
377            }
378
379            _ => Ok(None),
380        }
381    }
382}
383
384impl OptimizationPass for CliffordFusion {
385    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
386        let mut optimized = Vec::new();
387        let mut i = 0;
388
389        while i < gates.len() {
390            if i + 1 < gates.len() {
391                if let Some(fused) =
392                    self.fuse_clifford_pair(gates[i].as_ref(), gates[i + 1].as_ref())?
393                {
394                    optimized.push(fused);
395                    i += 2;
396                    continue;
397                } else if gates[i].qubits() == gates[i + 1].qubits() {
398                    // Check if it's identity (would return None from fusion)
399                    let combined_is_identity = match (gates[i].name(), gates[i + 1].name()) {
400                        ("H", "H")
401                        | ("S", "S†")
402                        | ("S†", "S")
403                        | ("X", "X")
404                        | ("Y", "Y")
405                        | ("Z", "Z") => true,
406                        _ => false,
407                    };
408
409                    if combined_is_identity {
410                        i += 2;
411                        continue;
412                    }
413                }
414            }
415
416            optimized.push(gates[i].clone_gate());
417            i += 1;
418        }
419
420        Ok(optimized)
421    }
422
423    fn name(&self) -> &'static str {
424        "Clifford Fusion"
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use crate::gate::single::{Hadamard, Phase};
432    use crate::prelude::QubitId;
433
434    #[test]
435    fn test_rotation_fusion() {
436        let fusion = GateFusion::new();
437        let qubit = QubitId(0);
438
439        let gates: Vec<Box<dyn GateOp>> = vec![
440            Box::new(RotationZ {
441                target: qubit,
442                theta: 0.5,
443            }),
444            Box::new(RotationZ {
445                target: qubit,
446                theta: 0.3,
447            }),
448            Box::new(RotationZ {
449                target: qubit,
450                theta: 0.2,
451            }),
452        ];
453
454        let result = fusion
455            .fuse_rotation_gates(&gates)
456            .expect("Failed to fuse rotation gates");
457        assert!(result.is_some());
458
459        if let Some(rz) = result
460            .expect("Expected a fused gate")
461            .as_any()
462            .downcast_ref::<RotationZ>()
463        {
464            assert!((rz.theta - 1.0).abs() < 1e-10);
465        } else {
466            panic!("Expected RotationZ gate");
467        }
468    }
469
470    #[test]
471    fn test_cnot_cancellation() {
472        let fusion = GateFusion::new();
473        let q0 = QubitId(0);
474        let q1 = QubitId(1);
475
476        let gates: Vec<Box<dyn GateOp>> = vec![
477            Box::new(CNOT {
478                control: q0,
479                target: q1,
480            }),
481            Box::new(CNOT {
482                control: q0,
483                target: q1,
484            }),
485        ];
486
487        let result = fusion
488            .fuse_cnot_gates(&gates)
489            .expect("Failed to fuse CNOT gates");
490        assert!(result.is_some());
491        assert_eq!(result.expect("Expected fused gate list").len(), 0); // Should cancel
492    }
493
494    #[test]
495    fn test_cnot_to_swap() {
496        let fusion = GateFusion::new();
497        let q0 = QubitId(0);
498        let q1 = QubitId(1);
499
500        let gates: Vec<Box<dyn GateOp>> = vec![
501            Box::new(CNOT {
502                control: q0,
503                target: q1,
504            }),
505            Box::new(CNOT {
506                control: q1,
507                target: q0,
508            }),
509        ];
510
511        let result = fusion
512            .fuse_cnot_gates(&gates)
513            .expect("Failed to fuse CNOT gates");
514        assert!(result.is_some());
515        let fused = result.expect("Expected fused gate list");
516        assert_eq!(fused.len(), 1);
517        assert_eq!(fused[0].name(), "SWAP");
518    }
519
520    #[test]
521    fn test_clifford_fusion() {
522        let fusion = CliffordFusion::new();
523        let qubit = QubitId(0);
524
525        let gates: Vec<Box<dyn GateOp>> = vec![
526            Box::new(Hadamard { target: qubit }),
527            Box::new(Hadamard { target: qubit }),
528            Box::new(Phase { target: qubit }),
529            Box::new(Phase { target: qubit }),
530        ];
531
532        let result = fusion
533            .optimize(gates)
534            .expect("Failed to optimize Clifford gates");
535        // H*H cancels, S*S = Z
536        assert_eq!(result.len(), 1);
537        assert_eq!(result[0].name(), "Z");
538    }
539
540    #[test]
541    fn test_full_optimization() {
542        let mut chain = super::super::OptimizationChain::new();
543        chain = chain
544            .add_pass(Box::new(CliffordFusion::new()))
545            .add_pass(Box::new(GateFusion::new()));
546
547        let q0 = QubitId(0);
548        let q1 = QubitId(1);
549
550        let gates: Vec<Box<dyn GateOp>> = vec![
551            Box::new(Hadamard { target: q0 }),
552            Box::new(Hadamard { target: q0 }),
553            Box::new(CNOT {
554                control: q0,
555                target: q1,
556            }),
557            Box::new(CNOT {
558                control: q0,
559                target: q1,
560            }),
561            Box::new(RotationZ {
562                target: q1,
563                theta: 0.5,
564            }),
565            Box::new(RotationZ {
566                target: q1,
567                theta: 0.5,
568            }),
569        ];
570
571        let result = chain
572            .optimize(gates)
573            .expect("Failed to optimize gate chain");
574
575        // After CliffordFusion: CNOT, CNOT, RZ, RZ (H*H canceled)
576        // After GateFusion: RZ (CNOT*CNOT canceled, RZ+RZ fused)
577        assert_eq!(result.len(), 1);
578        assert_eq!(result[0].name(), "RZ");
579
580        // Check the fused angle
581        if let Some(rz) = result[0].as_any().downcast_ref::<RotationZ>() {
582            assert!((rz.theta - 1.0).abs() < 1e-10);
583        }
584    }
585}