use std::collections::HashMap;
use std::sync::Arc;
use adk_action::{ActionNodeConfig, SwitchNodeConfig, TriggerNodeConfig};
use serde::{Deserialize, Serialize};
use crate::action::ActionNodeExecutor;
use crate::action::switch::evaluate_switch_conditions;
use crate::edge::{END, EdgeTarget};
use crate::error::{GraphError, Result};
use crate::graph::StateGraph;
use crate::node::Node;
use crate::state::StateSchema;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WorkflowEdge {
pub from: String,
pub to: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub condition: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub from_port: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub to_port: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WorkflowCondition {
pub id: String,
pub expression: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WorkflowSchema {
pub edges: Vec<WorkflowEdge>,
#[serde(default)]
pub conditions: Vec<WorkflowCondition>,
#[serde(default)]
pub action_nodes: HashMap<String, ActionNodeConfig>,
#[serde(default)]
pub agent_nodes: Vec<String>,
}
impl WorkflowSchema {
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json)
.map_err(|e| GraphError::InvalidGraph(format!("invalid workflow JSON: {e}")))
}
pub fn build_graph(&self, name: &str) -> Result<crate::agent::GraphAgent> {
let schema = StateSchema::simple(&["input", "output", "messages"]);
let mut graph = StateGraph::new(schema);
for (node_id, config) in &self.action_nodes {
let executor = ActionNodeExecutor::new(config.clone());
debug_assert_eq!(
executor.name(),
node_id,
"ActionNodeExecutor name must match node ID"
);
graph.nodes.insert(node_id.clone(), Arc::new(executor));
}
let condition_map: HashMap<&str, &WorkflowCondition> =
self.conditions.iter().map(|c| (c.id.as_str(), c)).collect();
let mut conditional_groups: HashMap<String, Vec<&WorkflowEdge>> = HashMap::new();
let mut direct_edges: Vec<&WorkflowEdge> = Vec::new();
for edge in &self.edges {
if edge.condition.is_some() {
conditional_groups.entry(edge.from.clone()).or_default().push(edge);
} else {
direct_edges.push(edge);
}
}
for edge in &direct_edges {
graph = graph.add_edge(&edge.from, &edge.to);
}
for (node_id, config) in &self.action_nodes {
if let ActionNodeConfig::Switch(switch_config) = config {
graph = register_switch_conditional_edges(graph, node_id, switch_config);
}
}
for (source, edges) in &conditional_groups {
if let Some(ActionNodeConfig::Switch(_)) = self.action_nodes.get(source) {
continue;
}
let mut targets_map: HashMap<String, EdgeTarget> = HashMap::new();
let mut condition_expressions: Vec<(String, String)> = Vec::new();
for edge in edges {
if let Some(cond_id) = &edge.condition {
let target = if edge.to == END {
EdgeTarget::End
} else {
EdgeTarget::Node(edge.to.clone())
};
targets_map.insert(edge.to.clone(), target);
if let Some(cond) = condition_map.get(cond_id.as_str()) {
condition_expressions.push((cond.expression.clone(), edge.to.clone()));
}
}
}
let router_expressions = condition_expressions.clone();
let default_target = END.to_string();
let router = Arc::new(move |state: &crate::state::State| -> String {
for (expr, target) in &router_expressions {
let resolved = adk_action::interpolate_variables(expr, state);
let trimmed = resolved.trim().to_lowercase();
if !trimmed.is_empty() && trimmed != "false" && trimmed != "0" {
return target.clone();
}
}
default_target.clone()
});
graph.edges.push(crate::edge::Edge::Conditional {
source: source.clone(),
router,
targets: targets_map,
});
}
let compiled = graph.compile()?;
Ok(crate::agent::GraphAgent::from_graph(name, compiled))
}
pub fn trigger_configs(&self) -> Vec<TriggerNodeConfig> {
self.action_nodes
.values()
.filter_map(|config| {
if let ActionNodeConfig::Trigger(trigger) = config {
Some(trigger.clone())
} else {
None
}
})
.collect()
}
#[cfg(feature = "action-trigger")]
pub fn build_trigger_runtime(
&self,
graph: Arc<crate::agent::GraphAgent>,
) -> crate::action::trigger_runtime::TriggerRuntime {
let triggers = self.trigger_configs();
crate::action::trigger_runtime::TriggerRuntime::new(graph, triggers)
}
}
fn register_switch_conditional_edges(
mut graph: StateGraph,
node_id: &str,
switch_config: &SwitchNodeConfig,
) -> StateGraph {
let conditions = switch_config.conditions.clone();
let eval_mode = switch_config.evaluation_mode.clone();
let default_branch = switch_config.default_branch.clone();
let mut targets_map: HashMap<String, EdgeTarget> = HashMap::new();
for condition in &conditions {
targets_map
.insert(condition.output_port.clone(), EdgeTarget::Node(condition.output_port.clone()));
}
if let Some(ref default) = default_branch {
let target =
if default == END { EdgeTarget::End } else { EdgeTarget::Node(default.clone()) };
targets_map.insert(default.clone(), target);
}
targets_map.insert(END.to_string(), EdgeTarget::End);
let router = Arc::new(move |state: &crate::state::State| -> String {
match evaluate_switch_conditions(&conditions, state, &eval_mode, default_branch.as_deref())
{
Ok(ports) => ports.into_iter().next().unwrap_or_else(|| END.to_string()),
Err(_) => END.to_string(),
}
});
graph.edges.push(crate::edge::Edge::Conditional {
source: node_id.to_string(),
router,
targets: targets_map,
});
graph
}