Skip to main content

adk_graph/
agent.rs

1//! GraphAgent - ADK Agent integration for graph workflows
2//!
3//! Provides a builder pattern similar to LlmAgent and RealtimeAgent.
4
5use crate::checkpoint::Checkpointer;
6use crate::edge::{END, Edge, EdgeTarget, START};
7use crate::error::{GraphError, Result};
8use crate::graph::{CompiledGraph, StateGraph};
9use crate::node::{ExecutionConfig, FunctionNode, Node, NodeContext, NodeOutput};
10use crate::state::{State, StateSchema};
11use crate::stream::{StreamEvent, StreamMode};
12use adk_core::{Agent, Content, Event, EventStream, InvocationContext};
13use async_trait::async_trait;
14use serde_json::json;
15use std::collections::HashMap;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19
20/// Type alias for callbacks
21pub type BeforeAgentCallback = Arc<
22    dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
23        + Send
24        + Sync,
25>;
26
27pub type AfterAgentCallback = Arc<
28    dyn Fn(
29            Arc<dyn InvocationContext>,
30            Event,
31        ) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
32        + Send
33        + Sync,
34>;
35
36/// Type alias for input mapper function
37pub type InputMapper = Arc<dyn Fn(&dyn InvocationContext) -> State + Send + Sync>;
38
39/// Type alias for output mapper function
40pub type OutputMapper = Arc<dyn Fn(&State) -> Vec<Event> + Send + Sync>;
41
42/// GraphAgent wraps a CompiledGraph as an ADK Agent
43pub struct GraphAgent {
44    name: String,
45    description: String,
46    graph: Arc<CompiledGraph>,
47    /// Map InvocationContext to graph input state
48    input_mapper: InputMapper,
49    /// Map graph output state to ADK Events
50    output_mapper: OutputMapper,
51    /// Before agent callback
52    before_callback: Option<BeforeAgentCallback>,
53    /// After agent callback
54    after_callback: Option<AfterAgentCallback>,
55}
56
57impl GraphAgent {
58    /// Create a new GraphAgent builder
59    pub fn builder(name: &str) -> GraphAgentBuilder {
60        GraphAgentBuilder::new(name)
61    }
62
63    /// Create directly from a compiled graph
64    pub fn from_graph(name: &str, graph: CompiledGraph) -> Self {
65        Self {
66            name: name.to_string(),
67            description: String::new(),
68            graph: Arc::new(graph),
69            input_mapper: Arc::new(default_input_mapper),
70            output_mapper: Arc::new(default_output_mapper),
71            before_callback: None,
72            after_callback: None,
73        }
74    }
75
76    /// Build a `GraphAgent` from a `WorkflowSchema`.
77    ///
78    /// Delegates to `schema.build_graph()` to construct the graph from the
79    /// workflow schema's action nodes, edges, and conditions.
80    #[cfg(feature = "action")]
81    pub fn from_workflow_schema(
82        name: &str,
83        schema: &crate::workflow::WorkflowSchema,
84    ) -> Result<Self> {
85        schema.build_graph(name)
86    }
87
88    /// Get the underlying compiled graph
89    pub fn graph(&self) -> &CompiledGraph {
90        &self.graph
91    }
92
93    /// Execute the graph directly (bypassing Agent trait)
94    pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
95        self.graph.invoke(input, config).await
96    }
97
98    /// Stream execution
99    pub fn stream(
100        &self,
101        input: State,
102        config: ExecutionConfig,
103        mode: StreamMode,
104    ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
105        self.graph.stream(input, config, mode)
106    }
107}
108
109#[async_trait]
110impl Agent for GraphAgent {
111    fn name(&self) -> &str {
112        &self.name
113    }
114
115    fn description(&self) -> &str {
116        &self.description
117    }
118
119    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
120        &[]
121    }
122
123    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> adk_core::Result<EventStream> {
124        // Call before callback
125        if let Some(callback) = &self.before_callback {
126            callback(ctx.clone()).await?;
127        }
128
129        // Map context to input state
130        let input = (self.input_mapper)(ctx.as_ref());
131
132        // Create execution config from context
133        let config = ExecutionConfig::new(ctx.session_id());
134
135        // Execute graph
136        let graph = self.graph.clone();
137        let output_mapper = self.output_mapper.clone();
138        let after_callback = self.after_callback.clone();
139        let ctx_clone = ctx.clone();
140
141        let stream = async_stream::stream! {
142            match graph.invoke(input, config).await {
143                Ok(state) => {
144                    let events = output_mapper(&state);
145                    for event in events {
146                        // Call after callback for each event
147                        if let Some(callback) = &after_callback {
148                            if let Err(e) = callback(ctx_clone.clone(), event.clone()).await {
149                                yield Err(e);
150                                return;
151                            }
152                        }
153                        yield Ok(event);
154                    }
155                }
156                Err(GraphError::Interrupted(interrupt)) => {
157                    // Create an interrupt event
158                    let mut event = Event::new("graph_interrupted");
159                    event.set_content(Content::new("assistant").with_text(format!(
160                        "Graph interrupted: {:?}\nThread: {}\nCheckpoint: {}",
161                        interrupt.interrupt,
162                        interrupt.thread_id,
163                        interrupt.checkpoint_id
164                    )));
165                    yield Ok(event);
166                }
167                Err(e) => {
168                    yield Err(adk_core::AdkError::agent(e.to_string()));
169                }
170            }
171        };
172
173        Ok(Box::pin(stream))
174    }
175}
176
177/// Default input mapper - extracts content from InvocationContext
178fn default_input_mapper(ctx: &dyn InvocationContext) -> State {
179    let mut state = State::new();
180
181    // Get user content
182    let content = ctx.user_content();
183    let text: String = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("\n");
184
185    if !text.is_empty() {
186        state.insert("input".to_string(), json!(text));
187        state.insert("messages".to_string(), json!([{"role": "user", "content": text}]));
188    }
189
190    // Add session ID
191    state.insert("session_id".to_string(), json!(ctx.session_id()));
192
193    state
194}
195
196/// Default output mapper - creates events from state
197fn default_output_mapper(state: &State) -> Vec<Event> {
198    let mut events = Vec::new();
199
200    // Try to get output from common fields
201    let output_text = state
202        .get("output")
203        .and_then(|v| v.as_str())
204        .or_else(|| state.get("result").and_then(|v| v.as_str()))
205        .or_else(|| {
206            state
207                .get("messages")
208                .and_then(|v| v.as_array())
209                .and_then(|arr| arr.last())
210                .and_then(|msg| msg.get("content"))
211                .and_then(|c| c.as_str())
212        });
213
214    let text = if let Some(text) = output_text {
215        text.to_string()
216    } else {
217        // Return the full state as JSON
218        serde_json::to_string_pretty(state).unwrap_or_default()
219    };
220
221    let mut event = Event::new("graph_output");
222    event.set_content(Content::new("assistant").with_text(&text));
223    events.push(event);
224
225    events
226}
227
228/// Builder for GraphAgent
229pub struct GraphAgentBuilder {
230    name: String,
231    description: String,
232    schema: StateSchema,
233    nodes: Vec<Arc<dyn Node>>,
234    edges: Vec<Edge>,
235    checkpointer: Option<Arc<dyn Checkpointer>>,
236    interrupt_before: Vec<String>,
237    interrupt_after: Vec<String>,
238    recursion_limit: usize,
239    input_mapper: Option<InputMapper>,
240    output_mapper: Option<OutputMapper>,
241    before_callback: Option<BeforeAgentCallback>,
242    after_callback: Option<AfterAgentCallback>,
243}
244
245impl GraphAgentBuilder {
246    /// Create a new builder
247    pub fn new(name: &str) -> Self {
248        Self {
249            name: name.to_string(),
250            description: String::new(),
251            schema: StateSchema::simple(&["input", "output", "messages"]),
252            nodes: vec![],
253            edges: vec![],
254            checkpointer: None,
255            interrupt_before: vec![],
256            interrupt_after: vec![],
257            recursion_limit: 50,
258            input_mapper: None,
259            output_mapper: None,
260            before_callback: None,
261            after_callback: None,
262        }
263    }
264
265    /// Set description
266    pub fn description(mut self, desc: &str) -> Self {
267        self.description = desc.to_string();
268        self
269    }
270
271    /// Set state schema
272    pub fn state_schema(mut self, schema: StateSchema) -> Self {
273        self.schema = schema;
274        self
275    }
276
277    /// Add channels to state schema
278    pub fn channels(mut self, channels: &[&str]) -> Self {
279        self.schema = StateSchema::simple(channels);
280        self
281    }
282
283    /// Add a node
284    pub fn node<N: Node + 'static>(mut self, node: N) -> Self {
285        self.nodes.push(Arc::new(node));
286        self
287    }
288
289    /// Add a function as a node
290    pub fn node_fn<F, Fut>(mut self, name: &str, func: F) -> Self
291    where
292        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
293        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
294    {
295        self.nodes.push(Arc::new(FunctionNode::new(name, func)));
296        self
297    }
298
299    /// Add a direct edge
300    pub fn edge(mut self, source: &str, target: &str) -> Self {
301        let target =
302            if target == END { EdgeTarget::End } else { EdgeTarget::Node(target.to_string()) };
303
304        if source == START {
305            let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
306            match entry_idx {
307                Some(idx) => {
308                    if let Edge::Entry { targets } = &mut self.edges[idx] {
309                        if let EdgeTarget::Node(node) = &target {
310                            if !targets.contains(node) {
311                                targets.push(node.clone());
312                            }
313                        }
314                    }
315                }
316                None => {
317                    if let EdgeTarget::Node(node) = target {
318                        self.edges.push(Edge::Entry { targets: vec![node] });
319                    }
320                }
321            }
322        } else {
323            self.edges.push(Edge::Direct { source: source.to_string(), target });
324        }
325
326        self
327    }
328
329    /// Add a conditional edge
330    pub fn conditional_edge<F, I>(mut self, source: &str, router: F, targets: I) -> Self
331    where
332        F: Fn(&State) -> String + Send + Sync + 'static,
333        I: IntoIterator<Item = (&'static str, &'static str)>,
334    {
335        let targets_map: HashMap<String, EdgeTarget> = targets
336            .into_iter()
337            .map(|(k, v)| {
338                let target =
339                    if v == END { EdgeTarget::End } else { EdgeTarget::Node(v.to_string()) };
340                (k.to_string(), target)
341            })
342            .collect();
343
344        self.edges.push(Edge::Conditional {
345            source: source.to_string(),
346            router: Arc::new(router),
347            targets: targets_map,
348        });
349
350        self
351    }
352
353    /// Set checkpointer
354    pub fn checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
355        self.checkpointer = Some(Arc::new(checkpointer));
356        self
357    }
358
359    /// Set checkpointer with Arc
360    pub fn checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
361        self.checkpointer = Some(checkpointer);
362        self
363    }
364
365    /// Set nodes to interrupt before
366    pub fn interrupt_before(mut self, nodes: &[&str]) -> Self {
367        self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
368        self
369    }
370
371    /// Set nodes to interrupt after
372    pub fn interrupt_after(mut self, nodes: &[&str]) -> Self {
373        self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
374        self
375    }
376
377    /// Set recursion limit
378    pub fn recursion_limit(mut self, limit: usize) -> Self {
379        self.recursion_limit = limit;
380        self
381    }
382
383    /// Set custom input mapper
384    pub fn input_mapper<F>(mut self, mapper: F) -> Self
385    where
386        F: Fn(&dyn InvocationContext) -> State + Send + Sync + 'static,
387    {
388        self.input_mapper = Some(Arc::new(mapper));
389        self
390    }
391
392    /// Set custom output mapper
393    pub fn output_mapper<F>(mut self, mapper: F) -> Self
394    where
395        F: Fn(&State) -> Vec<Event> + Send + Sync + 'static,
396    {
397        self.output_mapper = Some(Arc::new(mapper));
398        self
399    }
400
401    /// Set before agent callback
402    pub fn before_agent_callback<F, Fut>(mut self, callback: F) -> Self
403    where
404        F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
405        Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
406    {
407        self.before_callback = Some(Arc::new(move |ctx| Box::pin(callback(ctx))));
408        self
409    }
410
411    /// Set after agent callback
412    ///
413    /// Note: The callback receives a cloned Event to avoid lifetime issues.
414    pub fn after_agent_callback<F, Fut>(mut self, callback: F) -> Self
415    where
416        F: Fn(Arc<dyn InvocationContext>, Event) -> Fut + Send + Sync + 'static,
417        Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
418    {
419        self.after_callback = Some(Arc::new(move |ctx, event| {
420            let event_clone = event.clone();
421            Box::pin(callback(ctx, event_clone))
422        }));
423        self
424    }
425
426    /// Add an action node to the graph.
427    ///
428    /// Wraps the `ActionNodeConfig` in an `ActionNodeExecutor` and registers it
429    /// as a node. If the config is a `SwitchNodeConfig`, conditional edges are
430    /// also auto-registered from the switch conditions.
431    #[cfg(feature = "action")]
432    pub fn action_node(mut self, config: adk_action::ActionNodeConfig) -> Self {
433        use crate::action::ActionNodeExecutor;
434
435        // If this is a Switch node, register conditional edges
436        if let adk_action::ActionNodeConfig::Switch(ref switch_config) = config {
437            let conditions = switch_config.conditions.clone();
438            let eval_mode = switch_config.evaluation_mode.clone();
439            let default_branch = switch_config.default_branch.clone();
440            let source = config.standard().id.clone();
441
442            let mut targets_map: HashMap<String, EdgeTarget> = HashMap::new();
443            for condition in &conditions {
444                targets_map.insert(
445                    condition.output_port.clone(),
446                    EdgeTarget::Node(condition.output_port.clone()),
447                );
448            }
449            if let Some(ref default) = default_branch {
450                let target = if default == END {
451                    EdgeTarget::End
452                } else {
453                    EdgeTarget::Node(default.clone())
454                };
455                targets_map.insert(default.clone(), target);
456            }
457            targets_map.insert(END.to_string(), EdgeTarget::End);
458
459            let router = Arc::new(move |state: &State| -> String {
460                match crate::action::switch::evaluate_switch_conditions(
461                    &conditions,
462                    state,
463                    &eval_mode,
464                    default_branch.as_deref(),
465                ) {
466                    Ok(ports) => ports.into_iter().next().unwrap_or_else(|| END.to_string()),
467                    Err(_) => END.to_string(),
468                }
469            });
470
471            self.edges.push(Edge::Conditional { source, router, targets: targets_map });
472        }
473
474        let executor = ActionNodeExecutor::new(config);
475        self.nodes.push(Arc::new(executor));
476        self
477    }
478
479    /// Build the GraphAgent
480    pub fn build(self) -> Result<GraphAgent> {
481        // Build the graph
482        let mut graph = StateGraph::new(self.schema);
483
484        // Add nodes
485        for node in self.nodes {
486            graph.nodes.insert(node.name().to_string(), node);
487        }
488
489        // Add edges
490        graph.edges = self.edges;
491
492        // Compile
493        let mut compiled = graph.compile()?;
494
495        // Configure
496        if let Some(cp) = self.checkpointer {
497            compiled.checkpointer = Some(cp);
498        }
499        compiled.interrupt_before = self.interrupt_before.into_iter().collect();
500        compiled.interrupt_after = self.interrupt_after.into_iter().collect();
501        compiled.recursion_limit = self.recursion_limit;
502
503        Ok(GraphAgent {
504            name: self.name,
505            description: self.description,
506            graph: Arc::new(compiled),
507            input_mapper: self.input_mapper.unwrap_or(Arc::new(default_input_mapper)),
508            output_mapper: self.output_mapper.unwrap_or(Arc::new(default_output_mapper)),
509            before_callback: self.before_callback,
510            after_callback: self.after_callback,
511        })
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use serde_json::json;
519
520    #[tokio::test]
521    async fn test_graph_agent_builder() {
522        let agent = GraphAgent::builder("test")
523            .description("Test agent")
524            .channels(&["value"])
525            .node_fn("set", |_ctx| async { Ok(NodeOutput::new().with_update("value", json!(42))) })
526            .edge(START, "set")
527            .edge("set", END)
528            .build()
529            .unwrap();
530
531        assert_eq!(agent.name(), "test");
532        assert_eq!(agent.description(), "Test agent");
533
534        // Test direct invocation
535        let result = agent.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
536
537        assert_eq!(result.get("value"), Some(&json!(42)));
538    }
539}