adk_graph/
node.rs

1//! Node types for graph execution
2//!
3//! Nodes are the computational units in a graph. They receive state and return updates.
4
5use crate::error::Result;
6use crate::interrupt::Interrupt;
7use crate::state::State;
8use crate::stream::StreamEvent;
9use async_trait::async_trait;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16/// Configuration passed to nodes during execution
17#[derive(Clone)]
18pub struct ExecutionConfig {
19    /// Thread identifier for checkpointing
20    pub thread_id: String,
21    /// Resume from a specific checkpoint
22    pub resume_from: Option<String>,
23    /// Recursion limit for cycles
24    pub recursion_limit: usize,
25    /// Additional configuration
26    pub metadata: HashMap<String, Value>,
27}
28
29impl ExecutionConfig {
30    /// Create a new config with the given thread ID
31    pub fn new(thread_id: &str) -> Self {
32        Self {
33            thread_id: thread_id.to_string(),
34            resume_from: None,
35            recursion_limit: 50,
36            metadata: HashMap::new(),
37        }
38    }
39
40    /// Set the recursion limit
41    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
42        self.recursion_limit = limit;
43        self
44    }
45
46    /// Resume from a specific checkpoint
47    pub fn with_resume_from(mut self, checkpoint_id: &str) -> Self {
48        self.resume_from = Some(checkpoint_id.to_string());
49        self
50    }
51
52    /// Add metadata
53    pub fn with_metadata(mut self, key: &str, value: Value) -> Self {
54        self.metadata.insert(key.to_string(), value);
55        self
56    }
57}
58
59impl Default for ExecutionConfig {
60    fn default() -> Self {
61        Self::new(&uuid::Uuid::new_v4().to_string())
62    }
63}
64
65/// Context passed to nodes during execution
66pub struct NodeContext {
67    /// Current graph state (read-only view)
68    pub state: State,
69    /// Configuration for this execution
70    pub config: ExecutionConfig,
71    /// Current step number
72    pub step: usize,
73}
74
75impl NodeContext {
76    /// Create a new node context
77    pub fn new(state: State, config: ExecutionConfig, step: usize) -> Self {
78        Self { state, config, step }
79    }
80
81    /// Get a value from state
82    pub fn get(&self, key: &str) -> Option<&Value> {
83        self.state.get(key)
84    }
85
86    /// Get a value from state as a specific type
87    pub fn get_as<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
88        self.state.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
89    }
90}
91
92/// Output from a node execution
93#[derive(Default)]
94pub struct NodeOutput {
95    /// State updates to apply
96    pub updates: HashMap<String, Value>,
97    /// Optional interrupt request
98    pub interrupt: Option<Interrupt>,
99    /// Custom stream events
100    pub events: Vec<StreamEvent>,
101}
102
103impl NodeOutput {
104    /// Create a new empty output
105    pub fn new() -> Self {
106        Self::default()
107    }
108
109    /// Add a state update
110    pub fn with_update(mut self, key: &str, value: impl Into<Value>) -> Self {
111        self.updates.insert(key.to_string(), value.into());
112        self
113    }
114
115    /// Add multiple state updates
116    pub fn with_updates(mut self, updates: HashMap<String, Value>) -> Self {
117        self.updates.extend(updates);
118        self
119    }
120
121    /// Set an interrupt
122    pub fn with_interrupt(mut self, interrupt: Interrupt) -> Self {
123        self.interrupt = Some(interrupt);
124        self
125    }
126
127    /// Add a custom stream event
128    pub fn with_event(mut self, event: StreamEvent) -> Self {
129        self.events.push(event);
130        self
131    }
132
133    /// Create output that triggers a dynamic interrupt
134    pub fn interrupt(message: &str) -> Self {
135        Self::new().with_interrupt(crate::interrupt::interrupt(message))
136    }
137
138    /// Create output that triggers a dynamic interrupt with data
139    pub fn interrupt_with_data(message: &str, data: Value) -> Self {
140        Self::new().with_interrupt(crate::interrupt::interrupt_with_data(message, data))
141    }
142}
143
144/// A node in the graph
145#[async_trait]
146pub trait Node: Send + Sync {
147    /// Node identifier
148    fn name(&self) -> &str;
149
150    /// Execute the node and return state updates
151    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput>;
152}
153
154/// Type alias for boxed node
155pub type BoxedNode = Box<dyn Node>;
156
157/// Type alias for async function signature
158pub type AsyncNodeFn = Box<
159    dyn Fn(NodeContext) -> Pin<Box<dyn Future<Output = Result<NodeOutput>> + Send>> + Send + Sync,
160>;
161
162/// Function node - wraps an async function as a node
163pub struct FunctionNode {
164    name: String,
165    func: AsyncNodeFn,
166}
167
168impl FunctionNode {
169    /// Create a new function node
170    pub fn new<F, Fut>(name: &str, func: F) -> Self
171    where
172        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
173        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
174    {
175        Self { name: name.to_string(), func: Box::new(move |ctx| Box::pin(func(ctx))) }
176    }
177}
178
179#[async_trait]
180impl Node for FunctionNode {
181    fn name(&self) -> &str {
182        &self.name
183    }
184
185    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
186        let ctx_owned =
187            NodeContext { state: ctx.state.clone(), config: ctx.config.clone(), step: ctx.step };
188        (self.func)(ctx_owned).await
189    }
190}
191
192/// Passthrough node - just passes state through unchanged
193pub struct PassthroughNode {
194    name: String,
195}
196
197impl PassthroughNode {
198    /// Create a new passthrough node
199    pub fn new(name: &str) -> Self {
200        Self { name: name.to_string() }
201    }
202}
203
204#[async_trait]
205impl Node for PassthroughNode {
206    fn name(&self) -> &str {
207        &self.name
208    }
209
210    async fn execute(&self, _ctx: &NodeContext) -> Result<NodeOutput> {
211        Ok(NodeOutput::new())
212    }
213}
214
215/// Type alias for agent node input mapper
216pub type AgentInputMapper = Box<dyn Fn(&State) -> adk_core::Content + Send + Sync>;
217
218/// Type alias for agent node output mapper
219pub type AgentOutputMapper =
220    Box<dyn Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync>;
221
222/// Wrapper to use an existing ADK Agent as a graph node
223pub struct AgentNode {
224    name: String,
225    #[allow(dead_code)]
226    agent: Arc<dyn adk_core::Agent>,
227    /// Map state to agent input content
228    input_mapper: AgentInputMapper,
229    /// Map agent events to state updates
230    output_mapper: AgentOutputMapper,
231}
232
233impl AgentNode {
234    /// Create a new agent node
235    pub fn new(agent: Arc<dyn adk_core::Agent>) -> Self {
236        let name = agent.name().to_string();
237        Self {
238            name,
239            agent,
240            input_mapper: Box::new(default_input_mapper),
241            output_mapper: Box::new(default_output_mapper),
242        }
243    }
244
245    /// Set custom input mapper
246    pub fn with_input_mapper<F>(mut self, mapper: F) -> Self
247    where
248        F: Fn(&State) -> adk_core::Content + Send + Sync + 'static,
249    {
250        self.input_mapper = Box::new(mapper);
251        self
252    }
253
254    /// Set custom output mapper
255    pub fn with_output_mapper<F>(mut self, mapper: F) -> Self
256    where
257        F: Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync + 'static,
258    {
259        self.output_mapper = Box::new(mapper);
260        self
261    }
262}
263
264/// Default input mapper - looks for "messages" or "input" in state
265fn default_input_mapper(state: &State) -> adk_core::Content {
266    // Try to get messages first
267    if let Some(messages) = state.get("messages") {
268        if let Some(arr) = messages.as_array() {
269            if let Some(last) = arr.last() {
270                if let Some(content) = last.get("content").and_then(|c| c.as_str()) {
271                    return adk_core::Content::new("user").with_text(content);
272                }
273            }
274        }
275    }
276
277    // Try input field
278    if let Some(input) = state.get("input") {
279        if let Some(text) = input.as_str() {
280            return adk_core::Content::new("user").with_text(text);
281        }
282    }
283
284    adk_core::Content::new("user")
285}
286
287/// Default output mapper - extracts text content to "messages"
288fn default_output_mapper(events: &[adk_core::Event]) -> HashMap<String, Value> {
289    let mut updates = HashMap::new();
290
291    // Collect text from events
292    let mut messages = Vec::new();
293    for event in events {
294        if let Some(content) = event.content() {
295            let text = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("");
296
297            if !text.is_empty() {
298                messages.push(serde_json::json!({
299                    "role": "assistant",
300                    "content": text
301                }));
302            }
303        }
304    }
305
306    if !messages.is_empty() {
307        updates.insert("messages".to_string(), serde_json::json!(messages));
308    }
309
310    updates
311}
312
313#[async_trait]
314impl Node for AgentNode {
315    fn name(&self) -> &str {
316        &self.name
317    }
318
319    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
320        use futures::StreamExt;
321
322        // Map state to input content
323        let content = (self.input_mapper)(&ctx.state);
324
325        // Create a graph invocation context with the agent
326        let invocation_ctx = Arc::new(GraphInvocationContext::new(
327            ctx.config.thread_id.clone(),
328            content,
329            self.agent.clone(),
330        ));
331
332        // Run the agent and collect events
333        let stream = self.agent.run(invocation_ctx).await.map_err(|e| {
334            crate::error::GraphError::NodeExecutionFailed {
335                node: self.name.clone(),
336                message: e.to_string(),
337            }
338        })?;
339
340        let events: Vec<adk_core::Event> = stream.filter_map(|r| async { r.ok() }).collect().await;
341
342        // Map events to state updates
343        let updates = (self.output_mapper)(&events);
344
345        Ok(NodeOutput::new().with_updates(updates))
346    }
347}
348
349/// Full InvocationContext implementation for running agents within graph nodes
350struct GraphInvocationContext {
351    invocation_id: String,
352    user_content: adk_core::Content,
353    agent: Arc<dyn adk_core::Agent>,
354    session: GraphSession,
355    run_config: adk_core::RunConfig,
356    ended: std::sync::atomic::AtomicBool,
357}
358
359impl GraphInvocationContext {
360    fn new(
361        session_id: String,
362        user_content: adk_core::Content,
363        agent: Arc<dyn adk_core::Agent>,
364    ) -> Self {
365        let invocation_id = uuid::Uuid::new_v4().to_string();
366        Self {
367            invocation_id,
368            user_content,
369            agent,
370            session: GraphSession::new(session_id),
371            run_config: adk_core::RunConfig::default(),
372            ended: std::sync::atomic::AtomicBool::new(false),
373        }
374    }
375}
376
377// Implement ReadonlyContext (required by CallbackContext)
378impl adk_core::ReadonlyContext for GraphInvocationContext {
379    fn invocation_id(&self) -> &str {
380        &self.invocation_id
381    }
382
383    fn agent_name(&self) -> &str {
384        self.agent.name()
385    }
386
387    fn user_id(&self) -> &str {
388        "graph_user"
389    }
390
391    fn app_name(&self) -> &str {
392        "graph_app"
393    }
394
395    fn session_id(&self) -> &str {
396        &self.session.id
397    }
398
399    fn branch(&self) -> &str {
400        "main"
401    }
402
403    fn user_content(&self) -> &adk_core::Content {
404        &self.user_content
405    }
406}
407
408// Implement CallbackContext (required by InvocationContext)
409#[async_trait]
410impl adk_core::CallbackContext for GraphInvocationContext {
411    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
412        None
413    }
414}
415
416// Implement InvocationContext
417#[async_trait]
418impl adk_core::InvocationContext for GraphInvocationContext {
419    fn agent(&self) -> Arc<dyn adk_core::Agent> {
420        self.agent.clone()
421    }
422
423    fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
424        None
425    }
426
427    fn session(&self) -> &dyn adk_core::Session {
428        &self.session
429    }
430
431    fn run_config(&self) -> &adk_core::RunConfig {
432        &self.run_config
433    }
434
435    fn end_invocation(&self) {
436        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
437    }
438
439    fn ended(&self) -> bool {
440        self.ended.load(std::sync::atomic::Ordering::SeqCst)
441    }
442}
443
444/// Minimal Session implementation for graph execution
445struct GraphSession {
446    id: String,
447    state: GraphState,
448}
449
450impl GraphSession {
451    fn new(id: String) -> Self {
452        Self { id, state: GraphState::new() }
453    }
454}
455
456impl adk_core::Session for GraphSession {
457    fn id(&self) -> &str {
458        &self.id
459    }
460
461    fn app_name(&self) -> &str {
462        "graph_app"
463    }
464
465    fn user_id(&self) -> &str {
466        "graph_user"
467    }
468
469    fn state(&self) -> &dyn adk_core::State {
470        &self.state
471    }
472
473    fn conversation_history(&self) -> Vec<adk_core::Content> {
474        vec![]
475    }
476}
477
478/// Minimal State implementation for graph execution
479struct GraphState {
480    data: std::sync::RwLock<std::collections::HashMap<String, serde_json::Value>>,
481}
482
483impl GraphState {
484    fn new() -> Self {
485        Self { data: std::sync::RwLock::new(std::collections::HashMap::new()) }
486    }
487}
488
489impl adk_core::State for GraphState {
490    fn get(&self, key: &str) -> Option<serde_json::Value> {
491        self.data.read().ok()?.get(key).cloned()
492    }
493
494    fn set(&mut self, key: String, value: serde_json::Value) {
495        if let Ok(mut data) = self.data.write() {
496            data.insert(key, value);
497        }
498    }
499
500    fn all(&self) -> std::collections::HashMap<String, serde_json::Value> {
501        self.data.read().ok().map(|d| d.clone()).unwrap_or_default()
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[tokio::test]
510    async fn test_function_node() {
511        let node = FunctionNode::new("test", |_ctx| async {
512            Ok(NodeOutput::new().with_update("result", serde_json::json!("success")))
513        });
514
515        assert_eq!(node.name(), "test");
516
517        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
518        let output = node.execute(&ctx).await.unwrap();
519
520        assert_eq!(output.updates.get("result"), Some(&serde_json::json!("success")));
521    }
522
523    #[tokio::test]
524    async fn test_passthrough_node() {
525        let node = PassthroughNode::new("pass");
526        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
527        let output = node.execute(&ctx).await.unwrap();
528
529        assert!(output.updates.is_empty());
530        assert!(output.interrupt.is_none());
531    }
532
533    #[test]
534    fn test_node_output_builder() {
535        let output = NodeOutput::new().with_update("a", 1).with_update("b", "hello");
536
537        assert_eq!(output.updates.get("a"), Some(&serde_json::json!(1)));
538        assert_eq!(output.updates.get("b"), Some(&serde_json::json!("hello")));
539    }
540}