Skip to main content

adk_graph/
executor.rs

1//! Pregel-based execution engine for graphs
2//!
3//! Executes graphs using the Pregel model with super-steps.
4
5#[cfg(feature = "node-cache")]
6use crate::cache::{NodeCache, compute_cache_key};
7use crate::deferred::FanInTracker;
8use crate::error::{GraphError, InterruptedExecution, Result};
9use crate::graph::CompiledGraph;
10use crate::interrupt::Interrupt;
11use crate::node::{ExecutionConfig, NodeContext};
12use crate::state::{Checkpoint, State};
13use crate::stream::{StreamEvent, StreamMode};
14use crate::timeout::{ProgressHandle, execute_with_timeout};
15use futures::stream::{self, StreamExt};
16use std::collections::HashMap;
17use std::time::Instant;
18
19/// Result of a super-step execution
20#[derive(Default)]
21pub struct SuperStepResult {
22    /// Nodes that were executed
23    pub executed_nodes: Vec<String>,
24    /// Interrupt if one occurred
25    pub interrupt: Option<Interrupt>,
26    /// Stream events generated
27    pub events: Vec<StreamEvent>,
28}
29
30/// Pregel-based executor for graphs
31pub struct PregelExecutor<'a> {
32    graph: &'a CompiledGraph,
33    config: ExecutionConfig,
34    state: State,
35    step: usize,
36    pending_nodes: Vec<String>,
37    /// Tracks deferred nodes waiting for all upstream paths to complete.
38    pending_deferred: HashMap<String, FanInTracker>,
39    /// Tracks when each deferred node first entered the pending state (for fan-in timeout).
40    deferred_start_times: HashMap<String, Instant>,
41    /// Per-node caches initialized from `CompiledGraph::cache_policies`.
42    #[cfg(feature = "node-cache")]
43    node_caches: HashMap<String, NodeCache>,
44}
45
46impl<'a> PregelExecutor<'a> {
47    /// Create a new executor
48    pub fn new(graph: &'a CompiledGraph, config: ExecutionConfig) -> Self {
49        #[cfg(feature = "node-cache")]
50        let node_caches = graph
51            .cache_policies
52            .iter()
53            .map(|(name, policy)| (name.clone(), NodeCache::from_policy(policy)))
54            .collect();
55
56        Self {
57            graph,
58            config,
59            state: State::new(),
60            step: 0,
61            pending_nodes: vec![],
62            pending_deferred: HashMap::new(),
63            deferred_start_times: HashMap::new(),
64            #[cfg(feature = "node-cache")]
65            node_caches,
66        }
67    }
68
69    /// Attempt to resume from an existing checkpoint.
70    ///
71    /// If a checkpoint is found (either by explicit `resume_from` ID or by latest
72    /// checkpoint for the thread), restores state, pending_nodes, and step from it,
73    /// then merges the provided input on top. Returns `true` if resumed.
74    ///
75    /// If no checkpoint is found, returns `false` so the caller can proceed with
76    /// fresh-start logic.
77    async fn try_resume_from_checkpoint(&mut self, input: &State) -> Result<bool> {
78        let checkpoint = if let Some(checkpoint_id) = &self.config.resume_from {
79            // Resume from a specific checkpoint by ID
80            if let Some(cp) = self.graph.checkpointer.as_ref() {
81                cp.load_by_id(checkpoint_id).await?
82            } else {
83                None
84            }
85        } else if let Some(cp) = self.graph.checkpointer.as_ref() {
86            // Try to load the latest checkpoint for this thread
87            cp.load(&self.config.thread_id).await?
88        } else {
89            None
90        };
91
92        if let Some(checkpoint) = checkpoint {
93            // Restore state from checkpoint
94            self.state = checkpoint.state;
95            self.pending_nodes = checkpoint.pending_nodes;
96            self.step = checkpoint.step;
97
98            // Merge input on top of restored state
99            for (key, value) in input {
100                self.graph.schema.apply_update(&mut self.state, key, value.clone());
101            }
102
103            Ok(true)
104        } else {
105            Ok(false)
106        }
107    }
108
109    /// Run the graph to completion
110    pub async fn run(&mut self, input: State) -> Result<State> {
111        // Check for existing checkpoint to resume from
112        let resumed = self.try_resume_from_checkpoint(&input).await?;
113
114        if !resumed {
115            // No checkpoint found — fresh start
116            self.state = self.initialize_state(input).await?;
117            self.pending_nodes = self.graph.get_entry_nodes();
118        }
119
120        // Main execution loop
121        while !self.pending_nodes.is_empty() {
122            // Check recursion limit
123            if self.step >= self.config.recursion_limit {
124                return Err(GraphError::RecursionLimitExceeded(self.step));
125            }
126
127            // Execute super-step
128            let result = self.execute_super_step().await?;
129
130            // Handle interrupts
131            if let Some(interrupt) = result.interrupt {
132                let checkpoint_id = self.save_checkpoint().await?;
133                return Err(GraphError::Interrupted(Box::new(InterruptedExecution::new(
134                    self.config.thread_id.clone(),
135                    checkpoint_id,
136                    interrupt,
137                    self.state.clone(),
138                    self.step,
139                ))));
140            }
141
142            // Save checkpoint after each step
143            self.save_checkpoint().await?;
144
145            // Check if we're done (all paths led to END)
146            if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
147                let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
148                if next.is_empty() {
149                    break;
150                }
151            }
152
153            // Determine next nodes and apply deferred node filtering
154            let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
155            self.pending_nodes =
156                self.filter_deferred_nodes(next_candidates, &result.executed_nodes)?;
157            self.step += 1;
158        }
159
160        Ok(self.state.clone())
161    }
162
163    /// Run with streaming
164    pub fn run_stream(
165        mut self,
166        input: State,
167        mode: StreamMode,
168    ) -> impl futures::Stream<Item = Result<StreamEvent>> + 'a {
169        async_stream::stream! {
170            // Check for existing checkpoint to resume from
171            let resumed = match self.try_resume_from_checkpoint(&input).await {
172                Ok(r) => r,
173                Err(e) => {
174                    yield Err(e);
175                    return;
176                }
177            };
178
179            if resumed {
180                // Emit a resumed event indicating execution was restored from checkpoint
181                yield Ok(StreamEvent::resumed(self.step, self.pending_nodes.clone()));
182            } else {
183                // No checkpoint found — fresh start
184                match self.initialize_state(input).await {
185                    Ok(state) => self.state = state,
186                    Err(e) => {
187                        yield Err(e);
188                        return;
189                    }
190                }
191                self.pending_nodes = self.graph.get_entry_nodes();
192            }
193
194            // Stream initial state if requested
195            if matches!(mode, StreamMode::Values) {
196                yield Ok(StreamEvent::state(self.state.clone(), self.step));
197            }
198
199            // Main execution loop
200            while !self.pending_nodes.is_empty() {
201                // Check recursion limit
202                if self.step >= self.config.recursion_limit {
203                    yield Err(GraphError::RecursionLimitExceeded(self.step));
204                    return;
205                }
206
207                // Emit node_start events BEFORE execution (in Debug mode)
208                if matches!(mode, StreamMode::Debug | StreamMode::Custom | StreamMode::Messages) {
209                    for node_name in &self.pending_nodes {
210                        yield Ok(StreamEvent::node_start(node_name, self.step));
211                    }
212                }
213
214                // For Messages mode, stream from nodes directly
215                if matches!(mode, StreamMode::Messages) {
216                    let mut result = SuperStepResult::default();
217
218                    for node_name in &self.pending_nodes {
219                        if let Some(node) = self.graph.nodes.get(node_name) {
220                            let mut ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
221
222                            // Attach progress handle if idle timeout is configured
223                            let policy = self.graph.timeout_policy_for(node_name).cloned();
224                            if let Some(ref p) = policy
225                                && p.idle_timeout.is_some() {
226                                    ctx.set_progress_handle(ProgressHandle::new());
227                                }
228
229                            let start = std::time::Instant::now();
230
231                            let mut node_stream = node.execute_stream(&ctx);
232                            let mut collected_events = Vec::new();
233
234                            while let Some(event_result) = node_stream.next().await {
235                                match event_result {
236                                    Ok(event) => {
237                                        // Yield Message events immediately
238                                        if matches!(event, StreamEvent::Message { .. }) {
239                                            yield Ok(event.clone());
240                                        }
241                                        collected_events.push(event);
242                                    }
243                                    Err(e) => {
244                                        yield Err(e);
245                                        return;
246                                    }
247                                }
248                            }
249
250                            let duration_ms = start.elapsed().as_millis() as u64;
251                            result.executed_nodes.push(node_name.clone());
252                            result.events.push(StreamEvent::node_end(node_name, self.step, duration_ms));
253                            result.events.extend(collected_events);
254
255                            // Get output from execute for state updates, with timeout if configured
256                            let output_result = match policy {
257                                Some(ref timeout_policy) => {
258                                    execute_with_timeout(node.as_ref(), &ctx, timeout_policy).await
259                                }
260                                None => node.execute(&ctx).await,
261                            };
262                            if let Ok(output) = output_result {
263                                for (key, value) in output.updates {
264                                    self.graph.schema.apply_update(&mut self.state, &key, value);
265                                }
266                            }
267                        }
268                    }
269
270                    // Yield node_end events
271                    for event in &result.events {
272                        if matches!(event, StreamEvent::NodeEnd { .. }) {
273                            yield Ok(event.clone());
274                        }
275                    }
276
277                    self.pending_nodes = {
278                        let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
279                        match self.filter_deferred_nodes(next_candidates, &result.executed_nodes) {
280                            Ok(nodes) => nodes,
281                            Err(e) => {
282                                yield Err(e);
283                                return;
284                            }
285                        }
286                    };
287                    self.step += 1;
288                    continue;
289                }
290
291                // Execute super-step (non-streaming)
292                let result = match self.execute_super_step().await {
293                    Ok(r) => r,
294                    Err(e) => {
295                        yield Err(e);
296                        return;
297                    }
298                };
299
300                // Yield events based on mode (node_end and custom events)
301                for event in &result.events {
302                    match (&mode, &event) {
303                        // Skip node_start since we already emitted it above
304                        (StreamMode::Custom | StreamMode::Debug, StreamEvent::NodeStart { .. }) => {}
305                        (StreamMode::Custom, _) => yield Ok(event.clone()),
306                        (StreamMode::Debug, _) => yield Ok(event.clone()),
307                        _ => {}
308                    }
309                }
310
311                // Yield state/updates
312                match mode {
313                    StreamMode::Values => {
314                        yield Ok(StreamEvent::state(self.state.clone(), self.step));
315                    }
316                    StreamMode::Updates => {
317                        yield Ok(StreamEvent::step_complete(
318                            self.step,
319                            result.executed_nodes.clone(),
320                        ));
321                    }
322                    _ => {}
323                }
324
325                // Handle interrupts
326                if let Some(interrupt) = result.interrupt {
327                    yield Ok(StreamEvent::interrupted(
328                        result.executed_nodes.first().map(|s| s.as_str()).unwrap_or("unknown"),
329                        &interrupt.to_string(),
330                    ));
331                    return;
332                }
333
334                // Check if done
335                if self.graph.leads_to_end(&result.executed_nodes, &self.state) {
336                    let next = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
337                    if next.is_empty() {
338                        break;
339                    }
340                }
341
342                self.pending_nodes = {
343                    let next_candidates = self.graph.get_next_nodes(&result.executed_nodes, &self.state);
344                    match self.filter_deferred_nodes(next_candidates, &result.executed_nodes) {
345                        Ok(nodes) => nodes,
346                        Err(e) => {
347                            yield Err(e);
348                            return;
349                        }
350                    }
351                };
352                self.step += 1;
353            }
354
355            yield Ok(StreamEvent::done(self.state.clone(), self.step + 1));
356        }
357    }
358
359    /// Filter deferred nodes from the next candidates.
360    ///
361    /// For each candidate node that is configured as deferred, check whether all
362    /// upstream paths have completed. If not, hold the node in `pending_deferred`
363    /// and record the outputs from the just-executed nodes. If all upstream paths
364    /// have completed, inject the merged output into state and allow the node to
365    /// proceed.
366    ///
367    /// If a deferred node has a `fan_in_timeout` configured and the timeout has
368    /// elapsed:
369    /// - If at least one upstream path has completed, proceed with partial results.
370    /// - If zero upstream paths have completed, return `GraphError::FanInTimedOut`.
371    fn filter_deferred_nodes(
372        &mut self,
373        candidates: Vec<String>,
374        executed_nodes: &[String],
375    ) -> Result<Vec<String>> {
376        let mut ready_nodes = Vec::new();
377
378        for candidate in candidates {
379            if let Some(config) = self.graph.deferred_configs.get(&candidate) {
380                // This is a deferred node — check if all upstream paths are done
381                let upstream = self.graph.get_upstream_nodes(&candidate);
382
383                // Get or create the tracker for this deferred node
384                let tracker = self.pending_deferred.entry(candidate.clone()).or_insert_with(|| {
385                    let sources: Vec<&str> = upstream.iter().map(|s| s.as_str()).collect();
386                    FanInTracker::new(sources)
387                });
388
389                // Record the start time if this is the first time we see this deferred node
390                self.deferred_start_times.entry(candidate.clone()).or_insert_with(Instant::now);
391
392                // Record outputs from the just-executed nodes that are upstream of this deferred node
393                for executed in executed_nodes {
394                    if upstream.contains(executed) {
395                        // Use the current state as the output representation for this upstream node.
396                        // We capture a snapshot of the state that this upstream node contributed to.
397                        let output = self.state.get(executed).cloned().unwrap_or_else(|| {
398                            // If no state key matches the node name, capture the full state
399                            serde_json::Value::Object(
400                                self.state.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
401                            )
402                        });
403                        tracker.record(executed, output);
404                    }
405                }
406
407                if tracker.is_ready() {
408                    // All upstream paths have completed — merge and inject into state
409                    let merged = tracker.merge(&config.merge_strategy);
410                    let fan_in_key = format!("{candidate}_fan_in");
411                    self.graph.schema.apply_update(&mut self.state, &fan_in_key, merged);
412
413                    // Remove from pending_deferred and start times since it's now ready
414                    self.pending_deferred.remove(&candidate);
415                    self.deferred_start_times.remove(&candidate);
416                    ready_nodes.push(candidate);
417                } else if let Some(timeout_duration) = config.fan_in_timeout {
418                    // Check if the fan-in timeout has elapsed
419                    let start_time = self.deferred_start_times[&candidate];
420                    if start_time.elapsed() >= timeout_duration {
421                        let received = tracker.received_count();
422                        let expected = tracker.expected_count();
423
424                        if received > 0 {
425                            // Proceed with partial results
426                            tracing::warn!(
427                                node = %candidate,
428                                received,
429                                expected,
430                                "fan-in timeout expired, proceeding with partial results"
431                            );
432                            let merged = tracker.merge(&config.merge_strategy);
433                            let fan_in_key = format!("{candidate}_fan_in");
434                            self.graph.schema.apply_update(&mut self.state, &fan_in_key, merged);
435
436                            // Clean up tracking state
437                            self.pending_deferred.remove(&candidate);
438                            self.deferred_start_times.remove(&candidate);
439                            ready_nodes.push(candidate);
440                        } else {
441                            // Zero upstream paths completed — return error
442                            self.pending_deferred.remove(&candidate);
443                            self.deferred_start_times.remove(&candidate);
444                            return Err(GraphError::FanInTimedOut {
445                                node: candidate,
446                                received,
447                                expected,
448                            });
449                        }
450                    }
451                }
452                // If not ready and no timeout (or timeout not yet elapsed), the node stays
453                // in pending_deferred and is NOT added to ready_nodes
454            } else {
455                // Not a deferred node — schedule normally
456                ready_nodes.push(candidate);
457            }
458        }
459
460        Ok(ready_nodes)
461    }
462
463    /// Initialize state from input and/or checkpoint
464    async fn initialize_state(&self, input: State) -> Result<State> {
465        // Start with schema defaults
466        let mut state = self.graph.schema.initialize_state();
467
468        // If resuming from checkpoint, load it
469        if let Some(checkpoint_id) = &self.config.resume_from {
470            if let Some(cp) = self.graph.checkpointer.as_ref()
471                && let Some(checkpoint) = cp.load_by_id(checkpoint_id).await?
472            {
473                state = checkpoint.state;
474            }
475        } else if let Some(cp) = self.graph.checkpointer.as_ref() {
476            // Try to load latest checkpoint for thread
477            if let Some(checkpoint) = cp.load(&self.config.thread_id).await? {
478                state = checkpoint.state;
479            }
480        }
481
482        // Merge input into state
483        for (key, value) in input {
484            self.graph.schema.apply_update(&mut state, &key, value);
485        }
486
487        Ok(state)
488    }
489
490    /// Execute one super-step (plan -> execute -> update)
491    async fn execute_super_step(&mut self) -> Result<SuperStepResult> {
492        let mut result = SuperStepResult::default();
493
494        // Check for interrupt_before
495        for node_name in &self.pending_nodes {
496            if self.graph.interrupt_before.contains(node_name) {
497                return Ok(SuperStepResult {
498                    interrupt: Some(Interrupt::Before(node_name.clone())),
499                    ..Default::default()
500                });
501            }
502        }
503
504        // --- Node cache: check for cache hits before executing ---
505        #[cfg(feature = "node-cache")]
506        let mut cached_results: HashMap<String, serde_json::Value> = HashMap::new();
507        #[cfg(feature = "node-cache")]
508        let mut nodes_to_execute: Vec<String> = Vec::new();
509
510        #[cfg(feature = "node-cache")]
511        {
512            for node_name in &self.pending_nodes {
513                if let Some(cache) = self.node_caches.get(node_name) {
514                    let cache_key = compute_cache_key(node_name, &self.state);
515                    let cached_value = cache.get(&cache_key).await;
516                    tracing::debug!(
517                        node = %node_name,
518                        cache_hit = cached_value.is_some(),
519                        cache_key = %cache_key,
520                        "node cache lookup"
521                    );
522                    if let Some(value) = cached_value {
523                        // Cache hit — store the cached result for later application
524                        cached_results.insert(node_name.clone(), value);
525                    } else {
526                        // Cache miss — node needs execution
527                        nodes_to_execute.push(node_name.clone());
528                    }
529                } else {
530                    // No cache configured — node needs execution
531                    nodes_to_execute.push(node_name.clone());
532                }
533            }
534        }
535
536        // Apply cached results immediately
537        #[cfg(feature = "node-cache")]
538        {
539            for (node_name, cached_value) in &cached_results {
540                result.executed_nodes.push(node_name.clone());
541                result.events.push(StreamEvent::node_end(node_name, self.step, 0));
542
543                // Reconstruct updates from the cached JSON value (a map of key -> value)
544                if let Some(updates_map) = cached_value.as_object() {
545                    for (key, value) in updates_map {
546                        self.graph.schema.apply_update(&mut self.state, key, value.clone());
547                    }
548                }
549            }
550        }
551
552        // Determine which nodes to execute (all if cache feature is disabled)
553        #[cfg(feature = "node-cache")]
554        let pending_for_execution = &nodes_to_execute;
555        #[cfg(not(feature = "node-cache"))]
556        let pending_for_execution = &self.pending_nodes;
557
558        // Execute all pending nodes in parallel
559        let nodes: Vec<_> = pending_for_execution
560            .iter()
561            .filter_map(|name| self.graph.nodes.get(name).map(|n| (name.clone(), n.clone())))
562            .collect();
563
564        // Look up timeout policies for each node before spawning futures
565        let timeout_policies: Vec<_> =
566            nodes.iter().map(|(name, _)| self.graph.timeout_policy_for(name).cloned()).collect();
567
568        let futures: Vec<_> = nodes
569            .into_iter()
570            .zip(timeout_policies)
571            .map(|((name, node), policy)| {
572                let mut ctx = NodeContext::new(self.state.clone(), self.config.clone(), self.step);
573
574                // Attach a ProgressHandle when idle timeout is configured
575                if let Some(ref p) = policy
576                    && p.idle_timeout.is_some()
577                {
578                    ctx.set_progress_handle(ProgressHandle::new());
579                }
580
581                let step = self.step;
582                async move {
583                    let start = Instant::now();
584                    let output = match policy {
585                        Some(ref timeout_policy) => {
586                            execute_with_timeout(node.as_ref(), &ctx, timeout_policy).await
587                        }
588                        None => node.execute(&ctx).await,
589                    };
590                    let duration_ms = start.elapsed().as_millis() as u64;
591                    (name, output, duration_ms, step)
592                }
593            })
594            .collect();
595
596        let outputs: Vec<_> =
597            stream::iter(futures).buffer_unordered(pending_for_execution.len()).collect().await;
598
599        // Collect all updates and check for errors/interrupts
600        let mut all_updates = Vec::new();
601
602        for (node_name, output_result, duration_ms, step) in outputs {
603            result.executed_nodes.push(node_name.clone());
604            result.events.push(StreamEvent::node_end(&node_name, step, duration_ms));
605
606            match output_result {
607                Ok(output) => {
608                    // Check for dynamic interrupt
609                    if let Some(interrupt) = output.interrupt {
610                        return Ok(SuperStepResult {
611                            interrupt: Some(interrupt),
612                            executed_nodes: result.executed_nodes,
613                            events: result.events,
614                        });
615                    }
616
617                    // Collect custom events
618                    result.events.extend(output.events);
619
620                    // Store result in cache on miss
621                    #[cfg(feature = "node-cache")]
622                    {
623                        if let Some(cache) = self.node_caches.get(&node_name) {
624                            let cache_key = compute_cache_key(&node_name, &self.state);
625                            let updates_value = serde_json::to_value(&output.updates)
626                                .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
627                            let ttl = self.graph.cache_policies.get(&node_name).and_then(|p| p.ttl);
628                            cache.set(&cache_key, updates_value, ttl).await;
629                        }
630                    }
631
632                    // Collect updates
633                    all_updates.push(output.updates);
634                }
635                Err(e) => {
636                    return Err(GraphError::NodeExecutionFailed {
637                        node: node_name,
638                        message: e.to_string(),
639                    });
640                }
641            }
642        }
643
644        // Apply all updates atomically using reducers
645        for updates in all_updates {
646            for (key, value) in updates {
647                self.graph.schema.apply_update(&mut self.state, &key, value);
648            }
649        }
650
651        // Check for interrupt_after
652        for node_name in &result.executed_nodes {
653            if self.graph.interrupt_after.contains(node_name) {
654                return Ok(SuperStepResult {
655                    interrupt: Some(Interrupt::After(node_name.clone())),
656                    ..result
657                });
658            }
659        }
660
661        Ok(result)
662    }
663
664    /// Save a checkpoint
665    async fn save_checkpoint(&self) -> Result<String> {
666        if let Some(cp) = &self.graph.checkpointer {
667            let checkpoint = Checkpoint::new(
668                &self.config.thread_id,
669                self.state.clone(),
670                self.step,
671                self.pending_nodes.clone(),
672            );
673            return cp.save(&checkpoint).await;
674        }
675        Ok(String::new())
676    }
677}
678
679/// Convenience methods for CompiledGraph
680impl CompiledGraph {
681    /// Execute the graph synchronously
682    pub async fn invoke(&self, input: State, config: ExecutionConfig) -> Result<State> {
683        let mut executor = PregelExecutor::new(self, config);
684        executor.run(input).await
685    }
686
687    /// Execute with streaming
688    pub fn stream(
689        &self,
690        input: State,
691        config: ExecutionConfig,
692        mode: StreamMode,
693    ) -> impl futures::Stream<Item = Result<StreamEvent>> + '_ {
694        tracing::debug!("CompiledGraph::stream called with mode {:?}", mode);
695        let executor = PregelExecutor::new(self, config);
696        executor.run_stream(input, mode)
697    }
698
699    /// Get current state for a thread
700    pub async fn get_state(&self, thread_id: &str) -> Result<Option<State>> {
701        if let Some(cp) = &self.checkpointer {
702            Ok(cp.load(thread_id).await?.map(|c| c.state))
703        } else {
704            Ok(None)
705        }
706    }
707
708    /// Update state for a thread (for human-in-the-loop)
709    pub async fn update_state(
710        &self,
711        thread_id: &str,
712        updates: impl IntoIterator<Item = (String, serde_json::Value)>,
713    ) -> Result<()> {
714        if let Some(cp) = &self.checkpointer
715            && let Some(checkpoint) = cp.load(thread_id).await?
716        {
717            let mut state = checkpoint.state;
718            for (key, value) in updates {
719                self.schema.apply_update(&mut state, &key, value);
720            }
721            let new_checkpoint =
722                Checkpoint::new(thread_id, state, checkpoint.step, checkpoint.pending_nodes);
723            cp.save(&new_checkpoint).await?;
724        }
725        Ok(())
726    }
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732    use crate::edge::{END, START};
733    use crate::graph::StateGraph;
734    use crate::node::NodeOutput;
735    use serde_json::json;
736
737    #[tokio::test]
738    async fn test_simple_execution() {
739        let graph = StateGraph::with_channels(&["value"])
740            .add_node_fn("set_value", |_ctx| async {
741                Ok(NodeOutput::new().with_update("value", json!(42)))
742            })
743            .add_edge(START, "set_value")
744            .add_edge("set_value", END)
745            .compile()
746            .unwrap();
747
748        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
749
750        assert_eq!(result.get("value"), Some(&json!(42)));
751    }
752
753    #[tokio::test]
754    async fn test_sequential_execution() {
755        let graph = StateGraph::with_channels(&["value"])
756            .add_node_fn("step1", |_ctx| async {
757                Ok(NodeOutput::new().with_update("value", json!(1)))
758            })
759            .add_node_fn("step2", |ctx| async move {
760                let current = ctx.get("value").and_then(|v| v.as_i64()).unwrap_or(0);
761                Ok(NodeOutput::new().with_update("value", json!(current + 10)))
762            })
763            .add_edge(START, "step1")
764            .add_edge("step1", "step2")
765            .add_edge("step2", END)
766            .compile()
767            .unwrap();
768
769        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
770
771        assert_eq!(result.get("value"), Some(&json!(11)));
772    }
773
774    #[tokio::test]
775    async fn test_conditional_routing() {
776        let graph = StateGraph::with_channels(&["path", "result"])
777            .add_node_fn("router", |ctx| async move {
778                let path = ctx.get("path").and_then(|v| v.as_str()).unwrap_or("a");
779                Ok(NodeOutput::new().with_update("route", json!(path)))
780            })
781            .add_node_fn("path_a", |_ctx| async {
782                Ok(NodeOutput::new().with_update("result", json!("went to A")))
783            })
784            .add_node_fn("path_b", |_ctx| async {
785                Ok(NodeOutput::new().with_update("result", json!("went to B")))
786            })
787            .add_edge(START, "router")
788            .add_conditional_edges(
789                "router",
790                |state| state.get("route").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
791                [("a", "path_a"), ("b", "path_b"), (END, END)],
792            )
793            .add_edge("path_a", END)
794            .add_edge("path_b", END)
795            .compile()
796            .unwrap();
797
798        // Test path A
799        let mut input = State::new();
800        input.insert("path".to_string(), json!("a"));
801        let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
802        assert_eq!(result.get("result"), Some(&json!("went to A")));
803
804        // Test path B
805        let mut input = State::new();
806        input.insert("path".to_string(), json!("b"));
807        let result = graph.invoke(input, ExecutionConfig::new("test")).await.unwrap();
808        assert_eq!(result.get("result"), Some(&json!("went to B")));
809    }
810
811    #[tokio::test]
812    async fn test_cycle_with_limit() {
813        let graph = StateGraph::with_channels(&["count"])
814            .add_node_fn("increment", |ctx| async move {
815                let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
816                Ok(NodeOutput::new().with_update("count", json!(count + 1)))
817            })
818            .add_edge(START, "increment")
819            .add_conditional_edges(
820                "increment",
821                |state| {
822                    let count = state.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
823                    if count < 5 { "increment".to_string() } else { END.to_string() }
824                },
825                [("increment", "increment"), (END, END)],
826            )
827            .compile()
828            .unwrap();
829
830        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await.unwrap();
831
832        assert_eq!(result.get("count"), Some(&json!(5)));
833    }
834
835    #[tokio::test]
836    async fn test_recursion_limit() {
837        let graph = StateGraph::with_channels(&["count"])
838            .add_node_fn("loop", |ctx| async move {
839                let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
840                Ok(NodeOutput::new().with_update("count", json!(count + 1)))
841            })
842            .add_edge(START, "loop")
843            .add_edge("loop", "loop") // Infinite loop
844            .compile()
845            .unwrap()
846            .with_recursion_limit(10);
847
848        let result = graph.invoke(State::new(), ExecutionConfig::new("test")).await;
849
850        // The recursion limit check happens when step >= limit, so it will exceed at step 10
851        assert!(
852            matches!(result, Err(GraphError::RecursionLimitExceeded(_))),
853            "Expected RecursionLimitExceeded error, got: {:?}",
854            result
855        );
856    }
857}