use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub use daggy;
pub use petgraph;
use daggy::petgraph::visit::Topo;
use daggy::{Dag, NodeIndex, Walker};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionNode {
pub id: String,
pub node_type: String,
pub config: serde_json::Value,
pub state: NodeState,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeState {
Pending,
Running,
Completed,
Failed,
Skipped,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataFlow {
pub source_output: String,
pub target_input: String,
}
pub struct ExecutionGraph {
dag: Dag<ExecutionNode, DataFlow>,
node_map: HashMap<String, NodeIndex>,
}
impl ExecutionGraph {
pub fn new() -> Self {
Self {
dag: Dag::new(),
node_map: HashMap::new(),
}
}
pub fn add_node(&mut self, node: ExecutionNode) -> Result<NodeIndex> {
let id = node.id.clone();
if self.node_map.contains_key(&id) {
anyhow::bail!("Node with ID '{}' already exists", id);
}
let idx = self.dag.add_node(node);
self.node_map.insert(id, idx);
Ok(idx)
}
pub fn add_edge(&mut self, from_id: &str, to_id: &str, data_flow: DataFlow) -> Result<()> {
let from_idx = self
.node_map
.get(from_id)
.ok_or_else(|| anyhow::anyhow!("Source node '{}' not found", from_id))?;
let to_idx = self
.node_map
.get(to_id)
.ok_or_else(|| anyhow::anyhow!("Target node '{}' not found", to_id))?;
self.dag
.add_edge(*from_idx, *to_idx, data_flow)
.map_err(|_| anyhow::anyhow!("Adding edge would create a cycle"))?;
Ok(())
}
pub fn execution_order(&self) -> Vec<&ExecutionNode> {
let mut order = Vec::new();
let mut topo = Topo::new(self.dag.graph());
while let Some(idx) = topo.next(self.dag.graph()) {
if let Some(node) = self.dag.node_weight(idx) {
order.push(node);
}
}
order
}
pub fn ready_nodes(&self) -> Vec<&ExecutionNode> {
let mut ready = Vec::new();
for (idx, node) in self
.dag
.graph()
.node_indices()
.zip(self.dag.raw_nodes().iter())
{
if node.weight.state != NodeState::Pending {
continue;
}
let all_parents_done = self
.dag
.parents(idx)
.iter(&self.dag)
.all(|(_, parent_idx)| {
self.dag
.node_weight(parent_idx)
.map(|p| p.state == NodeState::Completed)
.unwrap_or(false)
});
if all_parents_done {
ready.push(&node.weight);
}
}
ready
}
pub fn set_state(&mut self, node_id: &str, state: NodeState) -> Result<()> {
let idx = self
.node_map
.get(node_id)
.ok_or_else(|| anyhow::anyhow!("Node '{}' not found", node_id))?;
if let Some(node) = self.dag.node_weight_mut(*idx) {
node.state = state;
}
Ok(())
}
pub fn get_node(&self, node_id: &str) -> Option<&ExecutionNode> {
self.node_map
.get(node_id)
.and_then(|idx| self.dag.node_weight(*idx))
}
pub fn get_node_mut(&mut self, node_id: &str) -> Option<&mut ExecutionNode> {
self.node_map
.get(node_id)
.copied()
.and_then(move |idx| self.dag.node_weight_mut(idx))
}
pub fn dependencies(&self, node_id: &str) -> Vec<&ExecutionNode> {
let Some(idx) = self.node_map.get(node_id) else {
return Vec::new();
};
self.dag
.parents(*idx)
.iter(&self.dag)
.filter_map(|(_, parent_idx)| self.dag.node_weight(parent_idx))
.collect()
}
pub fn dependents(&self, node_id: &str) -> Vec<&ExecutionNode> {
let Some(idx) = self.node_map.get(node_id) else {
return Vec::new();
};
self.dag
.children(*idx)
.iter(&self.dag)
.filter_map(|(_, child_idx)| self.dag.node_weight(child_idx))
.collect()
}
pub fn node_count(&self) -> usize {
self.dag.node_count()
}
pub fn edge_count(&self) -> usize {
self.dag.edge_count()
}
pub fn is_empty(&self) -> bool {
self.dag.node_count() == 0
}
pub fn is_complete(&self) -> bool {
self.dag.raw_nodes().iter().all(|n| {
matches!(
n.weight.state,
NodeState::Completed | NodeState::Skipped | NodeState::Failed
)
})
}
}
impl Default for ExecutionGraph {
fn default() -> Self {
Self::new()
}
}
pub struct GraphBuilder {
graph: ExecutionGraph,
}
impl GraphBuilder {
pub fn new() -> Self {
Self {
graph: ExecutionGraph::new(),
}
}
pub fn thinktool(mut self, id: impl Into<String>, tool_name: impl Into<String>) -> Self {
let _ = self.graph.add_node(ExecutionNode {
id: id.into(),
node_type: "thinktool".to_string(),
config: serde_json::json!({ "tool": tool_name.into() }),
state: NodeState::Pending,
});
self
}
pub fn transform(mut self, id: impl Into<String>, transform_type: impl Into<String>) -> Self {
let _ = self.graph.add_node(ExecutionNode {
id: id.into(),
node_type: "transform".to_string(),
config: serde_json::json!({ "type": transform_type.into() }),
state: NodeState::Pending,
});
self
}
pub fn edge(mut self, from: &str, to: &str) -> Self {
let _ = self.graph.add_edge(
from,
to,
DataFlow {
source_output: "output".to_string(),
target_input: "input".to_string(),
},
);
self
}
pub fn edge_named(mut self, from: &str, output: &str, to: &str, input: &str) -> Self {
let _ = self.graph.add_edge(
from,
to,
DataFlow {
source_output: output.to_string(),
target_input: input.to_string(),
},
);
self
}
pub fn build(self) -> ExecutionGraph {
self.graph
}
}
impl Default for GraphBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graph_construction() {
let graph = GraphBuilder::new()
.thinktool("gigathink", "GigaThink")
.thinktool("laserlogic", "LaserLogic")
.transform("merge", "combine")
.edge("gigathink", "merge")
.edge("laserlogic", "merge")
.build();
assert_eq!(graph.node_count(), 3);
assert_eq!(graph.edge_count(), 2);
}
#[test]
fn test_execution_order() {
let graph = GraphBuilder::new()
.thinktool("a", "ToolA")
.thinktool("b", "ToolB")
.thinktool("c", "ToolC")
.edge("a", "c")
.edge("b", "c")
.build();
let order = graph.execution_order();
assert_eq!(order.len(), 3);
let c_pos = order.iter().position(|n| n.id == "c").unwrap();
let a_pos = order.iter().position(|n| n.id == "a").unwrap();
let b_pos = order.iter().position(|n| n.id == "b").unwrap();
assert!(c_pos > a_pos);
assert!(c_pos > b_pos);
}
#[test]
fn test_cycle_detection() {
let mut graph = ExecutionGraph::new();
graph
.add_node(ExecutionNode {
id: "a".to_string(),
node_type: "test".to_string(),
config: serde_json::json!({}),
state: NodeState::Pending,
})
.unwrap();
graph
.add_node(ExecutionNode {
id: "b".to_string(),
node_type: "test".to_string(),
config: serde_json::json!({}),
state: NodeState::Pending,
})
.unwrap();
graph
.add_edge(
"a",
"b",
DataFlow {
source_output: "out".to_string(),
target_input: "in".to_string(),
},
)
.unwrap();
let result = graph.add_edge(
"b",
"a",
DataFlow {
source_output: "out".to_string(),
target_input: "in".to_string(),
},
);
assert!(result.is_err());
}
#[test]
fn test_ready_nodes() {
let mut graph = GraphBuilder::new()
.thinktool("a", "ToolA")
.thinktool("b", "ToolB")
.thinktool("c", "ToolC")
.edge("a", "c")
.edge("b", "c")
.build();
let ready = graph.ready_nodes();
assert_eq!(ready.len(), 2);
graph.set_state("a", NodeState::Completed).unwrap();
let ready = graph.ready_nodes();
assert_eq!(ready.len(), 1);
assert_eq!(ready[0].id, "b");
graph.set_state("b", NodeState::Completed).unwrap();
let ready = graph.ready_nodes();
assert_eq!(ready.len(), 1);
assert_eq!(ready[0].id, "c");
}
}