1use crate::checkpoint::Checkpointer;
6use crate::deferred::DeferredNodeConfig;
7use crate::edge::{END, Edge, EdgeTarget, START};
8use crate::error::{GraphError, Result};
9use crate::graph::{CompiledGraph, StateGraph};
10use crate::node::{ExecutionConfig, FunctionNode, Node, NodeContext, NodeOutput};
11use crate::state::{State, StateSchema};
12use crate::stream::{StreamEvent, StreamMode};
13use crate::timeout::TimeoutPolicy;
14use adk_core::{Agent, Content, Event, EventStream, InvocationContext};
15use async_trait::async_trait;
16use serde_json::json;
17use std::collections::HashMap;
18use std::future::Future;
19use std::pin::Pin;
20use std::sync::Arc;
21
22pub type BeforeAgentCallback = Arc<
24 dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
25 + Send
26 + Sync,
27>;
28
29pub type AfterAgentCallback = Arc<
30 dyn Fn(
31 Arc<dyn InvocationContext>,
32 Event,
33 ) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
34 + Send
35 + Sync,
36>;
37
38pub type InputMapper = Arc<dyn Fn(&dyn InvocationContext) -> State + Send + Sync>;
40
41pub type OutputMapper = Arc<dyn Fn(&State) -> Vec<Event> + Send + Sync>;
43
44pub struct GraphAgent {
46 name: String,
47 description: String,
48 graph: Arc<CompiledGraph>,
49 input_mapper: InputMapper,
51 output_mapper: OutputMapper,
53 before_callback: Option<BeforeAgentCallback>,
55 after_callback: Option<AfterAgentCallback>,
57}
58
59impl GraphAgent {
60 pub fn builder(name: &str) -> GraphAgentBuilder {
62 GraphAgentBuilder::new(name)
63 }
64
65 pub fn from_graph(name: &str, graph: CompiledGraph) -> Self {
67 Self {
68 name: name.to_string(),
69 description: String::new(),
70 graph: Arc::new(graph),
71 input_mapper: Arc::new(default_input_mapper),
72 output_mapper: Arc::new(default_output_mapper),
73 before_callback: None,
74 after_callback: None,
75 }
76 }
77
78 #[cfg(feature = "action")]
83 pub fn from_workflow_schema(
84 name: &str,
85 schema: &crate::workflow::WorkflowSchema,
86 ) -> Result<Self> {
87 schema.build_graph(name)
88 }
89
90 pub fn graph(&self) -> &CompiledGraph {
92 &self.graph
93 }
94
95 pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
97 self.graph.invoke(input, config).await
98 }
99
100 pub fn stream(
102 &self,
103 input: State,
104 config: ExecutionConfig,
105 mode: StreamMode,
106 ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
107 self.graph.stream(input, config, mode)
108 }
109}
110
111#[async_trait]
112impl Agent for GraphAgent {
113 fn name(&self) -> &str {
114 &self.name
115 }
116
117 fn description(&self) -> &str {
118 &self.description
119 }
120
121 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
122 &[]
123 }
124
125 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> adk_core::Result<EventStream> {
126 if let Some(callback) = &self.before_callback {
128 callback(ctx.clone()).await?;
129 }
130
131 let input = (self.input_mapper)(ctx.as_ref());
133
134 let config = ExecutionConfig::new(ctx.session_id());
136
137 let graph = self.graph.clone();
139 let output_mapper = self.output_mapper.clone();
140 let after_callback = self.after_callback.clone();
141 let ctx_clone = ctx.clone();
142
143 let stream = async_stream::stream! {
144 match graph.invoke(input, config).await {
145 Ok(state) => {
146 let events = output_mapper(&state);
147 for event in events {
148 if let Some(callback) = &after_callback {
150 if let Err(e) = callback(ctx_clone.clone(), event.clone()).await {
151 yield Err(e);
152 return;
153 }
154 }
155 yield Ok(event);
156 }
157 }
158 Err(GraphError::Interrupted(interrupt)) => {
159 let mut event = Event::new("graph_interrupted");
161 event.set_content(Content::new("assistant").with_text(format!(
162 "Graph interrupted: {:?}\nThread: {}\nCheckpoint: {}",
163 interrupt.interrupt,
164 interrupt.thread_id,
165 interrupt.checkpoint_id
166 )));
167 yield Ok(event);
168 }
169 Err(e) => {
170 yield Err(adk_core::AdkError::agent(e.to_string()));
171 }
172 }
173 };
174
175 Ok(Box::pin(stream))
176 }
177}
178
179fn default_input_mapper(ctx: &dyn InvocationContext) -> State {
181 let mut state = State::new();
182
183 let content = ctx.user_content();
185 let text: String = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("\n");
186
187 if !text.is_empty() {
188 state.insert("input".to_string(), json!(text));
189 state.insert("messages".to_string(), json!([{"role": "user", "content": text}]));
190 }
191
192 state.insert("session_id".to_string(), json!(ctx.session_id()));
194
195 state
196}
197
198fn default_output_mapper(state: &State) -> Vec<Event> {
200 let mut events = Vec::new();
201
202 let output_text = state
204 .get("output")
205 .and_then(|v| v.as_str())
206 .or_else(|| state.get("result").and_then(|v| v.as_str()))
207 .or_else(|| {
208 state
209 .get("messages")
210 .and_then(|v| v.as_array())
211 .and_then(|arr| arr.last())
212 .and_then(|msg| msg.get("content"))
213 .and_then(|c| c.as_str())
214 });
215
216 let text = if let Some(text) = output_text {
217 text.to_string()
218 } else {
219 serde_json::to_string_pretty(state).unwrap_or_default()
221 };
222
223 let mut event = Event::new("graph_output");
224 event.set_content(Content::new("assistant").with_text(&text));
225 events.push(event);
226
227 events
228}
229
230pub struct GraphAgentBuilder {
232 name: String,
233 description: String,
234 schema: StateSchema,
235 nodes: Vec<Arc<dyn Node>>,
236 edges: Vec<Edge>,
237 checkpointer: Option<Arc<dyn Checkpointer>>,
238 interrupt_before: Vec<String>,
239 interrupt_after: Vec<String>,
240 recursion_limit: usize,
241 input_mapper: Option<InputMapper>,
242 output_mapper: Option<OutputMapper>,
243 before_callback: Option<BeforeAgentCallback>,
244 after_callback: Option<AfterAgentCallback>,
245 timeout_policies: HashMap<String, TimeoutPolicy>,
246 default_timeout: Option<TimeoutPolicy>,
247 deferred_configs: HashMap<String, DeferredNodeConfig>,
248 #[cfg(feature = "node-cache")]
249 cache_policies: HashMap<String, crate::cache::NodeCachePolicy>,
250}
251
252impl GraphAgentBuilder {
253 pub fn new(name: &str) -> Self {
255 Self {
256 name: name.to_string(),
257 description: String::new(),
258 schema: StateSchema::simple(&["input", "output", "messages"]),
259 nodes: vec![],
260 edges: vec![],
261 checkpointer: None,
262 interrupt_before: vec![],
263 interrupt_after: vec![],
264 recursion_limit: 50,
265 input_mapper: None,
266 output_mapper: None,
267 before_callback: None,
268 after_callback: None,
269 timeout_policies: HashMap::new(),
270 default_timeout: None,
271 deferred_configs: HashMap::new(),
272 #[cfg(feature = "node-cache")]
273 cache_policies: HashMap::new(),
274 }
275 }
276
277 pub fn description(mut self, desc: &str) -> Self {
279 self.description = desc.to_string();
280 self
281 }
282
283 pub fn state_schema(mut self, schema: StateSchema) -> Self {
285 self.schema = schema;
286 self
287 }
288
289 pub fn channels(mut self, channels: &[&str]) -> Self {
291 self.schema = StateSchema::simple(channels);
292 self
293 }
294
295 pub fn node<N: Node + 'static>(mut self, node: N) -> Self {
297 self.nodes.push(Arc::new(node));
298 self
299 }
300
301 pub fn node_fn<F, Fut>(mut self, name: &str, func: F) -> Self
303 where
304 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
305 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
306 {
307 self.nodes.push(Arc::new(FunctionNode::new(name, func)));
308 self
309 }
310
311 pub fn edge(mut self, source: &str, target: &str) -> Self {
313 let target =
314 if target == END { EdgeTarget::End } else { EdgeTarget::Node(target.to_string()) };
315
316 if source == START {
317 let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
318 match entry_idx {
319 Some(idx) => {
320 if let Edge::Entry { targets } = &mut self.edges[idx] {
321 if let EdgeTarget::Node(node) = &target {
322 if !targets.contains(node) {
323 targets.push(node.clone());
324 }
325 }
326 }
327 }
328 None => {
329 if let EdgeTarget::Node(node) = target {
330 self.edges.push(Edge::Entry { targets: vec![node] });
331 }
332 }
333 }
334 } else {
335 self.edges.push(Edge::Direct { source: source.to_string(), target });
336 }
337
338 self
339 }
340
341 pub fn conditional_edge<F, I>(mut self, source: &str, router: F, targets: I) -> Self
343 where
344 F: Fn(&State) -> String + Send + Sync + 'static,
345 I: IntoIterator<Item = (&'static str, &'static str)>,
346 {
347 let targets_map: HashMap<String, EdgeTarget> = targets
348 .into_iter()
349 .map(|(k, v)| {
350 let target =
351 if v == END { EdgeTarget::End } else { EdgeTarget::Node(v.to_string()) };
352 (k.to_string(), target)
353 })
354 .collect();
355
356 self.edges.push(Edge::Conditional {
357 source: source.to_string(),
358 router: Arc::new(router),
359 targets: targets_map,
360 });
361
362 self
363 }
364
365 pub fn checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
367 self.checkpointer = Some(Arc::new(checkpointer));
368 self
369 }
370
371 pub fn checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
373 self.checkpointer = Some(checkpointer);
374 self
375 }
376
377 pub fn interrupt_before(mut self, nodes: &[&str]) -> Self {
379 self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
380 self
381 }
382
383 pub fn interrupt_after(mut self, nodes: &[&str]) -> Self {
385 self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
386 self
387 }
388
389 pub fn recursion_limit(mut self, limit: usize) -> Self {
391 self.recursion_limit = limit;
392 self
393 }
394
395 pub fn node_timeout(mut self, node_name: &str, policy: TimeoutPolicy) -> Self {
415 self.timeout_policies.insert(node_name.to_string(), policy);
416 self
417 }
418
419 pub fn default_timeout(mut self, policy: TimeoutPolicy) -> Self {
439 self.default_timeout = Some(policy);
440 self
441 }
442
443 pub fn deferred_node<F, Fut>(mut self, name: &str, func: F, config: DeferredNodeConfig) -> Self
473 where
474 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
475 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
476 {
477 self.nodes.push(Arc::new(FunctionNode::new(name, func)));
478 self.deferred_configs.insert(name.to_string(), config);
479 self
480 }
481
482 #[cfg(feature = "node-cache")]
508 pub fn node_cache(mut self, name: &str, policy: crate::cache::NodeCachePolicy) -> Self {
509 self.cache_policies.insert(name.to_string(), policy);
510 self
511 }
512
513 pub fn input_mapper<F>(mut self, mapper: F) -> Self
515 where
516 F: Fn(&dyn InvocationContext) -> State + Send + Sync + 'static,
517 {
518 self.input_mapper = Some(Arc::new(mapper));
519 self
520 }
521
522 pub fn output_mapper<F>(mut self, mapper: F) -> Self
524 where
525 F: Fn(&State) -> Vec<Event> + Send + Sync + 'static,
526 {
527 self.output_mapper = Some(Arc::new(mapper));
528 self
529 }
530
531 pub fn before_agent_callback<F, Fut>(mut self, callback: F) -> Self
533 where
534 F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
535 Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
536 {
537 self.before_callback = Some(Arc::new(move |ctx| Box::pin(callback(ctx))));
538 self
539 }
540
541 pub fn after_agent_callback<F, Fut>(mut self, callback: F) -> Self
545 where
546 F: Fn(Arc<dyn InvocationContext>, Event) -> Fut + Send + Sync + 'static,
547 Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
548 {
549 self.after_callback = Some(Arc::new(move |ctx, event| {
550 let event_clone = event.clone();
551 Box::pin(callback(ctx, event_clone))
552 }));
553 self
554 }
555
556 #[cfg(feature = "action")]
562 pub fn action_node(mut self, config: adk_action::ActionNodeConfig) -> Self {
563 use crate::action::ActionNodeExecutor;
564
565 if let adk_action::ActionNodeConfig::Switch(ref switch_config) = config {
567 let conditions = switch_config.conditions.clone();
568 let eval_mode = switch_config.evaluation_mode.clone();
569 let default_branch = switch_config.default_branch.clone();
570 let source = config.standard().id.clone();
571
572 let mut targets_map: HashMap<String, EdgeTarget> = HashMap::new();
573 for condition in &conditions {
574 targets_map.insert(
575 condition.output_port.clone(),
576 EdgeTarget::Node(condition.output_port.clone()),
577 );
578 }
579 if let Some(ref default) = default_branch {
580 let target = if default == END {
581 EdgeTarget::End
582 } else {
583 EdgeTarget::Node(default.clone())
584 };
585 targets_map.insert(default.clone(), target);
586 }
587 targets_map.insert(END.to_string(), EdgeTarget::End);
588
589 let router = Arc::new(move |state: &State| -> String {
590 match crate::action::switch::evaluate_switch_conditions(
591 &conditions,
592 state,
593 &eval_mode,
594 default_branch.as_deref(),
595 ) {
596 Ok(ports) => ports.into_iter().next().unwrap_or_else(|| END.to_string()),
597 Err(_) => END.to_string(),
598 }
599 });
600
601 self.edges.push(Edge::Conditional { source, router, targets: targets_map });
602 }
603
604 let executor = ActionNodeExecutor::new(config);
605 self.nodes.push(Arc::new(executor));
606 self
607 }
608
609 pub fn build(self) -> Result<GraphAgent> {
611 let mut graph = StateGraph::new(self.schema);
613
614 for node in self.nodes {
616 graph.nodes.insert(node.name().to_string(), node);
617 }
618
619 graph.edges = self.edges;
621
622 let mut compiled = graph.compile()?;
624
625 if let Some(cp) = self.checkpointer {
627 compiled.checkpointer = Some(cp);
628 }
629 compiled.interrupt_before = self.interrupt_before.into_iter().collect();
630 compiled.interrupt_after = self.interrupt_after.into_iter().collect();
631 compiled.recursion_limit = self.recursion_limit;
632 compiled.timeout_policies = self.timeout_policies;
633 compiled.default_timeout = self.default_timeout;
634 compiled.deferred_configs = self.deferred_configs;
635
636 #[cfg(feature = "node-cache")]
637 {
638 compiled.cache_policies = self.cache_policies;
639 }
640
641 Ok(GraphAgent {
642 name: self.name,
643 description: self.description,
644 graph: Arc::new(compiled),
645 input_mapper: self.input_mapper.unwrap_or(Arc::new(default_input_mapper)),
646 output_mapper: self.output_mapper.unwrap_or(Arc::new(default_output_mapper)),
647 before_callback: self.before_callback,
648 after_callback: self.after_callback,
649 })
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656 use serde_json::json;
657
658 #[tokio::test]
659 async fn test_graph_agent_builder() {
660 let agent = GraphAgent::builder("test")
661 .description("Test agent")
662 .channels(&["value"])
663 .node_fn("set", |_ctx| async { Ok(NodeOutput::new().with_update("value", json!(42))) })
664 .edge(START, "set")
665 .edge("set", END)
666 .build()
667 .unwrap();
668
669 assert_eq!(agent.name(), "test");
670 assert_eq!(agent.description(), "Test agent");
671
672 let result = agent.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
674
675 assert_eq!(result.get("value"), Some(&json!(42)));
676 }
677}