1use crate::checkpoint::Checkpointer;
4use crate::edge::{END, Edge, EdgeTarget, RouterFn, START};
5use crate::error::{GraphError, Result};
6use crate::node::{FunctionNode, Node, NodeContext, NodeOutput};
7use crate::state::{State, StateSchema};
8use std::collections::{HashMap, HashSet};
9use std::future::Future;
10use std::sync::Arc;
11
12pub struct StateGraph {
14 pub schema: StateSchema,
16 pub nodes: HashMap<String, Arc<dyn Node>>,
18 pub edges: Vec<Edge>,
20}
21
22impl StateGraph {
23 pub fn new(schema: StateSchema) -> Self {
25 Self { schema, nodes: HashMap::new(), edges: vec![] }
26 }
27
28 pub fn with_channels(channels: &[&str]) -> Self {
30 Self::new(StateSchema::simple(channels))
31 }
32
33 pub fn add_node<N: Node + 'static>(mut self, node: N) -> Self {
35 self.nodes.insert(node.name().to_string(), Arc::new(node));
36 self
37 }
38
39 pub fn add_node_fn<F, Fut>(self, name: &str, func: F) -> Self
41 where
42 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
43 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
44 {
45 self.add_node(FunctionNode::new(name, func))
46 }
47
48 pub fn add_edge(mut self, source: &str, target: &str) -> Self {
50 let target = EdgeTarget::from(target);
51
52 if source == START {
53 let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
55
56 match entry_idx {
57 Some(idx) => {
58 if let Edge::Entry { targets } = &mut self.edges[idx]
59 && let EdgeTarget::Node(node) = &target
60 && !targets.contains(node)
61 {
62 targets.push(node.clone());
63 }
64 }
65 None => {
66 if let EdgeTarget::Node(node) = target {
67 self.edges.push(Edge::Entry { targets: vec![node] });
68 }
69 }
70 }
71 } else {
72 self.edges.push(Edge::Direct { source: source.to_string(), target });
73 }
74
75 self
76 }
77
78 pub fn add_conditional_edges<F, I>(mut self, source: &str, router: F, targets: I) -> Self
80 where
81 F: Fn(&State) -> String + Send + Sync + 'static,
82 I: IntoIterator<Item = (&'static str, &'static str)>,
83 {
84 let targets_map: HashMap<String, EdgeTarget> =
85 targets.into_iter().map(|(k, v)| (k.to_string(), EdgeTarget::from(v))).collect();
86
87 self.edges.push(Edge::Conditional {
88 source: source.to_string(),
89 router: Arc::new(router),
90 targets: targets_map,
91 });
92
93 self
94 }
95
96 pub fn add_conditional_edges_arc<I>(
98 mut self,
99 source: &str,
100 router: RouterFn,
101 targets: I,
102 ) -> Self
103 where
104 I: IntoIterator<Item = (&'static str, &'static str)>,
105 {
106 let targets_map: HashMap<String, EdgeTarget> =
107 targets.into_iter().map(|(k, v)| (k.to_string(), EdgeTarget::from(v))).collect();
108
109 self.edges.push(Edge::Conditional {
110 source: source.to_string(),
111 router,
112 targets: targets_map,
113 });
114
115 self
116 }
117
118 pub fn compile(self) -> Result<CompiledGraph> {
120 self.validate()?;
121
122 Ok(CompiledGraph {
123 schema: self.schema,
124 nodes: self.nodes,
125 edges: self.edges,
126 checkpointer: None,
127 interrupt_before: HashSet::new(),
128 interrupt_after: HashSet::new(),
129 recursion_limit: 50,
130 timeout_policies: HashMap::new(),
131 default_timeout: None,
132 deferred_configs: HashMap::new(),
133 #[cfg(feature = "node-cache")]
134 cache_policies: HashMap::new(),
135 })
136 }
137
138 fn validate(&self) -> Result<()> {
140 let has_entry = self.edges.iter().any(|e| matches!(e, Edge::Entry { .. }));
142 if !has_entry {
143 return Err(GraphError::NoEntryPoint);
144 }
145
146 for edge in &self.edges {
148 match edge {
149 Edge::Direct { source, target } => {
150 if source != START && !self.nodes.contains_key(source) {
151 return Err(GraphError::NodeNotFound(source.clone()));
152 }
153 if let EdgeTarget::Node(name) = target
154 && !self.nodes.contains_key(name)
155 {
156 return Err(GraphError::EdgeTargetNotFound(name.clone()));
157 }
158 }
159 Edge::Conditional { source, targets, .. } => {
160 if !self.nodes.contains_key(source) {
161 return Err(GraphError::NodeNotFound(source.clone()));
162 }
163 for target in targets.values() {
164 if let EdgeTarget::Node(name) = target
165 && !self.nodes.contains_key(name)
166 {
167 return Err(GraphError::EdgeTargetNotFound(name.clone()));
168 }
169 }
170 }
171 Edge::Entry { targets } => {
172 for target in targets {
173 if !self.nodes.contains_key(target) {
174 return Err(GraphError::EdgeTargetNotFound(target.clone()));
175 }
176 }
177 }
178 }
179 }
180
181 Ok(())
182 }
183}
184
185pub struct CompiledGraph {
187 pub(crate) schema: StateSchema,
188 pub(crate) nodes: HashMap<String, Arc<dyn Node>>,
189 pub(crate) edges: Vec<Edge>,
190 pub(crate) checkpointer: Option<Arc<dyn Checkpointer>>,
191 pub(crate) interrupt_before: HashSet<String>,
192 pub(crate) interrupt_after: HashSet<String>,
193 pub(crate) recursion_limit: usize,
194 pub(crate) timeout_policies: HashMap<String, crate::timeout::TimeoutPolicy>,
196 pub(crate) default_timeout: Option<crate::timeout::TimeoutPolicy>,
198 pub(crate) deferred_configs: HashMap<String, crate::deferred::DeferredNodeConfig>,
200 #[cfg(feature = "node-cache")]
202 pub(crate) cache_policies: HashMap<String, crate::cache::NodeCachePolicy>,
203}
204
205impl CompiledGraph {
206 pub fn with_checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
208 self.checkpointer = Some(Arc::new(checkpointer));
209 self
210 }
211
212 pub fn with_checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
214 self.checkpointer = Some(checkpointer);
215 self
216 }
217
218 pub fn with_interrupt_before(mut self, nodes: &[&str]) -> Self {
220 self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
221 self
222 }
223
224 pub fn with_interrupt_after(mut self, nodes: &[&str]) -> Self {
226 self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
227 self
228 }
229
230 pub fn with_recursion_limit(mut self, limit: usize) -> Self {
232 self.recursion_limit = limit;
233 self
234 }
235
236 pub fn timeout_policy_for(&self, node_name: &str) -> Option<&crate::timeout::TimeoutPolicy> {
242 self.timeout_policies.get(node_name).or(self.default_timeout.as_ref())
243 }
244
245 pub fn get_entry_nodes(&self) -> Vec<String> {
247 for edge in &self.edges {
248 if let Edge::Entry { targets } = edge {
249 return targets.clone();
250 }
251 }
252 vec![]
253 }
254
255 pub fn get_next_nodes(&self, executed: &[String], state: &State) -> Vec<String> {
257 let mut next = Vec::new();
258
259 for edge in &self.edges {
260 match edge {
261 Edge::Direct { source, target: EdgeTarget::Node(n) }
262 if executed.contains(source) =>
263 {
264 if !next.contains(n) {
265 next.push(n.clone());
266 }
267 }
268 Edge::Conditional { source, router, targets } if executed.contains(source) => {
269 let route = router(state);
270 if let Some(EdgeTarget::Node(n)) = targets.get(&route)
271 && !next.contains(n)
272 {
273 next.push(n.clone());
274 }
275 }
277 _ => {}
278 }
279 }
280
281 next
282 }
283
284 pub fn leads_to_end(&self, executed: &[String], state: &State) -> bool {
286 for edge in &self.edges {
287 match edge {
288 Edge::Direct { source, target } if executed.contains(source) => {
289 if target.is_end() {
290 return true;
291 }
292 }
293 Edge::Conditional { source, router, targets } if executed.contains(source) => {
294 let route = router(state);
295 if route == END {
296 return true;
297 }
298 if let Some(target) = targets.get(&route)
299 && target.is_end()
300 {
301 return true;
302 }
303 }
304 _ => {}
305 }
306 }
307 false
308 }
309
310 pub fn get_upstream_nodes(&self, target_node: &str) -> Vec<String> {
319 let mut sources = Vec::new();
320
321 for edge in &self.edges {
322 match edge {
323 Edge::Direct { source, target } => {
324 if let EdgeTarget::Node(name) = target
325 && name == target_node
326 && !sources.contains(source)
327 {
328 sources.push(source.clone());
329 }
330 }
331 Edge::Conditional { source, targets, .. } => {
332 for target in targets.values() {
333 if let EdgeTarget::Node(name) = target
334 && name == target_node
335 && !sources.contains(source)
336 {
337 sources.push(source.clone());
338 }
339 }
340 }
341 Edge::Entry { targets } => {
342 if targets.contains(&target_node.to_string()) {
343 }
346 }
347 }
348 }
349
350 sources
351 }
352
353 pub fn schema(&self) -> &StateSchema {
355 &self.schema
356 }
357
358 pub fn checkpointer(&self) -> Option<&Arc<dyn Checkpointer>> {
360 self.checkpointer.as_ref()
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use serde_json::json;
368
369 #[test]
370 fn test_basic_graph_construction() {
371 let graph = StateGraph::with_channels(&["input", "output"])
372 .add_node_fn("process", |_ctx| async { Ok(NodeOutput::new()) })
373 .add_edge(START, "process")
374 .add_edge("process", END)
375 .compile();
376
377 assert!(graph.is_ok());
378 }
379
380 #[test]
381 fn test_graph_missing_entry() {
382 let graph = StateGraph::with_channels(&["input"])
383 .add_node_fn("process", |_ctx| async { Ok(NodeOutput::new()) })
384 .add_edge("process", END) .compile();
386
387 assert!(matches!(graph, Err(GraphError::NoEntryPoint)));
388 }
389
390 #[test]
391 fn test_graph_missing_node() {
392 let graph = StateGraph::with_channels(&["input"]).add_edge(START, "nonexistent").compile();
393
394 assert!(matches!(graph, Err(GraphError::EdgeTargetNotFound(_))));
395 }
396
397 #[test]
398 fn test_conditional_edges() {
399 let graph = StateGraph::with_channels(&["next"])
400 .add_node_fn("router", |_ctx| async { Ok(NodeOutput::new()) })
401 .add_node_fn("path_a", |_ctx| async { Ok(NodeOutput::new()) })
402 .add_node_fn("path_b", |_ctx| async { Ok(NodeOutput::new()) })
403 .add_edge(START, "router")
404 .add_conditional_edges(
405 "router",
406 |state| state.get("next").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
407 [("path_a", "path_a"), ("path_b", "path_b"), (END, END)],
408 )
409 .compile()
410 .unwrap();
411
412 let mut state = State::new();
414 state.insert("next".to_string(), json!("path_a"));
415 let next = graph.get_next_nodes(&["router".to_string()], &state);
416 assert_eq!(next, vec!["path_a".to_string()]);
417
418 state.insert("next".to_string(), json!("path_b"));
419 let next = graph.get_next_nodes(&["router".to_string()], &state);
420 assert_eq!(next, vec!["path_b".to_string()]);
421 }
422}