1use crate::builder::Circuit;
7use scirs2_core::Complex64;
8use quantrs2_core::error::QuantRS2Error;
9use quantrs2_core::qubit::QubitId;
10use std::collections::{HashMap, HashSet, VecDeque};
11
12#[cfg(feature = "scirs")]
14use scirs2_core::sparse::{CsrMatrix, SparseMatrix};
15#[cfg(feature = "scirs")]
16use scirs2_optimize::graph::{
17 find_critical_path, graph_coloring, minimum_feedback_arc_set, topological_sort_weighted,
18};
19
20fn matrix_multiply_2x2(a: &[Vec<Complex64>], b: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
22 vec![
23 vec![
24 a[0][0] * b[0][0] + a[0][1] * b[1][0],
25 a[0][0] * b[0][1] + a[0][1] * b[1][1],
26 ],
27 vec![
28 a[1][0] * b[0][0] + a[1][1] * b[1][0],
29 a[1][0] * b[0][1] + a[1][1] * b[1][1],
30 ],
31 ]
32}
33
34#[derive(Debug, Clone, PartialEq)]
36pub struct GraphGate {
37 pub id: usize,
38 pub gate_type: String,
39 pub qubits: Vec<QubitId>,
40 pub params: Vec<f64>,
41 pub matrix: Option<Vec<Vec<Complex64>>>,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum EdgeType {
47 DataDependency,
49 NonCommuting,
51 Parallelizable,
53}
54
55pub struct CircuitDAG {
57 nodes: Vec<GraphGate>,
58 edges: HashMap<(usize, usize), EdgeType>,
59 qubit_chains: HashMap<u32, Vec<usize>>, }
61
62impl CircuitDAG {
63 pub fn new() -> Self {
65 Self {
66 nodes: Vec::new(),
67 edges: HashMap::new(),
68 qubit_chains: HashMap::new(),
69 }
70 }
71
72 pub fn add_gate(&mut self, gate: GraphGate) -> usize {
74 let gate_id = self.nodes.len();
75
76 for qubit in &gate.qubits {
78 self.qubit_chains
79 .entry(qubit.id())
80 .or_default()
81 .push(gate_id);
82 }
83
84 for qubit in &gate.qubits {
86 if let Some(chain) = self.qubit_chains.get(&qubit.id()) {
87 if chain.len() > 1 {
88 let prev_gate = chain[chain.len() - 2];
89 self.edges
90 .insert((prev_gate, gate_id), EdgeType::DataDependency);
91 }
92 }
93 }
94
95 self.nodes.push(gate);
96 gate_id
97 }
98
99 fn gates_commute(&self, g1: &GraphGate, g2: &GraphGate) -> bool {
101 let qubits1: HashSet<_> = g1.qubits.iter().map(|q| q.id()).collect();
103 let qubits2: HashSet<_> = g2.qubits.iter().map(|q| q.id()).collect();
104
105 if qubits1.is_disjoint(&qubits2) {
106 return true;
107 }
108
109 match (g1.gate_type.as_str(), g2.gate_type.as_str()) {
111 ("z", "z") | ("rz", "rz") | ("z", "rz") | ("rz", "z") => true,
113 ("cnot", "cnot") => {
115 if g1.qubits.len() == 2 && g2.qubits.len() == 2 {
116 let same_control = g1.qubits[0] == g2.qubits[0];
117 let same_target = g1.qubits[1] == g2.qubits[1];
118 same_control && same_target } else {
120 false
121 }
122 }
123 _ => false,
124 }
125 }
126
127 pub fn compute_commutation_edges(&mut self) {
129 #[cfg(feature = "scirs")]
130 {
131 if self.scirs_compute_commutation_edges() {
132 return;
133 }
134 }
135
136 self.standard_compute_commutation_edges();
137 }
138
139 #[cfg(feature = "scirs")]
140 fn scirs_compute_commutation_edges(&mut self) -> bool {
142 let n = self.nodes.len();
143 if n == 0 {
144 return true;
145 }
146
147 let mut interference = vec![vec![false; n]; n];
149
150 for i in 0..n {
151 for j in i + 1..n {
152 let g1 = &self.nodes[i];
153 let g2 = &self.nodes[j];
154
155 if !self.gates_commute(g1, g2) && !self.has_path(i, j) && !self.has_path(j, i) {
157 interference[i][j] = true;
158 interference[j][i] = true;
159 }
160 }
161 }
162
163 if let Ok(coloring) = graph_coloring(&interference) {
165 for i in 0..n {
167 for j in i + 1..n {
168 if coloring[i] == coloring[j] && !interference[i][j] {
169 self.edges.insert((i, j), EdgeType::Parallelizable);
171 }
172 }
173 }
174 return true;
175 }
176
177 false
178 }
179
180 fn standard_compute_commutation_edges(&mut self) {
182 let n = self.nodes.len();
183
184 for i in 0..n {
185 for j in i + 1..n {
186 let g1 = &self.nodes[i];
187 let g2 = &self.nodes[j];
188
189 if self.edges.contains_key(&(i, j)) || self.edges.contains_key(&(j, i)) {
191 continue;
192 }
193
194 if !self.gates_commute(g1, g2) {
196 if !self.has_path(i, j) && !self.has_path(j, i) {
198 self.edges.insert((i, j), EdgeType::NonCommuting);
199 }
200 } else if g1.qubits.iter().any(|q| g2.qubits.contains(q)) {
201 self.edges.insert((i, j), EdgeType::Parallelizable);
203 }
204 }
205 }
206 }
207
208 #[cfg(feature = "scirs")]
209 pub fn optimize_with_feedback_arc_set(&self) -> Option<Vec<usize>> {
211 let n = self.nodes.len();
212 if n == 0 {
213 return Some(vec![]);
214 }
215
216 let mut edges = Vec::new();
218 let mut weights = Vec::new();
219
220 for ((u, v), edge_type) in &self.edges {
221 if *edge_type == EdgeType::NonCommuting {
222 edges.push((*u, *v));
223 let weight = self.gate_weight(*u) * self.gate_weight(*v);
225 weights.push(weight);
226 }
227 }
228
229 if let Ok(feedback_arcs) = minimum_feedback_arc_set(&edges, &weights) {
231 let mut filtered_edges = edges.clone();
233 for &arc_idx in &feedback_arcs {
234 filtered_edges[arc_idx] = (n, n); }
236 filtered_edges.retain(|&(u, v)| u < n && v < n);
237
238 let mut new_dag = CircuitDAG::new();
240 new_dag.nodes = self.nodes.clone();
241 for (u, v) in filtered_edges {
242 new_dag.edges.insert((u, v), EdgeType::DataDependency);
243 }
244
245 return Some(new_dag.optimized_topological_sort());
246 }
247
248 None
249 }
250
251 fn has_path(&self, src: usize, dst: usize) -> bool {
253 let mut visited = vec![false; self.nodes.len()];
254 let mut queue = VecDeque::new();
255
256 queue.push_back(src);
257 visited[src] = true;
258
259 while let Some(node) = queue.pop_front() {
260 if node == dst {
261 return true;
262 }
263
264 for ((u, v), edge_type) in &self.edges {
265 if *u == node && !visited[*v] && *edge_type == EdgeType::DataDependency {
266 visited[*v] = true;
267 queue.push_back(*v);
268 }
269 }
270 }
271
272 false
273 }
274
275 pub fn optimized_topological_sort(&self) -> Vec<usize> {
277 #[cfg(feature = "scirs")]
278 {
279 if let Some(order) = self.scirs_topological_sort() {
281 return order;
282 }
283 }
284
285 self.standard_topological_sort()
287 }
288
289 #[cfg(feature = "scirs")]
290 fn scirs_topological_sort(&self) -> Option<Vec<usize>> {
292 let n = self.nodes.len();
293 if n == 0 {
294 return Some(vec![]);
295 }
296
297 let mut row_indices = Vec::new();
299 let mut col_indices = Vec::new();
300 let mut values = Vec::new();
301
302 for ((u, v), edge_type) in &self.edges {
303 if *edge_type == EdgeType::DataDependency {
304 row_indices.push(*u);
305 col_indices.push(*v);
306 let weight = self.gate_weight(*u) + self.gate_weight(*v);
308 values.push(weight);
309 }
310 }
311
312 let matrix = CsrMatrix::from_triplets(n, n, &row_indices, &col_indices, &values);
314
315 if let Ok(order) = topological_sort_weighted(&matrix, |i| self.gate_priority(i)) {
317 if let Ok(critical) = find_critical_path(&matrix, &order) {
319 return Some(self.optimize_order_by_critical_path(order, critical));
321 }
322 return Some(order);
323 }
324
325 None
326 }
327
328 fn gate_weight(&self, gate_id: usize) -> f64 {
330 let gate = &self.nodes[gate_id];
331 match gate.gate_type.as_str() {
332 "cnot" | "cz" | "swap" => 10.0,
334 "rzz" | "rxx" | "ryy" => 15.0,
335 "rx" | "ry" | "rz" => 2.0,
337 "h" | "s" | "t" | "x" | "y" | "z" => 1.0,
339 _ => 5.0,
341 }
342 }
343
344 fn gate_priority(&self, gate_id: usize) -> f64 {
346 let parallelism_score = self.count_parallel_successors(gate_id) as f64;
348 let weight = self.gate_weight(gate_id);
350 parallelism_score / weight
352 }
353
354 fn count_parallel_successors(&self, gate_id: usize) -> usize {
356 let mut count = 0;
357 for ((u, _v), edge_type) in &self.edges {
358 if *u == gate_id && *edge_type == EdgeType::Parallelizable {
359 count += 1;
360 }
361 }
362 count
363 }
364
365 #[cfg(feature = "scirs")]
366 fn optimize_order_by_critical_path(
368 &self,
369 order: Vec<usize>,
370 critical: Vec<usize>,
371 ) -> Vec<usize> {
372 let mut optimized = Vec::new();
373 let mut scheduled = HashSet::new();
374 let critical_set: HashSet<_> = critical.into_iter().collect();
375
376 for &gate_id in &order {
378 if critical_set.contains(&gate_id) && !scheduled.contains(&gate_id) {
379 optimized.push(gate_id);
380 scheduled.insert(gate_id);
381 }
382 }
383
384 for &gate_id in &order {
386 if !scheduled.contains(&gate_id) {
387 optimized.push(gate_id);
388 scheduled.insert(gate_id);
389 }
390 }
391
392 optimized
393 }
394
395 fn standard_topological_sort(&self) -> Vec<usize> {
397 let n = self.nodes.len();
398 let mut in_degree = vec![0; n];
399 let mut adj_list: HashMap<usize, Vec<usize>> = HashMap::new();
400
401 for ((u, v), edge_type) in &self.edges {
403 if *edge_type == EdgeType::DataDependency {
404 adj_list.entry(*u).or_default().push(*v);
405 in_degree[*v] += 1;
406 }
407 }
408
409 let mut ready: Vec<usize> = Vec::new();
411 for (i, °ree) in in_degree.iter().enumerate() {
412 if degree == 0 {
413 ready.push(i);
414 }
415 }
416
417 let mut result = Vec::new();
418 let mut layer_qubits: HashSet<u32> = HashSet::new();
419
420 while !ready.is_empty() {
421 ready.sort_by_key(|&i| self.nodes[i].qubits.len());
423
424 let mut next_layer = Vec::new();
426 let mut used = vec![false; ready.len()];
427
428 for (idx, &gate_id) in ready.iter().enumerate() {
429 if used[idx] {
430 continue;
431 }
432
433 let gate = &self.nodes[gate_id];
434 let gate_qubits: HashSet<_> = gate.qubits.iter().map(|q| q.id()).collect();
435
436 if gate_qubits.is_disjoint(&layer_qubits) {
438 next_layer.push(gate_id);
439 layer_qubits.extend(&gate_qubits);
440 used[idx] = true;
441 }
442 }
443
444 if next_layer.is_empty() && !ready.is_empty() {
446 next_layer.push(ready[0]);
447 used[0] = true;
448 }
449
450 ready.retain(|&g| !next_layer.contains(&g));
452
453 for &gate_id in &next_layer {
455 result.push(gate_id);
456
457 if let Some(neighbors) = adj_list.get(&gate_id) {
458 for &neighbor in neighbors {
459 in_degree[neighbor] -= 1;
460 if in_degree[neighbor] == 0 {
461 ready.push(neighbor);
462 }
463 }
464 }
465 }
466
467 layer_qubits.clear();
468 }
469
470 result
471 }
472}
473
474impl Default for CircuitDAG {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480pub struct GraphOptimizer {
482 merge_threshold: f64,
483 #[allow(dead_code)]
484 max_lookahead: usize,
485}
486
487impl GraphOptimizer {
488 pub fn new() -> Self {
490 Self {
491 merge_threshold: 1e-6,
492 max_lookahead: 10,
493 }
494 }
495
496 pub fn circuit_to_dag<const N: usize>(
498 &self,
499 circuit: &Circuit<N>,
500 ) -> Result<CircuitDAG, QuantRS2Error> {
501 let mut dag = CircuitDAG::new();
502
503 for (gate_id, gate) in circuit.gates().iter().enumerate() {
505 let graph_gate = GraphGate {
506 id: gate_id,
507 gate_type: gate.name().to_string(),
508 qubits: gate.qubits(),
509 params: if gate.is_parameterized() {
510 match gate.name() {
512 "RX" | "RY" | "RZ" => {
513 if let Some(rx_gate) =
515 gate.as_any()
516 .downcast_ref::<quantrs2_core::gate::single::RotationX>()
517 {
518 vec![rx_gate.theta]
519 } else if let Some(ry_gate) =
520 gate.as_any()
521 .downcast_ref::<quantrs2_core::gate::single::RotationY>()
522 {
523 vec![ry_gate.theta]
524 } else if let Some(rz_gate) =
525 gate.as_any()
526 .downcast_ref::<quantrs2_core::gate::single::RotationZ>()
527 {
528 vec![rz_gate.theta]
529 } else {
530 vec![] }
532 }
533 "CRX" | "CRY" | "CRZ" => {
534 if let Some(crx_gate) = gate
536 .as_any()
537 .downcast_ref::<quantrs2_core::gate::multi::CRX>()
538 {
539 vec![crx_gate.theta]
540 } else if let Some(cry_gate) =
541 gate.as_any()
542 .downcast_ref::<quantrs2_core::gate::multi::CRY>()
543 {
544 vec![cry_gate.theta]
545 } else if let Some(crz_gate) =
546 gate.as_any()
547 .downcast_ref::<quantrs2_core::gate::multi::CRZ>()
548 {
549 vec![crz_gate.theta]
550 } else {
551 vec![]
552 }
553 }
554 _ => vec![], }
556 } else {
557 vec![] },
559 matrix: None, };
561
562 dag.add_gate(graph_gate);
563 }
564
565 Ok(dag)
566 }
567
568 pub fn optimize_gate_sequence(&self, gates: Vec<GraphGate>) -> Vec<GraphGate> {
570 let mut dag = CircuitDAG::new();
571
572 for gate in gates {
574 dag.add_gate(gate);
575 }
576
577 dag.compute_commutation_edges();
579
580 #[cfg(feature = "scirs")]
582 {
583 if dag.nodes.len() > 10 {
585 if let Some(optimized_order) = dag.optimize_with_feedback_arc_set() {
586 return optimized_order
587 .iter()
588 .map(|&i| dag.nodes[i].clone())
589 .collect();
590 }
591 }
592 }
593
594 let order = dag.optimized_topological_sort();
596
597 self.merge_gates_in_sequence(order.iter().map(|&i| dag.nodes[i].clone()).collect())
599 }
600
601 fn merge_gates_in_sequence(&self, gates: Vec<GraphGate>) -> Vec<GraphGate> {
603 if gates.is_empty() {
604 return gates;
605 }
606
607 let mut merged = Vec::new();
608 let mut i = 0;
609
610 while i < gates.len() {
611 if i + 1 < gates.len() {
612 if let Some(merged_gate) = self.try_merge_gates(&gates[i], &gates[i + 1]) {
614 merged.push(merged_gate);
615 i += 2; continue;
617 }
618 }
619 merged.push(gates[i].clone());
620 i += 1;
621 }
622
623 merged
624 }
625
626 fn try_merge_gates(&self, g1: &GraphGate, g2: &GraphGate) -> Option<GraphGate> {
628 if g1.qubits.len() == 1 && g2.qubits.len() == 1 && g1.qubits[0] == g2.qubits[0] {
630 self.merge_single_qubit_gates(g1, g2)
631 } else {
632 None
633 }
634 }
635
636 pub fn merge_single_qubit_gates(&self, g1: &GraphGate, g2: &GraphGate) -> Option<GraphGate> {
638 if g1.qubits.len() != 1 || g2.qubits.len() != 1 || g1.qubits[0] != g2.qubits[0] {
640 return None;
641 }
642
643 let m1 = g1.matrix.as_ref()?;
645 let m2 = g2.matrix.as_ref()?;
646
647 let combined = matrix_multiply_2x2(m2, m1);
649
650 if let Some((gate_type, params)) = self.identify_gate(&combined) {
652 Some(GraphGate {
653 id: g1.id, gate_type,
655 qubits: g1.qubits.clone(),
656 params,
657 matrix: Some(combined),
658 })
659 } else {
660 Some(GraphGate {
662 id: g1.id,
663 gate_type: "u".to_string(),
664 qubits: g1.qubits.clone(),
665 params: vec![],
666 matrix: Some(combined),
667 })
668 }
669 }
670
671 fn identify_gate(&self, matrix: &[Vec<Complex64>]) -> Option<(String, Vec<f64>)> {
673 let tolerance = self.merge_threshold;
674
675 if self.is_pauli_x(matrix, tolerance) {
677 return Some(("x".to_string(), vec![]));
678 }
679 if self.is_pauli_y(matrix, tolerance) {
680 return Some(("y".to_string(), vec![]));
681 }
682 if self.is_pauli_z(matrix, tolerance) {
683 return Some(("z".to_string(), vec![]));
684 }
685
686 if self.is_hadamard(matrix, tolerance) {
688 return Some(("h".to_string(), vec![]));
689 }
690
691 if let Some(angle) = self.is_rz(matrix, tolerance) {
693 return Some(("rz".to_string(), vec![angle]));
694 }
695
696 None
697 }
698
699 fn is_pauli_x(&self, matrix: &[Vec<Complex64>], tol: f64) -> bool {
700 matrix.len() == 2
701 && matrix[0].len() == 2
702 && (matrix[0][0].norm() < tol)
703 && (matrix[0][1] - Complex64::new(1.0, 0.0)).norm() < tol
704 && (matrix[1][0] - Complex64::new(1.0, 0.0)).norm() < tol
705 && (matrix[1][1].norm() < tol)
706 }
707
708 fn is_pauli_y(&self, matrix: &[Vec<Complex64>], tol: f64) -> bool {
709 matrix.len() == 2
710 && matrix[0].len() == 2
711 && (matrix[0][0].norm() < tol)
712 && (matrix[0][1] - Complex64::new(0.0, -1.0)).norm() < tol
713 && (matrix[1][0] - Complex64::new(0.0, 1.0)).norm() < tol
714 && (matrix[1][1].norm() < tol)
715 }
716
717 fn is_pauli_z(&self, matrix: &[Vec<Complex64>], tol: f64) -> bool {
718 matrix.len() == 2
719 && matrix[0].len() == 2
720 && (matrix[0][0] - Complex64::new(1.0, 0.0)).norm() < tol
721 && (matrix[0][1].norm() < tol)
722 && (matrix[1][0].norm() < tol)
723 && (matrix[1][1] - Complex64::new(-1.0, 0.0)).norm() < tol
724 }
725
726 fn is_hadamard(&self, matrix: &[Vec<Complex64>], tol: f64) -> bool {
727 let h_val = 1.0 / 2.0_f64.sqrt();
728 matrix.len() == 2
729 && matrix[0].len() == 2
730 && (matrix[0][0] - Complex64::new(h_val, 0.0)).norm() < tol
731 && (matrix[0][1] - Complex64::new(h_val, 0.0)).norm() < tol
732 && (matrix[1][0] - Complex64::new(h_val, 0.0)).norm() < tol
733 && (matrix[1][1] - Complex64::new(-h_val, 0.0)).norm() < tol
734 }
735
736 fn is_rz(&self, matrix: &[Vec<Complex64>], tol: f64) -> Option<f64> {
737 if matrix.len() != 2
738 || matrix[0].len() != 2
739 || matrix[0][1].norm() > tol
740 || matrix[1][0].norm() > tol
741 {
742 return None;
743 }
744
745 let phase1 = matrix[0][0].arg();
746 let phase2 = matrix[1][1].arg();
747
748 if (matrix[0][0].norm() - 1.0).abs() < tol && (matrix[1][1].norm() - 1.0).abs() < tol {
749 let angle = phase2 - phase1;
750 Some(angle)
751 } else {
752 None
753 }
754 }
755}
756
757impl Default for GraphOptimizer {
758 fn default() -> Self {
759 Self::new()
760 }
761}
762
763#[derive(Debug, Clone)]
765pub struct OptimizationStats {
766 pub original_gate_count: usize,
767 pub optimized_gate_count: usize,
768 pub original_depth: usize,
769 pub optimized_depth: usize,
770 pub gates_removed: usize,
771 pub gates_merged: usize,
772}
773
774impl OptimizationStats {
775 pub fn improvement_percentage(&self) -> f64 {
776 if self.original_gate_count == 0 {
777 0.0
778 } else {
779 100.0 * (self.original_gate_count - self.optimized_gate_count) as f64
780 / self.original_gate_count as f64
781 }
782 }
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788
789 #[test]
790 fn test_dag_construction() {
791 let mut dag = CircuitDAG::new();
792
793 let g1 = GraphGate {
794 id: 0,
795 gate_type: "h".to_string(),
796 qubits: vec![QubitId::new(0)],
797 params: vec![],
798 matrix: None,
799 };
800
801 let g2 = GraphGate {
802 id: 1,
803 gate_type: "cnot".to_string(),
804 qubits: vec![QubitId::new(0), QubitId::new(1)],
805 params: vec![],
806 matrix: None,
807 };
808
809 dag.add_gate(g1);
810 dag.add_gate(g2);
811
812 assert_eq!(dag.nodes.len(), 2);
813 assert!(dag.edges.contains_key(&(0, 1)));
814 }
815
816 #[test]
817 fn test_commutation_detection() {
818 let _optimizer = GraphOptimizer::new();
819
820 let g1 = GraphGate {
821 id: 0,
822 gate_type: "z".to_string(),
823 qubits: vec![QubitId::new(0)],
824 params: vec![],
825 matrix: None,
826 };
827
828 let g2 = GraphGate {
829 id: 1,
830 gate_type: "z".to_string(),
831 qubits: vec![QubitId::new(0)],
832 params: vec![],
833 matrix: None,
834 };
835
836 let dag = CircuitDAG::new();
837 assert!(dag.gates_commute(&g1, &g2));
838 }
839
840 #[test]
841 fn test_gate_identification() {
842 let optimizer = GraphOptimizer::new();
843
844 let x_matrix = vec![
846 vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
847 vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
848 ];
849
850 if let Some((gate_type, _)) = optimizer.identify_gate(&x_matrix) {
851 assert_eq!(gate_type, "x");
852 } else {
853 panic!("Failed to identify Pauli X gate");
854 }
855 }
856}