1use crate::error::{GraphError, InterruptedExecution, Result};
6use crate::graph::CompiledGraph;
7use crate::interrupt::Interrupt;
8use crate::node::{ExecutionConfig, NodeContext};
9use crate::state::{Checkpoint, State};
10use crate::stream::{StreamEvent, StreamMode};
11use futures::stream::{self, StreamExt};
12use std::time::Instant;
13
14#[derive(Default)]
16pub struct SuperStepResult {
17 pub executed_nodes: Vec<String>,
19 pub interrupt: Option<Interrupt>,
21 pub events: Vec<StreamEvent>,
23}
24
25pub struct PregelExecutor<'a> {
27 graph: &'a CompiledGraph,
28 config: ExecutionConfig,
29 state: State,
30 step: usize,
31 pending_nodes: Vec<String>,
32}
33
34impl<'a> PregelExecutor<'a> {
35 pub fn new(graph: &'a CompiledGraph, config: ExecutionConfig) -> Self {
37 Self { graph, config, state: State::new(), step: 0, pending_nodes: vec![] }
38 }
39
40 pub async fn run(&mut self, input: State) -> Result<State> {
42 self.state = self.initialize_state(input).await?;
44 self.pending_nodes = self.graph.get_entry_nodes();
45
46 while !self.pending_nodes.is_empty() {
48 if self.step >= self.config.recursion_limit {
50 return Err(GraphError::RecursionLimitExceeded(self.step));
51 }
52
53 let result = self.execute_super_step().await?;
55
56 if let Some(interrupt) = result.interrupt {
58 let checkpoint_id = self.save_checkpoint().await?;
59 return Err(GraphError::Interrupted(Box::new(InterruptedExecution::new(
60 self.config.thread_id.clone(),
61 checkpoint_id,
62 interrupt,
63 self.state.clone(),
64 self.step,
65 ))));
66 }
67
68 self.save_checkpoint().await?;
70
71 if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
73 let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
74 if next.is_empty() {
75 break;
76 }
77 }
78
79 self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
81 self.step += 1;
82 }
83
84 Ok(self.state.clone())
85 }
86
87 pub fn run_stream(
89 mut self,
90 input: State,
91 mode: StreamMode,
92 ) -> impl futures::Stream<Item = Result<StreamEvent>> + 'a {
93 async_stream::stream! {
94 match self.initialize_state(input).await {
96 Ok(state) => self.state = state,
97 Err(e) => {
98 yield Err(e);
99 return;
100 }
101 }
102 self.pending_nodes = self.graph.get_entry_nodes();
103
104 if matches!(mode, StreamMode::Values) {
106 yield Ok(StreamEvent::state(self.state.clone(), self.step));
107 }
108
109 while !self.pending_nodes.is_empty() {
111 if self.step >= self.config.recursion_limit {
113 yield Err(GraphError::RecursionLimitExceeded(self.step));
114 return;
115 }
116
117 let result = match self.execute_super_step().await {
119 Ok(r) => r,
120 Err(e) => {
121 yield Err(e);
122 return;
123 }
124 };
125
126 for event in &result.events {
128 match mode {
129 StreamMode::Custom => yield Ok(event.clone()),
130 StreamMode::Debug => yield Ok(event.clone()),
131 _ => {}
132 }
133 }
134
135 match mode {
137 StreamMode::Values => {
138 yield Ok(StreamEvent::state(self.state.clone(), self.step));
139 }
140 StreamMode::Updates => {
141 yield Ok(StreamEvent::step_complete(
142 self.step,
143 result.executed_nodes.clone(),
144 ));
145 }
146 _ => {}
147 }
148
149 if let Some(interrupt) = result.interrupt {
151 yield Ok(StreamEvent::interrupted(
152 result.executed_nodes.first().map(|s| s.as_str()).unwrap_or("unknown"),
153 &interrupt.to_string(),
154 ));
155 return;
156 }
157
158 if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
160 let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
161 if next.is_empty() {
162 break;
163 }
164 }
165
166 self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
167 self.step += 1;
168 }
169
170 yield Ok(StreamEvent::done(self.state.clone(), self.step + 1));
171 }
172 }
173
174 async fn initialize_state(&self, input: State) -> Result<State> {
176 let mut state = self.graph.schema.initialize_state();
178
179 if let Some(checkpoint_id) = &self.config.resume_from {
181 if let Some(cp) = self.graph.checkpointer.as_ref() {
182 if let Some(checkpoint) = cp.load_by_id(checkpoint_id).await? {
183 state = checkpoint.state;
184 }
185 }
186 } else if let Some(cp) = self.graph.checkpointer.as_ref() {
187 if let Some(checkpoint) = cp.load(&self.config.thread_id).await? {
189 state = checkpoint.state;
190 }
191 }
192
193 for (key, value) in input {
195 self.graph.schema.apply_update(&mut state, &key, value);
196 }
197
198 Ok(state)
199 }
200
201 async fn execute_super_step(&mut self) -> Result<SuperStepResult> {
203 let mut result = SuperStepResult::default();
204
205 for node_name in &self.pending_nodes {
207 if self.graph.interrupt_before.contains(node_name) {
208 return Ok(SuperStepResult {
209 interrupt: Some(Interrupt::Before(node_name.clone())),
210 ..Default::default()
211 });
212 }
213 }
214
215 let nodes: Vec<_> = self
217 .pending_nodes
218 .iter()
219 .filter_map(|name| self.graph.nodes.get(name).map(|n| (name.clone(), n.clone())))
220 .collect();
221
222 let futures: Vec<_> = nodes
223 .into_iter()
224 .map(|(name, node)| {
225 let ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
226 let step = self.step;
227 async move {
228 let start = Instant::now();
229 let output = node.execute(&ctx).await;
230 let duration_ms = start.elapsed().as_millis() as u64;
231 (name, output, duration_ms, step)
232 }
233 })
234 .collect();
235
236 let outputs: Vec<_> =
237 stream::iter(futures).buffer_unordered(self.pending_nodes.len()).collect().await;
238
239 let mut all_updates = Vec::new();
241
242 for (node_name, output_result, duration_ms, step) in outputs {
243 result.executed_nodes.push(node_name.clone());
244 result.events.push(StreamEvent::node_end(&node_name, step, duration_ms));
245
246 match output_result {
247 Ok(output) => {
248 if let Some(interrupt) = output.interrupt {
250 return Ok(SuperStepResult {
251 interrupt: Some(interrupt),
252 executed_nodes: result.executed_nodes,
253 events: result.events,
254 });
255 }
256
257 result.events.extend(output.events);
259
260 all_updates.push(output.updates);
262 }
263 Err(e) => {
264 return Err(GraphError::NodeExecutionFailed {
265 node: node_name,
266 message: e.to_string(),
267 });
268 }
269 }
270 }
271
272 for updates in all_updates {
274 for (key, value) in updates {
275 self.graph.schema.apply_update(&mut self.state, &key, value);
276 }
277 }
278
279 for node_name in &result.executed_nodes {
281 if self.graph.interrupt_after.contains(node_name) {
282 return Ok(SuperStepResult {
283 interrupt: Some(Interrupt::After(node_name.clone())),
284 ..result
285 });
286 }
287 }
288
289 Ok(result)
290 }
291
292 async fn save_checkpoint(&self) -> Result<String> {
294 if let Some(cp) = &self.graph.checkpointer {
295 let checkpoint = Checkpoint::new(
296 &self.config.thread_id,
297 self.state.clone(),
298 self.step,
299 self.pending_nodes.clone(),
300 );
301 return cp.save(&checkpoint).await;
302 }
303 Ok(String::new())
304 }
305}
306
307impl CompiledGraph {
309 pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
311 let mut executor = PregelExecutor::new(self, config);
312 executor.run(input).await
313 }
314
315 pub fn stream(
317 &self,
318 input: State,
319 config: ExecutionConfig,
320 mode: StreamMode,
321 ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
322 let executor = PregelExecutor::new(self, config);
323 executor.run_stream(input, mode)
324 }
325
326 pub async fn get_state(&self, thread_id: &str) -> Result<Option<State>> {
328 if let Some(cp) = &self.checkpointer {
329 Ok(cp.load(thread_id).await?.map(|c| c.state))
330 } else {
331 Ok(None)
332 }
333 }
334
335 pub async fn update_state(
337 &self,
338 thread_id: &str,
339 updates: impl IntoIterator<Item = (String, serde_json::Value)>,
340 ) -> Result<()> {
341 if let Some(cp) = &self.checkpointer {
342 if let Some(checkpoint) = cp.load(thread_id).await? {
343 let mut state = checkpoint.state;
344 for (key, value) in updates {
345 self.schema.apply_update(&mut state, &key, value);
346 }
347 let new_checkpoint =
348 Checkpoint::new(thread_id, state, checkpoint.step, checkpoint.pending_nodes);
349 cp.save(&new_checkpoint).await?;
350 }
351 }
352 Ok(())
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::edge::{END, START};
360 use crate::graph::StateGraph;
361 use crate::node::NodeOutput;
362 use serde_json::json;
363
364 #[tokio::test]
365 async fn test_simple_execution() {
366 let graph = StateGraph::with_channels(&["value"])
367 .add_node_fn("set_value", |_ctx| async {
368 Ok(NodeOutput::new().with_update("value", json!(42)))
369 })
370 .add_edge(START, "set_value")
371 .add_edge("set_value", END)
372 .compile()
373 .unwrap();
374
375 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
376
377 assert_eq!(result.get("value"), Some(&json!(42)));
378 }
379
380 #[tokio::test]
381 async fn test_sequential_execution() {
382 let graph = StateGraph::with_channels(&["value"])
383 .add_node_fn("step1", |_ctx| async {
384 Ok(NodeOutput::new().with_update("value", json!(1)))
385 })
386 .add_node_fn("step2", |ctx| async move {
387 let current = ctx.get("value").and_then(|v| v.as_i64()).unwrap_or(0);
388 Ok(NodeOutput::new().with_update("value", json!(current + 10)))
389 })
390 .add_edge(START, "step1")
391 .add_edge("step1", "step2")
392 .add_edge("step2", END)
393 .compile()
394 .unwrap();
395
396 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
397
398 assert_eq!(result.get("value"), Some(&json!(11)));
399 }
400
401 #[tokio::test]
402 async fn test_conditional_routing() {
403 let graph = StateGraph::with_channels(&["path", "result"])
404 .add_node_fn("router", |ctx| async move {
405 let path = ctx.get("path").and_then(|v| v.as_str()).unwrap_or("a");
406 Ok(NodeOutput::new().with_update("route", json!(path)))
407 })
408 .add_node_fn("path_a", |_ctx| async {
409 Ok(NodeOutput::new().with_update("result", json!("went to A")))
410 })
411 .add_node_fn("path_b", |_ctx| async {
412 Ok(NodeOutput::new().with_update("result", json!("went to B")))
413 })
414 .add_edge(START, "router")
415 .add_conditional_edges(
416 "router",
417 |state| state.get("route").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
418 [("a", "path_a"), ("b", "path_b"), (END, END)],
419 )
420 .add_edge("path_a", END)
421 .add_edge("path_b", END)
422 .compile()
423 .unwrap();
424
425 let mut input = State::new();
427 input.insert("path".to_string(), json!("a"));
428 let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
429 assert_eq!(result.get("result"), Some(&json!("went to A")));
430
431 let mut input = State::new();
433 input.insert("path".to_string(), json!("b"));
434 let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
435 assert_eq!(result.get("result"), Some(&json!("went to B")));
436 }
437
438 #[tokio::test]
439 async fn test_cycle_with_limit() {
440 let graph = StateGraph::with_channels(&["count"])
441 .add_node_fn("increment", |ctx| async move {
442 let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
443 Ok(NodeOutput::new().with_update("count", json!(count + 1)))
444 })
445 .add_edge(START, "increment")
446 .add_conditional_edges(
447 "increment",
448 |state| {
449 let count = state.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
450 if count < 5 {
451 "increment".to_string()
452 } else {
453 END.to_string()
454 }
455 },
456 [("increment", "increment"), (END, END)],
457 )
458 .compile()
459 .unwrap();
460
461 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
462
463 assert_eq!(result.get("count"), Some(&json!(5)));
464 }
465
466 #[tokio::test]
467 async fn test_recursion_limit() {
468 let graph = StateGraph::with_channels(&["count"])
469 .add_node_fn("loop", |ctx| async move {
470 let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
471 Ok(NodeOutput::new().with_update("count", json!(count + 1)))
472 })
473 .add_edge(START, "loop")
474 .add_edge("loop", "loop") .compile()
476 .unwrap()
477 .with_recursion_limit(10);
478
479 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await;
480
481 assert!(
483 matches!(result, Err(GraphError::RecursionLimitExceeded(_))),
484 "Expected RecursionLimitExceeded error, got: {:?}",
485 result
486 );
487 }
488}