Skip to main content

adk_graph/
node.rs

1//! Node types for graph execution
2//!
3//! Nodes are the computational units in a graph. They receive state and return updates.
4
5use crate::error::Result;
6use crate::interrupt::Interrupt;
7use crate::state::State;
8use crate::stream::StreamEvent;
9use crate::timeout::ProgressHandle;
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16
17/// Configuration passed to nodes during execution
18#[derive(Clone)]
19pub struct ExecutionConfig {
20    /// Thread identifier for checkpointing
21    pub thread_id: String,
22    /// Resume from a specific checkpoint
23    pub resume_from: Option<String>,
24    /// Recursion limit for cycles
25    pub recursion_limit: usize,
26    /// Additional configuration
27    pub metadata: HashMap<String, Value>,
28}
29
30impl ExecutionConfig {
31    /// Create a new config with the given thread ID
32    pub fn new(thread_id: &str) -> Self {
33        Self {
34            thread_id: thread_id.to_string(),
35            resume_from: None,
36            recursion_limit: 50,
37            metadata: HashMap::new(),
38        }
39    }
40
41    /// Set the recursion limit
42    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
43        self.recursion_limit = limit;
44        self
45    }
46
47    /// Resume from a specific checkpoint
48    pub fn with_resume_from(mut self, checkpoint_id: &str) -> Self {
49        self.resume_from = Some(checkpoint_id.to_string());
50        self
51    }
52
53    /// Add metadata
54    pub fn with_metadata(mut self, key: &str, value: Value) -> Self {
55        self.metadata.insert(key.to_string(), value);
56        self
57    }
58}
59
60impl Default for ExecutionConfig {
61    fn default() -> Self {
62        Self::new(&uuid::Uuid::new_v4().to_string())
63    }
64}
65
66/// Context passed to nodes during execution
67pub struct NodeContext {
68    /// Current graph state (read-only view)
69    pub state: State,
70    /// Configuration for this execution
71    pub config: ExecutionConfig,
72    /// Current step number
73    pub step: usize,
74    /// Optional progress handle for idle timeout tracking.
75    /// When present, calling [`report_progress()`](Self::report_progress) resets the idle timeout counter.
76    progress_handle: Option<ProgressHandle>,
77}
78
79impl NodeContext {
80    /// Create a new node context
81    pub fn new(state: State, config: ExecutionConfig, step: usize) -> Self {
82        Self { state, config, step, progress_handle: None }
83    }
84
85    /// Get a value from state
86    pub fn get(&self, key: &str) -> Option<&Value> {
87        self.state.get(key)
88    }
89
90    /// Get a value from state as a specific type
91    pub fn get_as<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
92        self.state.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
93    }
94
95    /// Report progress, resetting the idle timeout counter.
96    ///
97    /// Nodes performing long-running work should call this periodically to
98    /// prevent the idle timeout from firing. If no progress handle is attached
99    /// (e.g., when no idle timeout is configured), this is a no-op.
100    ///
101    /// # Example
102    ///
103    /// ```rust,ignore
104    /// async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
105    ///     for chunk in large_dataset.chunks(100) {
106    ///         process(chunk).await;
107    ///         ctx.report_progress(); // reset idle timeout
108    ///     }
109    ///     Ok(NodeOutput::new())
110    /// }
111    /// ```
112    pub fn report_progress(&self) {
113        if let Some(handle) = &self.progress_handle {
114            handle.report_progress();
115        }
116    }
117
118    /// Attach a progress handle for idle timeout tracking.
119    ///
120    /// This is called by the executor before running a node with an idle timeout
121    /// policy. Nodes do not need to call this directly.
122    pub fn set_progress_handle(&mut self, handle: ProgressHandle) {
123        self.progress_handle = Some(handle);
124    }
125
126    /// Get a reference to the attached progress handle, if any.
127    pub fn progress_handle(&self) -> Option<&ProgressHandle> {
128        self.progress_handle.as_ref()
129    }
130}
131
132/// Output from a node execution
133#[derive(Default)]
134pub struct NodeOutput {
135    /// State updates to apply
136    pub updates: HashMap<String, Value>,
137    /// Optional interrupt request
138    pub interrupt: Option<Interrupt>,
139    /// Custom stream events
140    pub events: Vec<StreamEvent>,
141}
142
143impl NodeOutput {
144    /// Create a new empty output
145    pub fn new() -> Self {
146        Self::default()
147    }
148
149    /// Add a state update
150    pub fn with_update(mut self, key: &str, value: impl Into<Value>) -> Self {
151        self.updates.insert(key.to_string(), value.into());
152        self
153    }
154
155    /// Add multiple state updates
156    pub fn with_updates(mut self, updates: HashMap<String, Value>) -> Self {
157        self.updates.extend(updates);
158        self
159    }
160
161    /// Set an interrupt
162    pub fn with_interrupt(mut self, interrupt: Interrupt) -> Self {
163        self.interrupt = Some(interrupt);
164        self
165    }
166
167    /// Add a custom stream event
168    pub fn with_event(mut self, event: StreamEvent) -> Self {
169        self.events.push(event);
170        self
171    }
172
173    /// Create output that triggers a dynamic interrupt
174    pub fn interrupt(message: &str) -> Self {
175        Self::new().with_interrupt(crate::interrupt::interrupt(message))
176    }
177
178    /// Create output that triggers a dynamic interrupt with data
179    pub fn interrupt_with_data(message: &str, data: Value) -> Self {
180        Self::new().with_interrupt(crate::interrupt::interrupt_with_data(message, data))
181    }
182}
183
184/// A node in the graph
185#[async_trait]
186pub trait Node: Send + Sync {
187    /// Node identifier
188    fn name(&self) -> &str;
189
190    /// Execute the node and return state updates
191    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput>;
192
193    /// Stream execution events (default: wraps execute)
194    fn execute_stream<'a>(
195        &'a self,
196        ctx: &'a NodeContext,
197    ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
198        let _name = self.name().to_string();
199        Box::pin(async_stream::stream! {
200            match self.execute(ctx).await {
201                Ok(output) => {
202                    for event in output.events {
203                        yield Ok(event);
204                    }
205                }
206                Err(e) => yield Err(e),
207            }
208        })
209    }
210}
211
212/// Type alias for boxed node
213pub type BoxedNode = Box<dyn Node>;
214
215/// Type alias for async function signature
216pub type AsyncNodeFn = Box<
217    dyn Fn(NodeContext) -> Pin<Box<dyn Future<Output = Result<NodeOutput>> + Send>> + Send + Sync,
218>;
219
220/// Function node - wraps an async function as a node
221pub struct FunctionNode {
222    name: String,
223    func: AsyncNodeFn,
224}
225
226impl FunctionNode {
227    /// Create a new function node
228    pub fn new<F, Fut>(name: &str, func: F) -> Self
229    where
230        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
231        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
232    {
233        Self { name: name.to_string(), func: Box::new(move |ctx| Box::pin(func(ctx))) }
234    }
235}
236
237#[async_trait]
238impl Node for FunctionNode {
239    fn name(&self) -> &str {
240        &self.name
241    }
242
243    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
244        let mut ctx_owned = NodeContext::new(ctx.state.clone(), ctx.config.clone(), ctx.step);
245        if let Some(handle) = ctx.progress_handle() {
246            ctx_owned.set_progress_handle(handle.clone());
247        }
248        (self.func)(ctx_owned).await
249    }
250}
251
252/// Passthrough node - just passes state through unchanged
253pub struct PassthroughNode {
254    name: String,
255}
256
257impl PassthroughNode {
258    /// Create a new passthrough node
259    pub fn new(name: &str) -> Self {
260        Self { name: name.to_string() }
261    }
262}
263
264#[async_trait]
265impl Node for PassthroughNode {
266    fn name(&self) -> &str {
267        &self.name
268    }
269
270    async fn execute(&self, _ctx: &NodeContext) -> Result<NodeOutput> {
271        Ok(NodeOutput::new())
272    }
273}
274
275/// Type alias for agent node input mapper
276pub type AgentInputMapper = Box<dyn Fn(&State) -> adk_core::Content + Send + Sync>;
277
278/// Type alias for agent node output mapper
279pub type AgentOutputMapper =
280    Box<dyn Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync>;
281
282/// Wrapper to use an existing ADK Agent as a graph node
283pub struct AgentNode {
284    name: String,
285    #[allow(dead_code)]
286    agent: Arc<dyn adk_core::Agent>,
287    /// Map state to agent input content
288    input_mapper: AgentInputMapper,
289    /// Map agent events to state updates
290    output_mapper: AgentOutputMapper,
291}
292
293impl AgentNode {
294    /// Create a new agent node
295    pub fn new(agent: Arc<dyn adk_core::Agent>) -> Self {
296        let name = agent.name().to_string();
297        Self {
298            name,
299            agent,
300            input_mapper: Box::new(default_input_mapper),
301            output_mapper: Box::new(default_output_mapper),
302        }
303    }
304
305    /// Set custom input mapper
306    pub fn with_input_mapper<F>(mut self, mapper: F) -> Self
307    where
308        F: Fn(&State) -> adk_core::Content + Send + Sync + 'static,
309    {
310        self.input_mapper = Box::new(mapper);
311        self
312    }
313
314    /// Set custom output mapper
315    pub fn with_output_mapper<F>(mut self, mapper: F) -> Self
316    where
317        F: Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync + 'static,
318    {
319        self.output_mapper = Box::new(mapper);
320        self
321    }
322}
323
324/// Default input mapper - looks for "messages" or "input" in state
325fn default_input_mapper(state: &State) -> adk_core::Content {
326    // Try to get messages first
327    if let Some(messages) = state.get("messages") {
328        if let Some(arr) = messages.as_array() {
329            if let Some(last) = arr.last() {
330                if let Some(content) = last.get("content").and_then(|c| c.as_str()) {
331                    return adk_core::Content::new("user").with_text(content);
332                }
333            }
334        }
335    }
336
337    // Try input field
338    if let Some(input) = state.get("input") {
339        if let Some(text) = input.as_str() {
340            return adk_core::Content::new("user").with_text(text);
341        }
342    }
343
344    adk_core::Content::new("user")
345}
346
347/// Default output mapper - extracts text content to "messages"
348fn default_output_mapper(events: &[adk_core::Event]) -> HashMap<String, Value> {
349    let mut updates = HashMap::new();
350
351    // Collect text from events
352    let mut messages = Vec::new();
353    for event in events {
354        if let Some(content) = event.content() {
355            let text = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("");
356
357            if !text.is_empty() {
358                messages.push(serde_json::json!({
359                    "role": "assistant",
360                    "content": text
361                }));
362            }
363        }
364    }
365
366    if !messages.is_empty() {
367        updates.insert("messages".to_string(), serde_json::json!(messages));
368    }
369
370    updates
371}
372
373#[async_trait]
374impl Node for AgentNode {
375    fn name(&self) -> &str {
376        &self.name
377    }
378
379    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
380        use futures::StreamExt;
381
382        // Map state to input content
383        let content = (self.input_mapper)(&ctx.state);
384
385        // Create a graph invocation context with the agent
386        let invocation_ctx = Arc::new(GraphInvocationContext::new(
387            ctx.config.thread_id.clone(),
388            content,
389            self.agent.clone(),
390        ));
391
392        // Run the agent and collect events
393        let stream = self.agent.run(invocation_ctx).await.map_err(|e| {
394            crate::error::GraphError::NodeExecutionFailed {
395                node: self.name.clone(),
396                message: e.to_string(),
397            }
398        })?;
399
400        let events: Vec<adk_core::Event> = stream.filter_map(|r| async { r.ok() }).collect().await;
401
402        // Map events to state updates
403        let updates = (self.output_mapper)(&events);
404
405        // Convert agent events to stream events for tracing
406        let mut output = NodeOutput::new().with_updates(updates);
407        for event in &events {
408            if let Ok(json) = serde_json::to_value(event) {
409                output = output.with_event(StreamEvent::custom(&self.name, "agent_event", json));
410            }
411        }
412
413        Ok(output)
414    }
415
416    fn execute_stream<'a>(
417        &'a self,
418        ctx: &'a NodeContext,
419    ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
420        use futures::StreamExt;
421        let name = self.name.clone();
422        let agent = self.agent.clone();
423        let input_mapper = &self.input_mapper;
424        let thread_id = ctx.config.thread_id.clone();
425        let content = (input_mapper)(&ctx.state);
426
427        Box::pin(async_stream::stream! {
428            tracing::debug!("AgentNode::execute_stream called for {}", name);
429            let invocation_ctx = Arc::new(GraphInvocationContext::new(
430                thread_id,
431                content,
432                agent.clone(),
433            ));
434
435            let stream = match agent.run(invocation_ctx).await {
436                Ok(s) => s,
437                Err(e) => {
438                    yield Err(crate::error::GraphError::NodeExecutionFailed {
439                        node: name.clone(),
440                        message: e.to_string(),
441                    });
442                    return;
443                }
444            };
445
446            tokio::pin!(stream);
447            let mut all_events = Vec::new();
448
449            while let Some(result) = stream.next().await {
450                match result {
451                    Ok(event) => {
452                        // Emit streaming event immediately
453                        if let Some(content) = event.content() {
454                            let text: String = content.parts.iter().filter_map(|p| p.text()).collect();
455                            if !text.is_empty() {
456                                yield Ok(StreamEvent::Message {
457                                    node: name.clone(),
458                                    content: text,
459                                    is_final: false,
460                                });
461                            }
462                        }
463                        all_events.push(event);
464                    }
465                    Err(e) => {
466                        yield Err(crate::error::GraphError::NodeExecutionFailed {
467                            node: name.clone(),
468                            message: e.to_string(),
469                        });
470                        return;
471                    }
472                }
473            }
474
475            // Emit final events
476            for event in &all_events {
477                if let Ok(json) = serde_json::to_value(event) {
478                    yield Ok(StreamEvent::custom(&name, "agent_event", json));
479                }
480            }
481        })
482    }
483}
484
485/// Full InvocationContext implementation for running agents within graph nodes
486struct GraphInvocationContext {
487    invocation_id: String,
488    user_content: adk_core::Content,
489    agent: Arc<dyn adk_core::Agent>,
490    session: Arc<GraphSession>,
491    run_config: adk_core::RunConfig,
492    ended: std::sync::atomic::AtomicBool,
493}
494
495impl GraphInvocationContext {
496    fn new(
497        session_id: String,
498        user_content: adk_core::Content,
499        agent: Arc<dyn adk_core::Agent>,
500    ) -> Self {
501        let invocation_id = uuid::Uuid::new_v4().to_string();
502        let session = Arc::new(GraphSession::new(session_id));
503        // Add user content to history
504        session.append_content(user_content.clone());
505        Self {
506            invocation_id,
507            user_content,
508            agent,
509            session,
510            run_config: adk_core::RunConfig::default(),
511            ended: std::sync::atomic::AtomicBool::new(false),
512        }
513    }
514}
515
516// Implement ReadonlyContext (required by CallbackContext)
517impl adk_core::ReadonlyContext for GraphInvocationContext {
518    fn invocation_id(&self) -> &str {
519        &self.invocation_id
520    }
521
522    fn agent_name(&self) -> &str {
523        self.agent.name()
524    }
525
526    fn user_id(&self) -> &str {
527        "graph_user"
528    }
529
530    fn app_name(&self) -> &str {
531        "graph_app"
532    }
533
534    fn session_id(&self) -> &str {
535        &self.session.id
536    }
537
538    fn branch(&self) -> &str {
539        "main"
540    }
541
542    fn user_content(&self) -> &adk_core::Content {
543        &self.user_content
544    }
545}
546
547// Implement CallbackContext (required by InvocationContext)
548#[async_trait]
549impl adk_core::CallbackContext for GraphInvocationContext {
550    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
551        None
552    }
553}
554
555// Implement InvocationContext
556#[async_trait]
557impl adk_core::InvocationContext for GraphInvocationContext {
558    fn agent(&self) -> Arc<dyn adk_core::Agent> {
559        self.agent.clone()
560    }
561
562    fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
563        None
564    }
565
566    fn session(&self) -> &dyn adk_core::Session {
567        self.session.as_ref()
568    }
569
570    fn run_config(&self) -> &adk_core::RunConfig {
571        &self.run_config
572    }
573
574    fn end_invocation(&self) {
575        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
576    }
577
578    fn ended(&self) -> bool {
579        self.ended.load(std::sync::atomic::Ordering::SeqCst)
580    }
581}
582
583/// Minimal Session implementation for graph execution
584struct GraphSession {
585    id: String,
586    state: GraphState,
587    history: std::sync::RwLock<Vec<adk_core::Content>>,
588}
589
590impl GraphSession {
591    fn new(id: String) -> Self {
592        Self { id, state: GraphState::new(), history: std::sync::RwLock::new(Vec::new()) }
593    }
594
595    fn append_content(&self, content: adk_core::Content) {
596        if let Ok(mut h) = self.history.write() {
597            h.push(content);
598        }
599    }
600}
601
602impl adk_core::Session for GraphSession {
603    fn id(&self) -> &str {
604        &self.id
605    }
606
607    fn app_name(&self) -> &str {
608        "graph_app"
609    }
610
611    fn user_id(&self) -> &str {
612        "graph_user"
613    }
614
615    fn state(&self) -> &dyn adk_core::State {
616        &self.state
617    }
618
619    fn conversation_history(&self) -> Vec<adk_core::Content> {
620        self.history.read().ok().map(|h| h.clone()).unwrap_or_default()
621    }
622
623    fn append_to_history(&self, content: adk_core::Content) {
624        self.append_content(content);
625    }
626}
627
628/// Minimal State implementation for graph execution
629struct GraphState {
630    data: std::sync::RwLock<std::collections::HashMap<String, serde_json::Value>>,
631}
632
633impl GraphState {
634    fn new() -> Self {
635        Self { data: std::sync::RwLock::new(std::collections::HashMap::new()) }
636    }
637}
638
639impl adk_core::State for GraphState {
640    fn get(&self, key: &str) -> Option<serde_json::Value> {
641        self.data.read().ok()?.get(key).cloned()
642    }
643
644    fn set(&mut self, key: String, value: serde_json::Value) {
645        if let Err(msg) = adk_core::validate_state_key(&key) {
646            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
647            return;
648        }
649        if let Ok(mut data) = self.data.write() {
650            data.insert(key, value);
651        }
652    }
653
654    fn all(&self) -> std::collections::HashMap<String, serde_json::Value> {
655        self.data.read().ok().map(|d| d.clone()).unwrap_or_default()
656    }
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662
663    #[tokio::test]
664    async fn test_function_node() {
665        let node = FunctionNode::new("test", |_ctx| async {
666            Ok(NodeOutput::new().with_update("result", serde_json::json!("success")))
667        });
668
669        assert_eq!(node.name(), "test");
670
671        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
672        let output = node.execute(&ctx).await.unwrap();
673
674        assert_eq!(output.updates.get("result"), Some(&serde_json::json!("success")));
675    }
676
677    #[tokio::test]
678    async fn test_passthrough_node() {
679        let node = PassthroughNode::new("pass");
680        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
681        let output = node.execute(&ctx).await.unwrap();
682
683        assert!(output.updates.is_empty());
684        assert!(output.interrupt.is_none());
685    }
686
687    #[test]
688    fn test_node_output_builder() {
689        let output = NodeOutput::new().with_update("a", 1).with_update("b", "hello");
690
691        assert_eq!(output.updates.get("a"), Some(&serde_json::json!(1)));
692        assert_eq!(output.updates.get("b"), Some(&serde_json::json!("hello")));
693    }
694}