quantrs2_core/optimization/
fusion.rs1use crate::error::{QuantRS2Error, QuantRS2Result};
7use crate::gate::{multi::*, single::*, GateOp};
8use crate::matrix_ops::{DenseMatrix, QuantumMatrix};
9use crate::qubit::QubitId;
10use crate::synthesis::{identify_gate, synthesize_unitary};
11use ndarray::{Array1, Array2};
12use num_complex::Complex64;
13use std::collections::HashMap;
14
15use super::{gates_are_disjoint, OptimizationPass};
16
17pub struct GateFusion {
19 pub fuse_single_qubit: bool,
21 pub fuse_two_qubit: bool,
23 pub max_fusion_size: usize,
25 pub tolerance: f64,
27}
28
29impl Default for GateFusion {
30 fn default() -> Self {
31 Self {
32 fuse_single_qubit: true,
33 fuse_two_qubit: true,
34 max_fusion_size: 4,
35 tolerance: 1e-10,
36 }
37 }
38}
39
40impl GateFusion {
41 pub fn new() -> Self {
43 Self::default()
44 }
45
46 fn fuse_single_qubit_gates(
48 &self,
49 gates: &[Box<dyn GateOp>],
50 ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
51 if gates.is_empty() {
52 return Ok(None);
53 }
54
55 let target_qubit = gates[0].qubits()[0];
57 if !gates
58 .iter()
59 .all(|g| g.qubits().len() == 1 && g.qubits()[0] == target_qubit)
60 {
61 return Ok(None);
62 }
63
64 let mut combined = Array2::eye(2);
66 for gate in gates {
67 let gate_matrix = gate.matrix()?;
68 let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)
69 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
70 combined = combined.dot(&gate_array);
71 }
72
73 if let Some(gate_name) = identify_gate(&combined.view(), self.tolerance) {
75 let identified_gate = match gate_name.as_str() {
77 "I" => None, "X" => Some(Box::new(PauliX {
79 target: target_qubit,
80 }) as Box<dyn GateOp>),
81 "Y" => Some(Box::new(PauliY {
82 target: target_qubit,
83 }) as Box<dyn GateOp>),
84 "Z" => Some(Box::new(PauliZ {
85 target: target_qubit,
86 }) as Box<dyn GateOp>),
87 "H" => Some(Box::new(Hadamard {
88 target: target_qubit,
89 }) as Box<dyn GateOp>),
90 "S" => Some(Box::new(Phase {
91 target: target_qubit,
92 }) as Box<dyn GateOp>),
93 "S†" => Some(Box::new(PhaseDagger {
94 target: target_qubit,
95 }) as Box<dyn GateOp>),
96 "T" => Some(Box::new(T {
97 target: target_qubit,
98 }) as Box<dyn GateOp>),
99 "T†" => Some(Box::new(TDagger {
100 target: target_qubit,
101 }) as Box<dyn GateOp>),
102 _ => None,
103 };
104
105 if let Some(gate) = identified_gate {
106 return Ok(Some(gate));
107 }
108 }
109
110 let synthesized = synthesize_unitary(&combined.view(), &[target_qubit])?;
112 if synthesized.len() < gates.len() {
113 if synthesized.len() == 1 {
115 Ok(Some(synthesized.into_iter().next().unwrap()))
116 } else {
117 Ok(None)
118 }
119 } else {
120 Ok(None)
121 }
122 }
123
124 fn fuse_cnot_gates(
126 &self,
127 gates: &[Box<dyn GateOp>],
128 ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
129 if gates.len() < 2 {
130 return Ok(None);
131 }
132
133 let mut fused = Vec::new();
134 let mut i = 0;
135
136 while i < gates.len() {
137 if i + 1 < gates.len() {
138 if let (Some(cnot1), Some(cnot2)) = (
139 gates[i].as_any().downcast_ref::<CNOT>(),
140 gates[i + 1].as_any().downcast_ref::<CNOT>(),
141 ) {
142 if cnot1.control == cnot2.control && cnot1.target == cnot2.target {
144 i += 2;
146 continue;
147 }
148 else if cnot1.control == cnot2.target && cnot1.target == cnot2.control {
150 fused.push(Box::new(SWAP {
151 qubit1: cnot1.control,
152 qubit2: cnot1.target,
153 }) as Box<dyn GateOp>);
154 i += 2;
155 continue;
156 }
157 }
158 }
159
160 fused.push(gates[i].clone_gate());
162 i += 1;
163 }
164
165 if fused.len() < gates.len() {
166 Ok(Some(fused))
167 } else {
168 Ok(None)
169 }
170 }
171
172 fn fuse_rotation_gates(
174 &self,
175 gates: &[Box<dyn GateOp>],
176 ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
177 if gates.len() < 2 {
178 return Ok(None);
179 }
180
181 let first_gate = &gates[0];
183 let target_qubit = first_gate.qubits()[0];
184
185 match first_gate.name() {
186 "RX" => {
187 let mut total_angle = 0.0;
188 for gate in gates {
189 if let Some(rx) = gate.as_any().downcast_ref::<RotationX>() {
190 if rx.target != target_qubit {
191 return Ok(None);
192 }
193 total_angle += rx.theta;
194 } else {
195 return Ok(None);
196 }
197 }
198 Ok(Some(Box::new(RotationX {
199 target: target_qubit,
200 theta: total_angle,
201 })))
202 }
203 "RY" => {
204 let mut total_angle = 0.0;
205 for gate in gates {
206 if let Some(ry) = gate.as_any().downcast_ref::<RotationY>() {
207 if ry.target != target_qubit {
208 return Ok(None);
209 }
210 total_angle += ry.theta;
211 } else {
212 return Ok(None);
213 }
214 }
215 Ok(Some(Box::new(RotationY {
216 target: target_qubit,
217 theta: total_angle,
218 })))
219 }
220 "RZ" => {
221 let mut total_angle = 0.0;
222 for gate in gates {
223 if let Some(rz) = gate.as_any().downcast_ref::<RotationZ>() {
224 if rz.target != target_qubit {
225 return Ok(None);
226 }
227 total_angle += rz.theta;
228 } else {
229 return Ok(None);
230 }
231 }
232 Ok(Some(Box::new(RotationZ {
233 target: target_qubit,
234 theta: total_angle,
235 })))
236 }
237 _ => Ok(None),
238 }
239 }
240
241 fn find_fusable_sequences(&self, gates: &[Box<dyn GateOp>]) -> Vec<(usize, usize)> {
243 let mut sequences = Vec::new();
244 let mut i = 0;
245
246 while i < gates.len() {
247 if gates[i].qubits().len() == 1 {
249 let target_qubit = gates[i].qubits()[0];
250 let mut j = i + 1;
251
252 while j < gates.len() && j - i < self.max_fusion_size {
253 if gates[j].qubits().len() == 1 && gates[j].qubits()[0] == target_qubit {
254 j += 1;
255 } else {
256 break;
257 }
258 }
259
260 if j > i + 1 {
261 sequences.push((i, j));
262 i = j;
263 continue;
264 }
265 }
266
267 if gates[i].name() == "CNOT" && i + 1 < gates.len() && gates[i + 1].name() == "CNOT" {
269 sequences.push((i, i + 2));
270 i += 2;
271 continue;
272 }
273
274 i += 1;
275 }
276
277 sequences
278 }
279}
280
281impl OptimizationPass for GateFusion {
282 fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
283 let mut optimized = Vec::new();
284 let mut processed = vec![false; gates.len()];
285
286 let sequences = self.find_fusable_sequences(&gates);
288
289 for (start, end) in sequences {
290 let sequence = &gates[start..end];
291
292 if processed[start] {
294 continue;
295 }
296
297 let mut fused = false;
299
300 if let Some(fused_gate) = self.fuse_rotation_gates(sequence)? {
302 optimized.push(fused_gate);
303 fused = true;
304 }
305 else if sequence.iter().all(|g| g.name() == "CNOT") {
307 if let Some(fused_gates) = self.fuse_cnot_gates(sequence)? {
308 optimized.extend(fused_gates);
309 fused = true;
310 }
311 }
312 else if self.fuse_single_qubit && sequence.iter().all(|g| g.qubits().len() == 1) {
314 if let Some(fused_gate) = self.fuse_single_qubit_gates(sequence)? {
315 optimized.push(fused_gate);
316 fused = true;
317 }
318 }
319
320 if fused {
322 for i in start..end {
323 processed[i] = true;
324 }
325 }
326 }
327
328 for (i, gate) in gates.into_iter().enumerate() {
330 if !processed[i] {
331 optimized.push(gate);
332 }
333 }
334
335 Ok(optimized)
336 }
337
338 fn name(&self) -> &str {
339 "Gate Fusion"
340 }
341}
342
343pub struct CliffordFusion {
345 tolerance: f64,
346}
347
348impl CliffordFusion {
349 pub fn new() -> Self {
350 Self { tolerance: 1e-10 }
351 }
352
353 fn fuse_clifford_pair(
355 &self,
356 gate1: &dyn GateOp,
357 gate2: &dyn GateOp,
358 ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
359 if gate1.qubits() != gate2.qubits() || gate1.qubits().len() != 1 {
361 return Ok(None);
362 }
363
364 let qubit = gate1.qubits()[0];
365
366 match (gate1.name(), gate2.name()) {
367 ("H", "H") => Ok(None), ("S", "S") => Ok(Some(Box::new(PauliZ { target: qubit }))),
372 ("S†", "S†") => Ok(Some(Box::new(PauliZ { target: qubit }))),
373 ("S", "S†") | ("S†", "S") => Ok(None), ("X", "X") | ("Y", "Y") | ("Z", "Z") => Ok(None), ("X", "Y") => Ok(Some(Box::new(PauliZ { target: qubit }))), ("Y", "X") => Ok(Some(Box::new(PauliZ { target: qubit }))), ("X", "Z") => Ok(Some(Box::new(PauliY { target: qubit }))), ("Z", "X") => Ok(Some(Box::new(PauliY { target: qubit }))), ("Y", "Z") => Ok(Some(Box::new(PauliX { target: qubit }))), ("Z", "Y") => Ok(Some(Box::new(PauliX { target: qubit }))), ("H", "X") => Ok(Some(Box::new(PauliZ { target: qubit }))), ("H", "Z") => Ok(Some(Box::new(PauliX { target: qubit }))), _ => Ok(None),
389 }
390 }
391}
392
393impl OptimizationPass for CliffordFusion {
394 fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
395 let mut optimized = Vec::new();
396 let mut i = 0;
397
398 while i < gates.len() {
399 if i + 1 < gates.len() {
400 if let Some(fused) =
401 self.fuse_clifford_pair(gates[i].as_ref(), gates[i + 1].as_ref())?
402 {
403 optimized.push(fused);
404 i += 2;
405 continue;
406 } else if gates[i].qubits() == gates[i + 1].qubits() {
407 let combined_is_identity = match (gates[i].name(), gates[i + 1].name()) {
409 ("H", "H")
410 | ("S", "S†")
411 | ("S†", "S")
412 | ("X", "X")
413 | ("Y", "Y")
414 | ("Z", "Z") => true,
415 _ => false,
416 };
417
418 if combined_is_identity {
419 i += 2;
420 continue;
421 }
422 }
423 }
424
425 optimized.push(gates[i].clone_gate());
426 i += 1;
427 }
428
429 Ok(optimized)
430 }
431
432 fn name(&self) -> &str {
433 "Clifford Fusion"
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use crate::gate::single::{Hadamard, Phase, PhaseDagger, TDagger, T};
441
442 #[test]
443 fn test_rotation_fusion() {
444 let fusion = GateFusion::new();
445 let qubit = QubitId(0);
446
447 let gates: Vec<Box<dyn GateOp>> = vec![
448 Box::new(RotationZ {
449 target: qubit,
450 theta: 0.5,
451 }),
452 Box::new(RotationZ {
453 target: qubit,
454 theta: 0.3,
455 }),
456 Box::new(RotationZ {
457 target: qubit,
458 theta: 0.2,
459 }),
460 ];
461
462 let result = fusion.fuse_rotation_gates(&gates).unwrap();
463 assert!(result.is_some());
464
465 if let Some(rz) = result.unwrap().as_any().downcast_ref::<RotationZ>() {
466 assert!((rz.theta - 1.0).abs() < 1e-10);
467 } else {
468 panic!("Expected RotationZ gate");
469 }
470 }
471
472 #[test]
473 fn test_cnot_cancellation() {
474 let fusion = GateFusion::new();
475 let q0 = QubitId(0);
476 let q1 = QubitId(1);
477
478 let gates: Vec<Box<dyn GateOp>> = vec![
479 Box::new(CNOT {
480 control: q0,
481 target: q1,
482 }),
483 Box::new(CNOT {
484 control: q0,
485 target: q1,
486 }),
487 ];
488
489 let result = fusion.fuse_cnot_gates(&gates).unwrap();
490 assert!(result.is_some());
491 assert_eq!(result.unwrap().len(), 0); }
493
494 #[test]
495 fn test_cnot_to_swap() {
496 let fusion = GateFusion::new();
497 let q0 = QubitId(0);
498 let q1 = QubitId(1);
499
500 let gates: Vec<Box<dyn GateOp>> = vec![
501 Box::new(CNOT {
502 control: q0,
503 target: q1,
504 }),
505 Box::new(CNOT {
506 control: q1,
507 target: q0,
508 }),
509 ];
510
511 let result = fusion.fuse_cnot_gates(&gates).unwrap();
512 assert!(result.is_some());
513 let fused = result.unwrap();
514 assert_eq!(fused.len(), 1);
515 assert_eq!(fused[0].name(), "SWAP");
516 }
517
518 #[test]
519 fn test_clifford_fusion() {
520 let fusion = CliffordFusion::new();
521 let qubit = QubitId(0);
522
523 let gates: Vec<Box<dyn GateOp>> = vec![
524 Box::new(Hadamard { target: qubit }),
525 Box::new(Hadamard { target: qubit }),
526 Box::new(Phase { target: qubit }),
527 Box::new(Phase { target: qubit }),
528 ];
529
530 let result = fusion.optimize(gates).unwrap();
531 assert_eq!(result.len(), 1);
533 assert_eq!(result[0].name(), "Z");
534 }
535
536 #[test]
537 fn test_full_optimization() {
538 let mut chain = super::super::OptimizationChain::new();
539 chain = chain
540 .add_pass(Box::new(CliffordFusion::new()))
541 .add_pass(Box::new(GateFusion::new()));
542
543 let q0 = QubitId(0);
544 let q1 = QubitId(1);
545
546 let gates: Vec<Box<dyn GateOp>> = vec![
547 Box::new(Hadamard { target: q0 }),
548 Box::new(Hadamard { target: q0 }),
549 Box::new(CNOT {
550 control: q0,
551 target: q1,
552 }),
553 Box::new(CNOT {
554 control: q0,
555 target: q1,
556 }),
557 Box::new(RotationZ {
558 target: q1,
559 theta: 0.5,
560 }),
561 Box::new(RotationZ {
562 target: q1,
563 theta: 0.5,
564 }),
565 ];
566
567 let result = chain.optimize(gates).unwrap();
568
569 assert_eq!(result.len(), 1);
572 assert_eq!(result[0].name(), "RZ");
573
574 if let Some(rz) = result[0].as_any().downcast_ref::<RotationZ>() {
576 assert!((rz.theta - 1.0).abs() < 1e-10);
577 }
578 }
579}