1use crate::checkpoint::Checkpointer;
4use crate::edge::{END, Edge, EdgeTarget, RouterFn, START};
5use crate::error::{GraphError, Result};
6use crate::node::{FunctionNode, Node, NodeContext, NodeOutput};
7use crate::state::{State, StateSchema};
8use std::collections::{HashMap, HashSet};
9use std::future::Future;
10use std::sync::Arc;
11
12pub struct StateGraph {
14 pub schema: StateSchema,
16 pub nodes: HashMap<String, Arc<dyn Node>>,
18 pub edges: Vec<Edge>,
20}
21
22impl StateGraph {
23 pub fn new(schema: StateSchema) -> Self {
25 Self { schema, nodes: HashMap::new(), edges: vec![] }
26 }
27
28 pub fn with_channels(channels: &[&str]) -> Self {
30 Self::new(StateSchema::simple(channels))
31 }
32
33 pub fn add_node<N: Node + 'static>(mut self, node: N) -> Self {
35 self.nodes.insert(node.name().to_string(), Arc::new(node));
36 self
37 }
38
39 pub fn add_node_fn<F, Fut>(self, name: &str, func: F) -> Self
41 where
42 F: Fn(NodeContext) -> Fut + Send + Sync + 'static,
43 Fut: Future<Output = Result<NodeOutput>> + Send + 'static,
44 {
45 self.add_node(FunctionNode::new(name, func))
46 }
47
48 pub fn add_edge(mut self, source: &str, target: &str) -> Self {
50 let target = EdgeTarget::from(target);
51
52 if source == START {
53 let entry_idx = self.edges.iter().position(|e| matches!(e, Edge::Entry { .. }));
55
56 match entry_idx {
57 Some(idx) => {
58 if let Edge::Entry { targets } = &mut self.edges[idx] {
59 if let EdgeTarget::Node(node) = &target {
60 if !targets.contains(node) {
61 targets.push(node.clone());
62 }
63 }
64 }
65 }
66 None => {
67 if let EdgeTarget::Node(node) = target {
68 self.edges.push(Edge::Entry { targets: vec![node] });
69 }
70 }
71 }
72 } else {
73 self.edges.push(Edge::Direct { source: source.to_string(), target });
74 }
75
76 self
77 }
78
79 pub fn add_conditional_edges<F, I>(mut self, source: &str, router: F, targets: I) -> Self
81 where
82 F: Fn(&State) -> String + Send + Sync + 'static,
83 I: IntoIterator<Item = (&'static str, &'static str)>,
84 {
85 let targets_map: HashMap<String, EdgeTarget> =
86 targets.into_iter().map(|(k, v)| (k.to_string(), EdgeTarget::from(v))).collect();
87
88 self.edges.push(Edge::Conditional {
89 source: source.to_string(),
90 router: Arc::new(router),
91 targets: targets_map,
92 });
93
94 self
95 }
96
97 pub fn add_conditional_edges_arc<I>(
99 mut self,
100 source: &str,
101 router: RouterFn,
102 targets: I,
103 ) -> Self
104 where
105 I: IntoIterator<Item = (&'static str, &'static str)>,
106 {
107 let targets_map: HashMap<String, EdgeTarget> =
108 targets.into_iter().map(|(k, v)| (k.to_string(), EdgeTarget::from(v))).collect();
109
110 self.edges.push(Edge::Conditional {
111 source: source.to_string(),
112 router,
113 targets: targets_map,
114 });
115
116 self
117 }
118
119 pub fn compile(self) -> Result<CompiledGraph> {
121 self.validate()?;
122
123 Ok(CompiledGraph {
124 schema: self.schema,
125 nodes: self.nodes,
126 edges: self.edges,
127 checkpointer: None,
128 interrupt_before: HashSet::new(),
129 interrupt_after: HashSet::new(),
130 recursion_limit: 50,
131 timeout_policies: HashMap::new(),
132 default_timeout: None,
133 deferred_configs: HashMap::new(),
134 #[cfg(feature = "node-cache")]
135 cache_policies: HashMap::new(),
136 })
137 }
138
139 fn validate(&self) -> Result<()> {
141 let has_entry = self.edges.iter().any(|e| matches!(e, Edge::Entry { .. }));
143 if !has_entry {
144 return Err(GraphError::NoEntryPoint);
145 }
146
147 for edge in &self.edges {
149 match edge {
150 Edge::Direct { source, target } => {
151 if source != START && !self.nodes.contains_key(source) {
152 return Err(GraphError::NodeNotFound(source.clone()));
153 }
154 if let EdgeTarget::Node(name) = target {
155 if !self.nodes.contains_key(name) {
156 return Err(GraphError::EdgeTargetNotFound(name.clone()));
157 }
158 }
159 }
160 Edge::Conditional { source, targets, .. } => {
161 if !self.nodes.contains_key(source) {
162 return Err(GraphError::NodeNotFound(source.clone()));
163 }
164 for target in targets.values() {
165 if let EdgeTarget::Node(name) = target {
166 if !self.nodes.contains_key(name) {
167 return Err(GraphError::EdgeTargetNotFound(name.clone()));
168 }
169 }
170 }
171 }
172 Edge::Entry { targets } => {
173 for target in targets {
174 if !self.nodes.contains_key(target) {
175 return Err(GraphError::EdgeTargetNotFound(target.clone()));
176 }
177 }
178 }
179 }
180 }
181
182 Ok(())
183 }
184}
185
186pub struct CompiledGraph {
188 pub(crate) schema: StateSchema,
189 pub(crate) nodes: HashMap<String, Arc<dyn Node>>,
190 pub(crate) edges: Vec<Edge>,
191 pub(crate) checkpointer: Option<Arc<dyn Checkpointer>>,
192 pub(crate) interrupt_before: HashSet<String>,
193 pub(crate) interrupt_after: HashSet<String>,
194 pub(crate) recursion_limit: usize,
195 pub(crate) timeout_policies: HashMap<String, crate::timeout::TimeoutPolicy>,
197 pub(crate) default_timeout: Option<crate::timeout::TimeoutPolicy>,
199 pub(crate) deferred_configs: HashMap<String, crate::deferred::DeferredNodeConfig>,
201 #[cfg(feature = "node-cache")]
203 pub(crate) cache_policies: HashMap<String, crate::cache::NodeCachePolicy>,
204}
205
206impl CompiledGraph {
207 pub fn with_checkpointer<C: Checkpointer + 'static>(mut self, checkpointer: C) -> Self {
209 self.checkpointer = Some(Arc::new(checkpointer));
210 self
211 }
212
213 pub fn with_checkpointer_arc(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
215 self.checkpointer = Some(checkpointer);
216 self
217 }
218
219 pub fn with_interrupt_before(mut self, nodes: &[&str]) -> Self {
221 self.interrupt_before = nodes.iter().map(|s| s.to_string()).collect();
222 self
223 }
224
225 pub fn with_interrupt_after(mut self, nodes: &[&str]) -> Self {
227 self.interrupt_after = nodes.iter().map(|s| s.to_string()).collect();
228 self
229 }
230
231 pub fn with_recursion_limit(mut self, limit: usize) -> Self {
233 self.recursion_limit = limit;
234 self
235 }
236
237 pub fn timeout_policy_for(&self, node_name: &str) -> Option<&crate::timeout::TimeoutPolicy> {
243 self.timeout_policies.get(node_name).or(self.default_timeout.as_ref())
244 }
245
246 pub fn get_entry_nodes(&self) -> Vec<String> {
248 for edge in &self.edges {
249 if let Edge::Entry { targets } = edge {
250 return targets.clone();
251 }
252 }
253 vec![]
254 }
255
256 pub fn get_next_nodes(&self, executed: &[String], state: &State) -> Vec<String> {
258 let mut next = Vec::new();
259
260 for edge in &self.edges {
261 match edge {
262 Edge::Direct { source, target: EdgeTarget::Node(n) }
263 if executed.contains(source) =>
264 {
265 if !next.contains(n) {
266 next.push(n.clone());
267 }
268 }
269 Edge::Conditional { source, router, targets } if executed.contains(source) => {
270 let route = router(state);
271 if let Some(EdgeTarget::Node(n)) = targets.get(&route) {
272 if !next.contains(n) {
273 next.push(n.clone());
274 }
275 }
276 }
278 _ => {}
279 }
280 }
281
282 next
283 }
284
285 pub fn leads_to_end(&self, executed: &[String], state: &State) -> bool {
287 for edge in &self.edges {
288 match edge {
289 Edge::Direct { source, target } if executed.contains(source) => {
290 if target.is_end() {
291 return true;
292 }
293 }
294 Edge::Conditional { source, router, targets } if executed.contains(source) => {
295 let route = router(state);
296 if route == END {
297 return true;
298 }
299 if let Some(target) = targets.get(&route) {
300 if target.is_end() {
301 return true;
302 }
303 }
304 }
305 _ => {}
306 }
307 }
308 false
309 }
310
311 pub fn get_upstream_nodes(&self, target_node: &str) -> Vec<String> {
320 let mut sources = Vec::new();
321
322 for edge in &self.edges {
323 match edge {
324 Edge::Direct { source, target } => {
325 if let EdgeTarget::Node(name) = target {
326 if name == target_node && !sources.contains(source) {
327 sources.push(source.clone());
328 }
329 }
330 }
331 Edge::Conditional { source, targets, .. } => {
332 for target in targets.values() {
333 if let EdgeTarget::Node(name) = target {
334 if name == target_node && !sources.contains(source) {
335 sources.push(source.clone());
336 }
337 }
338 }
339 }
340 Edge::Entry { targets } => {
341 if targets.contains(&target_node.to_string()) {
342 }
345 }
346 }
347 }
348
349 sources
350 }
351
352 pub fn schema(&self) -> &StateSchema {
354 &self.schema
355 }
356
357 pub fn checkpointer(&self) -> Option<&Arc<dyn Checkpointer>> {
359 self.checkpointer.as_ref()
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use serde_json::json;
367
368 #[test]
369 fn test_basic_graph_construction() {
370 let graph = StateGraph::with_channels(&["input", "output"])
371 .add_node_fn("process", |_ctx| async { Ok(NodeOutput::new()) })
372 .add_edge(START, "process")
373 .add_edge("process", END)
374 .compile();
375
376 assert!(graph.is_ok());
377 }
378
379 #[test]
380 fn test_graph_missing_entry() {
381 let graph = StateGraph::with_channels(&["input"])
382 .add_node_fn("process", |_ctx| async { Ok(NodeOutput::new()) })
383 .add_edge("process", END) .compile();
385
386 assert!(matches!(graph, Err(GraphError::NoEntryPoint)));
387 }
388
389 #[test]
390 fn test_graph_missing_node() {
391 let graph = StateGraph::with_channels(&["input"]).add_edge(START, "nonexistent").compile();
392
393 assert!(matches!(graph, Err(GraphError::EdgeTargetNotFound(_))));
394 }
395
396 #[test]
397 fn test_conditional_edges() {
398 let graph = StateGraph::with_channels(&["next"])
399 .add_node_fn("router", |_ctx| async { Ok(NodeOutput::new()) })
400 .add_node_fn("path_a", |_ctx| async { Ok(NodeOutput::new()) })
401 .add_node_fn("path_b", |_ctx| async { Ok(NodeOutput::new()) })
402 .add_edge(START, "router")
403 .add_conditional_edges(
404 "router",
405 |state| state.get("next").and_then(|v| v.as_str()).unwrap_or(END).to_string(),
406 [("path_a", "path_a"), ("path_b", "path_b"), (END, END)],
407 )
408 .compile()
409 .unwrap();
410
411 let mut state = State::new();
413 state.insert("next".to_string(), json!("path_a"));
414 let next = graph.get_next_nodes(&["router".to_string()], &state);
415 assert_eq!(next, vec!["path_a".to_string()]);
416
417 state.insert("next".to_string(), json!("path_b"));
418 let next = graph.get_next_nodes(&["router".to_string()], &state);
419 assert_eq!(next, vec!["path_b".to_string()]);
420 }
421}