1use crate::error::{GraphError, InterruptedExecution, Result};
6use crate::graph::CompiledGraph;
7use crate::interrupt::Interrupt;
8use crate::node::{ExecutionConfig, NodeContext};
9use crate::state::{Checkpoint, State};
10use crate::stream::{StreamEvent, StreamMode};
11use futures::stream::{self, StreamExt};
12use std::time::Instant;
13
14#[derive(Default)]
16pub struct SuperStepResult {
17 pub executed_nodes: Vec<String>,
19 pub interrupt: Option<Interrupt>,
21 pub events: Vec<StreamEvent>,
23}
24
25pub struct PregelExecutor<'a> {
27 graph: &'a CompiledGraph,
28 config: ExecutionConfig,
29 state: State,
30 step: usize,
31 pending_nodes: Vec<String>,
32}
33
34impl<'a> PregelExecutor<'a> {
35 pub fn new(graph: &'a CompiledGraph, config: ExecutionConfig) -> Self {
37 Self { graph, config, state: State::new(), step: 0, pending_nodes: vec![] }
38 }
39
40 async fn try_resume_from_checkpoint(&mut self, input: &State) -> Result<bool> {
49 let checkpoint = if let Some(checkpoint_id) = &self.config.resume_from {
50 if let Some(cp) = self.graph.checkpointer.as_ref() {
52 cp.load_by_id(checkpoint_id).await?
53 } else {
54 None
55 }
56 } else if let Some(cp) = self.graph.checkpointer.as_ref() {
57 cp.load(&self.config.thread_id).await?
59 } else {
60 None
61 };
62
63 if let Some(checkpoint) = checkpoint {
64 self.state = checkpoint.state;
66 self.pending_nodes = checkpoint.pending_nodes;
67 self.step = checkpoint.step;
68
69 for (key, value) in input {
71 self.graph.schema.apply_update(&mut self.state, key, value.clone());
72 }
73
74 Ok(true)
75 } else {
76 Ok(false)
77 }
78 }
79
80 pub async fn run(&mut self, input: State) -> Result<State> {
82 let resumed = self.try_resume_from_checkpoint(&input).await?;
84
85 if !resumed {
86 self.state = self.initialize_state(input).await?;
88 self.pending_nodes = self.graph.get_entry_nodes();
89 }
90
91 while !self.pending_nodes.is_empty() {
93 if self.step >= self.config.recursion_limit {
95 return Err(GraphError::RecursionLimitExceeded(self.step));
96 }
97
98 let result = self.execute_super_step().await?;
100
101 if let Some(interrupt) = result.interrupt {
103 let checkpoint_id = self.save_checkpoint().await?;
104 return Err(GraphError::Interrupted(Box::new(InterruptedExecution::new(
105 self.config.thread_id.clone(),
106 checkpoint_id,
107 interrupt,
108 self.state.clone(),
109 self.step,
110 ))));
111 }
112
113 self.save_checkpoint().await?;
115
116 if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
118 let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
119 if next.is_empty() {
120 break;
121 }
122 }
123
124 self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
126 self.step += 1;
127 }
128
129 Ok(self.state.clone())
130 }
131
132 pub fn run_stream(
134 mut self,
135 input: State,
136 mode: StreamMode,
137 ) -> impl futures::Stream<Item = Result<StreamEvent>> + 'a {
138 async_stream::stream! {
139 let resumed = match self.try_resume_from_checkpoint(&input).await {
141 Ok(r) => r,
142 Err(e) => {
143 yield Err(e);
144 return;
145 }
146 };
147
148 if resumed {
149 yield Ok(StreamEvent::resumed(self.step, self.pending_nodes.clone()));
151 } else {
152 match self.initialize_state(input).await {
154 Ok(state) => self.state = state,
155 Err(e) => {
156 yield Err(e);
157 return;
158 }
159 }
160 self.pending_nodes = self.graph.get_entry_nodes();
161 }
162
163 if matches!(mode, StreamMode::Values) {
165 yield Ok(StreamEvent::state(self.state.clone(), self.step));
166 }
167
168 while !self.pending_nodes.is_empty() {
170 if self.step >= self.config.recursion_limit {
172 yield Err(GraphError::RecursionLimitExceeded(self.step));
173 return;
174 }
175
176 if matches!(mode, StreamMode::Debug | StreamMode::Custom | StreamMode::Messages) {
178 for node_name in &self.pending_nodes {
179 yield Ok(StreamEvent::node_start(node_name, self.step));
180 }
181 }
182
183 if matches!(mode, StreamMode::Messages) {
185 let mut result = SuperStepResult::default();
186
187 for node_name in &self.pending_nodes {
188 if let Some(node) = self.graph.nodes.get(node_name) {
189 let ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
190 let start = std::time::Instant::now();
191
192 let mut node_stream = node.execute_stream(&ctx);
193 let mut collected_events = Vec::new();
194
195 while let Some(event_result) = node_stream.next().await {
196 match event_result {
197 Ok(event) => {
198 if matches!(event, StreamEvent::Message { .. }) {
200 yield Ok(event.clone());
201 }
202 collected_events.push(event);
203 }
204 Err(e) => {
205 yield Err(e);
206 return;
207 }
208 }
209 }
210
211 let duration_ms = start.elapsed().as_millis() as u64;
212 result.executed_nodes.push(node_name.clone());
213 result.events.push(StreamEvent::node_end(node_name, self.step, duration_ms));
214 result.events.extend(collected_events);
215
216 if let Ok(output) = node.execute(&ctx).await {
218 for (key, value) in output.updates {
219 self.graph.schema.apply_update(&mut self.state, &key, value);
220 }
221 }
222 }
223 }
224
225 for event in &result.events {
227 if matches!(event, StreamEvent::NodeEnd { .. }) {
228 yield Ok(event.clone());
229 }
230 }
231
232 self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
233 self.step += 1;
234 continue;
235 }
236
237 let result = match self.execute_super_step().await {
239 Ok(r) => r,
240 Err(e) => {
241 yield Err(e);
242 return;
243 }
244 };
245
246 for event in &result.events {
248 match (&mode, &event) {
249 (StreamMode::Custom | StreamMode::Debug, StreamEvent::NodeStart { .. }) => {}
251 (StreamMode::Custom, _) => yield Ok(event.clone()),
252 (StreamMode::Debug, _) => yield Ok(event.clone()),
253 _ => {}
254 }
255 }
256
257 match mode {
259 StreamMode::Values => {
260 yield Ok(StreamEvent::state(self.state.clone(), self.step));
261 }
262 StreamMode::Updates => {
263 yield Ok(StreamEvent::step_complete(
264 self.step,
265 result.executed_nodes.clone(),
266 ));
267 }
268 _ => {}
269 }
270
271 if let Some(interrupt) = result.interrupt {
273 yield Ok(StreamEvent::interrupted(
274 result.executed_nodes.first().map(|s| s.as_str()).unwrap_or("unknown"),
275 &interrupt.to_string(),
276 ));
277 return;
278 }
279
280 if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
282 let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
283 if next.is_empty() {
284 break;
285 }
286 }
287
288 self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
289 self.step += 1;
290 }
291
292 yield Ok(StreamEvent::done(self.state.clone(), self.step + 1));
293 }
294 }
295
296 async fn initialize_state(&self, input: State) -> Result<State> {
298 let mut state = self.graph.schema.initialize_state();
300
301 if let Some(checkpoint_id) = &self.config.resume_from {
303 if let Some(cp) = self.graph.checkpointer.as_ref() {
304 if let Some(checkpoint) = cp.load_by_id(checkpoint_id).await? {
305 state = checkpoint.state;
306 }
307 }
308 } else if let Some(cp) = self.graph.checkpointer.as_ref() {
309 if let Some(checkpoint) = cp.load(&self.config.thread_id).await? {
311 state = checkpoint.state;
312 }
313 }
314
315 for (key, value) in input {
317 self.graph.schema.apply_update(&mut state, &key, value);
318 }
319
320 Ok(state)
321 }
322
323 async fn execute_super_step(&mut self) -> Result<SuperStepResult> {
325 let mut result = SuperStepResult::default();
326
327 for node_name in &self.pending_nodes {
329 if self.graph.interrupt_before.contains(node_name) {
330 return Ok(SuperStepResult {
331 interrupt: Some(Interrupt::Before(node_name.clone())),
332 ..Default::default()
333 });
334 }
335 }
336
337 let nodes: Vec<_> = self
339 .pending_nodes
340 .iter()
341 .filter_map(|name| self.graph.nodes.get(name).map(|n| (name.clone(), n.clone())))
342 .collect();
343
344 let futures: Vec<_> = nodes
345 .into_iter()
346 .map(|(name, node)| {
347 let ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
348 let step = self.step;
349 async move {
350 let start = Instant::now();
351 let output = node.execute(&ctx).await;
352 let duration_ms = start.elapsed().as_millis() as u64;
353 (name, output, duration_ms, step)
354 }
355 })
356 .collect();
357
358 let outputs: Vec<_> =
359 stream::iter(futures).buffer_unordered(self.pending_nodes.len()).collect().await;
360
361 let mut all_updates = Vec::new();
363
364 for (node_name, output_result, duration_ms, step) in outputs {
365 result.executed_nodes.push(node_name.clone());
366 result.events.push(StreamEvent::node_end(&node_name, step, duration_ms));
367
368 match output_result {
369 Ok(output) => {
370 if let Some(interrupt) = output.interrupt {
372 return Ok(SuperStepResult {
373 interrupt: Some(interrupt),
374 executed_nodes: result.executed_nodes,
375 events: result.events,
376 });
377 }
378
379 result.events.extend(output.events);
381
382 all_updates.push(output.updates);
384 }
385 Err(e) => {
386 return Err(GraphError::NodeExecutionFailed {
387 node: node_name,
388 message: e.to_string(),
389 });
390 }
391 }
392 }
393
394 for updates in all_updates {
396 for (key, value) in updates {
397 self.graph.schema.apply_update(&mut self.state, &key, value);
398 }
399 }
400
401 for node_name in &result.executed_nodes {
403 if self.graph.interrupt_after.contains(node_name) {
404 return Ok(SuperStepResult {
405 interrupt: Some(Interrupt::After(node_name.clone())),
406 ..result
407 });
408 }
409 }
410
411 Ok(result)
412 }
413
414 async fn save_checkpoint(&self) -> Result<String> {
416 if let Some(cp) = &self.graph.checkpointer {
417 let checkpoint = Checkpoint::new(
418 &self.config.thread_id,
419 self.state.clone(),
420 self.step,
421 self.pending_nodes.clone(),
422 );
423 return cp.save(&checkpoint).await;
424 }
425 Ok(String::new())
426 }
427}
428
429impl CompiledGraph {
431 pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
433 let mut executor = PregelExecutor::new(self, config);
434 executor.run(input).await
435 }
436
437 pub fn stream(
439 &self,
440 input: State,
441 config: ExecutionConfig,
442 mode: StreamMode,
443 ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
444 tracing::debug!("CompiledGraph::stream called with mode {:?}", mode);
445 let executor = PregelExecutor::new(self, config);
446 executor.run_stream(input, mode)
447 }
448
449 pub async fn get_state(&self, thread_id: &str) -> Result<Option<State>> {
451 if let Some(cp) = &self.checkpointer {
452 Ok(cp.load(thread_id).await?.map(|c| c.state))
453 } else {
454 Ok(None)
455 }
456 }
457
458 pub async fn update_state(
460 &self,
461 thread_id: &str,
462 updates: impl IntoIterator<Item = (String, serde_json::Value)>,
463 ) -> Result<()> {
464 if let Some(cp) = &self.checkpointer {
465 if let Some(checkpoint) = cp.load(thread_id).await? {
466 let mut state = checkpoint.state;
467 for (key, value) in updates {
468 self.schema.apply_update(&mut state, &key, value);
469 }
470 let new_checkpoint =
471 Checkpoint::new(thread_id, state, checkpoint.step, checkpoint.pending_nodes);
472 cp.save(&new_checkpoint).await?;
473 }
474 }
475 Ok(())
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482 use crate::edge::{END, START};
483 use crate::graph::StateGraph;
484 use crate::node::NodeOutput;
485 use serde_json::json;
486
487 #[tokio::test]
488 async fn test_simple_execution() {
489 let graph = StateGraph::with_channels(&["value"])
490 .add_node_fn("set_value", |_ctx| async {
491 Ok(NodeOutput::new().with_update("value", json!(42)))
492 })
493 .add_edge(START, "set_value")
494 .add_edge("set_value", END)
495 .compile()
496 .unwrap();
497
498 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
499
500 assert_eq!(result.get("value"), Some(&json!(42)));
501 }
502
503 #[tokio::test]
504 async fn test_sequential_execution() {
505 let graph = StateGraph::with_channels(&["value"])
506 .add_node_fn("step1", |_ctx| async {
507 Ok(NodeOutput::new().with_update("value", json!(1)))
508 })
509 .add_node_fn("step2", |ctx| async move {
510 let current = ctx.get("value").and_then(|v| v.as_i64()).unwrap_or(0);
511 Ok(NodeOutput::new().with_update("value", json!(current + 10)))
512 })
513 .add_edge(START, "step1")
514 .add_edge("step1", "step2")
515 .add_edge("step2", END)
516 .compile()
517 .unwrap();
518
519 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
520
521 assert_eq!(result.get("value"), Some(&json!(11)));
522 }
523
524 #[tokio::test]
525 async fn test_conditional_routing() {
526 let graph = StateGraph::with_channels(&["path", "result"])
527 .add_node_fn("router", |ctx| async move {
528 let path = ctx.get("path").and_then(|v| v.as_str()).unwrap_or("a");
529 Ok(NodeOutput::new().with_update("route", json!(path)))
530 })
531 .add_node_fn("path_a", |_ctx| async {
532 Ok(NodeOutput::new().with_update("result", json!("went to A")))
533 })
534 .add_node_fn("path_b", |_ctx| async {
535 Ok(NodeOutput::new().with_update("result", json!("went to B")))
536 })
537 .add_edge(START, "router")
538 .add_conditional_edges(
539 "router",
540 |state| state.get("route").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
541 [("a", "path_a"), ("b", "path_b"), (END, END)],
542 )
543 .add_edge("path_a", END)
544 .add_edge("path_b", END)
545 .compile()
546 .unwrap();
547
548 let mut input = State::new();
550 input.insert("path".to_string(), json!("a"));
551 let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
552 assert_eq!(result.get("result"), Some(&json!("went to A")));
553
554 let mut input = State::new();
556 input.insert("path".to_string(), json!("b"));
557 let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
558 assert_eq!(result.get("result"), Some(&json!("went to B")));
559 }
560
561 #[tokio::test]
562 async fn test_cycle_with_limit() {
563 let graph = StateGraph::with_channels(&["count"])
564 .add_node_fn("increment", |ctx| async move {
565 let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
566 Ok(NodeOutput::new().with_update("count", json!(count + 1)))
567 })
568 .add_edge(START, "increment")
569 .add_conditional_edges(
570 "increment",
571 |state| {
572 let count = state.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
573 if count < 5 { "increment".to_string() } else { END.to_string() }
574 },
575 [("increment", "increment"), (END, END)],
576 )
577 .compile()
578 .unwrap();
579
580 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
581
582 assert_eq!(result.get("count"), Some(&json!(5)));
583 }
584
585 #[tokio::test]
586 async fn test_recursion_limit() {
587 let graph = StateGraph::with_channels(&["count"])
588 .add_node_fn("loop", |ctx| async move {
589 let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
590 Ok(NodeOutput::new().with_update("count", json!(count + 1)))
591 })
592 .add_edge(START, "loop")
593 .add_edge("loop", "loop") .compile()
595 .unwrap()
596 .with_recursion_limit(10);
597
598 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await;
599
600 assert!(
602 matches!(result, Err(GraphError::RecursionLimitExceeded(_))),
603 "Expected RecursionLimitExceeded error, got: {:?}",
604 result
605 );
606 }
607}