Skip to main content

lellm_graph/
graph.rs

1//! Graph 和 GraphBuilder。
2//!
3//! Edge 三类边模型:
4//! - **条件边** (`edge_if`) — `if/else-if` 规则链,按注册顺序求值,first match wins
5//! - **普通边** (`edge`) — 无条件非 fallback,条件边无命中时生效
6//! - **Fallback 边** (`edge_fallback`) — 最后兜底
7//!
8//! 运行时安全由 `GraphExecutor::max_steps` 统一负责。
9
10use std::sync::Arc;
11
12use indexmap::IndexMap;
13
14use crate::error::{BuildError, BuildErrors, DiagnosticCategory, GraphDiagnostics};
15use crate::node::NodeKind;
16use crate::state::State;
17
18// ─── Edge ──────────────────────────────────────────────────────
19
20/// 边条件回调类型别名。
21/// Arc 包装以支持 Graph Clone(条件回调不可 Clone)。
22pub type EdgeCondition = Arc<dyn Fn(&State) -> bool + Send + Sync>;
23
24/// 边(Edge)— 三类边模型。
25///
26/// 一个节点的出边分为三类,按固定顺序求值:
27/// 1. **条件边** — `condition` 非 None,`fallback` = false。按注册顺序求值,first match wins。
28/// 2. **普通边** — `condition` = None,`fallback` = false。条件边无命中时生效。
29/// 3. **Fallback 边** — `fallback` = true。最后兜底。
30#[derive(Clone)]
31pub struct Edge {
32    pub from: String,
33    pub to: String,
34    /// 路由条件。Some = 条件边;None = 普通边或 fallback 边。
35    pub condition: Option<EdgeCondition>,
36    /// 分析用约束(不参与 runtime 决策)
37    pub analysis: Option<EdgeAnalysis>,
38    /// 是否为 fallback 边(最后兜底)
39    pub fallback: bool,
40}
41
42impl Edge {
43    /// 判断是否为条件边。
44    pub fn is_conditional(&self) -> bool {
45        self.condition.is_some() && !self.fallback
46    }
47
48    /// 判断是否为普通边(无条件非 fallback)。
49    pub fn is_normal(&self) -> bool {
50        self.condition.is_none() && !self.fallback
51    }
52}
53
54impl std::fmt::Debug for Edge {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct("Edge")
57            .field("from", &self.from)
58            .field("to", &self.to)
59            .field("has_condition", &self.condition.is_some())
60            .field("analysis", &self.analysis)
61            .field("fallback", &self.fallback)
62            .finish()
63    }
64}
65
66/// 分析用约束 — 仅用于 `analyze_cycles()` 静态分析。
67///
68/// 不参与执行控制。运行时安全由 `GraphExecutor::max_steps` 负责。
69#[derive(Debug, Clone)]
70pub struct EdgeAnalysis {
71    /// 建议的最大访问次数 — 用于循环分析诊断。
72    pub max_visits: Option<usize>,
73}
74
75// ─── Graph ─────────────────────────────────────────────────────
76
77/// 图(Graph)— 允许有环,循环保护由 GraphExecutor::max_steps 运行时熔断提供。
78#[derive(Clone)]
79pub struct Graph {
80    pub(crate) name: String,
81    pub(crate) nodes: IndexMap<String, NodeKind>,
82    pub(crate) edges: Vec<Edge>,
83    pub(crate) start: String,
84    pub(crate) end: String,
85}
86
87impl Graph {
88    pub fn name(&self) -> &str {
89        &self.name
90    }
91
92    pub fn node_names(&self) -> Vec<&str> {
93        self.nodes.keys().map(|s| s.as_str()).collect()
94    }
95
96    pub fn start_node(&self) -> &str {
97        &self.start
98    }
99
100    pub fn end_node(&self) -> &str {
101        &self.end
102    }
103
104    /// 计算图结构指纹 hash。
105    ///
106    /// 用于 Checkpoint 恢复时校验图结构是否变更。
107    /// 基于节点名和边定义生成简化的 hash 字符串。
108    pub fn hash(&self) -> String {
109        let mut s = String::new();
110        // 排序节点名,确保确定性
111        let mut names: Vec<&str> = self.nodes.keys().map(|k| k.as_str()).collect();
112        names.sort();
113        s.push_str(&names.join(","));
114        s.push('|');
115        // 排序边,确保确定性
116        let mut edge_strs: Vec<String> = self
117            .edges
118            .iter()
119            .map(|e| {
120                format!(
121                    "{}->{}{:?}{}",
122                    e.from,
123                    e.to,
124                    if e.condition.is_some() { "?" } else { "" },
125                    if e.fallback { "!" } else { "" }
126                )
127            })
128            .collect();
129        edge_strs.sort();
130        s.push_str(&edge_strs.join(","));
131        // Simple hash — FNV-1a
132        let hash = fnv_hash(&s);
133        format!("{:016x}", hash)
134    }
135
136    pub fn edges_from(&self, from: &str) -> Vec<&Edge> {
137        self.edges.iter().filter(|e| e.from == from).collect()
138    }
139
140    pub fn find_edge(&self, from: &str, to: &str) -> Option<&Edge> {
141        self.edges.iter().find(|e| e.from == from && e.to == to)
142    }
143
144    /// 查找指定节点的 fallback 边目标。
145    ///
146    /// 用于 Fallback 控制流:节点主动声明降级后,Executor 查找 fallback 边路由。
147    pub fn find_fallback_edge(&self, from: &str) -> Option<String> {
148        self.edges
149            .iter()
150            .find(|e| e.from == from && e.fallback)
151            .map(|e| e.to.clone())
152    }
153
154    /// 验证图结构(节点、边引用有效性)。
155    ///
156    /// 注意:不检测环 — 有环图是合法的,循环保护由 GraphExecutor::max_steps 提供。
157    pub fn validate(&self) -> Result<(), crate::error::TerminalError> {
158        if !self.nodes.contains_key(&self.start) {
159            return Err(crate::error::TerminalError::InvalidGraph(format!(
160                "start node '{}' not found",
161                self.start
162            )));
163        }
164
165        if !self.nodes.contains_key(&self.end) {
166            return Err(crate::error::TerminalError::InvalidGraph(format!(
167                "end node '{}' not found",
168                self.end
169            )));
170        }
171
172        for edge in &self.edges {
173            if !self.nodes.contains_key(&edge.from) {
174                return Err(crate::error::TerminalError::InvalidGraph(format!(
175                    "edge references non-existent source node '{}'",
176                    edge.from
177                )));
178            }
179            if !self.nodes.contains_key(&edge.to) {
180                return Err(crate::error::TerminalError::InvalidGraph(format!(
181                    "edge references non-existent target node '{}'",
182                    edge.to
183                )));
184            }
185        }
186
187        Ok(())
188    }
189
190    /// 完整图诊断分析。
191    ///
192    /// 检查以下维度并返回 `GraphDiagnostics`:
193    /// 1. **环检测** — 图中存在循环路径(Warning)
194    /// 2. **Fallback 参与循环** — fallback 边在环内(Warning)
195    /// 3. **不可达路径** — 从 start 无法到达的节点(Info)
196    /// 4. **End 节点出边** — end 节点定义了出边(Info)
197    ///
198    /// 与 `build()` 的关系:`build()` 只检查结构正确性;`analyze()` 检查风险性。
199    pub fn analyze(&self) -> GraphDiagnostics {
200        let mut diag = GraphDiagnostics::new();
201
202        // 1. 构建邻接表(复用)
203        let adj = self.build_adj();
204
205        // 2. 环检测
206        let cycles = self.find_all_cycles(&adj);
207        if !cycles.is_empty() {
208            let unprotected = self.filter_unprotected_cycles(&cycles);
209            for cycle in &unprotected {
210                let cycle_str = format_cycle(cycle);
211                diag.add_warning(
212                    DiagnosticCategory::Cycle,
213                    format!("cycle detected: {} → {}", cycle_str, cycle[0]),
214                );
215            }
216            // 受保护的环仅提示
217            for cycle in &cycles {
218                if !unprotected.contains(cycle) {
219                    let cycle_str = format_cycle(cycle);
220                    diag.add_info(
221                        DiagnosticCategory::Cycle,
222                        format!(
223                            "protected cycle: {} → {} (has max_visits)",
224                            cycle_str, cycle[0]
225                        ),
226                    );
227                }
228            }
229        }
230
231        // 3. Fallback 参与循环
232        check_fallback_in_cycles(self, &cycles, &mut diag);
233
234        // 4. 不可达路径(BFS 从 start 出发)
235        check_unreachable_nodes(self, &adj, &mut diag);
236
237        // 5. End 节点出边
238        check_end_node_outgoing(self, &mut diag);
239
240        diag
241    }
242
243    // ─── 内部辅助方法 ───────────────────────────────────────
244
245    /// 构建邻接表。
246    fn build_adj(&self) -> std::collections::HashMap<String, Vec<String>> {
247        let mut adj: std::collections::HashMap<String, Vec<String>> =
248            std::collections::HashMap::new();
249        for edge in &self.edges {
250            adj.entry(edge.from.clone())
251                .or_default()
252                .push(edge.to.clone());
253        }
254        adj
255    }
256
257    /// 查找所有环。
258    fn find_all_cycles(
259        &self,
260        adj: &std::collections::HashMap<String, Vec<String>>,
261    ) -> Vec<Vec<String>> {
262        let mut cycles = Vec::new();
263        for node in self.nodes.keys() {
264            let mut in_path = std::collections::HashSet::new();
265            let mut path = Vec::new();
266            self.dfs_cycles(node, node, adj, &mut in_path, &mut path, &mut cycles);
267        }
268        cycles
269    }
270
271    /// DFS 环检测。
272    fn dfs_cycles(
273        &self,
274        start: &str,
275        current: &str,
276        adj: &std::collections::HashMap<String, Vec<String>>,
277        in_path: &mut std::collections::HashSet<String>,
278        path: &mut Vec<String>,
279        cycles: &mut Vec<Vec<String>>,
280    ) {
281        if in_path.contains(current) {
282            return;
283        }
284
285        path.push(current.to_string());
286        in_path.insert(current.to_string());
287
288        if let Some(neighbors) = adj.get(current) {
289            for neighbor in neighbors {
290                if neighbor.as_str() == start && path.len() >= 2 {
291                    cycles.push(path.clone());
292                } else if neighbor.as_str() > start && !in_path.contains(neighbor) {
293                    self.dfs_cycles(start, neighbor, adj, in_path, path, cycles);
294                }
295            }
296        }
297
298        path.pop();
299        in_path.remove(current);
300    }
301
302    /// 过滤未受保护的环。
303    fn filter_unprotected_cycles(&self, cycles: &[Vec<String>]) -> Vec<Vec<String>> {
304        let mut unprotected: Vec<Vec<String>> = cycles
305            .iter()
306            .filter(|cycle| {
307                let has_protection = (0..cycle.len()).any(|i| {
308                    let next = (i + 1) % cycle.len();
309                    let from = cycle[i].as_str();
310                    let to = cycle[next].as_str();
311                    self.edges.iter().any(|e| {
312                        e.from == from
313                            && e.to == to
314                            && e.analysis.as_ref().is_some_and(|a| a.max_visits.is_some())
315                    })
316                });
317                !has_protection
318            })
319            .cloned()
320            .collect();
321        unprotected.sort();
322        unprotected.dedup();
323        unprotected
324    }
325
326    // ─── 兼容方法 ─────────────────────────────────────────────
327
328    /// 分析图中所有环,生成诊断信息。
329    ///
330    /// @deprecated 使用 [`analyze()`](Self::analyze) 替代。
331    pub fn analyze_cycles(&self) -> CycleAnalysis {
332        let adj = self.build_adj();
333        let cycles = self.find_all_cycles(&adj);
334        let unprotected = self.filter_unprotected_cycles(&cycles);
335
336        CycleAnalysis {
337            has_cycles: !cycles.is_empty(),
338            cycles,
339            unprotected_cycles: unprotected,
340            total_edges: self.edges.len(),
341            protected_edges: self
342                .edges
343                .iter()
344                .filter(|e| e.analysis.as_ref().is_some_and(|a| a.max_visits.is_some()))
345                .count(),
346        }
347    }
348}
349
350/// 环分析诊断结果。
351#[derive(Debug, Clone)]
352pub struct CycleAnalysis {
353    pub has_cycles: bool,
354    pub cycles: Vec<Vec<String>>,
355    pub unprotected_cycles: Vec<Vec<String>>,
356    pub total_edges: usize,
357    pub protected_edges: usize,
358}
359
360impl CycleAnalysis {
361    pub fn all_protected(&self) -> bool {
362        self.unprotected_cycles.is_empty()
363    }
364
365    pub fn report(&self) -> String {
366        let mut lines = Vec::new();
367        lines.push("=== Graph Cycle Analysis ===".to_string());
368
369        if !self.has_cycles {
370            lines.push("No cycles detected — graph is a DAG.".to_string());
371            return lines.join("\n");
372        }
373
374        lines.push(format!("Found {} cycle(s).", self.cycles.len()));
375        lines.push(format!(
376            "Edge protection: {}/{} edges have analysis set.",
377            self.protected_edges, self.total_edges
378        ));
379
380        for (i, cycle) in self.cycles.iter().enumerate() {
381            let cycle_str = cycle.join(" → ");
382            lines.push(format!("  Cycle {}: {} → {}", i + 1, cycle_str, cycle[0]));
383
384            if self.unprotected_cycles.contains(cycle) {
385                lines.push("    ⚠️ UNPROTECTED — no max_visits on back-edge".into());
386            } else {
387                lines.push("    ✅ Protected by edge-level analysis".into());
388            }
389        }
390
391        if !self.all_protected() {
392            lines.push("".into());
393            lines.push("⚠️ Recommendation: Set analysis.max_visits on back-edges.".to_string());
394        }
395
396        lines.join("\n")
397    }
398}
399
400// ─── PendingEdge ──────────────────────────────────────────────
401
402/// 待完成的边 — 链式调用的中间句柄。
403///
404/// 由 `GraphBuilder::edge()` / `edge_if()` / `edge_fallback()` 返回。
405/// 通过 `.max_visits(n)` 附加循环分析约束。
406///
407/// ```rust,ignore
408/// // 条件回跳 + 循环分析
409/// g.edge_if("b", "a", |s| s.should_retry)?.max_visits(5);
410///
411/// // 普通边 + 循环分析
412/// g.edge("b", "a").max_visits(5);
413///
414/// // 不加分析(直接丢弃 PendingEdge)
415/// g.edge("b", "end");
416/// ```
417pub struct PendingEdge<'a> {
418    builder: &'a mut GraphBuilder,
419    edge_index: usize,
420}
421
422impl<'a> PendingEdge<'a> {
423    /// 附加循环分析约束(建议的最大访问次数)。
424    ///
425    /// 仅用于 `analyze_cycles()` 静态诊断,不参与运行时路由。
426    /// 返回 `&mut GraphBuilder` 以便继续链式调用。
427    pub fn max_visits(self, n: usize) -> &'a mut GraphBuilder {
428        self.builder.edges[self.edge_index].analysis = Some(EdgeAnalysis {
429            max_visits: Some(n),
430        });
431        self.builder
432    }
433}
434
435// ─── GraphBuilder ─────────────────────────────────────────────
436
437/// Graph 构建器。
438pub struct GraphBuilder {
439    name: String,
440    nodes: IndexMap<String, NodeKind>,
441    edges: Vec<Edge>,
442    start: Option<String>,
443    end: Option<String>,
444}
445
446impl GraphBuilder {
447    pub fn new(name: impl Into<String>) -> Self {
448        Self {
449            name: name.into(),
450            nodes: IndexMap::new(),
451            edges: Vec::new(),
452            start: None,
453            end: None,
454        }
455    }
456
457    pub fn start(&mut self, node: impl Into<String>) -> &mut Self {
458        self.start = Some(node.into());
459        self
460    }
461
462    pub fn end(&mut self, node: impl Into<String>) -> &mut Self {
463        self.end = Some(node.into());
464        self
465    }
466
467    pub fn node(&mut self, name: impl Into<String>, kind: NodeKind) -> &mut Self {
468        self.nodes.insert(name.into(), kind);
469        self
470    }
471
472    /// 添加边(无条件普通边)。
473    ///
474    /// 返回 [`PendingEdge`],可通过 `.max_visits(n)` 附加循环分析约束。
475    ///
476    /// ```rust,ignore
477    /// g.edge("a", "b");                    // 普通边
478    /// g.edge("b", "a").max_visits(5);      // 普通边 + 循环分析
479    /// ```
480    pub fn edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> PendingEdge<'_> {
481        let edge_index = self.edges.len();
482        self.edges.push(Edge {
483            from: from.into(),
484            to: to.into(),
485            condition: None,
486            analysis: None,
487            fallback: false,
488        });
489        PendingEdge {
490            builder: self,
491            edge_index,
492        }
493    }
494
495    /// 添加条件边(`if/else-if` 规则链)。
496    ///
497    /// 返回 [`PendingEdge`],可通过 `.max_visits(n)` 附加循环分析约束。
498    ///
499    /// ```rust,ignore
500    /// g.edge_if("agent", "retry", |s| s.has_tool_calls()).max_visits(10);
501    /// g.edge_if("agent", "end", |_| true);
502    /// ```
503    pub fn edge_if(
504        &mut self,
505        from: impl Into<String>,
506        to: impl Into<String>,
507        condition: impl Fn(&State) -> bool + Send + Sync + 'static,
508    ) -> PendingEdge<'_> {
509        let edge_index = self.edges.len();
510        self.edges.push(Edge {
511            from: from.into(),
512            to: to.into(),
513            condition: Some(Arc::new(condition)),
514            analysis: None,
515            fallback: false,
516        });
517        PendingEdge {
518            builder: self,
519            edge_index,
520        }
521    }
522
523    /// 添加 fallback 边(无条件兜底)。
524    ///
525    /// 返回 [`PendingEdge`],可通过 `.max_visits(n)` 附加循环分析约束。
526    pub fn edge_fallback(
527        &mut self,
528        from: impl Into<String>,
529        to: impl Into<String>,
530    ) -> PendingEdge<'_> {
531        let edge_index = self.edges.len();
532        self.edges.push(Edge {
533            from: from.into(),
534            to: to.into(),
535            condition: None,
536            analysis: None,
537            fallback: true,
538        });
539        PendingEdge {
540            builder: self,
541            edge_index,
542        }
543    }
544
545    /// 构建 Graph。
546    ///
547    /// 收集所有错误后统一报告。`Warning` 变体不阻止 build 成功。
548    ///
549    /// ```rust,ignore
550    /// match builder.build() {
551    ///     Ok(graph) => { /* 使用 graph */ }
552    ///     Err(errors) => {
553    ///         for e in &errors.0 {
554    ///             eprintln!("{}", e);
555    ///         }
556    ///     }
557    /// }
558    /// ```
559    pub fn build(self) -> Result<Graph, BuildErrors> {
560        let mut errors = BuildErrors::new();
561
562        // 1. 检查入口/出口
563        let start = match self.start {
564            Some(s) => s,
565            None => {
566                errors.push(BuildError::MissingEntryPoint);
567                // 无法继续验证,提前返回
568                return Err(errors);
569            }
570        };
571        let end = match self.end {
572            Some(s) => s,
573            None => {
574                errors.push(BuildError::MissingExitPoint);
575                return Err(errors);
576            }
577        };
578
579        // 2. 检测重复节点名
580        let mut seen_nodes = std::collections::HashSet::new();
581        for name in self.nodes.keys() {
582            if !seen_nodes.insert(name.clone()) {
583                errors.push(BuildError::DuplicateNode { id: name.clone() });
584            }
585        }
586
587        // 3. 检查边引用的节点是否存在
588        for edge in &self.edges {
589            if !self.nodes.contains_key(&edge.from) {
590                errors.push(BuildError::MissingNode {
591                    from: edge.from.clone(),
592                    to: edge.to.clone(),
593                });
594            }
595            if !self.nodes.contains_key(&edge.to) {
596                errors.push(BuildError::MissingNode {
597                    from: edge.from.clone(),
598                    to: edge.to.clone(),
599                });
600            }
601        }
602
603        // 4. 有错误则返回(build() 是纯函数,不产生 Warning)
604        if !errors.is_empty() {
605            return Err(errors);
606        }
607
608        // 5. 构建 Graph
609        let graph = Graph {
610            name: self.name,
611            nodes: self.nodes,
612            edges: self.edges,
613            start,
614            end,
615        };
616
617        // 6. 结构验证(validate 检查 start/end 节点存在性等)
618        if let Err(e) = graph.validate() {
619            return Err(BuildErrors(vec![BuildError::InvalidEdgeDefinition {
620                from: "graph".to_string(),
621                to: "graph".to_string(),
622                reason: e.to_string(),
623            }]));
624        }
625
626        Ok(graph)
627    }
628
629    pub fn name(&self) -> &str {
630        &self.name
631    }
632}
633
634// ─── 诊断辅助函数 ───────────────────────────────────────────────
635
636/// 格式化环路径为字符串:"a → b → c"
637fn format_cycle(cycle: &[String]) -> String {
638    cycle.join(" → ")
639}
640
641/// 检查 fallback 边是否参与循环。
642fn check_fallback_in_cycles(graph: &Graph, cycles: &[Vec<String>], diag: &mut GraphDiagnostics) {
643    // 收集所有 fallback 边的 (from, to)
644    let fallback_edges: std::collections::HashSet<(&str, &str)> = graph
645        .edges
646        .iter()
647        .filter(|e| e.fallback)
648        .map(|e| (e.from.as_str(), e.to.as_str()))
649        .collect();
650
651    if fallback_edges.is_empty() {
652        return;
653    }
654
655    // 检查每个环是否包含 fallback 边
656    for cycle in cycles {
657        for i in 0..cycle.len() {
658            let next = (i + 1) % cycle.len();
659            let from = cycle[i].as_str();
660            let to = cycle[next].as_str();
661            if fallback_edges.contains(&(from, to)) {
662                let edge_str = format!("{} → {}", from, to);
663                diag.add_warning(
664                    DiagnosticCategory::FallbackInCycle,
665                    format!(
666                        "fallback edge {} participates in cycle: {} → {}",
667                        edge_str,
668                        format_cycle(cycle),
669                        cycle[0]
670                    ),
671                );
672            }
673        }
674    }
675}
676
677/// 检查从 start 节点不可达的节点。
678fn check_unreachable_nodes(
679    graph: &Graph,
680    adj: &std::collections::HashMap<String, Vec<String>>,
681    diag: &mut GraphDiagnostics,
682) {
683    // BFS 从 start 出发
684    let mut visited = std::collections::HashSet::new();
685    let mut queue = Vec::new();
686
687    queue.push(graph.start.clone());
688    visited.insert(graph.start.clone());
689
690    while let Some(node) = queue.pop() {
691        if let Some(neighbors) = adj.get(&node) {
692            for neighbor in neighbors {
693                if visited.insert(neighbor.clone()) {
694                    queue.push(neighbor.clone());
695                }
696            }
697        }
698    }
699
700    // 找出未访问的节点
701    for name in graph.nodes.keys() {
702        if !visited.contains(name) {
703            diag.add_info(
704                DiagnosticCategory::Unreachable,
705                format!(
706                    "node '{}' is not reachable from start node '{}'",
707                    name, graph.start
708                ),
709            );
710        }
711    }
712}
713
714/// 检查 end 节点是否有出边。
715fn check_end_node_outgoing(graph: &Graph, diag: &mut GraphDiagnostics) {
716    let outgoing: Vec<&Edge> = graph.edges.iter().filter(|e| e.from == graph.end).collect();
717
718    if !outgoing.is_empty() {
719        let targets: Vec<&str> = outgoing.iter().map(|e| e.to.as_str()).collect();
720        diag.add_info(
721            DiagnosticCategory::EndNodeOutgoing,
722            format!(
723                "end node '{}' has {} outgoing edge(s) to: {:?}",
724                graph.end,
725                outgoing.len(),
726                targets
727            ),
728        );
729    }
730}
731
732/// FNV-1a hash — 无外部依赖的简单 hash。
733fn fnv_hash(s: &str) -> u64 {
734    let mut hash: u64 = 0xcbf29ce484222325;
735    for &byte in s.as_bytes() {
736        hash ^= byte as u64;
737        hash = hash.wrapping_mul(0x100000001b3);
738    }
739    hash
740}