1use crate::checkpoint::Checkpointer;
6use crate::edge::{Edge, EdgeTarget, END, START};
7use crate::error::{GraphError, Result};
8use crate::graph::{CompiledGraph, StateGraph};
9use crate::node::{ExecutionConfig, FunctionNode, Node, NodeContext, NodeOutput};
10use crate::state::{State, StateSchema};
11use crate::stream::{StreamEvent, StreamMode};
12use adk_core::{Agent, Content, Event, EventStream, InvocationContext};
13use async_trait::async_trait;
14use serde_json::json;
15use std::collections::HashMap;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19
20pub type BeforeAgentCallback = Arc<
22 dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
23 + Send
24 + Sync,
25>;
26
27pub type AfterAgentCallback = Arc<
28 dyn Fn(
29 Arc<dyn InvocationContext>,
30 Event,
31 ) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
32 + Send
33 + Sync,
34>;
35
36pub type InputMapper = Arc<dyn Fn(&dyn InvocationContext) -> State + Send + Sync>;
38
39pub type OutputMapper = Arc<dyn Fn(&State) -> Vec<Event> + Send + Sync>;
41
42pub struct GraphAgent {
44 name: String,
45 description: String,
46 graph: Arc<CompiledGraph>,
47 input_mapper: InputMapper,
49 output_mapper: OutputMapper,
51 before_callback: Option<BeforeAgentCallback>,
53 after_callback: Option<AfterAgentCallback>,
55}
56
57impl GraphAgent {
58 pub fn builder(name: &str) -> GraphAgentBuilder {
60 GraphAgentBuilder::new(name)
61 }
62
63 pub fn from_graph(name: &str, graph: CompiledGraph) -> Self {
65 Self {
66 name: name.to_string(),
67 description: String::new(),
68 graph: Arc::new(graph),
69 input_mapper: Arc::new(default_input_mapper),
70 output_mapper: Arc::new(default_output_mapper),
71 before_callback: None,
72 after_callback: None,
73 }
74 }
75
76 pub fn graph(&self) -> &CompiledGraph {
78 &self.graph
79 }
80
81 pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
83 self.graph.invoke(input, config).await
84 }
85
86 pub fn stream(
88 &self,
89 input: State,
90 config: ExecutionConfig,
91 mode: StreamMode,
92 ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
93 self.graph.stream(input, config, mode)
94 }
95}
96
97#[async_trait]
98impl Agent for GraphAgent {
99 fn name(&self) -> &str {
100 &self.name
101 }
102
103 fn description(&self) -> &str {
104 &self.description
105 }
106
107 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
108 &[]
109 }
110
111 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> adk_core::Result<EventStream> {
112 if let Some(callback) = &self.before_callback {
114 callback(ctx.clone()).await?;
115 }
116
117 let input = (self.input_mapper)(ctx.as_ref());
119
120 let config = ExecutionConfig::new(ctx.session_id());
122
123 let graph = self.graph.clone();
125 let output_mapper = self.output_mapper.clone();
126 let after_callback = self.after_callback.clone();
127 let ctx_clone = ctx.clone();
128
129 let stream = async_stream::stream! {
130 match graph.invoke(input, config).await {
131 Ok(state) => {
132 let events = output_mapper(&state);
133 for event in events {
134 if let Some(callback) = &after_callback {
136 if let Err(e) = callback(ctx_clone.clone(), event.clone()).await {
137 yield Err(e);
138 return;
139 }
140 }
141 yield Ok(event);
142 }
143 }
144 Err(GraphError::Interrupted(interrupt)) => {
145 let mut event = Event::new("graph_interrupted");
147 event.set_content(Content::new("assistant").with_text(format!(
148 "Graph interrupted: {:?}\nThread: {}\nCheckpoint: {}",
149 interrupt.interrupt,
150 interrupt.thread_id,
151 interrupt.checkpoint_id
152 )));
153 yield Ok(event);
154 }
155 Err(e) => {
156 yield Err(adk_core::AdkError::Agent(e.to_string()));
157 }
158 }
159 };
160
161 Ok(Box::pin(stream))
162 }
163}
164
165fn default_input_mapper(ctx: &dyn InvocationContext) -> State {
167 let mut state = State::new();
168
169 let content = ctx.user_content();
171 let text: String = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("\n");
172
173 if !text.is_empty() {
174 state.insert("input".to_string(), json!(text));
175 state.insert("messages".to_string(), json!([{"role": "user", "content": text}]));
176 }
177
178 state.insert("session_id".to_string(), json!(ctx.session_id()));
180
181 state
182}
183
184fn default_output_mapper(state: &State) -> Vec<Event> {
186 let mut events = Vec::new();
187
188 let output_text = state
190 .get("output")
191 .and_then(|v| v.as_str())
192 .or_else(|| state.get("result").and_then(|v| v.as_str()))
193 .or_else(|| {
194 state
195 .get("messages")
196 .and_then(|v| v.as_array())
197 .and_then(|arr| arr.last())
198 .and_then(|msg| msg.get("content"))
199 .and_then(|c| c.as_str())
200 });
201
202 let text = if let Some(text) = output_text {
203 text.to_string()
204 } else {
205 serde_json::to_string_pretty(state).unwrap_or_default()
207 };
208
209 let mut event = Event::new("graph_output");
210 event.set_content(Content::new("assistant").with_text(&text));
211 events.push(event);
212
213 events
214}
215
216pub struct GraphAgentBuilder {
218 name: String,
219 description: String,
220 schema: StateSchema,
221 nodes: Vec<Arc<dyn Node>>,
222 edges: Vec<Edge>,
223 checkpointer: Option<Arc<dyn Checkpointer>>,
224 interrupt_before: Vec<String>,
225 interrupt_after: Vec<String>,
226 recursion_limit: usize,
227 input_mapper: Option<InputMapper>,
228 output_mapper: Option<OutputMapper>,
229 before_callback: Option<BeforeAgentCallback>,
230 after_callback: Option<AfterAgentCallback>,
231}
232
233impl GraphAgentBuilder {
234 pub fn new(name: &str) -> Self {
236 Self {
237 name: name.to_string(),
238 description: String::new(),
239 schema: StateSchema::simple(&["input", "output", "messages"]),
240 nodes: vec![],
241 edges: vec![],
242 checkpointer: None,
243 interrupt_before: vec![],
244 interrupt_after: vec![],
245 recursion_limit: 50,
246 input_mapper: None,
247 output_mapper: None,
248 before_callback: None,
249 after_callback: None,
250 }
251 }
252
253 pub fn description(mut self, desc: &str) -> Self {
255 self.description = desc.to_string();
256 self
257 }
258
259 pub fn state_schema(mut self, schema: StateSchema) -> Self {
261 self.schema = schema;
262 self
263 }
264
265 pub fn channels(mut self, channels: &[&str]) -> Self {
267 self.schema = StateSchema::simple(channels);
268 self
269 }
270
271 pub fn node<N: Node + 'static>(mut self, node: N) -> Self {
273 self.nodes.push(Arc::new(node));
274 self
275 }
276
277 pub fn node_fn<F, Fut>(mut self, name: &str, func: F) -> Self
279 where
280 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
281 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
282 {
283 self.nodes.push(Arc::new(FunctionNode::new(name, func)));
284 self
285 }
286
287 pub fn edge(mut self, source: &str, target: &str) -> Self {
289 let target =
290 if target == END { EdgeTarget::End } else { EdgeTarget::Node(target.to_string()) };
291
292 if source == START {
293 let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
294 match entry_idx {
295 Some(idx) => {
296 if let Edge::Entry { targets } = &mut self.edges[idx] {
297 if let EdgeTarget::Node(node) = &target {
298 if !targets.contains(node) {
299 targets.push(node.clone());
300 }
301 }
302 }
303 }
304 None => {
305 if let EdgeTarget::Node(node) = target {
306 self.edges.push(Edge::Entry { targets: vec![node] });
307 }
308 }
309 }
310 } else {
311 self.edges.push(Edge::Direct { source: source.to_string(), target });
312 }
313
314 self
315 }
316
317 pub fn conditional_edge<F, I>(mut self, source: &str, router: F, targets: I) -> Self
319 where
320 F: Fn(&State) -> String + Send + Sync + 'static,
321 I: IntoIterator<Item = (&'static str, &'static str)>,
322 {
323 let targets_map: HashMap<String, EdgeTarget> = targets
324 .into_iter()
325 .map(|(k, v)| {
326 let target =
327 if v == END { EdgeTarget::End } else { EdgeTarget::Node(v.to_string()) };
328 (k.to_string(), target)
329 })
330 .collect();
331
332 self.edges.push(Edge::Conditional {
333 source: source.to_string(),
334 router: Arc::new(router),
335 targets: targets_map,
336 });
337
338 self
339 }
340
341 pub fn checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
343 self.checkpointer = Some(Arc::new(checkpointer));
344 self
345 }
346
347 pub fn checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
349 self.checkpointer = Some(checkpointer);
350 self
351 }
352
353 pub fn interrupt_before(mut self, nodes: &[&str]) -> Self {
355 self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
356 self
357 }
358
359 pub fn interrupt_after(mut self, nodes: &[&str]) -> Self {
361 self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
362 self
363 }
364
365 pub fn recursion_limit(mut self, limit: usize) -> Self {
367 self.recursion_limit = limit;
368 self
369 }
370
371 pub fn input_mapper<F>(mut self, mapper: F) -> Self
373 where
374 F: Fn(&dyn InvocationContext) -> State + Send + Sync + 'static,
375 {
376 self.input_mapper = Some(Arc::new(mapper));
377 self
378 }
379
380 pub fn output_mapper<F>(mut self, mapper: F) -> Self
382 where
383 F: Fn(&State) -> Vec<Event> + Send + Sync + 'static,
384 {
385 self.output_mapper = Some(Arc::new(mapper));
386 self
387 }
388
389 pub fn before_agent_callback<F, Fut>(mut self, callback: F) -> Self
391 where
392 F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
393 Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
394 {
395 self.before_callback = Some(Arc::new(move |ctx| Box::pin(callback(ctx))));
396 self
397 }
398
399 pub fn after_agent_callback<F, Fut>(mut self, callback: F) -> Self
403 where
404 F: Fn(Arc<dyn InvocationContext>, Event) -> Fut + Send + Sync + 'static,
405 Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
406 {
407 self.after_callback = Some(Arc::new(move |ctx, event| {
408 let event_clone = event.clone();
409 Box::pin(callback(ctx, event_clone))
410 }));
411 self
412 }
413
414 pub fn build(self) -> Result<GraphAgent> {
416 let mut graph = StateGraph::new(self.schema);
418
419 for node in self.nodes {
421 graph.nodes.insert(node.name().to_string(), node);
422 }
423
424 graph.edges = self.edges;
426
427 let mut compiled = graph.compile()?;
429
430 if let Some(cp) = self.checkpointer {
432 compiled.checkpointer = Some(cp);
433 }
434 compiled.interrupt_before = self.interrupt_before.into_iter().collect();
435 compiled.interrupt_after = self.interrupt_after.into_iter().collect();
436 compiled.recursion_limit = self.recursion_limit;
437
438 Ok(GraphAgent {
439 name: self.name,
440 description: self.description,
441 graph: Arc::new(compiled),
442 input_mapper: self.input_mapper.unwrap_or(Arc::new(default_input_mapper)),
443 output_mapper: self.output_mapper.unwrap_or(Arc::new(default_output_mapper)),
444 before_callback: self.before_callback,
445 after_callback: self.after_callback,
446 })
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453 use serde_json::json;
454
455 #[tokio::test]
456 async fn test_graph_agent_builder() {
457 let agent = GraphAgent::builder("test")
458 .description("Test agent")
459 .channels(&["value"])
460 .node_fn("set", |_ctx| async { Ok(NodeOutput::new().with_update("value", json!(42))) })
461 .edge(START, "set")
462 .edge("set", END)
463 .build()
464 .unwrap();
465
466 assert_eq!(agent.name(), "test");
467 assert_eq!(agent.description(), "Test agent");
468
469 let result = agent.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
471
472 assert_eq!(result.get("value"), Some(&json!(42)));
473 }
474}