Skip to main content

adk_graph/
executor.rs

1//! Pregel-based execution engine for graphs
2//!
3//! Executes graphs using the Pregel model with super-steps.
4
5#[cfg(feature = "node-cache")]
6use crate::cache::{NodeCache, compute_cache_key};
7use crate::deferred::FanInTracker;
8use crate::error::{GraphError, InterruptedExecution, Result};
9use crate::graph::CompiledGraph;
10use crate::interrupt::Interrupt;
11use crate::node::{ExecutionConfig, NodeContext};
12use crate::state::{Checkpoint, State};
13use crate::stream::{StreamEvent, StreamMode};
14use crate::timeout::{ProgressHandle, execute_with_timeout};
15use futures::stream::{self, StreamExt};
16use std::collections::HashMap;
17use std::time::Instant;
18
19/// Result of a super-step execution
20#[derive(Default)]
21pub struct SuperStepResult {
22    /// Nodes that were executed
23    pub executed_nodes: Vec<String>,
24    /// Interrupt if one occurred
25    pub interrupt: Option<Interrupt>,
26    /// Stream events generated
27    pub events: Vec<StreamEvent>,
28}
29
30/// Pregel-based executor for graphs
31pub struct PregelExecutor<'a> {
32    graph: &'a CompiledGraph,
33    config: ExecutionConfig,
34    state: State,
35    step: usize,
36    pending_nodes: Vec<String>,
37    /// Tracks deferred nodes waiting for all upstream paths to complete.
38    pending_deferred: HashMap<String, FanInTracker>,
39    /// Tracks when each deferred node first entered the pending state (for fan-in timeout).
40    deferred_start_times: HashMap<String, Instant>,
41    /// Per-node caches initialized from `CompiledGraph::cache_policies`.
42    #[cfg(feature = "node-cache")]
43    node_caches: HashMap<String, NodeCache>,
44}
45
46impl<'a> PregelExecutor<'a> {
47    /// Create a new executor
48    pub fn new(graph: &'a CompiledGraph, config: ExecutionConfig) -> Self {
49        #[cfg(feature = "node-cache")]
50        let node_caches = graph
51            .cache_policies
52            .iter()
53            .map(|(name, policy)| (name.clone(), NodeCache::from_policy(policy)))
54            .collect();
55
56        Self {
57            graph,
58            config,
59            state: State::new(),
60            step: 0,
61            pending_nodes: vec![],
62            pending_deferred: HashMap::new(),
63            deferred_start_times: HashMap::new(),
64            #[cfg(feature = "node-cache")]
65            node_caches,
66        }
67    }
68
69    /// Attempt to resume from an existing checkpoint.
70    ///
71    /// If a checkpoint is found (either by explicit `resume_from` ID or by latest
72    /// checkpoint for the thread), restores state, pending_nodes, and step from it,
73    /// then merges the provided input on top. Returns `true` if resumed.
74    ///
75    /// If no checkpoint is found, returns `false` so the caller can proceed with
76    /// fresh-start logic.
77    async fn try_resume_from_checkpoint(&mut self, input: &State) -> Result<bool> {
78        let checkpoint = if let Some(checkpoint_id) = &self.config.resume_from {
79            // Resume from a specific checkpoint by ID
80            if let Some(cp) = self.graph.checkpointer.as_ref() {
81                cp.load_by_id(checkpoint_id).await?
82            } else {
83                None
84            }
85        } else if let Some(cp) = self.graph.checkpointer.as_ref() {
86            // Try to load the latest checkpoint for this thread
87            cp.load(&self.config.thread_id).await?
88        } else {
89            None
90        };
91
92        if let Some(checkpoint) = checkpoint {
93            // Restore state from checkpoint
94            self.state = checkpoint.state;
95            self.pending_nodes = checkpoint.pending_nodes;
96            self.step = checkpoint.step;
97
98            // Merge input on top of restored state
99            for (key, value) in input {
100                self.graph.schema.apply_update(&mut self.state, key, value.clone());
101            }
102
103            Ok(true)
104        } else {
105            Ok(false)
106        }
107    }
108
109    /// Run the graph to completion
110    pub async fn run(&mut self, input: State) -> Result<State> {
111        // Check for existing checkpoint to resume from
112        let resumed = self.try_resume_from_checkpoint(&input).await?;
113
114        if !resumed {
115            // No checkpoint found — fresh start
116            self.state = self.initialize_state(input).await?;
117            self.pending_nodes = self.graph.get_entry_nodes();
118        }
119
120        // Main execution loop
121        while !self.pending_nodes.is_empty() {
122            // Check recursion limit
123            if self.step >= self.config.recursion_limit {
124                return Err(GraphError::RecursionLimitExceeded(self.step));
125            }
126
127            // Execute super-step
128            let result = self.execute_super_step().await?;
129
130            // Handle interrupts
131            if let Some(interrupt) = result.interrupt {
132                let checkpoint_id = self.save_checkpoint().await?;
133                return Err(GraphError::Interrupted(Box::new(InterruptedExecution::new(
134                    self.config.thread_id.clone(),
135                    checkpoint_id,
136                    interrupt,
137                    self.state.clone(),
138                    self.step,
139                ))));
140            }
141
142            // Save checkpoint after each step
143            self.save_checkpoint().await?;
144
145            // Check if we're done (all paths led to END)
146            if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
147                let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
148                if next.is_empty() {
149                    break;
150                }
151            }
152
153            // Determine next nodes and apply deferred node filtering
154            let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
155            self.pending_nodes =
156                self.filter_deferred_nodes(next_candidates, &result.executed_nodes)?;
157            self.step += 1;
158        }
159
160        Ok(self.state.clone())
161    }
162
163    /// Run with streaming
164    pub fn run_stream(
165        mut self,
166        input: State,
167        mode: StreamMode,
168    ) -> impl futures::Stream<Item = Result<StreamEvent>> + 'a {
169        async_stream::stream! {
170            // Check for existing checkpoint to resume from
171            let resumed = match self.try_resume_from_checkpoint(&input).await {
172                Ok(r) => r,
173                Err(e) => {
174                    yield Err(e);
175                    return;
176                }
177            };
178
179            if resumed {
180                // Emit a resumed event indicating execution was restored from checkpoint
181                yield Ok(StreamEvent::resumed(self.step, self.pending_nodes.clone()));
182            } else {
183                // No checkpoint found — fresh start
184                match self.initialize_state(input).await {
185                    Ok(state) => self.state = state,
186                    Err(e) => {
187                        yield Err(e);
188                        return;
189                    }
190                }
191                self.pending_nodes = self.graph.get_entry_nodes();
192            }
193
194            // Stream initial state if requested
195            if matches!(mode, StreamMode::Values) {
196                yield Ok(StreamEvent::state(self.state.clone(), self.step));
197            }
198
199            // Main execution loop
200            while !self.pending_nodes.is_empty() {
201                // Check recursion limit
202                if self.step >= self.config.recursion_limit {
203                    yield Err(GraphError::RecursionLimitExceeded(self.step));
204                    return;
205                }
206
207                // Emit node_start events BEFORE execution (in Debug mode)
208                if matches!(mode, StreamMode::Debug | StreamMode::Custom | StreamMode::Messages) {
209                    for node_name in &self.pending_nodes {
210                        yield Ok(StreamEvent::node_start(node_name, self.step));
211                    }
212                }
213
214                // For Messages mode, stream from nodes directly
215                if matches!(mode, StreamMode::Messages) {
216                    let mut result = SuperStepResult::default();
217
218                    for node_name in &self.pending_nodes {
219                        if let Some(node) = self.graph.nodes.get(node_name) {
220                            let mut ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
221
222                            // Attach progress handle if idle timeout is configured
223                            let policy = self.graph.timeout_policy_for(node_name).cloned();
224                            if let Some(ref p) = policy {
225                                if p.idle_timeout.is_some() {
226                                    ctx.set_progress_handle(ProgressHandle::new());
227                                }
228                            }
229
230                            let start = std::time::Instant::now();
231
232                            let mut node_stream = node.execute_stream(&ctx);
233                            let mut collected_events = Vec::new();
234
235                            while let Some(event_result) = node_stream.next().await {
236                                match event_result {
237                                    Ok(event) => {
238                                        // Yield Message events immediately
239                                        if matches!(event, StreamEvent::Message { .. }) {
240                                            yield Ok(event.clone());
241                                        }
242                                        collected_events.push(event);
243                                    }
244                                    Err(e) => {
245                                        yield Err(e);
246                                        return;
247                                    }
248                                }
249                            }
250
251                            let duration_ms = start.elapsed().as_millis() as u64;
252                            result.executed_nodes.push(node_name.clone());
253                            result.events.push(StreamEvent::node_end(node_name, self.step, duration_ms));
254                            result.events.extend(collected_events);
255
256                            // Get output from execute for state updates, with timeout if configured
257                            let output_result = match policy {
258                                Some(ref timeout_policy) => {
259                                    execute_with_timeout(node.as_ref(), &ctx, timeout_policy).await
260                                }
261                                None => node.execute(&ctx).await,
262                            };
263                            if let Ok(output) = output_result {
264                                for (key, value) in output.updates {
265                                    self.graph.schema.apply_update(&mut self.state, &key, value);
266                                }
267                            }
268                        }
269                    }
270
271                    // Yield node_end events
272                    for event in &result.events {
273                        if matches!(event, StreamEvent::NodeEnd { .. }) {
274                            yield Ok(event.clone());
275                        }
276                    }
277
278                    self.pending_nodes = {
279                        let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
280                        match self.filter_deferred_nodes(next_candidates, &result.executed_nodes) {
281                            Ok(nodes) => nodes,
282                            Err(e) => {
283                                yield Err(e);
284                                return;
285                            }
286                        }
287                    };
288                    self.step += 1;
289                    continue;
290                }
291
292                // Execute super-step (non-streaming)
293                let result = match self.execute_super_step().await {
294                    Ok(r) => r,
295                    Err(e) => {
296                        yield Err(e);
297                        return;
298                    }
299                };
300
301                // Yield events based on mode (node_end and custom events)
302                for event in &result.events {
303                    match (&mode, &event) {
304                        // Skip node_start since we already emitted it above
305                        (StreamMode::Custom | StreamMode::Debug, StreamEvent::NodeStart { .. }) => {}
306                        (StreamMode::Custom, _) => yield Ok(event.clone()),
307                        (StreamMode::Debug, _) => yield Ok(event.clone()),
308                        _ => {}
309                    }
310                }
311
312                // Yield state/updates
313                match mode {
314                    StreamMode::Values => {
315                        yield Ok(StreamEvent::state(self.state.clone(), self.step));
316                    }
317                    StreamMode::Updates => {
318                        yield Ok(StreamEvent::step_complete(
319                            self.step,
320                            result.executed_nodes.clone(),
321                        ));
322                    }
323                    _ => {}
324                }
325
326                // Handle interrupts
327                if let Some(interrupt) = result.interrupt {
328                    yield Ok(StreamEvent::interrupted(
329                        result.executed_nodes.first().map(|s| s.as_str()).unwrap_or("unknown"),
330                        &interrupt.to_string(),
331                    ));
332                    return;
333                }
334
335                // Check if done
336                if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
337                    let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
338                    if next.is_empty() {
339                        break;
340                    }
341                }
342
343                self.pending_nodes = {
344                    let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
345                    match self.filter_deferred_nodes(next_candidates, &result.executed_nodes) {
346                        Ok(nodes) => nodes,
347                        Err(e) => {
348                            yield Err(e);
349                            return;
350                        }
351                    }
352                };
353                self.step += 1;
354            }
355
356            yield Ok(StreamEvent::done(self.state.clone(), self.step + 1));
357        }
358    }
359
360    /// Filter deferred nodes from the next candidates.
361    ///
362    /// For each candidate node that is configured as deferred, check whether all
363    /// upstream paths have completed. If not, hold the node in `pending_deferred`
364    /// and record the outputs from the just-executed nodes. If all upstream paths
365    /// have completed, inject the merged output into state and allow the node to
366    /// proceed.
367    ///
368    /// If a deferred node has a `fan_in_timeout` configured and the timeout has
369    /// elapsed:
370    /// - If at least one upstream path has completed, proceed with partial results.
371    /// - If zero upstream paths have completed, return `GraphError::FanInTimedOut`.
372    fn filter_deferred_nodes(
373        &mut self,
374        candidates: Vec<String>,
375        executed_nodes: &[String],
376    ) -> Result<Vec<String>> {
377        let mut ready_nodes = Vec::new();
378
379        for candidate in candidates {
380            if let Some(config) = self.graph.deferred_configs.get(&candidate) {
381                // This is a deferred node — check if all upstream paths are done
382                let upstream = self.graph.get_upstream_nodes(&candidate);
383
384                // Get or create the tracker for this deferred node
385                let tracker = self.pending_deferred.entry(candidate.clone()).or_insert_with(|| {
386                    let sources: Vec<&str> = upstream.iter().map(|s| s.as_str()).collect();
387                    FanInTracker::new(sources)
388                });
389
390                // Record the start time if this is the first time we see this deferred node
391                self.deferred_start_times.entry(candidate.clone()).or_insert_with(Instant::now);
392
393                // Record outputs from the just-executed nodes that are upstream of this deferred node
394                for executed in executed_nodes {
395                    if upstream.contains(executed) {
396                        // Use the current state as the output representation for this upstream node.
397                        // We capture a snapshot of the state that this upstream node contributed to.
398                        let output = self.state.get(executed).cloned().unwrap_or_else(|| {
399                            // If no state key matches the node name, capture the full state
400                            serde_json::Value::Object(
401                                self.state.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
402                            )
403                        });
404                        tracker.record(executed, output);
405                    }
406                }
407
408                if tracker.is_ready() {
409                    // All upstream paths have completed — merge and inject into state
410                    let merged = tracker.merge(&config.merge_strategy);
411                    let fan_in_key = format!("{candidate}_fan_in");
412                    self.graph.schema.apply_update(&mut self.state, &fan_in_key, merged);
413
414                    // Remove from pending_deferred and start times since it's now ready
415                    self.pending_deferred.remove(&candidate);
416                    self.deferred_start_times.remove(&candidate);
417                    ready_nodes.push(candidate);
418                } else if let Some(timeout_duration) = config.fan_in_timeout {
419                    // Check if the fan-in timeout has elapsed
420                    let start_time = self.deferred_start_times[&candidate];
421                    if start_time.elapsed() >= timeout_duration {
422                        let received = tracker.received_count();
423                        let expected = tracker.expected_count();
424
425                        if received > 0 {
426                            // Proceed with partial results
427                            tracing::warn!(
428                                node = %candidate,
429                                received,
430                                expected,
431                                "fan-in timeout expired, proceeding with partial results"
432                            );
433                            let merged = tracker.merge(&config.merge_strategy);
434                            let fan_in_key = format!("{candidate}_fan_in");
435                            self.graph.schema.apply_update(&mut self.state, &fan_in_key, merged);
436
437                            // Clean up tracking state
438                            self.pending_deferred.remove(&candidate);
439                            self.deferred_start_times.remove(&candidate);
440                            ready_nodes.push(candidate);
441                        } else {
442                            // Zero upstream paths completed — return error
443                            self.pending_deferred.remove(&candidate);
444                            self.deferred_start_times.remove(&candidate);
445                            return Err(GraphError::FanInTimedOut {
446                                node: candidate,
447                                received,
448                                expected,
449                            });
450                        }
451                    }
452                }
453                // If not ready and no timeout (or timeout not yet elapsed), the node stays
454                // in pending_deferred and is NOT added to ready_nodes
455            } else {
456                // Not a deferred node — schedule normally
457                ready_nodes.push(candidate);
458            }
459        }
460
461        Ok(ready_nodes)
462    }
463
464    /// Initialize state from input and/or checkpoint
465    async fn initialize_state(&self, input: State) -> Result<State> {
466        // Start with schema defaults
467        let mut state = self.graph.schema.initialize_state();
468
469        // If resuming from checkpoint, load it
470        if let Some(checkpoint_id) = &self.config.resume_from {
471            if let Some(cp) = self.graph.checkpointer.as_ref() {
472                if let Some(checkpoint) = cp.load_by_id(checkpoint_id).await? {
473                    state = checkpoint.state;
474                }
475            }
476        } else if let Some(cp) = self.graph.checkpointer.as_ref() {
477            // Try to load latest checkpoint for thread
478            if let Some(checkpoint) = cp.load(&self.config.thread_id).await? {
479                state = checkpoint.state;
480            }
481        }
482
483        // Merge input into state
484        for (key, value) in input {
485            self.graph.schema.apply_update(&mut state, &key, value);
486        }
487
488        Ok(state)
489    }
490
491    /// Execute one super-step (plan -> execute -> update)
492    async fn execute_super_step(&mut self) -> Result<SuperStepResult> {
493        let mut result = SuperStepResult::default();
494
495        // Check for interrupt_before
496        for node_name in &self.pending_nodes {
497            if self.graph.interrupt_before.contains(node_name) {
498                return Ok(SuperStepResult {
499                    interrupt: Some(Interrupt::Before(node_name.clone())),
500                    ..Default::default()
501                });
502            }
503        }
504
505        // --- Node cache: check for cache hits before executing ---
506        #[cfg(feature = "node-cache")]
507        let mut cached_results: HashMap<String, serde_json::Value> = HashMap::new();
508        #[cfg(feature = "node-cache")]
509        let mut nodes_to_execute: Vec<String> = Vec::new();
510
511        #[cfg(feature = "node-cache")]
512        {
513            for node_name in &self.pending_nodes {
514                if let Some(cache) = self.node_caches.get(node_name) {
515                    let cache_key = compute_cache_key(node_name, &self.state);
516                    let cached_value = cache.get(&cache_key).await;
517                    tracing::debug!(
518                        node = %node_name,
519                        cache_hit = cached_value.is_some(),
520                        cache_key = %cache_key,
521                        "node cache lookup"
522                    );
523                    if let Some(value) = cached_value {
524                        // Cache hit — store the cached result for later application
525                        cached_results.insert(node_name.clone(), value);
526                    } else {
527                        // Cache miss — node needs execution
528                        nodes_to_execute.push(node_name.clone());
529                    }
530                } else {
531                    // No cache configured — node needs execution
532                    nodes_to_execute.push(node_name.clone());
533                }
534            }
535        }
536
537        // Apply cached results immediately
538        #[cfg(feature = "node-cache")]
539        {
540            for (node_name, cached_value) in &cached_results {
541                result.executed_nodes.push(node_name.clone());
542                result.events.push(StreamEvent::node_end(node_name, self.step, 0));
543
544                // Reconstruct updates from the cached JSON value (a map of key -> value)
545                if let Some(updates_map) = cached_value.as_object() {
546                    for (key, value) in updates_map {
547                        self.graph.schema.apply_update(&mut self.state, key, value.clone());
548                    }
549                }
550            }
551        }
552
553        // Determine which nodes to execute (all if cache feature is disabled)
554        #[cfg(feature = "node-cache")]
555        let pending_for_execution = &nodes_to_execute;
556        #[cfg(not(feature = "node-cache"))]
557        let pending_for_execution = &self.pending_nodes;
558
559        // Execute all pending nodes in parallel
560        let nodes: Vec<_> = pending_for_execution
561            .iter()
562            .filter_map(|name| self.graph.nodes.get(name).map(|n| (name.clone(), n.clone())))
563            .collect();
564
565        // Look up timeout policies for each node before spawning futures
566        let timeout_policies: Vec<_> =
567            nodes.iter().map(|(name, _)| self.graph.timeout_policy_for(name).cloned()).collect();
568
569        let futures: Vec<_> = nodes
570            .into_iter()
571            .zip(timeout_policies)
572            .map(|((name, node), policy)| {
573                let mut ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
574
575                // Attach a ProgressHandle when idle timeout is configured
576                if let Some(ref p) = policy {
577                    if p.idle_timeout.is_some() {
578                        ctx.set_progress_handle(ProgressHandle::new());
579                    }
580                }
581
582                let step = self.step;
583                async move {
584                    let start = Instant::now();
585                    let output = match policy {
586                        Some(ref timeout_policy) => {
587                            execute_with_timeout(node.as_ref(), &ctx, timeout_policy).await
588                        }
589                        None => node.execute(&ctx).await,
590                    };
591                    let duration_ms = start.elapsed().as_millis() as u64;
592                    (name, output, duration_ms, step)
593                }
594            })
595            .collect();
596
597        let outputs: Vec<_> =
598            stream::iter(futures).buffer_unordered(pending_for_execution.len()).collect().await;
599
600        // Collect all updates and check for errors/interrupts
601        let mut all_updates = Vec::new();
602
603        for (node_name, output_result, duration_ms, step) in outputs {
604            result.executed_nodes.push(node_name.clone());
605            result.events.push(StreamEvent::node_end(&node_name, step, duration_ms));
606
607            match output_result {
608                Ok(output) => {
609                    // Check for dynamic interrupt
610                    if let Some(interrupt) = output.interrupt {
611                        return Ok(SuperStepResult {
612                            interrupt: Some(interrupt),
613                            executed_nodes: result.executed_nodes,
614                            events: result.events,
615                        });
616                    }
617
618                    // Collect custom events
619                    result.events.extend(output.events);
620
621                    // Store result in cache on miss
622                    #[cfg(feature = "node-cache")]
623                    {
624                        if let Some(cache) = self.node_caches.get(&node_name) {
625                            let cache_key = compute_cache_key(&node_name, &self.state);
626                            let updates_value = serde_json::to_value(&output.updates)
627                                .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
628                            let ttl = self.graph.cache_policies.get(&node_name).and_then(|p| p.ttl);
629                            cache.set(&cache_key, updates_value, ttl).await;
630                        }
631                    }
632
633                    // Collect updates
634                    all_updates.push(output.updates);
635                }
636                Err(e) => {
637                    return Err(GraphError::NodeExecutionFailed {
638                        node: node_name,
639                        message: e.to_string(),
640                    });
641                }
642            }
643        }
644
645        // Apply all updates atomically using reducers
646        for updates in all_updates {
647            for (key, value) in updates {
648                self.graph.schema.apply_update(&mut self.state, &key, value);
649            }
650        }
651
652        // Check for interrupt_after
653        for node_name in &result.executed_nodes {
654            if self.graph.interrupt_after.contains(node_name) {
655                return Ok(SuperStepResult {
656                    interrupt: Some(Interrupt::After(node_name.clone())),
657                    ..result
658                });
659            }
660        }
661
662        Ok(result)
663    }
664
665    /// Save a checkpoint
666    async fn save_checkpoint(&self) -> Result<String> {
667        if let Some(cp) = &self.graph.checkpointer {
668            let checkpoint = Checkpoint::new(
669                &self.config.thread_id,
670                self.state.clone(),
671                self.step,
672                self.pending_nodes.clone(),
673            );
674            return cp.save(&checkpoint).await;
675        }
676        Ok(String::new())
677    }
678}
679
680/// Convenience methods for CompiledGraph
681impl CompiledGraph {
682    /// Execute the graph synchronously
683    pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
684        let mut executor = PregelExecutor::new(self, config);
685        executor.run(input).await
686    }
687
688    /// Execute with streaming
689    pub fn stream(
690        &self,
691        input: State,
692        config: ExecutionConfig,
693        mode: StreamMode,
694    ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
695        tracing::debug!("CompiledGraph::stream called with mode {:?}", mode);
696        let executor = PregelExecutor::new(self, config);
697        executor.run_stream(input, mode)
698    }
699
700    /// Get current state for a thread
701    pub async fn get_state(&self, thread_id: &str) -> Result<Option<State>> {
702        if let Some(cp) = &self.checkpointer {
703            Ok(cp.load(thread_id).await?.map(|c| c.state))
704        } else {
705            Ok(None)
706        }
707    }
708
709    /// Update state for a thread (for human-in-the-loop)
710    pub async fn update_state(
711        &self,
712        thread_id: &str,
713        updates: impl IntoIterator<Item = (String, serde_json::Value)>,
714    ) -> Result<()> {
715        if let Some(cp) = &self.checkpointer {
716            if let Some(checkpoint) = cp.load(thread_id).await? {
717                let mut state = checkpoint.state;
718                for (key, value) in updates {
719                    self.schema.apply_update(&mut state, &key, value);
720                }
721                let new_checkpoint =
722                    Checkpoint::new(thread_id, state, checkpoint.step, checkpoint.pending_nodes);
723                cp.save(&new_checkpoint).await?;
724            }
725        }
726        Ok(())
727    }
728}
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733    use crate::edge::{END, START};
734    use crate::graph::StateGraph;
735    use crate::node::NodeOutput;
736    use serde_json::json;
737
738    #[tokio::test]
739    async fn test_simple_execution() {
740        let graph = StateGraph::with_channels(&["value"])
741            .add_node_fn("set_value", |_ctx| async {
742                Ok(NodeOutput::new().with_update("value", json!(42)))
743            })
744            .add_edge(START, "set_value")
745            .add_edge("set_value", END)
746            .compile()
747            .unwrap();
748
749        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
750
751        assert_eq!(result.get("value"), Some(&json!(42)));
752    }
753
754    #[tokio::test]
755    async fn test_sequential_execution() {
756        let graph = StateGraph::with_channels(&["value"])
757            .add_node_fn("step1", |_ctx| async {
758                Ok(NodeOutput::new().with_update("value", json!(1)))
759            })
760            .add_node_fn("step2", |ctx| async move {
761                let current = ctx.get("value").and_then(|v| v.as_i64()).unwrap_or(0);
762                Ok(NodeOutput::new().with_update("value", json!(current + 10)))
763            })
764            .add_edge(START, "step1")
765            .add_edge("step1", "step2")
766            .add_edge("step2", END)
767            .compile()
768            .unwrap();
769
770        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
771
772        assert_eq!(result.get("value"), Some(&json!(11)));
773    }
774
775    #[tokio::test]
776    async fn test_conditional_routing() {
777        let graph = StateGraph::with_channels(&["path", "result"])
778            .add_node_fn("router", |ctx| async move {
779                let path = ctx.get("path").and_then(|v| v.as_str()).unwrap_or("a");
780                Ok(NodeOutput::new().with_update("route", json!(path)))
781            })
782            .add_node_fn("path_a", |_ctx| async {
783                Ok(NodeOutput::new().with_update("result", json!("went to A")))
784            })
785            .add_node_fn("path_b", |_ctx| async {
786                Ok(NodeOutput::new().with_update("result", json!("went to B")))
787            })
788            .add_edge(START, "router")
789            .add_conditional_edges(
790                "router",
791                |state| state.get("route").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
792                [("a", "path_a"), ("b", "path_b"), (END, END)],
793            )
794            .add_edge("path_a", END)
795            .add_edge("path_b", END)
796            .compile()
797            .unwrap();
798
799        // Test path A
800        let mut input = State::new();
801        input.insert("path".to_string(), json!("a"));
802        let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
803        assert_eq!(result.get("result"), Some(&json!("went to A")));
804
805        // Test path B
806        let mut input = State::new();
807        input.insert("path".to_string(), json!("b"));
808        let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
809        assert_eq!(result.get("result"), Some(&json!("went to B")));
810    }
811
812    #[tokio::test]
813    async fn test_cycle_with_limit() {
814        let graph = StateGraph::with_channels(&["count"])
815            .add_node_fn("increment", |ctx| async move {
816                let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
817                Ok(NodeOutput::new().with_update("count", json!(count + 1)))
818            })
819            .add_edge(START, "increment")
820            .add_conditional_edges(
821                "increment",
822                |state| {
823                    let count = state.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
824                    if count < 5 { "increment".to_string() } else { END.to_string() }
825                },
826                [("increment", "increment"), (END, END)],
827            )
828            .compile()
829            .unwrap();
830
831        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
832
833        assert_eq!(result.get("count"), Some(&json!(5)));
834    }
835
836    #[tokio::test]
837    async fn test_recursion_limit() {
838        let graph = StateGraph::with_channels(&["count"])
839            .add_node_fn("loop", |ctx| async move {
840                let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
841                Ok(NodeOutput::new().with_update("count", json!(count + 1)))
842            })
843            .add_edge(START, "loop")
844            .add_edge("loop", "loop") // Infinite loop
845            .compile()
846            .unwrap()
847            .with_recursion_limit(10);
848
849        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await;
850
851        // The recursion limit check happens when step >= limit, so it will exceed at step 10
852        assert!(
853            matches!(result, Err(GraphError::RecursionLimitExceeded(_))),
854            "Expected RecursionLimitExceeded error, got: {:?}",
855            result
856        );
857    }
858}