quantrs2_core/optimization/
fusion.rs

1//! Gate fusion optimization pass
2//!
3//! This module implements gate fusion, which combines adjacent compatible gates
4//! into single operations to reduce circuit depth and improve performance.
5
6use crate::error::{QuantRS2Error, QuantRS2Result};
7use crate::gate::{multi::*, single::*, GateOp};
8use crate::synthesis::{identify_gate, synthesize_unitary};
9use ndarray::Array2;
10
11use super::OptimizationPass;
12
13/// Gate fusion optimization pass
14pub struct GateFusion {
15    /// Whether to fuse single-qubit gates
16    pub fuse_single_qubit: bool,
17    /// Whether to fuse two-qubit gates
18    pub fuse_two_qubit: bool,
19    /// Maximum number of gates to fuse together
20    pub max_fusion_size: usize,
21    /// Tolerance for gate identification
22    pub tolerance: f64,
23}
24
25impl Default for GateFusion {
26    fn default() -> Self {
27        Self {
28            fuse_single_qubit: true,
29            fuse_two_qubit: true,
30            max_fusion_size: 4,
31            tolerance: 1e-10,
32        }
33    }
34}
35
36impl GateFusion {
37    /// Create a new gate fusion pass
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Try to fuse a sequence of single-qubit gates
43    fn fuse_single_qubit_gates(
44        &self,
45        gates: &[Box<dyn GateOp>],
46    ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
47        if gates.is_empty() {
48            return Ok(None);
49        }
50
51        // Check all gates act on the same qubit
52        let target_qubit = gates[0].qubits()[0];
53        if !gates
54            .iter()
55            .all(|g| g.qubits().len() == 1 && g.qubits()[0] == target_qubit)
56        {
57            return Ok(None);
58        }
59
60        // Compute the combined unitary matrix
61        let mut combined = Array2::eye(2);
62        for gate in gates {
63            let gate_matrix = gate.matrix()?;
64            let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)
65                .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
66            combined = combined.dot(&gate_array);
67        }
68
69        // Try to identify the combined gate
70        if let Some(gate_name) = identify_gate(&combined.view(), self.tolerance) {
71            // Convert identified gate name to actual gate
72            let identified_gate = match gate_name.as_str() {
73                "I" => None, // Identity - remove
74                "X" => Some(Box::new(PauliX {
75                    target: target_qubit,
76                }) as Box<dyn GateOp>),
77                "Y" => Some(Box::new(PauliY {
78                    target: target_qubit,
79                }) as Box<dyn GateOp>),
80                "Z" => Some(Box::new(PauliZ {
81                    target: target_qubit,
82                }) as Box<dyn GateOp>),
83                "H" => Some(Box::new(Hadamard {
84                    target: target_qubit,
85                }) as Box<dyn GateOp>),
86                "S" => Some(Box::new(Phase {
87                    target: target_qubit,
88                }) as Box<dyn GateOp>),
89                "S†" => Some(Box::new(PhaseDagger {
90                    target: target_qubit,
91                }) as Box<dyn GateOp>),
92                "T" => Some(Box::new(T {
93                    target: target_qubit,
94                }) as Box<dyn GateOp>),
95                "T†" => Some(Box::new(TDagger {
96                    target: target_qubit,
97                }) as Box<dyn GateOp>),
98                _ => None,
99            };
100
101            if let Some(gate) = identified_gate {
102                return Ok(Some(gate));
103            }
104        }
105
106        // If we can't identify it, synthesize it
107        let synthesized = synthesize_unitary(&combined.view(), &[target_qubit])?;
108        if synthesized.len() < gates.len() {
109            // Only use synthesis if it reduces gate count
110            if synthesized.len() == 1 {
111                Ok(Some(synthesized.into_iter().next().unwrap()))
112            } else {
113                Ok(None)
114            }
115        } else {
116            Ok(None)
117        }
118    }
119
120    /// Try to fuse CNOT gates
121    fn fuse_cnot_gates(
122        &self,
123        gates: &[Box<dyn GateOp>],
124    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
125        if gates.len() < 2 {
126            return Ok(None);
127        }
128
129        let mut fused = Vec::new();
130        let mut i = 0;
131
132        while i < gates.len() {
133            if i + 1 < gates.len() {
134                if let (Some(cnot1), Some(cnot2)) = (
135                    gates[i].as_any().downcast_ref::<CNOT>(),
136                    gates[i + 1].as_any().downcast_ref::<CNOT>(),
137                ) {
138                    // Two CNOTs with same control and target cancel
139                    if cnot1.control == cnot2.control && cnot1.target == cnot2.target {
140                        // Skip both gates
141                        i += 2;
142                        continue;
143                    }
144                    // CNOT(a,b) followed by CNOT(b,a) is a SWAP
145                    else if cnot1.control == cnot2.target && cnot1.target == cnot2.control {
146                        fused.push(Box::new(SWAP {
147                            qubit1: cnot1.control,
148                            qubit2: cnot1.target,
149                        }) as Box<dyn GateOp>);
150                        i += 2;
151                        continue;
152                    }
153                }
154            }
155
156            // No fusion possible, keep the gate
157            fused.push(gates[i].clone_gate());
158            i += 1;
159        }
160
161        if fused.len() < gates.len() {
162            Ok(Some(fused))
163        } else {
164            Ok(None)
165        }
166    }
167
168    /// Try to fuse rotation gates
169    fn fuse_rotation_gates(
170        &self,
171        gates: &[Box<dyn GateOp>],
172    ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
173        if gates.len() < 2 {
174            return Ok(None);
175        }
176
177        // Check if all gates are rotations around the same axis on the same qubit
178        let first_gate = &gates[0];
179        let target_qubit = first_gate.qubits()[0];
180
181        match first_gate.name() {
182            "RX" => {
183                let mut total_angle = 0.0;
184                for gate in gates {
185                    if let Some(rx) = gate.as_any().downcast_ref::<RotationX>() {
186                        if rx.target != target_qubit {
187                            return Ok(None);
188                        }
189                        total_angle += rx.theta;
190                    } else {
191                        return Ok(None);
192                    }
193                }
194                Ok(Some(Box::new(RotationX {
195                    target: target_qubit,
196                    theta: total_angle,
197                })))
198            }
199            "RY" => {
200                let mut total_angle = 0.0;
201                for gate in gates {
202                    if let Some(ry) = gate.as_any().downcast_ref::<RotationY>() {
203                        if ry.target != target_qubit {
204                            return Ok(None);
205                        }
206                        total_angle += ry.theta;
207                    } else {
208                        return Ok(None);
209                    }
210                }
211                Ok(Some(Box::new(RotationY {
212                    target: target_qubit,
213                    theta: total_angle,
214                })))
215            }
216            "RZ" => {
217                let mut total_angle = 0.0;
218                for gate in gates {
219                    if let Some(rz) = gate.as_any().downcast_ref::<RotationZ>() {
220                        if rz.target != target_qubit {
221                            return Ok(None);
222                        }
223                        total_angle += rz.theta;
224                    } else {
225                        return Ok(None);
226                    }
227                }
228                Ok(Some(Box::new(RotationZ {
229                    target: target_qubit,
230                    theta: total_angle,
231                })))
232            }
233            _ => Ok(None),
234        }
235    }
236
237    /// Find fusable gate sequences
238    fn find_fusable_sequences(&self, gates: &[Box<dyn GateOp>]) -> Vec<(usize, usize)> {
239        let mut sequences = Vec::new();
240        let mut i = 0;
241
242        while i < gates.len() {
243            // For single-qubit gates, find consecutive gates on same qubit
244            if gates[i].qubits().len() == 1 {
245                let target_qubit = gates[i].qubits()[0];
246                let mut j = i + 1;
247
248                while j < gates.len() && j - i < self.max_fusion_size {
249                    if gates[j].qubits().len() == 1 && gates[j].qubits()[0] == target_qubit {
250                        j += 1;
251                    } else {
252                        break;
253                    }
254                }
255
256                if j > i + 1 {
257                    sequences.push((i, j));
258                    i = j;
259                    continue;
260                }
261            }
262
263            // For multi-qubit gates, look for specific patterns
264            if gates[i].name() == "CNOT" && i + 1 < gates.len() && gates[i + 1].name() == "CNOT" {
265                sequences.push((i, i + 2));
266                i += 2;
267                continue;
268            }
269
270            i += 1;
271        }
272
273        sequences
274    }
275}
276
277impl OptimizationPass for GateFusion {
278    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
279        let mut optimized = Vec::new();
280        let mut processed = vec![false; gates.len()];
281
282        // Find fusable sequences
283        let sequences = self.find_fusable_sequences(&gates);
284
285        for (start, end) in sequences {
286            let sequence = &gates[start..end];
287
288            // Skip if already processed
289            if processed[start] {
290                continue;
291            }
292
293            // Try different fusion strategies
294            let mut fused = false;
295
296            // Try rotation fusion first (most specific)
297            if let Some(fused_gate) = self.fuse_rotation_gates(sequence)? {
298                optimized.push(fused_gate);
299                fused = true;
300            }
301            // Try CNOT fusion
302            else if sequence.iter().all(|g| g.name() == "CNOT") {
303                if let Some(fused_gates) = self.fuse_cnot_gates(sequence)? {
304                    optimized.extend(fused_gates);
305                    fused = true;
306                }
307            }
308            // Try general single-qubit fusion
309            else if self.fuse_single_qubit && sequence.iter().all(|g| g.qubits().len() == 1) {
310                if let Some(fused_gate) = self.fuse_single_qubit_gates(sequence)? {
311                    optimized.push(fused_gate);
312                    fused = true;
313                }
314            }
315
316            // Mark as processed
317            if fused {
318                for i in start..end {
319                    processed[i] = true;
320                }
321            }
322        }
323
324        // Add unfused gates
325        for (i, gate) in gates.into_iter().enumerate() {
326            if !processed[i] {
327                optimized.push(gate);
328            }
329        }
330
331        Ok(optimized)
332    }
333
334    fn name(&self) -> &str {
335        "Gate Fusion"
336    }
337}
338
339/// Specialized fusion for Clifford gates
340pub struct CliffordFusion {
341    #[allow(dead_code)]
342    tolerance: f64,
343}
344
345impl CliffordFusion {
346    pub fn new() -> Self {
347        Self { tolerance: 1e-10 }
348    }
349
350    /// Fuse adjacent Clifford gates
351    fn fuse_clifford_pair(
352        &self,
353        gate1: &dyn GateOp,
354        gate2: &dyn GateOp,
355    ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
356        // Only fuse if gates act on same qubit
357        if gate1.qubits() != gate2.qubits() || gate1.qubits().len() != 1 {
358            return Ok(None);
359        }
360
361        let qubit = gate1.qubits()[0];
362
363        match (gate1.name(), gate2.name()) {
364            // Hadamard is self-inverse
365            ("H", "H") => Ok(None), // Identity - will be removed
366
367            // S gate combinations
368            ("S", "S") => Ok(Some(Box::new(PauliZ { target: qubit }))),
369            ("S†", "S†") => Ok(Some(Box::new(PauliZ { target: qubit }))),
370            ("S", "S†") | ("S†", "S") => Ok(None), // Identity
371
372            // Pauli combinations
373            ("X", "X") | ("Y", "Y") | ("Z", "Z") => Ok(None), // Identity
374            ("X", "Y") => Ok(Some(Box::new(PauliZ { target: qubit }))), // XY = iZ
375            ("Y", "X") => Ok(Some(Box::new(PauliZ { target: qubit }))), // YX = -iZ
376            ("X", "Z") => Ok(Some(Box::new(PauliY { target: qubit }))), // XZ = -iY
377            ("Z", "X") => Ok(Some(Box::new(PauliY { target: qubit }))), // ZX = iY
378            ("Y", "Z") => Ok(Some(Box::new(PauliX { target: qubit }))), // YZ = iX
379            ("Z", "Y") => Ok(Some(Box::new(PauliX { target: qubit }))), // ZY = -iX
380
381            // H and Pauli combinations
382            ("H", "X") => Ok(Some(Box::new(PauliZ { target: qubit }))), // HX = ZH
383            ("H", "Z") => Ok(Some(Box::new(PauliX { target: qubit }))), // HZ = XH
384
385            _ => Ok(None),
386        }
387    }
388}
389
390impl OptimizationPass for CliffordFusion {
391    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
392        let mut optimized = Vec::new();
393        let mut i = 0;
394
395        while i < gates.len() {
396            if i + 1 < gates.len() {
397                if let Some(fused) =
398                    self.fuse_clifford_pair(gates[i].as_ref(), gates[i + 1].as_ref())?
399                {
400                    optimized.push(fused);
401                    i += 2;
402                    continue;
403                } else if gates[i].qubits() == gates[i + 1].qubits() {
404                    // Check if it's identity (would return None from fusion)
405                    let combined_is_identity = match (gates[i].name(), gates[i + 1].name()) {
406                        ("H", "H")
407                        | ("S", "S†")
408                        | ("S†", "S")
409                        | ("X", "X")
410                        | ("Y", "Y")
411                        | ("Z", "Z") => true,
412                        _ => false,
413                    };
414
415                    if combined_is_identity {
416                        i += 2;
417                        continue;
418                    }
419                }
420            }
421
422            optimized.push(gates[i].clone_gate());
423            i += 1;
424        }
425
426        Ok(optimized)
427    }
428
429    fn name(&self) -> &str {
430        "Clifford Fusion"
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437    use crate::gate::single::{Hadamard, Phase};
438    use crate::prelude::QubitId;
439
440    #[test]
441    fn test_rotation_fusion() {
442        let fusion = GateFusion::new();
443        let qubit = QubitId(0);
444
445        let gates: Vec<Box<dyn GateOp>> = vec![
446            Box::new(RotationZ {
447                target: qubit,
448                theta: 0.5,
449            }),
450            Box::new(RotationZ {
451                target: qubit,
452                theta: 0.3,
453            }),
454            Box::new(RotationZ {
455                target: qubit,
456                theta: 0.2,
457            }),
458        ];
459
460        let result = fusion.fuse_rotation_gates(&gates).unwrap();
461        assert!(result.is_some());
462
463        if let Some(rz) = result.unwrap().as_any().downcast_ref::<RotationZ>() {
464            assert!((rz.theta - 1.0).abs() < 1e-10);
465        } else {
466            panic!("Expected RotationZ gate");
467        }
468    }
469
470    #[test]
471    fn test_cnot_cancellation() {
472        let fusion = GateFusion::new();
473        let q0 = QubitId(0);
474        let q1 = QubitId(1);
475
476        let gates: Vec<Box<dyn GateOp>> = vec![
477            Box::new(CNOT {
478                control: q0,
479                target: q1,
480            }),
481            Box::new(CNOT {
482                control: q0,
483                target: q1,
484            }),
485        ];
486
487        let result = fusion.fuse_cnot_gates(&gates).unwrap();
488        assert!(result.is_some());
489        assert_eq!(result.unwrap().len(), 0); // Should cancel
490    }
491
492    #[test]
493    fn test_cnot_to_swap() {
494        let fusion = GateFusion::new();
495        let q0 = QubitId(0);
496        let q1 = QubitId(1);
497
498        let gates: Vec<Box<dyn GateOp>> = vec![
499            Box::new(CNOT {
500                control: q0,
501                target: q1,
502            }),
503            Box::new(CNOT {
504                control: q1,
505                target: q0,
506            }),
507        ];
508
509        let result = fusion.fuse_cnot_gates(&gates).unwrap();
510        assert!(result.is_some());
511        let fused = result.unwrap();
512        assert_eq!(fused.len(), 1);
513        assert_eq!(fused[0].name(), "SWAP");
514    }
515
516    #[test]
517    fn test_clifford_fusion() {
518        let fusion = CliffordFusion::new();
519        let qubit = QubitId(0);
520
521        let gates: Vec<Box<dyn GateOp>> = vec![
522            Box::new(Hadamard { target: qubit }),
523            Box::new(Hadamard { target: qubit }),
524            Box::new(Phase { target: qubit }),
525            Box::new(Phase { target: qubit }),
526        ];
527
528        let result = fusion.optimize(gates).unwrap();
529        // H*H cancels, S*S = Z
530        assert_eq!(result.len(), 1);
531        assert_eq!(result[0].name(), "Z");
532    }
533
534    #[test]
535    fn test_full_optimization() {
536        let mut chain = super::super::OptimizationChain::new();
537        chain = chain
538            .add_pass(Box::new(CliffordFusion::new()))
539            .add_pass(Box::new(GateFusion::new()));
540
541        let q0 = QubitId(0);
542        let q1 = QubitId(1);
543
544        let gates: Vec<Box<dyn GateOp>> = vec![
545            Box::new(Hadamard { target: q0 }),
546            Box::new(Hadamard { target: q0 }),
547            Box::new(CNOT {
548                control: q0,
549                target: q1,
550            }),
551            Box::new(CNOT {
552                control: q0,
553                target: q1,
554            }),
555            Box::new(RotationZ {
556                target: q1,
557                theta: 0.5,
558            }),
559            Box::new(RotationZ {
560                target: q1,
561                theta: 0.5,
562            }),
563        ];
564
565        let result = chain.optimize(gates).unwrap();
566
567        // After CliffordFusion: CNOT, CNOT, RZ, RZ (H*H canceled)
568        // After GateFusion: RZ (CNOT*CNOT canceled, RZ+RZ fused)
569        assert_eq!(result.len(), 1);
570        assert_eq!(result[0].name(), "RZ");
571
572        // Check the fused angle
573        if let Some(rz) = result[0].as_any().downcast_ref::<RotationZ>() {
574            assert!((rz.theta - 1.0).abs() < 1e-10);
575        }
576    }
577}