1use crate::checkpoint::Checkpointer;
6use crate::deferred::DeferredNodeConfig;
7use crate::edge::{END, Edge, EdgeTarget, START};
8use crate::error::{GraphError, Result};
9use crate::graph::{CompiledGraph, StateGraph};
10use crate::node::{ExecutionConfig, FunctionNode, Node, NodeContext, NodeOutput};
11use crate::state::{State, StateSchema};
12use crate::stream::{StreamEvent, StreamMode};
13use crate::timeout::TimeoutPolicy;
14use adk_core::{Agent, Content, Event, EventStream, InvocationContext};
15use async_trait::async_trait;
16use serde_json::json;
17use std::collections::HashMap;
18use std::future::Future;
19use std::pin::Pin;
20use std::sync::Arc;
21
22pub type BeforeAgentCallback = Arc<
24 dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
25 + Send
26 + Sync,
27>;
28
29pub type AfterAgentCallback = Arc<
30 dyn Fn(
31 Arc<dyn InvocationContext>,
32 Event,
33 ) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
34 + Send
35 + Sync,
36>;
37
38pub type InputMapper = Arc<dyn Fn(&dyn InvocationContext) -> State + Send + Sync>;
40
41pub type OutputMapper = Arc<dyn Fn(&State) -> Vec<Event> + Send + Sync>;
43
44pub struct GraphAgent {
46 name: String,
47 description: String,
48 graph: Arc<CompiledGraph>,
49 input_mapper: InputMapper,
51 output_mapper: OutputMapper,
53 before_callback: Option<BeforeAgentCallback>,
55 after_callback: Option<AfterAgentCallback>,
57}
58
59impl GraphAgent {
60 pub fn builder(name: &str) -> GraphAgentBuilder {
62 GraphAgentBuilder::new(name)
63 }
64
65 pub fn from_graph(name: &str, graph: CompiledGraph) -> Self {
67 Self {
68 name: name.to_string(),
69 description: String::new(),
70 graph: Arc::new(graph),
71 input_mapper: Arc::new(default_input_mapper),
72 output_mapper: Arc::new(default_output_mapper),
73 before_callback: None,
74 after_callback: None,
75 }
76 }
77
78 #[cfg(feature = "action")]
83 pub fn from_workflow_schema(
84 name: &str,
85 schema: &crate::workflow::WorkflowSchema,
86 ) -> Result<Self> {
87 schema.build_graph(name)
88 }
89
90 pub fn graph(&self) -> &CompiledGraph {
92 &self.graph
93 }
94
95 pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
97 self.graph.invoke(input, config).await
98 }
99
100 pub fn stream(
102 &self,
103 input: State,
104 config: ExecutionConfig,
105 mode: StreamMode,
106 ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
107 self.graph.stream(input, config, mode)
108 }
109}
110
111#[async_trait]
112impl Agent for GraphAgent {
113 fn name(&self) -> &str {
114 &self.name
115 }
116
117 fn description(&self) -> &str {
118 &self.description
119 }
120
121 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
122 &[]
123 }
124
125 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> adk_core::Result<EventStream> {
126 if let Some(callback) = &self.before_callback {
128 callback(ctx.clone()).await?;
129 }
130
131 let input = (self.input_mapper)(ctx.as_ref());
133
134 let config = ExecutionConfig::new(ctx.session_id());
136
137 let graph = self.graph.clone();
139 let output_mapper = self.output_mapper.clone();
140 let after_callback = self.after_callback.clone();
141 let ctx_clone = ctx.clone();
142
143 let stream = async_stream::stream! {
144 match graph.invoke(input, config).await {
145 Ok(state) => {
146 let events = output_mapper(&state);
147 for event in events {
148 if let Some(callback) = &after_callback
150 && let Err(e) = callback(ctx_clone.clone(), event.clone()).await {
151 yield Err(e);
152 return;
153 }
154 yield Ok(event);
155 }
156 }
157 Err(GraphError::Interrupted(interrupt)) => {
158 let mut event = Event::new("graph_interrupted");
160 event.set_content(Content::new("assistant").with_text(format!(
161 "Graph interrupted: {:?}\nThread: {}\nCheckpoint: {}",
162 interrupt.interrupt,
163 interrupt.thread_id,
164 interrupt.checkpoint_id
165 )));
166 yield Ok(event);
167 }
168 Err(e) => {
169 yield Err(adk_core::AdkError::agent(e.to_string()));
170 }
171 }
172 };
173
174 Ok(Box::pin(stream))
175 }
176}
177
178fn default_input_mapper(ctx: &dyn InvocationContext) -> State {
180 let mut state = State::new();
181
182 let content = ctx.user_content();
184 let text: String = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("\n");
185
186 if !text.is_empty() {
187 state.insert("input".to_string(), json!(text));
188 state.insert("messages".to_string(), json!([{"role": "user", "content": text}]));
189 }
190
191 state.insert("session_id".to_string(), json!(ctx.session_id()));
193
194 state
195}
196
197fn default_output_mapper(state: &State) -> Vec<Event> {
199 let mut events = Vec::new();
200
201 let output_text = state
203 .get("output")
204 .and_then(|v| v.as_str())
205 .or_else(|| state.get("result").and_then(|v| v.as_str()))
206 .or_else(|| {
207 state
208 .get("messages")
209 .and_then(|v| v.as_array())
210 .and_then(|arr| arr.last())
211 .and_then(|msg| msg.get("content"))
212 .and_then(|c| c.as_str())
213 });
214
215 let text = if let Some(text) = output_text {
216 text.to_string()
217 } else {
218 serde_json::to_string_pretty(state).unwrap_or_default()
220 };
221
222 let mut event = Event::new("graph_output");
223 event.set_content(Content::new("assistant").with_text(&text));
224 events.push(event);
225
226 events
227}
228
229pub struct GraphAgentBuilder {
231 name: String,
232 description: String,
233 schema: StateSchema,
234 nodes: Vec<Arc<dyn Node>>,
235 edges: Vec<Edge>,
236 checkpointer: Option<Arc<dyn Checkpointer>>,
237 interrupt_before: Vec<String>,
238 interrupt_after: Vec<String>,
239 recursion_limit: usize,
240 input_mapper: Option<InputMapper>,
241 output_mapper: Option<OutputMapper>,
242 before_callback: Option<BeforeAgentCallback>,
243 after_callback: Option<AfterAgentCallback>,
244 timeout_policies: HashMap<String, TimeoutPolicy>,
245 default_timeout: Option<TimeoutPolicy>,
246 deferred_configs: HashMap<String, DeferredNodeConfig>,
247 #[cfg(feature = "node-cache")]
248 cache_policies: HashMap<String, crate::cache::NodeCachePolicy>,
249}
250
251impl GraphAgentBuilder {
252 pub fn new(name: &str) -> Self {
254 Self {
255 name: name.to_string(),
256 description: String::new(),
257 schema: StateSchema::simple(&["input", "output", "messages"]),
258 nodes: vec![],
259 edges: vec![],
260 checkpointer: None,
261 interrupt_before: vec![],
262 interrupt_after: vec![],
263 recursion_limit: 50,
264 input_mapper: None,
265 output_mapper: None,
266 before_callback: None,
267 after_callback: None,
268 timeout_policies: HashMap::new(),
269 default_timeout: None,
270 deferred_configs: HashMap::new(),
271 #[cfg(feature = "node-cache")]
272 cache_policies: HashMap::new(),
273 }
274 }
275
276 pub fn description(mut self, desc: &str) -> Self {
278 self.description = desc.to_string();
279 self
280 }
281
282 pub fn state_schema(mut self, schema: StateSchema) -> Self {
284 self.schema = schema;
285 self
286 }
287
288 pub fn channels(mut self, channels: &[&str]) -> Self {
290 self.schema = StateSchema::simple(channels);
291 self
292 }
293
294 pub fn node<N: Node + 'static>(mut self, node: N) -> Self {
296 self.nodes.push(Arc::new(node));
297 self
298 }
299
300 pub fn node_fn<F, Fut>(mut self, name: &str, func: F) -> Self
302 where
303 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
304 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
305 {
306 self.nodes.push(Arc::new(FunctionNode::new(name, func)));
307 self
308 }
309
310 pub fn edge(mut self, source: &str, target: &str) -> Self {
312 let target =
313 if target == END { EdgeTarget::End } else { EdgeTarget::Node(target.to_string()) };
314
315 if source == START {
316 let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
317 match entry_idx {
318 Some(idx) => {
319 if let Edge::Entry { targets } = &mut self.edges[idx]
320 && let EdgeTarget::Node(node) = &target
321 && !targets.contains(node)
322 {
323 targets.push(node.clone());
324 }
325 }
326 None => {
327 if let EdgeTarget::Node(node) = target {
328 self.edges.push(Edge::Entry { targets: vec![node] });
329 }
330 }
331 }
332 } else {
333 self.edges.push(Edge::Direct { source: source.to_string(), target });
334 }
335
336 self
337 }
338
339 pub fn conditional_edge<F, I>(mut self, source: &str, router: F, targets: I) -> Self
341 where
342 F: Fn(&State) -> String + Send + Sync + 'static,
343 I: IntoIterator<Item = (&'static str, &'static str)>,
344 {
345 let targets_map: HashMap<String, EdgeTarget> = targets
346 .into_iter()
347 .map(|(k, v)| {
348 let target =
349 if v == END { EdgeTarget::End } else { EdgeTarget::Node(v.to_string()) };
350 (k.to_string(), target)
351 })
352 .collect();
353
354 self.edges.push(Edge::Conditional {
355 source: source.to_string(),
356 router: Arc::new(router),
357 targets: targets_map,
358 });
359
360 self
361 }
362
363 pub fn checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
365 self.checkpointer = Some(Arc::new(checkpointer));
366 self
367 }
368
369 pub fn checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
371 self.checkpointer = Some(checkpointer);
372 self
373 }
374
375 pub fn interrupt_before(mut self, nodes: &[&str]) -> Self {
377 self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
378 self
379 }
380
381 pub fn interrupt_after(mut self, nodes: &[&str]) -> Self {
383 self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
384 self
385 }
386
387 pub fn recursion_limit(mut self, limit: usize) -> Self {
389 self.recursion_limit = limit;
390 self
391 }
392
393 pub fn node_timeout(mut self, node_name: &str, policy: TimeoutPolicy) -> Self {
413 self.timeout_policies.insert(node_name.to_string(), policy);
414 self
415 }
416
417 pub fn default_timeout(mut self, policy: TimeoutPolicy) -> Self {
437 self.default_timeout = Some(policy);
438 self
439 }
440
441 pub fn deferred_node<F, Fut>(mut self, name: &str, func: F, config: DeferredNodeConfig) -> Self
471 where
472 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
473 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
474 {
475 self.nodes.push(Arc::new(FunctionNode::new(name, func)));
476 self.deferred_configs.insert(name.to_string(), config);
477 self
478 }
479
480 #[cfg(feature = "node-cache")]
506 pub fn node_cache(mut self, name: &str, policy: crate::cache::NodeCachePolicy) -> Self {
507 self.cache_policies.insert(name.to_string(), policy);
508 self
509 }
510
511 pub fn input_mapper<F>(mut self, mapper: F) -> Self
513 where
514 F: Fn(&dyn InvocationContext) -> State + Send + Sync + 'static,
515 {
516 self.input_mapper = Some(Arc::new(mapper));
517 self
518 }
519
520 pub fn output_mapper<F>(mut self, mapper: F) -> Self
522 where
523 F: Fn(&State) -> Vec<Event> + Send + Sync + 'static,
524 {
525 self.output_mapper = Some(Arc::new(mapper));
526 self
527 }
528
529 pub fn before_agent_callback<F, Fut>(mut self, callback: F) -> Self
531 where
532 F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
533 Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
534 {
535 self.before_callback = Some(Arc::new(move |ctx| Box::pin(callback(ctx))));
536 self
537 }
538
539 pub fn after_agent_callback<F, Fut>(mut self, callback: F) -> Self
543 where
544 F: Fn(Arc<dyn InvocationContext>, Event) -> Fut + Send + Sync + 'static,
545 Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
546 {
547 self.after_callback = Some(Arc::new(move |ctx, event| {
548 let event_clone = event.clone();
549 Box::pin(callback(ctx, event_clone))
550 }));
551 self
552 }
553
554 #[cfg(feature = "action")]
560 pub fn action_node(mut self, config: adk_action::ActionNodeConfig) -> Self {
561 use crate::action::ActionNodeExecutor;
562
563 if let adk_action::ActionNodeConfig::Switch(ref switch_config) = config {
565 let conditions = switch_config.conditions.clone();
566 let eval_mode = switch_config.evaluation_mode.clone();
567 let default_branch = switch_config.default_branch.clone();
568 let source = config.standard().id.clone();
569
570 let mut targets_map: HashMap<String, EdgeTarget> = HashMap::new();
571 for condition in &conditions {
572 targets_map.insert(
573 condition.output_port.clone(),
574 EdgeTarget::Node(condition.output_port.clone()),
575 );
576 }
577 if let Some(ref default) = default_branch {
578 let target = if default == END {
579 EdgeTarget::End
580 } else {
581 EdgeTarget::Node(default.clone())
582 };
583 targets_map.insert(default.clone(), target);
584 }
585 targets_map.insert(END.to_string(), EdgeTarget::End);
586
587 let router = Arc::new(move |state: &State| -> String {
588 match crate::action::switch::evaluate_switch_conditions(
589 &conditions,
590 state,
591 &eval_mode,
592 default_branch.as_deref(),
593 ) {
594 Ok(ports) => ports.into_iter().next().unwrap_or_else(|| END.to_string()),
595 Err(_) => END.to_string(),
596 }
597 });
598
599 self.edges.push(Edge::Conditional { source, router, targets: targets_map });
600 }
601
602 let executor = ActionNodeExecutor::new(config);
603 self.nodes.push(Arc::new(executor));
604 self
605 }
606
607 pub fn build(self) -> Result<GraphAgent> {
609 let mut graph = StateGraph::new(self.schema);
611
612 for node in self.nodes {
614 graph.nodes.insert(node.name().to_string(), node);
615 }
616
617 graph.edges = self.edges;
619
620 let mut compiled = graph.compile()?;
622
623 if let Some(cp) = self.checkpointer {
625 compiled.checkpointer = Some(cp);
626 }
627 compiled.interrupt_before = self.interrupt_before.into_iter().collect();
628 compiled.interrupt_after = self.interrupt_after.into_iter().collect();
629 compiled.recursion_limit = self.recursion_limit;
630 compiled.timeout_policies = self.timeout_policies;
631 compiled.default_timeout = self.default_timeout;
632 compiled.deferred_configs = self.deferred_configs;
633
634 #[cfg(feature = "node-cache")]
635 {
636 compiled.cache_policies = self.cache_policies;
637 }
638
639 Ok(GraphAgent {
640 name: self.name,
641 description: self.description,
642 graph: Arc::new(compiled),
643 input_mapper: self.input_mapper.unwrap_or(Arc::new(default_input_mapper)),
644 output_mapper: self.output_mapper.unwrap_or(Arc::new(default_output_mapper)),
645 before_callback: self.before_callback,
646 after_callback: self.after_callback,
647 })
648 }
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654 use serde_json::json;
655
656 #[tokio::test]
657 async fn test_graph_agent_builder() {
658 let agent = GraphAgent::builder("test")
659 .description("Test agent")
660 .channels(&["value"])
661 .node_fn("set", |_ctx| async { Ok(NodeOutput::new().with_update("value", json!(42))) })
662 .edge(START, "set")
663 .edge("set", END)
664 .build()
665 .unwrap();
666
667 assert_eq!(agent.name(), "test");
668 assert_eq!(agent.description(), "Test agent");
669
670 let result = agent.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
672
673 assert_eq!(result.get("value"), Some(&json!(42)));
674 }
675}