Skip to main content

adk_graph/
node.rs

1//! Node types for graph execution
2//!
3//! Nodes are the computational units in a graph. They receive state and return updates.
4
5use crate::error::Result;
6use crate::interrupt::Interrupt;
7use crate::state::State;
8use crate::stream::StreamEvent;
9use crate::timeout::ProgressHandle;
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16
17/// Configuration passed to nodes during execution
18#[derive(Clone)]
19pub struct ExecutionConfig {
20    /// Thread identifier for checkpointing
21    pub thread_id: String,
22    /// Resume from a specific checkpoint
23    pub resume_from: Option<String>,
24    /// Recursion limit for cycles
25    pub recursion_limit: usize,
26    /// Additional configuration
27    pub metadata: HashMap<String, Value>,
28}
29
30impl ExecutionConfig {
31    /// Create a new config with the given thread ID
32    pub fn new(thread_id: &str) -> Self {
33        Self {
34            thread_id: thread_id.to_string(),
35            resume_from: None,
36            recursion_limit: 50,
37            metadata: HashMap::new(),
38        }
39    }
40
41    /// Set the recursion limit
42    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
43        self.recursion_limit = limit;
44        self
45    }
46
47    /// Resume from a specific checkpoint
48    pub fn with_resume_from(mut self, checkpoint_id: &str) -> Self {
49        self.resume_from = Some(checkpoint_id.to_string());
50        self
51    }
52
53    /// Add metadata
54    pub fn with_metadata(mut self, key: &str, value: Value) -> Self {
55        self.metadata.insert(key.to_string(), value);
56        self
57    }
58}
59
60impl Default for ExecutionConfig {
61    fn default() -> Self {
62        Self::new(&uuid::Uuid::new_v4().to_string())
63    }
64}
65
66/// Context passed to nodes during execution
67pub struct NodeContext {
68    /// Current graph state (read-only view)
69    pub state: State,
70    /// Configuration for this execution
71    pub config: ExecutionConfig,
72    /// Current step number
73    pub step: usize,
74    /// Optional progress handle for idle timeout tracking.
75    /// When present, calling [`report_progress()`](Self::report_progress) resets the idle timeout counter.
76    progress_handle: Option<ProgressHandle>,
77}
78
79impl NodeContext {
80    /// Create a new node context
81    pub fn new(state: State, config: ExecutionConfig, step: usize) -> Self {
82        Self { state, config, step, progress_handle: None }
83    }
84
85    /// Get a value from state
86    pub fn get(&self, key: &str) -> Option<&Value> {
87        self.state.get(key)
88    }
89
90    /// Get a value from state as a specific type
91    pub fn get_as<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
92        self.state.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
93    }
94
95    /// Report progress, resetting the idle timeout counter.
96    ///
97    /// Nodes performing long-running work should call this periodically to
98    /// prevent the idle timeout from firing. If no progress handle is attached
99    /// (e.g., when no idle timeout is configured), this is a no-op.
100    ///
101    /// # Example
102    ///
103    /// ```rust,ignore
104    /// async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
105    ///     for chunk in large_dataset.chunks(100) {
106    ///         process(chunk).await;
107    ///         ctx.report_progress(); // reset idle timeout
108    ///     }
109    ///     Ok(NodeOutput::new())
110    /// }
111    /// ```
112    pub fn report_progress(&self) {
113        if let Some(handle) = &self.progress_handle {
114            handle.report_progress();
115        }
116    }
117
118    /// Attach a progress handle for idle timeout tracking.
119    ///
120    /// This is called by the executor before running a node with an idle timeout
121    /// policy. Nodes do not need to call this directly.
122    pub fn set_progress_handle(&mut self, handle: ProgressHandle) {
123        self.progress_handle = Some(handle);
124    }
125
126    /// Get a reference to the attached progress handle, if any.
127    pub fn progress_handle(&self) -> Option<&ProgressHandle> {
128        self.progress_handle.as_ref()
129    }
130}
131
132/// Output from a node execution
133#[derive(Default)]
134pub struct NodeOutput {
135    /// State updates to apply
136    pub updates: HashMap<String, Value>,
137    /// Optional interrupt request
138    pub interrupt: Option<Interrupt>,
139    /// Custom stream events
140    pub events: Vec<StreamEvent>,
141}
142
143impl NodeOutput {
144    /// Create a new empty output
145    pub fn new() -> Self {
146        Self::default()
147    }
148
149    /// Add a state update
150    pub fn with_update(mut self, key: &str, value: impl Into<Value>) -> Self {
151        self.updates.insert(key.to_string(), value.into());
152        self
153    }
154
155    /// Add multiple state updates
156    pub fn with_updates(mut self, updates: HashMap<String, Value>) -> Self {
157        self.updates.extend(updates);
158        self
159    }
160
161    /// Set an interrupt
162    pub fn with_interrupt(mut self, interrupt: Interrupt) -> Self {
163        self.interrupt = Some(interrupt);
164        self
165    }
166
167    /// Add a custom stream event
168    pub fn with_event(mut self, event: StreamEvent) -> Self {
169        self.events.push(event);
170        self
171    }
172
173    /// Create output that triggers a dynamic interrupt
174    pub fn interrupt(message: &str) -> Self {
175        Self::new().with_interrupt(crate::interrupt::interrupt(message))
176    }
177
178    /// Create output that triggers a dynamic interrupt with data
179    pub fn interrupt_with_data(message: &str, data: Value) -> Self {
180        Self::new().with_interrupt(crate::interrupt::interrupt_with_data(message, data))
181    }
182}
183
184/// A node in the graph
185#[async_trait]
186pub trait Node: Send + Sync {
187    /// Node identifier
188    fn name(&self) -> &str;
189
190    /// Execute the node and return state updates
191    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput>;
192
193    /// Stream execution events (default: wraps execute)
194    fn execute_stream<'a>(
195        &'a self,
196        ctx: &'a NodeContext,
197    ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
198        let _name = self.name().to_string();
199        Box::pin(async_stream::stream! {
200            match self.execute(ctx).await {
201                Ok(output) => {
202                    for event in output.events {
203                        yield Ok(event);
204                    }
205                }
206                Err(e) => yield Err(e),
207            }
208        })
209    }
210}
211
212/// Type alias for boxed node
213pub type BoxedNode = Box<dyn Node>;
214
215/// Type alias for async function signature
216pub type AsyncNodeFn = Box<
217    dyn Fn(NodeContext) -> Pin<Box<dyn Future<Output = Result<NodeOutput>> + Send>> + Send + Sync,
218>;
219
220/// Function node - wraps an async function as a node
221pub struct FunctionNode {
222    name: String,
223    func: AsyncNodeFn,
224}
225
226impl FunctionNode {
227    /// Create a new function node
228    pub fn new<F, Fut>(name: &str, func: F) -> Self
229    where
230        F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
231        Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
232    {
233        Self { name: name.to_string(), func: Box::new(move |ctx| Box::pin(func(ctx))) }
234    }
235}
236
237#[async_trait]
238impl Node for FunctionNode {
239    fn name(&self) -> &str {
240        &self.name
241    }
242
243    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
244        let mut ctx_owned = NodeContext::new(ctx.state.clone(), ctx.config.clone(), ctx.step);
245        if let Some(handle) = ctx.progress_handle() {
246            ctx_owned.set_progress_handle(handle.clone());
247        }
248        (self.func)(ctx_owned).await
249    }
250}
251
252/// Passthrough node - just passes state through unchanged
253pub struct PassthroughNode {
254    name: String,
255}
256
257impl PassthroughNode {
258    /// Create a new passthrough node
259    pub fn new(name: &str) -> Self {
260        Self { name: name.to_string() }
261    }
262}
263
264#[async_trait]
265impl Node for PassthroughNode {
266    fn name(&self) -> &str {
267        &self.name
268    }
269
270    async fn execute(&self, _ctx: &NodeContext) -> Result<NodeOutput> {
271        Ok(NodeOutput::new())
272    }
273}
274
275/// Type alias for agent node input mapper
276pub type AgentInputMapper = Box<dyn Fn(&State) -> adk_core::Content + Send + Sync>;
277
278/// Type alias for agent node output mapper
279pub type AgentOutputMapper =
280    Box<dyn Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync>;
281
282/// Wrapper to use an existing ADK Agent as a graph node
283pub struct AgentNode {
284    name: String,
285    #[allow(dead_code)]
286    agent: Arc<dyn adk_core::Agent>,
287    /// Map state to agent input content
288    input_mapper: AgentInputMapper,
289    /// Map agent events to state updates
290    output_mapper: AgentOutputMapper,
291}
292
293impl AgentNode {
294    /// Create a new agent node
295    pub fn new(agent: Arc<dyn adk_core::Agent>) -> Self {
296        let name = agent.name().to_string();
297        Self {
298            name,
299            agent,
300            input_mapper: Box::new(default_input_mapper),
301            output_mapper: Box::new(default_output_mapper),
302        }
303    }
304
305    /// Set custom input mapper
306    pub fn with_input_mapper<F>(mut self, mapper: F) -> Self
307    where
308        F: Fn(&State) -> adk_core::Content + Send + Sync + 'static,
309    {
310        self.input_mapper = Box::new(mapper);
311        self
312    }
313
314    /// Set custom output mapper
315    pub fn with_output_mapper<F>(mut self, mapper: F) -> Self
316    where
317        F: Fn(&[adk_core::Event]) -> HashMap<String, Value> + Send + Sync + 'static,
318    {
319        self.output_mapper = Box::new(mapper);
320        self
321    }
322}
323
324/// Default input mapper - looks for "messages" or "input" in state
325fn default_input_mapper(state: &State) -> adk_core::Content {
326    // Try to get messages first
327    if let Some(messages) = state.get("messages")
328        && let Some(arr) = messages.as_array()
329        && let Some(last) = arr.last()
330        && let Some(content) = last.get("content").and_then(|c| c.as_str())
331    {
332        return adk_core::Content::new("user").with_text(content);
333    }
334
335    // Try input field
336    if let Some(input) = state.get("input")
337        && let Some(text) = input.as_str()
338    {
339        return adk_core::Content::new("user").with_text(text);
340    }
341
342    adk_core::Content::new("user")
343}
344
345/// Default output mapper - extracts text content to "messages"
346fn default_output_mapper(events: &[adk_core::Event]) -> HashMap<String, Value> {
347    let mut updates = HashMap::new();
348
349    // Collect text from events
350    let mut messages = Vec::new();
351    for event in events {
352        if let Some(content) = event.content() {
353            let text = content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join("");
354
355            if !text.is_empty() {
356                messages.push(serde_json::json!({
357                    "role": "assistant",
358                    "content": text
359                }));
360            }
361        }
362    }
363
364    if !messages.is_empty() {
365        updates.insert("messages".to_string(), serde_json::json!(messages));
366    }
367
368    updates
369}
370
371#[async_trait]
372impl Node for AgentNode {
373    fn name(&self) -> &str {
374        &self.name
375    }
376
377    async fn execute(&self, ctx: &NodeContext) -> Result<NodeOutput> {
378        use futures::StreamExt;
379
380        // Map state to input content
381        let content = (self.input_mapper)(&ctx.state);
382
383        // Create a graph invocation context with the agent
384        let invocation_ctx = Arc::new(GraphInvocationContext::new(
385            ctx.config.thread_id.clone(),
386            content,
387            self.agent.clone(),
388        ));
389
390        // Run the agent and collect events
391        let stream = self.agent.run(invocation_ctx).await.map_err(|e| {
392            crate::error::GraphError::NodeExecutionFailed {
393                node: self.name.clone(),
394                message: e.to_string(),
395            }
396        })?;
397
398        let events: Vec<adk_core::Event> = stream.filter_map(|r| async { r.ok() }).collect().await;
399
400        // Map events to state updates
401        let updates = (self.output_mapper)(&events);
402
403        // Convert agent events to stream events for tracing
404        let mut output = NodeOutput::new().with_updates(updates);
405        for event in &events {
406            if let Ok(json) = serde_json::to_value(event) {
407                output = output.with_event(StreamEvent::custom(&self.name, "agent_event", json));
408            }
409        }
410
411        Ok(output)
412    }
413
414    fn execute_stream<'a>(
415        &'a self,
416        ctx: &'a NodeContext,
417    ) -> Pin<Box<dyn futures::Stream<Item = Result<StreamEvent>> + Send + 'a>> {
418        use futures::StreamExt;
419        let name = self.name.clone();
420        let agent = self.agent.clone();
421        let input_mapper = &self.input_mapper;
422        let thread_id = ctx.config.thread_id.clone();
423        let content = (input_mapper)(&ctx.state);
424
425        Box::pin(async_stream::stream! {
426            tracing::debug!("AgentNode::execute_stream called for {}", name);
427            let invocation_ctx = Arc::new(GraphInvocationContext::new(
428                thread_id,
429                content,
430                agent.clone(),
431            ));
432
433            let stream = match agent.run(invocation_ctx).await {
434                Ok(s) => s,
435                Err(e) => {
436                    yield Err(crate::error::GraphError::NodeExecutionFailed {
437                        node: name.clone(),
438                        message: e.to_string(),
439                    });
440                    return;
441                }
442            };
443
444            tokio::pin!(stream);
445            let mut all_events = Vec::new();
446
447            while let Some(result) = stream.next().await {
448                match result {
449                    Ok(event) => {
450                        // Emit streaming event immediately
451                        if let Some(content) = event.content() {
452                            let text: String = content.parts.iter().filter_map(|p| p.text()).collect();
453                            if !text.is_empty() {
454                                yield Ok(StreamEvent::Message {
455                                    node: name.clone(),
456                                    content: text,
457                                    is_final: false,
458                                });
459                            }
460                        }
461                        all_events.push(event);
462                    }
463                    Err(e) => {
464                        yield Err(crate::error::GraphError::NodeExecutionFailed {
465                            node: name.clone(),
466                            message: e.to_string(),
467                        });
468                        return;
469                    }
470                }
471            }
472
473            // Emit final events
474            for event in &all_events {
475                if let Ok(json) = serde_json::to_value(event) {
476                    yield Ok(StreamEvent::custom(&name, "agent_event", json));
477                }
478            }
479        })
480    }
481}
482
483/// Full InvocationContext implementation for running agents within graph nodes
484struct GraphInvocationContext {
485    invocation_id: String,
486    user_content: adk_core::Content,
487    agent: Arc<dyn adk_core::Agent>,
488    session: Arc<GraphSession>,
489    run_config: adk_core::RunConfig,
490    ended: std::sync::atomic::AtomicBool,
491}
492
493impl GraphInvocationContext {
494    fn new(
495        session_id: String,
496        user_content: adk_core::Content,
497        agent: Arc<dyn adk_core::Agent>,
498    ) -> Self {
499        let invocation_id = uuid::Uuid::new_v4().to_string();
500        let session = Arc::new(GraphSession::new(session_id));
501        // Add user content to history
502        session.append_content(user_content.clone());
503        Self {
504            invocation_id,
505            user_content,
506            agent,
507            session,
508            run_config: adk_core::RunConfig::default(),
509            ended: std::sync::atomic::AtomicBool::new(false),
510        }
511    }
512}
513
514// Implement ReadonlyContext (required by CallbackContext)
515impl adk_core::ReadonlyContext for GraphInvocationContext {
516    fn invocation_id(&self) -> &str {
517        &self.invocation_id
518    }
519
520    fn agent_name(&self) -> &str {
521        self.agent.name()
522    }
523
524    fn user_id(&self) -> &str {
525        "graph_user"
526    }
527
528    fn app_name(&self) -> &str {
529        "graph_app"
530    }
531
532    fn session_id(&self) -> &str {
533        &self.session.id
534    }
535
536    fn branch(&self) -> &str {
537        "main"
538    }
539
540    fn user_content(&self) -> &adk_core::Content {
541        &self.user_content
542    }
543}
544
545// Implement CallbackContext (required by InvocationContext)
546#[async_trait]
547impl adk_core::CallbackContext for GraphInvocationContext {
548    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
549        None
550    }
551}
552
553// Implement InvocationContext
554#[async_trait]
555impl adk_core::InvocationContext for GraphInvocationContext {
556    fn agent(&self) -> Arc<dyn adk_core::Agent> {
557        self.agent.clone()
558    }
559
560    fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
561        None
562    }
563
564    fn session(&self) -> &dyn adk_core::Session {
565        self.session.as_ref()
566    }
567
568    fn run_config(&self) -> &adk_core::RunConfig {
569        &self.run_config
570    }
571
572    fn end_invocation(&self) {
573        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
574    }
575
576    fn ended(&self) -> bool {
577        self.ended.load(std::sync::atomic::Ordering::SeqCst)
578    }
579}
580
581/// Minimal Session implementation for graph execution
582struct GraphSession {
583    id: String,
584    state: GraphState,
585    history: std::sync::RwLock<Vec<adk_core::Content>>,
586}
587
588impl GraphSession {
589    fn new(id: String) -> Self {
590        Self { id, state: GraphState::new(), history: std::sync::RwLock::new(Vec::new()) }
591    }
592
593    fn append_content(&self, content: adk_core::Content) {
594        if let Ok(mut h) = self.history.write() {
595            h.push(content);
596        }
597    }
598}
599
600impl adk_core::Session for GraphSession {
601    fn id(&self) -> &str {
602        &self.id
603    }
604
605    fn app_name(&self) -> &str {
606        "graph_app"
607    }
608
609    fn user_id(&self) -> &str {
610        "graph_user"
611    }
612
613    fn state(&self) -> &dyn adk_core::State {
614        &self.state
615    }
616
617    fn conversation_history(&self) -> Vec<adk_core::Content> {
618        self.history.read().ok().map(|h| h.clone()).unwrap_or_default()
619    }
620
621    fn append_to_history(&self, content: adk_core::Content) {
622        self.append_content(content);
623    }
624}
625
626/// Minimal State implementation for graph execution
627struct GraphState {
628    data: std::sync::RwLock<std::collections::HashMap<String, serde_json::Value>>,
629}
630
631impl GraphState {
632    fn new() -> Self {
633        Self { data: std::sync::RwLock::new(std::collections::HashMap::new()) }
634    }
635}
636
637impl adk_core::State for GraphState {
638    fn get(&self, key: &str) -> Option<serde_json::Value> {
639        self.data.read().ok()?.get(key).cloned()
640    }
641
642    fn set(&mut self, key: String, value: serde_json::Value) {
643        if let Err(msg) = adk_core::validate_state_key(&key) {
644            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
645            return;
646        }
647        if let Ok(mut data) = self.data.write() {
648            data.insert(key, value);
649        }
650    }
651
652    fn all(&self) -> std::collections::HashMap<String, serde_json::Value> {
653        self.data.read().ok().map(|d| d.clone()).unwrap_or_default()
654    }
655}
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660
661    #[tokio::test]
662    async fn test_function_node() {
663        let node = FunctionNode::new("test", |_ctx| async {
664            Ok(NodeOutput::new().with_update("result", serde_json::json!("success")))
665        });
666
667        assert_eq!(node.name(), "test");
668
669        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
670        let output = node.execute(&ctx).await.unwrap();
671
672        assert_eq!(output.updates.get("result"), Some(&serde_json::json!("success")));
673    }
674
675    #[tokio::test]
676    async fn test_passthrough_node() {
677        let node = PassthroughNode::new("pass");
678        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
679        let output = node.execute(&ctx).await.unwrap();
680
681        assert!(output.updates.is_empty());
682        assert!(output.interrupt.is_none());
683    }
684
685    #[test]
686    fn test_node_output_builder() {
687        let output = NodeOutput::new().with_update("a", 1).with_update("b", "hello");
688
689        assert_eq!(output.updates.get("a"), Some(&serde_json::json!(1)));
690        assert_eq!(output.updates.get("b"), Some(&serde_json::json!("hello")));
691    }
692}