use async_trait::async_trait;
use petgraph::stable_graph::{NodeIndex, StableGraph};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use crate::error::{GraphError, NodeError};
use crate::state::{AgentState, SharedState};
pub mod transitions {
pub const CONTINUE: &str = "__continue__";
pub const FINISH: &str = "__finish__";
pub const END: &str = "__end__";
pub const START: &str = "__start__";
}
#[derive(Clone, Debug)]
pub enum NodeOutput {
Continue(Option<String>),
Finish,
Route(String),
}
impl NodeOutput {
pub fn cont() -> Self {
Self::Continue(None)
}
pub fn continue_to(next: impl Into<String>) -> Self {
Self::Continue(Some(next.into()))
}
pub fn finish() -> Self {
Self::Finish
}
pub fn end() -> Self {
Self::Finish
}
pub fn route(target: impl Into<String>) -> Self {
Self::Route(target.into())
}
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Finish)
}
pub fn target(&self) -> Option<&str> {
match self {
Self::Continue(Some(s)) | Self::Route(s) => Some(s),
_ => None,
}
}
}
#[async_trait]
pub trait NodeExecutor: Send + Sync {
fn id(&self) -> &str;
async fn execute(&self, state: SharedState) -> Result<NodeOutput, NodeError>;
fn description(&self) -> Option<&str> {
None
}
}
pub type BoxedNodeExecutor = Arc<dyn NodeExecutor>;
#[derive(Clone)]
pub struct GraphNode {
pub id: String,
pub executor: BoxedNodeExecutor,
}
impl GraphNode {
pub fn new(executor: impl NodeExecutor + 'static) -> Self {
Self {
id: executor.id().to_string(),
executor: Arc::new(executor),
}
}
}
impl fmt::Debug for GraphNode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GraphNode").field("id", &self.id).finish()
}
}
#[derive(Clone)]
pub enum EdgeType {
Direct,
Conditional(Arc<dyn Fn(&AgentState) -> String + Send + Sync>),
}
impl fmt::Debug for EdgeType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
EdgeType::Direct => write!(f, "Direct"),
EdgeType::Conditional(_) => write!(f, "Conditional(<fn>)"),
}
}
}
#[derive(Clone)]
pub struct GraphEdge {
pub from: String,
pub to: String,
pub edge_type: EdgeType,
}
impl GraphEdge {
pub fn direct(from: impl Into<String>, to: impl Into<String>) -> Self {
Self {
from: from.into(),
to: to.into(),
edge_type: EdgeType::Direct,
}
}
pub fn conditional<F>(from: impl Into<String>, router: F) -> Self
where
F: Fn(&AgentState) -> String + Send + Sync + 'static,
{
Self {
from: from.into(),
to: String::new(), edge_type: EdgeType::Conditional(Arc::new(router)),
}
}
}
impl fmt::Debug for GraphEdge {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GraphEdge")
.field("from", &self.from)
.field("to", &self.to)
.field(
"edge_type",
&match &self.edge_type {
EdgeType::Direct => "Direct",
EdgeType::Conditional(_) => "Conditional",
},
)
.finish()
}
}
pub struct GraphBuilder {
nodes: HashMap<String, GraphNode>,
edges: Vec<GraphEdge>,
entry_point: Option<String>,
name: Option<String>,
description: Option<String>,
}
impl Default for GraphBuilder {
fn default() -> Self {
Self::new()
}
}
impl GraphBuilder {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges: Vec::new(),
entry_point: None,
name: None,
description: None,
}
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
pub fn add_node(mut self, executor: impl NodeExecutor + 'static) -> Self {
let node = GraphNode::new(executor);
self.nodes.insert(node.id.clone(), node);
self
}
pub fn set_entry_point(mut self, node_id: impl Into<String>) -> Self {
self.entry_point = Some(node_id.into());
self
}
pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
self.edges.push(GraphEdge::direct(from, to));
self
}
pub fn add_edge_to_end(mut self, from: impl Into<String>) -> Self {
self.edges
.push(GraphEdge::direct(from, transitions::END.to_string()));
self
}
pub fn add_conditional_edge<F>(mut self, from: impl Into<String>, router: F) -> Self
where
F: Fn(&AgentState) -> String + Send + Sync + 'static,
{
self.edges.push(GraphEdge::conditional(from, router));
self
}
pub fn compile(self) -> Result<CompiledGraph, GraphError> {
let entry = self.entry_point.ok_or(GraphError::NoEntryPoint)?;
if !self.nodes.contains_key(&entry) {
return Err(GraphError::NodeNotFound(entry));
}
let mut graph = StableGraph::new();
let mut node_indices: HashMap<String, NodeIndex> = HashMap::new();
for (id, node) in &self.nodes {
let idx = graph.add_node(node.clone());
node_indices.insert(id.clone(), idx);
}
let end_node = GraphNode {
id: transitions::END.to_string(),
executor: Arc::new(EndNode),
};
let end_idx = graph.add_node(end_node);
node_indices.insert(transitions::END.to_string(), end_idx);
for edge in &self.edges {
let from_idx = node_indices
.get(&edge.from)
.ok_or_else(|| GraphError::NodeNotFound(edge.from.clone()))?;
if let EdgeType::Direct = &edge.edge_type {
let to_idx = node_indices
.get(&edge.to)
.ok_or_else(|| GraphError::NodeNotFound(edge.to.clone()))?;
graph.add_edge(*from_idx, *to_idx, edge.clone());
}
}
Ok(CompiledGraph {
graph,
node_indices,
edges: self.edges,
entry_point: entry,
name: self.name,
description: self.description,
})
}
}
#[derive(Clone)]
pub struct CompiledGraph {
pub(crate) graph: StableGraph<GraphNode, GraphEdge>,
pub(crate) node_indices: HashMap<String, NodeIndex>,
pub(crate) edges: Vec<GraphEdge>,
pub(crate) entry_point: String,
pub(crate) name: Option<String>,
pub(crate) description: Option<String>,
}
impl CompiledGraph {
pub fn entry_point(&self) -> &str {
&self.entry_point
}
pub fn get_node(&self, id: &str) -> Option<&GraphNode> {
self.node_indices
.get(id)
.and_then(|idx| self.graph.node_weight(*idx))
}
pub fn get_next_node(&self, from: &str, state: &AgentState) -> Option<String> {
for edge in &self.edges {
if edge.from == from {
match &edge.edge_type {
EdgeType::Direct => return Some(edge.to.clone()),
EdgeType::Conditional(router) => {
let target = router(state);
return Some(target);
}
}
}
}
None
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn description(&self) -> Option<&str> {
self.description.as_deref()
}
pub fn to_mermaid(&self) -> String {
let mut output = String::from("graph TD\n");
for (id, _) in &self.node_indices {
if id != transitions::END {
output.push_str(&format!(" {}[{}]\n", id.replace('-', "_"), id));
} else {
output.push_str(&format!(" {}(({}))\n", id.replace('-', "_"), "END"));
}
}
output.push_str(&format!(
" __start__([Start]) --> {}\n",
self.entry_point.replace('-', "_")
));
for edge in &self.edges {
let from = edge.from.replace('-', "_");
let to = edge.to.replace('-', "_");
match &edge.edge_type {
EdgeType::Direct => {
output.push_str(&format!(" {} --> {}\n", from, to));
}
EdgeType::Conditional(_) => {
output.push_str(&format!(" {} -.->|conditional| ...\n", from));
}
}
}
output
}
}
struct EndNode;
#[async_trait]
impl NodeExecutor for EndNode {
fn id(&self) -> &str {
transitions::END
}
async fn execute(&self, _state: SharedState) -> Result<NodeOutput, NodeError> {
Ok(NodeOutput::Finish)
}
fn description(&self) -> Option<&str> {
Some("Terminal node that ends graph execution")
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestNode {
id: String,
}
#[async_trait]
impl NodeExecutor for TestNode {
fn id(&self) -> &str {
&self.id
}
async fn execute(&self, _state: SharedState) -> Result<NodeOutput, NodeError> {
Ok(NodeOutput::cont())
}
}
#[test]
fn test_node_output() {
let output = NodeOutput::cont();
assert!(!output.is_terminal());
assert!(output.target().is_none());
let output = NodeOutput::continue_to("next");
assert_eq!(output.target(), Some("next"));
let output = NodeOutput::finish();
assert!(output.is_terminal());
let output = NodeOutput::end();
assert!(output.is_terminal());
let output = NodeOutput::route("target");
assert_eq!(output.target(), Some("target"));
}
#[test]
fn test_graph_builder() {
let graph = GraphBuilder::new()
.name("test_graph")
.add_node(TestNode {
id: "node1".to_string(),
})
.add_node(TestNode {
id: "node2".to_string(),
})
.set_entry_point("node1")
.add_edge("node1", "node2")
.add_edge_to_end("node2")
.compile()
.unwrap();
assert_eq!(graph.entry_point(), "node1");
assert!(graph.get_node("node1").is_some());
assert!(graph.get_node("node2").is_some());
assert!(graph.get_node(transitions::END).is_some());
}
#[test]
fn test_graph_builder_no_entry_point() {
let result = GraphBuilder::new()
.add_node(TestNode {
id: "node1".to_string(),
})
.compile();
assert!(matches!(result, Err(GraphError::NoEntryPoint)));
}
#[test]
fn test_graph_builder_missing_node() {
let result = GraphBuilder::new().set_entry_point("nonexistent").compile();
assert!(matches!(result, Err(GraphError::NodeNotFound(_))));
}
#[test]
fn test_conditional_edge() {
let graph = GraphBuilder::new()
.add_node(TestNode {
id: "start".to_string(),
})
.add_node(TestNode {
id: "branch_a".to_string(),
})
.add_node(TestNode {
id: "branch_b".to_string(),
})
.set_entry_point("start")
.add_conditional_edge("start", |state: &AgentState| {
if state.is_complete {
transitions::END.to_string()
} else {
"branch_a".to_string()
}
})
.compile()
.unwrap();
let state = AgentState::new();
let next = graph.get_next_node("start", &state);
assert_eq!(next, Some("branch_a".to_string()));
let mut complete_state = AgentState::new();
complete_state.is_complete = true;
let next = graph.get_next_node("start", &complete_state);
assert_eq!(next, Some(transitions::END.to_string()));
}
#[test]
fn test_mermaid_output() {
let graph = GraphBuilder::new()
.add_node(TestNode {
id: "node1".to_string(),
})
.add_node(TestNode {
id: "node2".to_string(),
})
.set_entry_point("node1")
.add_edge("node1", "node2")
.add_edge_to_end("node2")
.compile()
.unwrap();
let mermaid = graph.to_mermaid();
assert!(mermaid.contains("graph TD"));
assert!(mermaid.contains("node1"));
assert!(mermaid.contains("node2"));
}
}