Skip to main content

adk_graph/
agent.rs

1//! GraphAgent - ADK Agent integration for graph workflows
2//!
3//! Provides a builder pattern similar to LlmAgent and RealtimeAgent.
4
5use crate::checkpoint::Checkpointer;
6use crate::deferred::DeferredNodeConfig;
7use crate::edge::{END, Edge, EdgeTarget, START};
8use crate::error::{GraphError, Result};
9use crate::graph::{CompiledGraph, StateGraph};
10use crate::node::{ExecutionConfig, FunctionNode, Node, NodeContext, NodeOutput};
11use crate::state::{State, StateSchema};
12use crate::stream::{StreamEvent, StreamMode};
13use crate::timeout::TimeoutPolicy;
14use adk_core::{Agent, Content, Event, EventStream, InvocationContext};
15use async_trait::async_trait;
16use serde_json::json;
17use std::collections::HashMap;
18use std::future::Future;
19use std::pin::Pin;
20use std::sync::Arc;
21
22/// Type alias for callbacks
23pub type BeforeAgentCallback = Arc<
24    dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
25        + Send
26        + Sync,
27>;
28
29pub type AfterAgentCallback = Arc<
30    dyn Fn(
31            Arc<dyn InvocationContext>,
32            Event,
33        ) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
34        + Send
35        + Sync,
36>;
37
38/// Type alias for input mapper function
39pub type InputMapper = Arc<dyn Fn(&dyn InvocationContext) -> State + Send + Sync>;
40
41/// Type alias for output mapper function
42pub type OutputMapper = Arc<dyn Fn(&State) -> Vec<Event> + Send + Sync>;
43
44/// GraphAgent wraps a CompiledGraph as an ADK Agent
45pub struct GraphAgent {
46    name: String,
47    description: String,
48    graph: Arc<CompiledGraph>,
49    /// Map InvocationContext to graph input state
50    input_mapper: InputMapper,
51    /// Map graph output state to ADK Events
52    output_mapper: OutputMapper,
53    /// Before agent callback
54    before_callback: Option<BeforeAgentCallback>,
55    /// After agent callback
56    after_callback: Option<AfterAgentCallback>,
57}
58
59impl GraphAgent {
60    /// Create a new GraphAgent builder
61    pub fn builder(name: &str) -> GraphAgentBuilder {
62        GraphAgentBuilder::new(name)
63    }
64
65    /// Create directly from a compiled graph
66    pub fn from_graph(name: &str, graph: CompiledGraph) -> Self {
67        Self {
68            name: name.to_string(),
69            description: String::new(),
70            graph: Arc::new(graph),
71            input_mapper: Arc::new(default_input_mapper),
72            output_mapper: Arc::new(default_output_mapper),
73            before_callback: None,
74            after_callback: None,
75        }
76    }
77
78    /// Build a `GraphAgent` from a `WorkflowSchema`.
79    ///
80    /// Delegates to `schema.build_graph()` to construct the graph from the
81    /// workflow schema's action nodes, edges, and conditions.
82    #[cfg(feature = "action")]
83    pub fn from_workflow_schema(
84        name: &str,
85        schema: &crate::workflow::WorkflowSchema,
86    ) -> Result<Self> {
87        schema.build_graph(name)
88    }
89
90    /// Get the underlying compiled graph
91    pub fn graph(&self) -> &CompiledGraph {
92        &self.graph
93    }
94
95    /// Execute the graph directly (bypassing Agent trait)
96    pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
97        self.graph.invoke(input, config).await
98    }
99
100    /// Stream execution
101    pub fn stream(
102        &self,
103        input: State,
104        config: ExecutionConfig,
105        mode: StreamMode,
106    ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
107        self.graph.stream(input, config, mode)
108    }
109}
110
111#[async_trait]
112impl Agent for GraphAgent {
113    fn name(&self) -> &str {
114        &self.name
115    }
116
117    fn description(&self) -> &str {
118        &self.description
119    }
120
121    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
122        &[]
123    }
124
125    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> adk_core::Result<EventStream> {
126        // Call before callback
127        if let Some(callback) = &self.before_callback {
128            callback(ctx.clone()).await?;
129        }
130
131        // Map context to input state
132        let input = (self.input_mapper)(ctx.as_ref());
133
134        // Create execution config from context
135        let config = ExecutionConfig::new(ctx.session_id());
136
137        // Execute graph
138        let graph = self.graph.clone();
139        let output_mapper = self.output_mapper.clone();
140        let after_callback = self.after_callback.clone();
141        let ctx_clone = ctx.clone();
142
143        let stream = async_stream::stream! {
144            match graph.invoke(input, config).await {
145                Ok(state) => {
146                    let events = output_mapper(&state);
147                    for event in events {
148                        // Call after callback for each event
149                        if let Some(callback) = &after_callback
150                            && let Err(e) = callback(ctx_clone.clone(), event.clone()).await {
151                                yield Err(e);
152                                return;
153                            }
154                        yield Ok(event);
155                    }
156                }
157                Err(GraphError::Interrupted(interrupt)) => {
158                    // Create an interrupt event
159                    let mut event = Event::new("graph_interrupted");
160                    event.set_content(Content::new("assistant").with_text(format!(
161                        "Graph interrupted: {:?}\nThread: {}\nCheckpoint: {}",
162                        interrupt.interrupt,
163                        interrupt.thread_id,
164                        interrupt.checkpoint_id
165                    )));
166                    yield Ok(event);
167                }
168                Err(e) => {
169                    yield Err(adk_core::AdkError::agent(e.to_string()));
170                }
171            }
172        };
173
174        Ok(Box::pin(stream))
175    }
176}
177
178/// Default input mapper - extracts content from InvocationContext
179fn default_input_mapper(ctx: &dyn InvocationContext) -> State {
180    let mut state = State::new();
181
182    // Get user content
183    let content = ctx.user_content();
184    let text: String = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("\n");
185
186    if !text.is_empty() {
187        state.insert("input".to_string(), json!(text));
188        state.insert("messages".to_string(), json!([{"role": "user", "content": text}]));
189    }
190
191    // Add session ID
192    state.insert("session_id".to_string(), json!(ctx.session_id()));
193
194    state
195}
196
197/// Default output mapper - creates events from state
198fn default_output_mapper(state: &State) -> Vec<Event> {
199    let mut events = Vec::new();
200
201    // Try to get output from common fields
202    let output_text = state
203        .get("output")
204        .and_then(|v| v.as_str())
205        .or_else(|| state.get("result").and_then(|v| v.as_str()))
206        .or_else(|| {
207            state
208                .get("messages")
209                .and_then(|v| v.as_array())
210                .and_then(|arr| arr.last())
211                .and_then(|msg| msg.get("content"))
212                .and_then(|c| c.as_str())
213        });
214
215    let text = if let Some(text) = output_text {
216        text.to_string()
217    } else {
218        // Return the full state as JSON
219        serde_json::to_string_pretty(state).unwrap_or_default()
220    };
221
222    let mut event = Event::new("graph_output");
223    event.set_content(Content::new("assistant").with_text(&text));
224    events.push(event);
225
226    events
227}
228
229/// Builder for GraphAgent
230pub struct GraphAgentBuilder {
231    name: String,
232    description: String,
233    schema: StateSchema,
234    nodes: Vec<Arc<dyn Node>>,
235    edges: Vec<Edge>,
236    checkpointer: Option<Arc<dyn Checkpointer>>,
237    interrupt_before: Vec<String>,
238    interrupt_after: Vec<String>,
239    recursion_limit: usize,
240    input_mapper: Option<InputMapper>,
241    output_mapper: Option<OutputMapper>,
242    before_callback: Option<BeforeAgentCallback>,
243    after_callback: Option<AfterAgentCallback>,
244    timeout_policies: HashMap<String, TimeoutPolicy>,
245    default_timeout: Option<TimeoutPolicy>,
246    deferred_configs: HashMap<String, DeferredNodeConfig>,
247    #[cfg(feature = "node-cache")]
248    cache_policies: HashMap<String, crate::cache::NodeCachePolicy>,
249}
250
251impl GraphAgentBuilder {
252    /// Create a new builder
253    pub fn new(name: &str) -> Self {
254        Self {
255            name: name.to_string(),
256            description: String::new(),
257            schema: StateSchema::simple(&["input", "output", "messages"]),
258            nodes: vec![],
259            edges: vec![],
260            checkpointer: None,
261            interrupt_before: vec![],
262            interrupt_after: vec![],
263            recursion_limit: 50,
264            input_mapper: None,
265            output_mapper: None,
266            before_callback: None,
267            after_callback: None,
268            timeout_policies: HashMap::new(),
269            default_timeout: None,
270            deferred_configs: HashMap::new(),
271            #[cfg(feature = "node-cache")]
272            cache_policies: HashMap::new(),
273        }
274    }
275
276    /// Set description
277    pub fn description(mut self, desc: &str) -> Self {
278        self.description = desc.to_string();
279        self
280    }
281
282    /// Set state schema
283    pub fn state_schema(mut self, schema: StateSchema) -> Self {
284        self.schema = schema;
285        self
286    }
287
288    /// Add channels to state schema
289    pub fn channels(mut self, channels: &[&str]) -> Self {
290        self.schema = StateSchema::simple(channels);
291        self
292    }
293
294    /// Add a node
295    pub fn node<N: Node + 'static>(mut self, node: N) -> Self {
296        self.nodes.push(Arc::new(node));
297        self
298    }
299
300    /// Add a function as a node
301    pub fn node_fn<F, Fut>(mut self, name: &str, func: F) -> Self
302    where
303        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
304        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
305    {
306        self.nodes.push(Arc::new(FunctionNode::new(name, func)));
307        self
308    }
309
310    /// Add a direct edge
311    pub fn edge(mut self, source: &str, target: &str) -> Self {
312        let target =
313            if target == END { EdgeTarget::End } else { EdgeTarget::Node(target.to_string()) };
314
315        if source == START {
316            let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
317            match entry_idx {
318                Some(idx) => {
319                    if let Edge::Entry { targets } = &mut self.edges[idx]
320                        && let EdgeTarget::Node(node) = &target
321                        && !targets.contains(node)
322                    {
323                        targets.push(node.clone());
324                    }
325                }
326                None => {
327                    if let EdgeTarget::Node(node) = target {
328                        self.edges.push(Edge::Entry { targets: vec![node] });
329                    }
330                }
331            }
332        } else {
333            self.edges.push(Edge::Direct { source: source.to_string(), target });
334        }
335
336        self
337    }
338
339    /// Add a conditional edge
340    pub fn conditional_edge<F, I>(mut self, source: &str, router: F, targets: I) -> Self
341    where
342        F: Fn(&State) -> String + Send + Sync + 'static,
343        I: IntoIterator<Item = (&'static str, &'static str)>,
344    {
345        let targets_map: HashMap<String, EdgeTarget> = targets
346            .into_iter()
347            .map(|(k, v)| {
348                let target =
349                    if v == END { EdgeTarget::End } else { EdgeTarget::Node(v.to_string()) };
350                (k.to_string(), target)
351            })
352            .collect();
353
354        self.edges.push(Edge::Conditional {
355            source: source.to_string(),
356            router: Arc::new(router),
357            targets: targets_map,
358        });
359
360        self
361    }
362
363    /// Set checkpointer
364    pub fn checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
365        self.checkpointer = Some(Arc::new(checkpointer));
366        self
367    }
368
369    /// Set checkpointer with Arc
370    pub fn checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
371        self.checkpointer = Some(checkpointer);
372        self
373    }
374
375    /// Set nodes to interrupt before
376    pub fn interrupt_before(mut self, nodes: &[&str]) -> Self {
377        self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
378        self
379    }
380
381    /// Set nodes to interrupt after
382    pub fn interrupt_after(mut self, nodes: &[&str]) -> Self {
383        self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
384        self
385    }
386
387    /// Set recursion limit
388    pub fn recursion_limit(mut self, limit: usize) -> Self {
389        self.recursion_limit = limit;
390        self
391    }
392
393    /// Set a timeout policy for a specific node.
394    ///
395    /// The policy is applied when the named node executes, enforcing
396    /// wall-clock and/or idle timeouts with the configured recovery action.
397    ///
398    /// # Example
399    ///
400    /// ```rust,ignore
401    /// use std::time::Duration;
402    /// use adk_graph::timeout::{TimeoutPolicy, OnTimeout};
403    ///
404    /// let agent = GraphAgent::builder("my_graph")
405    ///     .node_timeout("slow_node", TimeoutPolicy {
406    ///         run_timeout: Some(Duration::from_secs(10)),
407    ///         idle_timeout: None,
408    ///         on_timeout: OnTimeout::Fail,
409    ///     })
410    ///     .build()?;
411    /// ```
412    pub fn node_timeout(mut self, node_name: &str, policy: TimeoutPolicy) -> Self {
413        self.timeout_policies.insert(node_name.to_string(), policy);
414        self
415    }
416
417    /// Set a default timeout policy applied to all nodes without an explicit override.
418    ///
419    /// Nodes that have a per-node policy set via [`node_timeout`](Self::node_timeout)
420    /// will use their specific policy instead of this default.
421    ///
422    /// # Example
423    ///
424    /// ```rust,ignore
425    /// use std::time::Duration;
426    /// use adk_graph::timeout::{TimeoutPolicy, OnTimeout};
427    ///
428    /// let agent = GraphAgent::builder("my_graph")
429    ///     .default_timeout(TimeoutPolicy {
430    ///         run_timeout: Some(Duration::from_secs(30)),
431    ///         idle_timeout: Some(Duration::from_secs(5)),
432    ///         on_timeout: OnTimeout::Skip,
433    ///     })
434    ///     .build()?;
435    /// ```
436    pub fn default_timeout(mut self, policy: TimeoutPolicy) -> Self {
437        self.default_timeout = Some(policy);
438        self
439    }
440
441    /// Add a deferred (fan-in barrier) node to the graph.
442    ///
443    /// A deferred node waits for all upstream parallel paths to complete before
444    /// executing. The provided function is wrapped as a [`FunctionNode`] and the
445    /// [`DeferredNodeConfig`] controls how upstream outputs are merged and how
446    /// long the node waits for all paths.
447    ///
448    /// # Arguments
449    ///
450    /// * `name` - The name of the deferred node.
451    /// * `func` - The async function to execute once all upstream paths complete.
452    /// * `config` - Configuration controlling merge strategy and fan-in timeout.
453    ///
454    /// # Example
455    ///
456    /// ```rust,ignore
457    /// use std::time::Duration;
458    /// use adk_graph::deferred::{DeferredNodeConfig, MergeStrategy};
459    /// use adk_graph::node::NodeOutput;
460    ///
461    /// let agent = GraphAgent::builder("scatter_gather")
462    ///     .deferred_node("aggregator", |_ctx| async {
463    ///         Ok(NodeOutput::new().with_update("status", serde_json::json!("merged")))
464    ///     }, DeferredNodeConfig {
465    ///         merge_strategy: MergeStrategy::Collect,
466    ///         fan_in_timeout: Some(Duration::from_secs(30)),
467    ///     })
468    ///     .build()?;
469    /// ```
470    pub fn deferred_node<F, Fut>(mut self, name: &str, func: F, config: DeferredNodeConfig) -> Self
471    where
472        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
473        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
474    {
475        self.nodes.push(Arc::new(FunctionNode::new(name, func)));
476        self.deferred_configs.insert(name.to_string(), config);
477        self
478    }
479
480    /// Set a cache policy for a specific node.
481    ///
482    /// When a node has a cache policy, its execution results are cached keyed
483    /// by a blake3 hash of the node name and input state. Subsequent executions
484    /// with identical inputs return the cached result without re-executing the
485    /// node.
486    ///
487    /// # Arguments
488    ///
489    /// * `name` — the name of the node to cache
490    /// * `policy` — the cache policy specifying backend and TTL
491    ///
492    /// # Example
493    ///
494    /// ```rust,ignore
495    /// use std::time::Duration;
496    /// use adk_graph::cache::{CacheBackend, NodeCachePolicy};
497    ///
498    /// let agent = GraphAgent::builder("cached_graph")
499    ///     .node_cache("expensive_node", NodeCachePolicy {
500    ///         backend: CacheBackend::InMemory { max_entries: 128 },
501    ///         ttl: Some(Duration::from_secs(300)),
502    ///     })
503    ///     .build()?;
504    /// ```
505    #[cfg(feature = "node-cache")]
506    pub fn node_cache(mut self, name: &str, policy: crate::cache::NodeCachePolicy) -> Self {
507        self.cache_policies.insert(name.to_string(), policy);
508        self
509    }
510
511    /// Set custom input mapper
512    pub fn input_mapper<F>(mut self, mapper: F) -> Self
513    where
514        F: Fn(&dyn InvocationContext) -> State + Send + Sync + 'static,
515    {
516        self.input_mapper = Some(Arc::new(mapper));
517        self
518    }
519
520    /// Set custom output mapper
521    pub fn output_mapper<F>(mut self, mapper: F) -> Self
522    where
523        F: Fn(&State) -> Vec<Event> + Send + Sync + 'static,
524    {
525        self.output_mapper = Some(Arc::new(mapper));
526        self
527    }
528
529    /// Set before agent callback
530    pub fn before_agent_callback<F, Fut>(mut self, callback: F) -> Self
531    where
532        F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
533        Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
534    {
535        self.before_callback = Some(Arc::new(move |ctx| Box::pin(callback(ctx))));
536        self
537    }
538
539    /// Set after agent callback
540    ///
541    /// Note: The callback receives a cloned Event to avoid lifetime issues.
542    pub fn after_agent_callback<F, Fut>(mut self, callback: F) -> Self
543    where
544        F: Fn(Arc<dyn InvocationContext>, Event) -> Fut + Send + Sync + 'static,
545        Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
546    {
547        self.after_callback = Some(Arc::new(move |ctx, event| {
548            let event_clone = event.clone();
549            Box::pin(callback(ctx, event_clone))
550        }));
551        self
552    }
553
554    /// Add an action node to the graph.
555    ///
556    /// Wraps the `ActionNodeConfig` in an `ActionNodeExecutor` and registers it
557    /// as a node. If the config is a `SwitchNodeConfig`, conditional edges are
558    /// also auto-registered from the switch conditions.
559    #[cfg(feature = "action")]
560    pub fn action_node(mut self, config: adk_action::ActionNodeConfig) -> Self {
561        use crate::action::ActionNodeExecutor;
562
563        // If this is a Switch node, register conditional edges
564        if let adk_action::ActionNodeConfig::Switch(ref switch_config) = config {
565            let conditions = switch_config.conditions.clone();
566            let eval_mode = switch_config.evaluation_mode.clone();
567            let default_branch = switch_config.default_branch.clone();
568            let source = config.standard().id.clone();
569
570            let mut targets_map: HashMap<String, EdgeTarget> = HashMap::new();
571            for condition in &conditions {
572                targets_map.insert(
573                    condition.output_port.clone(),
574                    EdgeTarget::Node(condition.output_port.clone()),
575                );
576            }
577            if let Some(ref default) = default_branch {
578                let target = if default == END {
579                    EdgeTarget::End
580                } else {
581                    EdgeTarget::Node(default.clone())
582                };
583                targets_map.insert(default.clone(), target);
584            }
585            targets_map.insert(END.to_string(), EdgeTarget::End);
586
587            let router = Arc::new(move |state: &State| -> String {
588                match crate::action::switch::evaluate_switch_conditions(
589                    &conditions,
590                    state,
591                    &eval_mode,
592                    default_branch.as_deref(),
593                ) {
594                    Ok(ports) => ports.into_iter().next().unwrap_or_else(|| END.to_string()),
595                    Err(_) => END.to_string(),
596                }
597            });
598
599            self.edges.push(Edge::Conditional { source, router, targets: targets_map });
600        }
601
602        let executor = ActionNodeExecutor::new(config);
603        self.nodes.push(Arc::new(executor));
604        self
605    }
606
607    /// Build the GraphAgent
608    pub fn build(self) -> Result<GraphAgent> {
609        // Build the graph
610        let mut graph = StateGraph::new(self.schema);
611
612        // Add nodes
613        for node in self.nodes {
614            graph.nodes.insert(node.name().to_string(), node);
615        }
616
617        // Add edges
618        graph.edges = self.edges;
619
620        // Compile
621        let mut compiled = graph.compile()?;
622
623        // Configure
624        if let Some(cp) = self.checkpointer {
625            compiled.checkpointer = Some(cp);
626        }
627        compiled.interrupt_before = self.interrupt_before.into_iter().collect();
628        compiled.interrupt_after = self.interrupt_after.into_iter().collect();
629        compiled.recursion_limit = self.recursion_limit;
630        compiled.timeout_policies = self.timeout_policies;
631        compiled.default_timeout = self.default_timeout;
632        compiled.deferred_configs = self.deferred_configs;
633
634        #[cfg(feature = "node-cache")]
635        {
636            compiled.cache_policies = self.cache_policies;
637        }
638
639        Ok(GraphAgent {
640            name: self.name,
641            description: self.description,
642            graph: Arc::new(compiled),
643            input_mapper: self.input_mapper.unwrap_or(Arc::new(default_input_mapper)),
644            output_mapper: self.output_mapper.unwrap_or(Arc::new(default_output_mapper)),
645            before_callback: self.before_callback,
646            after_callback: self.after_callback,
647        })
648    }
649}
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654    use serde_json::json;
655
656    #[tokio::test]
657    async fn test_graph_agent_builder() {
658        let agent = GraphAgent::builder("test")
659            .description("Test agent")
660            .channels(&["value"])
661            .node_fn("set", |_ctx| async { Ok(NodeOutput::new().with_update("value", json!(42))) })
662            .edge(START, "set")
663            .edge("set", END)
664            .build()
665            .unwrap();
666
667        assert_eq!(agent.name(), "test");
668        assert_eq!(agent.description(), "Test agent");
669
670        // Test direct invocation
671        let result = agent.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
672
673        assert_eq!(result.get("value"), Some(&json!(42)));
674    }
675}