1use crate::checkpoint::Checkpointer;
6use crate::edge::{END, Edge, EdgeTarget, START};
7use crate::error::{GraphError, Result};
8use crate::graph::{CompiledGraph, StateGraph};
9use crate::node::{ExecutionConfig, FunctionNode, Node, NodeContext, NodeOutput};
10use crate::state::{State, StateSchema};
11use crate::stream::{StreamEvent, StreamMode};
12use adk_core::{Agent, Content, Event, EventStream, InvocationContext};
13use async_trait::async_trait;
14use serde_json::json;
15use std::collections::HashMap;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19
20pub type BeforeAgentCallback = Arc<
22 dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
23 + Send
24 + Sync,
25>;
26
27pub type AfterAgentCallback = Arc<
28 dyn Fn(
29 Arc<dyn InvocationContext>,
30 Event,
31 ) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
32 + Send
33 + Sync,
34>;
35
36pub type InputMapper = Arc<dyn Fn(&dyn InvocationContext) -> State + Send + Sync>;
38
39pub type OutputMapper = Arc<dyn Fn(&State) -> Vec<Event> + Send + Sync>;
41
42pub struct GraphAgent {
44 name: String,
45 description: String,
46 graph: Arc<CompiledGraph>,
47 input_mapper: InputMapper,
49 output_mapper: OutputMapper,
51 before_callback: Option<BeforeAgentCallback>,
53 after_callback: Option<AfterAgentCallback>,
55}
56
57impl GraphAgent {
58 pub fn builder(name: &str) -> GraphAgentBuilder {
60 GraphAgentBuilder::new(name)
61 }
62
63 pub fn from_graph(name: &str, graph: CompiledGraph) -> Self {
65 Self {
66 name: name.to_string(),
67 description: String::new(),
68 graph: Arc::new(graph),
69 input_mapper: Arc::new(default_input_mapper),
70 output_mapper: Arc::new(default_output_mapper),
71 before_callback: None,
72 after_callback: None,
73 }
74 }
75
76 #[cfg(feature = "action")]
81 pub fn from_workflow_schema(
82 name: &str,
83 schema: &crate::workflow::WorkflowSchema,
84 ) -> Result<Self> {
85 schema.build_graph(name)
86 }
87
88 pub fn graph(&self) -> &CompiledGraph {
90 &self.graph
91 }
92
93 pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
95 self.graph.invoke(input, config).await
96 }
97
98 pub fn stream(
100 &self,
101 input: State,
102 config: ExecutionConfig,
103 mode: StreamMode,
104 ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
105 self.graph.stream(input, config, mode)
106 }
107}
108
109#[async_trait]
110impl Agent for GraphAgent {
111 fn name(&self) -> &str {
112 &self.name
113 }
114
115 fn description(&self) -> &str {
116 &self.description
117 }
118
119 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
120 &[]
121 }
122
123 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> adk_core::Result<EventStream> {
124 if let Some(callback) = &self.before_callback {
126 callback(ctx.clone()).await?;
127 }
128
129 let input = (self.input_mapper)(ctx.as_ref());
131
132 let config = ExecutionConfig::new(ctx.session_id());
134
135 let graph = self.graph.clone();
137 let output_mapper = self.output_mapper.clone();
138 let after_callback = self.after_callback.clone();
139 let ctx_clone = ctx.clone();
140
141 let stream = async_stream::stream! {
142 match graph.invoke(input, config).await {
143 Ok(state) => {
144 let events = output_mapper(&state);
145 for event in events {
146 if let Some(callback) = &after_callback {
148 if let Err(e) = callback(ctx_clone.clone(), event.clone()).await {
149 yield Err(e);
150 return;
151 }
152 }
153 yield Ok(event);
154 }
155 }
156 Err(GraphError::Interrupted(interrupt)) => {
157 let mut event = Event::new("graph_interrupted");
159 event.set_content(Content::new("assistant").with_text(format!(
160 "Graph interrupted: {:?}\nThread: {}\nCheckpoint: {}",
161 interrupt.interrupt,
162 interrupt.thread_id,
163 interrupt.checkpoint_id
164 )));
165 yield Ok(event);
166 }
167 Err(e) => {
168 yield Err(adk_core::AdkError::agent(e.to_string()));
169 }
170 }
171 };
172
173 Ok(Box::pin(stream))
174 }
175}
176
177fn default_input_mapper(ctx: &dyn InvocationContext) -> State {
179 let mut state = State::new();
180
181 let content = ctx.user_content();
183 let text: String = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("\n");
184
185 if !text.is_empty() {
186 state.insert("input".to_string(), json!(text));
187 state.insert("messages".to_string(), json!([{"role": "user", "content": text}]));
188 }
189
190 state.insert("session_id".to_string(), json!(ctx.session_id()));
192
193 state
194}
195
196fn default_output_mapper(state: &State) -> Vec<Event> {
198 let mut events = Vec::new();
199
200 let output_text = state
202 .get("output")
203 .and_then(|v| v.as_str())
204 .or_else(|| state.get("result").and_then(|v| v.as_str()))
205 .or_else(|| {
206 state
207 .get("messages")
208 .and_then(|v| v.as_array())
209 .and_then(|arr| arr.last())
210 .and_then(|msg| msg.get("content"))
211 .and_then(|c| c.as_str())
212 });
213
214 let text = if let Some(text) = output_text {
215 text.to_string()
216 } else {
217 serde_json::to_string_pretty(state).unwrap_or_default()
219 };
220
221 let mut event = Event::new("graph_output");
222 event.set_content(Content::new("assistant").with_text(&text));
223 events.push(event);
224
225 events
226}
227
228pub struct GraphAgentBuilder {
230 name: String,
231 description: String,
232 schema: StateSchema,
233 nodes: Vec<Arc<dyn Node>>,
234 edges: Vec<Edge>,
235 checkpointer: Option<Arc<dyn Checkpointer>>,
236 interrupt_before: Vec<String>,
237 interrupt_after: Vec<String>,
238 recursion_limit: usize,
239 input_mapper: Option<InputMapper>,
240 output_mapper: Option<OutputMapper>,
241 before_callback: Option<BeforeAgentCallback>,
242 after_callback: Option<AfterAgentCallback>,
243}
244
245impl GraphAgentBuilder {
246 pub fn new(name: &str) -> Self {
248 Self {
249 name: name.to_string(),
250 description: String::new(),
251 schema: StateSchema::simple(&["input", "output", "messages"]),
252 nodes: vec![],
253 edges: vec![],
254 checkpointer: None,
255 interrupt_before: vec![],
256 interrupt_after: vec![],
257 recursion_limit: 50,
258 input_mapper: None,
259 output_mapper: None,
260 before_callback: None,
261 after_callback: None,
262 }
263 }
264
265 pub fn description(mut self, desc: &str) -> Self {
267 self.description = desc.to_string();
268 self
269 }
270
271 pub fn state_schema(mut self, schema: StateSchema) -> Self {
273 self.schema = schema;
274 self
275 }
276
277 pub fn channels(mut self, channels: &[&str]) -> Self {
279 self.schema = StateSchema::simple(channels);
280 self
281 }
282
283 pub fn node<N: Node + 'static>(mut self, node: N) -> Self {
285 self.nodes.push(Arc::new(node));
286 self
287 }
288
289 pub fn node_fn<F, Fut>(mut self, name: &str, func: F) -> Self
291 where
292 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
293 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
294 {
295 self.nodes.push(Arc::new(FunctionNode::new(name, func)));
296 self
297 }
298
299 pub fn edge(mut self, source: &str, target: &str) -> Self {
301 let target =
302 if target == END { EdgeTarget::End } else { EdgeTarget::Node(target.to_string()) };
303
304 if source == START {
305 let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
306 match entry_idx {
307 Some(idx) => {
308 if let Edge::Entry { targets } = &mut self.edges[idx] {
309 if let EdgeTarget::Node(node) = &target {
310 if !targets.contains(node) {
311 targets.push(node.clone());
312 }
313 }
314 }
315 }
316 None => {
317 if let EdgeTarget::Node(node) = target {
318 self.edges.push(Edge::Entry { targets: vec![node] });
319 }
320 }
321 }
322 } else {
323 self.edges.push(Edge::Direct { source: source.to_string(), target });
324 }
325
326 self
327 }
328
329 pub fn conditional_edge<F, I>(mut self, source: &str, router: F, targets: I) -> Self
331 where
332 F: Fn(&State) -> String + Send + Sync + 'static,
333 I: IntoIterator<Item = (&'static str, &'static str)>,
334 {
335 let targets_map: HashMap<String, EdgeTarget> = targets
336 .into_iter()
337 .map(|(k, v)| {
338 let target =
339 if v == END { EdgeTarget::End } else { EdgeTarget::Node(v.to_string()) };
340 (k.to_string(), target)
341 })
342 .collect();
343
344 self.edges.push(Edge::Conditional {
345 source: source.to_string(),
346 router: Arc::new(router),
347 targets: targets_map,
348 });
349
350 self
351 }
352
353 pub fn checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
355 self.checkpointer = Some(Arc::new(checkpointer));
356 self
357 }
358
359 pub fn checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
361 self.checkpointer = Some(checkpointer);
362 self
363 }
364
365 pub fn interrupt_before(mut self, nodes: &[&str]) -> Self {
367 self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
368 self
369 }
370
371 pub fn interrupt_after(mut self, nodes: &[&str]) -> Self {
373 self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
374 self
375 }
376
377 pub fn recursion_limit(mut self, limit: usize) -> Self {
379 self.recursion_limit = limit;
380 self
381 }
382
383 pub fn input_mapper<F>(mut self, mapper: F) -> Self
385 where
386 F: Fn(&dyn InvocationContext) -> State + Send + Sync + 'static,
387 {
388 self.input_mapper = Some(Arc::new(mapper));
389 self
390 }
391
392 pub fn output_mapper<F>(mut self, mapper: F) -> Self
394 where
395 F: Fn(&State) -> Vec<Event> + Send + Sync + 'static,
396 {
397 self.output_mapper = Some(Arc::new(mapper));
398 self
399 }
400
401 pub fn before_agent_callback<F, Fut>(mut self, callback: F) -> Self
403 where
404 F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
405 Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
406 {
407 self.before_callback = Some(Arc::new(move |ctx| Box::pin(callback(ctx))));
408 self
409 }
410
411 pub fn after_agent_callback<F, Fut>(mut self, callback: F) -> Self
415 where
416 F: Fn(Arc<dyn InvocationContext>, Event) -> Fut + Send + Sync + 'static,
417 Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
418 {
419 self.after_callback = Some(Arc::new(move |ctx, event| {
420 let event_clone = event.clone();
421 Box::pin(callback(ctx, event_clone))
422 }));
423 self
424 }
425
426 #[cfg(feature = "action")]
432 pub fn action_node(mut self, config: adk_action::ActionNodeConfig) -> Self {
433 use crate::action::ActionNodeExecutor;
434
435 if let adk_action::ActionNodeConfig::Switch(ref switch_config) = config {
437 let conditions = switch_config.conditions.clone();
438 let eval_mode = switch_config.evaluation_mode.clone();
439 let default_branch = switch_config.default_branch.clone();
440 let source = config.standard().id.clone();
441
442 let mut targets_map: HashMap<String, EdgeTarget> = HashMap::new();
443 for condition in &conditions {
444 targets_map.insert(
445 condition.output_port.clone(),
446 EdgeTarget::Node(condition.output_port.clone()),
447 );
448 }
449 if let Some(ref default) = default_branch {
450 let target = if default == END {
451 EdgeTarget::End
452 } else {
453 EdgeTarget::Node(default.clone())
454 };
455 targets_map.insert(default.clone(), target);
456 }
457 targets_map.insert(END.to_string(), EdgeTarget::End);
458
459 let router = Arc::new(move |state: &State| -> String {
460 match crate::action::switch::evaluate_switch_conditions(
461 &conditions,
462 state,
463 &eval_mode,
464 default_branch.as_deref(),
465 ) {
466 Ok(ports) => ports.into_iter().next().unwrap_or_else(|| END.to_string()),
467 Err(_) => END.to_string(),
468 }
469 });
470
471 self.edges.push(Edge::Conditional { source, router, targets: targets_map });
472 }
473
474 let executor = ActionNodeExecutor::new(config);
475 self.nodes.push(Arc::new(executor));
476 self
477 }
478
479 pub fn build(self) -> Result<GraphAgent> {
481 let mut graph = StateGraph::new(self.schema);
483
484 for node in self.nodes {
486 graph.nodes.insert(node.name().to_string(), node);
487 }
488
489 graph.edges = self.edges;
491
492 let mut compiled = graph.compile()?;
494
495 if let Some(cp) = self.checkpointer {
497 compiled.checkpointer = Some(cp);
498 }
499 compiled.interrupt_before = self.interrupt_before.into_iter().collect();
500 compiled.interrupt_after = self.interrupt_after.into_iter().collect();
501 compiled.recursion_limit = self.recursion_limit;
502
503 Ok(GraphAgent {
504 name: self.name,
505 description: self.description,
506 graph: Arc::new(compiled),
507 input_mapper: self.input_mapper.unwrap_or(Arc::new(default_input_mapper)),
508 output_mapper: self.output_mapper.unwrap_or(Arc::new(default_output_mapper)),
509 before_callback: self.before_callback,
510 after_callback: self.after_callback,
511 })
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use serde_json::json;
519
520 #[tokio::test]
521 async fn test_graph_agent_builder() {
522 let agent = GraphAgent::builder("test")
523 .description("Test agent")
524 .channels(&["value"])
525 .node_fn("set", |_ctx| async { Ok(NodeOutput::new().with_update("value", json!(42))) })
526 .edge(START, "set")
527 .edge("set", END)
528 .build()
529 .unwrap();
530
531 assert_eq!(agent.name(), "test");
532 assert_eq!(agent.description(), "Test agent");
533
534 let result = agent.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
536
537 assert_eq!(result.get("value"), Some(&json!(42)));
538 }
539}