1#[cfg(feature = "node-cache")]
6use crate::cache::{NodeCache, compute_cache_key};
7use crate::deferred::FanInTracker;
8use crate::error::{GraphError, InterruptedExecution, Result};
9use crate::graph::CompiledGraph;
10use crate::interrupt::Interrupt;
11use crate::node::{ExecutionConfig, NodeContext};
12use crate::state::{Checkpoint, State};
13use crate::stream::{StreamEvent, StreamMode};
14use crate::timeout::{ProgressHandle, execute_with_timeout};
15use futures::stream::{self, StreamExt};
16use std::collections::HashMap;
17use std::time::Instant;
18
19#[derive(Default)]
21pub struct SuperStepResult {
22 pub executed_nodes: Vec<String>,
24 pub interrupt: Option<Interrupt>,
26 pub events: Vec<StreamEvent>,
28}
29
30pub struct PregelExecutor<'a> {
32 graph: &'a CompiledGraph,
33 config: ExecutionConfig,
34 state: State,
35 step: usize,
36 pending_nodes: Vec<String>,
37 pending_deferred: HashMap<String, FanInTracker>,
39 deferred_start_times: HashMap<String, Instant>,
41 #[cfg(feature = "node-cache")]
43 node_caches: HashMap<String, NodeCache>,
44}
45
46impl<'a> PregelExecutor<'a> {
47 pub fn new(graph: &'a CompiledGraph, config: ExecutionConfig) -> Self {
49 #[cfg(feature = "node-cache")]
50 let node_caches = graph
51 .cache_policies
52 .iter()
53 .map(|(name, policy)| (name.clone(), NodeCache::from_policy(policy)))
54 .collect();
55
56 Self {
57 graph,
58 config,
59 state: State::new(),
60 step: 0,
61 pending_nodes: vec![],
62 pending_deferred: HashMap::new(),
63 deferred_start_times: HashMap::new(),
64 #[cfg(feature = "node-cache")]
65 node_caches,
66 }
67 }
68
69 async fn try_resume_from_checkpoint(&mut self, input: &State) -> Result<bool> {
78 let checkpoint = if let Some(checkpoint_id) = &self.config.resume_from {
79 if let Some(cp) = self.graph.checkpointer.as_ref() {
81 cp.load_by_id(checkpoint_id).await?
82 } else {
83 None
84 }
85 } else if let Some(cp) = self.graph.checkpointer.as_ref() {
86 cp.load(&self.config.thread_id).await?
88 } else {
89 None
90 };
91
92 if let Some(checkpoint) = checkpoint {
93 self.state = checkpoint.state;
95 self.pending_nodes = checkpoint.pending_nodes;
96 self.step = checkpoint.step;
97
98 for (key, value) in input {
100 self.graph.schema.apply_update(&mut self.state, key, value.clone());
101 }
102
103 Ok(true)
104 } else {
105 Ok(false)
106 }
107 }
108
109 pub async fn run(&mut self, input: State) -> Result<State> {
111 let resumed = self.try_resume_from_checkpoint(&input).await?;
113
114 if !resumed {
115 self.state = self.initialize_state(input).await?;
117 self.pending_nodes = self.graph.get_entry_nodes();
118 }
119
120 while !self.pending_nodes.is_empty() {
122 if self.step >= self.config.recursion_limit {
124 return Err(GraphError::RecursionLimitExceeded(self.step));
125 }
126
127 let result = self.execute_super_step().await?;
129
130 if let Some(interrupt) = result.interrupt {
132 let checkpoint_id = self.save_checkpoint().await?;
133 return Err(GraphError::Interrupted(Box::new(InterruptedExecution::new(
134 self.config.thread_id.clone(),
135 checkpoint_id,
136 interrupt,
137 self.state.clone(),
138 self.step,
139 ))));
140 }
141
142 self.save_checkpoint().await?;
144
145 if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
147 let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
148 if next.is_empty() {
149 break;
150 }
151 }
152
153 let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
155 self.pending_nodes =
156 self.filter_deferred_nodes(next_candidates, &result.executed_nodes)?;
157 self.step += 1;
158 }
159
160 Ok(self.state.clone())
161 }
162
163 pub fn run_stream(
165 mut self,
166 input: State,
167 mode: StreamMode,
168 ) -> impl futures::Stream<Item = Result<StreamEvent>> + 'a {
169 async_stream::stream! {
170 let resumed = match self.try_resume_from_checkpoint(&input).await {
172 Ok(r) => r,
173 Err(e) => {
174 yield Err(e);
175 return;
176 }
177 };
178
179 if resumed {
180 yield Ok(StreamEvent::resumed(self.step, self.pending_nodes.clone()));
182 } else {
183 match self.initialize_state(input).await {
185 Ok(state) => self.state = state,
186 Err(e) => {
187 yield Err(e);
188 return;
189 }
190 }
191 self.pending_nodes = self.graph.get_entry_nodes();
192 }
193
194 if matches!(mode, StreamMode::Values) {
196 yield Ok(StreamEvent::state(self.state.clone(), self.step));
197 }
198
199 while !self.pending_nodes.is_empty() {
201 if self.step >= self.config.recursion_limit {
203 yield Err(GraphError::RecursionLimitExceeded(self.step));
204 return;
205 }
206
207 if matches!(mode, StreamMode::Debug | StreamMode::Custom | StreamMode::Messages) {
209 for node_name in &self.pending_nodes {
210 yield Ok(StreamEvent::node_start(node_name, self.step));
211 }
212 }
213
214 if matches!(mode, StreamMode::Messages) {
216 let mut result = SuperStepResult::default();
217
218 for node_name in &self.pending_nodes {
219 if let Some(node) = self.graph.nodes.get(node_name) {
220 let mut ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
221
222 let policy = self.graph.timeout_policy_for(node_name).cloned();
224 if let Some(ref p) = policy {
225 if p.idle_timeout.is_some() {
226 ctx.set_progress_handle(ProgressHandle::new());
227 }
228 }
229
230 let start = std::time::Instant::now();
231
232 let mut node_stream = node.execute_stream(&ctx);
233 let mut collected_events = Vec::new();
234
235 while let Some(event_result) = node_stream.next().await {
236 match event_result {
237 Ok(event) => {
238 if matches!(event, StreamEvent::Message { .. }) {
240 yield Ok(event.clone());
241 }
242 collected_events.push(event);
243 }
244 Err(e) => {
245 yield Err(e);
246 return;
247 }
248 }
249 }
250
251 let duration_ms = start.elapsed().as_millis() as u64;
252 result.executed_nodes.push(node_name.clone());
253 result.events.push(StreamEvent::node_end(node_name, self.step, duration_ms));
254 result.events.extend(collected_events);
255
256 let output_result = match policy {
258 Some(ref timeout_policy) => {
259 execute_with_timeout(node.as_ref(), &ctx, timeout_policy).await
260 }
261 None => node.execute(&ctx).await,
262 };
263 if let Ok(output) = output_result {
264 for (key, value) in output.updates {
265 self.graph.schema.apply_update(&mut self.state, &key, value);
266 }
267 }
268 }
269 }
270
271 for event in &result.events {
273 if matches!(event, StreamEvent::NodeEnd { .. }) {
274 yield Ok(event.clone());
275 }
276 }
277
278 self.pending_nodes = {
279 let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
280 match self.filter_deferred_nodes(next_candidates, &result.executed_nodes) {
281 Ok(nodes) => nodes,
282 Err(e) => {
283 yield Err(e);
284 return;
285 }
286 }
287 };
288 self.step += 1;
289 continue;
290 }
291
292 let result = match self.execute_super_step().await {
294 Ok(r) => r,
295 Err(e) => {
296 yield Err(e);
297 return;
298 }
299 };
300
301 for event in &result.events {
303 match (&mode, &event) {
304 (StreamMode::Custom | StreamMode::Debug, StreamEvent::NodeStart { .. }) => {}
306 (StreamMode::Custom, _) => yield Ok(event.clone()),
307 (StreamMode::Debug, _) => yield Ok(event.clone()),
308 _ => {}
309 }
310 }
311
312 match mode {
314 StreamMode::Values => {
315 yield Ok(StreamEvent::state(self.state.clone(), self.step));
316 }
317 StreamMode::Updates => {
318 yield Ok(StreamEvent::step_complete(
319 self.step,
320 result.executed_nodes.clone(),
321 ));
322 }
323 _ => {}
324 }
325
326 if let Some(interrupt) = result.interrupt {
328 yield Ok(StreamEvent::interrupted(
329 result.executed_nodes.first().map(|s| s.as_str()).unwrap_or("unknown"),
330 &interrupt.to_string(),
331 ));
332 return;
333 }
334
335 if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
337 let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
338 if next.is_empty() {
339 break;
340 }
341 }
342
343 self.pending_nodes = {
344 let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
345 match self.filter_deferred_nodes(next_candidates, &result.executed_nodes) {
346 Ok(nodes) => nodes,
347 Err(e) => {
348 yield Err(e);
349 return;
350 }
351 }
352 };
353 self.step += 1;
354 }
355
356 yield Ok(StreamEvent::done(self.state.clone(), self.step + 1));
357 }
358 }
359
360 fn filter_deferred_nodes(
373 &mut self,
374 candidates: Vec<String>,
375 executed_nodes: &[String],
376 ) -> Result<Vec<String>> {
377 let mut ready_nodes = Vec::new();
378
379 for candidate in candidates {
380 if let Some(config) = self.graph.deferred_configs.get(&candidate) {
381 let upstream = self.graph.get_upstream_nodes(&candidate);
383
384 let tracker = self.pending_deferred.entry(candidate.clone()).or_insert_with(|| {
386 let sources: Vec<&str> = upstream.iter().map(|s| s.as_str()).collect();
387 FanInTracker::new(sources)
388 });
389
390 self.deferred_start_times.entry(candidate.clone()).or_insert_with(Instant::now);
392
393 for executed in executed_nodes {
395 if upstream.contains(executed) {
396 let output = self.state.get(executed).cloned().unwrap_or_else(|| {
399 serde_json::Value::Object(
401 self.state.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
402 )
403 });
404 tracker.record(executed, output);
405 }
406 }
407
408 if tracker.is_ready() {
409 let merged = tracker.merge(&config.merge_strategy);
411 let fan_in_key = format!("{candidate}_fan_in");
412 self.graph.schema.apply_update(&mut self.state, &fan_in_key, merged);
413
414 self.pending_deferred.remove(&candidate);
416 self.deferred_start_times.remove(&candidate);
417 ready_nodes.push(candidate);
418 } else if let Some(timeout_duration) = config.fan_in_timeout {
419 let start_time = self.deferred_start_times[&candidate];
421 if start_time.elapsed() >= timeout_duration {
422 let received = tracker.received_count();
423 let expected = tracker.expected_count();
424
425 if received > 0 {
426 tracing::warn!(
428 node = %candidate,
429 received,
430 expected,
431 "fan-in timeout expired, proceeding with partial results"
432 );
433 let merged = tracker.merge(&config.merge_strategy);
434 let fan_in_key = format!("{candidate}_fan_in");
435 self.graph.schema.apply_update(&mut self.state, &fan_in_key, merged);
436
437 self.pending_deferred.remove(&candidate);
439 self.deferred_start_times.remove(&candidate);
440 ready_nodes.push(candidate);
441 } else {
442 self.pending_deferred.remove(&candidate);
444 self.deferred_start_times.remove(&candidate);
445 return Err(GraphError::FanInTimedOut {
446 node: candidate,
447 received,
448 expected,
449 });
450 }
451 }
452 }
453 } else {
456 ready_nodes.push(candidate);
458 }
459 }
460
461 Ok(ready_nodes)
462 }
463
464 async fn initialize_state(&self, input: State) -> Result<State> {
466 let mut state = self.graph.schema.initialize_state();
468
469 if let Some(checkpoint_id) = &self.config.resume_from {
471 if let Some(cp) = self.graph.checkpointer.as_ref() {
472 if let Some(checkpoint) = cp.load_by_id(checkpoint_id).await? {
473 state = checkpoint.state;
474 }
475 }
476 } else if let Some(cp) = self.graph.checkpointer.as_ref() {
477 if let Some(checkpoint) = cp.load(&self.config.thread_id).await? {
479 state = checkpoint.state;
480 }
481 }
482
483 for (key, value) in input {
485 self.graph.schema.apply_update(&mut state, &key, value);
486 }
487
488 Ok(state)
489 }
490
491 async fn execute_super_step(&mut self) -> Result<SuperStepResult> {
493 let mut result = SuperStepResult::default();
494
495 for node_name in &self.pending_nodes {
497 if self.graph.interrupt_before.contains(node_name) {
498 return Ok(SuperStepResult {
499 interrupt: Some(Interrupt::Before(node_name.clone())),
500 ..Default::default()
501 });
502 }
503 }
504
505 #[cfg(feature = "node-cache")]
507 let mut cached_results: HashMap<String, serde_json::Value> = HashMap::new();
508 #[cfg(feature = "node-cache")]
509 let mut nodes_to_execute: Vec<String> = Vec::new();
510
511 #[cfg(feature = "node-cache")]
512 {
513 for node_name in &self.pending_nodes {
514 if let Some(cache) = self.node_caches.get(node_name) {
515 let cache_key = compute_cache_key(node_name, &self.state);
516 let cached_value = cache.get(&cache_key).await;
517 tracing::debug!(
518 node = %node_name,
519 cache_hit = cached_value.is_some(),
520 cache_key = %cache_key,
521 "node cache lookup"
522 );
523 if let Some(value) = cached_value {
524 cached_results.insert(node_name.clone(), value);
526 } else {
527 nodes_to_execute.push(node_name.clone());
529 }
530 } else {
531 nodes_to_execute.push(node_name.clone());
533 }
534 }
535 }
536
537 #[cfg(feature = "node-cache")]
539 {
540 for (node_name, cached_value) in &cached_results {
541 result.executed_nodes.push(node_name.clone());
542 result.events.push(StreamEvent::node_end(node_name, self.step, 0));
543
544 if let Some(updates_map) = cached_value.as_object() {
546 for (key, value) in updates_map {
547 self.graph.schema.apply_update(&mut self.state, key, value.clone());
548 }
549 }
550 }
551 }
552
553 #[cfg(feature = "node-cache")]
555 let pending_for_execution = &nodes_to_execute;
556 #[cfg(not(feature = "node-cache"))]
557 let pending_for_execution = &self.pending_nodes;
558
559 let nodes: Vec<_> = pending_for_execution
561 .iter()
562 .filter_map(|name| self.graph.nodes.get(name).map(|n| (name.clone(), n.clone())))
563 .collect();
564
565 let timeout_policies: Vec<_> =
567 nodes.iter().map(|(name, _)| self.graph.timeout_policy_for(name).cloned()).collect();
568
569 let futures: Vec<_> = nodes
570 .into_iter()
571 .zip(timeout_policies)
572 .map(|((name, node), policy)| {
573 let mut ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
574
575 if let Some(ref p) = policy {
577 if p.idle_timeout.is_some() {
578 ctx.set_progress_handle(ProgressHandle::new());
579 }
580 }
581
582 let step = self.step;
583 async move {
584 let start = Instant::now();
585 let output = match policy {
586 Some(ref timeout_policy) => {
587 execute_with_timeout(node.as_ref(), &ctx, timeout_policy).await
588 }
589 None => node.execute(&ctx).await,
590 };
591 let duration_ms = start.elapsed().as_millis() as u64;
592 (name, output, duration_ms, step)
593 }
594 })
595 .collect();
596
597 let outputs: Vec<_> =
598 stream::iter(futures).buffer_unordered(pending_for_execution.len()).collect().await;
599
600 let mut all_updates = Vec::new();
602
603 for (node_name, output_result, duration_ms, step) in outputs {
604 result.executed_nodes.push(node_name.clone());
605 result.events.push(StreamEvent::node_end(&node_name, step, duration_ms));
606
607 match output_result {
608 Ok(output) => {
609 if let Some(interrupt) = output.interrupt {
611 return Ok(SuperStepResult {
612 interrupt: Some(interrupt),
613 executed_nodes: result.executed_nodes,
614 events: result.events,
615 });
616 }
617
618 result.events.extend(output.events);
620
621 #[cfg(feature = "node-cache")]
623 {
624 if let Some(cache) = self.node_caches.get(&node_name) {
625 let cache_key = compute_cache_key(&node_name, &self.state);
626 let updates_value = serde_json::to_value(&output.updates)
627 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
628 let ttl = self.graph.cache_policies.get(&node_name).and_then(|p| p.ttl);
629 cache.set(&cache_key, updates_value, ttl).await;
630 }
631 }
632
633 all_updates.push(output.updates);
635 }
636 Err(e) => {
637 return Err(GraphError::NodeExecutionFailed {
638 node: node_name,
639 message: e.to_string(),
640 });
641 }
642 }
643 }
644
645 for updates in all_updates {
647 for (key, value) in updates {
648 self.graph.schema.apply_update(&mut self.state, &key, value);
649 }
650 }
651
652 for node_name in &result.executed_nodes {
654 if self.graph.interrupt_after.contains(node_name) {
655 return Ok(SuperStepResult {
656 interrupt: Some(Interrupt::After(node_name.clone())),
657 ..result
658 });
659 }
660 }
661
662 Ok(result)
663 }
664
665 async fn save_checkpoint(&self) -> Result<String> {
667 if let Some(cp) = &self.graph.checkpointer {
668 let checkpoint = Checkpoint::new(
669 &self.config.thread_id,
670 self.state.clone(),
671 self.step,
672 self.pending_nodes.clone(),
673 );
674 return cp.save(&checkpoint).await;
675 }
676 Ok(String::new())
677 }
678}
679
680impl CompiledGraph {
682 pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
684 let mut executor = PregelExecutor::new(self, config);
685 executor.run(input).await
686 }
687
688 pub fn stream(
690 &self,
691 input: State,
692 config: ExecutionConfig,
693 mode: StreamMode,
694 ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
695 tracing::debug!("CompiledGraph::stream called with mode {:?}", mode);
696 let executor = PregelExecutor::new(self, config);
697 executor.run_stream(input, mode)
698 }
699
700 pub async fn get_state(&self, thread_id: &str) -> Result<Option<State>> {
702 if let Some(cp) = &self.checkpointer {
703 Ok(cp.load(thread_id).await?.map(|c| c.state))
704 } else {
705 Ok(None)
706 }
707 }
708
709 pub async fn update_state(
711 &self,
712 thread_id: &str,
713 updates: impl IntoIterator<Item = (String, serde_json::Value)>,
714 ) -> Result<()> {
715 if let Some(cp) = &self.checkpointer {
716 if let Some(checkpoint) = cp.load(thread_id).await? {
717 let mut state = checkpoint.state;
718 for (key, value) in updates {
719 self.schema.apply_update(&mut state, &key, value);
720 }
721 let new_checkpoint =
722 Checkpoint::new(thread_id, state, checkpoint.step, checkpoint.pending_nodes);
723 cp.save(&new_checkpoint).await?;
724 }
725 }
726 Ok(())
727 }
728}
729
730#[cfg(test)]
731mod tests {
732 use super::*;
733 use crate::edge::{END, START};
734 use crate::graph::StateGraph;
735 use crate::node::NodeOutput;
736 use serde_json::json;
737
738 #[tokio::test]
739 async fn test_simple_execution() {
740 let graph = StateGraph::with_channels(&["value"])
741 .add_node_fn("set_value", |_ctx| async {
742 Ok(NodeOutput::new().with_update("value", json!(42)))
743 })
744 .add_edge(START, "set_value")
745 .add_edge("set_value", END)
746 .compile()
747 .unwrap();
748
749 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
750
751 assert_eq!(result.get("value"), Some(&json!(42)));
752 }
753
754 #[tokio::test]
755 async fn test_sequential_execution() {
756 let graph = StateGraph::with_channels(&["value"])
757 .add_node_fn("step1", |_ctx| async {
758 Ok(NodeOutput::new().with_update("value", json!(1)))
759 })
760 .add_node_fn("step2", |ctx| async move {
761 let current = ctx.get("value").and_then(|v| v.as_i64()).unwrap_or(0);
762 Ok(NodeOutput::new().with_update("value", json!(current + 10)))
763 })
764 .add_edge(START, "step1")
765 .add_edge("step1", "step2")
766 .add_edge("step2", END)
767 .compile()
768 .unwrap();
769
770 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
771
772 assert_eq!(result.get("value"), Some(&json!(11)));
773 }
774
775 #[tokio::test]
776 async fn test_conditional_routing() {
777 let graph = StateGraph::with_channels(&["path", "result"])
778 .add_node_fn("router", |ctx| async move {
779 let path = ctx.get("path").and_then(|v| v.as_str()).unwrap_or("a");
780 Ok(NodeOutput::new().with_update("route", json!(path)))
781 })
782 .add_node_fn("path_a", |_ctx| async {
783 Ok(NodeOutput::new().with_update("result", json!("went to A")))
784 })
785 .add_node_fn("path_b", |_ctx| async {
786 Ok(NodeOutput::new().with_update("result", json!("went to B")))
787 })
788 .add_edge(START, "router")
789 .add_conditional_edges(
790 "router",
791 |state| state.get("route").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
792 [("a", "path_a"), ("b", "path_b"), (END, END)],
793 )
794 .add_edge("path_a", END)
795 .add_edge("path_b", END)
796 .compile()
797 .unwrap();
798
799 let mut input = State::new();
801 input.insert("path".to_string(), json!("a"));
802 let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
803 assert_eq!(result.get("result"), Some(&json!("went to A")));
804
805 let mut input = State::new();
807 input.insert("path".to_string(), json!("b"));
808 let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
809 assert_eq!(result.get("result"), Some(&json!("went to B")));
810 }
811
812 #[tokio::test]
813 async fn test_cycle_with_limit() {
814 let graph = StateGraph::with_channels(&["count"])
815 .add_node_fn("increment", |ctx| async move {
816 let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
817 Ok(NodeOutput::new().with_update("count", json!(count + 1)))
818 })
819 .add_edge(START, "increment")
820 .add_conditional_edges(
821 "increment",
822 |state| {
823 let count = state.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
824 if count < 5 { "increment".to_string() } else { END.to_string() }
825 },
826 [("increment", "increment"), (END, END)],
827 )
828 .compile()
829 .unwrap();
830
831 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
832
833 assert_eq!(result.get("count"), Some(&json!(5)));
834 }
835
836 #[tokio::test]
837 async fn test_recursion_limit() {
838 let graph = StateGraph::with_channels(&["count"])
839 .add_node_fn("loop", |ctx| async move {
840 let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
841 Ok(NodeOutput::new().with_update("count", json!(count + 1)))
842 })
843 .add_edge(START, "loop")
844 .add_edge("loop", "loop") .compile()
846 .unwrap()
847 .with_recursion_limit(10);
848
849 let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await;
850
851 assert!(
853 matches!(result, Err(GraphError::RecursionLimitExceeded(_))),
854 "Expected RecursionLimitExceeded error, got: {:?}",
855 result
856 );
857 }
858}