quantrs2_core/optimization/
fusion.rs

1//! Gate fusion optimization pass
2//!
3//! This module implements gate fusion, which combines adjacent compatible gates
4//! into single operations to reduce circuit depth and improve performance.
5
6use crate::error::{QuantRS2Error, QuantRS2Result};
7use crate::gate::{multi::*, single::*, GateOp};
8use crate::matrix_ops::{DenseMatrix, QuantumMatrix};
9use crate::qubit::QubitId;
10use crate::synthesis::{identify_gate, synthesize_unitary};
11use ndarray::{Array1, Array2};
12use num_complex::Complex64;
13use std::collections::HashMap;
14
15use super::{gates_are_disjoint, OptimizationPass};
16
17/// Gate fusion optimization pass
18pub struct GateFusion {
19    /// Whether to fuse single-qubit gates
20    pub fuse_single_qubit: bool,
21    /// Whether to fuse two-qubit gates
22    pub fuse_two_qubit: bool,
23    /// Maximum number of gates to fuse together
24    pub max_fusion_size: usize,
25    /// Tolerance for gate identification
26    pub tolerance: f64,
27}
28
29impl Default for GateFusion {
30    fn default() -> Self {
31        Self {
32            fuse_single_qubit: true,
33            fuse_two_qubit: true,
34            max_fusion_size: 4,
35            tolerance: 1e-10,
36        }
37    }
38}
39
40impl GateFusion {
41    /// Create a new gate fusion pass
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    /// Try to fuse a sequence of single-qubit gates
47    fn fuse_single_qubit_gates(
48        &self,
49        gates: &[Box<dyn GateOp>],
50    ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
51        if gates.is_empty() {
52            return Ok(None);
53        }
54
55        // Check all gates act on the same qubit
56        let target_qubit = gates[0].qubits()[0];
57        if !gates
58            .iter()
59            .all(|g| g.qubits().len() == 1 && g.qubits()[0] == target_qubit)
60        {
61            return Ok(None);
62        }
63
64        // Compute the combined unitary matrix
65        let mut combined = Array2::eye(2);
66        for gate in gates {
67            let gate_matrix = gate.matrix()?;
68            let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)
69                .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
70            combined = combined.dot(&gate_array);
71        }
72
73        // Try to identify the combined gate
74        if let Some(gate_name) = identify_gate(&combined.view(), self.tolerance) {
75            // Convert identified gate name to actual gate
76            let identified_gate = match gate_name.as_str() {
77                "I" => None, // Identity - remove
78                "X" => Some(Box::new(PauliX {
79                    target: target_qubit,
80                }) as Box<dyn GateOp>),
81                "Y" => Some(Box::new(PauliY {
82                    target: target_qubit,
83                }) as Box<dyn GateOp>),
84                "Z" => Some(Box::new(PauliZ {
85                    target: target_qubit,
86                }) as Box<dyn GateOp>),
87                "H" => Some(Box::new(Hadamard {
88                    target: target_qubit,
89                }) as Box<dyn GateOp>),
90                "S" => Some(Box::new(Phase {
91                    target: target_qubit,
92                }) as Box<dyn GateOp>),
93                "S†" => Some(Box::new(PhaseDagger {
94                    target: target_qubit,
95                }) as Box<dyn GateOp>),
96                "T" => Some(Box::new(T {
97                    target: target_qubit,
98                }) as Box<dyn GateOp>),
99                "T†" => Some(Box::new(TDagger {
100                    target: target_qubit,
101                }) as Box<dyn GateOp>),
102                _ => None,
103            };
104
105            if let Some(gate) = identified_gate {
106                return Ok(Some(gate));
107            }
108        }
109
110        // If we can't identify it, synthesize it
111        let synthesized = synthesize_unitary(&combined.view(), &[target_qubit])?;
112        if synthesized.len() < gates.len() {
113            // Only use synthesis if it reduces gate count
114            if synthesized.len() == 1 {
115                Ok(Some(synthesized.into_iter().next().unwrap()))
116            } else {
117                Ok(None)
118            }
119        } else {
120            Ok(None)
121        }
122    }
123
124    /// Try to fuse CNOT gates
125    fn fuse_cnot_gates(
126        &self,
127        gates: &[Box<dyn GateOp>],
128    ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
129        if gates.len() < 2 {
130            return Ok(None);
131        }
132
133        let mut fused = Vec::new();
134        let mut i = 0;
135
136        while i < gates.len() {
137            if i + 1 < gates.len() {
138                if let (Some(cnot1), Some(cnot2)) = (
139                    gates[i].as_any().downcast_ref::<CNOT>(),
140                    gates[i + 1].as_any().downcast_ref::<CNOT>(),
141                ) {
142                    // Two CNOTs with same control and target cancel
143                    if cnot1.control == cnot2.control && cnot1.target == cnot2.target {
144                        // Skip both gates
145                        i += 2;
146                        continue;
147                    }
148                    // CNOT(a,b) followed by CNOT(b,a) is a SWAP
149                    else if cnot1.control == cnot2.target && cnot1.target == cnot2.control {
150                        fused.push(Box::new(SWAP {
151                            qubit1: cnot1.control,
152                            qubit2: cnot1.target,
153                        }) as Box<dyn GateOp>);
154                        i += 2;
155                        continue;
156                    }
157                }
158            }
159
160            // No fusion possible, keep the gate
161            fused.push(gates[i].clone_gate());
162            i += 1;
163        }
164
165        if fused.len() < gates.len() {
166            Ok(Some(fused))
167        } else {
168            Ok(None)
169        }
170    }
171
172    /// Try to fuse rotation gates
173    fn fuse_rotation_gates(
174        &self,
175        gates: &[Box<dyn GateOp>],
176    ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
177        if gates.len() < 2 {
178            return Ok(None);
179        }
180
181        // Check if all gates are rotations around the same axis on the same qubit
182        let first_gate = &gates[0];
183        let target_qubit = first_gate.qubits()[0];
184
185        match first_gate.name() {
186            "RX" => {
187                let mut total_angle = 0.0;
188                for gate in gates {
189                    if let Some(rx) = gate.as_any().downcast_ref::<RotationX>() {
190                        if rx.target != target_qubit {
191                            return Ok(None);
192                        }
193                        total_angle += rx.theta;
194                    } else {
195                        return Ok(None);
196                    }
197                }
198                Ok(Some(Box::new(RotationX {
199                    target: target_qubit,
200                    theta: total_angle,
201                })))
202            }
203            "RY" => {
204                let mut total_angle = 0.0;
205                for gate in gates {
206                    if let Some(ry) = gate.as_any().downcast_ref::<RotationY>() {
207                        if ry.target != target_qubit {
208                            return Ok(None);
209                        }
210                        total_angle += ry.theta;
211                    } else {
212                        return Ok(None);
213                    }
214                }
215                Ok(Some(Box::new(RotationY {
216                    target: target_qubit,
217                    theta: total_angle,
218                })))
219            }
220            "RZ" => {
221                let mut total_angle = 0.0;
222                for gate in gates {
223                    if let Some(rz) = gate.as_any().downcast_ref::<RotationZ>() {
224                        if rz.target != target_qubit {
225                            return Ok(None);
226                        }
227                        total_angle += rz.theta;
228                    } else {
229                        return Ok(None);
230                    }
231                }
232                Ok(Some(Box::new(RotationZ {
233                    target: target_qubit,
234                    theta: total_angle,
235                })))
236            }
237            _ => Ok(None),
238        }
239    }
240
241    /// Find fusable gate sequences
242    fn find_fusable_sequences(&self, gates: &[Box<dyn GateOp>]) -> Vec<(usize, usize)> {
243        let mut sequences = Vec::new();
244        let mut i = 0;
245
246        while i < gates.len() {
247            // For single-qubit gates, find consecutive gates on same qubit
248            if gates[i].qubits().len() == 1 {
249                let target_qubit = gates[i].qubits()[0];
250                let mut j = i + 1;
251
252                while j < gates.len() && j - i < self.max_fusion_size {
253                    if gates[j].qubits().len() == 1 && gates[j].qubits()[0] == target_qubit {
254                        j += 1;
255                    } else {
256                        break;
257                    }
258                }
259
260                if j > i + 1 {
261                    sequences.push((i, j));
262                    i = j;
263                    continue;
264                }
265            }
266
267            // For multi-qubit gates, look for specific patterns
268            if gates[i].name() == "CNOT" && i + 1 < gates.len() && gates[i + 1].name() == "CNOT" {
269                sequences.push((i, i + 2));
270                i += 2;
271                continue;
272            }
273
274            i += 1;
275        }
276
277        sequences
278    }
279}
280
281impl OptimizationPass for GateFusion {
282    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
283        let mut optimized = Vec::new();
284        let mut processed = vec![false; gates.len()];
285
286        // Find fusable sequences
287        let sequences = self.find_fusable_sequences(&gates);
288
289        for (start, end) in sequences {
290            let sequence = &gates[start..end];
291
292            // Skip if already processed
293            if processed[start] {
294                continue;
295            }
296
297            // Try different fusion strategies
298            let mut fused = false;
299
300            // Try rotation fusion first (most specific)
301            if let Some(fused_gate) = self.fuse_rotation_gates(sequence)? {
302                optimized.push(fused_gate);
303                fused = true;
304            }
305            // Try CNOT fusion
306            else if sequence.iter().all(|g| g.name() == "CNOT") {
307                if let Some(fused_gates) = self.fuse_cnot_gates(sequence)? {
308                    optimized.extend(fused_gates);
309                    fused = true;
310                }
311            }
312            // Try general single-qubit fusion
313            else if self.fuse_single_qubit && sequence.iter().all(|g| g.qubits().len() == 1) {
314                if let Some(fused_gate) = self.fuse_single_qubit_gates(sequence)? {
315                    optimized.push(fused_gate);
316                    fused = true;
317                }
318            }
319
320            // Mark as processed
321            if fused {
322                for i in start..end {
323                    processed[i] = true;
324                }
325            }
326        }
327
328        // Add unfused gates
329        for (i, gate) in gates.into_iter().enumerate() {
330            if !processed[i] {
331                optimized.push(gate);
332            }
333        }
334
335        Ok(optimized)
336    }
337
338    fn name(&self) -> &str {
339        "Gate Fusion"
340    }
341}
342
343/// Specialized fusion for Clifford gates
344pub struct CliffordFusion {
345    tolerance: f64,
346}
347
348impl CliffordFusion {
349    pub fn new() -> Self {
350        Self { tolerance: 1e-10 }
351    }
352
353    /// Fuse adjacent Clifford gates
354    fn fuse_clifford_pair(
355        &self,
356        gate1: &dyn GateOp,
357        gate2: &dyn GateOp,
358    ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
359        // Only fuse if gates act on same qubit
360        if gate1.qubits() != gate2.qubits() || gate1.qubits().len() != 1 {
361            return Ok(None);
362        }
363
364        let qubit = gate1.qubits()[0];
365
366        match (gate1.name(), gate2.name()) {
367            // Hadamard is self-inverse
368            ("H", "H") => Ok(None), // Identity - will be removed
369
370            // S gate combinations
371            ("S", "S") => Ok(Some(Box::new(PauliZ { target: qubit }))),
372            ("S†", "S†") => Ok(Some(Box::new(PauliZ { target: qubit }))),
373            ("S", "S†") | ("S†", "S") => Ok(None), // Identity
374
375            // Pauli combinations
376            ("X", "X") | ("Y", "Y") | ("Z", "Z") => Ok(None), // Identity
377            ("X", "Y") => Ok(Some(Box::new(PauliZ { target: qubit }))), // XY = iZ
378            ("Y", "X") => Ok(Some(Box::new(PauliZ { target: qubit }))), // YX = -iZ
379            ("X", "Z") => Ok(Some(Box::new(PauliY { target: qubit }))), // XZ = -iY
380            ("Z", "X") => Ok(Some(Box::new(PauliY { target: qubit }))), // ZX = iY
381            ("Y", "Z") => Ok(Some(Box::new(PauliX { target: qubit }))), // YZ = iX
382            ("Z", "Y") => Ok(Some(Box::new(PauliX { target: qubit }))), // ZY = -iX
383
384            // H and Pauli combinations
385            ("H", "X") => Ok(Some(Box::new(PauliZ { target: qubit }))), // HX = ZH
386            ("H", "Z") => Ok(Some(Box::new(PauliX { target: qubit }))), // HZ = XH
387
388            _ => Ok(None),
389        }
390    }
391}
392
393impl OptimizationPass for CliffordFusion {
394    fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
395        let mut optimized = Vec::new();
396        let mut i = 0;
397
398        while i < gates.len() {
399            if i + 1 < gates.len() {
400                if let Some(fused) =
401                    self.fuse_clifford_pair(gates[i].as_ref(), gates[i + 1].as_ref())?
402                {
403                    optimized.push(fused);
404                    i += 2;
405                    continue;
406                } else if gates[i].qubits() == gates[i + 1].qubits() {
407                    // Check if it's identity (would return None from fusion)
408                    let combined_is_identity = match (gates[i].name(), gates[i + 1].name()) {
409                        ("H", "H")
410                        | ("S", "S†")
411                        | ("S†", "S")
412                        | ("X", "X")
413                        | ("Y", "Y")
414                        | ("Z", "Z") => true,
415                        _ => false,
416                    };
417
418                    if combined_is_identity {
419                        i += 2;
420                        continue;
421                    }
422                }
423            }
424
425            optimized.push(gates[i].clone_gate());
426            i += 1;
427        }
428
429        Ok(optimized)
430    }
431
432    fn name(&self) -> &str {
433        "Clifford Fusion"
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use crate::gate::single::{Hadamard, Phase, PhaseDagger, TDagger, T};
441
442    #[test]
443    fn test_rotation_fusion() {
444        let fusion = GateFusion::new();
445        let qubit = QubitId(0);
446
447        let gates: Vec<Box<dyn GateOp>> = vec![
448            Box::new(RotationZ {
449                target: qubit,
450                theta: 0.5,
451            }),
452            Box::new(RotationZ {
453                target: qubit,
454                theta: 0.3,
455            }),
456            Box::new(RotationZ {
457                target: qubit,
458                theta: 0.2,
459            }),
460        ];
461
462        let result = fusion.fuse_rotation_gates(&gates).unwrap();
463        assert!(result.is_some());
464
465        if let Some(rz) = result.unwrap().as_any().downcast_ref::<RotationZ>() {
466            assert!((rz.theta - 1.0).abs() < 1e-10);
467        } else {
468            panic!("Expected RotationZ gate");
469        }
470    }
471
472    #[test]
473    fn test_cnot_cancellation() {
474        let fusion = GateFusion::new();
475        let q0 = QubitId(0);
476        let q1 = QubitId(1);
477
478        let gates: Vec<Box<dyn GateOp>> = vec![
479            Box::new(CNOT {
480                control: q0,
481                target: q1,
482            }),
483            Box::new(CNOT {
484                control: q0,
485                target: q1,
486            }),
487        ];
488
489        let result = fusion.fuse_cnot_gates(&gates).unwrap();
490        assert!(result.is_some());
491        assert_eq!(result.unwrap().len(), 0); // Should cancel
492    }
493
494    #[test]
495    fn test_cnot_to_swap() {
496        let fusion = GateFusion::new();
497        let q0 = QubitId(0);
498        let q1 = QubitId(1);
499
500        let gates: Vec<Box<dyn GateOp>> = vec![
501            Box::new(CNOT {
502                control: q0,
503                target: q1,
504            }),
505            Box::new(CNOT {
506                control: q1,
507                target: q0,
508            }),
509        ];
510
511        let result = fusion.fuse_cnot_gates(&gates).unwrap();
512        assert!(result.is_some());
513        let fused = result.unwrap();
514        assert_eq!(fused.len(), 1);
515        assert_eq!(fused[0].name(), "SWAP");
516    }
517
518    #[test]
519    fn test_clifford_fusion() {
520        let fusion = CliffordFusion::new();
521        let qubit = QubitId(0);
522
523        let gates: Vec<Box<dyn GateOp>> = vec![
524            Box::new(Hadamard { target: qubit }),
525            Box::new(Hadamard { target: qubit }),
526            Box::new(Phase { target: qubit }),
527            Box::new(Phase { target: qubit }),
528        ];
529
530        let result = fusion.optimize(gates).unwrap();
531        // H*H cancels, S*S = Z
532        assert_eq!(result.len(), 1);
533        assert_eq!(result[0].name(), "Z");
534    }
535
536    #[test]
537    fn test_full_optimization() {
538        let mut chain = super::super::OptimizationChain::new();
539        chain = chain
540            .add_pass(Box::new(CliffordFusion::new()))
541            .add_pass(Box::new(GateFusion::new()));
542
543        let q0 = QubitId(0);
544        let q1 = QubitId(1);
545
546        let gates: Vec<Box<dyn GateOp>> = vec![
547            Box::new(Hadamard { target: q0 }),
548            Box::new(Hadamard { target: q0 }),
549            Box::new(CNOT {
550                control: q0,
551                target: q1,
552            }),
553            Box::new(CNOT {
554                control: q0,
555                target: q1,
556            }),
557            Box::new(RotationZ {
558                target: q1,
559                theta: 0.5,
560            }),
561            Box::new(RotationZ {
562                target: q1,
563                theta: 0.5,
564            }),
565        ];
566
567        let result = chain.optimize(gates).unwrap();
568
569        // After CliffordFusion: CNOT, CNOT, RZ, RZ (H*H canceled)
570        // After GateFusion: RZ (CNOT*CNOT canceled, RZ+RZ fused)
571        assert_eq!(result.len(), 1);
572        assert_eq!(result[0].name(), "RZ");
573
574        // Check the fused angle
575        if let Some(rz) = result[0].as_any().downcast_ref::<RotationZ>() {
576            assert!((rz.theta - 1.0).abs() < 1e-10);
577        }
578    }
579}