Skip to main content

adk_graph/
agent.rs

1//! GraphAgent - ADK Agent integration for graph workflows
2//!
3//! Provides a builder pattern similar to LlmAgent and RealtimeAgent.
4
5use crate::checkpoint::Checkpointer;
6use crate::deferred::DeferredNodeConfig;
7use crate::edge::{END, Edge, EdgeTarget, START};
8use crate::error::{GraphError, Result};
9use crate::graph::{CompiledGraph, StateGraph};
10use crate::node::{ExecutionConfig, FunctionNode, Node, NodeContext, NodeOutput};
11use crate::state::{State, StateSchema};
12use crate::stream::{StreamEvent, StreamMode};
13use crate::timeout::TimeoutPolicy;
14use adk_core::{Agent, Content, Event, EventStream, InvocationContext};
15use async_trait::async_trait;
16use serde_json::json;
17use std::collections::HashMap;
18use std::future::Future;
19use std::pin::Pin;
20use std::sync::Arc;
21
22/// Type alias for callbacks
23pub type BeforeAgentCallback = Arc<
24    dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
25        + Send
26        + Sync,
27>;
28
29pub type AfterAgentCallback = Arc<
30    dyn Fn(
31            Arc<dyn InvocationContext>,
32            Event,
33        ) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
34        + Send
35        + Sync,
36>;
37
38/// Type alias for input mapper function
39pub type InputMapper = Arc<dyn Fn(&dyn InvocationContext) -> State + Send + Sync>;
40
41/// Type alias for output mapper function
42pub type OutputMapper = Arc<dyn Fn(&State) -> Vec<Event> + Send + Sync>;
43
44/// GraphAgent wraps a CompiledGraph as an ADK Agent
45pub struct GraphAgent {
46    name: String,
47    description: String,
48    graph: Arc<CompiledGraph>,
49    /// Map InvocationContext to graph input state
50    input_mapper: InputMapper,
51    /// Map graph output state to ADK Events
52    output_mapper: OutputMapper,
53    /// Before agent callback
54    before_callback: Option<BeforeAgentCallback>,
55    /// After agent callback
56    after_callback: Option<AfterAgentCallback>,
57}
58
59impl GraphAgent {
60    /// Create a new GraphAgent builder
61    pub fn builder(name: &str) -> GraphAgentBuilder {
62        GraphAgentBuilder::new(name)
63    }
64
65    /// Create directly from a compiled graph
66    pub fn from_graph(name: &str, graph: CompiledGraph) -> Self {
67        Self {
68            name: name.to_string(),
69            description: String::new(),
70            graph: Arc::new(graph),
71            input_mapper: Arc::new(default_input_mapper),
72            output_mapper: Arc::new(default_output_mapper),
73            before_callback: None,
74            after_callback: None,
75        }
76    }
77
78    /// Build a `GraphAgent` from a `WorkflowSchema`.
79    ///
80    /// Delegates to `schema.build_graph()` to construct the graph from the
81    /// workflow schema's action nodes, edges, and conditions.
82    #[cfg(feature = "action")]
83    pub fn from_workflow_schema(
84        name: &str,
85        schema: &crate::workflow::WorkflowSchema,
86    ) -> Result<Self> {
87        schema.build_graph(name)
88    }
89
90    /// Get the underlying compiled graph
91    pub fn graph(&self) -> &CompiledGraph {
92        &self.graph
93    }
94
95    /// Execute the graph directly (bypassing Agent trait)
96    pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
97        self.graph.invoke(input, config).await
98    }
99
100    /// Stream execution
101    pub fn stream(
102        &self,
103        input: State,
104        config: ExecutionConfig,
105        mode: StreamMode,
106    ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
107        self.graph.stream(input, config, mode)
108    }
109}
110
111#[async_trait]
112impl Agent for GraphAgent {
113    fn name(&self) -> &str {
114        &self.name
115    }
116
117    fn description(&self) -> &str {
118        &self.description
119    }
120
121    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
122        &[]
123    }
124
125    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> adk_core::Result<EventStream> {
126        // Call before callback
127        if let Some(callback) = &self.before_callback {
128            callback(ctx.clone()).await?;
129        }
130
131        // Map context to input state
132        let input = (self.input_mapper)(ctx.as_ref());
133
134        // Create execution config from context
135        let config = ExecutionConfig::new(ctx.session_id());
136
137        // Execute graph
138        let graph = self.graph.clone();
139        let output_mapper = self.output_mapper.clone();
140        let after_callback = self.after_callback.clone();
141        let ctx_clone = ctx.clone();
142
143        let stream = async_stream::stream! {
144            match graph.invoke(input, config).await {
145                Ok(state) => {
146                    let events = output_mapper(&state);
147                    for event in events {
148                        // Call after callback for each event
149                        if let Some(callback) = &after_callback {
150                            if let Err(e) = callback(ctx_clone.clone(), event.clone()).await {
151                                yield Err(e);
152                                return;
153                            }
154                        }
155                        yield Ok(event);
156                    }
157                }
158                Err(GraphError::Interrupted(interrupt)) => {
159                    // Create an interrupt event
160                    let mut event = Event::new("graph_interrupted");
161                    event.set_content(Content::new("assistant").with_text(format!(
162                        "Graph interrupted: {:?}\nThread: {}\nCheckpoint: {}",
163                        interrupt.interrupt,
164                        interrupt.thread_id,
165                        interrupt.checkpoint_id
166                    )));
167                    yield Ok(event);
168                }
169                Err(e) => {
170                    yield Err(adk_core::AdkError::agent(e.to_string()));
171                }
172            }
173        };
174
175        Ok(Box::pin(stream))
176    }
177}
178
179/// Default input mapper - extracts content from InvocationContext
180fn default_input_mapper(ctx: &dyn InvocationContext) -> State {
181    let mut state = State::new();
182
183    // Get user content
184    let content = ctx.user_content();
185    let text: String = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("\n");
186
187    if !text.is_empty() {
188        state.insert("input".to_string(), json!(text));
189        state.insert("messages".to_string(), json!([{"role": "user", "content": text}]));
190    }
191
192    // Add session ID
193    state.insert("session_id".to_string(), json!(ctx.session_id()));
194
195    state
196}
197
198/// Default output mapper - creates events from state
199fn default_output_mapper(state: &State) -> Vec<Event> {
200    let mut events = Vec::new();
201
202    // Try to get output from common fields
203    let output_text = state
204        .get("output")
205        .and_then(|v| v.as_str())
206        .or_else(|| state.get("result").and_then(|v| v.as_str()))
207        .or_else(|| {
208            state
209                .get("messages")
210                .and_then(|v| v.as_array())
211                .and_then(|arr| arr.last())
212                .and_then(|msg| msg.get("content"))
213                .and_then(|c| c.as_str())
214        });
215
216    let text = if let Some(text) = output_text {
217        text.to_string()
218    } else {
219        // Return the full state as JSON
220        serde_json::to_string_pretty(state).unwrap_or_default()
221    };
222
223    let mut event = Event::new("graph_output");
224    event.set_content(Content::new("assistant").with_text(&text));
225    events.push(event);
226
227    events
228}
229
230/// Builder for GraphAgent
231pub struct GraphAgentBuilder {
232    name: String,
233    description: String,
234    schema: StateSchema,
235    nodes: Vec<Arc<dyn Node>>,
236    edges: Vec<Edge>,
237    checkpointer: Option<Arc<dyn Checkpointer>>,
238    interrupt_before: Vec<String>,
239    interrupt_after: Vec<String>,
240    recursion_limit: usize,
241    input_mapper: Option<InputMapper>,
242    output_mapper: Option<OutputMapper>,
243    before_callback: Option<BeforeAgentCallback>,
244    after_callback: Option<AfterAgentCallback>,
245    timeout_policies: HashMap<String, TimeoutPolicy>,
246    default_timeout: Option<TimeoutPolicy>,
247    deferred_configs: HashMap<String, DeferredNodeConfig>,
248    #[cfg(feature = "node-cache")]
249    cache_policies: HashMap<String, crate::cache::NodeCachePolicy>,
250}
251
252impl GraphAgentBuilder {
253    /// Create a new builder
254    pub fn new(name: &str) -> Self {
255        Self {
256            name: name.to_string(),
257            description: String::new(),
258            schema: StateSchema::simple(&["input", "output", "messages"]),
259            nodes: vec![],
260            edges: vec![],
261            checkpointer: None,
262            interrupt_before: vec![],
263            interrupt_after: vec![],
264            recursion_limit: 50,
265            input_mapper: None,
266            output_mapper: None,
267            before_callback: None,
268            after_callback: None,
269            timeout_policies: HashMap::new(),
270            default_timeout: None,
271            deferred_configs: HashMap::new(),
272            #[cfg(feature = "node-cache")]
273            cache_policies: HashMap::new(),
274        }
275    }
276
277    /// Set description
278    pub fn description(mut self, desc: &str) -> Self {
279        self.description = desc.to_string();
280        self
281    }
282
283    /// Set state schema
284    pub fn state_schema(mut self, schema: StateSchema) -> Self {
285        self.schema = schema;
286        self
287    }
288
289    /// Add channels to state schema
290    pub fn channels(mut self, channels: &[&str]) -> Self {
291        self.schema = StateSchema::simple(channels);
292        self
293    }
294
295    /// Add a node
296    pub fn node<N: Node + 'static>(mut self, node: N) -> Self {
297        self.nodes.push(Arc::new(node));
298        self
299    }
300
301    /// Add a function as a node
302    pub fn node_fn<F, Fut>(mut self, name: &str, func: F) -> Self
303    where
304        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
305        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
306    {
307        self.nodes.push(Arc::new(FunctionNode::new(name, func)));
308        self
309    }
310
311    /// Add a direct edge
312    pub fn edge(mut self, source: &str, target: &str) -> Self {
313        let target =
314            if target == END { EdgeTarget::End } else { EdgeTarget::Node(target.to_string()) };
315
316        if source == START {
317            let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
318            match entry_idx {
319                Some(idx) => {
320                    if let Edge::Entry { targets } = &mut self.edges[idx] {
321                        if let EdgeTarget::Node(node) = &target {
322                            if !targets.contains(node) {
323                                targets.push(node.clone());
324                            }
325                        }
326                    }
327                }
328                None => {
329                    if let EdgeTarget::Node(node) = target {
330                        self.edges.push(Edge::Entry { targets: vec![node] });
331                    }
332                }
333            }
334        } else {
335            self.edges.push(Edge::Direct { source: source.to_string(), target });
336        }
337
338        self
339    }
340
341    /// Add a conditional edge
342    pub fn conditional_edge<F, I>(mut self, source: &str, router: F, targets: I) -> Self
343    where
344        F: Fn(&State) -> String + Send + Sync + 'static,
345        I: IntoIterator<Item = (&'static str, &'static str)>,
346    {
347        let targets_map: HashMap<String, EdgeTarget> = targets
348            .into_iter()
349            .map(|(k, v)| {
350                let target =
351                    if v == END { EdgeTarget::End } else { EdgeTarget::Node(v.to_string()) };
352                (k.to_string(), target)
353            })
354            .collect();
355
356        self.edges.push(Edge::Conditional {
357            source: source.to_string(),
358            router: Arc::new(router),
359            targets: targets_map,
360        });
361
362        self
363    }
364
365    /// Set checkpointer
366    pub fn checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
367        self.checkpointer = Some(Arc::new(checkpointer));
368        self
369    }
370
371    /// Set checkpointer with Arc
372    pub fn checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
373        self.checkpointer = Some(checkpointer);
374        self
375    }
376
377    /// Set nodes to interrupt before
378    pub fn interrupt_before(mut self, nodes: &[&str]) -> Self {
379        self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
380        self
381    }
382
383    /// Set nodes to interrupt after
384    pub fn interrupt_after(mut self, nodes: &[&str]) -> Self {
385        self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
386        self
387    }
388
389    /// Set recursion limit
390    pub fn recursion_limit(mut self, limit: usize) -> Self {
391        self.recursion_limit = limit;
392        self
393    }
394
395    /// Set a timeout policy for a specific node.
396    ///
397    /// The policy is applied when the named node executes, enforcing
398    /// wall-clock and/or idle timeouts with the configured recovery action.
399    ///
400    /// # Example
401    ///
402    /// ```rust,ignore
403    /// use std::time::Duration;
404    /// use adk_graph::timeout::{TimeoutPolicy, OnTimeout};
405    ///
406    /// let agent = GraphAgent::builder("my_graph")
407    ///     .node_timeout("slow_node", TimeoutPolicy {
408    ///         run_timeout: Some(Duration::from_secs(10)),
409    ///         idle_timeout: None,
410    ///         on_timeout: OnTimeout::Fail,
411    ///     })
412    ///     .build()?;
413    /// ```
414    pub fn node_timeout(mut self, node_name: &str, policy: TimeoutPolicy) -> Self {
415        self.timeout_policies.insert(node_name.to_string(), policy);
416        self
417    }
418
419    /// Set a default timeout policy applied to all nodes without an explicit override.
420    ///
421    /// Nodes that have a per-node policy set via [`node_timeout`](Self::node_timeout)
422    /// will use their specific policy instead of this default.
423    ///
424    /// # Example
425    ///
426    /// ```rust,ignore
427    /// use std::time::Duration;
428    /// use adk_graph::timeout::{TimeoutPolicy, OnTimeout};
429    ///
430    /// let agent = GraphAgent::builder("my_graph")
431    ///     .default_timeout(TimeoutPolicy {
432    ///         run_timeout: Some(Duration::from_secs(30)),
433    ///         idle_timeout: Some(Duration::from_secs(5)),
434    ///         on_timeout: OnTimeout::Skip,
435    ///     })
436    ///     .build()?;
437    /// ```
438    pub fn default_timeout(mut self, policy: TimeoutPolicy) -> Self {
439        self.default_timeout = Some(policy);
440        self
441    }
442
443    /// Add a deferred (fan-in barrier) node to the graph.
444    ///
445    /// A deferred node waits for all upstream parallel paths to complete before
446    /// executing. The provided function is wrapped as a [`FunctionNode`] and the
447    /// [`DeferredNodeConfig`] controls how upstream outputs are merged and how
448    /// long the node waits for all paths.
449    ///
450    /// # Arguments
451    ///
452    /// * `name` - The name of the deferred node.
453    /// * `func` - The async function to execute once all upstream paths complete.
454    /// * `config` - Configuration controlling merge strategy and fan-in timeout.
455    ///
456    /// # Example
457    ///
458    /// ```rust,ignore
459    /// use std::time::Duration;
460    /// use adk_graph::deferred::{DeferredNodeConfig, MergeStrategy};
461    /// use adk_graph::node::NodeOutput;
462    ///
463    /// let agent = GraphAgent::builder("scatter_gather")
464    ///     .deferred_node("aggregator", |_ctx| async {
465    ///         Ok(NodeOutput::new().with_update("status", serde_json::json!("merged")))
466    ///     }, DeferredNodeConfig {
467    ///         merge_strategy: MergeStrategy::Collect,
468    ///         fan_in_timeout: Some(Duration::from_secs(30)),
469    ///     })
470    ///     .build()?;
471    /// ```
472    pub fn deferred_node<F, Fut>(mut self, name: &str, func: F, config: DeferredNodeConfig) -> Self
473    where
474        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
475        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
476    {
477        self.nodes.push(Arc::new(FunctionNode::new(name, func)));
478        self.deferred_configs.insert(name.to_string(), config);
479        self
480    }
481
482    /// Set a cache policy for a specific node.
483    ///
484    /// When a node has a cache policy, its execution results are cached keyed
485    /// by a blake3 hash of the node name and input state. Subsequent executions
486    /// with identical inputs return the cached result without re-executing the
487    /// node.
488    ///
489    /// # Arguments
490    ///
491    /// * `name` — the name of the node to cache
492    /// * `policy` — the cache policy specifying backend and TTL
493    ///
494    /// # Example
495    ///
496    /// ```rust,ignore
497    /// use std::time::Duration;
498    /// use adk_graph::cache::{CacheBackend, NodeCachePolicy};
499    ///
500    /// let agent = GraphAgent::builder("cached_graph")
501    ///     .node_cache("expensive_node", NodeCachePolicy {
502    ///         backend: CacheBackend::InMemory { max_entries: 128 },
503    ///         ttl: Some(Duration::from_secs(300)),
504    ///     })
505    ///     .build()?;
506    /// ```
507    #[cfg(feature = "node-cache")]
508    pub fn node_cache(mut self, name: &str, policy: crate::cache::NodeCachePolicy) -> Self {
509        self.cache_policies.insert(name.to_string(), policy);
510        self
511    }
512
513    /// Set custom input mapper
514    pub fn input_mapper<F>(mut self, mapper: F) -> Self
515    where
516        F: Fn(&dyn InvocationContext) -> State + Send + Sync + 'static,
517    {
518        self.input_mapper = Some(Arc::new(mapper));
519        self
520    }
521
522    /// Set custom output mapper
523    pub fn output_mapper<F>(mut self, mapper: F) -> Self
524    where
525        F: Fn(&State) -> Vec<Event> + Send + Sync + 'static,
526    {
527        self.output_mapper = Some(Arc::new(mapper));
528        self
529    }
530
531    /// Set before agent callback
532    pub fn before_agent_callback<F, Fut>(mut self, callback: F) -> Self
533    where
534        F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
535        Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
536    {
537        self.before_callback = Some(Arc::new(move |ctx| Box::pin(callback(ctx))));
538        self
539    }
540
541    /// Set after agent callback
542    ///
543    /// Note: The callback receives a cloned Event to avoid lifetime issues.
544    pub fn after_agent_callback<F, Fut>(mut self, callback: F) -> Self
545    where
546        F: Fn(Arc<dyn InvocationContext>, Event) -> Fut + Send + Sync + 'static,
547        Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
548    {
549        self.after_callback = Some(Arc::new(move |ctx, event| {
550            let event_clone = event.clone();
551            Box::pin(callback(ctx, event_clone))
552        }));
553        self
554    }
555
556    /// Add an action node to the graph.
557    ///
558    /// Wraps the `ActionNodeConfig` in an `ActionNodeExecutor` and registers it
559    /// as a node. If the config is a `SwitchNodeConfig`, conditional edges are
560    /// also auto-registered from the switch conditions.
561    #[cfg(feature = "action")]
562    pub fn action_node(mut self, config: adk_action::ActionNodeConfig) -> Self {
563        use crate::action::ActionNodeExecutor;
564
565        // If this is a Switch node, register conditional edges
566        if let adk_action::ActionNodeConfig::Switch(ref switch_config) = config {
567            let conditions = switch_config.conditions.clone();
568            let eval_mode = switch_config.evaluation_mode.clone();
569            let default_branch = switch_config.default_branch.clone();
570            let source = config.standard().id.clone();
571
572            let mut targets_map: HashMap<String, EdgeTarget> = HashMap::new();
573            for condition in &conditions {
574                targets_map.insert(
575                    condition.output_port.clone(),
576                    EdgeTarget::Node(condition.output_port.clone()),
577                );
578            }
579            if let Some(ref default) = default_branch {
580                let target = if default == END {
581                    EdgeTarget::End
582                } else {
583                    EdgeTarget::Node(default.clone())
584                };
585                targets_map.insert(default.clone(), target);
586            }
587            targets_map.insert(END.to_string(), EdgeTarget::End);
588
589            let router = Arc::new(move |state: &State| -> String {
590                match crate::action::switch::evaluate_switch_conditions(
591                    &conditions,
592                    state,
593                    &eval_mode,
594                    default_branch.as_deref(),
595                ) {
596                    Ok(ports) => ports.into_iter().next().unwrap_or_else(|| END.to_string()),
597                    Err(_) => END.to_string(),
598                }
599            });
600
601            self.edges.push(Edge::Conditional { source, router, targets: targets_map });
602        }
603
604        let executor = ActionNodeExecutor::new(config);
605        self.nodes.push(Arc::new(executor));
606        self
607    }
608
609    /// Build the GraphAgent
610    pub fn build(self) -> Result<GraphAgent> {
611        // Build the graph
612        let mut graph = StateGraph::new(self.schema);
613
614        // Add nodes
615        for node in self.nodes {
616            graph.nodes.insert(node.name().to_string(), node);
617        }
618
619        // Add edges
620        graph.edges = self.edges;
621
622        // Compile
623        let mut compiled = graph.compile()?;
624
625        // Configure
626        if let Some(cp) = self.checkpointer {
627            compiled.checkpointer = Some(cp);
628        }
629        compiled.interrupt_before = self.interrupt_before.into_iter().collect();
630        compiled.interrupt_after = self.interrupt_after.into_iter().collect();
631        compiled.recursion_limit = self.recursion_limit;
632        compiled.timeout_policies = self.timeout_policies;
633        compiled.default_timeout = self.default_timeout;
634        compiled.deferred_configs = self.deferred_configs;
635
636        #[cfg(feature = "node-cache")]
637        {
638            compiled.cache_policies = self.cache_policies;
639        }
640
641        Ok(GraphAgent {
642            name: self.name,
643            description: self.description,
644            graph: Arc::new(compiled),
645            input_mapper: self.input_mapper.unwrap_or(Arc::new(default_input_mapper)),
646            output_mapper: self.output_mapper.unwrap_or(Arc::new(default_output_mapper)),
647            before_callback: self.before_callback,
648            after_callback: self.after_callback,
649        })
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656    use serde_json::json;
657
658    #[tokio::test]
659    async fn test_graph_agent_builder() {
660        let agent = GraphAgent::builder("test")
661            .description("Test agent")
662            .channels(&["value"])
663            .node_fn("set", |_ctx| async { Ok(NodeOutput::new().with_update("value", json!(42))) })
664            .edge(START, "set")
665            .edge("set", END)
666            .build()
667            .unwrap();
668
669        assert_eq!(agent.name(), "test");
670        assert_eq!(agent.description(), "Test agent");
671
672        // Test direct invocation
673        let result = agent.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
674
675        assert_eq!(result.get("value"), Some(&json!(42)));
676    }
677}