use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt;
use quantrs2_core::{gate::GateOp, qubit::QubitId};
use std::fmt::Write;
#[derive(Debug, Clone)]
pub struct DagNode {
pub id: usize,
pub gate: Box<dyn GateOp>,
pub predecessors: Vec<usize>,
pub successors: Vec<usize>,
pub depth: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EdgeType {
QubitDependency(u32),
ClassicalDependency,
BarrierDependency,
}
#[derive(Debug, Clone)]
pub struct DagEdge {
pub source: usize,
pub target: usize,
pub edge_type: EdgeType,
}
pub struct CircuitDag {
nodes: Vec<DagNode>,
edges: Vec<DagEdge>,
qubit_last_use: HashMap<u32, usize>,
input_nodes: Vec<usize>,
output_nodes: Vec<usize>,
}
impl CircuitDag {
#[must_use]
pub fn new() -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
qubit_last_use: HashMap::new(),
input_nodes: Vec::new(),
output_nodes: Vec::new(),
}
}
pub fn add_gate(&mut self, gate: Box<dyn GateOp>) -> usize {
let node_id = self.nodes.len();
let qubits = gate.qubits();
let mut predecessors = Vec::new();
for qubit in &qubits {
if let Some(&last_node) = self.qubit_last_use.get(&qubit.id()) {
predecessors.push(last_node);
self.edges.push(DagEdge {
source: last_node,
target: node_id,
edge_type: EdgeType::QubitDependency(qubit.id()),
});
self.nodes[last_node].successors.push(node_id);
}
}
let depth = if predecessors.is_empty() {
0
} else {
predecessors
.iter()
.map(|&pred| self.nodes[pred].depth)
.max()
.unwrap_or(0)
+ 1
};
let node = DagNode {
id: node_id,
gate,
predecessors: predecessors.clone(),
successors: Vec::new(),
depth,
};
for qubit in &qubits {
self.qubit_last_use.insert(qubit.id(), node_id);
}
if predecessors.is_empty() {
self.input_nodes.push(node_id);
}
for &pred in &predecessors {
self.output_nodes.retain(|&x| x != pred);
}
self.output_nodes.push(node_id);
self.nodes.push(node);
node_id
}
#[must_use]
pub fn nodes(&self) -> &[DagNode] {
&self.nodes
}
#[must_use]
pub fn edges(&self) -> &[DagEdge] {
&self.edges
}
#[must_use]
pub fn input_nodes(&self) -> &[usize] {
&self.input_nodes
}
#[must_use]
pub fn output_nodes(&self) -> &[usize] {
&self.output_nodes
}
#[must_use]
pub fn max_depth(&self) -> usize {
self.nodes.iter().map(|n| n.depth).max().unwrap_or(0)
}
pub fn topological_sort(&self) -> Result<Vec<usize>, String> {
let mut in_degree = vec![0; self.nodes.len()];
let mut sorted = Vec::new();
let mut queue = VecDeque::new();
for node in &self.nodes {
in_degree[node.id] = node.predecessors.len();
}
for (id, °ree) in in_degree.iter().enumerate() {
if degree == 0 {
queue.push_back(id);
}
}
while let Some(node_id) = queue.pop_front() {
sorted.push(node_id);
for &succ in &self.nodes[node_id].successors {
in_degree[succ] -= 1;
if in_degree[succ] == 0 {
queue.push_back(succ);
}
}
}
if sorted.len() != self.nodes.len() {
return Err("Circuit DAG contains a cycle".to_string());
}
Ok(sorted)
}
#[must_use]
pub fn nodes_at_depth(&self, depth: usize) -> Vec<usize> {
self.nodes
.iter()
.filter(|n| n.depth == depth)
.map(|n| n.id)
.collect()
}
#[must_use]
pub fn critical_path(&self) -> Vec<usize> {
if self.nodes.is_empty() {
return Vec::new();
}
let mut longest_path_to = vec![0; self.nodes.len()];
let mut parent = vec![None; self.nodes.len()];
if let Ok(topo_order) = self.topological_sort() {
for &node_id in &topo_order {
for &succ in &self.nodes[node_id].successors {
let new_length = longest_path_to[node_id] + 1;
if new_length > longest_path_to[succ] {
longest_path_to[succ] = new_length;
parent[succ] = Some(node_id);
}
}
}
}
let mut end_node = 0;
let mut max_length = 0;
for (id, &length) in longest_path_to.iter().enumerate() {
if length > max_length {
max_length = length;
end_node = id;
}
}
let mut path = Vec::new();
let mut current = Some(end_node);
while let Some(node) = current {
path.push(node);
current = parent[node];
}
path.reverse();
path
}
#[must_use]
pub fn paths_between(&self, start: usize, end: usize) -> Vec<Vec<usize>> {
let mut paths = Vec::new();
let mut current_path = vec![start];
let mut visited = HashSet::new();
self.find_paths_dfs(start, end, &mut current_path, &mut visited, &mut paths);
paths
}
fn find_paths_dfs(
&self,
current: usize,
end: usize,
current_path: &mut Vec<usize>,
visited: &mut HashSet<usize>,
paths: &mut Vec<Vec<usize>>,
) {
if current == end {
paths.push(current_path.clone());
return;
}
visited.insert(current);
for &successor in &self.nodes[current].successors {
if !visited.contains(&successor) {
current_path.push(successor);
self.find_paths_dfs(successor, end, current_path, visited, paths);
current_path.pop();
}
}
visited.remove(¤t);
}
#[must_use]
pub fn are_independent(&self, node1: usize, node2: usize) -> bool {
self.paths_between(node1, node2).is_empty() && self.paths_between(node2, node1).is_empty()
}
#[must_use]
pub fn parallel_nodes(&self, node_id: usize) -> Vec<usize> {
self.nodes
.iter()
.filter(|n| n.id != node_id && self.are_independent(node_id, n.id))
.map(|n| n.id)
.collect()
}
#[must_use]
pub fn to_dot(&self) -> String {
let mut dot = String::from("digraph CircuitDAG {\n");
dot.push_str(" rankdir=LR;\n");
dot.push_str(" node [shape=box];\n");
for node in &self.nodes {
writeln!(
dot,
" {} [label=\"{}: {}\"];",
node.id,
node.id,
node.gate.name()
)
.expect("writeln! to String cannot fail");
}
for edge in &self.edges {
let label = match edge.edge_type {
EdgeType::QubitDependency(q) => format!("q{q}"),
EdgeType::ClassicalDependency => "classical".to_string(),
EdgeType::BarrierDependency => "barrier".to_string(),
};
writeln!(
dot,
" {} -> {} [label=\"{}\"];",
edge.source, edge.target, label
)
.expect("writeln! to String cannot fail");
}
dot.push_str("}\n");
dot
}
}
impl Default for CircuitDag {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for CircuitDag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CircuitDag")
.field("nodes", &self.nodes.len())
.field("edges", &self.edges.len())
.field("max_depth", &self.max_depth())
.finish()
}
}
#[must_use]
pub fn circuit_to_dag<const N: usize>(circuit: &crate::builder::Circuit<N>) -> CircuitDag {
let mut dag = CircuitDag::new();
for gate in circuit.gates() {
let boxed_gate = gate.clone_gate();
dag.add_gate(boxed_gate);
}
dag
}
#[cfg(test)]
mod tests {
use super::*;
use quantrs2_core::gate::multi::CNOT;
use quantrs2_core::gate::single::{Hadamard, PauliX};
use quantrs2_core::qubit::QubitId;
#[test]
fn test_dag_creation() {
let mut dag = CircuitDag::new();
let h_gate = Box::new(Hadamard { target: QubitId(0) });
let h_id = dag.add_gate(h_gate);
let x_gate = Box::new(PauliX { target: QubitId(1) });
let x_id = dag.add_gate(x_gate);
let cnot_gate = Box::new(CNOT {
control: QubitId(0),
target: QubitId(1),
});
let cnot_id = dag.add_gate(cnot_gate);
assert_eq!(dag.nodes().len(), 3);
assert_eq!(dag.edges().len(), 2);
assert_eq!(dag.input_nodes(), &[h_id, x_id]);
assert_eq!(dag.output_nodes(), &[cnot_id]);
}
#[test]
fn test_topological_sort() {
let mut dag = CircuitDag::new();
let h_gate = Box::new(Hadamard { target: QubitId(0) });
let h_id = dag.add_gate(h_gate);
let x_gate = Box::new(PauliX { target: QubitId(1) });
let x_id = dag.add_gate(x_gate);
let cnot_gate = Box::new(CNOT {
control: QubitId(0),
target: QubitId(1),
});
let cnot_id = dag.add_gate(cnot_gate);
let sorted = dag
.topological_sort()
.expect("topological_sort should succeed");
assert_eq!(sorted.len(), 3);
assert!(sorted.contains(&h_id));
assert!(sorted.contains(&x_id));
assert_eq!(sorted[2], cnot_id);
}
#[test]
fn test_parallel_nodes() {
let mut dag = CircuitDag::new();
let h0 = dag.add_gate(Box::new(Hadamard { target: QubitId(0) }));
let h1 = dag.add_gate(Box::new(Hadamard { target: QubitId(1) }));
let h2 = dag.add_gate(Box::new(Hadamard { target: QubitId(2) }));
assert!(dag.are_independent(h0, h1));
assert!(dag.are_independent(h0, h2));
assert!(dag.are_independent(h1, h2));
let parallel_to_h0 = dag.parallel_nodes(h0);
assert!(parallel_to_h0.contains(&h1));
assert!(parallel_to_h0.contains(&h2));
}
#[test]
fn test_critical_path() {
let mut dag = CircuitDag::new();
let h0 = dag.add_gate(Box::new(Hadamard { target: QubitId(0) }));
let x1 = dag.add_gate(Box::new(PauliX { target: QubitId(1) }));
let cnot = dag.add_gate(Box::new(CNOT {
control: QubitId(0),
target: QubitId(1),
}));
let x0 = dag.add_gate(Box::new(PauliX { target: QubitId(0) }));
let path = dag.critical_path();
assert_eq!(path.len(), 3);
assert_eq!(path[0], h0);
assert_eq!(path[1], cnot);
assert_eq!(path[2], x0);
}
}