use crate::checkpoint::Checkpointer;
use crate::edge::{END, Edge, EdgeTarget, START};
use crate::error::{GraphError, Result};
use crate::graph::{CompiledGraph, StateGraph};
use crate::node::{ExecutionConfig, FunctionNode, Node, NodeContext, NodeOutput};
use crate::state::{State, StateSchema};
use crate::stream::{StreamEvent, StreamMode};
use adk_core::{Agent, Content, Event, EventStream, InvocationContext};
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub type BeforeAgentCallback = Arc<
dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
+ Send
+ Sync,
>;
pub type AfterAgentCallback = Arc<
dyn Fn(
Arc<dyn InvocationContext>,
Event,
) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
+ Send
+ Sync,
>;
pub type InputMapper = Arc<dyn Fn(&dyn InvocationContext) -> State + Send + Sync>;
pub type OutputMapper = Arc<dyn Fn(&State) -> Vec<Event> + Send + Sync>;
pub struct GraphAgent {
name: String,
description: String,
graph: Arc<CompiledGraph>,
input_mapper: InputMapper,
output_mapper: OutputMapper,
before_callback: Option<BeforeAgentCallback>,
after_callback: Option<AfterAgentCallback>,
}
impl GraphAgent {
pub fn builder(name: &str) -> GraphAgentBuilder {
GraphAgentBuilder::new(name)
}
pub fn from_graph(name: &str, graph: CompiledGraph) -> Self {
Self {
name: name.to_string(),
description: String::new(),
graph: Arc::new(graph),
input_mapper: Arc::new(default_input_mapper),
output_mapper: Arc::new(default_output_mapper),
before_callback: None,
after_callback: None,
}
}
#[cfg(feature = "action")]
pub fn from_workflow_schema(
name: &str,
schema: &crate::workflow::WorkflowSchema,
) -> Result<Self> {
schema.build_graph(name)
}
pub fn graph(&self) -> &CompiledGraph {
&self.graph
}
pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
self.graph.invoke(input, config).await
}
pub fn stream(
&self,
input: State,
config: ExecutionConfig,
mode: StreamMode,
) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
self.graph.stream(input, config, mode)
}
}
#[async_trait]
impl Agent for GraphAgent {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn sub_agents(&self) -> &[Arc<dyn Agent>] {
&[]
}
async fn run(&self, ctx: Arc<dyn InvocationContext>) -> adk_core::Result<EventStream> {
if let Some(callback) = &self.before_callback {
callback(ctx.clone()).await?;
}
let input = (self.input_mapper)(ctx.as_ref());
let config = ExecutionConfig::new(ctx.session_id());
let graph = self.graph.clone();
let output_mapper = self.output_mapper.clone();
let after_callback = self.after_callback.clone();
let ctx_clone = ctx.clone();
let stream = async_stream::stream! {
match graph.invoke(input, config).await {
Ok(state) => {
let events = output_mapper(&state);
for event in events {
if let Some(callback) = &after_callback {
if let Err(e) = callback(ctx_clone.clone(), event.clone()).await {
yield Err(e);
return;
}
}
yield Ok(event);
}
}
Err(GraphError::Interrupted(interrupt)) => {
let mut event = Event::new("graph_interrupted");
event.set_content(Content::new("assistant").with_text(format!(
"Graph interrupted: {:?}\nThread: {}\nCheckpoint: {}",
interrupt.interrupt,
interrupt.thread_id,
interrupt.checkpoint_id
)));
yield Ok(event);
}
Err(e) => {
yield Err(adk_core::AdkError::agent(e.to_string()));
}
}
};
Ok(Box::pin(stream))
}
}
fn default_input_mapper(ctx: &dyn InvocationContext) -> State {
let mut state = State::new();
let content = ctx.user_content();
let text: String = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("\n");
if !text.is_empty() {
state.insert("input".to_string(), json!(text));
state.insert("messages".to_string(), json!([{"role": "user", "content": text}]));
}
state.insert("session_id".to_string(), json!(ctx.session_id()));
state
}
fn default_output_mapper(state: &State) -> Vec<Event> {
let mut events = Vec::new();
let output_text = state
.get("output")
.and_then(|v| v.as_str())
.or_else(|| state.get("result").and_then(|v| v.as_str()))
.or_else(|| {
state
.get("messages")
.and_then(|v| v.as_array())
.and_then(|arr| arr.last())
.and_then(|msg| msg.get("content"))
.and_then(|c| c.as_str())
});
let text = if let Some(text) = output_text {
text.to_string()
} else {
serde_json::to_string_pretty(state).unwrap_or_default()
};
let mut event = Event::new("graph_output");
event.set_content(Content::new("assistant").with_text(&text));
events.push(event);
events
}
pub struct GraphAgentBuilder {
name: String,
description: String,
schema: StateSchema,
nodes: Vec<Arc<dyn Node>>,
edges: Vec<Edge>,
checkpointer: Option<Arc<dyn Checkpointer>>,
interrupt_before: Vec<String>,
interrupt_after: Vec<String>,
recursion_limit: usize,
input_mapper: Option<InputMapper>,
output_mapper: Option<OutputMapper>,
before_callback: Option<BeforeAgentCallback>,
after_callback: Option<AfterAgentCallback>,
}
impl GraphAgentBuilder {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
description: String::new(),
schema: StateSchema::simple(&["input", "output", "messages"]),
nodes: vec![],
edges: vec![],
checkpointer: None,
interrupt_before: vec![],
interrupt_after: vec![],
recursion_limit: 50,
input_mapper: None,
output_mapper: None,
before_callback: None,
after_callback: None,
}
}
pub fn description(mut self, desc: &str) -> Self {
self.description = desc.to_string();
self
}
pub fn state_schema(mut self, schema: StateSchema) -> Self {
self.schema = schema;
self
}
pub fn channels(mut self, channels: &[&str]) -> Self {
self.schema = StateSchema::simple(channels);
self
}
pub fn node<N: Node + 'static>(mut self, node: N) -> Self {
self.nodes.push(Arc::new(node));
self
}
pub fn node_fn<F, Fut>(mut self, name: &str, func: F) -> Self
where
F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
{
self.nodes.push(Arc::new(FunctionNode::new(name, func)));
self
}
pub fn edge(mut self, source: &str, target: &str) -> Self {
let target =
if target == END { EdgeTarget::End } else { EdgeTarget::Node(target.to_string()) };
if source == START {
let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
match entry_idx {
Some(idx) => {
if let Edge::Entry { targets } = &mut self.edges[idx] {
if let EdgeTarget::Node(node) = &target {
if !targets.contains(node) {
targets.push(node.clone());
}
}
}
}
None => {
if let EdgeTarget::Node(node) = target {
self.edges.push(Edge::Entry { targets: vec![node] });
}
}
}
} else {
self.edges.push(Edge::Direct { source: source.to_string(), target });
}
self
}
pub fn conditional_edge<F, I>(mut self, source: &str, router: F, targets: I) -> Self
where
F: Fn(&State) -> String + Send + Sync + 'static,
I: IntoIterator<Item = (&'static str, &'static str)>,
{
let targets_map: HashMap<String, EdgeTarget> = targets
.into_iter()
.map(|(k, v)| {
let target =
if v == END { EdgeTarget::End } else { EdgeTarget::Node(v.to_string()) };
(k.to_string(), target)
})
.collect();
self.edges.push(Edge::Conditional {
source: source.to_string(),
router: Arc::new(router),
targets: targets_map,
});
self
}
pub fn checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
self.checkpointer = Some(Arc::new(checkpointer));
self
}
pub fn checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
self.checkpointer = Some(checkpointer);
self
}
pub fn interrupt_before(mut self, nodes: &[&str]) -> Self {
self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
self
}
pub fn interrupt_after(mut self, nodes: &[&str]) -> Self {
self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
self
}
pub fn recursion_limit(mut self, limit: usize) -> Self {
self.recursion_limit = limit;
self
}
pub fn input_mapper<F>(mut self, mapper: F) -> Self
where
F: Fn(&dyn InvocationContext) -> State + Send + Sync + 'static,
{
self.input_mapper = Some(Arc::new(mapper));
self
}
pub fn output_mapper<F>(mut self, mapper: F) -> Self
where
F: Fn(&State) -> Vec<Event> + Send + Sync + 'static,
{
self.output_mapper = Some(Arc::new(mapper));
self
}
pub fn before_agent_callback<F, Fut>(mut self, callback: F) -> Self
where
F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
{
self.before_callback = Some(Arc::new(move |ctx| Box::pin(callback(ctx))));
self
}
pub fn after_agent_callback<F, Fut>(mut self, callback: F) -> Self
where
F: Fn(Arc<dyn InvocationContext>, Event) -> Fut + Send + Sync + 'static,
Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
{
self.after_callback = Some(Arc::new(move |ctx, event| {
let event_clone = event.clone();
Box::pin(callback(ctx, event_clone))
}));
self
}
#[cfg(feature = "action")]
pub fn action_node(mut self, config: adk_action::ActionNodeConfig) -> Self {
use crate::action::ActionNodeExecutor;
if let adk_action::ActionNodeConfig::Switch(ref switch_config) = config {
let conditions = switch_config.conditions.clone();
let eval_mode = switch_config.evaluation_mode.clone();
let default_branch = switch_config.default_branch.clone();
let source = config.standard().id.clone();
let mut targets_map: HashMap<String, EdgeTarget> = HashMap::new();
for condition in &conditions {
targets_map.insert(
condition.output_port.clone(),
EdgeTarget::Node(condition.output_port.clone()),
);
}
if let Some(ref default) = default_branch {
let target = if default == END {
EdgeTarget::End
} else {
EdgeTarget::Node(default.clone())
};
targets_map.insert(default.clone(), target);
}
targets_map.insert(END.to_string(), EdgeTarget::End);
let router = Arc::new(move |state: &State| -> String {
match crate::action::switch::evaluate_switch_conditions(
&conditions,
state,
&eval_mode,
default_branch.as_deref(),
) {
Ok(ports) => ports.into_iter().next().unwrap_or_else(|| END.to_string()),
Err(_) => END.to_string(),
}
});
self.edges.push(Edge::Conditional { source, router, targets: targets_map });
}
let executor = ActionNodeExecutor::new(config);
self.nodes.push(Arc::new(executor));
self
}
pub fn build(self) -> Result<GraphAgent> {
let mut graph = StateGraph::new(self.schema);
for node in self.nodes {
graph.nodes.insert(node.name().to_string(), node);
}
graph.edges = self.edges;
let mut compiled = graph.compile()?;
if let Some(cp) = self.checkpointer {
compiled.checkpointer = Some(cp);
}
compiled.interrupt_before = self.interrupt_before.into_iter().collect();
compiled.interrupt_after = self.interrupt_after.into_iter().collect();
compiled.recursion_limit = self.recursion_limit;
Ok(GraphAgent {
name: self.name,
description: self.description,
graph: Arc::new(compiled),
input_mapper: self.input_mapper.unwrap_or(Arc::new(default_input_mapper)),
output_mapper: self.output_mapper.unwrap_or(Arc::new(default_output_mapper)),
before_callback: self.before_callback,
after_callback: self.after_callback,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_graph_agent_builder() {
let agent = GraphAgent::builder("test")
.description("Test agent")
.channels(&["value"])
.node_fn("set", |_ctx| async { Ok(NodeOutput::new().with_update("value", json!(42))) })
.edge(START, "set")
.edge("set", END)
.build()
.unwrap();
assert_eq!(agent.name(), "test");
assert_eq!(agent.description(), "Test agent");
let result = agent.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
assert_eq!(result.get("value"), Some(&json!(42)));
}
}