1use crate::optimization::cost_model::CostModel;
4use quantrs2_core::error::{QuantRS2Error, QuantRS2Result};
5use quantrs2_core::gate::{
6 multi,
7 single::{self},
8 GateOp,
9};
10use quantrs2_core::qubit::QubitId;
11use std::collections::{HashMap, HashSet};
12use std::f64::consts::PI;
13
14use super::OptimizationPass;
15
16pub struct PeepholeOptimization {
18 window_size: usize,
19 patterns: Vec<PeepholePattern>,
20}
21
22#[derive(Clone)]
23pub struct PeepholePattern {
24 name: String,
25 window_size: usize,
26 matcher: fn(&[Box<dyn GateOp>]) -> Option<Vec<Box<dyn GateOp>>>,
27}
28
29impl PeepholeOptimization {
30 #[must_use]
31 pub fn new(window_size: usize) -> Self {
32 let patterns = vec![
33 PeepholePattern {
35 name: "X-Y-X to -Y".to_string(),
36 window_size: 3,
37 matcher: |gates| {
38 if gates.len() >= 3 {
39 let g0 = &gates[0];
40 let g1 = &gates[1];
41 let g2 = &gates[2];
42
43 if g0.name() == "X"
44 && g2.name() == "X"
45 && g1.name() == "Y"
46 && g0.qubits() == g1.qubits()
47 && g1.qubits() == g2.qubits()
48 {
49 return Some(vec![g1.clone()]);
51 }
52 }
53 None
54 },
55 },
56 PeepholePattern {
58 name: "H-S-H simplification".to_string(),
59 window_size: 3,
60 matcher: |gates| {
61 if gates.len() >= 3 {
62 let g0 = &gates[0];
63 let g1 = &gates[1];
64 let g2 = &gates[2];
65
66 if g0.name() == "H"
67 && g2.name() == "H"
68 && g1.name() == "S"
69 && g0.qubits() == g1.qubits()
70 && g1.qubits() == g2.qubits()
71 {
72 let target = g0.qubits()[0];
73 return Some(vec![
74 Box::new(single::PauliX { target }) as Box<dyn GateOp>,
75 Box::new(single::RotationZ {
76 target,
77 theta: PI / 2.0,
78 }) as Box<dyn GateOp>,
79 Box::new(single::PauliX { target }) as Box<dyn GateOp>,
80 ]);
81 }
82 }
83 None
84 },
85 },
86 PeepholePattern {
88 name: "Euler angle optimization".to_string(),
89 window_size: 3,
90 matcher: |gates| {
91 if gates.len() >= 3 {
92 let g0 = &gates[0];
93 let g1 = &gates[1];
94 let g2 = &gates[2];
95
96 if g0.name() == "RZ"
97 && g1.name() == "RX"
98 && g2.name() == "RZ"
99 && g0.qubits() == g1.qubits()
100 && g1.qubits() == g2.qubits()
101 {
102 if let (Some(rz1), Some(rx), Some(rz2)) = (
103 g0.as_any().downcast_ref::<single::RotationZ>(),
104 g1.as_any().downcast_ref::<single::RotationX>(),
105 g2.as_any().downcast_ref::<single::RotationZ>(),
106 ) {
107 if rx.theta.abs() < 1e-10 {
109 let combined_angle = rz1.theta + rz2.theta;
110 if combined_angle.abs() < 1e-10 {
111 return Some(vec![]); }
113 return Some(vec![Box::new(single::RotationZ {
114 target: rz1.target,
115 theta: combined_angle,
116 })
117 as Box<dyn GateOp>]);
118 }
119 }
120 }
121 }
122 None
123 },
124 },
125 PeepholePattern {
127 name: "Phase gadget optimization".to_string(),
128 window_size: 3,
129 matcher: |gates| {
130 if gates.len() >= 3 {
131 let g0 = &gates[0];
132 let g1 = &gates[1];
133 let g2 = &gates[2];
134
135 if g0.name() == "CNOT" && g2.name() == "CNOT" && g1.name() == "RZ" {
136 if let (Some(cnot1), Some(rz), Some(cnot2)) = (
137 g0.as_any().downcast_ref::<multi::CNOT>(),
138 g1.as_any().downcast_ref::<single::RotationZ>(),
139 g2.as_any().downcast_ref::<multi::CNOT>(),
140 ) {
141 if cnot1.control == cnot2.control
142 && cnot1.target == cnot2.target
143 && rz.target == cnot1.target
144 {
145 return None;
147 }
148 }
149 }
150 }
151 None
152 },
153 },
154 PeepholePattern {
156 name: "Hadamard ladder".to_string(),
157 window_size: 4,
158 matcher: |gates| {
159 if gates.len() >= 4 {
160 if gates[0].name() == "H"
162 && gates[1].name() == "CNOT"
163 && gates[2].name() == "H"
164 && gates[3].name() == "CNOT"
165 {
166 let h1_target = gates[0].qubits()[0];
167 let h2_target = gates[2].qubits()[0];
168
169 if let (Some(cnot1), Some(cnot2)) = (
170 gates[1].as_any().downcast_ref::<multi::CNOT>(),
171 gates[3].as_any().downcast_ref::<multi::CNOT>(),
172 ) {
173 if h1_target == cnot1.control
174 && h2_target == cnot2.control
175 && cnot1.target == cnot2.target
176 {
177 return None; }
179 }
180 }
181 }
182 None
183 },
184 },
185 ];
186
187 Self {
188 window_size,
189 patterns,
190 }
191 }
192
193 fn apply_patterns(&self, window: &[Box<dyn GateOp>]) -> Option<Vec<Box<dyn GateOp>>> {
195 for pattern in &self.patterns {
196 if window.len() >= pattern.window_size {
197 if let Some(replacement) = (pattern.matcher)(window) {
198 return Some(replacement);
199 }
200 }
201 }
202 None
203 }
204}
205
206impl OptimizationPass for PeepholeOptimization {
207 fn name(&self) -> &'static str {
208 "Peephole Optimization"
209 }
210
211 fn apply_to_gates(
212 &self,
213 gates: Vec<Box<dyn GateOp>>,
214 _cost_model: &dyn CostModel,
215 ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
216 let mut optimized = Vec::new();
217 let mut i = 0;
218
219 while i < gates.len() {
220 let mut matched = false;
221
222 for window_size in (2..=self.window_size).rev() {
223 if i + window_size <= gates.len() {
224 let window = &gates[i..i + window_size];
225
226 if let Some(replacement) = self.apply_patterns(window) {
227 optimized.extend(replacement);
228 i += window_size;
229 matched = true;
230 break;
231 }
232 }
233 }
234
235 if !matched {
236 optimized.push(gates[i].clone());
237 i += 1;
238 }
239 }
240
241 Ok(optimized)
242 }
243}
244
245pub struct TemplateMatching {
247 templates: Vec<CircuitTemplate>,
248}
249
250#[derive(Clone)]
251pub struct CircuitTemplate {
252 name: String,
253 pattern: Vec<String>,
254 replacement: Vec<String>,
255 cost_reduction: f64,
256}
257
258impl TemplateMatching {
259 #[must_use]
260 pub fn new() -> Self {
261 let templates = vec![
262 CircuitTemplate {
263 name: "H-Z-H to X".to_string(),
264 pattern: vec!["H".to_string(), "Z".to_string(), "H".to_string()],
265 replacement: vec!["X".to_string()],
266 cost_reduction: 2.0,
267 },
268 CircuitTemplate {
269 name: "H-X-H to Z".to_string(),
270 pattern: vec!["H".to_string(), "X".to_string(), "H".to_string()],
271 replacement: vec!["Z".to_string()],
272 cost_reduction: 2.0,
273 },
274 CircuitTemplate {
275 name: "CNOT-H-CNOT to CZ".to_string(),
276 pattern: vec!["CNOT".to_string(), "H".to_string(), "CNOT".to_string()],
277 replacement: vec!["CZ".to_string()],
278 cost_reduction: 1.5,
279 },
280 CircuitTemplate {
281 name: "Double CNOT elimination".to_string(),
282 pattern: vec!["CNOT".to_string(), "CNOT".to_string()],
283 replacement: vec![],
284 cost_reduction: 2.0,
285 },
286 CircuitTemplate {
287 name: "S-S to Z".to_string(),
288 pattern: vec!["S".to_string(), "S".to_string()],
289 replacement: vec!["Z".to_string()],
290 cost_reduction: 1.0,
291 },
292 ];
293
294 Self { templates }
295 }
296
297 #[must_use]
298 pub const fn with_templates(templates: Vec<CircuitTemplate>) -> Self {
299 Self { templates }
300 }
301
302 #[must_use]
304 pub fn with_advanced_templates() -> Self {
305 let templates = vec![
306 CircuitTemplate {
307 name: "H-Z-H to X".to_string(),
308 pattern: vec!["H".to_string(), "Z".to_string(), "H".to_string()],
309 replacement: vec!["X".to_string()],
310 cost_reduction: 2.0,
311 },
312 CircuitTemplate {
313 name: "H-X-H to Z".to_string(),
314 pattern: vec!["H".to_string(), "X".to_string(), "H".to_string()],
315 replacement: vec!["Z".to_string()],
316 cost_reduction: 2.0,
317 },
318 CircuitTemplate {
319 name: "CNOT-CNOT elimination".to_string(),
320 pattern: vec!["CNOT".to_string(), "CNOT".to_string()],
321 replacement: vec![],
322 cost_reduction: 2.0,
323 },
324 CircuitTemplate {
325 name: "S-S to Z".to_string(),
326 pattern: vec!["S".to_string(), "S".to_string()],
327 replacement: vec!["Z".to_string()],
328 cost_reduction: 1.0,
329 },
330 CircuitTemplate {
331 name: "T-T-T-T to Identity".to_string(),
332 pattern: vec![
333 "T".to_string(),
334 "T".to_string(),
335 "T".to_string(),
336 "T".to_string(),
337 ],
338 replacement: vec![],
339 cost_reduction: 4.0,
340 },
341 CircuitTemplate {
342 name: "CNOT-H-CNOT to CZ".to_string(),
343 pattern: vec!["CNOT".to_string(), "H".to_string(), "CNOT".to_string()],
344 replacement: vec!["CZ".to_string()],
345 cost_reduction: 1.0,
346 },
347 CircuitTemplate {
348 name: "SWAP via 3 CNOTs".to_string(),
349 pattern: vec!["CNOT".to_string(), "CNOT".to_string(), "CNOT".to_string()],
350 replacement: vec!["SWAP".to_string()],
351 cost_reduction: 0.5,
352 },
353 ];
354
355 Self { templates }
356 }
357
358 #[must_use]
360 pub fn for_hardware(hardware: &str) -> Self {
361 let templates = match hardware {
362 "ibm" => vec![
363 CircuitTemplate {
364 name: "H-Z-H to X".to_string(),
365 pattern: vec!["H".to_string(), "Z".to_string(), "H".to_string()],
366 replacement: vec!["X".to_string()],
367 cost_reduction: 2.0,
368 },
369 CircuitTemplate {
370 name: "CNOT-CNOT elimination".to_string(),
371 pattern: vec!["CNOT".to_string(), "CNOT".to_string()],
372 replacement: vec![],
373 cost_reduction: 2.0,
374 },
375 ],
376 "google" => vec![CircuitTemplate {
377 name: "CNOT to CZ with Hadamards".to_string(),
378 pattern: vec!["CNOT".to_string()],
379 replacement: vec!["H".to_string(), "CZ".to_string(), "H".to_string()],
380 cost_reduction: -0.5,
381 }],
382 _ => Self::new().templates,
383 };
384
385 Self { templates }
386 }
387}
388
389impl OptimizationPass for TemplateMatching {
390 fn name(&self) -> &'static str {
391 "Template Matching"
392 }
393
394 fn apply_to_gates(
395 &self,
396 gates: Vec<Box<dyn GateOp>>,
397 cost_model: &dyn CostModel,
398 ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
399 let mut optimized = gates;
400 let mut changed = true;
401
402 while changed {
403 changed = false;
404 let original_cost = cost_model.gates_cost(&optimized);
405
406 for template in &self.templates {
407 let result = self.apply_template(template, optimized.clone())?;
408 let new_cost = cost_model.gates_cost(&result);
409
410 if new_cost < original_cost {
411 optimized = result;
412 changed = true;
413 break;
414 }
415 }
416 }
417
418 Ok(optimized)
419 }
420}
421
422impl TemplateMatching {
423 fn apply_template(
424 &self,
425 template: &CircuitTemplate,
426 gates: Vec<Box<dyn GateOp>>,
427 ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
428 let mut result = Vec::new();
429 let mut i = 0;
430
431 while i < gates.len() {
432 if let Some(replacement) = self.match_pattern_at_position(template, &gates, i)? {
433 result.extend(replacement);
434 i += template.pattern.len();
435 } else {
436 result.push(gates[i].clone());
437 i += 1;
438 }
439 }
440
441 Ok(result)
442 }
443
444 fn match_pattern_at_position(
445 &self,
446 template: &CircuitTemplate,
447 gates: &[Box<dyn GateOp>],
448 start: usize,
449 ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
450 if start + template.pattern.len() > gates.len() {
451 return Ok(None);
452 }
453
454 let mut qubit_mapping = HashMap::new();
455 let mut all_qubits = Vec::new();
456 let mut is_match = true;
457
458 for (i, pattern_gate) in template.pattern.iter().enumerate() {
459 let gate = &gates[start + i];
460
461 if !self.gate_matches_pattern(gate.as_ref(), pattern_gate, &qubit_mapping) {
462 is_match = false;
463 break;
464 }
465
466 for qubit in gate.qubits() {
467 if !all_qubits.contains(&qubit) {
468 all_qubits.push(qubit);
469 }
470 }
471 }
472
473 if !is_match {
474 return Ok(None);
475 }
476
477 if template
478 .pattern
479 .iter()
480 .all(|p| p == "H" || p == "X" || p == "Y" || p == "Z" || p == "S" || p == "T")
481 {
482 let first_qubit = gates[start].qubits();
483 if first_qubit.len() != 1 {
484 return Ok(None);
485 }
486
487 for i in 1..template.pattern.len() {
488 let gate_qubits = gates[start + i].qubits();
489 if gate_qubits != first_qubit {
490 return Ok(None);
491 }
492 }
493 }
494
495 qubit_mapping.insert("qubits".to_string(), all_qubits);
496 self.generate_replacement_gates(template, &qubit_mapping)
497 }
498
499 fn gate_matches_pattern(
500 &self,
501 gate: &dyn GateOp,
502 pattern: &str,
503 _qubit_mapping: &HashMap<String, Vec<QubitId>>,
504 ) -> bool {
505 gate.name() == pattern
506 }
507
508 fn generate_replacement_gates(
509 &self,
510 template: &CircuitTemplate,
511 qubit_mapping: &HashMap<String, Vec<QubitId>>,
512 ) -> QuantRS2Result<Option<Vec<Box<dyn GateOp>>>> {
513 let mut replacement_gates = Vec::new();
514
515 let qubits: Vec<QubitId> = qubit_mapping
516 .values()
517 .flat_map(|v| v.iter().copied())
518 .collect();
519 let mut unique_qubits: Vec<QubitId> = Vec::new();
520 for qubit in qubits {
521 if !unique_qubits.contains(&qubit) {
522 unique_qubits.push(qubit);
523 }
524 }
525
526 for replacement_pattern in &template.replacement {
527 if let Some(gate) = self.create_simple_gate(replacement_pattern, &unique_qubits)? {
528 replacement_gates.push(gate);
529 }
530 }
531
532 Ok(Some(replacement_gates))
533 }
534
535 fn create_simple_gate(
536 &self,
537 pattern: &str,
538 qubits: &[QubitId],
539 ) -> QuantRS2Result<Option<Box<dyn GateOp>>> {
540 if qubits.is_empty() {
541 return Ok(None);
542 }
543
544 match pattern {
545 "H" => Ok(Some(Box::new(single::Hadamard { target: qubits[0] }))),
546 "X" => Ok(Some(Box::new(single::PauliX { target: qubits[0] }))),
547 "Y" => Ok(Some(Box::new(single::PauliY { target: qubits[0] }))),
548 "Z" => Ok(Some(Box::new(single::PauliZ { target: qubits[0] }))),
549 "S" => Ok(Some(Box::new(single::Phase { target: qubits[0] }))),
550 "T" => Ok(Some(Box::new(single::T { target: qubits[0] }))),
551 "CNOT" if qubits.len() >= 2 => Ok(Some(Box::new(multi::CNOT {
552 control: qubits[0],
553 target: qubits[1],
554 }))),
555 "CZ" if qubits.len() >= 2 => Ok(Some(Box::new(multi::CZ {
556 control: qubits[0],
557 target: qubits[1],
558 }))),
559 "SWAP" if qubits.len() >= 2 => Ok(Some(Box::new(multi::SWAP {
560 qubit1: qubits[0],
561 qubit2: qubits[1],
562 }))),
563 _ => Ok(None),
564 }
565 }
566
567 fn create_gate(&self, gate_name: &str, qubits: &[QubitId]) -> QuantRS2Result<Box<dyn GateOp>> {
569 match (gate_name, qubits.len()) {
570 ("H", 1) => Ok(Box::new(single::Hadamard { target: qubits[0] })),
571 ("X", 1) => Ok(Box::new(single::PauliX { target: qubits[0] })),
572 ("Y", 1) => Ok(Box::new(single::PauliY { target: qubits[0] })),
573 ("Z", 1) => Ok(Box::new(single::PauliZ { target: qubits[0] })),
574 ("S", 1) => Ok(Box::new(single::Phase { target: qubits[0] })),
575 ("T", 1) => Ok(Box::new(single::T { target: qubits[0] })),
576 ("CNOT", 2) => Ok(Box::new(multi::CNOT {
577 control: qubits[0],
578 target: qubits[1],
579 })),
580 ("CZ", 2) => Ok(Box::new(multi::CZ {
581 control: qubits[0],
582 target: qubits[1],
583 })),
584 ("SWAP", 2) => Ok(Box::new(multi::SWAP {
585 qubit1: qubits[0],
586 qubit2: qubits[1],
587 })),
588 _ => Err(QuantRS2Error::UnsupportedOperation(format!(
589 "Cannot create gate {} with {} qubits",
590 gate_name,
591 qubits.len()
592 ))),
593 }
594 }
595}
596
597impl Default for TemplateMatching {
598 fn default() -> Self {
599 Self::new()
600 }
601}
602
603pub fn single_qubit_of(g: &dyn GateOp) -> Option<QubitId> {
607 let qs = g.qubits();
608 if qs.len() == 1 {
609 Some(qs[0])
610 } else {
611 None
612 }
613}
614
615pub fn all_same_single_qubit(gates: &[Box<dyn GateOp>]) -> bool {
617 let first = match single_qubit_of(gates[0].as_ref()) {
618 Some(q) => q,
619 None => return false,
620 };
621 gates[1..]
622 .iter()
623 .all(|g| single_qubit_of(g.as_ref()) == Some(first))
624}
625
626pub fn extract_rz_angle(g: &dyn GateOp) -> Option<f64> {
628 if g.name() != "RZ" {
629 return None;
630 }
631 g.matrix().ok().map(|m| 2.0 * m[0].arg())
632}
633
634pub fn extract_rx_angle(g: &dyn GateOp) -> Option<f64> {
636 if g.name() != "RX" {
637 return None;
638 }
639 g.matrix().ok().map(|m| {
640 let sin_half = -m[1].im;
641 let cos_half = m[0].re;
642 2.0 * sin_half.atan2(cos_half)
643 })
644}
645
646pub fn extract_ry_angle(g: &dyn GateOp) -> Option<f64> {
648 if g.name() != "RY" {
649 return None;
650 }
651 g.matrix().ok().map(|m| {
652 let sin_half = -m[1].re;
653 let cos_half = m[0].re;
654 2.0 * sin_half.atan2(cos_half)
655 })
656}
657
658pub fn normalise_angle(theta: f64) -> f64 {
660 use std::f64::consts::TAU;
661 let t = theta % TAU;
662 if t > PI {
663 t - TAU
664 } else if t <= -PI {
665 t + TAU
666 } else {
667 t
668 }
669}
670
671pub fn is_identity_angle(theta: f64, eps: f64) -> bool {
673 normalise_angle(theta).abs() < eps
674}
675
676fn default_rewrite_rules() -> Vec<RewriteRule> {
677 let hxh_to_z = RewriteRule {
681 name: "H-X-H to Z".to_string(),
682 window_size: 3,
683 condition: |w| {
684 w[0].name() == "H"
685 && w[1].name() == "X"
686 && w[2].name() == "H"
687 && all_same_single_qubit(w)
688 },
689 rewrite: |w| {
690 let t = single_qubit_of(w[0].as_ref()).unwrap_or_else(|| w[0].qubits()[0]);
691 vec![Box::new(single::PauliZ { target: t }) as Box<dyn GateOp>]
692 },
693 };
694
695 let hzh_to_x = RewriteRule {
697 name: "H-Z-H to X".to_string(),
698 window_size: 3,
699 condition: |w| {
700 w[0].name() == "H"
701 && w[1].name() == "Z"
702 && w[2].name() == "H"
703 && all_same_single_qubit(w)
704 },
705 rewrite: |w| {
706 let t = single_qubit_of(w[0].as_ref()).unwrap_or_else(|| w[0].qubits()[0]);
707 vec![Box::new(single::PauliX { target: t }) as Box<dyn GateOp>]
708 },
709 };
710
711 let hyh_to_y = RewriteRule {
713 name: "H-Y-H to Y (global phase)".to_string(),
714 window_size: 3,
715 condition: |w| {
716 w[0].name() == "H"
717 && w[1].name() == "Y"
718 && w[2].name() == "H"
719 && all_same_single_qubit(w)
720 },
721 rewrite: |w| {
722 let t = single_qubit_of(w[1].as_ref()).unwrap_or_else(|| w[1].qubits()[0]);
723 vec![Box::new(single::PauliY { target: t }) as Box<dyn GateOp>]
724 },
725 };
726
727 let xzx_to_z = RewriteRule {
729 name: "X-Z-X to Z (global phase)".to_string(),
730 window_size: 3,
731 condition: |w| {
732 w[0].name() == "X"
733 && w[1].name() == "Z"
734 && w[2].name() == "X"
735 && all_same_single_qubit(w)
736 },
737 rewrite: |w| {
738 let t = single_qubit_of(w[1].as_ref()).unwrap_or_else(|| w[1].qubits()[0]);
739 vec![Box::new(single::PauliZ { target: t }) as Box<dyn GateOp>]
740 },
741 };
742
743 let zxz_to_x = RewriteRule {
745 name: "Z-X-Z to X (global phase)".to_string(),
746 window_size: 3,
747 condition: |w| {
748 w[0].name() == "Z"
749 && w[1].name() == "X"
750 && w[2].name() == "Z"
751 && all_same_single_qubit(w)
752 },
753 rewrite: |w| {
754 let t = single_qubit_of(w[1].as_ref()).unwrap_or_else(|| w[1].qubits()[0]);
755 vec![Box::new(single::PauliX { target: t }) as Box<dyn GateOp>]
756 },
757 };
758
759 let hh = RewriteRule {
761 name: "H-H cancel".to_string(),
762 window_size: 2,
763 condition: |w| w[0].name() == "H" && w[1].name() == "H" && all_same_single_qubit(w),
764 rewrite: |_w| vec![],
765 };
766 let xx = RewriteRule {
767 name: "X-X cancel".to_string(),
768 window_size: 2,
769 condition: |w| w[0].name() == "X" && w[1].name() == "X" && all_same_single_qubit(w),
770 rewrite: |_w| vec![],
771 };
772 let yy = RewriteRule {
773 name: "Y-Y cancel".to_string(),
774 window_size: 2,
775 condition: |w| w[0].name() == "Y" && w[1].name() == "Y" && all_same_single_qubit(w),
776 rewrite: |_w| vec![],
777 };
778 let zz = RewriteRule {
779 name: "Z-Z cancel".to_string(),
780 window_size: 2,
781 condition: |w| w[0].name() == "Z" && w[1].name() == "Z" && all_same_single_qubit(w),
782 rewrite: |_w| vec![],
783 };
784
785 let ss_to_z = RewriteRule {
787 name: "S-S to Z".to_string(),
788 window_size: 2,
789 condition: |w| w[0].name() == "S" && w[1].name() == "S" && all_same_single_qubit(w),
790 rewrite: |w| {
791 let t = single_qubit_of(w[0].as_ref()).unwrap_or_else(|| w[0].qubits()[0]);
792 vec![Box::new(single::PauliZ { target: t }) as Box<dyn GateOp>]
793 },
794 };
795
796 let tt_to_s = RewriteRule {
798 name: "T-T to S".to_string(),
799 window_size: 2,
800 condition: |w| w[0].name() == "T" && w[1].name() == "T" && all_same_single_qubit(w),
801 rewrite: |w| {
802 let t = single_qubit_of(w[0].as_ref()).unwrap_or_else(|| w[0].qubits()[0]);
803 vec![Box::new(single::Phase { target: t }) as Box<dyn GateOp>]
804 },
805 };
806
807 let cnot_cnot = RewriteRule {
809 name: "CNOT-CNOT cancel".to_string(),
810 window_size: 2,
811 condition: |w| {
812 if w[0].name() != "CNOT" || w[1].name() != "CNOT" {
813 return false;
814 }
815 match (
816 w[0].as_any().downcast_ref::<multi::CNOT>(),
817 w[1].as_any().downcast_ref::<multi::CNOT>(),
818 ) {
819 (Some(c1), Some(c2)) => c1.control == c2.control && c1.target == c2.target,
820 _ => false,
821 }
822 },
823 rewrite: |_w| vec![],
824 };
825
826 let cz_cz = RewriteRule {
828 name: "CZ-CZ cancel".to_string(),
829 window_size: 2,
830 condition: |w| {
831 if w[0].name() != "CZ" || w[1].name() != "CZ" {
832 return false;
833 }
834 let q0 = w[0].qubits();
835 let q1 = w[1].qubits();
836 if q0.len() != 2 || q1.len() != 2 {
837 return false;
838 }
839 (q0[0] == q1[0] && q0[1] == q1[1]) || (q0[0] == q1[1] && q0[1] == q1[0])
840 },
841 rewrite: |_w| vec![],
842 };
843
844 let rx_merge = RewriteRule {
846 name: "RX-RX merge".to_string(),
847 window_size: 2,
848 condition: |w| {
849 w[0].name() == "RX"
850 && w[1].name() == "RX"
851 && all_same_single_qubit(w)
852 && extract_rx_angle(w[0].as_ref()).is_some()
853 && extract_rx_angle(w[1].as_ref()).is_some()
854 },
855 rewrite: |w| {
856 let t = single_qubit_of(w[0].as_ref()).unwrap_or_else(|| w[0].qubits()[0]);
857 let a = extract_rx_angle(w[0].as_ref()).unwrap_or(0.0);
858 let b = extract_rx_angle(w[1].as_ref()).unwrap_or(0.0);
859 let sum = normalise_angle(a + b);
860 if is_identity_angle(sum, 1e-10) {
861 vec![]
862 } else {
863 vec![Box::new(single::RotationX {
864 target: t,
865 theta: sum,
866 }) as Box<dyn GateOp>]
867 }
868 },
869 };
870
871 let ry_merge = RewriteRule {
872 name: "RY-RY merge".to_string(),
873 window_size: 2,
874 condition: |w| {
875 w[0].name() == "RY"
876 && w[1].name() == "RY"
877 && all_same_single_qubit(w)
878 && extract_ry_angle(w[0].as_ref()).is_some()
879 && extract_ry_angle(w[1].as_ref()).is_some()
880 },
881 rewrite: |w| {
882 let t = single_qubit_of(w[0].as_ref()).unwrap_or_else(|| w[0].qubits()[0]);
883 let a = extract_ry_angle(w[0].as_ref()).unwrap_or(0.0);
884 let b = extract_ry_angle(w[1].as_ref()).unwrap_or(0.0);
885 let sum = normalise_angle(a + b);
886 if is_identity_angle(sum, 1e-10) {
887 vec![]
888 } else {
889 vec![Box::new(single::RotationY {
890 target: t,
891 theta: sum,
892 }) as Box<dyn GateOp>]
893 }
894 },
895 };
896
897 let rz_merge = RewriteRule {
898 name: "RZ-RZ merge".to_string(),
899 window_size: 2,
900 condition: |w| {
901 w[0].name() == "RZ"
902 && w[1].name() == "RZ"
903 && all_same_single_qubit(w)
904 && extract_rz_angle(w[0].as_ref()).is_some()
905 && extract_rz_angle(w[1].as_ref()).is_some()
906 },
907 rewrite: |w| {
908 let t = single_qubit_of(w[0].as_ref()).unwrap_or_else(|| w[0].qubits()[0]);
909 let a = extract_rz_angle(w[0].as_ref()).unwrap_or(0.0);
910 let b = extract_rz_angle(w[1].as_ref()).unwrap_or(0.0);
911 let sum = normalise_angle(a + b);
912 if is_identity_angle(sum, 1e-10) {
913 vec![]
914 } else {
915 vec![Box::new(single::RotationZ {
916 target: t,
917 theta: sum,
918 }) as Box<dyn GateOp>]
919 }
920 },
921 };
922
923 vec![
924 hxh_to_z, hzh_to_x, hyh_to_y, xzx_to_z, zxz_to_x, hh, xx, yy, zz, ss_to_z, tt_to_s,
925 cnot_cnot, cz_cz, rx_merge, ry_merge, rz_merge,
926 ]
927}
928
929pub struct CircuitRewriting {
931 rules: Vec<RewriteRule>,
932 max_rewrites: usize,
933}
934
935#[derive(Clone)]
937pub struct RewriteRule {
938 name: String,
940 window_size: usize,
942 condition: fn(&[Box<dyn GateOp>]) -> bool,
944 rewrite: fn(&[Box<dyn GateOp>]) -> Vec<Box<dyn GateOp>>,
946}
947
948impl CircuitRewriting {
949 #[must_use]
951 pub fn new(max_rewrites: usize) -> Self {
952 Self {
953 rules: default_rewrite_rules(),
954 max_rewrites,
955 }
956 }
957
958 #[must_use]
960 pub fn with_rules(rules: Vec<RewriteRule>, max_rewrites: usize) -> Self {
961 Self {
962 rules,
963 max_rewrites,
964 }
965 }
966
967 fn try_apply_at(
969 &self,
970 gates: &[Box<dyn GateOp>],
971 pos: usize,
972 ) -> Option<(usize, Vec<Box<dyn GateOp>>)> {
973 for rule in &self.rules {
974 let end = pos + rule.window_size;
975 if end > gates.len() {
976 continue;
977 }
978 let window = &gates[pos..end];
979 if (rule.condition)(window) {
980 return Some((rule.window_size, (rule.rewrite)(window)));
981 }
982 }
983 None
984 }
985
986 fn scan_once(&self, gates: Vec<Box<dyn GateOp>>) -> (Vec<Box<dyn GateOp>>, bool) {
988 let mut out: Vec<Box<dyn GateOp>> = Vec::with_capacity(gates.len());
989 let mut i = 0;
990 let mut fired = false;
991
992 while i < gates.len() {
993 if let Some((consumed, replacement)) = self.try_apply_at(&gates, i) {
994 out.extend(replacement);
995 i += consumed;
996 fired = true;
997 } else {
998 out.push(gates[i].clone());
999 i += 1;
1000 }
1001 }
1002
1003 (out, fired)
1004 }
1005}
1006
1007impl OptimizationPass for CircuitRewriting {
1008 fn name(&self) -> &'static str {
1009 "Circuit Rewriting"
1010 }
1011
1012 fn apply_to_gates(
1013 &self,
1014 gates: Vec<Box<dyn GateOp>>,
1015 _cost_model: &dyn CostModel,
1016 ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
1017 let mut current = gates;
1018 let mut passes_done = 0;
1019
1020 loop {
1021 if passes_done >= self.max_rewrites {
1022 break;
1023 }
1024 let (next, fired) = self.scan_once(current);
1025 current = next;
1026 if !fired {
1027 break;
1028 }
1029 passes_done += 1;
1030 }
1031
1032 Ok(current)
1033 }
1034}
1035
1036pub struct ParallelizationPass;
1042
1043impl ParallelizationPass {
1044 #[must_use]
1046 pub const fn new() -> Self {
1047 Self
1048 }
1049}
1050
1051impl Default for ParallelizationPass {
1052 fn default() -> Self {
1053 Self::new()
1054 }
1055}
1056
1057impl OptimizationPass for ParallelizationPass {
1058 fn name(&self) -> &'static str {
1059 "Parallelization (ASAP)"
1060 }
1061
1062 fn apply_to_gates(
1063 &self,
1064 gates: Vec<Box<dyn GateOp>>,
1065 _cost_model: &dyn CostModel,
1066 ) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
1067 Ok(parallelize_gates(gates))
1068 }
1069}
1070
1071#[must_use]
1078pub fn parallelize_gates(gates: Vec<Box<dyn GateOp>>) -> Vec<Box<dyn GateOp>> {
1079 let n = gates.len();
1080 if n == 0 {
1081 return gates;
1082 }
1083
1084 let qubit_sets: Vec<HashSet<u32>> = gates
1086 .iter()
1087 .map(|g| g.qubits().into_iter().map(|q| q.id()).collect())
1088 .collect();
1089
1090 let mut in_degree = vec![0usize; n];
1091 let mut predecessors: Vec<HashSet<usize>> = vec![HashSet::new(); n];
1092 let mut successors: Vec<Vec<usize>> = vec![Vec::new(); n];
1093
1094 let mut last_gate_on_qubit: HashMap<u32, usize> = HashMap::new();
1095
1096 for j in 0..n {
1097 for &qid in &qubit_sets[j] {
1098 if let Some(&i) = last_gate_on_qubit.get(&qid) {
1099 if predecessors[j].insert(i) {
1100 successors[i].push(j);
1101 in_degree[j] += 1;
1102 }
1103 }
1104 }
1105 for &qid in &qubit_sets[j] {
1106 last_gate_on_qubit.insert(qid, j);
1107 }
1108 }
1109
1110 let mut result: Vec<Box<dyn GateOp>> = Vec::with_capacity(n);
1112 let mut ready: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
1113
1114 while !ready.is_empty() {
1115 ready.sort_unstable();
1116 let layer = std::mem::take(&mut ready);
1117 for &idx in &layer {
1118 result.push(gates[idx].clone());
1119 for &succ in &successors[idx] {
1120 in_degree[succ] -= 1;
1121 if in_degree[succ] == 0 {
1122 ready.push(succ);
1123 }
1124 }
1125 }
1126 }
1127
1128 if result.len() != n {
1131 gates
1132 } else {
1133 result
1134 }
1135}
1136
1137pub mod utils {
1139 use super::{GateOp, HashMap};
1140 use crate::optimization::gate_properties::get_gate_properties;
1141
1142 pub fn gates_cancel(gate1: &dyn GateOp, gate2: &dyn GateOp) -> bool {
1144 if gate1.name() != gate2.name() || gate1.qubits() != gate2.qubits() {
1145 return false;
1146 }
1147
1148 let props = get_gate_properties(gate1);
1149 props.is_self_inverse
1150 }
1151
1152 pub fn is_identity_gate(gate: &dyn GateOp, tolerance: f64) -> bool {
1154 match gate.name() {
1155 "RX" | "RY" | "RZ" => {
1156 if let Ok(matrix) = gate.matrix() {
1157 (matrix[0].re - 1.0).abs() < tolerance && matrix[0].im.abs() < tolerance
1158 } else {
1159 false
1160 }
1161 }
1162 _ => false,
1163 }
1164 }
1165
1166 #[must_use]
1168 pub fn calculate_depth(gates: &[Box<dyn GateOp>]) -> usize {
1169 let mut qubit_depths: HashMap<u32, usize> = HashMap::new();
1170 let mut max_depth = 0;
1171
1172 for gate in gates {
1173 let gate_qubits = gate.qubits();
1174 let current_depth = gate_qubits
1175 .iter()
1176 .map(|q| qubit_depths.get(&q.id()).copied().unwrap_or(0))
1177 .max()
1178 .unwrap_or(0);
1179
1180 let new_depth = current_depth + 1;
1181 for qubit in gate_qubits {
1182 qubit_depths.insert(qubit.id(), new_depth);
1183 }
1184
1185 max_depth = max_depth.max(new_depth);
1186 }
1187
1188 max_depth
1189 }
1190}
1191
1192#[cfg(test)]
1195mod rewriting_tests {
1196 use super::*;
1197 use quantrs2_core::gate::single::{
1198 Hadamard, PauliX, PauliY, PauliZ, Phase, RotationX, RotationY, RotationZ, T,
1199 };
1200 use quantrs2_core::qubit::QubitId;
1201
1202 fn q(id: u32) -> QubitId {
1203 QubitId::new(id)
1204 }
1205
1206 fn cost() -> crate::optimization::cost_model::AbstractCostModel {
1207 crate::optimization::cost_model::AbstractCostModel::new(
1208 crate::optimization::cost_model::CostWeights::default(),
1209 )
1210 }
1211
1212 #[test]
1213 fn test_hh_cancels() {
1214 let pass = CircuitRewriting::new(10);
1215 let q0 = q(0);
1216 let gates: Vec<Box<dyn GateOp>> = vec![
1217 Box::new(Hadamard { target: q0 }),
1218 Box::new(Hadamard { target: q0 }),
1219 ];
1220 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1221 assert!(result.is_empty(), "H-H should cancel to identity");
1222 }
1223
1224 #[test]
1225 fn test_hh_different_qubits_no_cancel() {
1226 let pass = CircuitRewriting::new(10);
1227 let gates: Vec<Box<dyn GateOp>> = vec![
1228 Box::new(Hadamard { target: q(0) }),
1229 Box::new(Hadamard { target: q(1) }),
1230 ];
1231 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1232 assert_eq!(result.len(), 2, "H on different qubits must not cancel");
1233 }
1234
1235 #[test]
1236 fn test_ss_to_z() {
1237 let pass = CircuitRewriting::new(10);
1238 let q0 = q(0);
1239 let gates: Vec<Box<dyn GateOp>> = vec![
1240 Box::new(Phase { target: q0 }),
1241 Box::new(Phase { target: q0 }),
1242 ];
1243 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1244 assert_eq!(result.len(), 1, "S-S should produce one gate");
1245 assert_eq!(result[0].name(), "Z", "S-S should produce Z");
1246 }
1247
1248 #[test]
1249 fn test_tt_to_s() {
1250 let pass = CircuitRewriting::new(10);
1251 let q0 = q(0);
1252 let gates: Vec<Box<dyn GateOp>> =
1253 vec![Box::new(T { target: q0 }), Box::new(T { target: q0 })];
1254 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1255 assert_eq!(result.len(), 1, "T-T should produce one gate");
1256 assert_eq!(result[0].name(), "S", "T-T should produce S");
1257 }
1258
1259 #[test]
1260 fn test_rz_rz_merge() {
1261 let pass = CircuitRewriting::new(10);
1262 let q0 = q(0);
1263 let theta1 = PI / 4.0;
1264 let theta2 = PI / 4.0;
1265 let gates: Vec<Box<dyn GateOp>> = vec![
1266 Box::new(RotationZ {
1267 target: q0,
1268 theta: theta1,
1269 }),
1270 Box::new(RotationZ {
1271 target: q0,
1272 theta: theta2,
1273 }),
1274 ];
1275 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1276 assert_eq!(result.len(), 1, "RZ-RZ should merge to one gate");
1277 assert_eq!(result[0].name(), "RZ");
1278 let merged = result[0]
1279 .as_any()
1280 .downcast_ref::<RotationZ>()
1281 .expect("should be RotationZ");
1282 let expected = normalise_angle(theta1 + theta2);
1283 assert!(
1284 (merged.theta - expected).abs() < 1e-9,
1285 "merged angle {:.6} != expected {:.6}",
1286 merged.theta,
1287 expected
1288 );
1289 }
1290
1291 #[test]
1292 fn test_rz_rz_cancel() {
1293 let pass = CircuitRewriting::new(10);
1294 let q0 = q(0);
1295 let gates: Vec<Box<dyn GateOp>> = vec![
1296 Box::new(RotationZ {
1297 target: q0,
1298 theta: PI,
1299 }),
1300 Box::new(RotationZ {
1301 target: q0,
1302 theta: PI,
1303 }),
1304 ];
1305 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1306 assert!(result.is_empty(), "RZ(π)+RZ(π) should cancel");
1307 }
1308
1309 #[test]
1310 fn test_rx_rx_merge() {
1311 let pass = CircuitRewriting::new(10);
1312 let q0 = q(0);
1313 let gates: Vec<Box<dyn GateOp>> = vec![
1314 Box::new(RotationX {
1315 target: q0,
1316 theta: PI / 3.0,
1317 }),
1318 Box::new(RotationX {
1319 target: q0,
1320 theta: PI / 6.0,
1321 }),
1322 ];
1323 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1324 assert_eq!(result.len(), 1, "RX-RX should merge");
1325 assert_eq!(result[0].name(), "RX");
1326 let merged = result[0]
1327 .as_any()
1328 .downcast_ref::<RotationX>()
1329 .expect("RotationX");
1330 let expected = normalise_angle(PI / 3.0 + PI / 6.0);
1331 assert!((merged.theta - expected).abs() < 1e-9);
1332 }
1333
1334 #[test]
1335 fn test_ry_ry_cancel() {
1336 let pass = CircuitRewriting::new(10);
1337 let q0 = q(0);
1338 let gates: Vec<Box<dyn GateOp>> = vec![
1339 Box::new(RotationY {
1340 target: q0,
1341 theta: PI / 2.0,
1342 }),
1343 Box::new(RotationY {
1344 target: q0,
1345 theta: -PI / 2.0,
1346 }),
1347 ];
1348 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1349 assert!(result.is_empty(), "RY(π/2)+RY(-π/2) should cancel");
1350 }
1351
1352 #[test]
1353 fn test_hxh_to_z() {
1354 let pass = CircuitRewriting::new(10);
1355 let q0 = q(0);
1356 let gates: Vec<Box<dyn GateOp>> = vec![
1357 Box::new(Hadamard { target: q0 }),
1358 Box::new(PauliX { target: q0 }),
1359 Box::new(Hadamard { target: q0 }),
1360 ];
1361 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1362 assert_eq!(result.len(), 1);
1363 assert_eq!(result[0].name(), "Z", "H-X-H → Z");
1364 }
1365
1366 #[test]
1367 fn test_hzh_to_x() {
1368 let pass = CircuitRewriting::new(10);
1369 let q0 = q(0);
1370 let gates: Vec<Box<dyn GateOp>> = vec![
1371 Box::new(Hadamard { target: q0 }),
1372 Box::new(PauliZ { target: q0 }),
1373 Box::new(Hadamard { target: q0 }),
1374 ];
1375 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1376 assert_eq!(result.len(), 1);
1377 assert_eq!(result[0].name(), "X", "H-Z-H → X");
1378 }
1379
1380 #[test]
1381 fn test_hyh_to_y() {
1382 let pass = CircuitRewriting::new(10);
1383 let q0 = q(0);
1384 let gates: Vec<Box<dyn GateOp>> = vec![
1385 Box::new(Hadamard { target: q0 }),
1386 Box::new(PauliY { target: q0 }),
1387 Box::new(Hadamard { target: q0 }),
1388 ];
1389 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1390 assert_eq!(result.len(), 1);
1391 assert_eq!(result[0].name(), "Y", "H-Y-H → Y");
1392 }
1393
1394 #[test]
1395 fn test_hhh_converges_to_h() {
1396 let pass = CircuitRewriting::new(10);
1397 let q0 = q(0);
1398 let gates: Vec<Box<dyn GateOp>> = vec![
1399 Box::new(Hadamard { target: q0 }),
1400 Box::new(Hadamard { target: q0 }),
1401 Box::new(Hadamard { target: q0 }),
1402 ];
1403 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1404 assert_eq!(result.len(), 1, "H-H-H should converge to one H");
1405 assert_eq!(result[0].name(), "H");
1406 }
1407
1408 #[test]
1409 fn test_xxyy_cancel() {
1410 let pass = CircuitRewriting::new(10);
1411 let q0 = q(0);
1412 let gates: Vec<Box<dyn GateOp>> = vec![
1413 Box::new(PauliX { target: q0 }),
1414 Box::new(PauliX { target: q0 }),
1415 Box::new(PauliY { target: q0 }),
1416 Box::new(PauliY { target: q0 }),
1417 ];
1418 let result = pass.apply_to_gates(gates, &cost()).expect("apply failed");
1419 assert!(result.is_empty(), "X-X-Y-Y should cancel");
1420 }
1421}