adk_graph/
agent.rs

1//! GraphAgent - ADK Agent integration for graph workflows
2//!
3//! Provides a builder pattern similar to LlmAgent and RealtimeAgent.
4
5use crate::checkpoint::Checkpointer;
6use crate::edge::{Edge, EdgeTarget, END, START};
7use crate::error::{GraphError, Result};
8use crate::graph::{CompiledGraph, StateGraph};
9use crate::node::{ExecutionConfig, FunctionNode, Node, NodeContext, NodeOutput};
10use crate::state::{State, StateSchema};
11use crate::stream::{StreamEvent, StreamMode};
12use adk_core::{Agent, Content, Event, EventStream, InvocationContext};
13use async_trait::async_trait;
14use serde_json::json;
15use std::collections::HashMap;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19
20/// Type alias for callbacks
21pub type BeforeAgentCallback = Arc<
22    dyn Fn(Arc<dyn InvocationContext>) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
23        + Send
24        + Sync,
25>;
26
27pub type AfterAgentCallback = Arc<
28    dyn Fn(
29            Arc<dyn InvocationContext>,
30            Event,
31        ) -> Pin<Box<dyn Future<Output = adk_core::Result<()>> + Send>>
32        + Send
33        + Sync,
34>;
35
36/// Type alias for input mapper function
37pub type InputMapper = Arc<dyn Fn(&dyn InvocationContext) -> State + Send + Sync>;
38
39/// Type alias for output mapper function
40pub type OutputMapper = Arc<dyn Fn(&State) -> Vec<Event> + Send + Sync>;
41
42/// GraphAgent wraps a CompiledGraph as an ADK Agent
43pub struct GraphAgent {
44    name: String,
45    description: String,
46    graph: Arc<CompiledGraph>,
47    /// Map InvocationContext to graph input state
48    input_mapper: InputMapper,
49    /// Map graph output state to ADK Events
50    output_mapper: OutputMapper,
51    /// Before agent callback
52    before_callback: Option<BeforeAgentCallback>,
53    /// After agent callback
54    after_callback: Option<AfterAgentCallback>,
55}
56
57impl GraphAgent {
58    /// Create a new GraphAgent builder
59    pub fn builder(name: &str) -> GraphAgentBuilder {
60        GraphAgentBuilder::new(name)
61    }
62
63    /// Create directly from a compiled graph
64    pub fn from_graph(name: &str, graph: CompiledGraph) -> Self {
65        Self {
66            name: name.to_string(),
67            description: String::new(),
68            graph: Arc::new(graph),
69            input_mapper: Arc::new(default_input_mapper),
70            output_mapper: Arc::new(default_output_mapper),
71            before_callback: None,
72            after_callback: None,
73        }
74    }
75
76    /// Get the underlying compiled graph
77    pub fn graph(&self) -> &CompiledGraph {
78        &self.graph
79    }
80
81    /// Execute the graph directly (bypassing Agent trait)
82    pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
83        self.graph.invoke(input, config).await
84    }
85
86    /// Stream execution
87    pub fn stream(
88        &self,
89        input: State,
90        config: ExecutionConfig,
91        mode: StreamMode,
92    ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
93        self.graph.stream(input, config, mode)
94    }
95}
96
97#[async_trait]
98impl Agent for GraphAgent {
99    fn name(&self) -> &str {
100        &self.name
101    }
102
103    fn description(&self) -> &str {
104        &self.description
105    }
106
107    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
108        &[]
109    }
110
111    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> adk_core::Result<EventStream> {
112        // Call before callback
113        if let Some(callback) = &self.before_callback {
114            callback(ctx.clone()).await?;
115        }
116
117        // Map context to input state
118        let input = (self.input_mapper)(ctx.as_ref());
119
120        // Create execution config from context
121        let config = ExecutionConfig::new(ctx.session_id());
122
123        // Execute graph
124        let graph = self.graph.clone();
125        let output_mapper = self.output_mapper.clone();
126        let after_callback = self.after_callback.clone();
127        let ctx_clone = ctx.clone();
128
129        let stream = async_stream::stream! {
130            match graph.invoke(input, config).await {
131                Ok(state) => {
132                    let events = output_mapper(&state);
133                    for event in events {
134                        // Call after callback for each event
135                        if let Some(callback) = &after_callback {
136                            if let Err(e) = callback(ctx_clone.clone(), event.clone()).await {
137                                yield Err(e);
138                                return;
139                            }
140                        }
141                        yield Ok(event);
142                    }
143                }
144                Err(GraphError::Interrupted(interrupt)) => {
145                    // Create an interrupt event
146                    let mut event = Event::new("graph_interrupted");
147                    event.set_content(Content::new("assistant").with_text(format!(
148                        "Graph interrupted: {:?}\nThread: {}\nCheckpoint: {}",
149                        interrupt.interrupt,
150                        interrupt.thread_id,
151                        interrupt.checkpoint_id
152                    )));
153                    yield Ok(event);
154                }
155                Err(e) => {
156                    yield Err(adk_core::AdkError::Agent(e.to_string()));
157                }
158            }
159        };
160
161        Ok(Box::pin(stream))
162    }
163}
164
165/// Default input mapper - extracts content from InvocationContext
166fn default_input_mapper(ctx: &dyn InvocationContext) -> State {
167    let mut state = State::new();
168
169    // Get user content
170    let content = ctx.user_content();
171    let text: String = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("\n");
172
173    if !text.is_empty() {
174        state.insert("input".to_string(), json!(text));
175        state.insert("messages".to_string(), json!([{"role": "user", "content": text}]));
176    }
177
178    // Add session ID
179    state.insert("session_id".to_string(), json!(ctx.session_id()));
180
181    state
182}
183
184/// Default output mapper - creates events from state
185fn default_output_mapper(state: &State) -> Vec<Event> {
186    let mut events = Vec::new();
187
188    // Try to get output from common fields
189    let output_text = state
190        .get("output")
191        .and_then(|v| v.as_str())
192        .or_else(|| state.get("result").and_then(|v| v.as_str()))
193        .or_else(|| {
194            state
195                .get("messages")
196                .and_then(|v| v.as_array())
197                .and_then(|arr| arr.last())
198                .and_then(|msg| msg.get("content"))
199                .and_then(|c| c.as_str())
200        });
201
202    let text = if let Some(text) = output_text {
203        text.to_string()
204    } else {
205        // Return the full state as JSON
206        serde_json::to_string_pretty(state).unwrap_or_default()
207    };
208
209    let mut event = Event::new("graph_output");
210    event.set_content(Content::new("assistant").with_text(&text));
211    events.push(event);
212
213    events
214}
215
216/// Builder for GraphAgent
217pub struct GraphAgentBuilder {
218    name: String,
219    description: String,
220    schema: StateSchema,
221    nodes: Vec<Arc<dyn Node>>,
222    edges: Vec<Edge>,
223    checkpointer: Option<Arc<dyn Checkpointer>>,
224    interrupt_before: Vec<String>,
225    interrupt_after: Vec<String>,
226    recursion_limit: usize,
227    input_mapper: Option<InputMapper>,
228    output_mapper: Option<OutputMapper>,
229    before_callback: Option<BeforeAgentCallback>,
230    after_callback: Option<AfterAgentCallback>,
231}
232
233impl GraphAgentBuilder {
234    /// Create a new builder
235    pub fn new(name: &str) -> Self {
236        Self {
237            name: name.to_string(),
238            description: String::new(),
239            schema: StateSchema::simple(&["input", "output", "messages"]),
240            nodes: vec![],
241            edges: vec![],
242            checkpointer: None,
243            interrupt_before: vec![],
244            interrupt_after: vec![],
245            recursion_limit: 50,
246            input_mapper: None,
247            output_mapper: None,
248            before_callback: None,
249            after_callback: None,
250        }
251    }
252
253    /// Set description
254    pub fn description(mut self, desc: &str) -> Self {
255        self.description = desc.to_string();
256        self
257    }
258
259    /// Set state schema
260    pub fn state_schema(mut self, schema: StateSchema) -> Self {
261        self.schema = schema;
262        self
263    }
264
265    /// Add channels to state schema
266    pub fn channels(mut self, channels: &[&str]) -> Self {
267        self.schema = StateSchema::simple(channels);
268        self
269    }
270
271    /// Add a node
272    pub fn node<N: Node + 'static>(mut self, node: N) -> Self {
273        self.nodes.push(Arc::new(node));
274        self
275    }
276
277    /// Add a function as a node
278    pub fn node_fn<F, Fut>(mut self, name: &str, func: F) -> Self
279    where
280        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
281        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
282    {
283        self.nodes.push(Arc::new(FunctionNode::new(name, func)));
284        self
285    }
286
287    /// Add a direct edge
288    pub fn edge(mut self, source: &str, target: &str) -> Self {
289        let target =
290            if target == END { EdgeTarget::End } else { EdgeTarget::Node(target.to_string()) };
291
292        if source == START {
293            let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
294            match entry_idx {
295                Some(idx) => {
296                    if let Edge::Entry { targets } = &mut self.edges[idx] {
297                        if let EdgeTarget::Node(node) = &target {
298                            if !targets.contains(node) {
299                                targets.push(node.clone());
300                            }
301                        }
302                    }
303                }
304                None => {
305                    if let EdgeTarget::Node(node) = target {
306                        self.edges.push(Edge::Entry { targets: vec![node] });
307                    }
308                }
309            }
310        } else {
311            self.edges.push(Edge::Direct { source: source.to_string(), target });
312        }
313
314        self
315    }
316
317    /// Add a conditional edge
318    pub fn conditional_edge<F, I>(mut self, source: &str, router: F, targets: I) -> Self
319    where
320        F: Fn(&State) -> String + Send + Sync + 'static,
321        I: IntoIterator<Item = (&'static str, &'static str)>,
322    {
323        let targets_map: HashMap<String, EdgeTarget> = targets
324            .into_iter()
325            .map(|(k, v)| {
326                let target =
327                    if v == END { EdgeTarget::End } else { EdgeTarget::Node(v.to_string()) };
328                (k.to_string(), target)
329            })
330            .collect();
331
332        self.edges.push(Edge::Conditional {
333            source: source.to_string(),
334            router: Arc::new(router),
335            targets: targets_map,
336        });
337
338        self
339    }
340
341    /// Set checkpointer
342    pub fn checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
343        self.checkpointer = Some(Arc::new(checkpointer));
344        self
345    }
346
347    /// Set checkpointer with Arc
348    pub fn checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
349        self.checkpointer = Some(checkpointer);
350        self
351    }
352
353    /// Set nodes to interrupt before
354    pub fn interrupt_before(mut self, nodes: &[&str]) -> Self {
355        self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
356        self
357    }
358
359    /// Set nodes to interrupt after
360    pub fn interrupt_after(mut self, nodes: &[&str]) -> Self {
361        self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
362        self
363    }
364
365    /// Set recursion limit
366    pub fn recursion_limit(mut self, limit: usize) -> Self {
367        self.recursion_limit = limit;
368        self
369    }
370
371    /// Set custom input mapper
372    pub fn input_mapper<F>(mut self, mapper: F) -> Self
373    where
374        F: Fn(&dyn InvocationContext) -> State + Send + Sync + 'static,
375    {
376        self.input_mapper = Some(Arc::new(mapper));
377        self
378    }
379
380    /// Set custom output mapper
381    pub fn output_mapper<F>(mut self, mapper: F) -> Self
382    where
383        F: Fn(&State) -> Vec<Event> + Send + Sync + 'static,
384    {
385        self.output_mapper = Some(Arc::new(mapper));
386        self
387    }
388
389    /// Set before agent callback
390    pub fn before_agent_callback<F, Fut>(mut self, callback: F) -> Self
391    where
392        F: Fn(Arc<dyn InvocationContext>) -> Fut + Send + Sync + 'static,
393        Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
394    {
395        self.before_callback = Some(Arc::new(move |ctx| Box::pin(callback(ctx))));
396        self
397    }
398
399    /// Set after agent callback
400    ///
401    /// Note: The callback receives a cloned Event to avoid lifetime issues.
402    pub fn after_agent_callback<F, Fut>(mut self, callback: F) -> Self
403    where
404        F: Fn(Arc<dyn InvocationContext>, Event) -> Fut + Send + Sync + 'static,
405        Fut: Future<Output = adk_core::Result<()>> + Send + 'static,
406    {
407        self.after_callback = Some(Arc::new(move |ctx, event| {
408            let event_clone = event.clone();
409            Box::pin(callback(ctx, event_clone))
410        }));
411        self
412    }
413
414    /// Build the GraphAgent
415    pub fn build(self) -> Result<GraphAgent> {
416        // Build the graph
417        let mut graph = StateGraph::new(self.schema);
418
419        // Add nodes
420        for node in self.nodes {
421            graph.nodes.insert(node.name().to_string(), node);
422        }
423
424        // Add edges
425        graph.edges = self.edges;
426
427        // Compile
428        let mut compiled = graph.compile()?;
429
430        // Configure
431        if let Some(cp) = self.checkpointer {
432            compiled.checkpointer = Some(cp);
433        }
434        compiled.interrupt_before = self.interrupt_before.into_iter().collect();
435        compiled.interrupt_after = self.interrupt_after.into_iter().collect();
436        compiled.recursion_limit = self.recursion_limit;
437
438        Ok(GraphAgent {
439            name: self.name,
440            description: self.description,
441            graph: Arc::new(compiled),
442            input_mapper: self.input_mapper.unwrap_or(Arc::new(default_input_mapper)),
443            output_mapper: self.output_mapper.unwrap_or(Arc::new(default_output_mapper)),
444            before_callback: self.before_callback,
445            after_callback: self.after_callback,
446        })
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use serde_json::json;
454
455    #[tokio::test]
456    async fn test_graph_agent_builder() {
457        let agent = GraphAgent::builder("test")
458            .description("Test agent")
459            .channels(&["value"])
460            .node_fn("set", |_ctx| async { Ok(NodeOutput::new().with_update("value", json!(42))) })
461            .edge(START, "set")
462            .edge("set", END)
463            .build()
464            .unwrap();
465
466        assert_eq!(agent.name(), "test");
467        assert_eq!(agent.description(), "Test agent");
468
469        // Test direct invocation
470        let result = agent.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
471
472        assert_eq!(result.get("value"), Some(&json!(42)));
473    }
474}