quantrs2_core/optimization/
fusion.rs1use crate::error::{QuantRS2Error, QuantRS2Result};
7use crate::gate::{multi::*, single::*, GateOp};
8use crate::synthesis::{identify_gate, synthesize_unitary};
9use ndarray::Array2;
10
11use super::OptimizationPass;
12
13pub struct GateFusion {
15 pub fuse_single_qubit: bool,
17 pub fuse_two_qubit: bool,
19 pub max_fusion_size: usize,
21 pub tolerance: f64,
23}
24
25impl Default for GateFusion {
26 fn default() -> Self {
27 Self {
28 fuse_single_qubit: true,
29 fuse_two_qubit: true,
30 max_fusion_size: 4,
31 tolerance: 1e-10,
32 }
33 }
34}
35
36impl GateFusion {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 fn fuse_single_qubit_gates(
44 &self,
45 gates: &[Box<dyn GateOp>],
46 ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
47 if gates.is_empty() {
48 return Ok(None);
49 }
50
51 let target_qubit = gates[0].qubits()[0];
53 if !gates
54 .iter()
55 .all(|g| g.qubits().len() == 1 && g.qubits()[0] == target_qubit)
56 {
57 return Ok(None);
58 }
59
60 let mut combined = Array2::eye(2);
62 for gate in gates {
63 let gate_matrix = gate.matrix()?;
64 let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)
65 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
66 combined = combined.dot(&gate_array);
67 }
68
69 if let Some(gate_name) = identify_gate(&combined.view(), self.tolerance) {
71 let identified_gate = match gate_name.as_str() {
73 "I" => None, "X" => Some(Box::new(PauliX {
75 target: target_qubit,
76 }) as Box<dyn GateOp>),
77 "Y" => Some(Box::new(PauliY {
78 target: target_qubit,
79 }) as Box<dyn GateOp>),
80 "Z" => Some(Box::new(PauliZ {
81 target: target_qubit,
82 }) as Box<dyn GateOp>),
83 "H" => Some(Box::new(Hadamard {
84 target: target_qubit,
85 }) as Box<dyn GateOp>),
86 "S" => Some(Box::new(Phase {
87 target: target_qubit,
88 }) as Box<dyn GateOp>),
89 "S†" => Some(Box::new(PhaseDagger {
90 target: target_qubit,
91 }) as Box<dyn GateOp>),
92 "T" => Some(Box::new(T {
93 target: target_qubit,
94 }) as Box<dyn GateOp>),
95 "T†" => Some(Box::new(TDagger {
96 target: target_qubit,
97 }) as Box<dyn GateOp>),
98 _ => None,
99 };
100
101 if let Some(gate) = identified_gate {
102 return Ok(Some(gate));
103 }
104 }
105
106 let synthesized = synthesize_unitary(&combined.view(), &[target_qubit])?;
108 if synthesized.len() < gates.len() {
109 if synthesized.len() == 1 {
111 Ok(Some(synthesized.into_iter().next().unwrap()))
112 } else {
113 Ok(None)
114 }
115 } else {
116 Ok(None)
117 }
118 }
119
120 fn fuse_cnot_gates(
122 &self,
123 gates: &[Box<dyn GateOp>],
124 ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
125 if gates.len() < 2 {
126 return Ok(None);
127 }
128
129 let mut fused = Vec::new();
130 let mut i = 0;
131
132 while i < gates.len() {
133 if i + 1 < gates.len() {
134 if let (Some(cnot1), Some(cnot2)) = (
135 gates[i].as_any().downcast_ref::<CNOT>(),
136 gates[i + 1].as_any().downcast_ref::<CNOT>(),
137 ) {
138 if cnot1.control == cnot2.control && cnot1.target == cnot2.target {
140 i += 2;
142 continue;
143 }
144 else if cnot1.control == cnot2.target && cnot1.target == cnot2.control {
146 fused.push(Box::new(SWAP {
147 qubit1: cnot1.control,
148 qubit2: cnot1.target,
149 }) as Box<dyn GateOp>);
150 i += 2;
151 continue;
152 }
153 }
154 }
155
156 fused.push(gates[i].clone_gate());
158 i += 1;
159 }
160
161 if fused.len() < gates.len() {
162 Ok(Some(fused))
163 } else {
164 Ok(None)
165 }
166 }
167
168 fn fuse_rotation_gates(
170 &self,
171 gates: &[Box<dyn GateOp>],
172 ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
173 if gates.len() < 2 {
174 return Ok(None);
175 }
176
177 let first_gate = &gates[0];
179 let target_qubit = first_gate.qubits()[0];
180
181 match first_gate.name() {
182 "RX" => {
183 let mut total_angle = 0.0;
184 for gate in gates {
185 if let Some(rx) = gate.as_any().downcast_ref::<RotationX>() {
186 if rx.target != target_qubit {
187 return Ok(None);
188 }
189 total_angle += rx.theta;
190 } else {
191 return Ok(None);
192 }
193 }
194 Ok(Some(Box::new(RotationX {
195 target: target_qubit,
196 theta: total_angle,
197 })))
198 }
199 "RY" => {
200 let mut total_angle = 0.0;
201 for gate in gates {
202 if let Some(ry) = gate.as_any().downcast_ref::<RotationY>() {
203 if ry.target != target_qubit {
204 return Ok(None);
205 }
206 total_angle += ry.theta;
207 } else {
208 return Ok(None);
209 }
210 }
211 Ok(Some(Box::new(RotationY {
212 target: target_qubit,
213 theta: total_angle,
214 })))
215 }
216 "RZ" => {
217 let mut total_angle = 0.0;
218 for gate in gates {
219 if let Some(rz) = gate.as_any().downcast_ref::<RotationZ>() {
220 if rz.target != target_qubit {
221 return Ok(None);
222 }
223 total_angle += rz.theta;
224 } else {
225 return Ok(None);
226 }
227 }
228 Ok(Some(Box::new(RotationZ {
229 target: target_qubit,
230 theta: total_angle,
231 })))
232 }
233 _ => Ok(None),
234 }
235 }
236
237 fn find_fusable_sequences(&self, gates: &[Box<dyn GateOp>]) -> Vec<(usize, usize)> {
239 let mut sequences = Vec::new();
240 let mut i = 0;
241
242 while i < gates.len() {
243 if gates[i].qubits().len() == 1 {
245 let target_qubit = gates[i].qubits()[0];
246 let mut j = i + 1;
247
248 while j < gates.len() && j - i < self.max_fusion_size {
249 if gates[j].qubits().len() == 1 && gates[j].qubits()[0] == target_qubit {
250 j += 1;
251 } else {
252 break;
253 }
254 }
255
256 if j > i + 1 {
257 sequences.push((i, j));
258 i = j;
259 continue;
260 }
261 }
262
263 if gates[i].name() == "CNOT" && i + 1 < gates.len() && gates[i + 1].name() == "CNOT" {
265 sequences.push((i, i + 2));
266 i += 2;
267 continue;
268 }
269
270 i += 1;
271 }
272
273 sequences
274 }
275}
276
277impl OptimizationPass for GateFusion {
278 fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
279 let mut optimized = Vec::new();
280 let mut processed = vec![false; gates.len()];
281
282 let sequences = self.find_fusable_sequences(&gates);
284
285 for (start, end) in sequences {
286 let sequence = &gates[start..end];
287
288 if processed[start] {
290 continue;
291 }
292
293 let mut fused = false;
295
296 if let Some(fused_gate) = self.fuse_rotation_gates(sequence)? {
298 optimized.push(fused_gate);
299 fused = true;
300 }
301 else if sequence.iter().all(|g| g.name() == "CNOT") {
303 if let Some(fused_gates) = self.fuse_cnot_gates(sequence)? {
304 optimized.extend(fused_gates);
305 fused = true;
306 }
307 }
308 else if self.fuse_single_qubit && sequence.iter().all(|g| g.qubits().len() == 1) {
310 if let Some(fused_gate) = self.fuse_single_qubit_gates(sequence)? {
311 optimized.push(fused_gate);
312 fused = true;
313 }
314 }
315
316 if fused {
318 for i in start..end {
319 processed[i] = true;
320 }
321 }
322 }
323
324 for (i, gate) in gates.into_iter().enumerate() {
326 if !processed[i] {
327 optimized.push(gate);
328 }
329 }
330
331 Ok(optimized)
332 }
333
334 fn name(&self) -> &str {
335 "Gate Fusion"
336 }
337}
338
339pub struct CliffordFusion {
341 #[allow(dead_code)]
342 tolerance: f64,
343}
344
345impl CliffordFusion {
346 pub fn new() -> Self {
347 Self { tolerance: 1e-10 }
348 }
349
350 fn fuse_clifford_pair(
352 &self,
353 gate1: &dyn GateOp,
354 gate2: &dyn GateOp,
355 ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
356 if gate1.qubits() != gate2.qubits() || gate1.qubits().len() != 1 {
358 return Ok(None);
359 }
360
361 let qubit = gate1.qubits()[0];
362
363 match (gate1.name(), gate2.name()) {
364 ("H", "H") => Ok(None), ("S", "S") => Ok(Some(Box::new(PauliZ { target: qubit }))),
369 ("S†", "S†") => Ok(Some(Box::new(PauliZ { target: qubit }))),
370 ("S", "S†") | ("S†", "S") => Ok(None), ("X", "X") | ("Y", "Y") | ("Z", "Z") => Ok(None), ("X", "Y") => Ok(Some(Box::new(PauliZ { target: qubit }))), ("Y", "X") => Ok(Some(Box::new(PauliZ { target: qubit }))), ("X", "Z") => Ok(Some(Box::new(PauliY { target: qubit }))), ("Z", "X") => Ok(Some(Box::new(PauliY { target: qubit }))), ("Y", "Z") => Ok(Some(Box::new(PauliX { target: qubit }))), ("Z", "Y") => Ok(Some(Box::new(PauliX { target: qubit }))), ("H", "X") => Ok(Some(Box::new(PauliZ { target: qubit }))), ("H", "Z") => Ok(Some(Box::new(PauliX { target: qubit }))), _ => Ok(None),
386 }
387 }
388}
389
390impl OptimizationPass for CliffordFusion {
391 fn optimize(&self, gates: Vec<Box<dyn GateOp>>) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
392 let mut optimized = Vec::new();
393 let mut i = 0;
394
395 while i < gates.len() {
396 if i + 1 < gates.len() {
397 if let Some(fused) =
398 self.fuse_clifford_pair(gates[i].as_ref(), gates[i + 1].as_ref())?
399 {
400 optimized.push(fused);
401 i += 2;
402 continue;
403 } else if gates[i].qubits() == gates[i + 1].qubits() {
404 let combined_is_identity = match (gates[i].name(), gates[i + 1].name()) {
406 ("H", "H")
407 | ("S", "S†")
408 | ("S†", "S")
409 | ("X", "X")
410 | ("Y", "Y")
411 | ("Z", "Z") => true,
412 _ => false,
413 };
414
415 if combined_is_identity {
416 i += 2;
417 continue;
418 }
419 }
420 }
421
422 optimized.push(gates[i].clone_gate());
423 i += 1;
424 }
425
426 Ok(optimized)
427 }
428
429 fn name(&self) -> &str {
430 "Clifford Fusion"
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use crate::gate::single::{Hadamard, Phase};
438 use crate::prelude::QubitId;
439
440 #[test]
441 fn test_rotation_fusion() {
442 let fusion = GateFusion::new();
443 let qubit = QubitId(0);
444
445 let gates: Vec<Box<dyn GateOp>> = vec![
446 Box::new(RotationZ {
447 target: qubit,
448 theta: 0.5,
449 }),
450 Box::new(RotationZ {
451 target: qubit,
452 theta: 0.3,
453 }),
454 Box::new(RotationZ {
455 target: qubit,
456 theta: 0.2,
457 }),
458 ];
459
460 let result = fusion.fuse_rotation_gates(&gates).unwrap();
461 assert!(result.is_some());
462
463 if let Some(rz) = result.unwrap().as_any().downcast_ref::<RotationZ>() {
464 assert!((rz.theta - 1.0).abs() < 1e-10);
465 } else {
466 panic!("Expected RotationZ gate");
467 }
468 }
469
470 #[test]
471 fn test_cnot_cancellation() {
472 let fusion = GateFusion::new();
473 let q0 = QubitId(0);
474 let q1 = QubitId(1);
475
476 let gates: Vec<Box<dyn GateOp>> = vec![
477 Box::new(CNOT {
478 control: q0,
479 target: q1,
480 }),
481 Box::new(CNOT {
482 control: q0,
483 target: q1,
484 }),
485 ];
486
487 let result = fusion.fuse_cnot_gates(&gates).unwrap();
488 assert!(result.is_some());
489 assert_eq!(result.unwrap().len(), 0); }
491
492 #[test]
493 fn test_cnot_to_swap() {
494 let fusion = GateFusion::new();
495 let q0 = QubitId(0);
496 let q1 = QubitId(1);
497
498 let gates: Vec<Box<dyn GateOp>> = vec![
499 Box::new(CNOT {
500 control: q0,
501 target: q1,
502 }),
503 Box::new(CNOT {
504 control: q1,
505 target: q0,
506 }),
507 ];
508
509 let result = fusion.fuse_cnot_gates(&gates).unwrap();
510 assert!(result.is_some());
511 let fused = result.unwrap();
512 assert_eq!(fused.len(), 1);
513 assert_eq!(fused[0].name(), "SWAP");
514 }
515
516 #[test]
517 fn test_clifford_fusion() {
518 let fusion = CliffordFusion::new();
519 let qubit = QubitId(0);
520
521 let gates: Vec<Box<dyn GateOp>> = vec![
522 Box::new(Hadamard { target: qubit }),
523 Box::new(Hadamard { target: qubit }),
524 Box::new(Phase { target: qubit }),
525 Box::new(Phase { target: qubit }),
526 ];
527
528 let result = fusion.optimize(gates).unwrap();
529 assert_eq!(result.len(), 1);
531 assert_eq!(result[0].name(), "Z");
532 }
533
534 #[test]
535 fn test_full_optimization() {
536 let mut chain = super::super::OptimizationChain::new();
537 chain = chain
538 .add_pass(Box::new(CliffordFusion::new()))
539 .add_pass(Box::new(GateFusion::new()));
540
541 let q0 = QubitId(0);
542 let q1 = QubitId(1);
543
544 let gates: Vec<Box<dyn GateOp>> = vec![
545 Box::new(Hadamard { target: q0 }),
546 Box::new(Hadamard { target: q0 }),
547 Box::new(CNOT {
548 control: q0,
549 target: q1,
550 }),
551 Box::new(CNOT {
552 control: q0,
553 target: q1,
554 }),
555 Box::new(RotationZ {
556 target: q1,
557 theta: 0.5,
558 }),
559 Box::new(RotationZ {
560 target: q1,
561 theta: 0.5,
562 }),
563 ];
564
565 let result = chain.optimize(gates).unwrap();
566
567 assert_eq!(result.len(), 1);
570 assert_eq!(result[0].name(), "RZ");
571
572 if let Some(rz) = result[0].as_any().downcast_ref::<RotationZ>() {
574 assert!((rz.theta - 1.0).abs() < 1e-10);
575 }
576 }
577}