1use std::collections::{HashMap, HashSet, VecDeque};
8use std::fmt;
9
10use quantrs2_core::{gate::GateOp, qubit::QubitId};
11
12use std::fmt::Write;
13#[derive(Debug, Clone)]
15pub struct DagNode {
16 pub id: usize,
18 pub gate: Box<dyn GateOp>,
20 pub predecessors: Vec<usize>,
22 pub successors: Vec<usize>,
24 pub depth: usize,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub enum EdgeType {
31 QubitDependency(u32),
33 ClassicalDependency,
35 BarrierDependency,
37}
38
39#[derive(Debug, Clone)]
41pub struct DagEdge {
42 pub source: usize,
44 pub target: usize,
46 pub edge_type: EdgeType,
48}
49
50pub struct CircuitDag {
52 nodes: Vec<DagNode>,
54 edges: Vec<DagEdge>,
56 qubit_last_use: HashMap<u32, usize>,
58 input_nodes: Vec<usize>,
60 output_nodes: Vec<usize>,
62}
63
64impl CircuitDag {
65 #[must_use]
67 pub fn new() -> Self {
68 Self {
69 nodes: Vec::new(),
70 edges: Vec::new(),
71 qubit_last_use: HashMap::new(),
72 input_nodes: Vec::new(),
73 output_nodes: Vec::new(),
74 }
75 }
76
77 pub fn add_gate(&mut self, gate: Box<dyn GateOp>) -> usize {
79 let node_id = self.nodes.len();
80 let qubits = gate.qubits();
81
82 let mut predecessors = Vec::new();
84 for qubit in &qubits {
85 if let Some(&last_node) = self.qubit_last_use.get(&qubit.id()) {
86 predecessors.push(last_node);
87
88 self.edges.push(DagEdge {
90 source: last_node,
91 target: node_id,
92 edge_type: EdgeType::QubitDependency(qubit.id()),
93 });
94
95 self.nodes[last_node].successors.push(node_id);
97 }
98 }
99
100 let depth = if predecessors.is_empty() {
102 0
103 } else {
104 predecessors
105 .iter()
106 .map(|&pred| self.nodes[pred].depth)
107 .max()
108 .unwrap_or(0)
109 + 1
110 };
111
112 let node = DagNode {
114 id: node_id,
115 gate,
116 predecessors: predecessors.clone(),
117 successors: Vec::new(),
118 depth,
119 };
120
121 for qubit in &qubits {
123 self.qubit_last_use.insert(qubit.id(), node_id);
124 }
125
126 if predecessors.is_empty() {
128 self.input_nodes.push(node_id);
129 }
130
131 for &pred in &predecessors {
133 self.output_nodes.retain(|&x| x != pred);
134 }
135 self.output_nodes.push(node_id);
136
137 self.nodes.push(node);
138 node_id
139 }
140
141 #[must_use]
143 pub fn nodes(&self) -> &[DagNode] {
144 &self.nodes
145 }
146
147 #[must_use]
149 pub fn edges(&self) -> &[DagEdge] {
150 &self.edges
151 }
152
153 #[must_use]
155 pub fn input_nodes(&self) -> &[usize] {
156 &self.input_nodes
157 }
158
159 #[must_use]
161 pub fn output_nodes(&self) -> &[usize] {
162 &self.output_nodes
163 }
164
165 #[must_use]
167 pub fn max_depth(&self) -> usize {
168 self.nodes.iter().map(|n| n.depth).max().unwrap_or(0)
169 }
170
171 pub fn topological_sort(&self) -> Result<Vec<usize>, String> {
173 let mut in_degree = vec![0; self.nodes.len()];
174 let mut sorted = Vec::new();
175 let mut queue = VecDeque::new();
176
177 for node in &self.nodes {
179 in_degree[node.id] = node.predecessors.len();
180 }
181
182 for (id, °ree) in in_degree.iter().enumerate() {
184 if degree == 0 {
185 queue.push_back(id);
186 }
187 }
188
189 while let Some(node_id) = queue.pop_front() {
191 sorted.push(node_id);
192
193 for &succ in &self.nodes[node_id].successors {
195 in_degree[succ] -= 1;
196 if in_degree[succ] == 0 {
197 queue.push_back(succ);
198 }
199 }
200 }
201
202 if sorted.len() != self.nodes.len() {
204 return Err("Circuit DAG contains a cycle".to_string());
205 }
206
207 Ok(sorted)
208 }
209
210 #[must_use]
212 pub fn nodes_at_depth(&self, depth: usize) -> Vec<usize> {
213 self.nodes
214 .iter()
215 .filter(|n| n.depth == depth)
216 .map(|n| n.id)
217 .collect()
218 }
219
220 #[must_use]
222 pub fn critical_path(&self) -> Vec<usize> {
223 if self.nodes.is_empty() {
224 return Vec::new();
225 }
226
227 let mut longest_path_to = vec![0; self.nodes.len()];
229 let mut parent = vec![None; self.nodes.len()];
230
231 if let Ok(topo_order) = self.topological_sort() {
233 for &node_id in &topo_order {
234 for &succ in &self.nodes[node_id].successors {
235 let new_length = longest_path_to[node_id] + 1;
236 if new_length > longest_path_to[succ] {
237 longest_path_to[succ] = new_length;
238 parent[succ] = Some(node_id);
239 }
240 }
241 }
242 }
243
244 let mut end_node = 0;
246 let mut max_length = 0;
247 for (id, &length) in longest_path_to.iter().enumerate() {
248 if length > max_length {
249 max_length = length;
250 end_node = id;
251 }
252 }
253
254 let mut path = Vec::new();
256 let mut current = Some(end_node);
257 while let Some(node) = current {
258 path.push(node);
259 current = parent[node];
260 }
261 path.reverse();
262
263 path
264 }
265
266 #[must_use]
268 pub fn paths_between(&self, start: usize, end: usize) -> Vec<Vec<usize>> {
269 let mut paths = Vec::new();
270 let mut current_path = vec![start];
271 let mut visited = HashSet::new();
272
273 self.find_paths_dfs(start, end, &mut current_path, &mut visited, &mut paths);
274
275 paths
276 }
277
278 fn find_paths_dfs(
279 &self,
280 current: usize,
281 end: usize,
282 current_path: &mut Vec<usize>,
283 visited: &mut HashSet<usize>,
284 paths: &mut Vec<Vec<usize>>,
285 ) {
286 if current == end {
287 paths.push(current_path.clone());
288 return;
289 }
290
291 visited.insert(current);
292
293 for &successor in &self.nodes[current].successors {
294 if !visited.contains(&successor) {
295 current_path.push(successor);
296 self.find_paths_dfs(successor, end, current_path, visited, paths);
297 current_path.pop();
298 }
299 }
300
301 visited.remove(¤t);
302 }
303
304 #[must_use]
306 pub fn are_independent(&self, node1: usize, node2: usize) -> bool {
307 self.paths_between(node1, node2).is_empty() && self.paths_between(node2, node1).is_empty()
309 }
310
311 #[must_use]
313 pub fn parallel_nodes(&self, node_id: usize) -> Vec<usize> {
314 self.nodes
315 .iter()
316 .filter(|n| n.id != node_id && self.are_independent(node_id, n.id))
317 .map(|n| n.id)
318 .collect()
319 }
320
321 #[must_use]
323 pub fn to_dot(&self) -> String {
324 let mut dot = String::from("digraph CircuitDAG {\n");
325 dot.push_str(" rankdir=LR;\n");
326 dot.push_str(" node [shape=box];\n");
327
328 for node in &self.nodes {
330 writeln!(
331 dot,
332 " {} [label=\"{}: {}\"];",
333 node.id,
334 node.id,
335 node.gate.name()
336 )
337 .expect("writeln! to String cannot fail");
338 }
339
340 for edge in &self.edges {
342 let label = match edge.edge_type {
343 EdgeType::QubitDependency(q) => format!("q{q}"),
344 EdgeType::ClassicalDependency => "classical".to_string(),
345 EdgeType::BarrierDependency => "barrier".to_string(),
346 };
347 writeln!(
348 dot,
349 " {} -> {} [label=\"{}\"];",
350 edge.source, edge.target, label
351 )
352 .expect("writeln! to String cannot fail");
353 }
354
355 dot.push_str("}\n");
356 dot
357 }
358}
359
360impl Default for CircuitDag {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366impl fmt::Debug for CircuitDag {
367 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368 f.debug_struct("CircuitDag")
369 .field("nodes", &self.nodes.len())
370 .field("edges", &self.edges.len())
371 .field("max_depth", &self.max_depth())
372 .finish()
373 }
374}
375
376#[must_use]
378pub fn circuit_to_dag<const N: usize>(circuit: &crate::builder::Circuit<N>) -> CircuitDag {
379 let mut dag = CircuitDag::new();
380
381 for gate in circuit.gates() {
382 let boxed_gate = gate.clone_gate();
384 dag.add_gate(boxed_gate);
385 }
386
387 dag
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use quantrs2_core::gate::multi::CNOT;
394 use quantrs2_core::gate::single::{Hadamard, PauliX};
395 use quantrs2_core::qubit::QubitId;
396
397 #[test]
398 fn test_dag_creation() {
399 let mut dag = CircuitDag::new();
400
401 let h_gate = Box::new(Hadamard { target: QubitId(0) });
403 let h_id = dag.add_gate(h_gate);
404
405 let x_gate = Box::new(PauliX { target: QubitId(1) });
407 let x_id = dag.add_gate(x_gate);
408
409 let cnot_gate = Box::new(CNOT {
411 control: QubitId(0),
412 target: QubitId(1),
413 });
414 let cnot_id = dag.add_gate(cnot_gate);
415
416 assert_eq!(dag.nodes().len(), 3);
418 assert_eq!(dag.edges().len(), 2);
419 assert_eq!(dag.input_nodes(), &[h_id, x_id]);
420 assert_eq!(dag.output_nodes(), &[cnot_id]);
421 }
422
423 #[test]
424 fn test_topological_sort() {
425 let mut dag = CircuitDag::new();
426
427 let h_gate = Box::new(Hadamard { target: QubitId(0) });
429 let h_id = dag.add_gate(h_gate);
430
431 let x_gate = Box::new(PauliX { target: QubitId(1) });
432 let x_id = dag.add_gate(x_gate);
433
434 let cnot_gate = Box::new(CNOT {
435 control: QubitId(0),
436 target: QubitId(1),
437 });
438 let cnot_id = dag.add_gate(cnot_gate);
439
440 let sorted = dag
441 .topological_sort()
442 .expect("topological_sort should succeed");
443
444 assert_eq!(sorted.len(), 3);
446 assert!(sorted.contains(&h_id));
447 assert!(sorted.contains(&x_id));
448 assert_eq!(sorted[2], cnot_id);
449 }
450
451 #[test]
452 fn test_parallel_nodes() {
453 let mut dag = CircuitDag::new();
454
455 let h0 = dag.add_gate(Box::new(Hadamard { target: QubitId(0) }));
457 let h1 = dag.add_gate(Box::new(Hadamard { target: QubitId(1) }));
458 let h2 = dag.add_gate(Box::new(Hadamard { target: QubitId(2) }));
459
460 assert!(dag.are_independent(h0, h1));
462 assert!(dag.are_independent(h0, h2));
463 assert!(dag.are_independent(h1, h2));
464
465 let parallel_to_h0 = dag.parallel_nodes(h0);
466 assert!(parallel_to_h0.contains(&h1));
467 assert!(parallel_to_h0.contains(&h2));
468 }
469
470 #[test]
471 fn test_critical_path() {
472 let mut dag = CircuitDag::new();
473
474 let h0 = dag.add_gate(Box::new(Hadamard { target: QubitId(0) }));
478 let x1 = dag.add_gate(Box::new(PauliX { target: QubitId(1) }));
479 let cnot = dag.add_gate(Box::new(CNOT {
480 control: QubitId(0),
481 target: QubitId(1),
482 }));
483 let x0 = dag.add_gate(Box::new(PauliX { target: QubitId(0) }));
484
485 let path = dag.critical_path();
486
487 assert_eq!(path.len(), 3);
489 assert_eq!(path[0], h0);
490 assert_eq!(path[1], cnot);
491 assert_eq!(path[2], x0);
492 }
493}