1#[cfg(feature = "node-cache")]
6use crate::cache::{NodeCache, compute_cache_key};
7use crate::deferred::FanInTracker;
8use crate::error::{GraphError, InterruptedExecution, Result};
9use crate::graph::CompiledGraph;
10use crate::interrupt::Interrupt;
11use crate::node::{ExecutionConfig, NodeContext};
12use crate::state::{Checkpoint, State};
13use crate::stream::{StreamEvent, StreamMode};
14use crate::timeout::{ProgressHandle, execute_with_timeout};
15use futures::stream::{self, StreamExt};
16use std::collections::HashMap;
17use std::time::Instant;
18
19#[derive(Default)]
21pub struct SuperStepResult {
22 pub executed_nodes: Vec<String>,
24 pub interrupt: Option<Interrupt>,
26 pub events: Vec<StreamEvent>,
28}
29
30pub struct PregelExecutor<'a> {
32 graph: &'a CompiledGraph,
33 config: ExecutionConfig,
34 state: State,
35 step: usize,
36 pending_nodes: Vec<String>,
37 pending_deferred: HashMap<String, FanInTracker>,
39 deferred_start_times: HashMap<String, Instant>,
41 #[cfg(feature = "node-cache")]
43 node_caches: HashMap<String, NodeCache>,
44}
45
46impl<'a> PregelExecutor<'a> {
47 pub fn new(graph: &'a CompiledGraph, config: ExecutionConfig) -> Self {
49 #[cfg(feature = "node-cache")]
50 let node_caches = graph
51 .cache_policies
52 .iter()
53 .map(|(name, policy)| (name.clone(), NodeCache::from_policy(policy)))
54 .collect();
55
56 Self {
57 graph,
58 config,
59 state: State::new(),
60 step: 0,
61 pending_nodes: vec![],
62 pending_deferred: HashMap::new(),
63 deferred_start_times: HashMap::new(),
64 #[cfg(feature = "node-cache")]
65 node_caches,
66 }
67 }
68
69 async fn try_resume_from_checkpoint(&mut self, input: &State) -> Result<bool> {
78 let checkpoint = if let Some(checkpoint_id) = &self.config.resume_from {
79 if let Some(cp) = self.graph.checkpointer.as_ref() {
81 cp.load_by_id(checkpoint_id).await?
82 } else {
83 None
84 }
85 } else if let Some(cp) = self.graph.checkpointer.as_ref() {
86 cp.load(&self.config.thread_id).await?
88 } else {
89 None
90 };
91
92 if let Some(checkpoint) = checkpoint {
93 self.state = checkpoint.state;
95 self.pending_nodes = checkpoint.pending_nodes;
96 self.step = checkpoint.step;
97
98 for (key, value) in input {
100 self.graph.schema.apply_update(&mut self.state, key, value.clone());
101 }
102
103 Ok(true)
104 } else {
105 Ok(false)
106 }
107 }
108
109 pub async fn run(&mut self, input: State) -> Result<State> {
111 let resumed = self.try_resume_from_checkpoint(&input).await?;
113
114 if !resumed {
115 self.state = self.initialize_state(input).await?;
117 self.pending_nodes = self.graph.get_entry_nodes();
118 }
119
120 while !self.pending_nodes.is_empty() {
122 if self.step >= self.config.recursion_limit {
124 return Err(GraphError::RecursionLimitExceeded(self.step));
125 }
126
127 let result = self.execute_super_step().await?;
129
130 if let Some(interrupt) = result.interrupt {
132 let checkpoint_id = self.save_checkpoint().await?;
133 return Err(GraphError::Interrupted(Box::new(InterruptedExecution::new(
134 self.config.thread_id.clone(),
135 checkpoint_id,
136 interrupt,
137 self.state.clone(),
138 self.step,
139 ))));
140 }
141
142 self.save_checkpoint().await?;
144
145 if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
147 let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
148 if next.is_empty() {
149 break;
150 }
151 }
152
153 let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
155 self.pending_nodes =
156 self.filter_deferred_nodes(next_candidates, &result.executed_nodes)?;
157 self.step += 1;
158 }
159
160 Ok(self.state.clone())
161 }
162
163 pub fn run_stream(
165 mut self,
166 input: State,
167 mode: StreamMode,
168 ) -> impl futures::Stream<Item = Result<StreamEvent>> + 'a {
169 async_stream::stream! {
170 let resumed = match self.try_resume_from_checkpoint(&input).await {
172 Ok(r) => r,
173 Err(e) => {
174 yield Err(e);
175 return;
176 }
177 };
178
179 if resumed {
180 yield Ok(StreamEvent::resumed(self.step, self.pending_nodes.clone()));
182 } else {
183 match self.initialize_state(input).await {
185 Ok(state) => self.state = state,
186 Err(e) => {
187 yield Err(e);
188 return;
189 }
190 }
191 self.pending_nodes = self.graph.get_entry_nodes();
192 }
193
194 if matches!(mode, StreamMode::Values) {
196 yield Ok(StreamEvent::state(self.state.clone(), self.step));
197 }
198
199 while !self.pending_nodes.is_empty() {
201 if self.step >= self.config.recursion_limit {
203 yield Err(GraphError::RecursionLimitExceeded(self.step));
204 return;
205 }
206
207 if matches!(mode, StreamMode::Debug | StreamMode::Custom | StreamMode::Messages) {
209 for node_name in &self.pending_nodes {
210 yield Ok(StreamEvent::node_start(node_name, self.step));
211 }
212 }
213
214 if matches!(mode, StreamMode::Messages) {
216 let mut result = SuperStepResult::default();
217
218 for node_name in &self.pending_nodes {
219 if let Some(node) = self.graph.nodes.get(node_name) {
220 let mut ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
221
222 let policy = self.graph.timeout_policy_for(node_name).cloned();
224 if let Some(ref p) = policy
225 && p.idle_timeout.is_some() {
226 ctx.set_progress_handle(ProgressHandle::new());
227 }
228
229 let start = std::time::Instant::now();
230
231 let mut node_stream = node.execute_stream(&ctx);
232 let mut collected_events = Vec::new();
233
234 while let Some(event_result) = node_stream.next().await {
235 match event_result {
236 Ok(event) => {
237 if matches!(event, StreamEvent::Message { .. }) {
239 yield Ok(event.clone());
240 }
241 collected_events.push(event);
242 }
243 Err(e) => {
244 yield Err(e);
245 return;
246 }
247 }
248 }
249
250 let duration_ms = start.elapsed().as_millis() as u64;
251 result.executed_nodes.push(node_name.clone());
252 result.events.push(StreamEvent::node_end(node_name, self.step, duration_ms));
253 result.events.extend(collected_events);
254
255 let output_result = match policy {
257 Some(ref timeout_policy) => {
258 execute_with_timeout(node.as_ref(), &ctx, timeout_policy).await
259 }
260 None => node.execute(&ctx).await,
261 };
262 if let Ok(output) = output_result {
263 for (key, value) in output.updates {
264 self.graph.schema.apply_update(&mut self.state, &key, value);
265 }
266 }
267 }
268 }
269
270 for event in &result.events {
272 if matches!(event, StreamEvent::NodeEnd { .. }) {
273 yield Ok(event.clone());
274 }
275 }
276
277 self.pending_nodes = {
278 let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
279 match self.filter_deferred_nodes(next_candidates, &result.executed_nodes) {
280 Ok(nodes) => nodes,
281 Err(e) => {
282 yield Err(e);
283 return;
284 }
285 }
286 };
287 self.step += 1;
288 continue;
289 }
290
291 let result = match self.execute_super_step().await {
293 Ok(r) => r,
294 Err(e) => {
295 yield Err(e);
296 return;
297 }
298 };
299
300 for event in &result.events {
302 match (&mode, &event) {
303 (StreamMode::Custom | StreamMode::Debug, StreamEvent::NodeStart { .. }) => {}
305 (StreamMode::Custom, _) => yield Ok(event.clone()),
306 (StreamMode::Debug, _) => yield Ok(event.clone()),
307 _ => {}
308 }
309 }
310
311 match mode {
313 StreamMode::Values => {
314 yield Ok(StreamEvent::state(self.state.clone(), self.step));
315 }
316 StreamMode::Updates => {
317 yield Ok(StreamEvent::step_complete(
318 self.step,
319 result.executed_nodes.clone(),
320 ));
321 }
322 _ => {}
323 }
324
325 if let Some(interrupt) = result.interrupt {
327 yield Ok(StreamEvent::interrupted(
328 result.executed_nodes.first().map(|s| s.as_str()).unwrap_or("unknown"),
329 &interrupt.to_string(),
330 ));
331 return;
332 }
333
334 if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
336 let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
337 if next.is_empty() {
338 break;
339 }
340 }
341
342 self.pending_nodes = {
343 let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
344 match self.filter_deferred_nodes(next_candidates, &result.executed_nodes) {
345 Ok(nodes) => nodes,
346 Err(e) => {
347 yield Err(e);
348 return;
349 }
350 }
351 };
352 self.step += 1;
353 }
354
355 yield Ok(StreamEvent::done(self.state.clone(), self.step + 1));
356 }
357 }
358
359 fn filter_deferred_nodes(
372 &mut self,
373 candidates: Vec<String>,
374 executed_nodes: &[String],
375 ) -> Result<Vec<String>> {
376 let mut ready_nodes = Vec::new();
377
378 for candidate in candidates {
379 if let Some(config) = self.graph.deferred_configs.get(&candidate) {
380 let upstream = self.graph.get_upstream_nodes(&candidate);
382
383 let tracker = self.pending_deferred.entry(candidate.clone()).or_insert_with(|| {
385 let sources: Vec<&str> = upstream.iter().map(|s| s.as_str()).collect();
386 FanInTracker::new(sources)
387 });
388
389 self.deferred_start_times.entry(candidate.clone()).or_insert_with(Instant::now);
391
392 for executed in executed_nodes {
394 if upstream.contains(executed) {
395 let output = self.state.get(executed).cloned().unwrap_or_else(|| {
398 serde_json::Value::Object(
400 self.state.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
401 )
402 });
403 tracker.record(executed, output);
404 }
405 }
406
407 if tracker.is_ready() {
408 let merged = tracker.merge(&config.merge_strategy);
410 let fan_in_key = format!("{candidate}_fan_in");
411 self.graph.schema.apply_update(&mut self.state, &fan_in_key, merged);
412
413 self.pending_deferred.remove(&candidate);
415 self.deferred_start_times.remove(&candidate);
416 ready_nodes.push(candidate);
417 } else if let Some(timeout_duration) = config.fan_in_timeout {
418 let start_time = self.deferred_start_times[&candidate];
420 if start_time.elapsed() >= timeout_duration {
421 let received = tracker.received_count();
422 let expected = tracker.expected_count();
423
424 if received > 0 {
425 tracing::warn!(
427 node = %candidate,
428 received,
429 expected,
430 "fan-in timeout expired, proceeding with partial results"
431 );
432 let merged = tracker.merge(&config.merge_strategy);
433 let fan_in_key = format!("{candidate}_fan_in");
434 self.graph.schema.apply_update(&mut self.state, &fan_in_key, merged);
435
436 self.pending_deferred.remove(&candidate);
438 self.deferred_start_times.remove(&candidate);
439 ready_nodes.push(candidate);
440 } else {
441 self.pending_deferred.remove(&candidate);
443 self.deferred_start_times.remove(&candidate);
444 return Err(GraphError::FanInTimedOut {
445 node: candidate,
446 received,
447 expected,
448 });
449 }
450 }
451 }
452 } else {
455 ready_nodes.push(candidate);
457 }
458 }
459
460 Ok(ready_nodes)
461 }
462
463 async fn initialize_state(&self, input: State) -> Result<State> {
465 let mut state = self.graph.schema.initialize_state();
467
468 if let Some(checkpoint_id) = &self.config.resume_from {
470 if let Some(cp) = self.graph.checkpointer.as_ref()
471 && let Some(checkpoint) = cp.load_by_id(checkpoint_id).await?
472 {
473 state = checkpoint.state;
474 }
475 } else if let Some(cp) = self.graph.checkpointer.as_ref() {
476 if let Some(checkpoint) = cp.load(&self.config.thread_id).await? {
478 state = checkpoint.state;
479 }
480 }
481
482 for (key, value) in input {
484 self.graph.schema.apply_update(&mut state, &key, value);
485 }
486
487 Ok(state)
488 }
489
490 async fn execute_super_step(&mut self) -> Result<SuperStepResult> {
492 let mut result = SuperStepResult::default();
493
494 for node_name in &self.pending_nodes {
496 if self.graph.interrupt_before.contains(node_name) {
497 return Ok(SuperStepResult {
498 interrupt: Some(Interrupt::Before(node_name.clone())),
499 ..Default::default()
500 });
501 }
502 }
503
504 #[cfg(feature = "node-cache")]
506 let mut cached_results: HashMap<String, serde_json::Value> = HashMap::new();
507 #[cfg(feature = "node-cache")]
508 let mut nodes_to_execute: Vec<String> = Vec::new();
509
510 #[cfg(feature = "node-cache")]
511 {
512 for node_name in &self.pending_nodes {
513 if let Some(cache) = self.node_caches.get(node_name) {
514 let cache_key = compute_cache_key(node_name, &self.state);
515 let cached_value = cache.get(&cache_key).await;
516 tracing::debug!(
517 node = %node_name,
518 cache_hit = cached_value.is_some(),
519 cache_key = %cache_key,
520 "node cache lookup"
521 );
522 if let Some(value) = cached_value {
523 cached_results.insert(node_name.clone(), value);
525 } else {
526 nodes_to_execute.push(node_name.clone());
528 }
529 } else {
530 nodes_to_execute.push(node_name.clone());
532 }
533 }
534 }
535
536 #[cfg(feature = "node-cache")]
538 {
539 for (node_name, cached_value) in &cached_results {
540 result.executed_nodes.push(node_name.clone());
541 result.events.push(StreamEvent::node_end(node_name, self.step, 0));
542
543 if let Some(updates_map) = cached_value.as_object() {
545 for (key, value) in updates_map {
546 self.graph.schema.apply_update(&mut self.state, key, value.clone());
547 }
548 }
549 }
550 }
551
552 #[cfg(feature = "node-cache")]
554 let pending_for_execution = &nodes_to_execute;
555 #[cfg(not(feature = "node-cache"))]
556 let pending_for_execution = &self.pending_nodes;
557
558 let nodes: Vec<_> = pending_for_execution
560 .iter()
561 .filter_map(|name| self.graph.nodes.get(name).map(|n| (name.clone(), n.clone())))
562 .collect();
563
564 let timeout_policies: Vec<_> =
566 nodes.iter().map(|(name, _)| self.graph.timeout_policy_for(name).cloned()).collect();
567
568 let futures: Vec<_> = nodes
569 .into_iter()
570 .zip(timeout_policies)
571 .map(|((name, node), policy)| {
572 let mut ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
573
574 if let Some(ref p) = policy
576 && p.idle_timeout.is_some()
577 {
578 ctx.set_progress_handle(ProgressHandle::new());
579 }
580
581 let step = self.step;
582 async move {
583 let start = Instant::now();
584 let output = match policy {
585 Some(ref timeout_policy) => {
586 execute_with_timeout(node.as_ref(), &ctx, timeout_policy).await
587 }
588 None => node.execute(&ctx).await,
589 };
590 let duration_ms = start.elapsed().as_millis() as u64;
591 (name, output, duration_ms, step)
592 }
593 })
594 .collect();
595
596 let outputs: Vec<_> =
597 stream::iter(futures).buffer_unordered(pending_for_execution.len()).collect().await;
598
599 let mut all_updates = Vec::new();
601
602 for (node_name, output_result, duration_ms, step) in outputs {
603 result.executed_nodes.push(node_name.clone());
604 result.events.push(StreamEvent::node_end(&node_name, step, duration_ms));
605
606 match output_result {
607 Ok(output) => {
608 if let Some(interrupt) = output.interrupt {
610 return Ok(SuperStepResult {
611 interrupt: Some(interrupt),
612 executed_nodes: result.executed_nodes,
613 events: result.events,
614 });
615 }
616
617 result.events.extend(output.events);
619
620 #[cfg(feature = "node-cache")]
622 {
623 if let Some(cache) = self.node_caches.get(&node_name) {
624 let cache_key = compute_cache_key(&node_name, &self.state);
625 let updates_value = serde_json::to_value(&output.updates)
626 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
627 let ttl = self.graph.cache_policies.get(&node_name).and_then(|p| p.ttl);
628 cache.set(&cache_key, updates_value, ttl).await;
629 }
630 }
631
632 all_updates.push(output.updates);
634 }
635 Err(e) => {
636 return Err(GraphError::NodeExecutionFailed {
637 node: node_name,
638 message: e.to_string(),
639 });
640 }
641 }
642 }
643
644 for updates in all_updates {
646 for (key, value) in updates {
647 self.graph.schema.apply_update(&mut self.state, &key, value);
648 }
649 }
650
651 for node_name in &result.executed_nodes {
653 if self.graph.interrupt_after.contains(node_name) {
654 return Ok(SuperStepResult {
655 interrupt: Some(Interrupt::After(node_name.clone())),
656 ..result
657 });
658 }
659 }
660
661 Ok(result)
662 }
663
664 async fn save_checkpoint(&self) -> Result<String> {
666 if let Some(cp) = &self.graph.checkpointer {
667 let checkpoint = Checkpoint::new(
668 &self.config.thread_id,
669 self.state.clone(),
670 self.step,
671 self.pending_nodes.clone(),
672 );
673 return cp.save(&checkpoint).await;
674 }
675 Ok(String::new())
676 }
677}
678
679impl CompiledGraph {
681 pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
683 let mut executor = PregelExecutor::new(self, config);
684 executor.run(input).await
685 }
686
687 pub fn stream(
689 &self,
690 input: State,
691 config: ExecutionConfig,
692 mode: StreamMode,
693 ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
694 tracing::debug!("CompiledGraph::stream called with mode {:?}", mode);
695 let executor = PregelExecutor::new(self, config);
696 executor.run_stream(input, mode)
697 }
698
699 pub async fn get_state(&self, thread_id: &str) -> Result<Option<State>> {
701 if let Some(cp) = &self.checkpointer {
702 Ok(cp.load(thread_id).await?.map(|c| c.state))
703 } else {
704 Ok(None)
705 }
706 }
707
708 pub async fn update_state(
710 &self,
711 thread_id: &str,
712 updates: impl IntoIterator<Item = (String, serde_json::Value)>,
713 ) -> Result<()> {
714 if let Some(cp) = &self.checkpointer
715 && let Some(checkpoint) = cp.load(thread_id).await?
716 {
717 let mut state = checkpoint.state;
718 for (key, value) in updates {
719 self.schema.apply_update(&mut state, &key, value);
720 }
721 let new_checkpoint =
722 Checkpoint::new(thread_id, state, checkpoint.step, checkpoint.pending_nodes);
723 cp.save(&new_checkpoint).await?;
724 }
725 Ok(())
726 }
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732 use crate::edge::{END, START};
733 use crate::graph::StateGraph;
734 use crate::node::NodeOutput;
735 use serde_json::json;
736
737 #[tokio::test]
738 async fn test_simple_execution() {
739 let graph = StateGraph::with_channels(&["value"])
740 .add_node_fn("set_value", |_ctx| async {
741 Ok(NodeOutput::new().with_update("value", json!(42)))
742 })
743 .add_edge(START, "set_value")
744 .add_edge("set_value", END)
745 .compile()
746 .unwrap();
747
748 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
749
750 assert_eq!(result.get("value"), Some(&json!(42)));
751 }
752
753 #[tokio::test]
754 async fn test_sequential_execution() {
755 let graph = StateGraph::with_channels(&["value"])
756 .add_node_fn("step1", |_ctx| async {
757 Ok(NodeOutput::new().with_update("value", json!(1)))
758 })
759 .add_node_fn("step2", |ctx| async move {
760 let current = ctx.get("value").and_then(|v| v.as_i64()).unwrap_or(0);
761 Ok(NodeOutput::new().with_update("value", json!(current + 10)))
762 })
763 .add_edge(START, "step1")
764 .add_edge("step1", "step2")
765 .add_edge("step2", END)
766 .compile()
767 .unwrap();
768
769 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
770
771 assert_eq!(result.get("value"), Some(&json!(11)));
772 }
773
774 #[tokio::test]
775 async fn test_conditional_routing() {
776 let graph = StateGraph::with_channels(&["path", "result"])
777 .add_node_fn("router", |ctx| async move {
778 let path = ctx.get("path").and_then(|v| v.as_str()).unwrap_or("a");
779 Ok(NodeOutput::new().with_update("route", json!(path)))
780 })
781 .add_node_fn("path_a", |_ctx| async {
782 Ok(NodeOutput::new().with_update("result", json!("went to A")))
783 })
784 .add_node_fn("path_b", |_ctx| async {
785 Ok(NodeOutput::new().with_update("result", json!("went to B")))
786 })
787 .add_edge(START, "router")
788 .add_conditional_edges(
789 "router",
790 |state| state.get("route").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
791 [("a", "path_a"), ("b", "path_b"), (END, END)],
792 )
793 .add_edge("path_a", END)
794 .add_edge("path_b", END)
795 .compile()
796 .unwrap();
797
798 let mut input = State::new();
800 input.insert("path".to_string(), json!("a"));
801 let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
802 assert_eq!(result.get("result"), Some(&json!("went to A")));
803
804 let mut input = State::new();
806 input.insert("path".to_string(), json!("b"));
807 let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
808 assert_eq!(result.get("result"), Some(&json!("went to B")));
809 }
810
811 #[tokio::test]
812 async fn test_cycle_with_limit() {
813 let graph = StateGraph::with_channels(&["count"])
814 .add_node_fn("increment", |ctx| async move {
815 let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
816 Ok(NodeOutput::new().with_update("count", json!(count + 1)))
817 })
818 .add_edge(START, "increment")
819 .add_conditional_edges(
820 "increment",
821 |state| {
822 let count = state.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
823 if count < 5 { "increment".to_string() } else { END.to_string() }
824 },
825 [("increment", "increment"), (END, END)],
826 )
827 .compile()
828 .unwrap();
829
830 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
831
832 assert_eq!(result.get("count"), Some(&json!(5)));
833 }
834
835 #[tokio::test]
836 async fn test_recursion_limit() {
837 let graph = StateGraph::with_channels(&["count"])
838 .add_node_fn("loop", |ctx| async move {
839 let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
840 Ok(NodeOutput::new().with_update("count", json!(count + 1)))
841 })
842 .add_edge(START, "loop")
843 .add_edge("loop", "loop") .compile()
845 .unwrap()
846 .with_recursion_limit(10);
847
848 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await;
849
850 assert!(
852 matches!(result, Err(GraphError::RecursionLimitExceeded(_))),
853 "Expected RecursionLimitExceeded error, got: {:?}",
854 result
855 );
856 }
857}