use std::collections::{HashMap, HashSet, VecDeque};
use super::parser::{ParsedEdge, ParsedTopology, ReactionCriteria};
#[derive(Debug)]
pub struct GraphIR {
pub react: ReactionCriteria,
pub sorted_nodes: Vec<String>,
pub nodes: HashMap<String, GraphNode>,
}
#[derive(Debug, Clone)]
pub struct GraphNode {
pub name: String,
pub cache_inputs: Vec<String>,
pub edges_out: Vec<GraphEdge>,
pub edges_in: Vec<IncomingEdge>,
pub is_terminal: bool,
}
#[derive(Debug, Clone)]
pub enum GraphEdge {
Linear { target: String },
Routing { variants: Vec<GraphRoutingVariant> },
}
#[derive(Debug, Clone)]
pub struct GraphRoutingVariant {
pub variant_name: String,
pub target: String,
}
#[derive(Debug, Clone)]
pub struct IncomingEdge {
pub from: String,
pub variant: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum GraphIRError {
#[error("cycle detected in graph: {0}")]
Cycle(String),
#[error("node '{0}' referenced in graph but not defined as a function in the module")]
DanglingReference(String),
#[error("duplicate edge: node '{from}' has multiple edges to '{to}'")]
DuplicateEdge { from: String, to: String },
}
impl GraphIR {
pub fn from_parsed(parsed: ParsedTopology) -> Result<Self, GraphIRError> {
let mut nodes: HashMap<String, GraphNode> = HashMap::new();
for edge in &parsed.edges {
let from_name = edge.from_name().to_string();
let from_inputs: Vec<String> =
edge.from_inputs().iter().map(|i| i.to_string()).collect();
let node = nodes.entry(from_name.clone()).or_insert_with(|| GraphNode {
name: from_name.clone(),
cache_inputs: Vec::new(),
edges_out: Vec::new(),
edges_in: Vec::new(),
is_terminal: false,
});
if node.cache_inputs.is_empty() && !from_inputs.is_empty() {
node.cache_inputs = from_inputs;
}
match edge {
ParsedEdge::Linear { to, .. } => {
let to_name = to.to_string();
nodes.entry(to_name).or_insert_with(|| GraphNode {
name: to.to_string(),
cache_inputs: Vec::new(),
edges_out: Vec::new(),
edges_in: Vec::new(),
is_terminal: false,
});
}
ParsedEdge::Routing { variants, .. } => {
for v in variants {
let target_name = v.target.to_string();
nodes.entry(target_name).or_insert_with(|| GraphNode {
name: v.target.to_string(),
cache_inputs: Vec::new(),
edges_out: Vec::new(),
edges_in: Vec::new(),
is_terminal: false,
});
}
}
}
}
for edge in &parsed.edges {
let from_name = edge.from_name().to_string();
match edge {
ParsedEdge::Linear { to, .. } => {
let to_name = to.to_string();
nodes
.get_mut(&from_name)
.unwrap()
.edges_out
.push(GraphEdge::Linear {
target: to_name.clone(),
});
nodes
.get_mut(&to_name)
.unwrap()
.edges_in
.push(IncomingEdge {
from: from_name,
variant: None,
});
}
ParsedEdge::Routing { variants, .. } => {
let graph_variants: Vec<GraphRoutingVariant> = variants
.iter()
.map(|v| GraphRoutingVariant {
variant_name: v.variant_name.to_string(),
target: v.target.to_string(),
})
.collect();
for v in &graph_variants {
nodes
.get_mut(&v.target)
.unwrap()
.edges_in
.push(IncomingEdge {
from: from_name.clone(),
variant: Some(v.variant_name.clone()),
});
}
nodes
.get_mut(&from_name)
.unwrap()
.edges_out
.push(GraphEdge::Routing {
variants: graph_variants,
});
}
}
}
for node in nodes.values_mut() {
node.is_terminal = node.edges_out.is_empty();
}
let sorted_nodes = topological_sort(&nodes)?;
Ok(GraphIR {
react: parsed.react,
sorted_nodes,
nodes,
})
}
pub fn terminal_nodes(&self) -> Vec<&GraphNode> {
self.nodes.values().filter(|n| n.is_terminal).collect()
}
pub fn entry_nodes(&self) -> Vec<&GraphNode> {
self.nodes
.values()
.filter(|n| n.edges_in.is_empty())
.collect()
}
pub fn get_node(&self, name: &str) -> Option<&GraphNode> {
self.nodes.get(name)
}
pub fn incoming_sources(&self, name: &str) -> Vec<&IncomingEdge> {
self.nodes
.get(name)
.map(|n| n.edges_in.iter().collect())
.unwrap_or_default()
}
}
fn topological_sort(nodes: &HashMap<String, GraphNode>) -> Result<Vec<String>, GraphIRError> {
let mut in_degree: HashMap<String, usize> = HashMap::new();
for node in nodes.values() {
in_degree.entry(node.name.clone()).or_insert(0);
for edge in &node.edges_out {
match edge {
GraphEdge::Linear { target } => {
*in_degree.entry(target.clone()).or_insert(0) += 1;
}
GraphEdge::Routing { variants } => {
for v in variants {
*in_degree.entry(v.target.clone()).or_insert(0) += 1;
}
}
}
}
}
let mut queue: VecDeque<String> = in_degree
.iter()
.filter(|(_, °)| deg == 0)
.map(|(name, _)| name.clone())
.collect();
let mut sorted_queue: Vec<String> = queue.drain(..).collect();
sorted_queue.sort();
queue.extend(sorted_queue);
let mut sorted = Vec::new();
while let Some(name) = queue.pop_front() {
sorted.push(name.clone());
if let Some(node) = nodes.get(&name) {
let mut next_candidates = Vec::new();
for edge in &node.edges_out {
match edge {
GraphEdge::Linear { target } => {
if let Some(deg) = in_degree.get_mut(target) {
*deg -= 1;
if *deg == 0 {
next_candidates.push(target.clone());
}
}
}
GraphEdge::Routing { variants } => {
for v in variants {
if let Some(deg) = in_degree.get_mut(&v.target) {
*deg -= 1;
if *deg == 0 {
next_candidates.push(v.target.clone());
}
}
}
}
}
}
next_candidates.sort();
queue.extend(next_candidates);
}
}
if sorted.len() != nodes.len() {
let remaining: Vec<String> = nodes
.keys()
.filter(|k| !sorted.contains(k))
.cloned()
.collect();
return Err(GraphIRError::Cycle(format!(
"nodes involved in cycle: {}",
remaining.join(", ")
)));
}
Ok(sorted)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::computation_graph::parser::{ReactionMode, RoutingVariant};
use syn::Ident;
fn ident(name: &str) -> Ident {
Ident::new(name, proc_macro2::Span::call_site())
}
fn make_topology(edges: Vec<ParsedEdge>) -> ParsedTopology {
ParsedTopology {
react: ReactionCriteria {
mode: ReactionMode::WhenAny,
accumulators: vec![ident("alpha")],
},
edges,
}
}
#[test]
fn test_linear_chain() {
let topology = make_topology(vec![
ParsedEdge::Linear {
from: ident("a"),
from_inputs: vec![ident("alpha")],
to: ident("b"),
},
ParsedEdge::Linear {
from: ident("b"),
from_inputs: vec![],
to: ident("c"),
},
]);
let ir = GraphIR::from_parsed(topology).unwrap();
assert_eq!(ir.sorted_nodes, vec!["a", "b", "c"]);
assert!(ir.get_node("c").unwrap().is_terminal);
assert!(!ir.get_node("a").unwrap().is_terminal);
assert!(!ir.get_node("b").unwrap().is_terminal);
}
#[test]
fn test_routing() {
let topology = make_topology(vec![ParsedEdge::Routing {
from: ident("decision"),
from_inputs: vec![ident("alpha")],
variants: vec![
RoutingVariant {
variant_name: ident("Signal"),
target: ident("handler_a"),
},
RoutingVariant {
variant_name: ident("NoAction"),
target: ident("handler_b"),
},
],
}]);
let ir = GraphIR::from_parsed(topology).unwrap();
assert_eq!(ir.sorted_nodes[0], "decision");
assert!(ir.get_node("handler_a").unwrap().is_terminal);
assert!(ir.get_node("handler_b").unwrap().is_terminal);
assert_eq!(ir.terminal_nodes().len(), 2);
assert_eq!(ir.entry_nodes().len(), 1);
}
#[test]
fn test_diamond_graph() {
let topology = make_topology(vec![
ParsedEdge::Linear {
from: ident("a"),
from_inputs: vec![ident("alpha")],
to: ident("b"),
},
ParsedEdge::Linear {
from: ident("a"),
from_inputs: vec![],
to: ident("c"),
},
ParsedEdge::Linear {
from: ident("b"),
from_inputs: vec![],
to: ident("d"),
},
ParsedEdge::Linear {
from: ident("c"),
from_inputs: vec![],
to: ident("d"),
},
]);
let ir = GraphIR::from_parsed(topology).unwrap();
assert_eq!(ir.sorted_nodes[0], "a");
assert_eq!(*ir.sorted_nodes.last().unwrap(), "d");
let b_pos = ir.sorted_nodes.iter().position(|n| n == "b").unwrap();
let c_pos = ir.sorted_nodes.iter().position(|n| n == "c").unwrap();
let d_pos = ir.sorted_nodes.iter().position(|n| n == "d").unwrap();
assert!(b_pos < d_pos);
assert!(c_pos < d_pos);
assert_eq!(ir.get_node("d").unwrap().edges_in.len(), 2);
assert!(ir.get_node("d").unwrap().is_terminal);
}
#[test]
fn test_cycle_detection() {
let topology = make_topology(vec![
ParsedEdge::Linear {
from: ident("a"),
from_inputs: vec![ident("alpha")],
to: ident("b"),
},
ParsedEdge::Linear {
from: ident("b"),
from_inputs: vec![],
to: ident("a"),
},
]);
let result = GraphIR::from_parsed(topology);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("cycle"), "got: {}", err);
}
#[test]
fn test_terminal_nodes() {
let topology = make_topology(vec![
ParsedEdge::Linear {
from: ident("a"),
from_inputs: vec![ident("alpha")],
to: ident("b"),
},
ParsedEdge::Linear {
from: ident("a"),
from_inputs: vec![],
to: ident("c"),
},
]);
let ir = GraphIR::from_parsed(topology).unwrap();
let terminals: HashSet<String> =
ir.terminal_nodes().iter().map(|n| n.name.clone()).collect();
assert_eq!(terminals.len(), 2);
assert!(terminals.contains("b"));
assert!(terminals.contains("c"));
}
#[test]
fn test_entry_nodes() {
let topology = make_topology(vec![
ParsedEdge::Linear {
from: ident("a"),
from_inputs: vec![ident("alpha")],
to: ident("c"),
},
ParsedEdge::Linear {
from: ident("b"),
from_inputs: vec![],
to: ident("c"),
},
]);
let ir = GraphIR::from_parsed(topology).unwrap();
let entries: HashSet<String> = ir.entry_nodes().iter().map(|n| n.name.clone()).collect();
assert_eq!(entries.len(), 2);
assert!(entries.contains("a"));
assert!(entries.contains("b"));
}
#[test]
fn test_cache_inputs_preserved() {
let topology = make_topology(vec![ParsedEdge::Linear {
from: ident("entry"),
from_inputs: vec![ident("alpha"), ident("beta"), ident("gamma")],
to: ident("output"),
}]);
let ir = GraphIR::from_parsed(topology).unwrap();
let entry = ir.get_node("entry").unwrap();
assert_eq!(entry.cache_inputs, vec!["alpha", "beta", "gamma"]);
}
#[test]
fn test_incoming_edges_with_variants() {
let topology = make_topology(vec![ParsedEdge::Routing {
from: ident("decision"),
from_inputs: vec![ident("alpha")],
variants: vec![RoutingVariant {
variant_name: ident("Signal"),
target: ident("handler"),
}],
}]);
let ir = GraphIR::from_parsed(topology).unwrap();
let handler = ir.get_node("handler").unwrap();
assert_eq!(handler.edges_in.len(), 1);
assert_eq!(handler.edges_in[0].from, "decision");
assert_eq!(handler.edges_in[0].variant.as_deref(), Some("Signal"));
}
#[test]
fn test_mixed_routing_and_linear() {
let topology = make_topology(vec![
ParsedEdge::Routing {
from: ident("decision"),
from_inputs: vec![ident("alpha")],
variants: vec![
RoutingVariant {
variant_name: ident("Signal"),
target: ident("risk_check"),
},
RoutingVariant {
variant_name: ident("NoAction"),
target: ident("audit"),
},
],
},
ParsedEdge::Linear {
from: ident("risk_check"),
from_inputs: vec![],
to: ident("output"),
},
]);
let ir = GraphIR::from_parsed(topology).unwrap();
assert_eq!(ir.sorted_nodes[0], "decision");
assert!(ir.get_node("audit").unwrap().is_terminal);
assert!(ir.get_node("output").unwrap().is_terminal);
assert!(!ir.get_node("risk_check").unwrap().is_terminal);
}
}