1use crate::builder::Circuit;
8use crate::dag::{circuit_to_dag, CircuitDag, DagNode};
9use quantrs2_core::{
10 error::{QuantRS2Error, QuantRS2Result},
11 gate::GateOp,
12 qubit::QubitId,
13};
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet, VecDeque};
16use std::f64::consts::PI;
17use std::sync::Arc;
18
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub enum ZXNode {
22 ZSpider {
24 id: usize,
25 phase: f64,
26 arity: usize,
28 },
29 XSpider {
31 id: usize,
32 phase: f64,
33 arity: usize,
34 },
35 Hadamard {
37 id: usize,
38 },
39 Input {
41 id: usize,
42 qubit: u32,
43 },
44 Output {
45 id: usize,
46 qubit: u32,
47 },
48}
49
50impl ZXNode {
51 #[must_use]
52 pub const fn id(&self) -> usize {
53 match self {
54 Self::ZSpider { id, .. } => *id,
55 Self::XSpider { id, .. } => *id,
56 Self::Hadamard { id } => *id,
57 Self::Input { id, .. } => *id,
58 Self::Output { id, .. } => *id,
59 }
60 }
61
62 #[must_use]
63 pub const fn phase(&self) -> f64 {
64 match self {
65 Self::ZSpider { phase, .. } | Self::XSpider { phase, .. } => *phase,
66 _ => 0.0,
67 }
68 }
69
70 pub const fn set_phase(&mut self, new_phase: f64) {
71 match self {
72 Self::ZSpider { phase, .. } | Self::XSpider { phase, .. } => *phase = new_phase,
73 _ => {}
74 }
75 }
76}
77
78#[derive(Debug, Clone, PartialEq, Eq)]
80pub struct ZXEdge {
81 pub source: usize,
82 pub target: usize,
83 pub is_hadamard: bool,
85}
86
87#[derive(Debug, Clone)]
89pub struct ZXDiagram {
90 pub nodes: HashMap<usize, ZXNode>,
92 pub edges: Vec<ZXEdge>,
94 pub adjacency: HashMap<usize, Vec<usize>>,
96 pub inputs: HashMap<u32, usize>,
98 pub outputs: HashMap<u32, usize>,
100 next_id: usize,
102}
103
104impl Default for ZXDiagram {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl ZXDiagram {
111 #[must_use]
113 pub fn new() -> Self {
114 Self {
115 nodes: HashMap::new(),
116 edges: Vec::new(),
117 adjacency: HashMap::new(),
118 inputs: HashMap::new(),
119 outputs: HashMap::new(),
120 next_id: 0,
121 }
122 }
123
124 pub fn add_node(&mut self, node: ZXNode) -> usize {
126 let id = self.next_id;
127 self.next_id += 1;
128
129 let node_with_id = match node {
130 ZXNode::ZSpider { phase, arity, .. } => ZXNode::ZSpider { id, phase, arity },
131 ZXNode::XSpider { phase, arity, .. } => ZXNode::XSpider { id, phase, arity },
132 ZXNode::Hadamard { .. } => ZXNode::Hadamard { id },
133 ZXNode::Input { qubit, .. } => ZXNode::Input { id, qubit },
134 ZXNode::Output { qubit, .. } => ZXNode::Output { id, qubit },
135 };
136
137 self.nodes.insert(id, node_with_id);
138 self.adjacency.insert(id, Vec::new());
139 id
140 }
141
142 pub fn add_edge(&mut self, source: usize, target: usize, is_hadamard: bool) {
144 let edge = ZXEdge {
145 source,
146 target,
147 is_hadamard,
148 };
149 self.edges.push(edge);
150
151 self.adjacency.entry(source).or_default().push(target);
153 self.adjacency.entry(target).or_default().push(source);
154 }
155
156 pub fn initialize_boundaries(&mut self, num_qubits: usize) {
158 for i in 0..num_qubits {
159 let qubit = i as u32;
160
161 let input_id = self.add_node(ZXNode::Input { id: 0, qubit });
162 let output_id = self.add_node(ZXNode::Output { id: 0, qubit });
163
164 self.inputs.insert(qubit, input_id);
165 self.outputs.insert(qubit, output_id);
166 }
167 }
168
169 #[must_use]
171 pub fn neighbors(&self, node_id: usize) -> &[usize] {
172 self.adjacency
173 .get(&node_id)
174 .map_or(&[], std::vec::Vec::as_slice)
175 }
176
177 pub fn spider_fusion(&mut self) -> bool {
180 let mut changed = false;
181 let mut to_remove = Vec::new();
182 let mut to_update = Vec::new();
183
184 for edge in &self.edges {
185 if !edge.is_hadamard {
186 if let (Some(node1), Some(node2)) =
187 (self.nodes.get(&edge.source), self.nodes.get(&edge.target))
188 {
189 match (node1, node2) {
191 (
192 ZXNode::ZSpider {
193 id: id1,
194 phase: phase1,
195 ..
196 },
197 ZXNode::ZSpider {
198 id: id2,
199 phase: phase2,
200 ..
201 },
202 )
203 | (
204 ZXNode::XSpider {
205 id: id1,
206 phase: phase1,
207 ..
208 },
209 ZXNode::XSpider {
210 id: id2,
211 phase: phase2,
212 ..
213 },
214 ) => {
215 let new_phase = (phase1 + phase2) % (2.0 * PI);
217 to_update.push((*id1, new_phase));
218 to_remove.push(*id2);
219 changed = true;
220 }
221 _ => {}
222 }
223 }
224 }
225 }
226
227 for (id, new_phase) in to_update {
229 if let Some(node) = self.nodes.get_mut(&id) {
230 node.set_phase(new_phase);
231 }
232 }
233
234 for id in to_remove {
236 self.remove_node(id);
237 }
238
239 changed
240 }
241
242 pub fn identity_removal(&mut self) -> bool {
245 let mut changed = false;
246 let mut to_remove = Vec::new();
247
248 for (id, node) in &self.nodes {
249 match node {
250 ZXNode::ZSpider { phase, arity, .. } | ZXNode::XSpider { phase, arity, .. }
251 if *arity == 2 && phase.abs() < 1e-10 =>
252 {
253 to_remove.push(*id);
254 }
255 _ => {}
256 }
257 }
258
259 for id in to_remove {
260 let neighbors: Vec<_> = self.neighbors(id).to_vec();
262 if neighbors.len() == 2 {
263 self.add_edge(neighbors[0], neighbors[1], false);
264 changed = true;
265 }
266 self.remove_node(id);
267 }
268
269 changed
270 }
271
272 pub const fn pi_commutation(&mut self) -> bool {
275 false
278 }
279
280 pub fn hadamard_cancellation(&mut self) -> bool {
283 let mut changed = false;
284 let mut to_remove = Vec::new();
285
286 for edge in &self.edges {
288 if let (Some(ZXNode::Hadamard { id: id1 }), Some(ZXNode::Hadamard { id: id2 })) =
289 (self.nodes.get(&edge.source), self.nodes.get(&edge.target))
290 {
291 to_remove.push(*id1);
293 to_remove.push(*id2);
294 changed = true;
295 }
296 }
297
298 for id in to_remove {
299 self.remove_node(id);
300 }
301
302 changed
303 }
304
305 fn remove_node(&mut self, node_id: usize) {
307 self.nodes.remove(&node_id);
309
310 self.adjacency.remove(&node_id);
312
313 for adj_list in self.adjacency.values_mut() {
315 adj_list.retain(|&id| id != node_id);
316 }
317
318 self.edges
320 .retain(|edge| edge.source != node_id && edge.target != node_id);
321 }
322
323 #[must_use]
325 pub fn t_count(&self) -> usize {
326 self.nodes
327 .values()
328 .filter(|node| {
329 let phase = node.phase();
330 (phase - PI / 4.0).abs() < 1e-10
331 || (phase - 3.0 * PI / 4.0).abs() < 1e-10
332 || (phase - 5.0 * PI / 4.0).abs() < 1e-10
333 || (phase - 7.0 * PI / 4.0).abs() < 1e-10
334 })
335 .count()
336 }
337
338 pub fn optimize(&mut self) -> ZXOptimizationResult {
340 let initial_node_count = self.nodes.len();
341 let initial_t_count = self.t_count();
342
343 let mut iterations = 0;
344 let max_iterations = 100;
345
346 while iterations < max_iterations {
347 let mut changed = false;
348
349 changed |= self.spider_fusion();
351 changed |= self.identity_removal();
352 changed |= self.hadamard_cancellation();
353 changed |= self.pi_commutation();
354
355 if !changed {
356 break;
357 }
358 iterations += 1;
359 }
360
361 let final_node_count = self.nodes.len();
362 let final_t_count = self.t_count();
363
364 ZXOptimizationResult {
365 iterations,
366 initial_node_count,
367 final_node_count,
368 initial_t_count,
369 final_t_count,
370 converged: iterations < max_iterations,
371 }
372 }
373}
374
375#[derive(Debug, Clone)]
377pub struct ZXOptimizationResult {
378 pub iterations: usize,
379 pub initial_node_count: usize,
380 pub final_node_count: usize,
381 pub initial_t_count: usize,
382 pub final_t_count: usize,
383 pub converged: bool,
384}
385
386pub struct ZXOptimizer {
388 pub max_iterations: usize,
390 pub enable_spider_fusion: bool,
392 pub enable_identity_removal: bool,
393 pub enable_pi_commutation: bool,
394 pub enable_hadamard_cancellation: bool,
395}
396
397impl Default for ZXOptimizer {
398 fn default() -> Self {
399 Self {
400 max_iterations: 100,
401 enable_spider_fusion: true,
402 enable_identity_removal: true,
403 enable_pi_commutation: true,
404 enable_hadamard_cancellation: true,
405 }
406 }
407}
408
409impl ZXOptimizer {
410 #[must_use]
412 pub fn new() -> Self {
413 Self::default()
414 }
415
416 pub fn circuit_to_zx<const N: usize>(&self, circuit: &Circuit<N>) -> QuantRS2Result<ZXDiagram> {
418 let mut diagram = ZXDiagram::new();
419 diagram.initialize_boundaries(N);
420
421 let mut qubit_wires = HashMap::new();
423 for i in 0..N {
424 let qubit = i as u32;
425 if let Some(&input_id) = diagram.inputs.get(&qubit) {
426 qubit_wires.insert(qubit, input_id);
427 }
428 }
429
430 for gate in circuit.gates() {
432 self.gate_to_zx(gate.as_ref(), &mut diagram, &mut qubit_wires)?;
433 }
434
435 for i in 0..N {
437 let qubit = i as u32;
438 if let (Some(&last_node), Some(&output_id)) =
439 (qubit_wires.get(&qubit), diagram.outputs.get(&qubit))
440 {
441 diagram.add_edge(last_node, output_id, false);
442 }
443 }
444
445 Ok(diagram)
446 }
447
448 fn gate_to_zx(
450 &self,
451 gate: &dyn GateOp,
452 diagram: &mut ZXDiagram,
453 qubit_wires: &mut HashMap<u32, usize>,
454 ) -> QuantRS2Result<()> {
455 let gate_name = gate.name();
456 let qubits = gate.qubits();
457
458 match gate_name {
459 "H" => {
460 let qubit = qubits[0].id();
462 let h_node = diagram.add_node(ZXNode::Hadamard { id: 0 });
463
464 if let Some(&prev_node) = qubit_wires.get(&qubit) {
465 diagram.add_edge(prev_node, h_node, false);
466 }
467 qubit_wires.insert(qubit, h_node);
468 }
469 "X" => {
470 let qubit = qubits[0].id();
472 let x_node = diagram.add_node(ZXNode::ZSpider {
473 id: 0,
474 phase: PI,
475 arity: 2,
476 });
477
478 if let Some(&prev_node) = qubit_wires.get(&qubit) {
479 diagram.add_edge(prev_node, x_node, false);
480 }
481 qubit_wires.insert(qubit, x_node);
482 }
483 "Y" => {
484 let qubit = qubits[0].id();
486 let y_node = diagram.add_node(ZXNode::ZSpider {
487 id: 0,
488 phase: PI,
489 arity: 2,
490 });
491
492 if let Some(&prev_node) = qubit_wires.get(&qubit) {
493 diagram.add_edge(prev_node, y_node, false);
494 }
495 qubit_wires.insert(qubit, y_node);
496 }
497 "Z" => {
498 let qubit = qubits[0].id();
500 let z_node = diagram.add_node(ZXNode::ZSpider {
501 id: 0,
502 phase: PI,
503 arity: 2,
504 });
505
506 if let Some(&prev_node) = qubit_wires.get(&qubit) {
507 diagram.add_edge(prev_node, z_node, false);
508 }
509 qubit_wires.insert(qubit, z_node);
510 }
511 "RZ" => {
512 let qubit = qubits[0].id();
514
515 let angle = self.extract_rotation_angle(gate);
517 let rz_node = diagram.add_node(ZXNode::ZSpider {
518 id: 0,
519 phase: angle,
520 arity: 2,
521 });
522
523 if let Some(&prev_node) = qubit_wires.get(&qubit) {
524 diagram.add_edge(prev_node, rz_node, false);
525 }
526 qubit_wires.insert(qubit, rz_node);
527 }
528 "CNOT" => {
529 let control_qubit = qubits[0].id();
531 let target_qubit = qubits[1].id();
532
533 let control_spider = diagram.add_node(ZXNode::ZSpider {
534 id: 0,
535 phase: 0.0,
536 arity: 3,
537 });
538 let target_spider = diagram.add_node(ZXNode::XSpider {
539 id: 0,
540 phase: 0.0,
541 arity: 3,
542 });
543
544 if let Some(&prev_control) = qubit_wires.get(&control_qubit) {
546 diagram.add_edge(prev_control, control_spider, false);
547 }
548
549 if let Some(&prev_target) = qubit_wires.get(&target_qubit) {
551 diagram.add_edge(prev_target, target_spider, false);
552 }
553
554 diagram.add_edge(control_spider, target_spider, false);
556
557 qubit_wires.insert(control_qubit, control_spider);
558 qubit_wires.insert(target_qubit, target_spider);
559 }
560 _ => {
561 for qubit_id in qubits {
563 let qubit = qubit_id.id();
564 let identity_node = diagram.add_node(ZXNode::ZSpider {
565 id: 0,
566 phase: 0.0,
567 arity: 2,
568 });
569
570 if let Some(&prev_node) = qubit_wires.get(&qubit) {
571 diagram.add_edge(prev_node, identity_node, false);
572 }
573 qubit_wires.insert(qubit, identity_node);
574 }
575 }
576 }
577
578 Ok(())
579 }
580
581 fn extract_rotation_angle(&self, gate: &dyn GateOp) -> f64 {
583 PI / 4.0 }
587
588 pub fn optimize_circuit<const N: usize>(
590 &self,
591 circuit: &Circuit<N>,
592 ) -> QuantRS2Result<OptimizedZXResult<N>> {
593 let mut diagram = self.circuit_to_zx(circuit)?;
595
596 let optimization_result = diagram.optimize();
598
599 let optimized_circuit = self.zx_to_circuit(&diagram)?;
601
602 Ok(OptimizedZXResult {
603 original_circuit: circuit.clone(),
604 optimized_circuit,
605 diagram,
606 optimization_stats: optimization_result,
607 })
608 }
609
610 fn zx_to_circuit<const N: usize>(&self, diagram: &ZXDiagram) -> QuantRS2Result<Circuit<N>> {
612 let mut circuit = Circuit::<N>::new();
620
621 for i in 0..N {
623 }
625
626 Ok(circuit)
627 }
628}
629
630#[derive(Debug)]
632pub struct OptimizedZXResult<const N: usize> {
633 pub original_circuit: Circuit<N>,
634 pub optimized_circuit: Circuit<N>,
635 pub diagram: ZXDiagram,
636 pub optimization_stats: ZXOptimizationResult,
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642 use quantrs2_core::gate::multi::CNOT;
643 use quantrs2_core::gate::single::{Hadamard, PauliX};
644
645 #[test]
646 fn test_zx_diagram_creation() {
647 let mut diagram = ZXDiagram::new();
648 diagram.initialize_boundaries(2);
649
650 assert_eq!(diagram.inputs.len(), 2);
651 assert_eq!(diagram.outputs.len(), 2);
652 }
653
654 #[test]
655 fn test_spider_fusion() {
656 let mut diagram = ZXDiagram::new();
657
658 let spider1 = diagram.add_node(ZXNode::ZSpider {
660 id: 0,
661 phase: PI / 4.0,
662 arity: 2,
663 });
664 let spider2 = diagram.add_node(ZXNode::ZSpider {
665 id: 0,
666 phase: PI / 8.0,
667 arity: 2,
668 });
669
670 diagram.add_edge(spider1, spider2, false);
672
673 let changed = diagram.spider_fusion();
675 assert!(changed);
676
677 assert_eq!(diagram.nodes.len(), 1);
679
680 let remaining_node = diagram
682 .nodes
683 .values()
684 .next()
685 .expect("Expected at least one remaining node after fusion");
686 assert!((remaining_node.phase() - (PI / 4.0 + PI / 8.0)).abs() < 1e-10);
687 }
688
689 #[test]
690 fn test_identity_removal() {
691 let mut diagram = ZXDiagram::new();
692
693 let identity = diagram.add_node(ZXNode::ZSpider {
695 id: 0,
696 phase: 0.0,
697 arity: 2,
698 });
699
700 let node1 = diagram.add_node(ZXNode::ZSpider {
702 id: 0,
703 phase: PI / 4.0,
704 arity: 2,
705 });
706 let node2 = diagram.add_node(ZXNode::ZSpider {
707 id: 0,
708 phase: PI / 2.0,
709 arity: 2,
710 });
711
712 diagram.add_edge(node1, identity, false);
714 diagram.add_edge(identity, node2, false);
715
716 let initial_count = diagram.nodes.len();
717 let changed = diagram.identity_removal();
718
719 assert!(changed);
720 assert_eq!(diagram.nodes.len(), initial_count - 1);
721 }
722
723 #[test]
724 fn test_circuit_to_zx_conversion() {
725 let optimizer = ZXOptimizer::new();
726
727 let mut circuit = Circuit::<2>::new();
728 circuit
729 .add_gate(Hadamard { target: QubitId(0) })
730 .expect("Failed to add Hadamard gate");
731 circuit
732 .add_gate(CNOT {
733 control: QubitId(0),
734 target: QubitId(1),
735 })
736 .expect("Failed to add CNOT gate");
737
738 let diagram = optimizer
739 .circuit_to_zx(&circuit)
740 .expect("Failed to convert circuit to ZX diagram");
741
742 assert!(diagram.nodes.len() >= 4); assert!(!diagram.edges.is_empty());
745 }
746
747 #[test]
748 fn test_zx_optimization() {
749 let optimizer = ZXOptimizer::new();
750
751 let mut circuit = Circuit::<1>::new();
752 circuit
753 .add_gate(Hadamard { target: QubitId(0) })
754 .expect("Failed to add first Hadamard gate");
755 circuit
756 .add_gate(Hadamard { target: QubitId(0) })
757 .expect("Failed to add second Hadamard gate"); let result = optimizer
760 .optimize_circuit(&circuit)
761 .expect("Failed to optimize circuit");
762
763 assert!(
764 result.optimization_stats.final_node_count
765 <= result.optimization_stats.initial_node_count
766 );
767 }
768}