use crate::agent::react::builder::ReactAgentBuilder;
use crate::error::Result;
use crate::workflow::Graph;
use crate::workflow::GraphBuilder;
use crate::workflow::SharedState;
use futures::future::BoxFuture;
type AgentConfigFn = Box<dyn FnOnce(ReactAgentBuilder) -> ReactAgentBuilder>;
type FunctionNodeFn =
Box<dyn for<'a> Fn(&'a SharedState) -> BoxFuture<'a, Result<()>> + Send + Sync>;
type ConditionFn = Box<dyn for<'a> Fn(&'a SharedState) -> BoxFuture<'a, String> + Send + Sync>;
enum PendingNode {
Agent {
name: String,
config_fn: AgentConfigFn,
input_key: String,
output_key: String,
},
Function {
name: String,
f: FunctionNodeFn,
},
Router {
name: String,
},
}
enum PendingEdge {
Fixed {
from: String,
to: String,
},
Conditional {
from: String,
f: ConditionFn,
},
Parallel {
from: String,
targets: Vec<String>,
then: String,
},
}
pub struct StateGraph {
name: String,
pending_nodes: Vec<PendingNode>,
edges: Vec<PendingEdge>,
entry: Option<String>,
finish_nodes: Vec<String>,
}
impl StateGraph {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
pending_nodes: Vec::new(),
edges: Vec::new(),
entry: None,
finish_nodes: Vec::new(),
}
}
pub fn add_react_node<F>(&mut self, name: impl Into<String>, configure: F) -> &mut Self
where
F: FnOnce(ReactAgentBuilder) -> ReactAgentBuilder + 'static,
{
self.pending_nodes.push(PendingNode::Agent {
name: name.into(),
config_fn: Box::new(configure),
input_key: "task".to_string(),
output_key: "result".to_string(),
});
self
}
pub fn with_input(&mut self, key: impl Into<String>) -> &mut Self {
if let Some(PendingNode::Agent { input_key, .. }) = self.pending_nodes.last_mut() {
*input_key = key.into();
}
self
}
pub fn with_output(&mut self, key: impl Into<String>) -> &mut Self {
if let Some(PendingNode::Agent { output_key, .. }) = self.pending_nodes.last_mut() {
*output_key = key.into();
}
self
}
pub fn add_function_node<F>(&mut self, name: impl Into<String>, f: F) -> &mut Self
where
F: for<'a> Fn(&'a SharedState) -> BoxFuture<'a, Result<()>> + Send + Sync + 'static,
{
self.pending_nodes.push(PendingNode::Function {
name: name.into(),
f: Box::new(f),
});
self
}
pub fn add_router(&mut self, name: impl Into<String>) -> &mut Self {
self.pending_nodes
.push(PendingNode::Router { name: name.into() });
self
}
pub fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self {
self.edges.push(PendingEdge::Fixed {
from: from.into(),
to: to.into(),
});
self
}
pub fn add_conditional_edge<F>(&mut self, from: impl Into<String>, f: F) -> &mut Self
where
F: for<'a> Fn(&'a SharedState) -> BoxFuture<'a, String> + Send + Sync + 'static,
{
self.edges.push(PendingEdge::Conditional {
from: from.into(),
f: Box::new(f),
});
self
}
pub fn add_parallel_edge(
&mut self,
from: impl Into<String>,
targets: Vec<String>,
then: impl Into<String>,
) -> &mut Self {
self.edges.push(PendingEdge::Parallel {
from: from.into(),
targets,
then: then.into(),
});
self
}
pub fn entry(&mut self, name: impl Into<String>) -> &mut Self {
self.entry = Some(name.into());
self
}
pub fn finish(&mut self, name: impl Into<String>) -> &mut Self {
self.finish_nodes.push(name.into());
self
}
pub fn compile(self) -> Result<Graph> {
let mut builder = GraphBuilder::new(&self.name);
for node in self.pending_nodes {
match node {
PendingNode::Agent {
name,
config_fn,
input_key,
output_key,
} => {
let base = ReactAgentBuilder::new();
let configured = config_fn(base);
let agent = configured.build()?;
builder = builder.add_agent_node(&name, agent, &input_key, &output_key);
}
PendingNode::Function { name, f } => {
builder = builder.add_function_node(&name, f);
}
PendingNode::Router { name } => {
builder = builder.add_router_node(&name);
}
}
}
for edge in self.edges {
match edge {
PendingEdge::Fixed { from, to } => {
builder = builder.add_edge(&from, &to);
}
PendingEdge::Conditional { from, f } => {
builder = builder.add_conditional_edge(&from, f);
}
PendingEdge::Parallel {
from,
targets,
then,
} => {
builder = builder.add_parallel_edge(&from, targets, &then);
}
}
}
if let Some(entry) = self.entry {
builder = builder.set_entry(&entry);
}
for finish in self.finish_nodes {
builder = builder.set_finish(&finish);
}
builder.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_state_graph_function_nodes() {
let mut sg = StateGraph::new("test_flow");
sg.add_function_node("step1", |state| {
Box::pin(async move {
let _ = state.set("x", "hello");
Ok(())
})
})
.add_function_node("step2", |state| {
Box::pin(async move {
let x: String = state.get("x").unwrap_or_default();
let _ = state.set("y", format!("{} world", x));
Ok(())
})
})
.add_edge("step1", "step2")
.add_edge("step2", "__end__")
.entry("step1")
.finish("step2");
let graph = sg.compile().unwrap();
let state = SharedState::new();
let result = graph.run(state).await.unwrap();
let y: String = result.state.get("y").unwrap_or_default();
assert_eq!(y, "hello world");
}
#[tokio::test]
async fn test_state_graph_run() {
let mut sg = StateGraph::new("hello_flow");
sg.add_function_node("greet", |state| {
Box::pin(async move {
let _ = state.set("msg", "hello");
Ok(())
})
})
.add_function_node("shout", |state| {
Box::pin(async move {
let msg: String = state.get("msg").unwrap_or_default();
let _ = state.set("msg", format!("{}!", msg.to_uppercase()));
Ok(())
})
})
.add_edge("greet", "shout")
.add_edge("shout", "__end__")
.entry("greet")
.finish("shout");
let graph = sg.compile().unwrap();
let state = SharedState::new();
let result = graph.run(state).await.unwrap();
let msg: String = result.state.get("msg").unwrap_or_default();
assert_eq!(msg, "HELLO!");
}
#[tokio::test]
async fn test_state_graph_conditional() {
let mut sg = StateGraph::new("cond_flow");
sg.add_function_node("check", |state| {
Box::pin(async move {
let _ = state.set("status", "ok");
Ok(())
})
})
.add_function_node("pass", |state| {
Box::pin(async move {
let _ = state.set("result", "passed");
Ok(())
})
})
.add_function_node("fail", |state| {
Box::pin(async move {
let _ = state.set("result", "failed");
Ok(())
})
})
.add_conditional_edge("check", |state| {
Box::pin(async {
let s: String = state.get("status").unwrap_or_default();
if s == "ok" { "pass" } else { "fail" }.to_string()
})
})
.add_edge("pass", "__end__")
.add_edge("fail", "__end__")
.entry("check")
.finish("pass")
.finish("fail");
let graph = sg.compile().unwrap();
let state = SharedState::new();
let result = graph.run(state).await.unwrap();
assert_eq!(result.steps, 2); }
}