adk_graph/
executor.rs

1//! Pregel-based execution engine for graphs
2//!
3//! Executes graphs using the Pregel model with super-steps.
4
5use crate::error::{GraphError, InterruptedExecution, Result};
6use crate::graph::CompiledGraph;
7use crate::interrupt::Interrupt;
8use crate::node::{ExecutionConfig, NodeContext};
9use crate::state::{Checkpoint, State};
10use crate::stream::{StreamEvent, StreamMode};
11use futures::stream::{self, StreamExt};
12use std::time::Instant;
13
14/// Result of a super-step execution
15#[derive(Default)]
16pub struct SuperStepResult {
17    /// Nodes that were executed
18    pub executed_nodes: Vec<String>,
19    /// Interrupt if one occurred
20    pub interrupt: Option<Interrupt>,
21    /// Stream events generated
22    pub events: Vec<StreamEvent>,
23}
24
25/// Pregel-based executor for graphs
26pub struct PregelExecutor<'a> {
27    graph: &'a CompiledGraph,
28    config: ExecutionConfig,
29    state: State,
30    step: usize,
31    pending_nodes: Vec<String>,
32}
33
34impl<'a> PregelExecutor<'a> {
35    /// Create a new executor
36    pub fn new(graph: &'a CompiledGraph, config: ExecutionConfig) -> Self {
37        Self { graph, config, state: State::new(), step: 0, pending_nodes: vec![] }
38    }
39
40    /// Run the graph to completion
41    pub async fn run(&mut self, input: State) -> Result<State> {
42        // Initialize state
43        self.state = self.initialize_state(input).await?;
44        self.pending_nodes = self.graph.get_entry_nodes();
45
46        // Main execution loop
47        while !self.pending_nodes.is_empty() {
48            // Check recursion limit
49            if self.step >= self.config.recursion_limit {
50                return Err(GraphError::RecursionLimitExceeded(self.step));
51            }
52
53            // Execute super-step
54            let result = self.execute_super_step().await?;
55
56            // Handle interrupts
57            if let Some(interrupt) = result.interrupt {
58                let checkpoint_id = self.save_checkpoint().await?;
59                return Err(GraphError::Interrupted(Box::new(InterruptedExecution::new(
60                    self.config.thread_id.clone(),
61                    checkpoint_id,
62                    interrupt,
63                    self.state.clone(),
64                    self.step,
65                ))));
66            }
67
68            // Save checkpoint after each step
69            self.save_checkpoint().await?;
70
71            // Check if we're done (all paths led to END)
72            if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
73                let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
74                if next.is_empty() {
75                    break;
76                }
77            }
78
79            // Determine next nodes
80            self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
81            self.step += 1;
82        }
83
84        Ok(self.state.clone())
85    }
86
87    /// Run with streaming
88    pub fn run_stream(
89        mut self,
90        input: State,
91        mode: StreamMode,
92    ) -> impl futures::Stream<Item = Result<StreamEvent>> + 'a {
93        async_stream::stream! {
94            // Initialize state
95            match self.initialize_state(input).await {
96                Ok(state) => self.state = state,
97                Err(e) => {
98                    yield Err(e);
99                    return;
100                }
101            }
102            self.pending_nodes = self.graph.get_entry_nodes();
103
104            // Stream initial state if requested
105            if matches!(mode, StreamMode::Values) {
106                yield Ok(StreamEvent::state(self.state.clone(), self.step));
107            }
108
109            // Main execution loop
110            while !self.pending_nodes.is_empty() {
111                // Check recursion limit
112                if self.step >= self.config.recursion_limit {
113                    yield Err(GraphError::RecursionLimitExceeded(self.step));
114                    return;
115                }
116
117                // Execute super-step
118                let result = match self.execute_super_step().await {
119                    Ok(r) => r,
120                    Err(e) => {
121                        yield Err(e);
122                        return;
123                    }
124                };
125
126                // Yield events based on mode
127                for event in &result.events {
128                    match mode {
129                        StreamMode::Custom => yield Ok(event.clone()),
130                        StreamMode::Debug => yield Ok(event.clone()),
131                        _ => {}
132                    }
133                }
134
135                // Yield state/updates
136                match mode {
137                    StreamMode::Values => {
138                        yield Ok(StreamEvent::state(self.state.clone(), self.step));
139                    }
140                    StreamMode::Updates => {
141                        yield Ok(StreamEvent::step_complete(
142                            self.step,
143                            result.executed_nodes.clone(),
144                        ));
145                    }
146                    _ => {}
147                }
148
149                // Handle interrupts
150                if let Some(interrupt) = result.interrupt {
151                    yield Ok(StreamEvent::interrupted(
152                        result.executed_nodes.first().map(|s| s.as_str()).unwrap_or("unknown"),
153                        &interrupt.to_string(),
154                    ));
155                    return;
156                }
157
158                // Check if done
159                if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
160                    let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
161                    if next.is_empty() {
162                        break;
163                    }
164                }
165
166                self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
167                self.step += 1;
168            }
169
170            yield Ok(StreamEvent::done(self.state.clone(), self.step + 1));
171        }
172    }
173
174    /// Initialize state from input and/or checkpoint
175    async fn initialize_state(&self, input: State) -> Result<State> {
176        // Start with schema defaults
177        let mut state = self.graph.schema.initialize_state();
178
179        // If resuming from checkpoint, load it
180        if let Some(checkpoint_id) = &self.config.resume_from {
181            if let Some(cp) = self.graph.checkpointer.as_ref() {
182                if let Some(checkpoint) = cp.load_by_id(checkpoint_id).await? {
183                    state = checkpoint.state;
184                }
185            }
186        } else if let Some(cp) = self.graph.checkpointer.as_ref() {
187            // Try to load latest checkpoint for thread
188            if let Some(checkpoint) = cp.load(&self.config.thread_id).await? {
189                state = checkpoint.state;
190            }
191        }
192
193        // Merge input into state
194        for (key, value) in input {
195            self.graph.schema.apply_update(&mut state, &key, value);
196        }
197
198        Ok(state)
199    }
200
201    /// Execute one super-step (plan -> execute -> update)
202    async fn execute_super_step(&mut self) -> Result<SuperStepResult> {
203        let mut result = SuperStepResult::default();
204
205        // Check for interrupt_before
206        for node_name in &self.pending_nodes {
207            if self.graph.interrupt_before.contains(node_name) {
208                return Ok(SuperStepResult {
209                    interrupt: Some(Interrupt::Before(node_name.clone())),
210                    ..Default::default()
211                });
212            }
213        }
214
215        // Execute all pending nodes in parallel
216        let nodes: Vec<_> = self
217            .pending_nodes
218            .iter()
219            .filter_map(|name| self.graph.nodes.get(name).map(|n| (name.clone(), n.clone())))
220            .collect();
221
222        let futures: Vec<_> = nodes
223            .into_iter()
224            .map(|(name, node)| {
225                let ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
226                let step = self.step;
227                async move {
228                    let start = Instant::now();
229                    let output = node.execute(&ctx).await;
230                    let duration_ms = start.elapsed().as_millis() as u64;
231                    (name, output, duration_ms, step)
232                }
233            })
234            .collect();
235
236        let outputs: Vec<_> =
237            stream::iter(futures).buffer_unordered(self.pending_nodes.len()).collect().await;
238
239        // Collect all updates and check for errors/interrupts
240        let mut all_updates = Vec::new();
241
242        for (node_name, output_result, duration_ms, step) in outputs {
243            result.executed_nodes.push(node_name.clone());
244            result.events.push(StreamEvent::node_end(&node_name, step, duration_ms));
245
246            match output_result {
247                Ok(output) => {
248                    // Check for dynamic interrupt
249                    if let Some(interrupt) = output.interrupt {
250                        return Ok(SuperStepResult {
251                            interrupt: Some(interrupt),
252                            executed_nodes: result.executed_nodes,
253                            events: result.events,
254                        });
255                    }
256
257                    // Collect custom events
258                    result.events.extend(output.events);
259
260                    // Collect updates
261                    all_updates.push(output.updates);
262                }
263                Err(e) => {
264                    return Err(GraphError::NodeExecutionFailed {
265                        node: node_name,
266                        message: e.to_string(),
267                    });
268                }
269            }
270        }
271
272        // Apply all updates atomically using reducers
273        for updates in all_updates {
274            for (key, value) in updates {
275                self.graph.schema.apply_update(&mut self.state, &key, value);
276            }
277        }
278
279        // Check for interrupt_after
280        for node_name in &result.executed_nodes {
281            if self.graph.interrupt_after.contains(node_name) {
282                return Ok(SuperStepResult {
283                    interrupt: Some(Interrupt::After(node_name.clone())),
284                    ..result
285                });
286            }
287        }
288
289        Ok(result)
290    }
291
292    /// Save a checkpoint
293    async fn save_checkpoint(&self) -> Result<String> {
294        if let Some(cp) = &self.graph.checkpointer {
295            let checkpoint = Checkpoint::new(
296                &self.config.thread_id,
297                self.state.clone(),
298                self.step,
299                self.pending_nodes.clone(),
300            );
301            return cp.save(&checkpoint).await;
302        }
303        Ok(String::new())
304    }
305}
306
307/// Convenience methods for CompiledGraph
308impl CompiledGraph {
309    /// Execute the graph synchronously
310    pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
311        let mut executor = PregelExecutor::new(self, config);
312        executor.run(input).await
313    }
314
315    /// Execute with streaming
316    pub fn stream(
317        &self,
318        input: State,
319        config: ExecutionConfig,
320        mode: StreamMode,
321    ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
322        let executor = PregelExecutor::new(self, config);
323        executor.run_stream(input, mode)
324    }
325
326    /// Get current state for a thread
327    pub async fn get_state(&self, thread_id: &str) -> Result<Option<State>> {
328        if let Some(cp) = &self.checkpointer {
329            Ok(cp.load(thread_id).await?.map(|c| c.state))
330        } else {
331            Ok(None)
332        }
333    }
334
335    /// Update state for a thread (for human-in-the-loop)
336    pub async fn update_state(
337        &self,
338        thread_id: &str,
339        updates: impl IntoIterator<Item = (String, serde_json::Value)>,
340    ) -> Result<()> {
341        if let Some(cp) = &self.checkpointer {
342            if let Some(checkpoint) = cp.load(thread_id).await? {
343                let mut state = checkpoint.state;
344                for (key, value) in updates {
345                    self.schema.apply_update(&mut state, &key, value);
346                }
347                let new_checkpoint =
348                    Checkpoint::new(thread_id, state, checkpoint.step, checkpoint.pending_nodes);
349                cp.save(&new_checkpoint).await?;
350            }
351        }
352        Ok(())
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use crate::edge::{END, START};
360    use crate::graph::StateGraph;
361    use crate::node::NodeOutput;
362    use serde_json::json;
363
364    #[tokio::test]
365    async fn test_simple_execution() {
366        let graph = StateGraph::with_channels(&["value"])
367            .add_node_fn("set_value", |_ctx| async {
368                Ok(NodeOutput::new().with_update("value", json!(42)))
369            })
370            .add_edge(START, "set_value")
371            .add_edge("set_value", END)
372            .compile()
373            .unwrap();
374
375        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
376
377        assert_eq!(result.get("value"), Some(&json!(42)));
378    }
379
380    #[tokio::test]
381    async fn test_sequential_execution() {
382        let graph = StateGraph::with_channels(&["value"])
383            .add_node_fn("step1", |_ctx| async {
384                Ok(NodeOutput::new().with_update("value", json!(1)))
385            })
386            .add_node_fn("step2", |ctx| async move {
387                let current = ctx.get("value").and_then(|v| v.as_i64()).unwrap_or(0);
388                Ok(NodeOutput::new().with_update("value", json!(current + 10)))
389            })
390            .add_edge(START, "step1")
391            .add_edge("step1", "step2")
392            .add_edge("step2", END)
393            .compile()
394            .unwrap();
395
396        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
397
398        assert_eq!(result.get("value"), Some(&json!(11)));
399    }
400
401    #[tokio::test]
402    async fn test_conditional_routing() {
403        let graph = StateGraph::with_channels(&["path", "result"])
404            .add_node_fn("router", |ctx| async move {
405                let path = ctx.get("path").and_then(|v| v.as_str()).unwrap_or("a");
406                Ok(NodeOutput::new().with_update("route", json!(path)))
407            })
408            .add_node_fn("path_a", |_ctx| async {
409                Ok(NodeOutput::new().with_update("result", json!("went to A")))
410            })
411            .add_node_fn("path_b", |_ctx| async {
412                Ok(NodeOutput::new().with_update("result", json!("went to B")))
413            })
414            .add_edge(START, "router")
415            .add_conditional_edges(
416                "router",
417                |state| state.get("route").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
418                [("a", "path_a"), ("b", "path_b"), (END, END)],
419            )
420            .add_edge("path_a", END)
421            .add_edge("path_b", END)
422            .compile()
423            .unwrap();
424
425        // Test path A
426        let mut input = State::new();
427        input.insert("path".to_string(), json!("a"));
428        let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
429        assert_eq!(result.get("result"), Some(&json!("went to A")));
430
431        // Test path B
432        let mut input = State::new();
433        input.insert("path".to_string(), json!("b"));
434        let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
435        assert_eq!(result.get("result"), Some(&json!("went to B")));
436    }
437
438    #[tokio::test]
439    async fn test_cycle_with_limit() {
440        let graph = StateGraph::with_channels(&["count"])
441            .add_node_fn("increment", |ctx| async move {
442                let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
443                Ok(NodeOutput::new().with_update("count", json!(count + 1)))
444            })
445            .add_edge(START, "increment")
446            .add_conditional_edges(
447                "increment",
448                |state| {
449                    let count = state.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
450                    if count < 5 {
451                        "increment".to_string()
452                    } else {
453                        END.to_string()
454                    }
455                },
456                [("increment", "increment"), (END, END)],
457            )
458            .compile()
459            .unwrap();
460
461        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
462
463        assert_eq!(result.get("count"), Some(&json!(5)));
464    }
465
466    #[tokio::test]
467    async fn test_recursion_limit() {
468        let graph = StateGraph::with_channels(&["count"])
469            .add_node_fn("loop", |ctx| async move {
470                let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
471                Ok(NodeOutput::new().with_update("count", json!(count + 1)))
472            })
473            .add_edge(START, "loop")
474            .add_edge("loop", "loop") // Infinite loop
475            .compile()
476            .unwrap()
477            .with_recursion_limit(10);
478
479        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await;
480
481        // The recursion limit check happens when step >= limit, so it will exceed at step 10
482        assert!(
483            matches!(result, Err(GraphError::RecursionLimitExceeded(_))),
484            "Expected RecursionLimitExceeded error, got: {:?}",
485            result
486        );
487    }
488}