adk_graph/
executor.rs

1//! Pregel-based execution engine for graphs
2//!
3//! Executes graphs using the Pregel model with super-steps.
4
5use crate::error::{GraphError, InterruptedExecution, Result};
6use crate::graph::CompiledGraph;
7use crate::interrupt::Interrupt;
8use crate::node::{ExecutionConfig, NodeContext};
9use crate::state::{Checkpoint, State};
10use crate::stream::{StreamEvent, StreamMode};
11use futures::stream::{self, StreamExt};
12use std::time::Instant;
13
14/// Result of a super-step execution
15#[derive(Default)]
16pub struct SuperStepResult {
17    /// Nodes that were executed
18    pub executed_nodes: Vec<String>,
19    /// Interrupt if one occurred
20    pub interrupt: Option<Interrupt>,
21    /// Stream events generated
22    pub events: Vec<StreamEvent>,
23}
24
25/// Pregel-based executor for graphs
26pub struct PregelExecutor<'a> {
27    graph: &'a CompiledGraph,
28    config: ExecutionConfig,
29    state: State,
30    step: usize,
31    pending_nodes: Vec<String>,
32}
33
34impl<'a> PregelExecutor<'a> {
35    /// Create a new executor
36    pub fn new(graph: &'a CompiledGraph, config: ExecutionConfig) -> Self {
37        Self { graph, config, state: State::new(), step: 0, pending_nodes: vec![] }
38    }
39
40    /// Run the graph to completion
41    pub async fn run(&mut self, input: State) -> Result<State> {
42        // Initialize state
43        self.state = self.initialize_state(input).await?;
44        self.pending_nodes = self.graph.get_entry_nodes();
45
46        // Main execution loop
47        while !self.pending_nodes.is_empty() {
48            // Check recursion limit
49            if self.step >= self.config.recursion_limit {
50                return Err(GraphError::RecursionLimitExceeded(self.step));
51            }
52
53            // Execute super-step
54            let result = self.execute_super_step().await?;
55
56            // Handle interrupts
57            if let Some(interrupt) = result.interrupt {
58                let checkpoint_id = self.save_checkpoint().await?;
59                return Err(GraphError::Interrupted(Box::new(InterruptedExecution::new(
60                    self.config.thread_id.clone(),
61                    checkpoint_id,
62                    interrupt,
63                    self.state.clone(),
64                    self.step,
65                ))));
66            }
67
68            // Save checkpoint after each step
69            self.save_checkpoint().await?;
70
71            // Check if we're done (all paths led to END)
72            if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
73                let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
74                if next.is_empty() {
75                    break;
76                }
77            }
78
79            // Determine next nodes
80            self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
81            self.step += 1;
82        }
83
84        Ok(self.state.clone())
85    }
86
87    /// Run with streaming
88    pub fn run_stream(
89        mut self,
90        input: State,
91        mode: StreamMode,
92    ) -> impl futures::Stream<Item = Result<StreamEvent>> + 'a {
93        async_stream::stream! {
94            // Initialize state
95            match self.initialize_state(input).await {
96                Ok(state) => self.state = state,
97                Err(e) => {
98                    yield Err(e);
99                    return;
100                }
101            }
102            self.pending_nodes = self.graph.get_entry_nodes();
103
104            // Stream initial state if requested
105            if matches!(mode, StreamMode::Values) {
106                yield Ok(StreamEvent::state(self.state.clone(), self.step));
107            }
108
109            // Main execution loop
110            while !self.pending_nodes.is_empty() {
111                // Check recursion limit
112                if self.step >= self.config.recursion_limit {
113                    yield Err(GraphError::RecursionLimitExceeded(self.step));
114                    return;
115                }
116
117                // Emit node_start events BEFORE execution (in Debug mode)
118                if matches!(mode, StreamMode::Debug | StreamMode::Custom | StreamMode::Messages) {
119                    for node_name in &self.pending_nodes {
120                        yield Ok(StreamEvent::node_start(node_name, self.step));
121                    }
122                }
123
124                // For Messages mode, stream from nodes directly
125                if matches!(mode, StreamMode::Messages) {
126                    let mut result = SuperStepResult::default();
127
128                    for node_name in &self.pending_nodes {
129                        if let Some(node) = self.graph.nodes.get(node_name) {
130                            let ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
131                            let start = std::time::Instant::now();
132
133                            let mut node_stream = node.execute_stream(&ctx);
134                            let mut collected_events = Vec::new();
135
136                            while let Some(event_result) = node_stream.next().await {
137                                match event_result {
138                                    Ok(event) => {
139                                        // Yield Message events immediately
140                                        if matches!(event, StreamEvent::Message { .. }) {
141                                            yield Ok(event.clone());
142                                        }
143                                        collected_events.push(event);
144                                    }
145                                    Err(e) => {
146                                        yield Err(e);
147                                        return;
148                                    }
149                                }
150                            }
151
152                            let duration_ms = start.elapsed().as_millis() as u64;
153                            result.executed_nodes.push(node_name.clone());
154                            result.events.push(StreamEvent::node_end(node_name, self.step, duration_ms));
155                            result.events.extend(collected_events);
156
157                            // Get output from execute for state updates
158                            if let Ok(output) = node.execute(&ctx).await {
159                                for (key, value) in output.updates {
160                                    self.graph.schema.apply_update(&mut self.state, &key, value);
161                                }
162                            }
163                        }
164                    }
165
166                    // Yield node_end events
167                    for event in &result.events {
168                        if matches!(event, StreamEvent::NodeEnd { .. }) {
169                            yield Ok(event.clone());
170                        }
171                    }
172
173                    self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
174                    self.step += 1;
175                    continue;
176                }
177
178                // Execute super-step (non-streaming)
179                let result = match self.execute_super_step().await {
180                    Ok(r) => r,
181                    Err(e) => {
182                        yield Err(e);
183                        return;
184                    }
185                };
186
187                // Yield events based on mode (node_end and custom events)
188                for event in &result.events {
189                    match (&mode, &event) {
190                        // Skip node_start since we already emitted it above
191                        (StreamMode::Custom | StreamMode::Debug, StreamEvent::NodeStart { .. }) => {}
192                        (StreamMode::Custom, _) => yield Ok(event.clone()),
193                        (StreamMode::Debug, _) => yield Ok(event.clone()),
194                        _ => {}
195                    }
196                }
197
198                // Yield state/updates
199                match mode {
200                    StreamMode::Values => {
201                        yield Ok(StreamEvent::state(self.state.clone(), self.step));
202                    }
203                    StreamMode::Updates => {
204                        yield Ok(StreamEvent::step_complete(
205                            self.step,
206                            result.executed_nodes.clone(),
207                        ));
208                    }
209                    _ => {}
210                }
211
212                // Handle interrupts
213                if let Some(interrupt) = result.interrupt {
214                    yield Ok(StreamEvent::interrupted(
215                        result.executed_nodes.first().map(|s| s.as_str()).unwrap_or("unknown"),
216                        &interrupt.to_string(),
217                    ));
218                    return;
219                }
220
221                // Check if done
222                if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
223                    let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
224                    if next.is_empty() {
225                        break;
226                    }
227                }
228
229                self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
230                self.step += 1;
231            }
232
233            yield Ok(StreamEvent::done(self.state.clone(), self.step + 1));
234        }
235    }
236
237    /// Initialize state from input and/or checkpoint
238    async fn initialize_state(&self, input: State) -> Result<State> {
239        // Start with schema defaults
240        let mut state = self.graph.schema.initialize_state();
241
242        // If resuming from checkpoint, load it
243        if let Some(checkpoint_id) = &self.config.resume_from {
244            if let Some(cp) = self.graph.checkpointer.as_ref() {
245                if let Some(checkpoint) = cp.load_by_id(checkpoint_id).await? {
246                    state = checkpoint.state;
247                }
248            }
249        } else if let Some(cp) = self.graph.checkpointer.as_ref() {
250            // Try to load latest checkpoint for thread
251            if let Some(checkpoint) = cp.load(&self.config.thread_id).await? {
252                state = checkpoint.state;
253            }
254        }
255
256        // Merge input into state
257        for (key, value) in input {
258            self.graph.schema.apply_update(&mut state, &key, value);
259        }
260
261        Ok(state)
262    }
263
264    /// Execute one super-step (plan -> execute -> update)
265    async fn execute_super_step(&mut self) -> Result<SuperStepResult> {
266        let mut result = SuperStepResult::default();
267
268        // Check for interrupt_before
269        for node_name in &self.pending_nodes {
270            if self.graph.interrupt_before.contains(node_name) {
271                return Ok(SuperStepResult {
272                    interrupt: Some(Interrupt::Before(node_name.clone())),
273                    ..Default::default()
274                });
275            }
276        }
277
278        // Execute all pending nodes in parallel
279        let nodes: Vec<_> = self
280            .pending_nodes
281            .iter()
282            .filter_map(|name| self.graph.nodes.get(name).map(|n| (name.clone(), n.clone())))
283            .collect();
284
285        let futures: Vec<_> = nodes
286            .into_iter()
287            .map(|(name, node)| {
288                let ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
289                let step = self.step;
290                async move {
291                    let start = Instant::now();
292                    let output = node.execute(&ctx).await;
293                    let duration_ms = start.elapsed().as_millis() as u64;
294                    (name, output, duration_ms, step)
295                }
296            })
297            .collect();
298
299        let outputs: Vec<_> =
300            stream::iter(futures).buffer_unordered(self.pending_nodes.len()).collect().await;
301
302        // Collect all updates and check for errors/interrupts
303        let mut all_updates = Vec::new();
304
305        for (node_name, output_result, duration_ms, step) in outputs {
306            result.executed_nodes.push(node_name.clone());
307            result.events.push(StreamEvent::node_end(&node_name, step, duration_ms));
308
309            match output_result {
310                Ok(output) => {
311                    // Check for dynamic interrupt
312                    if let Some(interrupt) = output.interrupt {
313                        return Ok(SuperStepResult {
314                            interrupt: Some(interrupt),
315                            executed_nodes: result.executed_nodes,
316                            events: result.events,
317                        });
318                    }
319
320                    // Collect custom events
321                    result.events.extend(output.events);
322
323                    // Collect updates
324                    all_updates.push(output.updates);
325                }
326                Err(e) => {
327                    return Err(GraphError::NodeExecutionFailed {
328                        node: node_name,
329                        message: e.to_string(),
330                    });
331                }
332            }
333        }
334
335        // Apply all updates atomically using reducers
336        for updates in all_updates {
337            for (key, value) in updates {
338                self.graph.schema.apply_update(&mut self.state, &key, value);
339            }
340        }
341
342        // Check for interrupt_after
343        for node_name in &result.executed_nodes {
344            if self.graph.interrupt_after.contains(node_name) {
345                return Ok(SuperStepResult {
346                    interrupt: Some(Interrupt::After(node_name.clone())),
347                    ..result
348                });
349            }
350        }
351
352        Ok(result)
353    }
354
355    /// Save a checkpoint
356    async fn save_checkpoint(&self) -> Result<String> {
357        if let Some(cp) = &self.graph.checkpointer {
358            let checkpoint = Checkpoint::new(
359                &self.config.thread_id,
360                self.state.clone(),
361                self.step,
362                self.pending_nodes.clone(),
363            );
364            return cp.save(&checkpoint).await;
365        }
366        Ok(String::new())
367    }
368}
369
370/// Convenience methods for CompiledGraph
371impl CompiledGraph {
372    /// Execute the graph synchronously
373    pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
374        let mut executor = PregelExecutor::new(self, config);
375        executor.run(input).await
376    }
377
378    /// Execute with streaming
379    pub fn stream(
380        &self,
381        input: State,
382        config: ExecutionConfig,
383        mode: StreamMode,
384    ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
385        eprintln!("DEBUG: CompiledGraph::stream called with mode {:?}", mode);
386        let executor = PregelExecutor::new(self, config);
387        executor.run_stream(input, mode)
388    }
389
390    /// Get current state for a thread
391    pub async fn get_state(&self, thread_id: &str) -> Result<Option<State>> {
392        if let Some(cp) = &self.checkpointer {
393            Ok(cp.load(thread_id).await?.map(|c| c.state))
394        } else {
395            Ok(None)
396        }
397    }
398
399    /// Update state for a thread (for human-in-the-loop)
400    pub async fn update_state(
401        &self,
402        thread_id: &str,
403        updates: impl IntoIterator<Item = (String, serde_json::Value)>,
404    ) -> Result<()> {
405        if let Some(cp) = &self.checkpointer {
406            if let Some(checkpoint) = cp.load(thread_id).await? {
407                let mut state = checkpoint.state;
408                for (key, value) in updates {
409                    self.schema.apply_update(&mut state, &key, value);
410                }
411                let new_checkpoint =
412                    Checkpoint::new(thread_id, state, checkpoint.step, checkpoint.pending_nodes);
413                cp.save(&new_checkpoint).await?;
414            }
415        }
416        Ok(())
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use crate::edge::{END, START};
424    use crate::graph::StateGraph;
425    use crate::node::NodeOutput;
426    use serde_json::json;
427
428    #[tokio::test]
429    async fn test_simple_execution() {
430        let graph = StateGraph::with_channels(&["value"])
431            .add_node_fn("set_value", |_ctx| async {
432                Ok(NodeOutput::new().with_update("value", json!(42)))
433            })
434            .add_edge(START, "set_value")
435            .add_edge("set_value", END)
436            .compile()
437            .unwrap();
438
439        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
440
441        assert_eq!(result.get("value"), Some(&json!(42)));
442    }
443
444    #[tokio::test]
445    async fn test_sequential_execution() {
446        let graph = StateGraph::with_channels(&["value"])
447            .add_node_fn("step1", |_ctx| async {
448                Ok(NodeOutput::new().with_update("value", json!(1)))
449            })
450            .add_node_fn("step2", |ctx| async move {
451                let current = ctx.get("value").and_then(|v| v.as_i64()).unwrap_or(0);
452                Ok(NodeOutput::new().with_update("value", json!(current + 10)))
453            })
454            .add_edge(START, "step1")
455            .add_edge("step1", "step2")
456            .add_edge("step2", END)
457            .compile()
458            .unwrap();
459
460        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
461
462        assert_eq!(result.get("value"), Some(&json!(11)));
463    }
464
465    #[tokio::test]
466    async fn test_conditional_routing() {
467        let graph = StateGraph::with_channels(&["path", "result"])
468            .add_node_fn("router", |ctx| async move {
469                let path = ctx.get("path").and_then(|v| v.as_str()).unwrap_or("a");
470                Ok(NodeOutput::new().with_update("route", json!(path)))
471            })
472            .add_node_fn("path_a", |_ctx| async {
473                Ok(NodeOutput::new().with_update("result", json!("went to A")))
474            })
475            .add_node_fn("path_b", |_ctx| async {
476                Ok(NodeOutput::new().with_update("result", json!("went to B")))
477            })
478            .add_edge(START, "router")
479            .add_conditional_edges(
480                "router",
481                |state| state.get("route").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
482                [("a", "path_a"), ("b", "path_b"), (END, END)],
483            )
484            .add_edge("path_a", END)
485            .add_edge("path_b", END)
486            .compile()
487            .unwrap();
488
489        // Test path A
490        let mut input = State::new();
491        input.insert("path".to_string(), json!("a"));
492        let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
493        assert_eq!(result.get("result"), Some(&json!("went to A")));
494
495        // Test path B
496        let mut input = State::new();
497        input.insert("path".to_string(), json!("b"));
498        let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
499        assert_eq!(result.get("result"), Some(&json!("went to B")));
500    }
501
502    #[tokio::test]
503    async fn test_cycle_with_limit() {
504        let graph = StateGraph::with_channels(&["count"])
505            .add_node_fn("increment", |ctx| async move {
506                let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
507                Ok(NodeOutput::new().with_update("count", json!(count + 1)))
508            })
509            .add_edge(START, "increment")
510            .add_conditional_edges(
511                "increment",
512                |state| {
513                    let count = state.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
514                    if count < 5 { "increment".to_string() } else { END.to_string() }
515                },
516                [("increment", "increment"), (END, END)],
517            )
518            .compile()
519            .unwrap();
520
521        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
522
523        assert_eq!(result.get("count"), Some(&json!(5)));
524    }
525
526    #[tokio::test]
527    async fn test_recursion_limit() {
528        let graph = StateGraph::with_channels(&["count"])
529            .add_node_fn("loop", |ctx| async move {
530                let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
531                Ok(NodeOutput::new().with_update("count", json!(count + 1)))
532            })
533            .add_edge(START, "loop")
534            .add_edge("loop", "loop") // Infinite loop
535            .compile()
536            .unwrap()
537            .with_recursion_limit(10);
538
539        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await;
540
541        // The recursion limit check happens when step >= limit, so it will exceed at step 10
542        assert!(
543            matches!(result, Err(GraphError::RecursionLimitExceeded(_))),
544            "Expected RecursionLimitExceeded error, got: {:?}",
545            result
546        );
547    }
548}