1use crate::edge::Edge;
4use crate::error::{GraphError, GraphResult};
5use crate::executor::ExecutionOptions;
6use crate::node::{BaseNode, Node, NodeDef, NodeResult};
7use crate::state::{generate_run_id, GraphRunContext, GraphRunResult, GraphState};
8use std::collections::HashMap;
9
10pub struct Graph<State, Deps = (), End = ()>
12where
13 State: GraphState,
14{
15 name: Option<String>,
16 pub nodes: HashMap<String, NodeDef<State, Deps, End>>,
18 edges: Vec<Edge<State>>,
19 entry_node: Option<String>,
20 finish_nodes: Vec<String>,
21 max_steps: u32,
22 auto_instrument: bool,
23}
24
25impl<State, Deps, End> Graph<State, Deps, End>
26where
27 State: GraphState,
28 Deps: Send + Sync + 'static,
29 End: Send + Sync + 'static,
30{
31 pub fn new() -> Self {
33 Self {
34 name: None,
35 nodes: HashMap::new(),
36 edges: Vec::new(),
37 entry_node: None,
38 finish_nodes: Vec::new(),
39 max_steps: 100,
40 auto_instrument: true,
41 }
42 }
43
44 pub fn with_name(mut self, name: impl Into<String>) -> Self {
46 self.name = Some(name.into());
47 self
48 }
49
50 pub fn with_max_steps(mut self, max: u32) -> Self {
52 self.max_steps = max;
53 self
54 }
55
56 pub fn without_instrumentation(mut self) -> Self {
58 self.auto_instrument = false;
59 self
60 }
61
62 pub fn node<N>(mut self, name: impl Into<String>, node: N) -> Self
64 where
65 N: BaseNode<State, Deps, End> + 'static,
66 {
67 let name = name.into();
68 self.nodes.insert(name.clone(), NodeDef::new(name, node));
69 self
70 }
71
72 pub fn edge<F>(mut self, from: impl Into<String>, to: impl Into<String>, condition: F) -> Self
74 where
75 F: Fn(&State) -> bool + Send + Sync + 'static,
76 {
77 self.edges.push(Edge::new(from, to, condition));
78 self
79 }
80
81 pub fn edge_always(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
83 self.edges.push(Edge::unconditional(from, to));
84 self
85 }
86
87 pub fn entry(mut self, name: impl Into<String>) -> Self {
89 self.entry_node = Some(name.into());
90 self
91 }
92
93 pub fn finish(mut self, names: &[&str]) -> Self {
95 self.finish_nodes = names.iter().map(|s| s.to_string()).collect();
96 self
97 }
98
99 pub fn add_finish(mut self, name: impl Into<String>) -> Self {
101 self.finish_nodes.push(name.into());
102 self
103 }
104
105 pub fn name(&self) -> Option<&str> {
107 self.name.as_deref()
108 }
109
110 pub fn node_names(&self) -> impl Iterator<Item = &str> {
112 self.nodes.keys().map(|s| s.as_str())
113 }
114
115 pub fn node_count(&self) -> usize {
117 self.nodes.len()
118 }
119
120 pub fn edge_count(&self) -> usize {
122 self.edges.len()
123 }
124
125 pub fn edges(&self) -> &[Edge<State>] {
127 &self.edges
128 }
129
130 fn detect_cycle(
131 node: &str,
132 adjacency: &HashMap<String, Vec<String>>,
133 visiting: &mut std::collections::HashSet<String>,
134 visited: &mut std::collections::HashSet<String>,
135 ) -> bool {
136 if visited.contains(node) {
137 return false;
138 }
139 if visiting.contains(node) {
140 return true;
141 }
142
143 visiting.insert(node.to_string());
144 if let Some(neighbors) = adjacency.get(node) {
145 for neighbor in neighbors {
146 if Self::detect_cycle(neighbor, adjacency, visiting, visited) {
147 return true;
148 }
149 }
150 }
151 visiting.remove(node);
152 visited.insert(node.to_string());
153 false
154 }
155
156 pub fn validate(&self) -> GraphResult<()> {
158 if let Some(ref entry) = self.entry_node {
160 if !self.nodes.contains_key(entry) {
161 return Err(GraphError::node_not_found(entry));
162 }
163 } else {
164 return Err(GraphError::NoEntryNode);
165 }
166
167 for edge in &self.edges {
169 if !self.nodes.contains_key(&edge.from) {
170 return Err(GraphError::node_not_found(&edge.from));
171 }
172 if !self.nodes.contains_key(&edge.to) {
173 return Err(GraphError::node_not_found(&edge.to));
174 }
175 }
176
177 for finish in &self.finish_nodes {
179 if !self.nodes.contains_key(finish) {
180 return Err(GraphError::node_not_found(finish));
181 }
182 }
183
184 let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
185 for edge in &self.edges {
186 adjacency
187 .entry(edge.from.clone())
188 .or_default()
189 .push(edge.to.clone());
190 }
191
192 let mut visiting = std::collections::HashSet::new();
193 let mut visited = std::collections::HashSet::new();
194 for node in self.nodes.keys() {
195 if Self::detect_cycle(node, &adjacency, &mut visiting, &mut visited) {
196 return Err(GraphError::CycleDetected);
197 }
198 }
199
200 Ok(())
201 }
202
203 pub fn build(self) -> GraphResult<Self> {
205 self.validate()?;
206 Ok(self)
207 }
208}
209
210impl<State, Deps, End> Graph<State, Deps, End>
211where
212 State: GraphState,
213 Deps: Clone + Send + Sync + 'static,
214 End: Clone + Send + Sync + 'static,
215{
216 pub async fn run(&self, state: State, deps: Deps) -> GraphResult<GraphRunResult<State, End>> {
218 let options = ExecutionOptions::new()
219 .max_steps(self.max_steps)
220 .tracing(self.auto_instrument);
221 self.run_with_options(state, deps, options).await
222 }
223
224 pub async fn run_with_options(
226 &self,
227 state: State,
228 deps: Deps,
229 options: ExecutionOptions,
230 ) -> GraphResult<GraphRunResult<State, End>> {
231 let entry = self.entry_node.as_ref().ok_or(GraphError::NoEntryNode)?;
232 let start_node = self
233 .nodes
234 .get(entry)
235 .ok_or_else(|| GraphError::node_not_found(entry))?;
236
237 self.run_from_with_options(&*start_node.node, state, deps, options)
238 .await
239 }
240
241 pub async fn run_from<N>(
243 &self,
244 start: &N,
245 state: State,
246 deps: Deps,
247 ) -> GraphResult<GraphRunResult<State, End>>
248 where
249 N: BaseNode<State, Deps, End> + ?Sized,
250 {
251 let options = ExecutionOptions::new()
252 .max_steps(self.max_steps)
253 .tracing(self.auto_instrument);
254 self.run_from_with_options(start, state, deps, options)
255 .await
256 }
257
258 pub async fn run_from_with_options<N>(
260 &self,
261 start: &N,
262 state: State,
263 deps: Deps,
264 mut options: ExecutionOptions,
265 ) -> GraphResult<GraphRunResult<State, End>>
266 where
267 N: BaseNode<State, Deps, End> + ?Sized,
268 {
269 let run_id = options.run_id.take().unwrap_or_else(generate_run_id);
270 let max_steps = options.max_steps;
271 let mut ctx = GraphRunContext::new(state, deps, &run_id).with_max_steps(max_steps);
272 let mut history = Vec::new();
273 let mut steps = 0;
274
275 steps += 1;
276 if steps > max_steps {
277 return Err(GraphError::MaxStepsExceeded(max_steps));
278 }
279 ctx.increment_step();
280 let node_name = start.name().to_string();
281 history.push(node_name);
282
283 let mut result = start.run(&mut ctx).await?;
284
285 loop {
286 match result {
287 NodeResult::Next(next) => {
288 steps += 1;
289 if steps > max_steps {
290 return Err(GraphError::MaxStepsExceeded(max_steps));
291 }
292 ctx.increment_step();
293 let name = next.name().to_string();
294 history.push(name);
295 result = next.run(&mut ctx).await?;
296 }
297 NodeResult::NextNamed(name) => {
298 let node = self
299 .nodes
300 .get(&name)
301 .ok_or_else(|| GraphError::node_not_found(&name))?;
302 steps += 1;
303 if steps > max_steps {
304 return Err(GraphError::MaxStepsExceeded(max_steps));
305 }
306 ctx.increment_step();
307 history.push(name);
308 result = node.node.run(&mut ctx).await?;
309 }
310 NodeResult::End(end) => {
311 return Ok(
312 GraphRunResult::new(end, ctx.state, ctx.step, run_id).with_history(history)
313 );
314 }
315 }
316 }
317 }
318}
319
320impl<State, Deps, End> Default for Graph<State, Deps, End>
321where
322 State: GraphState,
323 Deps: Send + Sync + 'static,
324 End: Send + Sync + 'static,
325{
326 fn default() -> Self {
327 Self::new()
328 }
329}
330
331pub struct SimpleGraph<State: GraphState> {
333 nodes: HashMap<String, Box<dyn Node<State>>>,
334 edges: Vec<Edge<State>>,
335 entry_node: Option<String>,
336 finish_nodes: Vec<String>,
337}
338
339impl<State: GraphState + 'static> SimpleGraph<State> {
340 pub fn new() -> Self {
342 Self {
343 nodes: HashMap::new(),
344 edges: Vec::new(),
345 entry_node: None,
346 finish_nodes: Vec::new(),
347 }
348 }
349
350 pub fn add_node(mut self, name: impl Into<String>, node: impl Node<State> + 'static) -> Self {
352 self.nodes.insert(name.into(), Box::new(node));
353 self
354 }
355
356 pub fn add_edge<F>(
358 mut self,
359 from: impl Into<String>,
360 to: impl Into<String>,
361 condition: F,
362 ) -> Self
363 where
364 F: Fn(&State) -> bool + Send + Sync + 'static,
365 {
366 self.edges.push(Edge::new(from, to, condition));
367 self
368 }
369
370 pub fn set_entry(mut self, name: impl Into<String>) -> Self {
372 self.entry_node = Some(name.into());
373 self
374 }
375
376 pub fn set_finish(mut self, names: &[&str]) -> Self {
378 self.finish_nodes = names.iter().map(|s| s.to_string()).collect();
379 self
380 }
381
382 pub fn build(self) -> GraphResult<Self> {
384 if self.entry_node.is_none() {
385 return Err(GraphError::NoEntryNode);
386 }
387 Ok(self)
388 }
389
390 pub async fn run(&self, mut state: State) -> GraphResult<State> {
392 let entry = self.entry_node.as_ref().ok_or(GraphError::NoEntryNode)?;
393 let mut current = entry.clone();
394
395 loop {
396 if self.finish_nodes.contains(¤t) {
397 break;
398 }
399
400 let node = self
401 .nodes
402 .get(¤t)
403 .ok_or_else(|| GraphError::node_not_found(¤t))?;
404
405 state = node.execute(state).await?;
406
407 let next = self
409 .edges
410 .iter()
411 .find(|e| e.from == current && e.matches(&state));
412 match next {
413 Some(edge) => current = edge.to.clone(),
414 None => break,
415 }
416 }
417
418 Ok(state)
419 }
420}
421
422impl<State: GraphState + 'static> Default for SimpleGraph<State> {
423 fn default() -> Self {
424 Self::new()
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use async_trait::async_trait;
432
433 #[derive(Debug, Clone, Default)]
434 struct TestState {
435 value: i32,
436 }
437
438 struct IncrementNode;
439
440 #[async_trait]
441 impl BaseNode<TestState, (), i32> for IncrementNode {
442 fn name(&self) -> &str {
443 "increment"
444 }
445
446 async fn run(
447 &self,
448 ctx: &mut GraphRunContext<TestState, ()>,
449 ) -> GraphResult<NodeResult<TestState, (), i32>> {
450 ctx.state.value += 1;
451 if ctx.state.value >= 3 {
452 Ok(NodeResult::end(ctx.state.value))
453 } else {
454 Ok(NodeResult::next(IncrementNode))
455 }
456 }
457 }
458
459 #[tokio::test]
460 async fn test_simple_graph_run() {
461 let graph = Graph::<TestState, (), i32>::new()
462 .with_name("test")
463 .node("start", IncrementNode)
464 .entry("start")
465 .build()
466 .unwrap();
467
468 let result = graph.run(TestState::default(), ()).await.unwrap();
469 assert_eq!(result.result, 3);
470 assert_eq!(result.steps, 3);
471 }
472
473 #[test]
474 fn test_graph_validation() {
475 let graph = Graph::<TestState, (), i32>::new()
476 .node("a", IncrementNode)
477 .entry("missing");
478
479 assert!(graph.build().is_err());
480 }
481
482 #[test]
483 fn test_graph_no_entry() {
484 let graph = Graph::<TestState, (), i32>::new().node("a", IncrementNode);
485
486 assert!(graph.build().is_err());
487 }
488}