Skip to main content

langgraph_core_rs/graph/
state.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use async_trait::async_trait;
4use serde_json::Value as JsonValue;
5use tokio::sync::mpsc;
6use tokio_stream::wrappers::ReceiverStream;
7use langgraph_checkpoint::config::{RunnableConfig, RunnableConfigExt};
8use langgraph_checkpoint::cache::base::BaseCache;
9use langgraph_checkpoint::store::base::BaseStore;
10use langgraph_checkpoint::checkpoint::base::BaseCheckpointSaver;
11use crate::channels::{Channel, EphemeralValue, NamedBarrierValue};
12use crate::constants::{START, END, RESUME, INTERRUPT, NULL_TASK_ID};
13use crate::runnable::{Runnable, RunnableError, IntoNodeFunction};
14use crate::graph::node::StateNodeSpec;
15use crate::graph::branch::BranchSpec;
16use crate::pregel::{PregelNode, PregelRunner, ChannelVersions, channels_from_checkpoint, PregelExecutableTask};
17use crate::pregel::algo::{prepare_next_tasks, apply_writes};
18use crate::pregel::io::{map_input, map_command, read_channels};
19use crate::stream::StreamPart;
20use crate::types::{Command, StreamMode, StateSnapshot, PregelTask, Interrupt};
21use langgraph_checkpoint::checkpoint::types::CheckpointMetadata;
22
23/// Multi-source edge: waits for all sources to complete before routing to target.
24type WaitingEdge = (Vec<String>, String);
25
26/// Error type for graph building operations.
27#[derive(Debug, thiserror::Error)]
28pub enum GraphError {
29    #[error("node '{0}' already exists")]
30    DuplicateNode(String),
31
32    #[error("unknown node '{0}'")]
33    UnknownNode(String),
34
35    #[error("cannot use reserved name '{0}'")]
36    ReservedName(String),
37
38    #[error("START cannot be an edge target")]
39    StartAsTarget,
40
41    #[error("END cannot be an edge source")]
42    EndAsSource,
43
44    #[error("no outgoing edge from START")]
45    NoStartEdge,
46
47    #[error("graph validation failed: {0}")]
48    ValidationError(String),
49
50    #[error(transparent)]
51    Runnable(#[from] RunnableError),
52
53    #[error("checkpoint error: {0}")]
54    Checkpoint(String),
55}
56
57/// Builder for constructing a state graph.
58///
59/// `S` is the state type (typically a struct with `#[derive(StateGraph)]`).
60/// Channels are derived from `S::create_channels()` (the derive macro generates this).
61///
62/// # Example
63/// ```rust,ignore
64/// use langgraph::prelude::*;
65///
66/// let mut graph = StateGraph::new(channels);
67/// graph.add_node("agent", agent_fn);
68/// graph.add_edge(START, "agent");
69/// graph.add_edge("agent", END);
70/// let compiled = graph.compile(checkpointer, None, None, None, None, false, None);
71/// ```
72pub struct StateGraph {
73    /// Registered nodes keyed by name.
74    nodes: HashMap<String, StateNodeSpec>,
75    /// Simple directed edges: (source, target).
76    edges: HashSet<(String, String)>,
77    /// Multi-source "join" edges: ([source1, source2, ...], target).
78    waiting_edges: HashSet<WaitingEdge>,
79    /// Conditional edges: source -> branch_name -> BranchSpec.
80    branches: HashMap<String, HashMap<String, BranchSpec>>,
81    /// Channels derived from the state schema.
82    channels: HashMap<String, Box<dyn Channel>>,
83    /// Whether compile() has been called.
84    compiled: bool,
85}
86
87impl StateGraph {
88    /// Create a new StateGraph with the given channels.
89    ///
90    /// Typically called via the derive macro: `MyState::create_channels()`.
91    pub fn new(channels: HashMap<String, Box<dyn Channel>>) -> Self {
92        Self {
93            nodes: HashMap::new(),
94            edges: HashSet::new(),
95            waiting_edges: HashSet::new(),
96            branches: HashMap::new(),
97            channels,
98            compiled: false,
99        }
100    }
101
102    /// Add a node to the graph.
103    ///
104    /// Accepts async closures (the default), sync closures via `node_fn!()` or `SyncNodeFn`,
105    /// or pre-built `Arc<dyn Runnable>`.
106    ///
107    /// # Examples
108    /// ```ignore
109    /// // Async closure (default)
110    /// graph.add_node("agent", |input, _config| async move {
111    ///     Ok(json!({"result": "done"}))
112    /// })?;
113    ///
114    /// // Sync closure via node_fn! macro
115    /// graph.add_node("doubler", node_fn!(|input, _config| {
116    ///     let n = input.as_i64().unwrap_or(0);
117    ///     Ok(json!(n * 2))
118    /// }))?;
119    /// ```
120    pub fn add_node(
121        &mut self,
122        name: impl Into<String>,
123        action: impl IntoNodeFunction,
124    ) -> Result<&mut Self, GraphError> {
125        let name = name.into();
126        self.validate_node_name(&name)?;
127        let runnable = action.into_runnable(&name);
128        self.nodes.insert(name.clone(), StateNodeSpec::new(name, runnable));
129        Ok(self)
130    }
131
132    /// Add a direct edge from `start` to `end`.
133    ///
134    /// `start` can be a node name or `START`.
135    /// `end` can be a node name or `END`.
136    pub fn add_edge(
137        &mut self,
138        start: impl Into<String>,
139        end: impl Into<String>,
140    ) -> Result<&mut Self, GraphError> {
141        let start = start.into();
142        let end = end.into();
143
144        if start == END {
145            return Err(GraphError::EndAsSource);
146        }
147        if end == START {
148            return Err(GraphError::StartAsTarget);
149        }
150        if start != START && !self.nodes.contains_key(&start) {
151            return Err(GraphError::UnknownNode(start));
152        }
153        if end != END && !self.nodes.contains_key(&end) {
154            return Err(GraphError::UnknownNode(end));
155        }
156
157        self.edges.insert((start, end));
158        Ok(self)
159    }
160
161    /// Add a multi-source join edge.
162    ///
163    /// The graph waits for ALL `starts` to complete before routing to `end`.
164    pub fn add_join_edge(
165        &mut self,
166        starts: Vec<String>,
167        end: impl Into<String>,
168    ) -> Result<&mut Self, GraphError> {
169        let end = end.into();
170        if end == START {
171            return Err(GraphError::StartAsTarget);
172        }
173        for s in &starts {
174            if s == END {
175                return Err(GraphError::EndAsSource);
176            }
177            if s != START && !self.nodes.contains_key(s) {
178                return Err(GraphError::UnknownNode(s.clone()));
179            }
180        }
181        if end != END && !self.nodes.contains_key(&end) {
182            return Err(GraphError::UnknownNode(end));
183        }
184        self.waiting_edges.insert((starts, end));
185        Ok(self)
186    }
187
188    /// Add conditional edges from `source`.
189    ///
190    /// The `path` function evaluates the state and returns a routing key.
191    /// The `path_map` maps routing keys to destination node names.
192    /// If `path_map` is `None`, the routing key is used directly as the node name.
193    pub fn add_conditional_edges(
194        &mut self,
195        source: impl Into<String>,
196        path: impl IntoNodeFunction,
197        path_map: Option<HashMap<String, String>>,
198    ) -> Result<&mut Self, GraphError> {
199        let source = source.into();
200        if source != START && !self.nodes.contains_key(&source) {
201            return Err(GraphError::UnknownNode(source));
202        }
203
204        let branch_name = format!("branch:{}", source);
205        let runnable = path.into_runnable(&branch_name);
206        let branch = BranchSpec::new(runnable, path_map);
207
208        self.branches
209            .entry(source)
210            .or_default()
211            .insert(branch_name, branch);
212
213        Ok(self)
214    }
215
216    /// Set the entry point (equivalent to `add_edge(START, key)`).
217    pub fn set_entry_point(&mut self, key: impl Into<String>) -> Result<&mut Self, GraphError> {
218        self.add_edge(START, key)
219    }
220
221    /// Set the finish point (equivalent to `add_edge(key, END)`).
222    pub fn set_finish_point(&mut self, key: impl Into<String>) -> Result<&mut Self, GraphError> {
223        self.add_edge(key, END)
224    }
225
226    /// Compile the graph into an executable `CompiledStateGraph`.
227    ///
228    /// The compiled graph implements `Runnable` and can be invoked with state.
229    /// Uses all defaults (no checkpointer, no cache, no store, etc.).
230    ///
231    /// For custom configuration, use `compile_builder()`.
232    pub fn compile(&mut self) -> Result<CompiledStateGraph, GraphError> {
233        self.compile_with(None, None, None, None, None, false, None, None)
234    }
235
236    /// Start building compile options with a builder pattern.
237    ///
238    /// # Example
239    /// ```ignore
240    /// let compiled = graph.compile_builder()
241    ///     .debug(true)
242    ///     .name("my_graph")
243    ///     .build()?;
244    /// ```
245    pub fn compile_builder(&mut self) -> CompileBuilder<'_> {
246        CompileBuilder {
247            graph: self,
248            checkpointer: None,
249            cache: None,
250            store: None,
251            interrupt_before: None,
252            interrupt_after: None,
253            debug: false,
254            name: None,
255            recursion_limit: None,
256        }
257    }
258
259    /// Internal: compile with explicit parameters.
260    fn compile_with(
261        &mut self,
262        checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
263        cache: Option<Arc<dyn BaseCache>>,
264        store: Option<Arc<dyn BaseStore>>,
265        interrupt_before: Option<Vec<String>>,
266        interrupt_after: Option<Vec<String>>,
267        debug: bool,
268        name: Option<String>,
269        recursion_limit: Option<u64>,
270    ) -> Result<CompiledStateGraph, GraphError> {
271        self.validate()?;
272
273        // Add START channel (ephemeral)
274        self.channels.insert(
275            START.to_string(),
276            Box::new(EphemeralValue::new(START, false)),
277        );
278
279        // Add trigger channels for each node ("branch:to:{name}")
280        for name in self.nodes.keys() {
281            let trigger_key = format!("branch:to:{}", name);
282            self.channels
283                .insert(trigger_key.clone(), Box::new(EphemeralValue::new(trigger_key, false)));
284        }
285
286        // Add barrier channels for waiting edges
287        for (sources, target) in &self.waiting_edges {
288            let barrier_name = format!("join:{}:{}", sources.join("+"), target);
289            let names: HashSet<String> = sources.iter().cloned().collect();
290            self.channels.insert(
291                barrier_name.clone(),
292                Box::new(NamedBarrierValue::new(barrier_name, names)),
293            );
294        }
295
296        self.compiled = true;
297
298        let channels = self.channels
299            .iter()
300            .map(|(k, c)| (k.clone(), c.clone_channel()))
301            .collect();
302
303        Ok(CompiledStateGraph {
304            nodes: self.nodes.clone(),
305            edges: self.edges.clone(),
306            waiting_edges: self.waiting_edges.clone(),
307            branches: self.branches.clone(),
308            channels,
309            checkpointer,
310            cache,
311            store,
312            interrupt_before: interrupt_before.unwrap_or_default(),
313            interrupt_after: interrupt_after.unwrap_or_default(),
314            debug,
315            name: name.unwrap_or_else(|| "StateGraph".to_string()),
316            recursion_limit: recursion_limit.unwrap_or(DEFAULT_RECURSION_LIMIT),
317        })
318    }
319
320    fn validate_node_name(&self, name: &str) -> Result<(), GraphError> {
321        if name == START || name == END {
322            return Err(GraphError::ReservedName(name.to_string()));
323        }
324        if self.nodes.contains_key(name) {
325            return Err(GraphError::DuplicateNode(name.to_string()));
326        }
327        Ok(())
328    }
329
330    fn validate(&self) -> Result<(), GraphError> {
331        // START must have at least one outgoing edge
332        let has_start_edge = self.edges.iter().any(|(s, _)| s == START)
333            || self.waiting_edges.iter().any(|(s, _)| s.contains(&START.to_string()))
334            || self.branches.contains_key(START);
335        if !has_start_edge {
336            return Err(GraphError::NoStartEdge);
337        }
338
339        // Validate all edge endpoints exist
340        for (start, end) in &self.edges {
341            if start != START && !self.nodes.contains_key(start) {
342                return Err(GraphError::UnknownNode(start.clone()));
343            }
344            if end != END && !self.nodes.contains_key(end) {
345                return Err(GraphError::UnknownNode(end.clone()));
346            }
347        }
348
349        Ok(())
350    }
351}
352
353/// Builder for configuring `compile()` options.
354pub struct CompileBuilder<'a> {
355    graph: &'a mut StateGraph,
356    checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
357    cache: Option<Arc<dyn BaseCache>>,
358    store: Option<Arc<dyn BaseStore>>,
359    interrupt_before: Option<Vec<String>>,
360    interrupt_after: Option<Vec<String>>,
361    debug: bool,
362    name: Option<String>,
363    recursion_limit: Option<u64>,
364}
365
366impl<'a> CompileBuilder<'a> {
367    pub fn checkpointer(mut self, cp: Arc<dyn BaseCheckpointSaver>) -> Self {
368        self.checkpointer = Some(cp);
369        self
370    }
371
372    pub fn cache(mut self, cache: Arc<dyn BaseCache>) -> Self {
373        self.cache = Some(cache);
374        self
375    }
376
377    pub fn store(mut self, store: Arc<dyn BaseStore>) -> Self {
378        self.store = Some(store);
379        self
380    }
381
382    pub fn interrupt_before(mut self, nodes: Vec<String>) -> Self {
383        self.interrupt_before = Some(nodes);
384        self
385    }
386
387    pub fn interrupt_after(mut self, nodes: Vec<String>) -> Self {
388        self.interrupt_after = Some(nodes);
389        self
390    }
391
392    pub fn debug(mut self, debug: bool) -> Self {
393        self.debug = debug;
394        self
395    }
396
397    pub fn name(mut self, name: impl Into<String>) -> Self {
398        self.name = Some(name.into());
399        self
400    }
401
402    pub fn recursion_limit(mut self, limit: u64) -> Self {
403        self.recursion_limit = Some(limit);
404        self
405    }
406
407    pub fn build(self) -> Result<CompiledStateGraph, GraphError> {
408        self.graph.compile_with(
409            self.checkpointer,
410            self.cache,
411            self.store,
412            self.interrupt_before,
413            self.interrupt_after,
414            self.debug,
415            self.name,
416            self.recursion_limit,
417        )
418    }
419}
420
421/// A compiled, executable state graph.
422///
423/// This is the result of `StateGraph::compile()` and implements `Runnable`.
424/// In the full Pregel engine (Phase 6), this will execute in BSP super-steps.
425/// Currently it provides a simplified sequential execution model.
426pub struct CompiledStateGraph {
427    nodes: HashMap<String, StateNodeSpec>,
428    edges: HashSet<(String, String)>,
429    waiting_edges: HashSet<WaitingEdge>,
430    branches: HashMap<String, HashMap<String, BranchSpec>>,
431    channels: HashMap<String, Box<dyn Channel>>,
432    checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
433    #[allow(dead_code)]
434    cache: Option<Arc<dyn BaseCache>>,
435    store: Option<Arc<dyn BaseStore>>,
436    interrupt_before: Vec<String>,
437    interrupt_after: Vec<String>,
438    debug: bool,
439    name: String,
440    recursion_limit: u64,
441}
442
443impl CompiledStateGraph {
444    /// Get the node names in this graph.
445    pub fn node_names(&self) -> Vec<String> {
446        self.nodes.keys().cloned().collect()
447    }
448
449    /// Get the channel names in this graph.
450    pub fn channel_names(&self) -> Vec<String> {
451        self.channels.keys().cloned().collect()
452    }
453
454    /// Check if a node exists.
455    pub fn has_node(&self, name: &str) -> bool {
456        self.nodes.contains_key(name)
457    }
458
459    /// Get the graph name.
460    pub fn name(&self) -> &str {
461        &self.name
462    }
463
464    /// Get the checkpointer, if any.
465    pub fn checkpointer(&self) -> Option<&Arc<dyn BaseCheckpointSaver>> {
466        self.checkpointer.as_ref()
467    }
468
469    /// Get the store, if any.
470    pub fn store(&self) -> Option<&Arc<dyn BaseStore>> {
471        self.store.as_ref()
472    }
473
474    /// Save a checkpoint from current channel state.
475    fn save_checkpoint(
476        &self,
477        checkpointer: &Arc<dyn BaseCheckpointSaver>,
478        config: &RunnableConfig,
479        channels: &HashMap<String, Box<dyn Channel>>,
480        channel_versions: &ChannelVersions,
481        versions_seen: &HashMap<String, HashMap<String, JsonValue>>,
482    ) -> Option<RunnableConfig> {
483        use langgraph_checkpoint::checkpoint::id::uuid6;
484        use chrono::Utc;
485
486        // Collect all channel values (including trigger channels for state history)
487        let channel_values: HashMap<String, JsonValue> = channels
488            .iter()
489            .filter_map(|(k, v)| v.checkpoint().map(|val| (k.clone(), val)))
490            .collect();
491
492        let checkpoint = langgraph_checkpoint::Checkpoint {
493            v: 2,
494            id: uuid6(),
495            ts: Utc::now().to_rfc3339(),
496            channel_values,
497            channel_versions: channel_versions.clone(),
498            versions_seen: versions_seen.clone(),
499            updated_channels: None,
500        };
501
502        let metadata = CheckpointMetadata::default();
503        checkpointer.put(config, &checkpoint, &metadata, channel_versions).ok()
504    }
505
506    /// Determine which nodes should execute next given the current state.
507    ///
508    /// This is the "plan" phase of the BSP cycle.
509    pub fn get_next_nodes(&self, state: &HashMap<String, JsonValue>) -> Vec<String> {
510        let mut next = Vec::new();
511
512        // Check which nodes are triggered by edges from completed nodes
513        for (start, end) in &self.edges {
514            if (start == START || state.contains_key(&format!("branch:to:{}", start)))
515                && end != END {
516                    next.push(end.clone());
517                }
518        }
519
520        // Check conditional branches
521        for (source, branches) in &self.branches {
522            if source == START || state.contains_key(&format!("branch:to:{}", source)) {
523                for _branch in branches.values() {
524                    // Evaluate the branch path to determine routing
525                    // For now, we'd need to actually invoke the path runnable
526                    // This is handled by the Pregel engine in Phase 6
527                }
528            }
529        }
530
531        next
532    }
533
534    /// Get the current state of the graph from the checkpointer.
535    ///
536    /// Returns a `StateSnapshot` containing the current channel values,
537    /// the names of nodes that will execute next, pending tasks, and
538    /// any unresolved interrupts.
539    ///
540    /// Requires a checkpointer to be configured.
541    ///
542    /// # Example
543    /// ```ignore
544    /// let snapshot = compiled.get_state(&config)?;
545    /// println!("next: {:?}", snapshot.next);
546    /// println!("values: {}", snapshot.values);
547    /// ```
548    pub fn get_state(&self, config: &RunnableConfig) -> Result<StateSnapshot, GraphError> {
549        let checkpointer = self.checkpointer.as_ref().ok_or_else(|| {
550            GraphError::ValidationError("No checkpointer set".to_string())
551        })?;
552
553        let saved = checkpointer
554            .get_tuple(config)
555            .map_err(|e| GraphError::Checkpoint(e.to_string()))?;
556
557        let Some(saved) = saved else {
558            return Ok(StateSnapshot {
559                values: JsonValue::Object(serde_json::Map::new()),
560                next: vec![],
561                config: config.clone(),
562                metadata: None,
563                created_at: None,
564                parent_config: None,
565                tasks: vec![],
566                interrupts: vec![],
567            });
568        };
569
570        // Reconstruct channels from checkpoint
571        let cp_channels: HashMap<String, Option<JsonValue>> = saved
572            .checkpoint
573            .channel_values
574            .iter()
575            .map(|(k, v)| (k.clone(), Some(v.clone())))
576            .collect();
577        let mut channels = channels_from_checkpoint(&self.channels, &cp_channels);
578
579        let mut channel_versions = saved.checkpoint.channel_versions.clone();
580        let mut versions_seen = saved.checkpoint.versions_seen.clone();
581
582        // Apply null-task pending writes (input writes not tied to a task)
583        if let Some(ref pending) = saved.pending_writes {
584            for (tid, chan, val) in pending {
585                if tid == NULL_TASK_ID {
586                    if let Some(ch) = channels.get(chan) {
587                        ch.update(&[val.clone()]).ok();
588                    }
589                }
590            }
591        }
592
593        // Build PregelNode specs and prepare next tasks
594        let pregel_nodes = build_pregel_nodes(
595            &self.nodes,
596            &self.edges,
597            &self.waiting_edges,
598            &self.branches,
599            &self.channels,
600        );
601        let trigger_to_nodes = crate::pregel::build_trigger_to_nodes(&pregel_nodes);
602
603        let step = 0u64;
604        let checkpoint_id = format!("{:032}", step);
605        let pending_writes: Vec<(String, String, JsonValue)> = saved
606            .pending_writes
607            .as_ref()
608            .map(|pw| pw.to_vec())
609            .unwrap_or_default();
610
611        let mut tasks = prepare_next_tasks(
612            &pregel_nodes,
613            &channels,
614            config,
615            step,
616            &mut versions_seen,
617            &trigger_to_nodes,
618            None,
619            &checkpoint_id,
620            &pending_writes,
621            &channel_versions,
622        );
623
624        // Apply non-INTERRUPT, non-ERROR pending writes to tasks
625        // so that the snapshot values reflect completed task outputs
626        if let Some(ref pending) = saved.pending_writes {
627            for (tid, chan, val) in pending {
628                if chan == INTERRUPT || chan == crate::constants::ERROR {
629                    continue;
630                }
631                if tid == NULL_TASK_ID {
632                    continue;
633                }
634                if let Some(task) = tasks.iter_mut().find(|t| &t.id == tid) {
635                    task.writes.push((chan.clone(), val.clone()));
636                }
637            }
638        }
639
640        // Apply writes from completed tasks to get final channel state
641        apply_writes(
642            &mut channels,
643            &tasks,
644            &mut versions_seen,
645            &mut channel_versions,
646            &trigger_to_nodes,
647            |current| {
648                let num = current
649                    .and_then(|v| v.as_str())
650                    .and_then(|s| s.parse::<u64>().ok())
651                    .unwrap_or(0);
652                JsonValue::String(format!("{:032}", num + 1))
653            },
654        );
655
656        // Read channel values
657        let output_keys: Vec<String> = channels
658            .keys()
659            .filter(|k| !k.starts_with("branch:") && !k.starts_with("join:") && *k != START)
660            .cloned()
661            .collect();
662        let values = read_channels(&channels, &output_keys);
663
664        // Build next: names of tasks that have NOT written yet
665        let next: Vec<String> = tasks
666            .iter()
667            .filter(|t| t.writes.is_empty())
668            .map(|t| t.name.clone())
669            .collect();
670
671        // Extract interrupts from pending writes
672        let interrupts: Vec<Interrupt> = saved
673            .pending_writes
674            .as_ref()
675            .map(|pw| {
676                pw.iter()
677                    .filter(|(_, chan, _)| chan == INTERRUPT)
678                    .filter_map(|(_, _, val)| {
679                        serde_json::from_value::<Interrupt>(val.clone()).ok()
680                    })
681                    .collect()
682            })
683            .unwrap_or_default();
684
685        // Build PregelTask list for the snapshot
686        let snapshot_tasks: Vec<PregelTask> = tasks
687            .iter()
688            .map(|t| {
689                let task_interrupts: Vec<Interrupt> = saved
690                    .pending_writes
691                    .as_ref()
692                    .map(|pw| {
693                        pw.iter()
694                            .filter(|(tid, chan, _)| tid == &t.id && chan == INTERRUPT)
695                            .filter_map(|(_, _, val)| {
696                                serde_json::from_value::<Interrupt>(val.clone()).ok()
697                            })
698                            .collect()
699                    })
700                    .unwrap_or_default();
701
702                PregelTask {
703                    id: t.id.clone(),
704                    name: t.name.clone(),
705                    path: vec![],
706                    error: None,
707                    interrupts: task_interrupts,
708                    result: None,
709                }
710            })
711            .collect();
712
713        Ok(StateSnapshot {
714            values,
715            next,
716            config: saved.config.clone(),
717            metadata: Some(saved.metadata.clone()),
718            created_at: Some(saved.checkpoint.ts.clone()),
719            parent_config: saved.parent_config.clone(),
720            tasks: snapshot_tasks,
721            interrupts,
722        })
723    }
724
725    /// Manually update the graph state.
726    ///
727    /// Applies the given values to the current checkpoint's channels and
728    /// saves a new checkpoint. This allows updating custom state fields
729    /// (like `name`, `birthday`) outside of normal node execution.
730    ///
731    /// Requires a checkpointer to be configured.
732    ///
733    /// # Arguments
734    /// * `config` - The runnable config (must include `thread_id`)
735    /// * `values` - A JSON object of channel updates, e.g. `{"name": "LangGraph"}`
736    ///
737    /// # Example
738    /// ```ignore
739    /// compiled.update_state(&config, json!({"name": "LangGraph (library)"}))?;
740    /// let snapshot = compiled.get_state(&config)?;
741    /// assert_eq!(snapshot.values["name"], "LangGraph (library)");
742    /// ```
743    pub fn update_state(
744        &self,
745        config: &RunnableConfig,
746        values: &JsonValue,
747    ) -> Result<RunnableConfig, GraphError> {
748        let checkpointer = self.checkpointer.as_ref().ok_or_else(|| {
749            GraphError::ValidationError("No checkpointer set".to_string())
750        })?;
751
752        let saved = checkpointer
753            .get_tuple(config)
754            .map_err(|e| GraphError::Checkpoint(e.to_string()))?;
755
756        // Reconstruct channels from checkpoint (or fresh if none)
757        let channels: HashMap<String, Box<dyn Channel>> = if let Some(ref saved) = saved {
758            let cp_channels: HashMap<String, Option<JsonValue>> = saved
759                .checkpoint
760                .channel_values
761                .iter()
762                .map(|(k, v)| (k.clone(), Some(v.clone())))
763                .collect();
764            channels_from_checkpoint(&self.channels, &cp_channels)
765        } else {
766            self.channels
767                .iter()
768                .map(|(k, c)| (k.clone(), c.clone_channel()))
769                .collect()
770        };
771
772        let mut channel_versions = saved
773            .as_ref()
774            .map(|s| s.checkpoint.channel_versions.clone())
775            .unwrap_or_default();
776        let versions_seen = saved
777            .as_ref()
778            .map(|s| s.checkpoint.versions_seen.clone())
779            .unwrap_or_default();
780
781        // Apply the update values to channels
782        if let Some(obj) = values.as_object() {
783            for (key, val) in obj {
784                if let Some(ch) = channels.get(key) {
785                    ch.update(&[val.clone()]).ok();
786                    // Bump the channel version
787                    let new_version = channel_versions
788                        .get(key)
789                        .and_then(|v| v.as_str())
790                        .and_then(|s| s.parse::<u64>().ok())
791                        .unwrap_or(0)
792                        + 1;
793                    channel_versions.insert(
794                        key.clone(),
795                        JsonValue::String(format!("{:032}", new_version)),
796                    );
797                }
798            }
799        }
800
801        // Save the updated checkpoint
802        self.save_checkpoint(checkpointer, config, &channels, &channel_versions, &versions_seen);
803
804        Ok(config.clone())
805    }
806
807    /// Get the state history (all checkpoints) for a thread.
808    ///
809    /// Returns a list of `StateSnapshot` in reverse chronological order
810    /// (newest first). Each snapshot contains the checkpoint's channel values,
811    /// which node would execute next, and metadata.
812    ///
813    /// This enables "time travel" — reviewing past states and resuming
814    /// from any checkpoint.
815    ///
816    /// # Example
817    /// ```ignore
818    /// let history = compiled.get_state_history(&config)?;
819    /// for snapshot in &history {
820    ///     println!("messages: {}, next: {:?}", snapshot.values["messages"].as_array().map(|a| a.len()), snapshot.next);
821    /// }
822    /// ```
823    pub fn get_state_history(&self, config: &RunnableConfig) -> Result<Vec<StateSnapshot>, GraphError> {
824        let checkpointer = self.checkpointer.as_ref().ok_or_else(|| {
825            GraphError::ValidationError("No checkpointer set".to_string())
826        })?;
827
828        let tuples = checkpointer
829            .list(Some(config), None, None, None)
830            .map_err(|e| GraphError::Checkpoint(e.to_string()))?;
831
832        let mut snapshots = Vec::new();
833
834        // Build PregelNode specs for task preparation
835        let pregel_nodes = build_pregel_nodes(
836            &self.nodes,
837            &self.edges,
838            &self.waiting_edges,
839            &self.branches,
840            &self.channels,
841        );
842        let trigger_to_nodes = crate::pregel::build_trigger_to_nodes(&pregel_nodes);
843
844        for saved in &tuples {
845            // Reconstruct channels from checkpoint
846            let cp_channels: HashMap<String, Option<JsonValue>> = saved
847                .checkpoint
848                .channel_values
849                .iter()
850                .map(|(k, v)| (k.clone(), Some(v.clone())))
851                .collect();
852            let channels = channels_from_checkpoint(&self.channels, &cp_channels);
853
854            let channel_versions = saved.checkpoint.channel_versions.clone();
855            let mut versions_seen = saved.checkpoint.versions_seen.clone();
856
857            // Apply non-INTERRUPT pending writes to get the correct channel state
858            if let Some(ref pending) = saved.pending_writes {
859                for (tid, chan, val) in pending {
860                    if chan == INTERRUPT || chan == crate::constants::ERROR {
861                        continue;
862                    }
863                    if tid == NULL_TASK_ID {
864                        if let Some(ch) = channels.get(chan) {
865                            ch.update(&[val.clone()]).ok();
866                        }
867                        continue;
868                    }
869                    if let Some(ch) = channels.get(chan) {
870                        ch.update(&[val.clone()]).ok();
871                    }
872                }
873            }
874
875            // Read output values
876            let output_keys: Vec<String> = channels
877                .keys()
878                .filter(|k| !k.starts_with("branch:") && !k.starts_with("join:") && *k != START)
879                .cloned()
880                .collect();
881            let values = read_channels(&channels, &output_keys);
882
883            // Prepare tasks to determine what would execute next
884            let checkpoint_id = saved.checkpoint.id.clone();
885            let pending_writes: Vec<(String, String, JsonValue)> = saved
886                .pending_writes
887                .as_ref()
888                .map(|pw| pw.iter().map(|(t, c, v)| (t.clone(), c.clone(), v.clone())).collect())
889                .unwrap_or_default();
890
891            let tasks = prepare_next_tasks(
892                &pregel_nodes,
893                &channels,
894                &RunnableConfig::new(),
895                0,
896                &mut versions_seen,
897                &trigger_to_nodes,
898                None,
899                &checkpoint_id,
900                &pending_writes,
901                &channel_versions,
902            );
903
904            // Next = tasks that haven't written yet
905            let next: Vec<String> = tasks
906                .iter()
907                .filter(|t| t.writes.is_empty())
908                .map(|t| t.name.clone())
909                .collect();
910
911            // Extract interrupts from pending writes
912            let interrupts: Vec<Interrupt> = saved
913                .pending_writes
914                .as_ref()
915                .map(|pw| {
916                    pw.iter()
917                        .filter(|(_, chan, _)| chan == INTERRUPT)
918                        .filter_map(|(_, _, val)| {
919                            serde_json::from_value::<Interrupt>(val.clone()).ok()
920                        })
921                        .collect()
922                })
923                .unwrap_or_default();
924
925            snapshots.push(StateSnapshot {
926                values,
927                next,
928                config: saved.config.clone(),
929                metadata: Some(saved.metadata.clone()),
930                created_at: Some(saved.checkpoint.ts.clone()),
931                parent_config: saved.parent_config.clone(),
932                tasks: vec![],
933                interrupts,
934            });
935        }
936
937        Ok(snapshots)
938    }
939}
940
941impl Clone for CompiledStateGraph {
942    fn clone(&self) -> Self {
943        let channels: HashMap<String, Box<dyn Channel>> = self.channels
944            .iter()
945            .map(|(k, c)| (k.clone(), c.clone_channel()))
946            .collect();
947
948        // Manually clone branches (nested HashMap with non-Clone inner values already handled by Arc)
949        let branches: HashMap<String, HashMap<String, BranchSpec>> = self.branches
950            .iter()
951            .map(|(k, v)| (k.clone(), v.clone()))
952            .collect();
953
954        Self {
955            nodes: self.nodes.clone(),
956            edges: self.edges.clone(),
957            waiting_edges: self.waiting_edges.clone(),
958            branches,
959            channels,
960            checkpointer: self.checkpointer.clone(),
961            cache: self.cache.clone(),
962            store: self.store.clone(),
963            interrupt_before: self.interrupt_before.clone(),
964            interrupt_after: self.interrupt_after.clone(),
965            debug: self.debug,
966            name: self.name.clone(),
967            recursion_limit: self.recursion_limit,
968        }
969    }
970}
971
972/// Build PregelNode specs from the graph structure.
973///
974/// For each node, creates a combined runnable that:
975/// 1. Executes the node logic
976/// 2. Writes state updates to channels
977/// 3. Writes to trigger / barrier channels for edge targets
978///
979/// Join edges (from `add_join_edge`) use a `NamedBarrierValue` channel
980/// (named `join:{sources}:{target}`) instead of a plain `branch:to:{target}`.
981/// Each source node writes its own name into the barrier channel; the barrier
982/// becomes available only when ALL sources have written, at which point the
983/// join-target node is triggered.
984fn build_pregel_nodes(
985    nodes: &HashMap<String, StateNodeSpec>,
986    edges: &HashSet<(String, String)>,
987    waiting_edges: &HashSet<WaitingEdge>,
988    branches: &HashMap<String, HashMap<String, BranchSpec>>,
989    channels: &HashMap<String, Box<dyn Channel>>,
990) -> HashMap<String, PregelNode> {
991    let mut pregel_nodes = HashMap::new();
992
993    // Build a map of source -> [plain-edge targets] (excluding END)
994    let mut edge_targets: HashMap<String, Vec<String>> = HashMap::new();
995    for (start, end) in edges {
996        if end != END {
997            edge_targets.entry(start.clone()).or_default().push(end.clone());
998        }
999    }
1000
1001    // Build join-edge lookup maps from waiting_edges:
1002    //
1003    //   join_writes_for_source:  source_name -> [(barrier_channel_name, source_name)]
1004    //     When a source node completes, it writes its own name into every
1005    //     barrier channel it participates in.
1006    //
1007    //   join_trigger_for_target: target_name -> barrier_channel_name
1008    //     The join-target node uses the barrier channel as its sole trigger
1009    //     instead of the default "branch:to:{name}" ephemeral channel.
1010    let mut join_writes_for_source: HashMap<String, Vec<(String, String)>> = HashMap::new();
1011    let mut join_trigger_for_target: HashMap<String, String> = HashMap::new();
1012
1013    for (sources, target) in waiting_edges {
1014        // Barrier channel name must match what compile_with() created.
1015        // sources is a Vec so we preserve insertion order for the name.
1016        let barrier_name = format!("join:{}:{}", sources.join("+"), target);
1017
1018        // Each source must write its name into this barrier channel
1019        for source in sources {
1020            join_writes_for_source
1021                .entry(source.clone())
1022                .or_default()
1023                .push((barrier_name.clone(), source.clone()));
1024        }
1025
1026        // The target node is triggered by the barrier channel
1027        join_trigger_for_target.insert(target.clone(), barrier_name);
1028    }
1029
1030    // Build PregelNode for each registered node
1031    for (name, spec) in nodes {
1032        // Determine this node's trigger channel.
1033        // Join-target nodes use their barrier channel; all others use the
1034        // standard ephemeral "branch:to:{name}" channel.
1035        let trigger = join_trigger_for_target
1036            .get(name)
1037            .cloned()
1038            .unwrap_or_else(|| format!("branch:to:{}", name));
1039
1040        // Determine input channels — all non-special channels
1041        let input_channels: Vec<String> = channels
1042            .keys()
1043            .filter(|k| {
1044                !k.starts_with("branch:") && !k.starts_with("join:") && *k != START
1045            })
1046            .cloned()
1047            .collect();
1048
1049        // Plain edge targets for this node
1050        let targets: Vec<String> = edge_targets.get(name).cloned().unwrap_or_default();
1051
1052        // Barrier channel writes this node must emit when it completes
1053        // (participates in one or more join edges)
1054        let barrier_writes: Vec<(String, String)> = join_writes_for_source
1055            .get(name)
1056            .cloned()
1057            .unwrap_or_default();
1058
1059        // Branch specs
1060        let node_branches: Vec<BranchSpec> = branches
1061            .get(name)
1062            .map(|m| m.values().cloned().collect())
1063            .unwrap_or_default();
1064
1065        let node_runnable = spec.runnable.clone();
1066        let node_name = name.clone();
1067
1068        let combined: Arc<dyn Runnable> = Arc::new(
1069            crate::runnable::RunnableCallable::new(
1070                node_name.clone(),
1071                move |input, config| {
1072                    let node_runnable = node_runnable.clone();
1073                    let targets = targets.clone();
1074                    let barrier_writes = barrier_writes.clone();
1075                    let node_branches = node_branches.clone();
1076                    async move {
1077                        // 1. Execute the node logic
1078                        let output = node_runnable.ainvoke(&input, &config).await?;
1079
1080                        // 2. Build combined output: state updates + trigger writes
1081                        let mut result = serde_json::Map::new();
1082
1083                        // Copy state updates from node output
1084                        if let Some(obj) = output.as_object() {
1085                            for (k, v) in obj {
1086                                result.insert(k.clone(), v.clone());
1087                            }
1088                        }
1089
1090                        // 3. Write to plain trigger channels for simple edge targets
1091                        for target in &targets {
1092                            let trigger_ch = format!("branch:to:{}", target);
1093                            result.insert(trigger_ch, JsonValue::String(target.clone()));
1094                        }
1095
1096                        // 4. Write into barrier channels for join-edge participation.
1097                        // The value written is this node's own name so the
1098                        // NamedBarrierValue can track which sources have arrived.
1099                        for (barrier_ch, source_name) in &barrier_writes {
1100                            result.insert(
1101                                barrier_ch.clone(),
1102                                JsonValue::String(source_name.clone()),
1103                            );
1104                        }
1105
1106                        // 5. Evaluate conditional branches
1107                        for branch in &node_branches {
1108                            let branch_result = branch.path.ainvoke(&output, &config).await?;
1109                            let key = branch_result.as_str().unwrap_or("");
1110                            if let Some(target) = branch.resolve(key) {
1111                                let trigger_ch = format!("branch:to:{}", target);
1112                                result.insert(trigger_ch, JsonValue::String(target));
1113                            }
1114                        }
1115
1116                        Ok(JsonValue::Object(result))
1117                    }
1118                },
1119            ),
1120        );
1121
1122        let pregel_node = PregelNode::new(
1123            input_channels,
1124            vec![trigger],
1125            combined,
1126        );
1127
1128        pregel_nodes.insert(name.clone(), pregel_node);
1129    }
1130
1131    pregel_nodes
1132}
1133
1134/// Default recursion limit.
1135const DEFAULT_RECURSION_LIMIT: u64 = 25;
1136
1137// ── Streaming context ─────────────────────────────────────────────────────
1138//
1139// Passed to `run_pregel_inner` when streaming is enabled. When `None`, the
1140// inner loop runs in non-streaming mode (same logic, no emit calls).
1141
1142struct StreamCtx<'a> {
1143    modes: &'a HashSet<StreamMode>,
1144    tx: &'a mpsc::Sender<StreamPart>,
1145    /// Sender for the `Custom` stream channel. `None` when Custom mode is off.
1146    custom_tx: Option<mpsc::Sender<JsonValue>>,
1147}
1148
1149impl<'a> StreamCtx<'a> {
1150    fn has(&self, mode: &StreamMode) -> bool {
1151        self.modes.contains(mode)
1152    }
1153}
1154
1155// Helper: apply completed tasks' writes to channels, updating versions_seen
1156// and channel_versions for them. The interrupted task (identified by
1157// `interrupted_task_id`) is deliberately excluded so its trigger channels
1158// remain "unseen" and it re-triggers on resume.
1159//
1160// This mirrors Python's `_suppress_interrupt` in `_loop.py`.
1161fn apply_completed_writes(
1162    interrupted_task_id: &str,
1163    tasks: &[PregelExecutableTask],
1164    channels: &HashMap<String, Box<dyn Channel>>,
1165    versions_seen: &mut HashMap<String, HashMap<String, JsonValue>>,
1166    channel_versions: &mut ChannelVersions,
1167) {
1168    // Update versions_seen only for completed tasks
1169    for task in tasks.iter().filter(|t| t.id != interrupted_task_id && !t.writes.is_empty()) {
1170        let seen = versions_seen.entry(task.name.clone()).or_default();
1171        for trigger in &task.triggers {
1172            if let Some(ver) = channel_versions.get(trigger.as_str()) {
1173                seen.insert(trigger.clone(), ver.clone());
1174            }
1175        }
1176    }
1177
1178    // Compute a single global next_version from the max of all channel versions
1179    let max_ver = channel_versions
1180        .values()
1181        .filter_map(|v| v.as_str().and_then(|s| s.parse::<u64>().ok()))
1182        .max()
1183        .unwrap_or(0);
1184    let next_version = JsonValue::String(format!("{:032}", max_ver + 1));
1185
1186    // Collect and apply writes from completed tasks to channels.
1187    // Filter out all reserved keys (matching Python behavior).
1188    let mut writes_by_channel: HashMap<String, Vec<JsonValue>> = HashMap::new();
1189    for task in tasks.iter().filter(|t| t.id != interrupted_task_id && !t.writes.is_empty()) {
1190        for (chan, val) in &task.writes {
1191            if crate::constants::RESERVED.contains(&chan.as_str()) {
1192                continue;
1193            }
1194            writes_by_channel.entry(chan.clone()).or_default().push(val.clone());
1195        }
1196    }
1197
1198    for (chan, vals) in &writes_by_channel {
1199        if let Some(ch) = channels.get(chan.as_str()) {
1200            if ch.update(vals).unwrap_or(false) {
1201                channel_versions.insert(chan.clone(), next_version.clone());
1202            }
1203        }
1204    }
1205}
1206
1207// Helper: collect output channel keys (excluding internal routing channels).
1208fn output_channel_keys(channels: &HashMap<String, Box<dyn Channel>>) -> Vec<String> {
1209    channels
1210        .keys()
1211        .filter(|k| !k.starts_with("branch:") && !k.starts_with("join:") && *k != START)
1212        .cloned()
1213        .collect()
1214}
1215
1216// Helper: bump-version closure used in apply_writes.
1217fn bump_version(current: Option<&JsonValue>) -> JsonValue {
1218    let num = current
1219        .and_then(|v| v.as_str())
1220        .and_then(|s| s.parse::<u64>().ok())
1221        .unwrap_or(0);
1222    JsonValue::String(format!("{:032}", num + 1))
1223}
1224
1225impl CompiledStateGraph {
1226    // ────────────────────────────────────────────────────────────────────────
1227    // Public thin wrappers
1228    // ────────────────────────────────────────────────────────────────────────
1229
1230    /// Non-streaming invocation: runs the BSP loop and returns the final output.
1231    async fn run_pregel(
1232        &self,
1233        input: &JsonValue,
1234        config: &RunnableConfig,
1235    ) -> Result<JsonValue, RunnableError> {
1236        self.run_pregel_inner(input, config, None).await
1237    }
1238
1239    /// Streaming invocation: runs the BSP loop and emits `StreamPart`s via `tx`.
1240    async fn run_pregel_streaming(
1241        &self,
1242        input: &JsonValue,
1243        config: &RunnableConfig,
1244        modes: &HashSet<StreamMode>,
1245        tx: &mpsc::Sender<StreamPart>,
1246    ) -> Result<JsonValue, RunnableError> {
1247        // Set up the custom-stream forwarder if Custom mode is requested.
1248        let (custom_tx, has_custom) = if modes.contains(&StreamMode::Custom) {
1249            let (ctx, mut crx) = mpsc::channel::<JsonValue>(64);
1250            let tx_clone = tx.clone();
1251            tokio::spawn(async move {
1252                while let Some(data) = crx.recv().await {
1253                    let _ = tx_clone.send(StreamPart::custom(vec![], data)).await;
1254                }
1255            });
1256            (Some(ctx), true)
1257        } else {
1258            (None, false)
1259        };
1260        let _ = has_custom; // suppresses warning; custom_tx presence implies has_custom
1261
1262        let ctx = StreamCtx { modes, tx, custom_tx };
1263        self.run_pregel_inner(input, config, Some(&ctx)).await
1264    }
1265
1266    /// Public streaming API: returns a `ReceiverStream` of `StreamPart`s.
1267    pub fn astream(
1268        &self,
1269        input: &JsonValue,
1270        config: &RunnableConfig,
1271        stream_modes: Vec<StreamMode>,
1272    ) -> ReceiverStream<StreamPart> {
1273        let (tx, rx) = mpsc::channel(256);
1274        let modes: HashSet<StreamMode> = stream_modes.into_iter().collect();
1275
1276        let graph = self.clone();
1277        let input = input.clone();
1278        let config = config.clone();
1279
1280        tokio::spawn(async move {
1281            let result = graph.run_pregel_streaming(&input, &config, &modes, &tx).await;
1282            if let Err(e) = result {
1283                let _ = tx.send(StreamPart::debug(
1284                    vec![],
1285                    serde_json::json!({"error": e.to_string()}),
1286                )).await;
1287            }
1288        });
1289
1290        ReceiverStream::new(rx)
1291    }
1292
1293    // ────────────────────────────────────────────────────────────────────────
1294    // Unified BSP loop (previously duplicated as run_pregel / run_pregel_streaming)
1295    // ────────────────────────────────────────────────────────────────────────
1296    //
1297    // `stream` is `None` in non-streaming mode and `Some(&ctx)` in streaming
1298    // mode. Every streaming emit is guarded by `if let Some(s) = stream`.
1299    // The core logic — checkpoint loading, task preparation, execution,
1300    // apply_writes, interrupt handling — is identical in both modes.
1301    async fn run_pregel_inner(
1302        &self,
1303        input: &JsonValue,
1304        config: &RunnableConfig,
1305        stream: Option<&StreamCtx<'_>>,
1306    ) -> Result<JsonValue, RunnableError> {
1307        let mut config = config.clone();
1308        // ── Setup ────────────────────────────────────────────────────────────
1309
1310        let pregel_nodes = build_pregel_nodes(
1311            &self.nodes,
1312            &self.edges,
1313            &self.waiting_edges,
1314            &self.branches,
1315            &self.channels,
1316        );
1317        let trigger_to_nodes = crate::pregel::build_trigger_to_nodes(&pregel_nodes);
1318
1319        // Load checkpoint (for resume support)
1320        let mut saved_checkpoint_exists = false;
1321        let (mut channels, mut channel_versions, mut versions_seen) =
1322            if let Some(ref cp) = self.checkpointer {
1323                match cp.get_tuple(&config) {
1324                    Ok(Some(tuple)) => {
1325                        saved_checkpoint_exists = true;
1326                        let cp_channels: HashMap<String, Option<JsonValue>> = tuple
1327                            .checkpoint
1328                            .channel_values
1329                            .iter()
1330                            .map(|(k, v)| (k.clone(), Some(v.clone())))
1331                            .collect();
1332                        let restored = channels_from_checkpoint(&self.channels, &cp_channels);
1333
1334                        // Apply non-RESUME pending writes from the checkpoint
1335                        if let Some(ref pending) = tuple.pending_writes {
1336                            for (_task_id, channel, value) in pending {
1337                                if channel != RESUME {
1338                                    if let Some(ch) = restored.get(channel) {
1339                                        ch.update(&[value.clone()]).ok();
1340                                    }
1341                                }
1342                            }
1343                        }
1344
1345                        (
1346                            restored,
1347                            tuple.checkpoint.channel_versions.clone(),
1348                            tuple.checkpoint.versions_seen.clone(),
1349                        )
1350                    }
1351                    _ => (
1352                        self.channels.iter().map(|(k, c)| (k.clone(), c.clone_channel())).collect(),
1353                        HashMap::new(),
1354                        HashMap::new(),
1355                    ),
1356                }
1357            } else {
1358                (
1359                    self.channels.iter().map(|(k, c)| (k.clone(), c.clone_channel())).collect(),
1360                    HashMap::new(),
1361                    HashMap::new(),
1362                )
1363            };
1364
1365        // BSP loop counters
1366        let mut step: u64 = 0;
1367        let max_steps = config.get_recursion_limit().unwrap_or(self.recursion_limit);
1368        let mut last_output = JsonValue::Null;
1369        let mut pending_writes: Vec<(String, String, JsonValue)> = Vec::new();
1370
1371        // Version offset: ensures new trigger writes have strictly higher
1372        // versions than anything the checkpoint has already seen.
1373        let version_offset: u64 = if saved_checkpoint_exists {
1374            channel_versions
1375                .values()
1376                .filter_map(|v| v.as_str().and_then(|s| s.parse::<u64>().ok()))
1377                .max()
1378                .unwrap_or(0)
1379                + 1
1380        } else {
1381            0
1382        };
1383
1384        // Detect resume-from-Command vs. fresh invocation vs. fork
1385        let is_resuming = if let Ok(cmd) = serde_json::from_value::<Command>(input.clone()) {
1386            let cmd_writes = map_command(&cmd);
1387            let has_resume = cmd_writes.iter().any(|(_, chan, _)| chan == RESUME);
1388            pending_writes.extend(cmd_writes);
1389            has_resume
1390        } else {
1391            false
1392        };
1393        let is_fork = input.is_null() && saved_checkpoint_exists;
1394
1395        // Write input to channels on a fresh invocation only.
1396        // When resuming from an interrupt (is_resuming=true), the checkpoint already
1397        // has the full state; we must NOT re-trigger START because that would restart
1398        // the entire graph (compaction → llm → tools) from scratch, causing the
1399        // "memory confusion" / spurious LLM re-runs observed after tool denial.
1400        if !is_fork && !is_resuming {
1401            let input_writes = map_input(&[START.to_string()], input);
1402            for (chan, val) in &input_writes {
1403                if let Some(ch) = channels.get(chan) {
1404                    ch.update(&[val.clone()]).ok();
1405                }
1406            }
1407            if let Some(obj) = input.as_object() {
1408                for (key, val) in obj {
1409                    if key != START && !key.starts_with("branch:") && !key.starts_with("join:") {
1410                        if let Some(ch) = channels.get(key) {
1411                            ch.update(&[val.clone()]).ok();
1412                        }
1413                    }
1414                }
1415            }
1416            for (chan, _) in &input_writes {
1417                channel_versions.insert(
1418                    chan.clone(),
1419                    JsonValue::String(format!("{:032}", version_offset + step)),
1420                );
1421            }
1422            // Kick off the first nodes by writing START edge trigger channels
1423            for (start, end) in &self.edges {
1424                if start == START && end != END {
1425                    let trigger_ch = format!("branch:to:{}", end);
1426                    if let Some(ch) = channels.get(&trigger_ch) {
1427                        ch.update(&[JsonValue::String(end.clone())]).ok();
1428                        channel_versions.insert(
1429                            trigger_ch,
1430                            JsonValue::String(format!("{:032}", version_offset + step)),
1431                        );
1432                    }
1433                }
1434            }
1435        }
1436
1437
1438        // ── Super-step loop ──────────────────────────────────────────────────
1439
1440        while step < max_steps {
1441            let checkpoint_id = format!("{:032}", version_offset + step);
1442
1443            // PLAN: determine which nodes to run this step
1444            let mut tasks = prepare_next_tasks(
1445                &pregel_nodes,
1446                &channels,
1447                &config,
1448                version_offset + step,
1449                &mut versions_seen,
1450                &trigger_to_nodes,
1451                None,
1452                &checkpoint_id,
1453                &pending_writes,
1454                &channel_versions,
1455            );
1456
1457
1458
1459            if tasks.is_empty() {
1460                break;
1461            }
1462            
1463            // Consume pending writes (especially RESUME) so they don't apply to subsequent supersteps
1464            pending_writes.clear();
1465
1466            // ── Streaming: emit task-start events ───────────────────────────
1467            if let Some(s) = stream {
1468                if s.has(&StreamMode::Tasks) {
1469                    for task in &tasks {
1470                        let data = serde_json::json!({
1471                            "id": task.id,
1472                            "name": task.name,
1473                            "triggers": task.triggers,
1474                        });
1475                        let _ = s.tx.send(StreamPart::tasks(vec![], data)).await;
1476                    }
1477                }
1478            }
1479
1480            // interrupt_before: pause before running the matched nodes
1481            if !self.interrupt_before.is_empty() {
1482                let task_names: Vec<String> = tasks.iter().map(|t| t.name.clone()).collect();
1483                if task_names.iter().any(|n| self.interrupt_before.contains(n)) {
1484                    if let Some(ref cp) = self.checkpointer {
1485                        if let Some(new_config) = self.save_checkpoint(cp, &config, &channels, &channel_versions, &versions_seen) {
1486                            config = new_config;
1487                        }
1488                    }
1489                    // Streaming: emit values before returning
1490                    if let Some(s) = stream {
1491                        if s.has(&StreamMode::Values) {
1492                            let keys = output_channel_keys(&channels);
1493                            let _ = s.tx.send(StreamPart::values(vec![], read_channels(&channels, &keys))).await;
1494                        }
1495                    }
1496                    let keys = output_channel_keys(&channels);
1497                    return Ok(read_channels(&channels, &keys));
1498                }
1499            }
1500
1501            // EXECUTE: build runner (with custom-stream writer in streaming mode)
1502            let runner = if let Some(s) = stream {
1503                let runtime = Arc::new(crate::runtime::Runtime {
1504                    context: (),
1505                    store: self.store.clone(),
1506                    stream_writer: s.custom_tx.clone(),
1507                    previous: None,
1508                    execution_info: None,
1509                    server_info: None,
1510                });
1511                if s.custom_tx.is_some() {
1512                    PregelRunner::new(Some(runtime.clone()))
1513                        .with_stream_writer(s.custom_tx.clone().unwrap())
1514                } else {
1515                    PregelRunner::new(Some(runtime))
1516                }
1517            } else {
1518                PregelRunner::new(self.store.clone().map(|_| {
1519                    Arc::new(crate::runtime::Runtime {
1520                        context: (),
1521                        store: self.store.clone(),
1522                        stream_writer: None,
1523                        previous: None,
1524                        execution_info: None,
1525                        server_info: None,
1526                    })
1527                }))
1528            };
1529
1530            match runner.run_tasks(&mut tasks).await {
1531                Ok(()) => {}
1532
1533                Err(crate::pregel::runner::RunnerError::Interrupt { task_id, interrupt }) => {
1534                    // Mirrors Python's _suppress_interrupt:
1535                    // Apply writes from tasks that completed *before* the interrupt
1536                    // so their trigger-channel writes survive into the checkpoint.
1537                    // The interrupted task is excluded so it re-triggers on resume.
1538                    apply_completed_writes(
1539                        &task_id,
1540                        &tasks,
1541                        &channels,
1542                        &mut versions_seen,
1543                        &mut channel_versions,
1544                    );
1545
1546                    // Save checkpoint (now includes completed tasks' channel writes)
1547                    if let Some(ref cp) = self.checkpointer {
1548                        if let Some(new_config) = self.save_checkpoint(cp, &config, &channels, &channel_versions, &versions_seen) {
1549                            config = new_config;
1550                        }
1551                        // Save interrupt as pending writes for get_state()
1552                        let iw: Vec<(String, String, JsonValue)> = interrupt
1553                            .interrupts
1554                            .iter()
1555                            .map(|iv| {
1556                                let val = serde_json::to_value(iv).unwrap_or(JsonValue::Null);
1557                                (task_id.clone(), crate::constants::INTERRUPT.to_string(), val)
1558                            })
1559                            .collect();
1560                        if !iw.is_empty() {
1561                            if let Err(e) = cp.put_writes(&config, &iw, &task_id, "") {
1562                                eprintln!("[CHECKPOINT] Failed to save interrupt writes: {}", e);
1563                            }
1564                        }
1565                    }
1566
1567                    // Streaming: emit values before returning
1568                    if let Some(s) = stream {
1569                        if s.has(&StreamMode::Values) {
1570                            let keys = output_channel_keys(&channels);
1571                            let _ = s.tx.send(StreamPart::values(vec![], read_channels(&channels, &keys))).await;
1572                        }
1573                    }
1574
1575                    let keys = output_channel_keys(&channels);
1576                    return Ok(read_channels(&channels, &keys));
1577                }
1578
1579                Err(other) => return Err(RunnableError::Runner(other.to_string())),
1580            }
1581
1582            // ── Streaming: emit per-node updates ────────────────────────────
1583            if let Some(s) = stream {
1584                if s.has(&StreamMode::Updates) {
1585                    for task in &tasks {
1586                        if !task.writes.is_empty() {
1587                            let mut node_updates = serde_json::Map::new();
1588                            for (chan, val) in &task.writes {
1589                                if !chan.starts_with("branch:") && !chan.starts_with("join:") {
1590                                    node_updates.insert(chan.clone(), val.clone());
1591                                }
1592                            }
1593                            if !node_updates.is_empty() {
1594                                let data = serde_json::json!({ &task.name: node_updates });
1595                                let _ = s.tx.send(StreamPart::updates(vec![], data)).await;
1596                            }
1597                        }
1598                    }
1599                }
1600            }
1601
1602            // UPDATE: apply all task writes to channels
1603            apply_writes(
1604                &mut channels,
1605                &tasks,
1606                &mut versions_seen,
1607                &mut channel_versions,
1608                &trigger_to_nodes,
1609                bump_version,
1610            );
1611
1612            // ── DEBUG: 打印 apply_writes 后 messages channel 状态 ──
1613            // {
1614            //     let task_names: Vec<&str> = tasks.iter().map(|t| t.name.as_str()).collect();
1615            //     let msg_count = channels.get("messages")
1616            //         .and_then(|ch| ch.get().ok())
1617            //         .and_then(|v| v.as_array().map(|a| a.len()))
1618            //         .unwrap_or(0);
1619            //     eprintln!("[DEBUG][pregel] step={} tasks={:?} after apply_writes: messages.len={}", step, task_names, msg_count);
1620            // }
1621
1622            // Save "loop" checkpoint after each completed super-step
1623            if let Some(ref cp) = self.checkpointer {
1624                if let Some(new_config) = self.save_checkpoint(cp, &config, &channels, &channel_versions, &versions_seen) {
1625                    config = new_config;
1626                }
1627            }
1628
1629            // ── Streaming: emit values after writes ──────────────────────────
1630            if let Some(s) = stream {
1631                if s.has(&StreamMode::Values) {
1632                    let keys = output_channel_keys(&channels);
1633                    let _ = s.tx.send(StreamPart::values(vec![], read_channels(&channels, &keys))).await;
1634                }
1635            }
1636
1637            // Read output
1638            let keys = output_channel_keys(&channels);
1639            let output = read_channels(&channels, &keys);
1640            if !output.is_null() {
1641                last_output = output;
1642            }
1643
1644            // interrupt_after: pause after the matched nodes complete
1645            if !self.interrupt_after.is_empty() {
1646                let task_names: Vec<String> = tasks.iter().map(|t| t.name.clone()).collect();
1647                if task_names.iter().any(|n| self.interrupt_after.contains(n)) {
1648                    return Ok(last_output);
1649                }
1650            }
1651
1652            step += 1;
1653        }
1654
1655        Ok(last_output)
1656    }
1657
1658
1659}
1660
1661#[async_trait]
1662impl Runnable for CompiledStateGraph {
1663    fn invoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
1664        // Block on the async implementation
1665        match tokio::runtime::Handle::try_current() {
1666            Ok(handle) => handle.block_on(self.run_pregel(input, config)),
1667            Err(_) => tokio::runtime::Runtime::new()
1668                .unwrap()
1669                .block_on(self.run_pregel(input, config)),
1670        }
1671    }
1672
1673    async fn ainvoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
1674        self.run_pregel(input, config).await
1675    }
1676
1677    fn name(&self) -> &str {
1678        &self.name
1679    }
1680}
1681
1682#[cfg(test)]
1683mod tests {
1684    use super::*;
1685    use crate::channels::LastValue;
1686    use serde_json::json;
1687
1688    fn make_channels() -> HashMap<String, Box<dyn Channel>> {
1689        let mut channels = HashMap::new();
1690        channels.insert("value".to_string(), Box::new(LastValue::new("value")) as Box<dyn Channel>);
1691        channels
1692    }
1693
1694    #[tokio::test]
1695    async fn test_simple_linear_graph() {
1696        let mut graph = StateGraph::new(make_channels());
1697
1698        graph
1699            .add_node("a", |_input, _config| async { Ok(json!({"value": 1})) })
1700            .unwrap();
1701        graph
1702            .add_node("b", |_input, _config| async { Ok(json!({"value": 2})) })
1703            .unwrap();
1704
1705        graph.add_edge(START, "a").unwrap();
1706        graph.add_edge("a", "b").unwrap();
1707        graph.add_edge("b", END).unwrap();
1708
1709        let compiled = graph.compile().unwrap();
1710        assert!(compiled.has_node("a"));
1711        assert!(compiled.has_node("b"));
1712        assert_eq!(compiled.node_names().len(), 2);
1713    }
1714
1715    #[test]
1716    fn test_duplicate_node_error() {
1717        let mut graph = StateGraph::new(make_channels());
1718        graph.add_node("a", |_input, _config| async { Ok(json!({})) }).unwrap();
1719        let result = graph.add_node("a", |_input, _config| async { Ok(json!({})) });
1720        assert!(result.is_err());
1721    }
1722
1723    #[test]
1724    fn test_reserved_name_error() {
1725        let mut graph = StateGraph::new(make_channels());
1726        let result = graph.add_node(START, |_input, _config| async { Ok(json!({})) });
1727        assert!(result.is_err());
1728    }
1729
1730    #[test]
1731    fn test_end_as_source_error() {
1732        let mut graph = StateGraph::new(make_channels());
1733        graph.add_node("a", |_input, _config| async { Ok(json!({})) }).unwrap();
1734        let result = graph.add_edge(END, "a");
1735        assert!(result.is_err());
1736    }
1737
1738    #[test]
1739    fn test_start_as_target_error() {
1740        let mut graph = StateGraph::new(make_channels());
1741        graph.add_node("a", |_input, _config| async { Ok(json!({})) }).unwrap();
1742        let result = graph.add_edge("a", START);
1743        assert!(result.is_err());
1744    }
1745
1746    #[test]
1747    fn test_no_start_edge_error() {
1748        let mut graph = StateGraph::new(make_channels());
1749        graph.add_node("a", |_input, _config| async { Ok(json!({})) }).unwrap();
1750        let result = graph.compile();
1751        assert!(result.is_err());
1752    }
1753
1754    #[test]
1755    fn test_join_edge() {
1756        let mut graph = StateGraph::new(make_channels());
1757        graph.add_node("a", |_input, _config| async { Ok(json!({})) }).unwrap();
1758        graph.add_node("b", |_input, _config| async { Ok(json!({})) }).unwrap();
1759        graph.add_node("c", |_input, _config| async { Ok(json!({})) }).unwrap();
1760
1761        graph.add_edge(START, "a").unwrap();
1762        graph.add_edge(START, "b").unwrap();
1763        graph.add_join_edge(vec!["a".to_string(), "b".to_string()], "c").unwrap();
1764        graph.add_edge("c", END).unwrap();
1765
1766        let compiled = graph.compile().unwrap();
1767        assert_eq!(compiled.node_names().len(), 3);
1768    }
1769
1770    #[test]
1771    fn test_conditional_edges() {
1772        let mut graph = StateGraph::new(make_channels());
1773        graph.add_node("agent", |_input, _config| async { Ok(json!({})) }).unwrap();
1774        graph.add_node("tools", |_input, _config| async { Ok(json!({})) }).unwrap();
1775
1776        graph.add_edge(START, "agent").unwrap();
1777        graph
1778            .add_conditional_edges(
1779                "agent",
1780                |_input, _config| async { Ok(json!("continue")) },
1781                Some(HashMap::from([
1782                    ("continue".to_string(), "tools".to_string()),
1783                    ("end".to_string(), END.to_string()),
1784                ])),
1785            )
1786            .unwrap();
1787        graph.add_edge("tools", "agent").unwrap();
1788
1789        let compiled = graph.compile().unwrap();
1790        assert!(compiled.has_node("agent"));
1791        assert!(compiled.has_node("tools"));
1792    }
1793
1794    #[tokio::test]
1795    async fn test_invoke_linear_graph() {
1796        // End-to-end test: build graph → compile → invoke → check output
1797        let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
1798        channels.insert("count".to_string(), Box::new(LastValue::new("count")) as Box<dyn Channel>);
1799
1800        let mut graph = StateGraph::new(channels);
1801
1802        graph
1803            .add_node("increment", |_input, _config| async {
1804                Ok(json!({"count": 1}))
1805            })
1806            .unwrap();
1807        graph
1808            .add_node("double", |_input, _config| async {
1809                Ok(json!({"count": 2}))
1810            })
1811            .unwrap();
1812
1813        graph.add_edge(START, "increment").unwrap();
1814        graph.add_edge("increment", "double").unwrap();
1815        graph.add_edge("double", END).unwrap();
1816
1817        let compiled = graph.compile().unwrap();
1818        let config = RunnableConfig::new();
1819        let result = compiled.ainvoke(&json!({"count": 0}), &config).await.unwrap();
1820
1821        // The output should contain the "count" channel value
1822        assert!(result.is_object());
1823        // After "double" runs, count should be 2
1824        assert_eq!(result.get("count"), Some(&json!(2)));
1825    }
1826
1827    #[tokio::test]
1828    async fn test_invoke_single_node() {
1829        let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
1830        channels.insert("result".to_string(), Box::new(LastValue::new("result")) as Box<dyn Channel>);
1831
1832        let mut graph = StateGraph::new(channels);
1833        graph
1834            .add_node("process", |_input, _config| async {
1835                Ok(json!({"result": 42}))
1836            })
1837            .unwrap();
1838        graph.add_edge(START, "process").unwrap();
1839        graph.add_edge("process", END).unwrap();
1840
1841        let compiled = graph.compile().unwrap();
1842        let config = RunnableConfig::new();
1843        let result = compiled.ainvoke(&json!({}), &config).await.unwrap();
1844
1845        assert_eq!(result.get("result"), Some(&json!(42)));
1846    }
1847
1848    #[tokio::test]
1849    async fn test_interrupt_before() {
1850        // Test interrupt_before: graph pauses before executing the specified node
1851        let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
1852        channels.insert("value".to_string(), Box::new(LastValue::new("value")) as Box<dyn Channel>);
1853
1854        let mut graph = StateGraph::new(channels);
1855
1856        graph
1857            .add_node("process", |_input, _config| async {
1858                Ok(json!({"value": 42}))
1859            })
1860            .unwrap();
1861        graph.add_edge(START, "process").unwrap();
1862        graph.add_edge("process", END).unwrap();
1863
1864        let mut compiled = graph.compile().unwrap();
1865        // Set interrupt_before to pause before "process" node
1866        compiled.interrupt_before = vec!["process".to_string()];
1867
1868        let config = RunnableConfig::new();
1869        let result = compiled.ainvoke(&json!({}), &config).await.unwrap();
1870
1871        // Graph should return current state (empty since process hasn't run yet)
1872        assert!(result.is_object());
1873        // The "value" channel should not have been set yet
1874        assert!(result.get("value").is_none() || result.get("value").unwrap().is_null());
1875    }
1876
1877    #[tokio::test]
1878    async fn test_interrupt_after() {
1879        // Test interrupt_after: graph pauses after executing the specified node
1880        let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
1881        channels.insert("value".to_string(), Box::new(LastValue::new("value")) as Box<dyn Channel>);
1882
1883        let mut graph = StateGraph::new(channels);
1884
1885        graph
1886            .add_node("process", |_input, _config| async {
1887                Ok(json!({"value": 42}))
1888            })
1889            .unwrap();
1890        graph.add_edge(START, "process").unwrap();
1891        graph.add_edge("process", END).unwrap();
1892
1893        let mut compiled = graph.compile().unwrap();
1894        // Set interrupt_after to pause after "process" node
1895        compiled.interrupt_after = vec!["process".to_string()];
1896
1897        let config = RunnableConfig::new();
1898        let result = compiled.ainvoke(&json!({}), &config).await.unwrap();
1899
1900        // Graph should return current state with the value from "process"
1901        assert!(result.is_object());
1902        assert_eq!(result.get("value"), Some(&json!(42)));
1903    }
1904
1905    #[tokio::test]
1906    async fn test_update_state() {
1907        use crate::channels::LastValue;
1908        use langgraph_checkpoint::checkpoint::memory::InMemorySaver;
1909
1910        let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
1911        channels.insert("name".to_string(), Box::new(LastValue::new("name")) as Box<dyn Channel>);
1912        channels.insert("value".to_string(), Box::new(LastValue::new("value")) as Box<dyn Channel>);
1913
1914        let mut graph = StateGraph::new(channels);
1915        graph
1916            .add_node("set_value", |_input, _config| async {
1917                Ok(json!({"value": 42}))
1918            })
1919            .unwrap();
1920        graph.add_edge(START, "set_value").unwrap();
1921        graph.add_edge("set_value", END).unwrap();
1922
1923        let checkpointer = Arc::new(InMemorySaver::new());
1924        let compiled = graph.compile_builder()
1925            .checkpointer(checkpointer)
1926            .build()
1927            .unwrap();
1928
1929        let mut config = RunnableConfig::new();
1930        config.insert("configurable".to_string(), json!({"thread_id": "test-thread"}));
1931
1932        // First invoke
1933        let result = compiled.ainvoke(&json!({"name": "original"}), &config).await.unwrap();
1934        assert_eq!(result.get("value"), Some(&json!(42)));
1935
1936        // Verify get_state
1937        let snapshot = compiled.get_state(&config).unwrap();
1938        assert_eq!(snapshot.values.get("name").and_then(|v| v.as_str()), Some("original"));
1939        assert_eq!(snapshot.values.get("value").and_then(|v| v.as_i64()), Some(42));
1940
1941        // Update state
1942        compiled.update_state(&config, &json!({"name": "updated"})).unwrap();
1943
1944        // Verify update took effect
1945        let snapshot = compiled.get_state(&config).unwrap();
1946        assert_eq!(snapshot.values.get("name").and_then(|v| v.as_str()), Some("updated"));
1947        assert_eq!(snapshot.values.get("value").and_then(|v| v.as_i64()), Some(42));
1948    }
1949}