1use std::sync::Arc;
13
14use indexmap::IndexMap;
15
16use crate::error::{BuildError, BuildErrors, DiagnosticCategory, GraphDiagnostics};
17use crate::node::{FlowNode, NodeKind};
18use crate::node_context::NodeContext;
19use crate::state::{State, StateMerge};
20use crate::workflow_state::{MergeStrategy, WorkflowState};
21
22pub type EdgeCondition<S> = Arc<dyn Fn(&S) -> bool + Send + Sync>;
26
27#[derive(Clone)]
29pub struct Edge<S: WorkflowState = State> {
30 pub from: String,
31 pub to: String,
32 pub condition: Option<EdgeCondition<S>>,
34 pub analysis: Option<EdgeAnalysis>,
36 pub fallback: bool,
38}
39
40impl<S: WorkflowState> Edge<S> {
41 pub fn is_conditional(&self) -> bool {
43 self.condition.is_some() && !self.fallback
44 }
45
46 pub fn is_normal(&self) -> bool {
48 self.condition.is_none() && !self.fallback
49 }
50}
51
52impl<S: WorkflowState> std::fmt::Debug for Edge<S> {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("Edge")
55 .field("from", &self.from)
56 .field("to", &self.to)
57 .field("has_condition", &self.condition.is_some())
58 .field("analysis", &self.analysis)
59 .field("fallback", &self.fallback)
60 .finish()
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct EdgeAnalysis {
67 pub max_visits: Option<usize>,
69}
70
71#[derive(Clone)]
75pub struct Graph<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
76 pub(crate) name: String,
77 pub(crate) nodes: IndexMap<String, NodeKind<S, M>>,
78 pub(crate) edges: Vec<Edge<S>>,
79 pub(crate) start: String,
80 pub(crate) end: String,
81}
82
83impl<S: WorkflowState, M: MergeStrategy<S>> Graph<S, M> {
84 pub fn name(&self) -> &str {
85 &self.name
86 }
87
88 pub fn node_names(&self) -> Vec<&str> {
89 self.nodes.keys().map(|s| s.as_str()).collect()
90 }
91
92 pub fn start_node(&self) -> &str {
93 &self.start
94 }
95
96 pub fn end_node(&self) -> &str {
97 &self.end
98 }
99
100 pub fn hash(&self) -> String {
102 let mut s = String::new();
103 let mut names: Vec<&str> = self.nodes.keys().map(|k| k.as_str()).collect();
104 names.sort();
105 s.push_str(&names.join(","));
106 s.push('|');
107 let mut edge_strs: Vec<String> = self
108 .edges
109 .iter()
110 .map(|e| {
111 format!(
112 "{}->{}{:?}{}",
113 e.from,
114 e.to,
115 if e.condition.is_some() { "?" } else { "" },
116 if e.fallback { "!" } else { "" }
117 )
118 })
119 .collect();
120 edge_strs.sort();
121 s.push_str(&edge_strs.join(","));
122 let hash = fnv_hash(&s);
123 format!("{:016x}", hash)
124 }
125
126 pub fn edges_from(&self, from: &str) -> Vec<&Edge<S>> {
127 self.edges.iter().filter(|e| e.from == from).collect()
128 }
129
130 pub fn find_edge(&self, from: &str, to: &str) -> Option<&Edge<S>> {
131 self.edges.iter().find(|e| e.from == from && e.to == to)
132 }
133
134 pub fn node_map(&self) -> &IndexMap<String, NodeKind<S, M>> {
136 &self.nodes
137 }
138
139 pub fn resolve_next(&self, current: &str, state: &S) -> Option<String> {
141 let edges = self.edges_from(current);
142
143 for edge in &edges {
145 if edge.is_conditional() && edge.condition.as_ref().is_some_and(|c| c(state)) {
146 return Some(edge.to.clone());
147 }
148 }
149
150 for edge in &edges {
152 if edge.is_normal() {
153 return Some(edge.to.clone());
154 }
155 }
156
157 for edge in &edges {
159 if edge.fallback {
160 return Some(edge.to.clone());
161 }
162 }
163
164 None
165 }
166
167 pub fn find_fallback_edge(&self, from: &str) -> Option<String> {
168 self.edges
169 .iter()
170 .find(|e| e.from == from && e.fallback)
171 .map(|e| e.to.clone())
172 }
173
174 pub fn validate(&self) -> Result<(), crate::error::TerminalError> {
176 if !self.nodes.contains_key(&self.start) {
177 return Err(crate::error::TerminalError::InvalidGraph(format!(
178 "start node '{}' not found",
179 self.start
180 )));
181 }
182
183 if !self.nodes.contains_key(&self.end) {
184 return Err(crate::error::TerminalError::InvalidGraph(format!(
185 "end node '{}' not found",
186 self.end
187 )));
188 }
189
190 for edge in &self.edges {
191 if !self.nodes.contains_key(&edge.from) {
192 return Err(crate::error::TerminalError::InvalidGraph(format!(
193 "edge references non-existent source node '{}'",
194 edge.from
195 )));
196 }
197 if !self.nodes.contains_key(&edge.to) {
198 return Err(crate::error::TerminalError::InvalidGraph(format!(
199 "edge references non-existent target node '{}'",
200 edge.to
201 )));
202 }
203 }
204
205 Ok(())
206 }
207
208 pub fn analyze(&self) -> GraphDiagnostics {
210 let mut diag = GraphDiagnostics::new();
211 let adj = self.build_adj();
212
213 let cycles = self.find_all_cycles(&adj);
214 if !cycles.is_empty() {
215 let unprotected = self.filter_unprotected_cycles(&cycles);
216 for cycle in &unprotected {
217 let cycle_str = format_cycle(cycle);
218 diag.add_warning(
219 DiagnosticCategory::Cycle,
220 format!("cycle detected: {} → {}", cycle_str, cycle[0]),
221 );
222 }
223 for cycle in &cycles {
224 if !unprotected.contains(cycle) {
225 let cycle_str = format_cycle(cycle);
226 diag.add_info(
227 DiagnosticCategory::Cycle,
228 format!(
229 "protected cycle: {} → {} (has max_visits)",
230 cycle_str, cycle[0]
231 ),
232 );
233 }
234 }
235 }
236
237 check_fallback_in_cycles(self, &cycles, &mut diag);
238 check_unreachable_nodes(self, &adj, &mut diag);
239 check_end_node_outgoing(self, &mut diag);
240
241 diag
242 }
243
244 fn build_adj(&self) -> std::collections::HashMap<String, Vec<String>> {
245 let mut adj: std::collections::HashMap<String, Vec<String>> =
246 std::collections::HashMap::new();
247 for edge in &self.edges {
248 adj.entry(edge.from.clone())
249 .or_default()
250 .push(edge.to.clone());
251 }
252 adj
253 }
254
255 pub async fn run_inline(
259 &self,
260 ctx: &mut NodeContext<'_, S>,
261 max_steps: usize,
262 ) -> Result<(), crate::error::GraphError> {
263 use crate::node_context::NextAction;
264
265 let mut current = self.start_node().to_string();
266 let mut step: usize = 0;
267
268 loop {
269 step += 1;
270 if step > max_steps {
271 return Err(crate::error::GraphError::Terminal(
272 crate::error::TerminalError::StepsExceeded { limit: max_steps },
273 ));
274 }
275
276 let node = self.nodes.get(¤t).ok_or_else(|| {
277 crate::error::GraphError::Terminal(crate::error::TerminalError::NodeNotFound(
278 current.clone(),
279 ))
280 })?;
281
282 node.execute(ctx).await?;
284
285 let effects = ctx.consume_effects();
287 ctx.state_mut().apply_batch(effects);
288
289 let (next_action, _signal) = ctx.take_control();
291
292 match next_action {
294 NextAction::End => return Ok(()),
295 NextAction::Goto(target) => {
296 current = target;
297 }
298 NextAction::Next => {
299 if current == self.end_node() {
300 return Ok(());
301 }
302 current = self.resolve_next_inline(¤t, ctx.state())?;
303 }
304 }
305 }
306 }
307
308 fn resolve_next_inline(
310 &self,
311 current: &str,
312 state: &S,
313 ) -> Result<String, crate::error::GraphError> {
314 let edges = self.edges_from(current);
315
316 if edges.is_empty() {
317 return Err(crate::error::GraphError::Terminal(
318 crate::error::TerminalError::InvalidGraph(format!(
319 "node '{}' has no outgoing edges and is not the end node",
320 current
321 )),
322 ));
323 }
324
325 for edge in &edges {
327 if edge.is_conditional() && edge.condition.as_ref().is_some_and(|c| c(state)) {
328 return Ok(edge.to.clone());
329 }
330 }
331
332 for edge in &edges {
334 if edge.is_normal() {
335 return Ok(edge.to.clone());
336 }
337 }
338
339 for edge in &edges {
341 if edge.fallback {
342 return Ok(edge.to.clone());
343 }
344 }
345
346 Err(crate::error::GraphError::Terminal(
347 crate::error::TerminalError::InvalidGraph(format!(
348 "node '{}' has no matching outgoing edge",
349 current
350 )),
351 ))
352 }
353
354 fn find_all_cycles(
356 &self,
357 adj: &std::collections::HashMap<String, Vec<String>>,
358 ) -> Vec<Vec<String>> {
359 let mut cycles = Vec::new();
360 for node in self.nodes.keys() {
361 let mut in_path = std::collections::HashSet::new();
362 let mut path = Vec::new();
363 self.dfs_cycles(node, node, adj, &mut in_path, &mut path, &mut cycles);
364 }
365 cycles
366 }
367
368 fn dfs_cycles(
369 &self,
370 start: &str,
371 current: &str,
372 adj: &std::collections::HashMap<String, Vec<String>>,
373 in_path: &mut std::collections::HashSet<String>,
374 path: &mut Vec<String>,
375 cycles: &mut Vec<Vec<String>>,
376 ) {
377 if in_path.contains(current) {
378 return;
379 }
380
381 path.push(current.to_string());
382 in_path.insert(current.to_string());
383
384 if let Some(neighbors) = adj.get(current) {
385 for neighbor in neighbors {
386 if neighbor.as_str() == start && path.len() >= 2 {
387 cycles.push(path.clone());
388 } else if neighbor.as_str() > start && !in_path.contains(neighbor) {
389 self.dfs_cycles(start, neighbor, adj, in_path, path, cycles);
390 }
391 }
392 }
393
394 path.pop();
395 in_path.remove(current);
396 }
397
398 fn filter_unprotected_cycles(&self, cycles: &[Vec<String>]) -> Vec<Vec<String>> {
399 let mut unprotected: Vec<Vec<String>> = cycles
400 .iter()
401 .filter(|cycle| {
402 let has_protection = (0..cycle.len()).any(|i| {
403 let next = (i + 1) % cycle.len();
404 let from = cycle[i].as_str();
405 let to = cycle[next].as_str();
406 self.edges.iter().any(|e| {
407 e.from == from
408 && e.to == to
409 && e.analysis.as_ref().is_some_and(|a| a.max_visits.is_some())
410 })
411 });
412 !has_protection
413 })
414 .cloned()
415 .collect();
416 unprotected.sort();
417 unprotected.dedup();
418 unprotected
419 }
420
421 pub fn analyze_cycles(&self) -> CycleAnalysis {
423 let adj = self.build_adj();
424 let cycles = self.find_all_cycles(&adj);
425 let unprotected = self.filter_unprotected_cycles(&cycles);
426
427 CycleAnalysis {
428 has_cycles: !cycles.is_empty(),
429 cycles,
430 unprotected_cycles: unprotected,
431 total_edges: self.edges.len(),
432 protected_edges: self
433 .edges
434 .iter()
435 .filter(|e| e.analysis.as_ref().is_some_and(|a| a.max_visits.is_some()))
436 .count(),
437 }
438 }
439}
440
441#[derive(Debug, Clone)]
443pub struct CycleAnalysis {
444 pub has_cycles: bool,
445 pub cycles: Vec<Vec<String>>,
446 pub unprotected_cycles: Vec<Vec<String>>,
447 pub total_edges: usize,
448 pub protected_edges: usize,
449}
450
451impl CycleAnalysis {
452 pub fn all_protected(&self) -> bool {
453 self.unprotected_cycles.is_empty()
454 }
455
456 pub fn report(&self) -> String {
457 let mut lines = Vec::new();
458 lines.push("=== Graph Cycle Analysis ===".to_string());
459
460 if !self.has_cycles {
461 lines.push("No cycles detected — graph is a DAG.".to_string());
462 return lines.join("\n");
463 }
464
465 lines.push(format!("Found {} cycle(s).", self.cycles.len()));
466 lines.push(format!(
467 "Edge protection: {}/{} edges have analysis set.",
468 self.protected_edges, self.total_edges
469 ));
470
471 for (i, cycle) in self.cycles.iter().enumerate() {
472 let cycle_str = cycle.join(" → ");
473 lines.push(format!(" Cycle {}: {} → {}", i + 1, cycle_str, cycle[0]));
474
475 if self.unprotected_cycles.contains(cycle) {
476 lines.push(" ⚠️ UNPROTECTED — no max_visits on back-edge".into());
477 } else {
478 lines.push(" ✅ Protected by edge-level analysis".into());
479 }
480 }
481
482 if !self.all_protected() {
483 lines.push("".into());
484 lines.push("⚠️ Recommendation: Set analysis.max_visits on back-edges.".to_string());
485 }
486
487 lines.join("\n")
488 }
489}
490
491pub struct PendingEdge<'a, S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
495 builder: &'a mut GraphBuilder<S, M>,
496 edge_index: usize,
497}
498
499impl<'a, S: WorkflowState, M: MergeStrategy<S>> PendingEdge<'a, S, M> {
500 pub fn max_visits(self, n: usize) -> &'a mut GraphBuilder<S, M> {
501 self.builder.edges[self.edge_index].analysis = Some(EdgeAnalysis {
502 max_visits: Some(n),
503 });
504 self.builder
505 }
506}
507
508pub struct GraphBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
512 name: String,
513 nodes: IndexMap<String, NodeKind<S, M>>,
514 edges: Vec<Edge<S>>,
515 start: Option<String>,
516 end: Option<String>,
517}
518
519impl<S: WorkflowState, M: MergeStrategy<S>> GraphBuilder<S, M> {
520 pub fn new(name: impl Into<String>) -> Self {
526 Self {
527 name: name.into(),
528 nodes: IndexMap::new(),
529 edges: Vec::new(),
530 start: None,
531 end: None,
532 }
533 }
534 pub fn start(&mut self, node: impl Into<String>) -> &mut Self {
535 self.start = Some(node.into());
536 self
537 }
538
539 pub fn end(&mut self, node: impl Into<String>) -> &mut Self {
540 self.end = Some(node.into());
541 self
542 }
543
544 pub fn node(&mut self, name: impl Into<String>, kind: NodeKind<S, M>) -> &mut Self {
545 self.nodes.insert(name.into(), kind);
546 self
547 }
548
549 pub fn edge(
550 &mut self,
551 from: impl Into<String>,
552 to: impl Into<String>,
553 ) -> PendingEdge<'_, S, M> {
554 let edge_index = self.edges.len();
555 self.edges.push(Edge {
556 from: from.into(),
557 to: to.into(),
558 condition: None,
559 analysis: None,
560 fallback: false,
561 });
562 PendingEdge {
563 builder: self,
564 edge_index,
565 }
566 }
567
568 pub fn edge_if(
569 &mut self,
570 from: impl Into<String>,
571 to: impl Into<String>,
572 condition: impl Fn(&S) -> bool + Send + Sync + 'static,
573 ) -> PendingEdge<'_, S, M> {
574 let edge_index = self.edges.len();
575 self.edges.push(Edge {
576 from: from.into(),
577 to: to.into(),
578 condition: Some(Arc::new(condition)),
579 analysis: None,
580 fallback: false,
581 });
582 PendingEdge {
583 builder: self,
584 edge_index,
585 }
586 }
587
588 pub fn edge_fallback(
589 &mut self,
590 from: impl Into<String>,
591 to: impl Into<String>,
592 ) -> PendingEdge<'_, S, M> {
593 let edge_index = self.edges.len();
594 self.edges.push(Edge {
595 from: from.into(),
596 to: to.into(),
597 condition: None,
598 analysis: None,
599 fallback: true,
600 });
601 PendingEdge {
602 builder: self,
603 edge_index,
604 }
605 }
606
607 pub fn build(self) -> Result<Graph<S, M>, BuildErrors> {
608 let mut errors = BuildErrors::new();
609
610 let start = match self.start {
611 Some(s) => s,
612 None => {
613 errors.push(BuildError::MissingEntryPoint);
614 return Err(errors);
615 }
616 };
617 let end = match self.end {
618 Some(s) => s,
619 None => {
620 errors.push(BuildError::MissingExitPoint);
621 return Err(errors);
622 }
623 };
624
625 let mut seen_nodes = std::collections::HashSet::new();
626 for name in self.nodes.keys() {
627 if !seen_nodes.insert(name.clone()) {
628 errors.push(BuildError::DuplicateNode { id: name.clone() });
629 }
630 }
631
632 for edge in &self.edges {
633 if !self.nodes.contains_key(&edge.from) {
634 errors.push(BuildError::MissingNode {
635 from: edge.from.clone(),
636 to: edge.to.clone(),
637 });
638 }
639 if !self.nodes.contains_key(&edge.to) {
640 errors.push(BuildError::MissingNode {
641 from: edge.from.clone(),
642 to: edge.to.clone(),
643 });
644 }
645 }
646
647 if !errors.is_empty() {
648 return Err(errors);
649 }
650
651 let graph = Graph {
652 name: self.name,
653 nodes: self.nodes,
654 edges: self.edges,
655 start,
656 end,
657 };
658
659 if let Err(e) = graph.validate() {
660 return Err(BuildErrors(vec![BuildError::InvalidEdgeDefinition {
661 from: "graph".to_string(),
662 to: "graph".to_string(),
663 reason: e.to_string(),
664 }]));
665 }
666
667 Ok(graph)
668 }
669
670 pub fn name(&self) -> &str {
671 &self.name
672 }
673}
674
675fn format_cycle(cycle: &[String]) -> String {
678 cycle.join(" → ")
679}
680
681fn check_fallback_in_cycles<S: WorkflowState, M: MergeStrategy<S>>(
682 graph: &Graph<S, M>,
683 cycles: &[Vec<String>],
684 diag: &mut GraphDiagnostics,
685) {
686 let fallback_edges: std::collections::HashSet<(&str, &str)> = graph
687 .edges
688 .iter()
689 .filter(|e| e.fallback)
690 .map(|e| (e.from.as_str(), e.to.as_str()))
691 .collect();
692
693 if fallback_edges.is_empty() {
694 return;
695 }
696
697 for cycle in cycles {
698 for i in 0..cycle.len() {
699 let next = (i + 1) % cycle.len();
700 let from = cycle[i].as_str();
701 let to = cycle[next].as_str();
702 if fallback_edges.contains(&(from, to)) {
703 let edge_str = format!("{} → {}", from, to);
704 diag.add_warning(
705 DiagnosticCategory::FallbackInCycle,
706 format!(
707 "fallback edge {} participates in cycle: {} → {}",
708 edge_str,
709 format_cycle(cycle),
710 cycle[0]
711 ),
712 );
713 }
714 }
715 }
716}
717
718fn check_unreachable_nodes<S: WorkflowState, M: MergeStrategy<S>>(
719 graph: &Graph<S, M>,
720 adj: &std::collections::HashMap<String, Vec<String>>,
721 diag: &mut GraphDiagnostics,
722) {
723 let mut visited = std::collections::HashSet::new();
724 let mut queue = Vec::new();
725
726 queue.push(graph.start.clone());
727 visited.insert(graph.start.clone());
728
729 while let Some(node) = queue.pop() {
730 if let Some(neighbors) = adj.get(&node) {
731 for neighbor in neighbors {
732 if visited.insert(neighbor.clone()) {
733 queue.push(neighbor.clone());
734 }
735 }
736 }
737 }
738
739 for name in graph.nodes.keys() {
740 if !visited.contains(name) {
741 diag.add_info(
742 DiagnosticCategory::Unreachable,
743 format!(
744 "node '{}' is not reachable from start node '{}'",
745 name, graph.start
746 ),
747 );
748 }
749 }
750}
751
752fn check_end_node_outgoing<S: WorkflowState, M: MergeStrategy<S>>(
753 graph: &Graph<S, M>,
754 diag: &mut GraphDiagnostics,
755) {
756 let outgoing: Vec<&Edge<S>> = graph.edges.iter().filter(|e| e.from == graph.end).collect();
757
758 if !outgoing.is_empty() {
759 let targets: Vec<&str> = outgoing.iter().map(|e| e.to.as_str()).collect();
760 diag.add_info(
761 DiagnosticCategory::EndNodeOutgoing,
762 format!(
763 "end node '{}' has {} outgoing edge(s) to: {:?}",
764 graph.end,
765 outgoing.len(),
766 targets
767 ),
768 );
769 }
770}
771
772fn fnv_hash(s: &str) -> u64 {
773 let mut hash: u64 = 0xcbf29ce484222325;
774 for &byte in s.as_bytes() {
775 hash ^= byte as u64;
776 hash = hash.wrapping_mul(0x100000001b3);
777 }
778 hash
779}