Skip to main content

lellm_graph/graph/
graph_analysis.rs

1//! Graph 静态分析 — 环检测、不可达节点、Fallback 诊断。
2//!
3//! 从 graph.rs 拆分出来,保持核心文件精简。
4
5use super::{Edge, Graph};
6use crate::error::{DiagnosticCategory, GraphDiagnostics};
7use crate::state::workflow_state::{MergeStrategy, WorkflowState};
8
9// ─── 环检测 ──────────────────────────────────────────────────────
10
11/// 查找所有环。
12pub(crate) fn find_all_cycles<S: WorkflowState, M: MergeStrategy<S>>(
13    graph: &Graph<S, M>,
14) -> Vec<Vec<String>> {
15    let adj = build_adj(graph);
16    let mut cycles = Vec::new();
17    for node in graph.nodes.keys() {
18        let mut in_path = std::collections::HashSet::new();
19        let mut path = Vec::new();
20        dfs_cycles(node, node, &adj, &mut in_path, &mut path, &mut cycles);
21    }
22    cycles
23}
24
25fn dfs_cycles(
26    start: &str,
27    current: &str,
28    adj: &std::collections::HashMap<String, Vec<String>>,
29    in_path: &mut std::collections::HashSet<String>,
30    path: &mut Vec<String>,
31    cycles: &mut Vec<Vec<String>>,
32) {
33    if in_path.contains(current) {
34        return;
35    }
36
37    path.push(current.to_string());
38    in_path.insert(current.to_string());
39
40    if let Some(neighbors) = adj.get(current) {
41        for neighbor in neighbors {
42            if neighbor.as_str() == start && path.len() >= 2 {
43                cycles.push(path.clone());
44            } else if neighbor.as_str() > start && !in_path.contains(neighbor) {
45                dfs_cycles(start, neighbor, adj, in_path, path, cycles);
46            }
47        }
48    }
49
50    path.pop();
51    in_path.remove(current);
52}
53
54/// 过滤未保护的环(环上所有边都没有 max_visits 约束)。
55pub(crate) fn filter_unprotected_cycles<S: WorkflowState, M: MergeStrategy<S>>(
56    graph: &Graph<S, M>,
57    cycles: &[Vec<String>],
58) -> Vec<Vec<String>> {
59    let mut unprotected: Vec<Vec<String>> = cycles
60        .iter()
61        .filter(|cycle| {
62            let has_protection = (0..cycle.len()).any(|i| {
63                let next = (i + 1) % cycle.len();
64                let from = cycle[i].as_str();
65                let to = cycle[next].as_str();
66                graph.edges.iter().any(|e| {
67                    e.from == from
68                        && e.to == to
69                        && e.analysis.as_ref().is_some_and(|a| a.max_visits.is_some())
70                })
71            });
72            !has_protection
73        })
74        .cloned()
75        .collect();
76    unprotected.sort();
77    unprotected.dedup();
78    unprotected
79}
80
81/// 构建邻接表。
82pub(crate) fn build_adj<S: WorkflowState, M: MergeStrategy<S>>(
83    graph: &Graph<S, M>,
84) -> std::collections::HashMap<String, Vec<String>> {
85    let mut adj: std::collections::HashMap<String, Vec<String>> = std::collections::HashMap::new();
86    for edge in &graph.edges {
87        adj.entry(edge.from.clone())
88            .or_default()
89            .push(edge.to.clone());
90    }
91    adj
92}
93
94// ─── CycleAnalysis ─────────────────────────────────────────────
95
96/// 环分析诊断结果。
97#[derive(Debug, Clone)]
98pub struct CycleAnalysis {
99    pub has_cycles: bool,
100    pub cycles: Vec<Vec<String>>,
101    pub unprotected_cycles: Vec<Vec<String>>,
102    pub total_edges: usize,
103    pub protected_edges: usize,
104}
105
106impl CycleAnalysis {
107    pub fn all_protected(&self) -> bool {
108        self.unprotected_cycles.is_empty()
109    }
110
111    pub fn report(&self) -> String {
112        let mut lines = Vec::new();
113        lines.push("=== Graph Cycle Analysis ===".to_string());
114
115        if !self.has_cycles {
116            lines.push("No cycles detected — graph is a DAG.".to_string());
117            return lines.join("\n");
118        }
119
120        lines.push(format!("Found {} cycle(s).", self.cycles.len()));
121        lines.push(format!(
122            "Edge protection: {}/{} edges have analysis set.",
123            self.protected_edges, self.total_edges
124        ));
125
126        for (i, cycle) in self.cycles.iter().enumerate() {
127            let cycle_str = cycle.join(" → ");
128            lines.push(format!("  Cycle {}: {} → {}", i + 1, cycle_str, cycle[0]));
129
130            if self.unprotected_cycles.contains(cycle) {
131                lines.push("    ⚠️ UNPROTECTED — no max_visits on back-edge".into());
132            } else {
133                lines.push("    ✅ Protected by edge-level analysis".into());
134            }
135        }
136
137        if !self.all_protected() {
138            lines.push("".into());
139            lines.push("⚠️ Recommendation: Set analysis.max_visits on back-edges.".to_string());
140        }
141
142        lines.join("\n")
143    }
144}
145
146// ─── Graph::analyze() 调用的分析逻辑 ─────────────────────────
147
148/// 执行完整图诊断分析(由 Graph::analyze() 调用)。
149pub(crate) fn analyze_graph<S: WorkflowState, M: MergeStrategy<S>>(
150    graph: &Graph<S, M>,
151) -> GraphDiagnostics {
152    let mut diag = GraphDiagnostics::new();
153    let adj = build_adj(graph);
154
155    let cycles = find_all_cycles(graph);
156    if !cycles.is_empty() {
157        let unprotected = filter_unprotected_cycles(graph, &cycles);
158        for cycle in &unprotected {
159            let cycle_str = format_cycle(cycle);
160            diag.add_warning(
161                DiagnosticCategory::Cycle,
162                format!("cycle detected: {} → {}", cycle_str, cycle[0]),
163            );
164        }
165        for cycle in &cycles {
166            if !unprotected.contains(cycle) {
167                let cycle_str = format_cycle(cycle);
168                diag.add_info(
169                    DiagnosticCategory::Cycle,
170                    format!(
171                        "protected cycle: {} → {} (has max_visits)",
172                        cycle_str, cycle[0]
173                    ),
174                );
175            }
176        }
177    }
178
179    check_fallback_in_cycles(graph, &cycles, &mut diag);
180    check_unreachable_nodes(graph, &adj, &mut diag);
181    check_end_node_outgoing(graph, &mut diag);
182
183    diag
184}
185
186// ─── 诊断辅助函数 ───────────────────────────────────────────────
187
188fn format_cycle(cycle: &[String]) -> String {
189    cycle.join(" → ")
190}
191
192fn check_fallback_in_cycles<S: WorkflowState, M: MergeStrategy<S>>(
193    graph: &Graph<S, M>,
194    cycles: &[Vec<String>],
195    diag: &mut GraphDiagnostics,
196) {
197    let fallback_edges: std::collections::HashSet<(&str, &str)> = graph
198        .edges
199        .iter()
200        .filter(|e| e.fallback)
201        .map(|e| (e.from.as_str(), e.to.as_str()))
202        .collect();
203
204    if fallback_edges.is_empty() {
205        return;
206    }
207
208    for cycle in cycles {
209        for i in 0..cycle.len() {
210            let next = (i + 1) % cycle.len();
211            let from = cycle[i].as_str();
212            let to = cycle[next].as_str();
213            if fallback_edges.contains(&(from, to)) {
214                let edge_str = format!("{} → {}", from, to);
215                diag.add_warning(
216                    DiagnosticCategory::FallbackInCycle,
217                    format!(
218                        "fallback edge {} participates in cycle: {} → {}",
219                        edge_str,
220                        format_cycle(cycle),
221                        cycle[0]
222                    ),
223                );
224            }
225        }
226    }
227}
228
229fn check_unreachable_nodes<S: WorkflowState, M: MergeStrategy<S>>(
230    graph: &Graph<S, M>,
231    adj: &std::collections::HashMap<String, Vec<String>>,
232    diag: &mut GraphDiagnostics,
233) {
234    let mut visited = std::collections::HashSet::new();
235    let mut queue = Vec::new();
236
237    queue.push(graph.start.clone());
238    visited.insert(graph.start.clone());
239
240    while let Some(node) = queue.pop() {
241        if let Some(neighbors) = adj.get(&node) {
242            for neighbor in neighbors {
243                if visited.insert(neighbor.clone()) {
244                    queue.push(neighbor.clone());
245                }
246            }
247        }
248    }
249
250    for name in graph.nodes.keys() {
251        if !visited.contains(name) {
252            diag.add_info(
253                DiagnosticCategory::Unreachable,
254                format!(
255                    "node '{}' is not reachable from start node '{}'",
256                    name, graph.start
257                ),
258            );
259        }
260    }
261}
262
263fn check_end_node_outgoing<S: WorkflowState, M: MergeStrategy<S>>(
264    graph: &Graph<S, M>,
265    diag: &mut GraphDiagnostics,
266) {
267    let outgoing: Vec<&Edge<S>> = graph.edges.iter().filter(|e| e.from == graph.end).collect();
268
269    if !outgoing.is_empty() {
270        let targets: Vec<&str> = outgoing.iter().map(|e| e.to.as_str()).collect();
271        diag.add_info(
272            DiagnosticCategory::EndNodeOutgoing,
273            format!(
274                "end node '{}' has {} outgoing edge(s) to: {:?}",
275                graph.end,
276                outgoing.len(),
277                targets
278            ),
279        );
280    }
281}