use crate::ir_inner::model::expr::Ident;
use crate::ir_inner::model::node::Node;
use crate::ir_inner::model::program::Program;
use core::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum GraphValidateError {
Cycle {
path: Vec<u32>,
},
DanglingEdge {
from: u32,
to: u32,
},
OrphanPhi {
node_id: u32,
},
}
impl fmt::Display for GraphValidateError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Cycle { path } => {
write!(
f,
"graph contains a cycle involving nodes {:?}. Fix: remove cyclic dependencies so the graph is a valid DAG.",
path
)
}
Self::DanglingEdge { from, to } => {
write!(
f,
"edge from {} to {} references a non-existent node. Fix: ensure all edge endpoints exist in the graph's node list.",
from, to
)
}
Self::OrphanPhi { node_id } => {
write!(
f,
"Phi node {} has no valid predecessors. Fix: ensure every Phi node references at least one existing predecessor node.",
node_id
)
}
}
}
}
impl std::error::Error for GraphValidateError {}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct GraphNode {
pub id: u32,
pub kind: DataflowKind,
}
impl GraphNode {
#[must_use]
pub fn new(id: u32, kind: DataflowKind) -> Self {
Self { id, kind }
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum DataflowKind {
Statement(Node),
Phi(Vec<u32>),
Barrier,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct DataEdge {
pub from: u32,
pub to: u32,
pub kind: EdgeKind,
}
impl DataEdge {
#[must_use]
pub fn new(from: u32, to: u32, kind: EdgeKind) -> Self {
Self { from, to, kind }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum EdgeKind {
Ordering,
Def {
name: Ident,
},
Control,
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct NodeGraph {
pub nodes: Vec<GraphNode>,
pub edges: Vec<DataEdge>,
pub workgroup_size: [u32; 3],
pub buffers: Vec<crate::ir_inner::model::program::BufferDecl>,
}
impl NodeGraph {
#[must_use]
pub fn new(nodes: Vec<GraphNode>, edges: Vec<DataEdge>) -> Self {
Self {
nodes,
edges,
workgroup_size: [1, 1, 1],
buffers: Vec::new(),
}
}
#[must_use]
pub fn from_program(program: &Program) -> Self {
Self::from_program_owned(program.clone())
}
#[must_use]
pub fn from_program_owned(program: Program) -> Self {
let workgroup_size = program.workgroup_size();
let buffers = program.buffers().to_vec();
let entry_vec = program.into_entry_vec();
let mut nodes = Vec::with_capacity(entry_vec.len());
let mut edges = Vec::with_capacity(entry_vec.len().saturating_sub(1));
for (i, n) in entry_vec.into_iter().enumerate() {
#[allow(clippy::cast_possible_truncation)]
let id = i as u32;
nodes.push(GraphNode {
id,
kind: DataflowKind::Statement(n),
});
if id > 0 {
edges.push(DataEdge {
from: id - 1,
to: id,
kind: EdgeKind::Ordering,
});
}
}
Self {
nodes,
edges,
workgroup_size,
buffers,
}
}
pub fn try_into_program(self) -> Result<Program, GraphValidateError> {
let node_count = self.nodes.len() as u32;
for edge in &self.edges {
if edge.from >= node_count || edge.to >= node_count {
return Err(GraphValidateError::DanglingEdge {
from: edge.from,
to: edge.to,
});
}
}
for node in &self.nodes {
if let DataflowKind::Phi(predecessors) = &node.kind {
if predecessors.is_empty() {
return Err(GraphValidateError::OrphanPhi { node_id: node.id });
}
for &pred in predecessors {
if pred >= node_count {
return Err(GraphValidateError::OrphanPhi { node_id: node.id });
}
}
}
}
let mut adj: Vec<Vec<u32>> = vec![Vec::new(); self.nodes.len()];
for edge in &self.edges {
adj[edge.from as usize].push(edge.to);
}
let mut state = vec![0u8; self.nodes.len()]; let mut path = Vec::new();
fn dfs(
node: u32,
adj: &[Vec<u32>],
state: &mut [u8],
path: &mut Vec<u32>,
) -> Option<Vec<u32>> {
let idx = node as usize;
if state[idx] == 1 {
let cycle_start = path.iter().position(|&n| n == node).unwrap_or(0);
return Some(path[cycle_start..].to_vec());
}
if state[idx] == 2 {
return None;
}
state[idx] = 1;
path.push(node);
for &next in &adj[idx] {
if let Some(cycle) = dfs(next, adj, state, path) {
return Some(cycle);
}
}
path.pop();
state[idx] = 2;
None
}
for i in 0..self.nodes.len() as u32 {
if state[i as usize] == 0 {
if let Some(cycle_path) = dfs(i, &adj, &mut state, &mut path) {
return Err(GraphValidateError::Cycle { path: cycle_path });
}
}
}
let entry: Vec<Node> = self
.nodes
.into_iter()
.filter_map(|gn| match gn.kind {
DataflowKind::Statement(n) => Some(n),
DataflowKind::Phi(_) => None,
DataflowKind::Barrier => Some(Node::barrier()),
})
.collect();
Ok(Program::wrapped(self.buffers, self.workgroup_size, entry))
}
}
#[must_use]
pub fn to_graph(program: &Program) -> NodeGraph {
NodeGraph::from_program(program)
}
pub fn from_graph(graph: NodeGraph) -> Result<Program, GraphValidateError> {
graph.try_into_program()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
fn trivial() -> Program {
Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![
Node::let_bind("x", Expr::u32(42)),
Node::store("out", Expr::u32(0), Expr::var("x")),
],
)
}
#[test]
fn graph_view_mirrors_top_level_nodes() {
let p = trivial();
let g = to_graph(&p);
assert_eq!(g.nodes.len(), p.entry().len());
assert_eq!(g.workgroup_size, p.workgroup_size());
}
#[test]
fn graph_edges_are_ordering_in_sequence() {
let p = trivial();
let g = to_graph(&p);
assert_eq!(g.edges.len(), g.nodes.len() - 1);
for (i, e) in g.edges.iter().enumerate() {
assert_eq!(e.from, i as u32);
assert_eq!(e.to, (i + 1) as u32);
assert!(matches!(e.kind, EdgeKind::Ordering));
}
}
#[test]
fn round_trip_is_byte_identical_under_canonicalize() {
let p = trivial();
let g = to_graph(&p);
let p2 = from_graph(g).unwrap();
let p_c = crate::optimizer::passes::algebraic::canonicalize_engine::run(p);
let p2_c = crate::optimizer::passes::algebraic::canonicalize_engine::run(p2);
assert_eq!(p_c.to_wire().unwrap(), p2_c.to_wire().unwrap());
}
#[test]
fn phi_node_dropped_on_lowering() {
let mut g = NodeGraph {
workgroup_size: [1, 1, 1],
..Default::default()
};
g.buffers
.push(BufferDecl::read_write("out", 0, DataType::U32).with_count(1));
g.nodes.push(GraphNode {
id: 0,
kind: DataflowKind::Statement(Node::store("out", Expr::u32(0), Expr::u32(1))),
});
g.nodes.push(GraphNode {
id: 1,
kind: DataflowKind::Phi(vec![0]),
});
let p = from_graph(g).unwrap();
assert_eq!(
p.entry().len(),
1,
"Phi must not round-trip to statement-IR"
);
}
}