Skip to main content

adk_graph/
executor.rs

1//! Pregel-based execution engine for graphs
2//!
3//! Executes graphs using the Pregel model with super-steps.
4
5use crate::error::{GraphError, InterruptedExecution, Result};
6use crate::graph::CompiledGraph;
7use crate::interrupt::Interrupt;
8use crate::node::{ExecutionConfig, NodeContext};
9use crate::state::{Checkpoint, State};
10use crate::stream::{StreamEvent, StreamMode};
11use futures::stream::{self, StreamExt};
12use std::time::Instant;
13
14/// Result of a super-step execution
15#[derive(Default)]
16pub struct SuperStepResult {
17    /// Nodes that were executed
18    pub executed_nodes: Vec<String>,
19    /// Interrupt if one occurred
20    pub interrupt: Option<Interrupt>,
21    /// Stream events generated
22    pub events: Vec<StreamEvent>,
23}
24
25/// Pregel-based executor for graphs
26pub struct PregelExecutor<'a> {
27    graph: &'a CompiledGraph,
28    config: ExecutionConfig,
29    state: State,
30    step: usize,
31    pending_nodes: Vec<String>,
32}
33
34impl<'a> PregelExecutor<'a> {
35    /// Create a new executor
36    pub fn new(graph: &'a CompiledGraph, config: ExecutionConfig) -> Self {
37        Self { graph, config, state: State::new(), step: 0, pending_nodes: vec![] }
38    }
39
40    /// Attempt to resume from an existing checkpoint.
41    ///
42    /// If a checkpoint is found (either by explicit `resume_from` ID or by latest
43    /// checkpoint for the thread), restores state, pending_nodes, and step from it,
44    /// then merges the provided input on top. Returns `true` if resumed.
45    ///
46    /// If no checkpoint is found, returns `false` so the caller can proceed with
47    /// fresh-start logic.
48    async fn try_resume_from_checkpoint(&mut self, input: &State) -> Result<bool> {
49        let checkpoint = if let Some(checkpoint_id) = &self.config.resume_from {
50            // Resume from a specific checkpoint by ID
51            if let Some(cp) = self.graph.checkpointer.as_ref() {
52                cp.load_by_id(checkpoint_id).await?
53            } else {
54                None
55            }
56        } else if let Some(cp) = self.graph.checkpointer.as_ref() {
57            // Try to load the latest checkpoint for this thread
58            cp.load(&self.config.thread_id).await?
59        } else {
60            None
61        };
62
63        if let Some(checkpoint) = checkpoint {
64            // Restore state from checkpoint
65            self.state = checkpoint.state;
66            self.pending_nodes = checkpoint.pending_nodes;
67            self.step = checkpoint.step;
68
69            // Merge input on top of restored state
70            for (key, value) in input {
71                self.graph.schema.apply_update(&mut self.state, key, value.clone());
72            }
73
74            Ok(true)
75        } else {
76            Ok(false)
77        }
78    }
79
80    /// Run the graph to completion
81    pub async fn run(&mut self, input: State) -> Result<State> {
82        // Check for existing checkpoint to resume from
83        let resumed = self.try_resume_from_checkpoint(&input).await?;
84
85        if !resumed {
86            // No checkpoint found — fresh start
87            self.state = self.initialize_state(input).await?;
88            self.pending_nodes = self.graph.get_entry_nodes();
89        }
90
91        // Main execution loop
92        while !self.pending_nodes.is_empty() {
93            // Check recursion limit
94            if self.step >= self.config.recursion_limit {
95                return Err(GraphError::RecursionLimitExceeded(self.step));
96            }
97
98            // Execute super-step
99            let result = self.execute_super_step().await?;
100
101            // Handle interrupts
102            if let Some(interrupt) = result.interrupt {
103                let checkpoint_id = self.save_checkpoint().await?;
104                return Err(GraphError::Interrupted(Box::new(InterruptedExecution::new(
105                    self.config.thread_id.clone(),
106                    checkpoint_id,
107                    interrupt,
108                    self.state.clone(),
109                    self.step,
110                ))));
111            }
112
113            // Save checkpoint after each step
114            self.save_checkpoint().await?;
115
116            // Check if we're done (all paths led to END)
117            if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
118                let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
119                if next.is_empty() {
120                    break;
121                }
122            }
123
124            // Determine next nodes
125            self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
126            self.step += 1;
127        }
128
129        Ok(self.state.clone())
130    }
131
132    /// Run with streaming
133    pub fn run_stream(
134        mut self,
135        input: State,
136        mode: StreamMode,
137    ) -> impl futures::Stream<Item = Result<StreamEvent>> + 'a {
138        async_stream::stream! {
139            // Check for existing checkpoint to resume from
140            let resumed = match self.try_resume_from_checkpoint(&input).await {
141                Ok(r) => r,
142                Err(e) => {
143                    yield Err(e);
144                    return;
145                }
146            };
147
148            if resumed {
149                // Emit a resumed event indicating execution was restored from checkpoint
150                yield Ok(StreamEvent::resumed(self.step, self.pending_nodes.clone()));
151            } else {
152                // No checkpoint found — fresh start
153                match self.initialize_state(input).await {
154                    Ok(state) => self.state = state,
155                    Err(e) => {
156                        yield Err(e);
157                        return;
158                    }
159                }
160                self.pending_nodes = self.graph.get_entry_nodes();
161            }
162
163            // Stream initial state if requested
164            if matches!(mode, StreamMode::Values) {
165                yield Ok(StreamEvent::state(self.state.clone(), self.step));
166            }
167
168            // Main execution loop
169            while !self.pending_nodes.is_empty() {
170                // Check recursion limit
171                if self.step >= self.config.recursion_limit {
172                    yield Err(GraphError::RecursionLimitExceeded(self.step));
173                    return;
174                }
175
176                // Emit node_start events BEFORE execution (in Debug mode)
177                if matches!(mode, StreamMode::Debug | StreamMode::Custom | StreamMode::Messages) {
178                    for node_name in &self.pending_nodes {
179                        yield Ok(StreamEvent::node_start(node_name, self.step));
180                    }
181                }
182
183                // For Messages mode, stream from nodes directly
184                if matches!(mode, StreamMode::Messages) {
185                    let mut result = SuperStepResult::default();
186
187                    for node_name in &self.pending_nodes {
188                        if let Some(node) = self.graph.nodes.get(node_name) {
189                            let ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
190                            let start = std::time::Instant::now();
191
192                            let mut node_stream = node.execute_stream(&ctx);
193                            let mut collected_events = Vec::new();
194
195                            while let Some(event_result) = node_stream.next().await {
196                                match event_result {
197                                    Ok(event) => {
198                                        // Yield Message events immediately
199                                        if matches!(event, StreamEvent::Message { .. }) {
200                                            yield Ok(event.clone());
201                                        }
202                                        collected_events.push(event);
203                                    }
204                                    Err(e) => {
205                                        yield Err(e);
206                                        return;
207                                    }
208                                }
209                            }
210
211                            let duration_ms = start.elapsed().as_millis() as u64;
212                            result.executed_nodes.push(node_name.clone());
213                            result.events.push(StreamEvent::node_end(node_name, self.step, duration_ms));
214                            result.events.extend(collected_events);
215
216                            // Get output from execute for state updates
217                            if let Ok(output) = node.execute(&ctx).await {
218                                for (key, value) in output.updates {
219                                    self.graph.schema.apply_update(&mut self.state, &key, value);
220                                }
221                            }
222                        }
223                    }
224
225                    // Yield node_end events
226                    for event in &result.events {
227                        if matches!(event, StreamEvent::NodeEnd { .. }) {
228                            yield Ok(event.clone());
229                        }
230                    }
231
232                    self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
233                    self.step += 1;
234                    continue;
235                }
236
237                // Execute super-step (non-streaming)
238                let result = match self.execute_super_step().await {
239                    Ok(r) => r,
240                    Err(e) => {
241                        yield Err(e);
242                        return;
243                    }
244                };
245
246                // Yield events based on mode (node_end and custom events)
247                for event in &result.events {
248                    match (&mode, &event) {
249                        // Skip node_start since we already emitted it above
250                        (StreamMode::Custom | StreamMode::Debug, StreamEvent::NodeStart { .. }) => {}
251                        (StreamMode::Custom, _) => yield Ok(event.clone()),
252                        (StreamMode::Debug, _) => yield Ok(event.clone()),
253                        _ => {}
254                    }
255                }
256
257                // Yield state/updates
258                match mode {
259                    StreamMode::Values => {
260                        yield Ok(StreamEvent::state(self.state.clone(), self.step));
261                    }
262                    StreamMode::Updates => {
263                        yield Ok(StreamEvent::step_complete(
264                            self.step,
265                            result.executed_nodes.clone(),
266                        ));
267                    }
268                    _ => {}
269                }
270
271                // Handle interrupts
272                if let Some(interrupt) = result.interrupt {
273                    yield Ok(StreamEvent::interrupted(
274                        result.executed_nodes.first().map(|s| s.as_str()).unwrap_or("unknown"),
275                        &interrupt.to_string(),
276                    ));
277                    return;
278                }
279
280                // Check if done
281                if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
282                    let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
283                    if next.is_empty() {
284                        break;
285                    }
286                }
287
288                self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
289                self.step += 1;
290            }
291
292            yield Ok(StreamEvent::done(self.state.clone(), self.step + 1));
293        }
294    }
295
296    /// Initialize state from input and/or checkpoint
297    async fn initialize_state(&self, input: State) -> Result<State> {
298        // Start with schema defaults
299        let mut state = self.graph.schema.initialize_state();
300
301        // If resuming from checkpoint, load it
302        if let Some(checkpoint_id) = &self.config.resume_from {
303            if let Some(cp) = self.graph.checkpointer.as_ref() {
304                if let Some(checkpoint) = cp.load_by_id(checkpoint_id).await? {
305                    state = checkpoint.state;
306                }
307            }
308        } else if let Some(cp) = self.graph.checkpointer.as_ref() {
309            // Try to load latest checkpoint for thread
310            if let Some(checkpoint) = cp.load(&self.config.thread_id).await? {
311                state = checkpoint.state;
312            }
313        }
314
315        // Merge input into state
316        for (key, value) in input {
317            self.graph.schema.apply_update(&mut state, &key, value);
318        }
319
320        Ok(state)
321    }
322
323    /// Execute one super-step (plan -> execute -> update)
324    async fn execute_super_step(&mut self) -> Result<SuperStepResult> {
325        let mut result = SuperStepResult::default();
326
327        // Check for interrupt_before
328        for node_name in &self.pending_nodes {
329            if self.graph.interrupt_before.contains(node_name) {
330                return Ok(SuperStepResult {
331                    interrupt: Some(Interrupt::Before(node_name.clone())),
332                    ..Default::default()
333                });
334            }
335        }
336
337        // Execute all pending nodes in parallel
338        let nodes: Vec<_> = self
339            .pending_nodes
340            .iter()
341            .filter_map(|name| self.graph.nodes.get(name).map(|n| (name.clone(), n.clone())))
342            .collect();
343
344        let futures: Vec<_> = nodes
345            .into_iter()
346            .map(|(name, node)| {
347                let ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
348                let step = self.step;
349                async move {
350                    let start = Instant::now();
351                    let output = node.execute(&ctx).await;
352                    let duration_ms = start.elapsed().as_millis() as u64;
353                    (name, output, duration_ms, step)
354                }
355            })
356            .collect();
357
358        let outputs: Vec<_> =
359            stream::iter(futures).buffer_unordered(self.pending_nodes.len()).collect().await;
360
361        // Collect all updates and check for errors/interrupts
362        let mut all_updates = Vec::new();
363
364        for (node_name, output_result, duration_ms, step) in outputs {
365            result.executed_nodes.push(node_name.clone());
366            result.events.push(StreamEvent::node_end(&node_name, step, duration_ms));
367
368            match output_result {
369                Ok(output) => {
370                    // Check for dynamic interrupt
371                    if let Some(interrupt) = output.interrupt {
372                        return Ok(SuperStepResult {
373                            interrupt: Some(interrupt),
374                            executed_nodes: result.executed_nodes,
375                            events: result.events,
376                        });
377                    }
378
379                    // Collect custom events
380                    result.events.extend(output.events);
381
382                    // Collect updates
383                    all_updates.push(output.updates);
384                }
385                Err(e) => {
386                    return Err(GraphError::NodeExecutionFailed {
387                        node: node_name,
388                        message: e.to_string(),
389                    });
390                }
391            }
392        }
393
394        // Apply all updates atomically using reducers
395        for updates in all_updates {
396            for (key, value) in updates {
397                self.graph.schema.apply_update(&mut self.state, &key, value);
398            }
399        }
400
401        // Check for interrupt_after
402        for node_name in &result.executed_nodes {
403            if self.graph.interrupt_after.contains(node_name) {
404                return Ok(SuperStepResult {
405                    interrupt: Some(Interrupt::After(node_name.clone())),
406                    ..result
407                });
408            }
409        }
410
411        Ok(result)
412    }
413
414    /// Save a checkpoint
415    async fn save_checkpoint(&self) -> Result<String> {
416        if let Some(cp) = &self.graph.checkpointer {
417            let checkpoint = Checkpoint::new(
418                &self.config.thread_id,
419                self.state.clone(),
420                self.step,
421                self.pending_nodes.clone(),
422            );
423            return cp.save(&checkpoint).await;
424        }
425        Ok(String::new())
426    }
427}
428
429/// Convenience methods for CompiledGraph
430impl CompiledGraph {
431    /// Execute the graph synchronously
432    pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
433        let mut executor = PregelExecutor::new(self, config);
434        executor.run(input).await
435    }
436
437    /// Execute with streaming
438    pub fn stream(
439        &self,
440        input: State,
441        config: ExecutionConfig,
442        mode: StreamMode,
443    ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
444        tracing::debug!("CompiledGraph::stream called with mode {:?}", mode);
445        let executor = PregelExecutor::new(self, config);
446        executor.run_stream(input, mode)
447    }
448
449    /// Get current state for a thread
450    pub async fn get_state(&self, thread_id: &str) -> Result<Option<State>> {
451        if let Some(cp) = &self.checkpointer {
452            Ok(cp.load(thread_id).await?.map(|c| c.state))
453        } else {
454            Ok(None)
455        }
456    }
457
458    /// Update state for a thread (for human-in-the-loop)
459    pub async fn update_state(
460        &self,
461        thread_id: &str,
462        updates: impl IntoIterator<Item = (String, serde_json::Value)>,
463    ) -> Result<()> {
464        if let Some(cp) = &self.checkpointer {
465            if let Some(checkpoint) = cp.load(thread_id).await? {
466                let mut state = checkpoint.state;
467                for (key, value) in updates {
468                    self.schema.apply_update(&mut state, &key, value);
469                }
470                let new_checkpoint =
471                    Checkpoint::new(thread_id, state, checkpoint.step, checkpoint.pending_nodes);
472                cp.save(&new_checkpoint).await?;
473            }
474        }
475        Ok(())
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use crate::edge::{END, START};
483    use crate::graph::StateGraph;
484    use crate::node::NodeOutput;
485    use serde_json::json;
486
487    #[tokio::test]
488    async fn test_simple_execution() {
489        let graph = StateGraph::with_channels(&["value"])
490            .add_node_fn("set_value", |_ctx| async {
491                Ok(NodeOutput::new().with_update("value", json!(42)))
492            })
493            .add_edge(START, "set_value")
494            .add_edge("set_value", END)
495            .compile()
496            .unwrap();
497
498        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
499
500        assert_eq!(result.get("value"), Some(&json!(42)));
501    }
502
503    #[tokio::test]
504    async fn test_sequential_execution() {
505        let graph = StateGraph::with_channels(&["value"])
506            .add_node_fn("step1", |_ctx| async {
507                Ok(NodeOutput::new().with_update("value", json!(1)))
508            })
509            .add_node_fn("step2", |ctx| async move {
510                let current = ctx.get("value").and_then(|v| v.as_i64()).unwrap_or(0);
511                Ok(NodeOutput::new().with_update("value", json!(current + 10)))
512            })
513            .add_edge(START, "step1")
514            .add_edge("step1", "step2")
515            .add_edge("step2", END)
516            .compile()
517            .unwrap();
518
519        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
520
521        assert_eq!(result.get("value"), Some(&json!(11)));
522    }
523
524    #[tokio::test]
525    async fn test_conditional_routing() {
526        let graph = StateGraph::with_channels(&["path", "result"])
527            .add_node_fn("router", |ctx| async move {
528                let path = ctx.get("path").and_then(|v| v.as_str()).unwrap_or("a");
529                Ok(NodeOutput::new().with_update("route", json!(path)))
530            })
531            .add_node_fn("path_a", |_ctx| async {
532                Ok(NodeOutput::new().with_update("result", json!("went to A")))
533            })
534            .add_node_fn("path_b", |_ctx| async {
535                Ok(NodeOutput::new().with_update("result", json!("went to B")))
536            })
537            .add_edge(START, "router")
538            .add_conditional_edges(
539                "router",
540                |state| state.get("route").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
541                [("a", "path_a"), ("b", "path_b"), (END, END)],
542            )
543            .add_edge("path_a", END)
544            .add_edge("path_b", END)
545            .compile()
546            .unwrap();
547
548        // Test path A
549        let mut input = State::new();
550        input.insert("path".to_string(), json!("a"));
551        let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
552        assert_eq!(result.get("result"), Some(&json!("went to A")));
553
554        // Test path B
555        let mut input = State::new();
556        input.insert("path".to_string(), json!("b"));
557        let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
558        assert_eq!(result.get("result"), Some(&json!("went to B")));
559    }
560
561    #[tokio::test]
562    async fn test_cycle_with_limit() {
563        let graph = StateGraph::with_channels(&["count"])
564            .add_node_fn("increment", |ctx| async move {
565                let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
566                Ok(NodeOutput::new().with_update("count", json!(count + 1)))
567            })
568            .add_edge(START, "increment")
569            .add_conditional_edges(
570                "increment",
571                |state| {
572                    let count = state.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
573                    if count < 5 { "increment".to_string() } else { END.to_string() }
574                },
575                [("increment", "increment"), (END, END)],
576            )
577            .compile()
578            .unwrap();
579
580        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
581
582        assert_eq!(result.get("count"), Some(&json!(5)));
583    }
584
585    #[tokio::test]
586    async fn test_recursion_limit() {
587        let graph = StateGraph::with_channels(&["count"])
588            .add_node_fn("loop", |ctx| async move {
589                let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
590                Ok(NodeOutput::new().with_update("count", json!(count + 1)))
591            })
592            .add_edge(START, "loop")
593            .add_edge("loop", "loop") // Infinite loop
594            .compile()
595            .unwrap()
596            .with_recursion_limit(10);
597
598        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await;
599
600        // The recursion limit check happens when step >= limit, so it will exceed at step 10
601        assert!(
602            matches!(result, Err(GraphError::RecursionLimitExceeded(_))),
603            "Expected RecursionLimitExceeded error, got: {:?}",
604            result
605        );
606    }
607}