use crate::agent::react::builder::ReactAgentBuilder;
use crate::error::{AgentError, ReactError, Result};
use crate::llm::config::LlmConfig;
use crate::workflow::Graph;
use crate::workflow::GraphBuilder;
use crate::workflow::SharedState;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct WorkflowDefinition {
pub name: String,
pub nodes: Vec<NodeDefinition>,
pub edges: Vec<EdgeDefinition>,
pub entry: String,
#[serde(default)]
pub finish: Vec<String>,
#[serde(default)]
pub max_steps: Option<usize>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct NodeDefinition {
pub name: String,
#[serde(rename = "type")]
pub node_type: String,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub system_prompt: Option<String>,
#[serde(default = "default_input_key")]
pub input_key: String,
#[serde(default = "default_output_key")]
pub output_key: String,
}
fn default_input_key() -> String {
"input".to_string()
}
fn default_output_key() -> String {
"output".to_string()
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct EdgeDefinition {
pub from: String,
#[serde(default)]
pub to: Option<String>,
#[serde(default)]
pub condition: Option<ConditionDefinition>,
#[serde(default)]
pub parallel: Option<Vec<String>>,
#[serde(default)]
pub then: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ConditionDefinition {
pub key: String,
pub equals: serde_json::Value,
pub then: String,
#[serde(rename = "else")]
pub else_node: String,
}
impl WorkflowDefinition {
pub fn from_yaml(path: impl AsRef<Path>) -> Result<Self> {
let content = std::fs::read_to_string(path.as_ref()).map_err(|e| {
ReactError::Agent(AgentError::InitializationFailed(format!(
"Failed to read workflow YAML file: {e}"
)))
})?;
Self::from_yaml_str(&content)
}
pub fn from_yaml_str(yaml: &str) -> Result<Self> {
serde_yaml::from_str(yaml).map_err(|e| {
ReactError::Agent(AgentError::InitializationFailed(format!(
"Failed to parse workflow YAML: {e}"
)))
})
}
pub fn from_json(path: impl AsRef<Path>) -> Result<Self> {
let content = std::fs::read_to_string(path.as_ref()).map_err(|e| {
ReactError::Agent(AgentError::InitializationFailed(format!(
"Failed to read workflow JSON file: {e}"
)))
})?;
Self::from_json_str(&content)
}
pub fn from_json_str(json: &str) -> Result<Self> {
serde_json::from_str(json).map_err(|e| {
ReactError::Agent(AgentError::InitializationFailed(format!(
"Failed to parse workflow JSON: {e}"
)))
})
}
pub fn build_graph(self) -> Result<Graph> {
self.build_graph_with_llm_config(None)
}
pub fn build_graph_with_llm_config(self, llm_config: Option<&LlmConfig>) -> Result<Graph> {
let mut builder = GraphBuilder::new(&self.name);
for node_def in &self.nodes {
match node_def.node_type.as_str() {
"agent" => {
let model = node_def.model.as_deref().unwrap_or("qwen3-max");
let prompt = node_def
.system_prompt
.as_deref()
.unwrap_or("You are a helpful assistant");
let mut agent_builder = ReactAgentBuilder::new()
.name(&node_def.name)
.model(model)
.system_prompt(prompt);
if let Some(config) = llm_config {
agent_builder = agent_builder.llm_config(config.clone());
}
let agent = agent_builder.build()?;
builder = builder.add_agent_node(
&node_def.name,
agent,
&node_def.input_key,
&node_def.output_key,
);
}
"router" => {
builder = builder.add_router_node(&node_def.name);
}
"function" => {
return Err(ReactError::Agent(AgentError::InitializationFailed(
format!(
"Node type 'function' is not yet supported for node '{}'. \
Use 'agent' or 'router' instead, or register a function node manually.",
node_def.name
),
)));
}
other => {
return Err(ReactError::Agent(AgentError::InitializationFailed(
format!("Unknown node type '{}' for node '{}'", other, node_def.name),
)));
}
}
}
for edge_def in &self.edges {
if let Some(ref to) = edge_def.to {
builder = builder.add_edge(&edge_def.from, to);
} else if let Some(ref cond) = edge_def.condition {
let key = cond.key.clone();
let expected = cond.equals.clone();
let then = cond.then.clone();
let else_node = cond.else_node.clone();
builder =
builder.add_conditional_edge(&edge_def.from, move |state: &SharedState| {
let key = key.clone();
let expected = expected.clone();
let then = then.clone();
let else_node = else_node.clone();
Box::pin(async move {
let actual = state.get_raw(&key);
if actual.as_ref() == Some(&expected) {
then
} else {
else_node
}
})
});
} else if let Some(ref targets) = edge_def.parallel {
let then = edge_def
.then
.clone()
.unwrap_or_else(|| "__end__".to_string());
builder = builder.add_parallel_edge(&edge_def.from, targets.clone(), then);
}
}
builder = builder.set_entry(&self.entry);
for finish in &self.finish {
builder = builder.set_finish(finish);
}
let mut graph = builder.build()?;
if let Some(max) = self.max_steps {
graph.set_max_steps(max);
}
Ok(graph)
}
}
pub fn load_graph_from_yaml(path: impl AsRef<Path>) -> Result<crate::workflow::Graph> {
WorkflowDefinition::from_yaml(path)?.build_graph()
}
pub fn load_graph_from_json(path: impl AsRef<Path>) -> Result<crate::workflow::Graph> {
WorkflowDefinition::from_json(path)?.build_graph()
}
pub fn load_graph_from_yaml_str(yaml: &str) -> Result<crate::workflow::Graph> {
WorkflowDefinition::from_yaml_str(yaml)?.build_graph()
}
pub fn load_graph_from_json_str(json: &str) -> Result<crate::workflow::Graph> {
WorkflowDefinition::from_json_str(json)?.build_graph()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_yaml_definition() {
let yaml = r#"
name: test_workflow
nodes:
- name: start
type: router
- name: worker
type: agent
model: qwen3-max
system_prompt: "You are an assistant"
input_key: task
output_key: result
edges:
- from: start
to: worker
- from: worker
to: __end__
entry: start
finish: []
max_steps: 50
"#;
let def = WorkflowDefinition::from_yaml_str(yaml).unwrap();
assert_eq!(def.name, "test_workflow");
assert_eq!(def.nodes.len(), 2);
assert_eq!(def.edges.len(), 2);
assert_eq!(def.entry, "start");
assert_eq!(def.max_steps, Some(50));
}
#[test]
fn test_parse_json_definition() {
let json = r#"{
"name": "json_flow",
"nodes": [
{ "name": "n1", "type": "router", "input_key": "input", "output_key": "output" },
{ "name": "n2", "type": "router", "input_key": "input", "output_key": "output" }
],
"edges": [
{ "from": "n1", "to": "n2" }
],
"entry": "n1",
"finish": ["n2"]
}"#;
let def = WorkflowDefinition::from_json_str(json).unwrap();
assert_eq!(def.name, "json_flow");
assert_eq!(def.nodes.len(), 2);
}
#[test]
fn test_build_graph_from_yaml() {
let yaml = r#"
name: simple_graph
nodes:
- name: hub
type: router
- name: end_node
type: router
edges:
- from: hub
to: end_node
entry: hub
finish:
- end_node
"#;
let graph = load_graph_from_yaml_str(yaml).unwrap();
assert_eq!(graph.name, "simple_graph");
}
#[test]
fn test_conditional_edge_definition() {
let yaml = r#"
name: cond_flow
nodes:
- name: check
type: router
- name: yes_path
type: router
- name: no_path
type: router
edges:
- from: check
condition:
key: approved
equals: true
then: yes_path
else: no_path
entry: check
finish:
- yes_path
- no_path
"#;
let graph = load_graph_from_yaml_str(yaml).unwrap();
assert_eq!(graph.name, "cond_flow");
}
#[test]
fn test_parallel_edge_definition() {
let yaml = r#"
name: parallel_flow
nodes:
- name: start
type: router
- name: branch_a
type: router
- name: branch_b
type: router
- name: merge
type: router
edges:
- from: start
parallel:
- branch_a
- branch_b
then: merge
entry: start
finish:
- merge
"#;
let graph = load_graph_from_yaml_str(yaml).unwrap();
assert_eq!(graph.name, "parallel_flow");
}
}