Skip to main content

lellm_graph/
graph.rs

1//! Graph 和 GraphBuilder。
2//!
3//! Edge 三层语义:
4//! - `condition` — 业务路由条件(必须满足)
5//! - `analysis` — 分析用约束(不参与 runtime 决策)
6//! - `policy` — runtime policy(显式声明才生效)
7//!
8//! `fallback` — 兜底边,无匹配时优先尝试。
9
10use std::sync::Arc;
11
12use indexmap::IndexMap;
13
14use crate::error::BuildError;
15use crate::node::NodeKind;
16use crate::state::State;
17
18// ─── Edge ──────────────────────────────────────────────────────
19
20/// 边条件回调类型别名。
21/// Arc 包装以支持 Graph Clone(条件回调不可 Clone)。
22pub type EdgeCondition = Arc<dyn Fn(&State) -> bool + Send + Sync>;
23
24/// 边(Edge)— 三层语义叠加。
25pub struct Edge {
26    pub from: String,
27    pub to: String,
28    /// ① 业务路由条件(必须满足)
29    pub condition: Option<EdgeCondition>,
30    /// ② 分析用约束(不参与 runtime 决策)
31    pub analysis: Option<EdgeAnalysis>,
32    /// ③ runtime policy(显式声明才生效)
33    pub policy: Option<EdgePolicy>,
34    /// ④ fallback 标记 — 兜底边
35    pub fallback: bool,
36}
37
38/// 分析用约束 — 仅用于 `analyze_cycles()` 静态分析。
39///
40/// `analysis` = "你可能会出事",不参与执行控制。
41#[derive(Debug, Clone)]
42pub struct EdgeAnalysis {
43    /// 建议的最大访问次数 — 用于循环分析诊断。
44    /// 不参与运行时拦截。
45    pub max_visits: Option<usize>,
46}
47
48/// Runtime Policy — 显式声明的运行时拦截策略。
49///
50/// `policy` = "我现在要拦你",参与执行控制。
51#[derive(Debug, Clone)]
52pub enum EdgePolicy {
53    /// 限制边被 traversed 的次数。超过后按策略处理。
54    MaxVisits { limit: usize, on_exceeded: EdgeExceededStrategy },
55}
56
57/// Edge Policy 被 exceeded 时的处理策略。
58#[derive(Debug, Clone, Copy, Default)]
59pub enum EdgeExceededStrategy {
60    /// 严格模式(默认)— 路径失败,回溯到上一个 decision node
61    #[default]
62    Strict,
63    /// 软降级 — 尝试其他满足 condition 的 edge → fallback → 失败
64    SoftFallback,
65    /// 静默跳过 — 不报错,继续执行其他逻辑
66    Drop,
67}
68
69// ─── Graph ─────────────────────────────────────────────────────
70
71/// 图(Graph)— 允许有环,循环保护由 GraphExecutor::max_steps 运行时熔断提供。
72pub struct Graph {
73    pub(crate) nodes: IndexMap<String, NodeKind>,
74    pub(crate) edges: Vec<Edge>,
75    pub(crate) start: String,
76    pub(crate) end: String,
77}
78
79impl Graph {
80    pub fn node_names(&self) -> Vec<&str> {
81        self.nodes.keys().map(|s| s.as_str()).collect()
82    }
83
84    pub fn start_node(&self) -> &str {
85        &self.start
86    }
87
88    pub fn end_node(&self) -> &str {
89        &self.end
90    }
91
92    pub fn edges_from(&self, from: &str) -> Vec<&Edge> {
93        self.edges.iter().filter(|e| e.from == from).collect()
94    }
95
96    pub fn find_edge(&self, from: &str, to: &str) -> Option<&Edge> {
97        self.edges.iter().find(|e| e.from == from && e.to == to)
98    }
99
100    /// 查找指定节点的 fallback 边目标。
101    ///
102    /// 用于 RecoverableError 恢复:当边级 policy 触发 SoftFallback 时,
103    /// 寻找 fallback 边作为降级路径。
104    pub fn find_fallback_edge(&self, from: &str) -> Option<String> {
105        self.edges
106            .iter()
107            .find(|e| e.from == from && e.fallback)
108            .map(|e| e.to.clone())
109    }
110
111    /// 验证图结构(节点、边引用有效性)。
112    ///
113    /// 注意:不检测环 — 有环图是合法的,循环保护由 GraphExecutor::max_steps 提供。
114    pub fn validate(&self) -> Result<(), crate::error::TerminalError> {
115        if !self.nodes.contains_key(&self.start) {
116            return Err(crate::error::TerminalError::InvalidGraph(format!(
117                "start node '{}' not found",
118                self.start
119            )));
120        }
121
122        if !self.nodes.contains_key(&self.end) {
123            return Err(crate::error::TerminalError::InvalidGraph(format!(
124                "end node '{}' not found",
125                self.end
126            )));
127        }
128
129        for edge in &self.edges {
130            if !self.nodes.contains_key(&edge.from) {
131                return Err(crate::error::TerminalError::InvalidGraph(format!(
132                    "edge references non-existent source node '{}'",
133                    edge.from
134                )));
135            }
136            if !self.nodes.contains_key(&edge.to) {
137                return Err(crate::error::TerminalError::InvalidGraph(format!(
138                    "edge references non-existent target node '{}'",
139                    edge.to
140                )));
141            }
142        }
143
144        Ok(())
145    }
146
147    /// 分析图中所有环,生成诊断信息。
148    pub fn analyze_cycles(&self) -> CycleAnalysis {
149        let mut cycles = Vec::new();
150        let mut path = Vec::new();
151
152        let mut adj: std::collections::HashMap<String, Vec<String>> =
153            std::collections::HashMap::new();
154        for edge in &self.edges {
155            adj.entry(edge.from.clone()).or_default().push(edge.to.clone());
156        }
157
158        for node in self.nodes.keys() {
159            let mut in_path = std::collections::HashSet::new();
160            path.clear();
161            self.dfs_cycles(node, node, &adj, &mut in_path, &mut path, &mut cycles);
162        }
163
164        // 检查哪些环有 analysis 或 policy 保护
165        let mut unprotected = cycles
166            .iter()
167            .filter(|cycle| {
168                let has_protection = (0..cycle.len()).any(|i| {
169                    let next = (i + 1) % cycle.len();
170                    let from = cycle[i].as_str();
171                    let to = cycle[next].as_str();
172                    self.edges.iter().any(|e| {
173                        e.from == from
174                            && e.to == to
175                            && (e.analysis.as_ref().is_some_and(|a| a.max_visits.is_some())
176                                || e.policy.is_some())
177                    })
178                });
179                !has_protection
180            })
181            .cloned()
182            .collect::<Vec<_>>();
183        unprotected.sort();
184        unprotected.dedup();
185
186        CycleAnalysis {
187            has_cycles: !cycles.is_empty(),
188            cycles,
189            unprotected_cycles: unprotected,
190            total_edges: self.edges.len(),
191            protected_edges: self.edges.iter().filter(|e| {
192                e.analysis.as_ref().is_some_and(|a| a.max_visits.is_some()) || e.policy.is_some()
193            }).count(),
194        }
195    }
196
197    fn dfs_cycles(
198        &self,
199        start: &str,
200        current: &str,
201        adj: &std::collections::HashMap<String, Vec<String>>,
202        in_path: &mut std::collections::HashSet<String>,
203        path: &mut Vec<String>,
204        cycles: &mut Vec<Vec<String>>,
205    ) {
206        if in_path.contains(current) {
207            return;
208        }
209
210        path.push(current.to_string());
211        in_path.insert(current.to_string());
212
213        if let Some(neighbors) = adj.get(current) {
214            for neighbor in neighbors {
215                if neighbor.as_str() == start && path.len() >= 2 {
216                    cycles.push(path.clone());
217                } else if neighbor.as_str() > start && !in_path.contains(neighbor) {
218                    self.dfs_cycles(start, neighbor, adj, in_path, path, cycles);
219                }
220            }
221        }
222
223        path.pop();
224        in_path.remove(current);
225    }
226}
227
228/// 环分析诊断结果。
229#[derive(Debug, Clone)]
230pub struct CycleAnalysis {
231    pub has_cycles: bool,
232    pub cycles: Vec<Vec<String>>,
233    pub unprotected_cycles: Vec<Vec<String>>,
234    pub total_edges: usize,
235    pub protected_edges: usize,
236}
237
238impl CycleAnalysis {
239    pub fn all_protected(&self) -> bool {
240        self.unprotected_cycles.is_empty()
241    }
242
243    pub fn report(&self) -> String {
244        let mut lines = Vec::new();
245        lines.push("=== Graph Cycle Analysis ===".to_string());
246
247        if !self.has_cycles {
248            lines.push("No cycles detected — graph is a DAG.".to_string());
249            return lines.join("\n");
250        }
251
252        lines.push(format!("Found {} cycle(s).", self.cycles.len()));
253        lines.push(format!(
254            "Edge protection: {}/{} edges have analysis or policy set.",
255            self.protected_edges, self.total_edges
256        ));
257
258        for (i, cycle) in self.cycles.iter().enumerate() {
259            let cycle_str = cycle.join(" → ");
260            lines.push(format!("  Cycle {}: {} → {}", i + 1, cycle_str, cycle[0]));
261
262            if self.unprotected_cycles.contains(cycle) {
263                lines.push("    ⚠️ UNPROTECTED — no max_visits or policy on back-edge".into());
264            } else {
265                lines.push("    ✅ Protected by edge-level analysis or policy".into());
266            }
267        }
268
269        if !self.all_protected() {
270            lines.push("".into());
271            lines.push(
272                "⚠️ Recommendation: Set analysis.max_visits or policy on back-edges."
273                    .to_string(),
274            );
275        }
276
277        lines.join("\n")
278    }
279}
280
281// ─── GraphBuilder ─────────────────────────────────────────────
282
283/// Graph 构建器。
284pub struct GraphBuilder {
285    name: String,
286    nodes: IndexMap<String, NodeKind>,
287    edges: Vec<Edge>,
288    start: Option<String>,
289    end: Option<String>,
290}
291
292impl GraphBuilder {
293    pub fn new(name: impl Into<String>) -> Self {
294        Self {
295            name: name.into(),
296            nodes: IndexMap::new(),
297            edges: Vec::new(),
298            start: None,
299            end: None,
300        }
301    }
302
303    pub fn start(&mut self, node: impl Into<String>) -> Result<&mut Self, BuildError> {
304        self.start = Some(node.into());
305        Ok(self)
306    }
307
308    pub fn end(&mut self, node: impl Into<String>) -> Result<&mut Self, BuildError> {
309        self.end = Some(node.into());
310        Ok(self)
311    }
312
313    pub fn node(
314        &mut self,
315        name: impl Into<String>,
316        kind: NodeKind,
317    ) -> Result<&mut Self, BuildError> {
318        let name = name.into();
319        if self.nodes.contains_key(&name) {
320            return Err(BuildError::DuplicateNode { id: name });
321        }
322        self.nodes.insert(name, kind);
323        Ok(self)
324    }
325
326    /// 添加边(无条件,无 policy)。
327    pub fn edge(
328        &mut self,
329        from: impl Into<String>,
330        to: impl Into<String>,
331    ) -> Result<&mut Self, BuildError> {
332        self.edges.push(Edge {
333            from: from.into(),
334            to: to.into(),
335            condition: None,
336            analysis: None,
337            policy: None,
338            fallback: false,
339        });
340        Ok(self)
341    }
342
343    /// 添加条件边。
344    pub fn edge_if(
345        &mut self,
346        from: impl Into<String>,
347        to: impl Into<String>,
348        condition: impl Fn(&State) -> bool + Send + Sync + 'static,
349    ) -> Result<&mut Self, BuildError> {
350        self.edges.push(Edge {
351            from: from.into(),
352            to: to.into(),
353            condition: Some(Arc::new(condition)),
354            analysis: None,
355            policy: None,
356            fallback: false,
357        });
358        Ok(self)
359    }
360
361    /// 添加 fallback 边(无条件兜底)。
362    pub fn edge_fallback(
363        &mut self,
364        from: impl Into<String>,
365        to: impl Into<String>,
366    ) -> Result<&mut Self, BuildError> {
367        self.edges.push(Edge {
368            from: from.into(),
369            to: to.into(),
370            condition: None,
371            analysis: None,
372            policy: None,
373            fallback: true,
374        });
375        Ok(self)
376    }
377
378    /// 添加带 analysis 约束的边(仅静态分析用,不参与 runtime)。
379    pub fn edge_analysis(
380        &mut self,
381        from: impl Into<String>,
382        to: impl Into<String>,
383        max_visits: usize,
384    ) -> Result<&mut Self, BuildError> {
385        self.edges.push(Edge {
386            from: from.into(),
387            to: to.into(),
388            condition: None,
389            analysis: Some(EdgeAnalysis {
390                max_visits: Some(max_visits),
391            }),
392            policy: None,
393            fallback: false,
394        });
395        Ok(self)
396    }
397
398    /// 添加带 runtime policy 的边(显式拦截)。
399    pub fn edge_policy(
400        &mut self,
401        from: impl Into<String>,
402        to: impl Into<String>,
403        policy: EdgePolicy,
404    ) -> Result<&mut Self, BuildError> {
405        self.edges.push(Edge {
406            from: from.into(),
407            to: to.into(),
408            condition: None,
409            analysis: None,
410            policy: Some(policy),
411            fallback: false,
412        });
413        Ok(self)
414    }
415
416    /// 构建 Graph。返回 `Result<Graph, BuildError>`。
417    pub fn build(self) -> Result<Graph, BuildError> {
418        let start = self.start.ok_or(BuildError::MissingEntryPoint)?;
419        let end = self.end.ok_or(BuildError::MissingExitPoint)?;
420
421        let graph = Graph {
422            nodes: self.nodes,
423            edges: self.edges,
424            start,
425            end,
426        };
427
428        // 结构验证
429        for edge in &graph.edges {
430            if !graph.nodes.contains_key(&edge.from) {
431                return Err(BuildError::MissingNode {
432                    from: edge.from.clone(),
433                    to: edge.from.clone(),
434                });
435            }
436            if !graph.nodes.contains_key(&edge.to) {
437                return Err(BuildError::MissingNode {
438                    from: edge.from.clone(),
439                    to: edge.to.clone(),
440                });
441            }
442        }
443
444        graph.validate().map_err(|e| match e {
445            crate::error::TerminalError::InvalidGraph(msg) => {
446                BuildError::InvalidEdgeDefinition {
447                    from: "unknown".into(),
448                    to: "unknown".into(),
449                    reason: msg,
450                }
451            }
452            _ => BuildError::InvalidEdgeDefinition {
453                from: "unknown".into(),
454                to: "unknown".into(),
455                reason: "validation failed".into(),
456            },
457        })?;
458
459        Ok(graph)
460    }
461
462    pub fn name(&self) -> &str {
463        &self.name
464    }
465}