adk_graph/
node.rs

1//! Node types for graph execution
2//!
3//! Nodes are the computational units in a graph. They receive state and return updates.
4
5use crate::error::Result;
6use crate::interrupt::Interrupt;
7use crate::state::State;
8use crate::stream::StreamEvent;
9use async_trait::async_trait;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16/// Configuration passed to nodes during execution
17#[derive(Clone)]
18pub struct ExecutionConfig {
19    /// Thread identifier for checkpointing
20    pub thread_id: String,
21    /// Resume from a specific checkpoint
22    pub resume_from: Option<String>,
23    /// Recursion limit for cycles
24    pub recursion_limit: usize,
25    /// Additional configuration
26    pub metadata: HashMap<String, Value>,
27}
28
29impl ExecutionConfig {
30    /// Create a new config with the given thread ID
31    pub fn new(thread_id: &str) -> Self {
32        Self {
33            thread_id: thread_id.to_string(),
34            resume_from: None,
35            recursion_limit: 50,
36            metadata: HashMap::new(),
37        }
38    }
39
40    /// Set the recursion limit
41    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
42        self.recursion_limit = limit;
43        self
44    }
45
46    /// Resume from a specific checkpoint
47    pub fn with_resume_from(mut self, checkpoint_id: &str) -> Self {
48        self.resume_from = Some(checkpoint_id.to_string());
49        self
50    }
51
52    /// Add metadata
53    pub fn with_metadata(mut self, key: &str, value: Value) -> Self {
54        self.metadata.insert(key.to_string(), value);
55        self
56    }
57}
58
59impl Default for ExecutionConfig {
60    fn default() -> Self {
61        Self::new(&uuid::Uuid::new_v4().to_string())
62    }
63}
64
65/// Context passed to nodes during execution
66pub struct NodeContext {
67    /// Current graph state (read-only view)
68    pub state: State,
69    /// Configuration for this execution
70    pub config: ExecutionConfig,
71    /// Current step number
72    pub step: usize,
73}
74
75impl NodeContext {
76    /// Create a new node context
77    pub fn new(state: State, config: ExecutionConfig, step: usize) -> Self {
78        Self { state, config, step }
79    }
80
81    /// Get a value from state
82    pub fn get(&self, key: &str) -> Option<&Value> {
83        self.state.get(key)
84    }
85
86    /// Get a value from state as a specific type
87    pub fn get_as<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
88        self.state.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
89    }
90}
91
92/// Output from a node execution
93#[derive(Default)]
94pub struct NodeOutput {
95    /// State updates to apply
96    pub updates: HashMap<String, Value>,
97    /// Optional interrupt request
98    pub interrupt: Option<Interrupt>,
99    /// Custom stream events
100    pub events: Vec<StreamEvent>,
101}
102
103impl NodeOutput {
104    /// Create a new empty output
105    pub fn new() -> Self {
106        Self::default()
107    }
108
109    /// Add a state update
110    pub fn with_update(mut self, key: &str, value: impl Into<Value>) -> Self {
111        self.updates.insert(key.to_string(), value.into());
112        self
113    }
114
115    /// Add multiple state updates
116    pub fn with_updates(mut self, updates: HashMap<String, Value>) -> Self {
117        self.updates.extend(updates);
118        self
119    }
120
121    /// Set an interrupt
122    pub fn with_interrupt(mut self, interrupt: Interrupt) -> Self {
123        self.interrupt = Some(interrupt);
124        self
125    }
126
127    /// Add a custom stream event
128    pub fn with_event(mut self, event: StreamEvent) -> Self {
129        self.events.push(event);
130        self
131    }
132
133    /// Create output that triggers a dynamic interrupt
134    pub fn interrupt(message: &str) -> Self {
135        Self::new().with_interrupt(crate::interrupt::interrupt(message))
136    }
137
138    /// Create output that triggers a dynamic interrupt with data
139    pub fn interrupt_with_data(message: &str, data: Value) -> Self {
140        Self::new().with_interrupt(crate::interrupt::interrupt_with_data(message, data))
141    }
142}
143
144/// A node in the graph
145#[async_trait]
146pub trait Node: Send + Sync {
147    /// Node identifier
148    fn name(&self) -> &str;
149
150    /// Execute the node and return state updates
151    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput>;
152
153    /// Stream execution events (default: wraps execute)
154    fn execute_stream<'a>(
155        &'a self,
156        ctx: &'a NodeContext,
157    ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
158        let _name = self.name().to_string();
159        Box::pin(async_stream::stream! {
160            match self.execute(ctx).await {
161                Ok(output) => {
162                    for event in output.events {
163                        yield Ok(event);
164                    }
165                }
166                Err(e) => yield Err(e),
167            }
168        })
169    }
170}
171
172/// Type alias for boxed node
173pub type BoxedNode = Box<dyn Node>;
174
175/// Type alias for async function signature
176pub type AsyncNodeFn = Box<
177    dyn Fn(NodeContext) -> Pin<Box<dyn Future<Output = Result<NodeOutput>> + Send>> + Send + Sync,
178>;
179
180/// Function node - wraps an async function as a node
181pub struct FunctionNode {
182    name: String,
183    func: AsyncNodeFn,
184}
185
186impl FunctionNode {
187    /// Create a new function node
188    pub fn new<F, Fut>(name: &str, func: F) -> Self
189    where
190        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
191        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
192    {
193        Self { name: name.to_string(), func: Box::new(move |ctx| Box::pin(func(ctx))) }
194    }
195}
196
197#[async_trait]
198impl Node for FunctionNode {
199    fn name(&self) -> &str {
200        &self.name
201    }
202
203    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
204        let ctx_owned =
205            NodeContext { state: ctx.state.clone(), config: ctx.config.clone(), step: ctx.step };
206        (self.func)(ctx_owned).await
207    }
208}
209
210/// Passthrough node - just passes state through unchanged
211pub struct PassthroughNode {
212    name: String,
213}
214
215impl PassthroughNode {
216    /// Create a new passthrough node
217    pub fn new(name: &str) -> Self {
218        Self { name: name.to_string() }
219    }
220}
221
222#[async_trait]
223impl Node for PassthroughNode {
224    fn name(&self) -> &str {
225        &self.name
226    }
227
228    async fn execute(&self, _ctx: &NodeContext) -> Result<NodeOutput> {
229        Ok(NodeOutput::new())
230    }
231}
232
233/// Type alias for agent node input mapper
234pub type AgentInputMapper = Box<dyn Fn(&State) -> adk_core::Content + Send + Sync>;
235
236/// Type alias for agent node output mapper
237pub type AgentOutputMapper =
238    Box<dyn Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync>;
239
240/// Wrapper to use an existing ADK Agent as a graph node
241pub struct AgentNode {
242    name: String,
243    #[allow(dead_code)]
244    agent: Arc<dyn adk_core::Agent>,
245    /// Map state to agent input content
246    input_mapper: AgentInputMapper,
247    /// Map agent events to state updates
248    output_mapper: AgentOutputMapper,
249}
250
251impl AgentNode {
252    /// Create a new agent node
253    pub fn new(agent: Arc<dyn adk_core::Agent>) -> Self {
254        let name = agent.name().to_string();
255        Self {
256            name,
257            agent,
258            input_mapper: Box::new(default_input_mapper),
259            output_mapper: Box::new(default_output_mapper),
260        }
261    }
262
263    /// Set custom input mapper
264    pub fn with_input_mapper<F>(mut self, mapper: F) -> Self
265    where
266        F: Fn(&State) -> adk_core::Content + Send + Sync + 'static,
267    {
268        self.input_mapper = Box::new(mapper);
269        self
270    }
271
272    /// Set custom output mapper
273    pub fn with_output_mapper<F>(mut self, mapper: F) -> Self
274    where
275        F: Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync + 'static,
276    {
277        self.output_mapper = Box::new(mapper);
278        self
279    }
280}
281
282/// Default input mapper - looks for "messages" or "input" in state
283fn default_input_mapper(state: &State) -> adk_core::Content {
284    // Try to get messages first
285    if let Some(messages) = state.get("messages") {
286        if let Some(arr) = messages.as_array() {
287            if let Some(last) = arr.last() {
288                if let Some(content) = last.get("content").and_then(|c| c.as_str()) {
289                    return adk_core::Content::new("user").with_text(content);
290                }
291            }
292        }
293    }
294
295    // Try input field
296    if let Some(input) = state.get("input") {
297        if let Some(text) = input.as_str() {
298            return adk_core::Content::new("user").with_text(text);
299        }
300    }
301
302    adk_core::Content::new("user")
303}
304
305/// Default output mapper - extracts text content to "messages"
306fn default_output_mapper(events: &[adk_core::Event]) -> HashMap<String, Value> {
307    let mut updates = HashMap::new();
308
309    // Collect text from events
310    let mut messages = Vec::new();
311    for event in events {
312        if let Some(content) = event.content() {
313            let text = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("");
314
315            if !text.is_empty() {
316                messages.push(serde_json::json!({
317                    "role": "assistant",
318                    "content": text
319                }));
320            }
321        }
322    }
323
324    if !messages.is_empty() {
325        updates.insert("messages".to_string(), serde_json::json!(messages));
326    }
327
328    updates
329}
330
331#[async_trait]
332impl Node for AgentNode {
333    fn name(&self) -> &str {
334        &self.name
335    }
336
337    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
338        use futures::StreamExt;
339
340        // Map state to input content
341        let content = (self.input_mapper)(&ctx.state);
342
343        // Create a graph invocation context with the agent
344        let invocation_ctx = Arc::new(GraphInvocationContext::new(
345            ctx.config.thread_id.clone(),
346            content,
347            self.agent.clone(),
348        ));
349
350        // Run the agent and collect events
351        let stream = self.agent.run(invocation_ctx).await.map_err(|e| {
352            crate::error::GraphError::NodeExecutionFailed {
353                node: self.name.clone(),
354                message: e.to_string(),
355            }
356        })?;
357
358        let events: Vec<adk_core::Event> = stream.filter_map(|r| async { r.ok() }).collect().await;
359
360        // Map events to state updates
361        let updates = (self.output_mapper)(&events);
362
363        // Convert agent events to stream events for tracing
364        let mut output = NodeOutput::new().with_updates(updates);
365        for event in &events {
366            if let Ok(json) = serde_json::to_value(event) {
367                output = output.with_event(StreamEvent::custom(&self.name, "agent_event", json));
368            }
369        }
370
371        Ok(output)
372    }
373
374    fn execute_stream<'a>(
375        &'a self,
376        ctx: &'a NodeContext,
377    ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
378        use futures::StreamExt;
379        let name = self.name.clone();
380        let agent = self.agent.clone();
381        let input_mapper = &self.input_mapper;
382        let thread_id = ctx.config.thread_id.clone();
383        let content = (input_mapper)(&ctx.state);
384
385        Box::pin(async_stream::stream! {
386            eprintln!("DEBUG: AgentNode::execute_stream called for {}", name);
387            let invocation_ctx = Arc::new(GraphInvocationContext::new(
388                thread_id,
389                content,
390                agent.clone(),
391            ));
392
393            let stream = match agent.run(invocation_ctx).await {
394                Ok(s) => s,
395                Err(e) => {
396                    yield Err(crate::error::GraphError::NodeExecutionFailed {
397                        node: name.clone(),
398                        message: e.to_string(),
399                    });
400                    return;
401                }
402            };
403
404            tokio::pin!(stream);
405            let mut all_events = Vec::new();
406
407            while let Some(result) = stream.next().await {
408                match result {
409                    Ok(event) => {
410                        // Emit streaming event immediately
411                        if let Some(content) = event.content() {
412                            let text: String = content.parts.iter().filter_map(|p| p.text()).collect();
413                            if !text.is_empty() {
414                                yield Ok(StreamEvent::Message {
415                                    node: name.clone(),
416                                    content: text,
417                                    is_final: false,
418                                });
419                            }
420                        }
421                        all_events.push(event);
422                    }
423                    Err(e) => {
424                        yield Err(crate::error::GraphError::NodeExecutionFailed {
425                            node: name.clone(),
426                            message: e.to_string(),
427                        });
428                        return;
429                    }
430                }
431            }
432
433            // Emit final events
434            for event in &all_events {
435                if let Ok(json) = serde_json::to_value(event) {
436                    yield Ok(StreamEvent::custom(&name, "agent_event", json));
437                }
438            }
439        })
440    }
441}
442
443/// Full InvocationContext implementation for running agents within graph nodes
444struct GraphInvocationContext {
445    invocation_id: String,
446    user_content: adk_core::Content,
447    agent: Arc<dyn adk_core::Agent>,
448    session: Arc<GraphSession>,
449    run_config: adk_core::RunConfig,
450    ended: std::sync::atomic::AtomicBool,
451}
452
453impl GraphInvocationContext {
454    fn new(
455        session_id: String,
456        user_content: adk_core::Content,
457        agent: Arc<dyn adk_core::Agent>,
458    ) -> Self {
459        let invocation_id = uuid::Uuid::new_v4().to_string();
460        let session = Arc::new(GraphSession::new(session_id));
461        // Add user content to history
462        session.append_content(user_content.clone());
463        Self {
464            invocation_id,
465            user_content,
466            agent,
467            session,
468            run_config: adk_core::RunConfig::default(),
469            ended: std::sync::atomic::AtomicBool::new(false),
470        }
471    }
472}
473
474// Implement ReadonlyContext (required by CallbackContext)
475impl adk_core::ReadonlyContext for GraphInvocationContext {
476    fn invocation_id(&self) -> &str {
477        &self.invocation_id
478    }
479
480    fn agent_name(&self) -> &str {
481        self.agent.name()
482    }
483
484    fn user_id(&self) -> &str {
485        "graph_user"
486    }
487
488    fn app_name(&self) -> &str {
489        "graph_app"
490    }
491
492    fn session_id(&self) -> &str {
493        &self.session.id
494    }
495
496    fn branch(&self) -> &str {
497        "main"
498    }
499
500    fn user_content(&self) -> &adk_core::Content {
501        &self.user_content
502    }
503}
504
505// Implement CallbackContext (required by InvocationContext)
506#[async_trait]
507impl adk_core::CallbackContext for GraphInvocationContext {
508    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
509        None
510    }
511}
512
513// Implement InvocationContext
514#[async_trait]
515impl adk_core::InvocationContext for GraphInvocationContext {
516    fn agent(&self) -> Arc<dyn adk_core::Agent> {
517        self.agent.clone()
518    }
519
520    fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
521        None
522    }
523
524    fn session(&self) -> &dyn adk_core::Session {
525        self.session.as_ref()
526    }
527
528    fn run_config(&self) -> &adk_core::RunConfig {
529        &self.run_config
530    }
531
532    fn end_invocation(&self) {
533        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
534    }
535
536    fn ended(&self) -> bool {
537        self.ended.load(std::sync::atomic::Ordering::SeqCst)
538    }
539}
540
541/// Minimal Session implementation for graph execution
542struct GraphSession {
543    id: String,
544    state: GraphState,
545    history: std::sync::RwLock<Vec<adk_core::Content>>,
546}
547
548impl GraphSession {
549    fn new(id: String) -> Self {
550        Self { id, state: GraphState::new(), history: std::sync::RwLock::new(Vec::new()) }
551    }
552
553    fn append_content(&self, content: adk_core::Content) {
554        if let Ok(mut h) = self.history.write() {
555            h.push(content);
556        }
557    }
558}
559
560impl adk_core::Session for GraphSession {
561    fn id(&self) -> &str {
562        &self.id
563    }
564
565    fn app_name(&self) -> &str {
566        "graph_app"
567    }
568
569    fn user_id(&self) -> &str {
570        "graph_user"
571    }
572
573    fn state(&self) -> &dyn adk_core::State {
574        &self.state
575    }
576
577    fn conversation_history(&self) -> Vec<adk_core::Content> {
578        self.history.read().ok().map(|h| h.clone()).unwrap_or_default()
579    }
580
581    fn append_to_history(&self, content: adk_core::Content) {
582        self.append_content(content);
583    }
584}
585
586/// Minimal State implementation for graph execution
587struct GraphState {
588    data: std::sync::RwLock<std::collections::HashMap<String, serde_json::Value>>,
589}
590
591impl GraphState {
592    fn new() -> Self {
593        Self { data: std::sync::RwLock::new(std::collections::HashMap::new()) }
594    }
595}
596
597impl adk_core::State for GraphState {
598    fn get(&self, key: &str) -> Option<serde_json::Value> {
599        self.data.read().ok()?.get(key).cloned()
600    }
601
602    fn set(&mut self, key: String, value: serde_json::Value) {
603        if let Ok(mut data) = self.data.write() {
604            data.insert(key, value);
605        }
606    }
607
608    fn all(&self) -> std::collections::HashMap<String, serde_json::Value> {
609        self.data.read().ok().map(|d| d.clone()).unwrap_or_default()
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[tokio::test]
618    async fn test_function_node() {
619        let node = FunctionNode::new("test", |_ctx| async {
620            Ok(NodeOutput::new().with_update("result", serde_json::json!("success")))
621        });
622
623        assert_eq!(node.name(), "test");
624
625        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
626        let output = node.execute(&ctx).await.unwrap();
627
628        assert_eq!(output.updates.get("result"), Some(&serde_json::json!("success")));
629    }
630
631    #[tokio::test]
632    async fn test_passthrough_node() {
633        let node = PassthroughNode::new("pass");
634        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
635        let output = node.execute(&ctx).await.unwrap();
636
637        assert!(output.updates.is_empty());
638        assert!(output.interrupt.is_none());
639    }
640
641    #[test]
642    fn test_node_output_builder() {
643        let output = NodeOutput::new().with_update("a", 1).with_update("b", "hello");
644
645        assert_eq!(output.updates.get("a"), Some(&serde_json::json!(1)));
646        assert_eq!(output.updates.get("b"), Some(&serde_json::json!("hello")));
647    }
648}