1use crate::error::{GraphError, InterruptedExecution, Result};
6use crate::graph::CompiledGraph;
7use crate::interrupt::Interrupt;
8use crate::node::{ExecutionConfig, NodeContext};
9use crate::state::{Checkpoint, State};
10use crate::stream::{StreamEvent, StreamMode};
11use futures::stream::{self, StreamExt};
12use std::time::Instant;
13
14#[derive(Default)]
16pub struct SuperStepResult {
17 pub executed_nodes: Vec<String>,
19 pub interrupt: Option<Interrupt>,
21 pub events: Vec<StreamEvent>,
23}
24
25pub struct PregelExecutor<'a> {
27 graph: &'a CompiledGraph,
28 config: ExecutionConfig,
29 state: State,
30 step: usize,
31 pending_nodes: Vec<String>,
32}
33
34impl<'a> PregelExecutor<'a> {
35 pub fn new(graph: &'a CompiledGraph, config: ExecutionConfig) -> Self {
37 Self { graph, config, state: State::new(), step: 0, pending_nodes: vec![] }
38 }
39
40 pub async fn run(&mut self, input: State) -> Result<State> {
42 self.state = self.initialize_state(input).await?;
44 self.pending_nodes = self.graph.get_entry_nodes();
45
46 while !self.pending_nodes.is_empty() {
48 if self.step >= self.config.recursion_limit {
50 return Err(GraphError::RecursionLimitExceeded(self.step));
51 }
52
53 let result = self.execute_super_step().await?;
55
56 if let Some(interrupt) = result.interrupt {
58 let checkpoint_id = self.save_checkpoint().await?;
59 return Err(GraphError::Interrupted(Box::new(InterruptedExecution::new(
60 self.config.thread_id.clone(),
61 checkpoint_id,
62 interrupt,
63 self.state.clone(),
64 self.step,
65 ))));
66 }
67
68 self.save_checkpoint().await?;
70
71 if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
73 let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
74 if next.is_empty() {
75 break;
76 }
77 }
78
79 self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
81 self.step += 1;
82 }
83
84 Ok(self.state.clone())
85 }
86
87 pub fn run_stream(
89 mut self,
90 input: State,
91 mode: StreamMode,
92 ) -> impl futures::Stream<Item = Result<StreamEvent>> + 'a {
93 async_stream::stream! {
94 match self.initialize_state(input).await {
96 Ok(state) => self.state = state,
97 Err(e) => {
98 yield Err(e);
99 return;
100 }
101 }
102 self.pending_nodes = self.graph.get_entry_nodes();
103
104 if matches!(mode, StreamMode::Values) {
106 yield Ok(StreamEvent::state(self.state.clone(), self.step));
107 }
108
109 while !self.pending_nodes.is_empty() {
111 if self.step >= self.config.recursion_limit {
113 yield Err(GraphError::RecursionLimitExceeded(self.step));
114 return;
115 }
116
117 if matches!(mode, StreamMode::Debug | StreamMode::Custom | StreamMode::Messages) {
119 for node_name in &self.pending_nodes {
120 yield Ok(StreamEvent::node_start(node_name, self.step));
121 }
122 }
123
124 if matches!(mode, StreamMode::Messages) {
126 let mut result = SuperStepResult::default();
127
128 for node_name in &self.pending_nodes {
129 if let Some(node) = self.graph.nodes.get(node_name) {
130 let ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
131 let start = std::time::Instant::now();
132
133 let mut node_stream = node.execute_stream(&ctx);
134 let mut collected_events = Vec::new();
135
136 while let Some(event_result) = node_stream.next().await {
137 match event_result {
138 Ok(event) => {
139 if matches!(event, StreamEvent::Message { .. }) {
141 yield Ok(event.clone());
142 }
143 collected_events.push(event);
144 }
145 Err(e) => {
146 yield Err(e);
147 return;
148 }
149 }
150 }
151
152 let duration_ms = start.elapsed().as_millis() as u64;
153 result.executed_nodes.push(node_name.clone());
154 result.events.push(StreamEvent::node_end(node_name, self.step, duration_ms));
155 result.events.extend(collected_events);
156
157 if let Ok(output) = node.execute(&ctx).await {
159 for (key, value) in output.updates {
160 self.graph.schema.apply_update(&mut self.state, &key, value);
161 }
162 }
163 }
164 }
165
166 for event in &result.events {
168 if matches!(event, StreamEvent::NodeEnd { .. }) {
169 yield Ok(event.clone());
170 }
171 }
172
173 self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
174 self.step += 1;
175 continue;
176 }
177
178 let result = match self.execute_super_step().await {
180 Ok(r) => r,
181 Err(e) => {
182 yield Err(e);
183 return;
184 }
185 };
186
187 for event in &result.events {
189 match (&mode, &event) {
190 (StreamMode::Custom | StreamMode::Debug, StreamEvent::NodeStart { .. }) => {}
192 (StreamMode::Custom, _) => yield Ok(event.clone()),
193 (StreamMode::Debug, _) => yield Ok(event.clone()),
194 _ => {}
195 }
196 }
197
198 match mode {
200 StreamMode::Values => {
201 yield Ok(StreamEvent::state(self.state.clone(), self.step));
202 }
203 StreamMode::Updates => {
204 yield Ok(StreamEvent::step_complete(
205 self.step,
206 result.executed_nodes.clone(),
207 ));
208 }
209 _ => {}
210 }
211
212 if let Some(interrupt) = result.interrupt {
214 yield Ok(StreamEvent::interrupted(
215 result.executed_nodes.first().map(|s| s.as_str()).unwrap_or("unknown"),
216 &interrupt.to_string(),
217 ));
218 return;
219 }
220
221 if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
223 let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
224 if next.is_empty() {
225 break;
226 }
227 }
228
229 self.pending_nodes = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
230 self.step += 1;
231 }
232
233 yield Ok(StreamEvent::done(self.state.clone(), self.step + 1));
234 }
235 }
236
237 async fn initialize_state(&self, input: State) -> Result<State> {
239 let mut state = self.graph.schema.initialize_state();
241
242 if let Some(checkpoint_id) = &self.config.resume_from {
244 if let Some(cp) = self.graph.checkpointer.as_ref() {
245 if let Some(checkpoint) = cp.load_by_id(checkpoint_id).await? {
246 state = checkpoint.state;
247 }
248 }
249 } else if let Some(cp) = self.graph.checkpointer.as_ref() {
250 if let Some(checkpoint) = cp.load(&self.config.thread_id).await? {
252 state = checkpoint.state;
253 }
254 }
255
256 for (key, value) in input {
258 self.graph.schema.apply_update(&mut state, &key, value);
259 }
260
261 Ok(state)
262 }
263
264 async fn execute_super_step(&mut self) -> Result<SuperStepResult> {
266 let mut result = SuperStepResult::default();
267
268 for node_name in &self.pending_nodes {
270 if self.graph.interrupt_before.contains(node_name) {
271 return Ok(SuperStepResult {
272 interrupt: Some(Interrupt::Before(node_name.clone())),
273 ..Default::default()
274 });
275 }
276 }
277
278 let nodes: Vec<_> = self
280 .pending_nodes
281 .iter()
282 .filter_map(|name| self.graph.nodes.get(name).map(|n| (name.clone(), n.clone())))
283 .collect();
284
285 let futures: Vec<_> = nodes
286 .into_iter()
287 .map(|(name, node)| {
288 let ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
289 let step = self.step;
290 async move {
291 let start = Instant::now();
292 let output = node.execute(&ctx).await;
293 let duration_ms = start.elapsed().as_millis() as u64;
294 (name, output, duration_ms, step)
295 }
296 })
297 .collect();
298
299 let outputs: Vec<_> =
300 stream::iter(futures).buffer_unordered(self.pending_nodes.len()).collect().await;
301
302 let mut all_updates = Vec::new();
304
305 for (node_name, output_result, duration_ms, step) in outputs {
306 result.executed_nodes.push(node_name.clone());
307 result.events.push(StreamEvent::node_end(&node_name, step, duration_ms));
308
309 match output_result {
310 Ok(output) => {
311 if let Some(interrupt) = output.interrupt {
313 return Ok(SuperStepResult {
314 interrupt: Some(interrupt),
315 executed_nodes: result.executed_nodes,
316 events: result.events,
317 });
318 }
319
320 result.events.extend(output.events);
322
323 all_updates.push(output.updates);
325 }
326 Err(e) => {
327 return Err(GraphError::NodeExecutionFailed {
328 node: node_name,
329 message: e.to_string(),
330 });
331 }
332 }
333 }
334
335 for updates in all_updates {
337 for (key, value) in updates {
338 self.graph.schema.apply_update(&mut self.state, &key, value);
339 }
340 }
341
342 for node_name in &result.executed_nodes {
344 if self.graph.interrupt_after.contains(node_name) {
345 return Ok(SuperStepResult {
346 interrupt: Some(Interrupt::After(node_name.clone())),
347 ..result
348 });
349 }
350 }
351
352 Ok(result)
353 }
354
355 async fn save_checkpoint(&self) -> Result<String> {
357 if let Some(cp) = &self.graph.checkpointer {
358 let checkpoint = Checkpoint::new(
359 &self.config.thread_id,
360 self.state.clone(),
361 self.step,
362 self.pending_nodes.clone(),
363 );
364 return cp.save(&checkpoint).await;
365 }
366 Ok(String::new())
367 }
368}
369
370impl CompiledGraph {
372 pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
374 let mut executor = PregelExecutor::new(self, config);
375 executor.run(input).await
376 }
377
378 pub fn stream(
380 &self,
381 input: State,
382 config: ExecutionConfig,
383 mode: StreamMode,
384 ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
385 eprintln!("DEBUG: CompiledGraph::stream called with mode {:?}", mode);
386 let executor = PregelExecutor::new(self, config);
387 executor.run_stream(input, mode)
388 }
389
390 pub async fn get_state(&self, thread_id: &str) -> Result<Option<State>> {
392 if let Some(cp) = &self.checkpointer {
393 Ok(cp.load(thread_id).await?.map(|c| c.state))
394 } else {
395 Ok(None)
396 }
397 }
398
399 pub async fn update_state(
401 &self,
402 thread_id: &str,
403 updates: impl IntoIterator<Item = (String, serde_json::Value)>,
404 ) -> Result<()> {
405 if let Some(cp) = &self.checkpointer {
406 if let Some(checkpoint) = cp.load(thread_id).await? {
407 let mut state = checkpoint.state;
408 for (key, value) in updates {
409 self.schema.apply_update(&mut state, &key, value);
410 }
411 let new_checkpoint =
412 Checkpoint::new(thread_id, state, checkpoint.step, checkpoint.pending_nodes);
413 cp.save(&new_checkpoint).await?;
414 }
415 }
416 Ok(())
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use crate::edge::{END, START};
424 use crate::graph::StateGraph;
425 use crate::node::NodeOutput;
426 use serde_json::json;
427
428 #[tokio::test]
429 async fn test_simple_execution() {
430 let graph = StateGraph::with_channels(&["value"])
431 .add_node_fn("set_value", |_ctx| async {
432 Ok(NodeOutput::new().with_update("value", json!(42)))
433 })
434 .add_edge(START, "set_value")
435 .add_edge("set_value", END)
436 .compile()
437 .unwrap();
438
439 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
440
441 assert_eq!(result.get("value"), Some(&json!(42)));
442 }
443
444 #[tokio::test]
445 async fn test_sequential_execution() {
446 let graph = StateGraph::with_channels(&["value"])
447 .add_node_fn("step1", |_ctx| async {
448 Ok(NodeOutput::new().with_update("value", json!(1)))
449 })
450 .add_node_fn("step2", |ctx| async move {
451 let current = ctx.get("value").and_then(|v| v.as_i64()).unwrap_or(0);
452 Ok(NodeOutput::new().with_update("value", json!(current + 10)))
453 })
454 .add_edge(START, "step1")
455 .add_edge("step1", "step2")
456 .add_edge("step2", END)
457 .compile()
458 .unwrap();
459
460 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
461
462 assert_eq!(result.get("value"), Some(&json!(11)));
463 }
464
465 #[tokio::test]
466 async fn test_conditional_routing() {
467 let graph = StateGraph::with_channels(&["path", "result"])
468 .add_node_fn("router", |ctx| async move {
469 let path = ctx.get("path").and_then(|v| v.as_str()).unwrap_or("a");
470 Ok(NodeOutput::new().with_update("route", json!(path)))
471 })
472 .add_node_fn("path_a", |_ctx| async {
473 Ok(NodeOutput::new().with_update("result", json!("went to A")))
474 })
475 .add_node_fn("path_b", |_ctx| async {
476 Ok(NodeOutput::new().with_update("result", json!("went to B")))
477 })
478 .add_edge(START, "router")
479 .add_conditional_edges(
480 "router",
481 |state| state.get("route").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
482 [("a", "path_a"), ("b", "path_b"), (END, END)],
483 )
484 .add_edge("path_a", END)
485 .add_edge("path_b", END)
486 .compile()
487 .unwrap();
488
489 let mut input = State::new();
491 input.insert("path".to_string(), json!("a"));
492 let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
493 assert_eq!(result.get("result"), Some(&json!("went to A")));
494
495 let mut input = State::new();
497 input.insert("path".to_string(), json!("b"));
498 let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
499 assert_eq!(result.get("result"), Some(&json!("went to B")));
500 }
501
502 #[tokio::test]
503 async fn test_cycle_with_limit() {
504 let graph = StateGraph::with_channels(&["count"])
505 .add_node_fn("increment", |ctx| async move {
506 let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
507 Ok(NodeOutput::new().with_update("count", json!(count + 1)))
508 })
509 .add_edge(START, "increment")
510 .add_conditional_edges(
511 "increment",
512 |state| {
513 let count = state.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
514 if count < 5 { "increment".to_string() } else { END.to_string() }
515 },
516 [("increment", "increment"), (END, END)],
517 )
518 .compile()
519 .unwrap();
520
521 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
522
523 assert_eq!(result.get("count"), Some(&json!(5)));
524 }
525
526 #[tokio::test]
527 async fn test_recursion_limit() {
528 let graph = StateGraph::with_channels(&["count"])
529 .add_node_fn("loop", |ctx| async move {
530 let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
531 Ok(NodeOutput::new().with_update("count", json!(count + 1)))
532 })
533 .add_edge(START, "loop")
534 .add_edge("loop", "loop") .compile()
536 .unwrap()
537 .with_recursion_limit(10);
538
539 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await;
540
541 assert!(
543 matches!(result, Err(GraphError::RecursionLimitExceeded(_))),
544 "Expected RecursionLimitExceeded error, got: {:?}",
545 result
546 );
547 }
548}