use async_trait::async_trait;
use std::sync::Arc;
use tracing::{debug, info};
use crate::error::NodeError;
use crate::graph::{CompiledGraph, NodeExecutor, NodeOutput};
use crate::runner::{GraphRunner, RunnerConfig};
use crate::state::SharedState;
use super::{clone_state, merge_all_context, ResultMerger, StateMapper};
pub struct SubgraphNode {
id: String,
graph: Arc<CompiledGraph>,
config: RunnerConfig,
state_mapper: StateMapper,
result_merger: ResultMerger,
next_node: Option<String>,
}
impl SubgraphNode {
pub fn new(id: impl Into<String>, graph: CompiledGraph) -> Self {
Self {
id: id.into(),
graph: Arc::new(graph),
config: RunnerConfig::default(),
state_mapper: clone_state(),
result_merger: merge_all_context(),
next_node: None,
}
}
pub fn with_config(mut self, config: RunnerConfig) -> Self {
self.config = config;
self
}
pub fn with_state_mapper<F>(mut self, mapper: F) -> Self
where
F: Fn(&crate::state::AgentState) -> crate::state::AgentState + Send + Sync + 'static,
{
self.state_mapper = Box::new(mapper);
self
}
pub fn with_result_merger<F>(mut self, merger: F) -> Self
where
F: Fn(&mut crate::state::AgentState, crate::state::AgentState) + Send + Sync + 'static,
{
self.result_merger = Box::new(merger);
self
}
pub fn then(mut self, next_node: impl Into<String>) -> Self {
self.next_node = Some(next_node.into());
self
}
pub fn then_finish(mut self) -> Self {
self.next_node = None;
self
}
}
#[async_trait]
impl NodeExecutor for SubgraphNode {
fn id(&self) -> &str {
&self.id
}
async fn execute(&self, state: SharedState) -> Result<NodeOutput, NodeError> {
let child_state = {
let guard = state
.read()
.map_err(|e| NodeError::execution_failed(e.to_string()))?;
(self.state_mapper)(&guard)
};
info!(
node_id = %self.id,
subgraph_name = ?self.graph.name(),
"Executing subgraph"
);
let runner = GraphRunner::new((*self.graph).clone(), self.config.clone());
let child_result = runner.invoke(child_state).await.map_err(|e| {
NodeError::execution_failed(format!("Subgraph '{}' failed: {}", self.id, e))
})?;
debug!(
node_id = %self.id,
iterations = child_result.iteration,
"Subgraph completed"
);
{
let mut guard = state
.write()
.map_err(|e| NodeError::execution_failed(e.to_string()))?;
(self.result_merger)(&mut guard, child_result);
}
match &self.next_node {
Some(next) => Ok(NodeOutput::continue_to(next.clone())),
None => Ok(NodeOutput::finish()),
}
}
fn description(&self) -> Option<&str> {
Some("Executes a subgraph")
}
}
#[allow(dead_code)]
pub struct SubgraphNodeBuilder {
id: String,
graph: Option<CompiledGraph>,
config: RunnerConfig,
state_mapper: Option<StateMapper>,
result_merger: Option<ResultMerger>,
next_node: Option<String>,
}
#[allow(dead_code)]
impl SubgraphNodeBuilder {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
graph: None,
config: RunnerConfig::default(),
state_mapper: None,
result_merger: None,
next_node: None,
}
}
pub fn graph(mut self, graph: CompiledGraph) -> Self {
self.graph = Some(graph);
self
}
pub fn config(mut self, config: RunnerConfig) -> Self {
self.config = config;
self
}
pub fn state_mapper<F>(mut self, mapper: F) -> Self
where
F: Fn(&crate::state::AgentState) -> crate::state::AgentState + Send + Sync + 'static,
{
self.state_mapper = Some(Box::new(mapper));
self
}
pub fn result_merger<F>(mut self, merger: F) -> Self
where
F: Fn(&mut crate::state::AgentState, crate::state::AgentState) + Send + Sync + 'static,
{
self.result_merger = Some(Box::new(merger));
self
}
pub fn then(mut self, next_node: impl Into<String>) -> Self {
self.next_node = Some(next_node.into());
self
}
pub fn build(self) -> Result<SubgraphNode, &'static str> {
let graph = self.graph.ok_or("Graph is required")?;
let mut node = SubgraphNode::new(self.id, graph).with_config(self.config);
if let Some(mapper) = self.state_mapper {
node.state_mapper = mapper;
}
if let Some(merger) = self.result_merger {
node.result_merger = merger;
}
node.next_node = self.next_node;
Ok(node)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::GraphBuilder;
use crate::state::AgentState;
use std::sync::{Arc, RwLock};
struct SetValueNode {
id: String,
key: String,
value: String,
}
#[async_trait]
impl NodeExecutor for SetValueNode {
fn id(&self) -> &str {
&self.id
}
async fn execute(&self, state: SharedState) -> Result<NodeOutput, NodeError> {
{
let mut guard = state
.write()
.map_err(|e| NodeError::execution_failed(e.to_string()))?;
guard.set_context(&self.key, self.value.clone());
}
Ok(NodeOutput::finish())
}
}
#[tokio::test]
async fn test_subgraph_node_basic() {
let subgraph = GraphBuilder::new()
.add_node(SetValueNode {
id: "set".to_string(),
key: "child_result".to_string(),
value: "from_child".to_string(),
})
.set_entry_point("set")
.compile()
.unwrap();
let node = SubgraphNode::new("subgraph", subgraph);
let state = Arc::new(RwLock::new(AgentState::new()));
let result = node.execute(state.clone()).await.unwrap();
assert!(result.is_terminal());
let guard = state.read().unwrap();
assert_eq!(
guard.get_context::<String>("child_result"),
Some("from_child".to_string())
);
}
#[tokio::test]
async fn test_subgraph_node_with_mapper() {
struct ProcessNode;
#[async_trait]
impl NodeExecutor for ProcessNode {
fn id(&self) -> &str {
"process"
}
async fn execute(&self, state: SharedState) -> Result<NodeOutput, NodeError> {
let input: String = {
let guard = state
.read()
.map_err(|e| NodeError::execution_failed(e.to_string()))?;
guard.get_context("input").unwrap_or_default()
};
{
let mut guard = state
.write()
.map_err(|e| NodeError::execution_failed(e.to_string()))?;
guard.set_context("output", format!("processed: {}", input));
}
Ok(NodeOutput::finish())
}
}
let subgraph = GraphBuilder::new()
.add_node(ProcessNode)
.set_entry_point("process")
.compile()
.unwrap();
let node = SubgraphNode::new("subgraph", subgraph)
.with_state_mapper(|parent| {
let mut child = AgentState::new();
if let Some(data) = parent.get_context::<String>("data") {
child.set_context("input", data);
}
child
})
.with_result_merger(|parent, child| {
if let Some(output) = child.get_context::<String>("output") {
parent.set_context("result", output);
}
});
let state = Arc::new(RwLock::new(AgentState::new()));
{
let mut guard = state.write().unwrap();
guard.set_context("data", "hello".to_string());
}
node.execute(state.clone()).await.unwrap();
let guard = state.read().unwrap();
assert_eq!(
guard.get_context::<String>("result"),
Some("processed: hello".to_string())
);
}
}