1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use async_trait::async_trait;
4use serde_json::Value as JsonValue;
5use tokio::sync::mpsc;
6use tokio_stream::wrappers::ReceiverStream;
7use langgraph_checkpoint::config::{RunnableConfig, RunnableConfigExt};
8use langgraph_checkpoint::cache::base::BaseCache;
9use langgraph_checkpoint::store::base::BaseStore;
10use langgraph_checkpoint::checkpoint::base::BaseCheckpointSaver;
11use crate::channels::{Channel, EphemeralValue, NamedBarrierValue};
12use crate::constants::{START, END, RESUME, INTERRUPT, NULL_TASK_ID};
13use crate::runnable::{Runnable, RunnableError, IntoNodeFunction};
14use crate::graph::node::StateNodeSpec;
15use crate::graph::branch::BranchSpec;
16use crate::pregel::{PregelNode, PregelRunner, ChannelVersions, channels_from_checkpoint, PregelExecutableTask};
17use crate::pregel::algo::{prepare_next_tasks, apply_writes};
18use crate::pregel::io::{map_input, map_command, read_channels};
19use crate::stream::StreamPart;
20use crate::types::{Command, StreamMode, StateSnapshot, PregelTask, Interrupt};
21use langgraph_checkpoint::checkpoint::types::CheckpointMetadata;
22
23type WaitingEdge = (Vec<String>, String);
25
26#[derive(Debug, thiserror::Error)]
28pub enum GraphError {
29 #[error("node '{0}' already exists")]
30 DuplicateNode(String),
31
32 #[error("unknown node '{0}'")]
33 UnknownNode(String),
34
35 #[error("cannot use reserved name '{0}'")]
36 ReservedName(String),
37
38 #[error("START cannot be an edge target")]
39 StartAsTarget,
40
41 #[error("END cannot be an edge source")]
42 EndAsSource,
43
44 #[error("no outgoing edge from START")]
45 NoStartEdge,
46
47 #[error("graph validation failed: {0}")]
48 ValidationError(String),
49
50 #[error(transparent)]
51 Runnable(#[from] RunnableError),
52
53 #[error("checkpoint error: {0}")]
54 Checkpoint(String),
55}
56
57pub struct StateGraph {
73 nodes: HashMap<String, StateNodeSpec>,
75 edges: HashSet<(String, String)>,
77 waiting_edges: HashSet<WaitingEdge>,
79 branches: HashMap<String, HashMap<String, BranchSpec>>,
81 channels: HashMap<String, Box<dyn Channel>>,
83 compiled: bool,
85}
86
87impl StateGraph {
88 pub fn new(channels: HashMap<String, Box<dyn Channel>>) -> Self {
92 Self {
93 nodes: HashMap::new(),
94 edges: HashSet::new(),
95 waiting_edges: HashSet::new(),
96 branches: HashMap::new(),
97 channels,
98 compiled: false,
99 }
100 }
101
102 pub fn add_node(
121 &mut self,
122 name: impl Into<String>,
123 action: impl IntoNodeFunction,
124 ) -> Result<&mut Self, GraphError> {
125 let name = name.into();
126 self.validate_node_name(&name)?;
127 let runnable = action.into_runnable(&name);
128 self.nodes.insert(name.clone(), StateNodeSpec::new(name, runnable));
129 Ok(self)
130 }
131
132 pub fn add_edge(
137 &mut self,
138 start: impl Into<String>,
139 end: impl Into<String>,
140 ) -> Result<&mut Self, GraphError> {
141 let start = start.into();
142 let end = end.into();
143
144 if start == END {
145 return Err(GraphError::EndAsSource);
146 }
147 if end == START {
148 return Err(GraphError::StartAsTarget);
149 }
150 if start != START && !self.nodes.contains_key(&start) {
151 return Err(GraphError::UnknownNode(start));
152 }
153 if end != END && !self.nodes.contains_key(&end) {
154 return Err(GraphError::UnknownNode(end));
155 }
156
157 self.edges.insert((start, end));
158 Ok(self)
159 }
160
161 pub fn add_join_edge(
165 &mut self,
166 starts: Vec<String>,
167 end: impl Into<String>,
168 ) -> Result<&mut Self, GraphError> {
169 let end = end.into();
170 if end == START {
171 return Err(GraphError::StartAsTarget);
172 }
173 for s in &starts {
174 if s == END {
175 return Err(GraphError::EndAsSource);
176 }
177 if s != START && !self.nodes.contains_key(s) {
178 return Err(GraphError::UnknownNode(s.clone()));
179 }
180 }
181 if end != END && !self.nodes.contains_key(&end) {
182 return Err(GraphError::UnknownNode(end));
183 }
184 self.waiting_edges.insert((starts, end));
185 Ok(self)
186 }
187
188 pub fn add_conditional_edges(
194 &mut self,
195 source: impl Into<String>,
196 path: impl IntoNodeFunction,
197 path_map: Option<HashMap<String, String>>,
198 ) -> Result<&mut Self, GraphError> {
199 let source = source.into();
200 if source != START && !self.nodes.contains_key(&source) {
201 return Err(GraphError::UnknownNode(source));
202 }
203
204 let branch_name = format!("branch:{}", source);
205 let runnable = path.into_runnable(&branch_name);
206 let branch = BranchSpec::new(runnable, path_map);
207
208 self.branches
209 .entry(source)
210 .or_default()
211 .insert(branch_name, branch);
212
213 Ok(self)
214 }
215
216 pub fn set_entry_point(&mut self, key: impl Into<String>) -> Result<&mut Self, GraphError> {
218 self.add_edge(START, key)
219 }
220
221 pub fn set_finish_point(&mut self, key: impl Into<String>) -> Result<&mut Self, GraphError> {
223 self.add_edge(key, END)
224 }
225
226 pub fn compile(&mut self) -> Result<CompiledStateGraph, GraphError> {
233 self.compile_with(None, None, None, None, None, false, None, None)
234 }
235
236 pub fn compile_builder(&mut self) -> CompileBuilder<'_> {
246 CompileBuilder {
247 graph: self,
248 checkpointer: None,
249 cache: None,
250 store: None,
251 interrupt_before: None,
252 interrupt_after: None,
253 debug: false,
254 name: None,
255 recursion_limit: None,
256 }
257 }
258
259 fn compile_with(
261 &mut self,
262 checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
263 cache: Option<Arc<dyn BaseCache>>,
264 store: Option<Arc<dyn BaseStore>>,
265 interrupt_before: Option<Vec<String>>,
266 interrupt_after: Option<Vec<String>>,
267 debug: bool,
268 name: Option<String>,
269 recursion_limit: Option<u64>,
270 ) -> Result<CompiledStateGraph, GraphError> {
271 self.validate()?;
272
273 self.channels.insert(
275 START.to_string(),
276 Box::new(EphemeralValue::new(START, false)),
277 );
278
279 for name in self.nodes.keys() {
281 let trigger_key = format!("branch:to:{}", name);
282 self.channels
283 .insert(trigger_key.clone(), Box::new(EphemeralValue::new(trigger_key, false)));
284 }
285
286 for (sources, target) in &self.waiting_edges {
288 let barrier_name = format!("join:{}:{}", sources.join("+"), target);
289 let names: HashSet<String> = sources.iter().cloned().collect();
290 self.channels.insert(
291 barrier_name.clone(),
292 Box::new(NamedBarrierValue::new(barrier_name, names)),
293 );
294 }
295
296 self.compiled = true;
297
298 let channels = self.channels
299 .iter()
300 .map(|(k, c)| (k.clone(), c.clone_channel()))
301 .collect();
302
303 Ok(CompiledStateGraph {
304 nodes: self.nodes.clone(),
305 edges: self.edges.clone(),
306 waiting_edges: self.waiting_edges.clone(),
307 branches: self.branches.clone(),
308 channels,
309 checkpointer,
310 cache,
311 store,
312 interrupt_before: interrupt_before.unwrap_or_default(),
313 interrupt_after: interrupt_after.unwrap_or_default(),
314 debug,
315 name: name.unwrap_or_else(|| "StateGraph".to_string()),
316 recursion_limit: recursion_limit.unwrap_or(DEFAULT_RECURSION_LIMIT),
317 })
318 }
319
320 fn validate_node_name(&self, name: &str) -> Result<(), GraphError> {
321 if name == START || name == END {
322 return Err(GraphError::ReservedName(name.to_string()));
323 }
324 if self.nodes.contains_key(name) {
325 return Err(GraphError::DuplicateNode(name.to_string()));
326 }
327 Ok(())
328 }
329
330 fn validate(&self) -> Result<(), GraphError> {
331 let has_start_edge = self.edges.iter().any(|(s, _)| s == START)
333 || self.waiting_edges.iter().any(|(s, _)| s.contains(&START.to_string()))
334 || self.branches.contains_key(START);
335 if !has_start_edge {
336 return Err(GraphError::NoStartEdge);
337 }
338
339 for (start, end) in &self.edges {
341 if start != START && !self.nodes.contains_key(start) {
342 return Err(GraphError::UnknownNode(start.clone()));
343 }
344 if end != END && !self.nodes.contains_key(end) {
345 return Err(GraphError::UnknownNode(end.clone()));
346 }
347 }
348
349 Ok(())
350 }
351}
352
353pub struct CompileBuilder<'a> {
355 graph: &'a mut StateGraph,
356 checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
357 cache: Option<Arc<dyn BaseCache>>,
358 store: Option<Arc<dyn BaseStore>>,
359 interrupt_before: Option<Vec<String>>,
360 interrupt_after: Option<Vec<String>>,
361 debug: bool,
362 name: Option<String>,
363 recursion_limit: Option<u64>,
364}
365
366impl<'a> CompileBuilder<'a> {
367 pub fn checkpointer(mut self, cp: Arc<dyn BaseCheckpointSaver>) -> Self {
368 self.checkpointer = Some(cp);
369 self
370 }
371
372 pub fn cache(mut self, cache: Arc<dyn BaseCache>) -> Self {
373 self.cache = Some(cache);
374 self
375 }
376
377 pub fn store(mut self, store: Arc<dyn BaseStore>) -> Self {
378 self.store = Some(store);
379 self
380 }
381
382 pub fn interrupt_before(mut self, nodes: Vec<String>) -> Self {
383 self.interrupt_before = Some(nodes);
384 self
385 }
386
387 pub fn interrupt_after(mut self, nodes: Vec<String>) -> Self {
388 self.interrupt_after = Some(nodes);
389 self
390 }
391
392 pub fn debug(mut self, debug: bool) -> Self {
393 self.debug = debug;
394 self
395 }
396
397 pub fn name(mut self, name: impl Into<String>) -> Self {
398 self.name = Some(name.into());
399 self
400 }
401
402 pub fn recursion_limit(mut self, limit: u64) -> Self {
403 self.recursion_limit = Some(limit);
404 self
405 }
406
407 pub fn build(self) -> Result<CompiledStateGraph, GraphError> {
408 self.graph.compile_with(
409 self.checkpointer,
410 self.cache,
411 self.store,
412 self.interrupt_before,
413 self.interrupt_after,
414 self.debug,
415 self.name,
416 self.recursion_limit,
417 )
418 }
419}
420
421pub struct CompiledStateGraph {
427 nodes: HashMap<String, StateNodeSpec>,
428 edges: HashSet<(String, String)>,
429 waiting_edges: HashSet<WaitingEdge>,
430 branches: HashMap<String, HashMap<String, BranchSpec>>,
431 channels: HashMap<String, Box<dyn Channel>>,
432 checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
433 #[allow(dead_code)]
434 cache: Option<Arc<dyn BaseCache>>,
435 store: Option<Arc<dyn BaseStore>>,
436 interrupt_before: Vec<String>,
437 interrupt_after: Vec<String>,
438 debug: bool,
439 name: String,
440 recursion_limit: u64,
441}
442
443impl CompiledStateGraph {
444 pub fn node_names(&self) -> Vec<String> {
446 self.nodes.keys().cloned().collect()
447 }
448
449 pub fn channel_names(&self) -> Vec<String> {
451 self.channels.keys().cloned().collect()
452 }
453
454 pub fn has_node(&self, name: &str) -> bool {
456 self.nodes.contains_key(name)
457 }
458
459 pub fn name(&self) -> &str {
461 &self.name
462 }
463
464 pub fn checkpointer(&self) -> Option<&Arc<dyn BaseCheckpointSaver>> {
466 self.checkpointer.as_ref()
467 }
468
469 pub fn store(&self) -> Option<&Arc<dyn BaseStore>> {
471 self.store.as_ref()
472 }
473
474 fn save_checkpoint(
476 &self,
477 checkpointer: &Arc<dyn BaseCheckpointSaver>,
478 config: &RunnableConfig,
479 channels: &HashMap<String, Box<dyn Channel>>,
480 channel_versions: &ChannelVersions,
481 versions_seen: &HashMap<String, HashMap<String, JsonValue>>,
482 ) -> Option<RunnableConfig> {
483 use langgraph_checkpoint::checkpoint::id::uuid6;
484 use chrono::Utc;
485
486 let channel_values: HashMap<String, JsonValue> = channels
488 .iter()
489 .filter_map(|(k, v)| v.checkpoint().map(|val| (k.clone(), val)))
490 .collect();
491
492 let checkpoint = langgraph_checkpoint::Checkpoint {
493 v: 2,
494 id: uuid6(),
495 ts: Utc::now().to_rfc3339(),
496 channel_values,
497 channel_versions: channel_versions.clone(),
498 versions_seen: versions_seen.clone(),
499 updated_channels: None,
500 };
501
502 let metadata = CheckpointMetadata::default();
503 checkpointer.put(config, &checkpoint, &metadata, channel_versions).ok()
504 }
505
506 pub fn get_next_nodes(&self, state: &HashMap<String, JsonValue>) -> Vec<String> {
510 let mut next = Vec::new();
511
512 for (start, end) in &self.edges {
514 if (start == START || state.contains_key(&format!("branch:to:{}", start)))
515 && end != END {
516 next.push(end.clone());
517 }
518 }
519
520 for (source, branches) in &self.branches {
522 if source == START || state.contains_key(&format!("branch:to:{}", source)) {
523 for _branch in branches.values() {
524 }
528 }
529 }
530
531 next
532 }
533
534 pub fn get_state(&self, config: &RunnableConfig) -> Result<StateSnapshot, GraphError> {
549 let checkpointer = self.checkpointer.as_ref().ok_or_else(|| {
550 GraphError::ValidationError("No checkpointer set".to_string())
551 })?;
552
553 let saved = checkpointer
554 .get_tuple(config)
555 .map_err(|e| GraphError::Checkpoint(e.to_string()))?;
556
557 let Some(saved) = saved else {
558 return Ok(StateSnapshot {
559 values: JsonValue::Object(serde_json::Map::new()),
560 next: vec![],
561 config: config.clone(),
562 metadata: None,
563 created_at: None,
564 parent_config: None,
565 tasks: vec![],
566 interrupts: vec![],
567 });
568 };
569
570 let cp_channels: HashMap<String, Option<JsonValue>> = saved
572 .checkpoint
573 .channel_values
574 .iter()
575 .map(|(k, v)| (k.clone(), Some(v.clone())))
576 .collect();
577 let mut channels = channels_from_checkpoint(&self.channels, &cp_channels);
578
579 let mut channel_versions = saved.checkpoint.channel_versions.clone();
580 let mut versions_seen = saved.checkpoint.versions_seen.clone();
581
582 if let Some(ref pending) = saved.pending_writes {
584 for (tid, chan, val) in pending {
585 if tid == NULL_TASK_ID {
586 if let Some(ch) = channels.get(chan) {
587 ch.update(&[val.clone()]).ok();
588 }
589 }
590 }
591 }
592
593 let pregel_nodes = build_pregel_nodes(
595 &self.nodes,
596 &self.edges,
597 &self.waiting_edges,
598 &self.branches,
599 &self.channels,
600 );
601 let trigger_to_nodes = crate::pregel::build_trigger_to_nodes(&pregel_nodes);
602
603 let step = 0u64;
604 let checkpoint_id = format!("{:032}", step);
605 let pending_writes: Vec<(String, String, JsonValue)> = saved
606 .pending_writes
607 .as_ref()
608 .map(|pw| pw.to_vec())
609 .unwrap_or_default();
610
611 let mut tasks = prepare_next_tasks(
612 &pregel_nodes,
613 &channels,
614 config,
615 step,
616 &mut versions_seen,
617 &trigger_to_nodes,
618 None,
619 &checkpoint_id,
620 &pending_writes,
621 &channel_versions,
622 );
623
624 if let Some(ref pending) = saved.pending_writes {
627 for (tid, chan, val) in pending {
628 if chan == INTERRUPT || chan == crate::constants::ERROR {
629 continue;
630 }
631 if tid == NULL_TASK_ID {
632 continue;
633 }
634 if let Some(task) = tasks.iter_mut().find(|t| &t.id == tid) {
635 task.writes.push((chan.clone(), val.clone()));
636 }
637 }
638 }
639
640 apply_writes(
642 &mut channels,
643 &tasks,
644 &mut versions_seen,
645 &mut channel_versions,
646 &trigger_to_nodes,
647 |current| {
648 let num = current
649 .and_then(|v| v.as_str())
650 .and_then(|s| s.parse::<u64>().ok())
651 .unwrap_or(0);
652 JsonValue::String(format!("{:032}", num + 1))
653 },
654 );
655
656 let output_keys: Vec<String> = channels
658 .keys()
659 .filter(|k| !k.starts_with("branch:") && !k.starts_with("join:") && *k != START)
660 .cloned()
661 .collect();
662 let values = read_channels(&channels, &output_keys);
663
664 let next: Vec<String> = tasks
666 .iter()
667 .filter(|t| t.writes.is_empty())
668 .map(|t| t.name.clone())
669 .collect();
670
671 let interrupts: Vec<Interrupt> = saved
673 .pending_writes
674 .as_ref()
675 .map(|pw| {
676 pw.iter()
677 .filter(|(_, chan, _)| chan == INTERRUPT)
678 .filter_map(|(_, _, val)| {
679 serde_json::from_value::<Interrupt>(val.clone()).ok()
680 })
681 .collect()
682 })
683 .unwrap_or_default();
684
685 let snapshot_tasks: Vec<PregelTask> = tasks
687 .iter()
688 .map(|t| {
689 let task_interrupts: Vec<Interrupt> = saved
690 .pending_writes
691 .as_ref()
692 .map(|pw| {
693 pw.iter()
694 .filter(|(tid, chan, _)| tid == &t.id && chan == INTERRUPT)
695 .filter_map(|(_, _, val)| {
696 serde_json::from_value::<Interrupt>(val.clone()).ok()
697 })
698 .collect()
699 })
700 .unwrap_or_default();
701
702 PregelTask {
703 id: t.id.clone(),
704 name: t.name.clone(),
705 path: vec![],
706 error: None,
707 interrupts: task_interrupts,
708 result: None,
709 }
710 })
711 .collect();
712
713 Ok(StateSnapshot {
714 values,
715 next,
716 config: saved.config.clone(),
717 metadata: Some(saved.metadata.clone()),
718 created_at: Some(saved.checkpoint.ts.clone()),
719 parent_config: saved.parent_config.clone(),
720 tasks: snapshot_tasks,
721 interrupts,
722 })
723 }
724
725 pub fn update_state(
744 &self,
745 config: &RunnableConfig,
746 values: &JsonValue,
747 ) -> Result<RunnableConfig, GraphError> {
748 let checkpointer = self.checkpointer.as_ref().ok_or_else(|| {
749 GraphError::ValidationError("No checkpointer set".to_string())
750 })?;
751
752 let saved = checkpointer
753 .get_tuple(config)
754 .map_err(|e| GraphError::Checkpoint(e.to_string()))?;
755
756 let channels: HashMap<String, Box<dyn Channel>> = if let Some(ref saved) = saved {
758 let cp_channels: HashMap<String, Option<JsonValue>> = saved
759 .checkpoint
760 .channel_values
761 .iter()
762 .map(|(k, v)| (k.clone(), Some(v.clone())))
763 .collect();
764 channels_from_checkpoint(&self.channels, &cp_channels)
765 } else {
766 self.channels
767 .iter()
768 .map(|(k, c)| (k.clone(), c.clone_channel()))
769 .collect()
770 };
771
772 let mut channel_versions = saved
773 .as_ref()
774 .map(|s| s.checkpoint.channel_versions.clone())
775 .unwrap_or_default();
776 let versions_seen = saved
777 .as_ref()
778 .map(|s| s.checkpoint.versions_seen.clone())
779 .unwrap_or_default();
780
781 if let Some(obj) = values.as_object() {
783 for (key, val) in obj {
784 if let Some(ch) = channels.get(key) {
785 ch.update(&[val.clone()]).ok();
786 let new_version = channel_versions
788 .get(key)
789 .and_then(|v| v.as_str())
790 .and_then(|s| s.parse::<u64>().ok())
791 .unwrap_or(0)
792 + 1;
793 channel_versions.insert(
794 key.clone(),
795 JsonValue::String(format!("{:032}", new_version)),
796 );
797 }
798 }
799 }
800
801 self.save_checkpoint(checkpointer, config, &channels, &channel_versions, &versions_seen);
803
804 Ok(config.clone())
805 }
806
807 pub fn get_state_history(&self, config: &RunnableConfig) -> Result<Vec<StateSnapshot>, GraphError> {
824 let checkpointer = self.checkpointer.as_ref().ok_or_else(|| {
825 GraphError::ValidationError("No checkpointer set".to_string())
826 })?;
827
828 let tuples = checkpointer
829 .list(Some(config), None, None, None)
830 .map_err(|e| GraphError::Checkpoint(e.to_string()))?;
831
832 let mut snapshots = Vec::new();
833
834 let pregel_nodes = build_pregel_nodes(
836 &self.nodes,
837 &self.edges,
838 &self.waiting_edges,
839 &self.branches,
840 &self.channels,
841 );
842 let trigger_to_nodes = crate::pregel::build_trigger_to_nodes(&pregel_nodes);
843
844 for saved in &tuples {
845 let cp_channels: HashMap<String, Option<JsonValue>> = saved
847 .checkpoint
848 .channel_values
849 .iter()
850 .map(|(k, v)| (k.clone(), Some(v.clone())))
851 .collect();
852 let channels = channels_from_checkpoint(&self.channels, &cp_channels);
853
854 let channel_versions = saved.checkpoint.channel_versions.clone();
855 let mut versions_seen = saved.checkpoint.versions_seen.clone();
856
857 if let Some(ref pending) = saved.pending_writes {
859 for (tid, chan, val) in pending {
860 if chan == INTERRUPT || chan == crate::constants::ERROR {
861 continue;
862 }
863 if tid == NULL_TASK_ID {
864 if let Some(ch) = channels.get(chan) {
865 ch.update(&[val.clone()]).ok();
866 }
867 continue;
868 }
869 if let Some(ch) = channels.get(chan) {
870 ch.update(&[val.clone()]).ok();
871 }
872 }
873 }
874
875 let output_keys: Vec<String> = channels
877 .keys()
878 .filter(|k| !k.starts_with("branch:") && !k.starts_with("join:") && *k != START)
879 .cloned()
880 .collect();
881 let values = read_channels(&channels, &output_keys);
882
883 let checkpoint_id = saved.checkpoint.id.clone();
885 let pending_writes: Vec<(String, String, JsonValue)> = saved
886 .pending_writes
887 .as_ref()
888 .map(|pw| pw.iter().map(|(t, c, v)| (t.clone(), c.clone(), v.clone())).collect())
889 .unwrap_or_default();
890
891 let tasks = prepare_next_tasks(
892 &pregel_nodes,
893 &channels,
894 &RunnableConfig::new(),
895 0,
896 &mut versions_seen,
897 &trigger_to_nodes,
898 None,
899 &checkpoint_id,
900 &pending_writes,
901 &channel_versions,
902 );
903
904 let next: Vec<String> = tasks
906 .iter()
907 .filter(|t| t.writes.is_empty())
908 .map(|t| t.name.clone())
909 .collect();
910
911 let interrupts: Vec<Interrupt> = saved
913 .pending_writes
914 .as_ref()
915 .map(|pw| {
916 pw.iter()
917 .filter(|(_, chan, _)| chan == INTERRUPT)
918 .filter_map(|(_, _, val)| {
919 serde_json::from_value::<Interrupt>(val.clone()).ok()
920 })
921 .collect()
922 })
923 .unwrap_or_default();
924
925 snapshots.push(StateSnapshot {
926 values,
927 next,
928 config: saved.config.clone(),
929 metadata: Some(saved.metadata.clone()),
930 created_at: Some(saved.checkpoint.ts.clone()),
931 parent_config: saved.parent_config.clone(),
932 tasks: vec![],
933 interrupts,
934 });
935 }
936
937 Ok(snapshots)
938 }
939}
940
941impl Clone for CompiledStateGraph {
942 fn clone(&self) -> Self {
943 let channels: HashMap<String, Box<dyn Channel>> = self.channels
944 .iter()
945 .map(|(k, c)| (k.clone(), c.clone_channel()))
946 .collect();
947
948 let branches: HashMap<String, HashMap<String, BranchSpec>> = self.branches
950 .iter()
951 .map(|(k, v)| (k.clone(), v.clone()))
952 .collect();
953
954 Self {
955 nodes: self.nodes.clone(),
956 edges: self.edges.clone(),
957 waiting_edges: self.waiting_edges.clone(),
958 branches,
959 channels,
960 checkpointer: self.checkpointer.clone(),
961 cache: self.cache.clone(),
962 store: self.store.clone(),
963 interrupt_before: self.interrupt_before.clone(),
964 interrupt_after: self.interrupt_after.clone(),
965 debug: self.debug,
966 name: self.name.clone(),
967 recursion_limit: self.recursion_limit,
968 }
969 }
970}
971
972fn build_pregel_nodes(
985 nodes: &HashMap<String, StateNodeSpec>,
986 edges: &HashSet<(String, String)>,
987 waiting_edges: &HashSet<WaitingEdge>,
988 branches: &HashMap<String, HashMap<String, BranchSpec>>,
989 channels: &HashMap<String, Box<dyn Channel>>,
990) -> HashMap<String, PregelNode> {
991 let mut pregel_nodes = HashMap::new();
992
993 let mut edge_targets: HashMap<String, Vec<String>> = HashMap::new();
995 for (start, end) in edges {
996 if end != END {
997 edge_targets.entry(start.clone()).or_default().push(end.clone());
998 }
999 }
1000
1001 let mut join_writes_for_source: HashMap<String, Vec<(String, String)>> = HashMap::new();
1011 let mut join_trigger_for_target: HashMap<String, String> = HashMap::new();
1012
1013 for (sources, target) in waiting_edges {
1014 let barrier_name = format!("join:{}:{}", sources.join("+"), target);
1017
1018 for source in sources {
1020 join_writes_for_source
1021 .entry(source.clone())
1022 .or_default()
1023 .push((barrier_name.clone(), source.clone()));
1024 }
1025
1026 join_trigger_for_target.insert(target.clone(), barrier_name);
1028 }
1029
1030 for (name, spec) in nodes {
1032 let trigger = join_trigger_for_target
1036 .get(name)
1037 .cloned()
1038 .unwrap_or_else(|| format!("branch:to:{}", name));
1039
1040 let input_channels: Vec<String> = channels
1042 .keys()
1043 .filter(|k| {
1044 !k.starts_with("branch:") && !k.starts_with("join:") && *k != START
1045 })
1046 .cloned()
1047 .collect();
1048
1049 let targets: Vec<String> = edge_targets.get(name).cloned().unwrap_or_default();
1051
1052 let barrier_writes: Vec<(String, String)> = join_writes_for_source
1055 .get(name)
1056 .cloned()
1057 .unwrap_or_default();
1058
1059 let node_branches: Vec<BranchSpec> = branches
1061 .get(name)
1062 .map(|m| m.values().cloned().collect())
1063 .unwrap_or_default();
1064
1065 let node_runnable = spec.runnable.clone();
1066 let node_name = name.clone();
1067
1068 let combined: Arc<dyn Runnable> = Arc::new(
1069 crate::runnable::RunnableCallable::new(
1070 node_name.clone(),
1071 move |input, config| {
1072 let node_runnable = node_runnable.clone();
1073 let targets = targets.clone();
1074 let barrier_writes = barrier_writes.clone();
1075 let node_branches = node_branches.clone();
1076 async move {
1077 let output = node_runnable.ainvoke(&input, &config).await?;
1079
1080 let mut result = serde_json::Map::new();
1082
1083 if let Some(obj) = output.as_object() {
1085 for (k, v) in obj {
1086 result.insert(k.clone(), v.clone());
1087 }
1088 }
1089
1090 for target in &targets {
1092 let trigger_ch = format!("branch:to:{}", target);
1093 result.insert(trigger_ch, JsonValue::String(target.clone()));
1094 }
1095
1096 for (barrier_ch, source_name) in &barrier_writes {
1100 result.insert(
1101 barrier_ch.clone(),
1102 JsonValue::String(source_name.clone()),
1103 );
1104 }
1105
1106 for branch in &node_branches {
1108 let branch_result = branch.path.ainvoke(&output, &config).await?;
1109 let key = branch_result.as_str().unwrap_or("");
1110 if let Some(target) = branch.resolve(key) {
1111 let trigger_ch = format!("branch:to:{}", target);
1112 result.insert(trigger_ch, JsonValue::String(target));
1113 }
1114 }
1115
1116 Ok(JsonValue::Object(result))
1117 }
1118 },
1119 ),
1120 );
1121
1122 let pregel_node = PregelNode::new(
1123 input_channels,
1124 vec![trigger],
1125 combined,
1126 );
1127
1128 pregel_nodes.insert(name.clone(), pregel_node);
1129 }
1130
1131 pregel_nodes
1132}
1133
1134const DEFAULT_RECURSION_LIMIT: u64 = 25;
1136
1137struct StreamCtx<'a> {
1143 modes: &'a HashSet<StreamMode>,
1144 tx: &'a mpsc::Sender<StreamPart>,
1145 custom_tx: Option<mpsc::Sender<JsonValue>>,
1147}
1148
1149impl<'a> StreamCtx<'a> {
1150 fn has(&self, mode: &StreamMode) -> bool {
1151 self.modes.contains(mode)
1152 }
1153}
1154
1155fn apply_completed_writes(
1162 interrupted_task_id: &str,
1163 tasks: &[PregelExecutableTask],
1164 channels: &HashMap<String, Box<dyn Channel>>,
1165 versions_seen: &mut HashMap<String, HashMap<String, JsonValue>>,
1166 channel_versions: &mut ChannelVersions,
1167) {
1168 for task in tasks.iter().filter(|t| t.id != interrupted_task_id && !t.writes.is_empty()) {
1170 let seen = versions_seen.entry(task.name.clone()).or_default();
1171 for trigger in &task.triggers {
1172 if let Some(ver) = channel_versions.get(trigger.as_str()) {
1173 seen.insert(trigger.clone(), ver.clone());
1174 }
1175 }
1176 }
1177
1178 let max_ver = channel_versions
1180 .values()
1181 .filter_map(|v| v.as_str().and_then(|s| s.parse::<u64>().ok()))
1182 .max()
1183 .unwrap_or(0);
1184 let next_version = JsonValue::String(format!("{:032}", max_ver + 1));
1185
1186 let mut writes_by_channel: HashMap<String, Vec<JsonValue>> = HashMap::new();
1189 for task in tasks.iter().filter(|t| t.id != interrupted_task_id && !t.writes.is_empty()) {
1190 for (chan, val) in &task.writes {
1191 if crate::constants::RESERVED.contains(&chan.as_str()) {
1192 continue;
1193 }
1194 writes_by_channel.entry(chan.clone()).or_default().push(val.clone());
1195 }
1196 }
1197
1198 for (chan, vals) in &writes_by_channel {
1199 if let Some(ch) = channels.get(chan.as_str()) {
1200 if ch.update(vals).unwrap_or(false) {
1201 channel_versions.insert(chan.clone(), next_version.clone());
1202 }
1203 }
1204 }
1205}
1206
1207fn output_channel_keys(channels: &HashMap<String, Box<dyn Channel>>) -> Vec<String> {
1209 channels
1210 .keys()
1211 .filter(|k| !k.starts_with("branch:") && !k.starts_with("join:") && *k != START)
1212 .cloned()
1213 .collect()
1214}
1215
1216fn bump_version(current: Option<&JsonValue>) -> JsonValue {
1218 let num = current
1219 .and_then(|v| v.as_str())
1220 .and_then(|s| s.parse::<u64>().ok())
1221 .unwrap_or(0);
1222 JsonValue::String(format!("{:032}", num + 1))
1223}
1224
1225impl CompiledStateGraph {
1226 async fn run_pregel(
1232 &self,
1233 input: &JsonValue,
1234 config: &RunnableConfig,
1235 ) -> Result<JsonValue, RunnableError> {
1236 self.run_pregel_inner(input, config, None).await
1237 }
1238
1239 async fn run_pregel_streaming(
1241 &self,
1242 input: &JsonValue,
1243 config: &RunnableConfig,
1244 modes: &HashSet<StreamMode>,
1245 tx: &mpsc::Sender<StreamPart>,
1246 ) -> Result<JsonValue, RunnableError> {
1247 let (custom_tx, has_custom) = if modes.contains(&StreamMode::Custom) {
1249 let (ctx, mut crx) = mpsc::channel::<JsonValue>(64);
1250 let tx_clone = tx.clone();
1251 tokio::spawn(async move {
1252 while let Some(data) = crx.recv().await {
1253 let _ = tx_clone.send(StreamPart::custom(vec![], data)).await;
1254 }
1255 });
1256 (Some(ctx), true)
1257 } else {
1258 (None, false)
1259 };
1260 let _ = has_custom; let ctx = StreamCtx { modes, tx, custom_tx };
1263 self.run_pregel_inner(input, config, Some(&ctx)).await
1264 }
1265
1266 pub fn astream(
1268 &self,
1269 input: &JsonValue,
1270 config: &RunnableConfig,
1271 stream_modes: Vec<StreamMode>,
1272 ) -> ReceiverStream<StreamPart> {
1273 let (tx, rx) = mpsc::channel(256);
1274 let modes: HashSet<StreamMode> = stream_modes.into_iter().collect();
1275
1276 let graph = self.clone();
1277 let input = input.clone();
1278 let config = config.clone();
1279
1280 tokio::spawn(async move {
1281 let result = graph.run_pregel_streaming(&input, &config, &modes, &tx).await;
1282 if let Err(e) = result {
1283 let _ = tx.send(StreamPart::debug(
1284 vec![],
1285 serde_json::json!({"error": e.to_string()}),
1286 )).await;
1287 }
1288 });
1289
1290 ReceiverStream::new(rx)
1291 }
1292
1293 async fn run_pregel_inner(
1302 &self,
1303 input: &JsonValue,
1304 config: &RunnableConfig,
1305 stream: Option<&StreamCtx<'_>>,
1306 ) -> Result<JsonValue, RunnableError> {
1307 let mut config = config.clone();
1308 let pregel_nodes = build_pregel_nodes(
1311 &self.nodes,
1312 &self.edges,
1313 &self.waiting_edges,
1314 &self.branches,
1315 &self.channels,
1316 );
1317 let trigger_to_nodes = crate::pregel::build_trigger_to_nodes(&pregel_nodes);
1318
1319 let mut saved_checkpoint_exists = false;
1321 let (mut channels, mut channel_versions, mut versions_seen) =
1322 if let Some(ref cp) = self.checkpointer {
1323 match cp.get_tuple(&config) {
1324 Ok(Some(tuple)) => {
1325 saved_checkpoint_exists = true;
1326 let cp_channels: HashMap<String, Option<JsonValue>> = tuple
1327 .checkpoint
1328 .channel_values
1329 .iter()
1330 .map(|(k, v)| (k.clone(), Some(v.clone())))
1331 .collect();
1332 let restored = channels_from_checkpoint(&self.channels, &cp_channels);
1333
1334 if let Some(ref pending) = tuple.pending_writes {
1336 for (_task_id, channel, value) in pending {
1337 if channel != RESUME {
1338 if let Some(ch) = restored.get(channel) {
1339 ch.update(&[value.clone()]).ok();
1340 }
1341 }
1342 }
1343 }
1344
1345 (
1346 restored,
1347 tuple.checkpoint.channel_versions.clone(),
1348 tuple.checkpoint.versions_seen.clone(),
1349 )
1350 }
1351 _ => (
1352 self.channels.iter().map(|(k, c)| (k.clone(), c.clone_channel())).collect(),
1353 HashMap::new(),
1354 HashMap::new(),
1355 ),
1356 }
1357 } else {
1358 (
1359 self.channels.iter().map(|(k, c)| (k.clone(), c.clone_channel())).collect(),
1360 HashMap::new(),
1361 HashMap::new(),
1362 )
1363 };
1364
1365 let mut step: u64 = 0;
1367 let max_steps = config.get_recursion_limit().unwrap_or(self.recursion_limit);
1368 let mut last_output = JsonValue::Null;
1369 let mut pending_writes: Vec<(String, String, JsonValue)> = Vec::new();
1370
1371 let version_offset: u64 = if saved_checkpoint_exists {
1374 channel_versions
1375 .values()
1376 .filter_map(|v| v.as_str().and_then(|s| s.parse::<u64>().ok()))
1377 .max()
1378 .unwrap_or(0)
1379 + 1
1380 } else {
1381 0
1382 };
1383
1384 let is_resuming = if let Ok(cmd) = serde_json::from_value::<Command>(input.clone()) {
1386 let cmd_writes = map_command(&cmd);
1387 let has_resume = cmd_writes.iter().any(|(_, chan, _)| chan == RESUME);
1388 pending_writes.extend(cmd_writes);
1389 has_resume
1390 } else {
1391 false
1392 };
1393 let is_fork = input.is_null() && saved_checkpoint_exists;
1394
1395 if !is_fork && !is_resuming {
1401 let input_writes = map_input(&[START.to_string()], input);
1402 for (chan, val) in &input_writes {
1403 if let Some(ch) = channels.get(chan) {
1404 ch.update(&[val.clone()]).ok();
1405 }
1406 }
1407 if let Some(obj) = input.as_object() {
1408 for (key, val) in obj {
1409 if key != START && !key.starts_with("branch:") && !key.starts_with("join:") {
1410 if let Some(ch) = channels.get(key) {
1411 ch.update(&[val.clone()]).ok();
1412 }
1413 }
1414 }
1415 }
1416 for (chan, _) in &input_writes {
1417 channel_versions.insert(
1418 chan.clone(),
1419 JsonValue::String(format!("{:032}", version_offset + step)),
1420 );
1421 }
1422 for (start, end) in &self.edges {
1424 if start == START && end != END {
1425 let trigger_ch = format!("branch:to:{}", end);
1426 if let Some(ch) = channels.get(&trigger_ch) {
1427 ch.update(&[JsonValue::String(end.clone())]).ok();
1428 channel_versions.insert(
1429 trigger_ch,
1430 JsonValue::String(format!("{:032}", version_offset + step)),
1431 );
1432 }
1433 }
1434 }
1435 }
1436
1437
1438 while step < max_steps {
1441 let checkpoint_id = format!("{:032}", version_offset + step);
1442
1443 let mut tasks = prepare_next_tasks(
1445 &pregel_nodes,
1446 &channels,
1447 &config,
1448 version_offset + step,
1449 &mut versions_seen,
1450 &trigger_to_nodes,
1451 None,
1452 &checkpoint_id,
1453 &pending_writes,
1454 &channel_versions,
1455 );
1456
1457
1458
1459 if tasks.is_empty() {
1460 break;
1461 }
1462
1463 pending_writes.clear();
1465
1466 if let Some(s) = stream {
1468 if s.has(&StreamMode::Tasks) {
1469 for task in &tasks {
1470 let data = serde_json::json!({
1471 "id": task.id,
1472 "name": task.name,
1473 "triggers": task.triggers,
1474 });
1475 let _ = s.tx.send(StreamPart::tasks(vec![], data)).await;
1476 }
1477 }
1478 }
1479
1480 if !self.interrupt_before.is_empty() {
1482 let task_names: Vec<String> = tasks.iter().map(|t| t.name.clone()).collect();
1483 if task_names.iter().any(|n| self.interrupt_before.contains(n)) {
1484 if let Some(ref cp) = self.checkpointer {
1485 if let Some(new_config) = self.save_checkpoint(cp, &config, &channels, &channel_versions, &versions_seen) {
1486 config = new_config;
1487 }
1488 }
1489 if let Some(s) = stream {
1491 if s.has(&StreamMode::Values) {
1492 let keys = output_channel_keys(&channels);
1493 let _ = s.tx.send(StreamPart::values(vec![], read_channels(&channels, &keys))).await;
1494 }
1495 }
1496 let keys = output_channel_keys(&channels);
1497 return Ok(read_channels(&channels, &keys));
1498 }
1499 }
1500
1501 let runner = if let Some(s) = stream {
1503 let runtime = Arc::new(crate::runtime::Runtime {
1504 context: (),
1505 store: self.store.clone(),
1506 stream_writer: s.custom_tx.clone(),
1507 previous: None,
1508 execution_info: None,
1509 server_info: None,
1510 });
1511 if s.custom_tx.is_some() {
1512 PregelRunner::new(Some(runtime.clone()))
1513 .with_stream_writer(s.custom_tx.clone().unwrap())
1514 } else {
1515 PregelRunner::new(Some(runtime))
1516 }
1517 } else {
1518 PregelRunner::new(self.store.clone().map(|_| {
1519 Arc::new(crate::runtime::Runtime {
1520 context: (),
1521 store: self.store.clone(),
1522 stream_writer: None,
1523 previous: None,
1524 execution_info: None,
1525 server_info: None,
1526 })
1527 }))
1528 };
1529
1530 match runner.run_tasks(&mut tasks).await {
1531 Ok(()) => {}
1532
1533 Err(crate::pregel::runner::RunnerError::Interrupt { task_id, interrupt }) => {
1534 apply_completed_writes(
1539 &task_id,
1540 &tasks,
1541 &channels,
1542 &mut versions_seen,
1543 &mut channel_versions,
1544 );
1545
1546 if let Some(ref cp) = self.checkpointer {
1548 if let Some(new_config) = self.save_checkpoint(cp, &config, &channels, &channel_versions, &versions_seen) {
1549 config = new_config;
1550 }
1551 let iw: Vec<(String, String, JsonValue)> = interrupt
1553 .interrupts
1554 .iter()
1555 .map(|iv| {
1556 let val = serde_json::to_value(iv).unwrap_or(JsonValue::Null);
1557 (task_id.clone(), crate::constants::INTERRUPT.to_string(), val)
1558 })
1559 .collect();
1560 if !iw.is_empty() {
1561 if let Err(e) = cp.put_writes(&config, &iw, &task_id, "") {
1562 eprintln!("[CHECKPOINT] Failed to save interrupt writes: {}", e);
1563 }
1564 }
1565 }
1566
1567 if let Some(s) = stream {
1569 if s.has(&StreamMode::Values) {
1570 let keys = output_channel_keys(&channels);
1571 let _ = s.tx.send(StreamPart::values(vec![], read_channels(&channels, &keys))).await;
1572 }
1573 }
1574
1575 let keys = output_channel_keys(&channels);
1576 return Ok(read_channels(&channels, &keys));
1577 }
1578
1579 Err(other) => return Err(RunnableError::Runner(other.to_string())),
1580 }
1581
1582 if let Some(s) = stream {
1584 if s.has(&StreamMode::Updates) {
1585 for task in &tasks {
1586 if !task.writes.is_empty() {
1587 let mut node_updates = serde_json::Map::new();
1588 for (chan, val) in &task.writes {
1589 if !chan.starts_with("branch:") && !chan.starts_with("join:") {
1590 node_updates.insert(chan.clone(), val.clone());
1591 }
1592 }
1593 if !node_updates.is_empty() {
1594 let data = serde_json::json!({ &task.name: node_updates });
1595 let _ = s.tx.send(StreamPart::updates(vec![], data)).await;
1596 }
1597 }
1598 }
1599 }
1600 }
1601
1602 apply_writes(
1604 &mut channels,
1605 &tasks,
1606 &mut versions_seen,
1607 &mut channel_versions,
1608 &trigger_to_nodes,
1609 bump_version,
1610 );
1611
1612 if let Some(ref cp) = self.checkpointer {
1624 if let Some(new_config) = self.save_checkpoint(cp, &config, &channels, &channel_versions, &versions_seen) {
1625 config = new_config;
1626 }
1627 }
1628
1629 if let Some(s) = stream {
1631 if s.has(&StreamMode::Values) {
1632 let keys = output_channel_keys(&channels);
1633 let _ = s.tx.send(StreamPart::values(vec![], read_channels(&channels, &keys))).await;
1634 }
1635 }
1636
1637 let keys = output_channel_keys(&channels);
1639 let output = read_channels(&channels, &keys);
1640 if !output.is_null() {
1641 last_output = output;
1642 }
1643
1644 if !self.interrupt_after.is_empty() {
1646 let task_names: Vec<String> = tasks.iter().map(|t| t.name.clone()).collect();
1647 if task_names.iter().any(|n| self.interrupt_after.contains(n)) {
1648 return Ok(last_output);
1649 }
1650 }
1651
1652 step += 1;
1653 }
1654
1655 Ok(last_output)
1656 }
1657
1658
1659}
1660
1661#[async_trait]
1662impl Runnable for CompiledStateGraph {
1663 fn invoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
1664 match tokio::runtime::Handle::try_current() {
1666 Ok(handle) => handle.block_on(self.run_pregel(input, config)),
1667 Err(_) => tokio::runtime::Runtime::new()
1668 .unwrap()
1669 .block_on(self.run_pregel(input, config)),
1670 }
1671 }
1672
1673 async fn ainvoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
1674 self.run_pregel(input, config).await
1675 }
1676
1677 fn name(&self) -> &str {
1678 &self.name
1679 }
1680}
1681
1682#[cfg(test)]
1683mod tests {
1684 use super::*;
1685 use crate::channels::LastValue;
1686 use serde_json::json;
1687
1688 fn make_channels() -> HashMap<String, Box<dyn Channel>> {
1689 let mut channels = HashMap::new();
1690 channels.insert("value".to_string(), Box::new(LastValue::new("value")) as Box<dyn Channel>);
1691 channels
1692 }
1693
1694 #[tokio::test]
1695 async fn test_simple_linear_graph() {
1696 let mut graph = StateGraph::new(make_channels());
1697
1698 graph
1699 .add_node("a", |_input, _config| async { Ok(json!({"value": 1})) })
1700 .unwrap();
1701 graph
1702 .add_node("b", |_input, _config| async { Ok(json!({"value": 2})) })
1703 .unwrap();
1704
1705 graph.add_edge(START, "a").unwrap();
1706 graph.add_edge("a", "b").unwrap();
1707 graph.add_edge("b", END).unwrap();
1708
1709 let compiled = graph.compile().unwrap();
1710 assert!(compiled.has_node("a"));
1711 assert!(compiled.has_node("b"));
1712 assert_eq!(compiled.node_names().len(), 2);
1713 }
1714
1715 #[test]
1716 fn test_duplicate_node_error() {
1717 let mut graph = StateGraph::new(make_channels());
1718 graph.add_node("a", |_input, _config| async { Ok(json!({})) }).unwrap();
1719 let result = graph.add_node("a", |_input, _config| async { Ok(json!({})) });
1720 assert!(result.is_err());
1721 }
1722
1723 #[test]
1724 fn test_reserved_name_error() {
1725 let mut graph = StateGraph::new(make_channels());
1726 let result = graph.add_node(START, |_input, _config| async { Ok(json!({})) });
1727 assert!(result.is_err());
1728 }
1729
1730 #[test]
1731 fn test_end_as_source_error() {
1732 let mut graph = StateGraph::new(make_channels());
1733 graph.add_node("a", |_input, _config| async { Ok(json!({})) }).unwrap();
1734 let result = graph.add_edge(END, "a");
1735 assert!(result.is_err());
1736 }
1737
1738 #[test]
1739 fn test_start_as_target_error() {
1740 let mut graph = StateGraph::new(make_channels());
1741 graph.add_node("a", |_input, _config| async { Ok(json!({})) }).unwrap();
1742 let result = graph.add_edge("a", START);
1743 assert!(result.is_err());
1744 }
1745
1746 #[test]
1747 fn test_no_start_edge_error() {
1748 let mut graph = StateGraph::new(make_channels());
1749 graph.add_node("a", |_input, _config| async { Ok(json!({})) }).unwrap();
1750 let result = graph.compile();
1751 assert!(result.is_err());
1752 }
1753
1754 #[test]
1755 fn test_join_edge() {
1756 let mut graph = StateGraph::new(make_channels());
1757 graph.add_node("a", |_input, _config| async { Ok(json!({})) }).unwrap();
1758 graph.add_node("b", |_input, _config| async { Ok(json!({})) }).unwrap();
1759 graph.add_node("c", |_input, _config| async { Ok(json!({})) }).unwrap();
1760
1761 graph.add_edge(START, "a").unwrap();
1762 graph.add_edge(START, "b").unwrap();
1763 graph.add_join_edge(vec!["a".to_string(), "b".to_string()], "c").unwrap();
1764 graph.add_edge("c", END).unwrap();
1765
1766 let compiled = graph.compile().unwrap();
1767 assert_eq!(compiled.node_names().len(), 3);
1768 }
1769
1770 #[test]
1771 fn test_conditional_edges() {
1772 let mut graph = StateGraph::new(make_channels());
1773 graph.add_node("agent", |_input, _config| async { Ok(json!({})) }).unwrap();
1774 graph.add_node("tools", |_input, _config| async { Ok(json!({})) }).unwrap();
1775
1776 graph.add_edge(START, "agent").unwrap();
1777 graph
1778 .add_conditional_edges(
1779 "agent",
1780 |_input, _config| async { Ok(json!("continue")) },
1781 Some(HashMap::from([
1782 ("continue".to_string(), "tools".to_string()),
1783 ("end".to_string(), END.to_string()),
1784 ])),
1785 )
1786 .unwrap();
1787 graph.add_edge("tools", "agent").unwrap();
1788
1789 let compiled = graph.compile().unwrap();
1790 assert!(compiled.has_node("agent"));
1791 assert!(compiled.has_node("tools"));
1792 }
1793
1794 #[tokio::test]
1795 async fn test_invoke_linear_graph() {
1796 let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
1798 channels.insert("count".to_string(), Box::new(LastValue::new("count")) as Box<dyn Channel>);
1799
1800 let mut graph = StateGraph::new(channels);
1801
1802 graph
1803 .add_node("increment", |_input, _config| async {
1804 Ok(json!({"count": 1}))
1805 })
1806 .unwrap();
1807 graph
1808 .add_node("double", |_input, _config| async {
1809 Ok(json!({"count": 2}))
1810 })
1811 .unwrap();
1812
1813 graph.add_edge(START, "increment").unwrap();
1814 graph.add_edge("increment", "double").unwrap();
1815 graph.add_edge("double", END).unwrap();
1816
1817 let compiled = graph.compile().unwrap();
1818 let config = RunnableConfig::new();
1819 let result = compiled.ainvoke(&json!({"count": 0}), &config).await.unwrap();
1820
1821 assert!(result.is_object());
1823 assert_eq!(result.get("count"), Some(&json!(2)));
1825 }
1826
1827 #[tokio::test]
1828 async fn test_invoke_single_node() {
1829 let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
1830 channels.insert("result".to_string(), Box::new(LastValue::new("result")) as Box<dyn Channel>);
1831
1832 let mut graph = StateGraph::new(channels);
1833 graph
1834 .add_node("process", |_input, _config| async {
1835 Ok(json!({"result": 42}))
1836 })
1837 .unwrap();
1838 graph.add_edge(START, "process").unwrap();
1839 graph.add_edge("process", END).unwrap();
1840
1841 let compiled = graph.compile().unwrap();
1842 let config = RunnableConfig::new();
1843 let result = compiled.ainvoke(&json!({}), &config).await.unwrap();
1844
1845 assert_eq!(result.get("result"), Some(&json!(42)));
1846 }
1847
1848 #[tokio::test]
1849 async fn test_interrupt_before() {
1850 let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
1852 channels.insert("value".to_string(), Box::new(LastValue::new("value")) as Box<dyn Channel>);
1853
1854 let mut graph = StateGraph::new(channels);
1855
1856 graph
1857 .add_node("process", |_input, _config| async {
1858 Ok(json!({"value": 42}))
1859 })
1860 .unwrap();
1861 graph.add_edge(START, "process").unwrap();
1862 graph.add_edge("process", END).unwrap();
1863
1864 let mut compiled = graph.compile().unwrap();
1865 compiled.interrupt_before = vec!["process".to_string()];
1867
1868 let config = RunnableConfig::new();
1869 let result = compiled.ainvoke(&json!({}), &config).await.unwrap();
1870
1871 assert!(result.is_object());
1873 assert!(result.get("value").is_none() || result.get("value").unwrap().is_null());
1875 }
1876
1877 #[tokio::test]
1878 async fn test_interrupt_after() {
1879 let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
1881 channels.insert("value".to_string(), Box::new(LastValue::new("value")) as Box<dyn Channel>);
1882
1883 let mut graph = StateGraph::new(channels);
1884
1885 graph
1886 .add_node("process", |_input, _config| async {
1887 Ok(json!({"value": 42}))
1888 })
1889 .unwrap();
1890 graph.add_edge(START, "process").unwrap();
1891 graph.add_edge("process", END).unwrap();
1892
1893 let mut compiled = graph.compile().unwrap();
1894 compiled.interrupt_after = vec!["process".to_string()];
1896
1897 let config = RunnableConfig::new();
1898 let result = compiled.ainvoke(&json!({}), &config).await.unwrap();
1899
1900 assert!(result.is_object());
1902 assert_eq!(result.get("value"), Some(&json!(42)));
1903 }
1904
1905 #[tokio::test]
1906 async fn test_update_state() {
1907 use crate::channels::LastValue;
1908 use langgraph_checkpoint::checkpoint::memory::InMemorySaver;
1909
1910 let mut channels: HashMap<String, Box<dyn Channel>> = HashMap::new();
1911 channels.insert("name".to_string(), Box::new(LastValue::new("name")) as Box<dyn Channel>);
1912 channels.insert("value".to_string(), Box::new(LastValue::new("value")) as Box<dyn Channel>);
1913
1914 let mut graph = StateGraph::new(channels);
1915 graph
1916 .add_node("set_value", |_input, _config| async {
1917 Ok(json!({"value": 42}))
1918 })
1919 .unwrap();
1920 graph.add_edge(START, "set_value").unwrap();
1921 graph.add_edge("set_value", END).unwrap();
1922
1923 let checkpointer = Arc::new(InMemorySaver::new());
1924 let compiled = graph.compile_builder()
1925 .checkpointer(checkpointer)
1926 .build()
1927 .unwrap();
1928
1929 let mut config = RunnableConfig::new();
1930 config.insert("configurable".to_string(), json!({"thread_id": "test-thread"}));
1931
1932 let result = compiled.ainvoke(&json!({"name": "original"}), &config).await.unwrap();
1934 assert_eq!(result.get("value"), Some(&json!(42)));
1935
1936 let snapshot = compiled.get_state(&config).unwrap();
1938 assert_eq!(snapshot.values.get("name").and_then(|v| v.as_str()), Some("original"));
1939 assert_eq!(snapshot.values.get("value").and_then(|v| v.as_i64()), Some(42));
1940
1941 compiled.update_state(&config, &json!({"name": "updated"})).unwrap();
1943
1944 let snapshot = compiled.get_state(&config).unwrap();
1946 assert_eq!(snapshot.values.get("name").and_then(|v| v.as_str()), Some("updated"));
1947 assert_eq!(snapshot.values.get("value").and_then(|v| v.as_i64()), Some(42));
1948 }
1949}