quantrs2_circuit/
dag.rs

1//! Directed Acyclic Graph (DAG) representation for quantum circuits.
2//!
3//! This module provides a DAG representation of quantum circuits that enables
4//! advanced optimization techniques such as gate reordering, parallelization,
5//! and dependency analysis.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::fmt;
9
10use quantrs2_core::{gate::GateOp, qubit::QubitId};
11
12use std::fmt::Write;
13/// A node in the circuit DAG
14#[derive(Debug, Clone)]
15pub struct DagNode {
16    /// Unique identifier for this node
17    pub id: usize,
18    /// The quantum gate operation
19    pub gate: Box<dyn GateOp>,
20    /// Indices of predecessor nodes
21    pub predecessors: Vec<usize>,
22    /// Indices of successor nodes
23    pub successors: Vec<usize>,
24    /// Depth in the DAG (for scheduling)
25    pub depth: usize,
26}
27
28/// Edge type in the DAG
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub enum EdgeType {
31    /// Data dependency on same qubit
32    QubitDependency(u32),
33    /// Classical control dependency
34    ClassicalDependency,
35    /// Barrier dependency
36    BarrierDependency,
37}
38
39/// An edge in the circuit DAG
40#[derive(Debug, Clone)]
41pub struct DagEdge {
42    /// Source node index
43    pub source: usize,
44    /// Target node index
45    pub target: usize,
46    /// Type of dependency
47    pub edge_type: EdgeType,
48}
49
50/// DAG representation of a quantum circuit
51pub struct CircuitDag {
52    /// All nodes in the DAG
53    nodes: Vec<DagNode>,
54    /// All edges in the DAG
55    edges: Vec<DagEdge>,
56    /// Map from qubit ID to the last node that operated on it
57    qubit_last_use: HashMap<u32, usize>,
58    /// Input nodes (no predecessors)
59    input_nodes: Vec<usize>,
60    /// Output nodes (no successors)
61    output_nodes: Vec<usize>,
62}
63
64impl CircuitDag {
65    /// Create a new empty DAG
66    #[must_use]
67    pub fn new() -> Self {
68        Self {
69            nodes: Vec::new(),
70            edges: Vec::new(),
71            qubit_last_use: HashMap::new(),
72            input_nodes: Vec::new(),
73            output_nodes: Vec::new(),
74        }
75    }
76
77    /// Add a gate to the DAG
78    pub fn add_gate(&mut self, gate: Box<dyn GateOp>) -> usize {
79        let node_id = self.nodes.len();
80        let qubits = gate.qubits();
81
82        // Find predecessors based on qubit dependencies
83        let mut predecessors = Vec::new();
84        for qubit in &qubits {
85            if let Some(&last_node) = self.qubit_last_use.get(&qubit.id()) {
86                predecessors.push(last_node);
87
88                // Add edge
89                self.edges.push(DagEdge {
90                    source: last_node,
91                    target: node_id,
92                    edge_type: EdgeType::QubitDependency(qubit.id()),
93                });
94
95                // Update successor of predecessor
96                self.nodes[last_node].successors.push(node_id);
97            }
98        }
99
100        // Calculate depth
101        let depth = if predecessors.is_empty() {
102            0
103        } else {
104            predecessors
105                .iter()
106                .map(|&pred| self.nodes[pred].depth)
107                .max()
108                .unwrap_or(0)
109                + 1
110        };
111
112        // Create new node
113        let node = DagNode {
114            id: node_id,
115            gate,
116            predecessors: predecessors.clone(),
117            successors: Vec::new(),
118            depth,
119        };
120
121        // Update qubit last use
122        for qubit in &qubits {
123            self.qubit_last_use.insert(qubit.id(), node_id);
124        }
125
126        // Update input/output nodes
127        if predecessors.is_empty() {
128            self.input_nodes.push(node_id);
129        }
130
131        // Remove predecessors from output nodes
132        for &pred in &predecessors {
133            self.output_nodes.retain(|&x| x != pred);
134        }
135        self.output_nodes.push(node_id);
136
137        self.nodes.push(node);
138        node_id
139    }
140
141    /// Get all nodes in the DAG
142    #[must_use]
143    pub fn nodes(&self) -> &[DagNode] {
144        &self.nodes
145    }
146
147    /// Get all edges in the DAG
148    #[must_use]
149    pub fn edges(&self) -> &[DagEdge] {
150        &self.edges
151    }
152
153    /// Get input nodes (no predecessors)
154    #[must_use]
155    pub fn input_nodes(&self) -> &[usize] {
156        &self.input_nodes
157    }
158
159    /// Get output nodes (no successors)
160    #[must_use]
161    pub fn output_nodes(&self) -> &[usize] {
162        &self.output_nodes
163    }
164
165    /// Get the maximum depth of the DAG
166    #[must_use]
167    pub fn max_depth(&self) -> usize {
168        self.nodes.iter().map(|n| n.depth).max().unwrap_or(0)
169    }
170
171    /// Perform topological sort on the DAG
172    pub fn topological_sort(&self) -> Result<Vec<usize>, String> {
173        let mut in_degree = vec![0; self.nodes.len()];
174        let mut sorted = Vec::new();
175        let mut queue = VecDeque::new();
176
177        // Calculate in-degrees
178        for node in &self.nodes {
179            in_degree[node.id] = node.predecessors.len();
180        }
181
182        // Initialize queue with nodes having no predecessors
183        for (id, &degree) in in_degree.iter().enumerate() {
184            if degree == 0 {
185                queue.push_back(id);
186            }
187        }
188
189        // Process nodes
190        while let Some(node_id) = queue.pop_front() {
191            sorted.push(node_id);
192
193            // Reduce in-degree of successors
194            for &succ in &self.nodes[node_id].successors {
195                in_degree[succ] -= 1;
196                if in_degree[succ] == 0 {
197                    queue.push_back(succ);
198                }
199            }
200        }
201
202        // Check for cycles
203        if sorted.len() != self.nodes.len() {
204            return Err("Circuit DAG contains a cycle".to_string());
205        }
206
207        Ok(sorted)
208    }
209
210    /// Get nodes at a specific depth
211    #[must_use]
212    pub fn nodes_at_depth(&self, depth: usize) -> Vec<usize> {
213        self.nodes
214            .iter()
215            .filter(|n| n.depth == depth)
216            .map(|n| n.id)
217            .collect()
218    }
219
220    /// Find the critical path (longest path) through the DAG
221    #[must_use]
222    pub fn critical_path(&self) -> Vec<usize> {
223        if self.nodes.is_empty() {
224            return Vec::new();
225        }
226
227        // Dynamic programming to find longest path
228        let mut longest_path_to = vec![0; self.nodes.len()];
229        let mut parent = vec![None; self.nodes.len()];
230
231        // Process nodes in topological order
232        if let Ok(topo_order) = self.topological_sort() {
233            for &node_id in &topo_order {
234                for &succ in &self.nodes[node_id].successors {
235                    let new_length = longest_path_to[node_id] + 1;
236                    if new_length > longest_path_to[succ] {
237                        longest_path_to[succ] = new_length;
238                        parent[succ] = Some(node_id);
239                    }
240                }
241            }
242        }
243
244        // Find the end of the longest path
245        let mut end_node = 0;
246        let mut max_length = 0;
247        for (id, &length) in longest_path_to.iter().enumerate() {
248            if length > max_length {
249                max_length = length;
250                end_node = id;
251            }
252        }
253
254        // Reconstruct the path
255        let mut path = Vec::new();
256        let mut current = Some(end_node);
257        while let Some(node) = current {
258            path.push(node);
259            current = parent[node];
260        }
261        path.reverse();
262
263        path
264    }
265
266    /// Get all paths between two nodes
267    #[must_use]
268    pub fn paths_between(&self, start: usize, end: usize) -> Vec<Vec<usize>> {
269        let mut paths = Vec::new();
270        let mut current_path = vec![start];
271        let mut visited = HashSet::new();
272
273        self.find_paths_dfs(start, end, &mut current_path, &mut visited, &mut paths);
274
275        paths
276    }
277
278    fn find_paths_dfs(
279        &self,
280        current: usize,
281        end: usize,
282        current_path: &mut Vec<usize>,
283        visited: &mut HashSet<usize>,
284        paths: &mut Vec<Vec<usize>>,
285    ) {
286        if current == end {
287            paths.push(current_path.clone());
288            return;
289        }
290
291        visited.insert(current);
292
293        for &successor in &self.nodes[current].successors {
294            if !visited.contains(&successor) {
295                current_path.push(successor);
296                self.find_paths_dfs(successor, end, current_path, visited, paths);
297                current_path.pop();
298            }
299        }
300
301        visited.remove(&current);
302    }
303
304    /// Check if two nodes are independent (can be executed in parallel)
305    #[must_use]
306    pub fn are_independent(&self, node1: usize, node2: usize) -> bool {
307        // Two nodes are independent if there's no path between them
308        self.paths_between(node1, node2).is_empty() && self.paths_between(node2, node1).is_empty()
309    }
310
311    /// Get all nodes that can be executed in parallel with a given node
312    #[must_use]
313    pub fn parallel_nodes(&self, node_id: usize) -> Vec<usize> {
314        self.nodes
315            .iter()
316            .filter(|n| n.id != node_id && self.are_independent(node_id, n.id))
317            .map(|n| n.id)
318            .collect()
319    }
320
321    /// Convert the DAG to a DOT format string for visualization
322    #[must_use]
323    pub fn to_dot(&self) -> String {
324        let mut dot = String::from("digraph CircuitDAG {\n");
325        dot.push_str("  rankdir=LR;\n");
326        dot.push_str("  node [shape=box];\n");
327
328        // Add nodes
329        for node in &self.nodes {
330            writeln!(
331                dot,
332                "  {} [label=\"{}: {}\"];",
333                node.id,
334                node.id,
335                node.gate.name()
336            )
337            .expect("writeln! to String cannot fail");
338        }
339
340        // Add edges
341        for edge in &self.edges {
342            let label = match edge.edge_type {
343                EdgeType::QubitDependency(q) => format!("q{q}"),
344                EdgeType::ClassicalDependency => "classical".to_string(),
345                EdgeType::BarrierDependency => "barrier".to_string(),
346            };
347            writeln!(
348                dot,
349                "  {} -> {} [label=\"{}\"];",
350                edge.source, edge.target, label
351            )
352            .expect("writeln! to String cannot fail");
353        }
354
355        dot.push_str("}\n");
356        dot
357    }
358}
359
360impl Default for CircuitDag {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366impl fmt::Debug for CircuitDag {
367    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368        f.debug_struct("CircuitDag")
369            .field("nodes", &self.nodes.len())
370            .field("edges", &self.edges.len())
371            .field("max_depth", &self.max_depth())
372            .finish()
373    }
374}
375
376/// Convert a Circuit into a DAG representation
377#[must_use]
378pub fn circuit_to_dag<const N: usize>(circuit: &crate::builder::Circuit<N>) -> CircuitDag {
379    let mut dag = CircuitDag::new();
380
381    for gate in circuit.gates() {
382        // Convert Arc to Box for DAG compatibility
383        let boxed_gate = gate.clone_gate();
384        dag.add_gate(boxed_gate);
385    }
386
387    dag
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use quantrs2_core::gate::multi::CNOT;
394    use quantrs2_core::gate::single::{Hadamard, PauliX};
395    use quantrs2_core::qubit::QubitId;
396
397    #[test]
398    fn test_dag_creation() {
399        let mut dag = CircuitDag::new();
400
401        // Add H gate on qubit 0
402        let h_gate = Box::new(Hadamard { target: QubitId(0) });
403        let h_id = dag.add_gate(h_gate);
404
405        // Add X gate on qubit 1
406        let x_gate = Box::new(PauliX { target: QubitId(1) });
407        let x_id = dag.add_gate(x_gate);
408
409        // Add CNOT gate on qubits 0,1
410        let cnot_gate = Box::new(CNOT {
411            control: QubitId(0),
412            target: QubitId(1),
413        });
414        let cnot_id = dag.add_gate(cnot_gate);
415
416        // Check structure
417        assert_eq!(dag.nodes().len(), 3);
418        assert_eq!(dag.edges().len(), 2);
419        assert_eq!(dag.input_nodes(), &[h_id, x_id]);
420        assert_eq!(dag.output_nodes(), &[cnot_id]);
421    }
422
423    #[test]
424    fn test_topological_sort() {
425        let mut dag = CircuitDag::new();
426
427        // Create a simple circuit: H(0) -> CNOT(0,1) <- X(1)
428        let h_gate = Box::new(Hadamard { target: QubitId(0) });
429        let h_id = dag.add_gate(h_gate);
430
431        let x_gate = Box::new(PauliX { target: QubitId(1) });
432        let x_id = dag.add_gate(x_gate);
433
434        let cnot_gate = Box::new(CNOT {
435            control: QubitId(0),
436            target: QubitId(1),
437        });
438        let cnot_id = dag.add_gate(cnot_gate);
439
440        let sorted = dag
441            .topological_sort()
442            .expect("topological_sort should succeed");
443
444        // H and X can be in any order, but CNOT must be last
445        assert_eq!(sorted.len(), 3);
446        assert!(sorted.contains(&h_id));
447        assert!(sorted.contains(&x_id));
448        assert_eq!(sorted[2], cnot_id);
449    }
450
451    #[test]
452    fn test_parallel_nodes() {
453        let mut dag = CircuitDag::new();
454
455        // Add gates on different qubits (can be parallel)
456        let h0 = dag.add_gate(Box::new(Hadamard { target: QubitId(0) }));
457        let h1 = dag.add_gate(Box::new(Hadamard { target: QubitId(1) }));
458        let h2 = dag.add_gate(Box::new(Hadamard { target: QubitId(2) }));
459
460        // Check that all H gates can be executed in parallel
461        assert!(dag.are_independent(h0, h1));
462        assert!(dag.are_independent(h0, h2));
463        assert!(dag.are_independent(h1, h2));
464
465        let parallel_to_h0 = dag.parallel_nodes(h0);
466        assert!(parallel_to_h0.contains(&h1));
467        assert!(parallel_to_h0.contains(&h2));
468    }
469
470    #[test]
471    fn test_critical_path() {
472        let mut dag = CircuitDag::new();
473
474        // Create a circuit with a clear critical path
475        // H(0) -> CNOT(0,1) -> X(0)
476        //      -> X(1) -----/
477        let h0 = dag.add_gate(Box::new(Hadamard { target: QubitId(0) }));
478        let x1 = dag.add_gate(Box::new(PauliX { target: QubitId(1) }));
479        let cnot = dag.add_gate(Box::new(CNOT {
480            control: QubitId(0),
481            target: QubitId(1),
482        }));
483        let x0 = dag.add_gate(Box::new(PauliX { target: QubitId(0) }));
484
485        let path = dag.critical_path();
486
487        // Critical path should be H(0) -> CNOT -> X(0)
488        assert_eq!(path.len(), 3);
489        assert_eq!(path[0], h0);
490        assert_eq!(path[1], cnot);
491        assert_eq!(path[2], x0);
492    }
493}