Skip to main content

lellm_graph/
graph.rs

1//! Graph 和 GraphBuilder。
2//!
3//! Edge 三类边模型:
4//! - **条件边** (`edge_if`) — `if/else-if` 规则链,按注册顺序求值,first match wins
5//! - **普通边** (`edge`) — 无条件非 fallback,条件边无命中时生效
6//! - **Fallback 边** (`edge_fallback`) — 最后兜底
7//!
8//! v0.4+: 泛型化 `Graph<S: WorkflowState>`,默认 `S = State`(向后兼容)。
9//!
10//! 运行时安全由 `GraphExecutor::max_steps` 统一负责。
11
12use std::sync::Arc;
13
14use indexmap::IndexMap;
15
16use crate::error::{BuildError, BuildErrors, DiagnosticCategory, GraphDiagnostics};
17use crate::node::{FlowNode, NodeKind};
18use crate::node_context::NodeContext;
19use crate::state::{State, StateMerge};
20use crate::workflow_state::{MergeStrategy, WorkflowState};
21
22// ─── Edge ──────────────────────────────────────────────────────
23
24/// 边条件回调类型别名。
25pub type EdgeCondition<S> = Arc<dyn Fn(&S) -> bool + Send + Sync>;
26
27/// 边(Edge)— 三类边模型。
28#[derive(Clone)]
29pub struct Edge<S: WorkflowState = State> {
30    pub from: String,
31    pub to: String,
32    /// 路由条件。Some = 条件边;None = 普通边或 fallback 边。
33    pub condition: Option<EdgeCondition<S>>,
34    /// 分析用约束(不参与 runtime 决策)
35    pub analysis: Option<EdgeAnalysis>,
36    /// 是否为 fallback 边(最后兜底)
37    pub fallback: bool,
38}
39
40impl<S: WorkflowState> Edge<S> {
41    /// 判断是否为条件边。
42    pub fn is_conditional(&self) -> bool {
43        self.condition.is_some() && !self.fallback
44    }
45
46    /// 判断是否为普通边(无条件非 fallback)。
47    pub fn is_normal(&self) -> bool {
48        self.condition.is_none() && !self.fallback
49    }
50}
51
52impl<S: WorkflowState> std::fmt::Debug for Edge<S> {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("Edge")
55            .field("from", &self.from)
56            .field("to", &self.to)
57            .field("has_condition", &self.condition.is_some())
58            .field("analysis", &self.analysis)
59            .field("fallback", &self.fallback)
60            .finish()
61    }
62}
63
64/// 分析用约束 — 仅用于 `analyze_cycles()` 静态分析。
65#[derive(Debug, Clone)]
66pub struct EdgeAnalysis {
67    /// 建议的最大访问次数
68    pub max_visits: Option<usize>,
69}
70
71// ─── Graph ─────────────────────────────────────────────────────
72
73/// 图(Graph)— 允许有环,循环保护由 GraphExecutor::max_steps 运行时熔断提供。
74#[derive(Clone)]
75pub struct Graph<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
76    pub(crate) name: String,
77    pub(crate) nodes: IndexMap<String, NodeKind<S, M>>,
78    pub(crate) edges: Vec<Edge<S>>,
79    pub(crate) start: String,
80    pub(crate) end: String,
81}
82
83impl<S: WorkflowState, M: MergeStrategy<S>> Graph<S, M> {
84    pub fn name(&self) -> &str {
85        &self.name
86    }
87
88    pub fn node_names(&self) -> Vec<&str> {
89        self.nodes.keys().map(|s| s.as_str()).collect()
90    }
91
92    pub fn start_node(&self) -> &str {
93        &self.start
94    }
95
96    pub fn end_node(&self) -> &str {
97        &self.end
98    }
99
100    /// 计算图结构指纹 hash。
101    pub fn hash(&self) -> String {
102        let mut s = String::new();
103        let mut names: Vec<&str> = self.nodes.keys().map(|k| k.as_str()).collect();
104        names.sort();
105        s.push_str(&names.join(","));
106        s.push('|');
107        let mut edge_strs: Vec<String> = self
108            .edges
109            .iter()
110            .map(|e| {
111                format!(
112                    "{}->{}{:?}{}",
113                    e.from,
114                    e.to,
115                    if e.condition.is_some() { "?" } else { "" },
116                    if e.fallback { "!" } else { "" }
117                )
118            })
119            .collect();
120        edge_strs.sort();
121        s.push_str(&edge_strs.join(","));
122        let hash = fnv_hash(&s);
123        format!("{:016x}", hash)
124    }
125
126    pub fn edges_from(&self, from: &str) -> Vec<&Edge<S>> {
127        self.edges.iter().filter(|e| e.from == from).collect()
128    }
129
130    pub fn find_edge(&self, from: &str, to: &str) -> Option<&Edge<S>> {
131        self.edges.iter().find(|e| e.from == from && e.to == to)
132    }
133
134    /// 获取节点映射表引用。
135    pub fn node_map(&self) -> &IndexMap<String, NodeKind<S, M>> {
136        &self.nodes
137    }
138
139    /// 路由解析 — 根据当前节点和 State 找到下一个节点。
140    pub fn resolve_next(&self, current: &str, state: &S) -> Option<String> {
141        let edges = self.edges_from(current);
142
143        // 1. 条件边
144        for edge in &edges {
145            if edge.is_conditional() && edge.condition.as_ref().is_some_and(|c| c(state)) {
146                return Some(edge.to.clone());
147            }
148        }
149
150        // 2. 普通边
151        for edge in &edges {
152            if edge.is_normal() {
153                return Some(edge.to.clone());
154            }
155        }
156
157        // 3. Fallback 边
158        for edge in &edges {
159            if edge.fallback {
160                return Some(edge.to.clone());
161            }
162        }
163
164        None
165    }
166
167    pub fn find_fallback_edge(&self, from: &str) -> Option<String> {
168        self.edges
169            .iter()
170            .find(|e| e.from == from && e.fallback)
171            .map(|e| e.to.clone())
172    }
173
174    /// 验证图结构。
175    pub fn validate(&self) -> Result<(), crate::error::TerminalError> {
176        if !self.nodes.contains_key(&self.start) {
177            return Err(crate::error::TerminalError::InvalidGraph(format!(
178                "start node '{}' not found",
179                self.start
180            )));
181        }
182
183        if !self.nodes.contains_key(&self.end) {
184            return Err(crate::error::TerminalError::InvalidGraph(format!(
185                "end node '{}' not found",
186                self.end
187            )));
188        }
189
190        for edge in &self.edges {
191            if !self.nodes.contains_key(&edge.from) {
192                return Err(crate::error::TerminalError::InvalidGraph(format!(
193                    "edge references non-existent source node '{}'",
194                    edge.from
195                )));
196            }
197            if !self.nodes.contains_key(&edge.to) {
198                return Err(crate::error::TerminalError::InvalidGraph(format!(
199                    "edge references non-existent target node '{}'",
200                    edge.to
201                )));
202            }
203        }
204
205        Ok(())
206    }
207
208    /// 完整图诊断分析。
209    pub fn analyze(&self) -> GraphDiagnostics {
210        let mut diag = GraphDiagnostics::new();
211        let adj = self.build_adj();
212
213        let cycles = self.find_all_cycles(&adj);
214        if !cycles.is_empty() {
215            let unprotected = self.filter_unprotected_cycles(&cycles);
216            for cycle in &unprotected {
217                let cycle_str = format_cycle(cycle);
218                diag.add_warning(
219                    DiagnosticCategory::Cycle,
220                    format!("cycle detected: {} → {}", cycle_str, cycle[0]),
221                );
222            }
223            for cycle in &cycles {
224                if !unprotected.contains(cycle) {
225                    let cycle_str = format_cycle(cycle);
226                    diag.add_info(
227                        DiagnosticCategory::Cycle,
228                        format!(
229                            "protected cycle: {} → {} (has max_visits)",
230                            cycle_str, cycle[0]
231                        ),
232                    );
233                }
234            }
235        }
236
237        check_fallback_in_cycles(self, &cycles, &mut diag);
238        check_unreachable_nodes(self, &adj, &mut diag);
239        check_end_node_outgoing(self, &mut diag);
240
241        diag
242    }
243
244    fn build_adj(&self) -> std::collections::HashMap<String, Vec<String>> {
245        let mut adj: std::collections::HashMap<String, Vec<String>> =
246            std::collections::HashMap::new();
247        for edge in &self.edges {
248            adj.entry(edge.from.clone())
249                .or_default()
250                .push(edge.to.clone());
251        }
252        adj
253    }
254
255    // ─── 内联执行 ────────────────────────────────────────────
256
257    /// 内联执行 — 不产生 RuntimeEvent,不 Checkpoint。
258    pub async fn run_inline(
259        &self,
260        ctx: &mut NodeContext<'_, S>,
261        max_steps: usize,
262    ) -> Result<(), crate::error::GraphError> {
263        use crate::node_context::NextAction;
264
265        let mut current = self.start_node().to_string();
266        let mut step: usize = 0;
267
268        loop {
269            step += 1;
270            if step > max_steps {
271                return Err(crate::error::GraphError::Terminal(
272                    crate::error::TerminalError::StepsExceeded { limit: max_steps },
273                ));
274            }
275
276            let node = self.nodes.get(&current).ok_or_else(|| {
277                crate::error::GraphError::Terminal(crate::error::TerminalError::NodeNotFound(
278                    current.clone(),
279                ))
280            })?;
281
282            // 执行节点
283            node.execute(ctx).await?;
284
285            // 消费 Effects → apply 到 typed state(零序列化)
286            let effects = ctx.consume_effects();
287            ctx.state_mut().apply_batch(effects);
288
289            // 提取控制信号
290            let (next_action, _signal) = ctx.take_control();
291
292            // 处理路由
293            match next_action {
294                NextAction::End => return Ok(()),
295                NextAction::Goto(target) => {
296                    current = target;
297                }
298                NextAction::Next => {
299                    if current == self.end_node() {
300                        return Ok(());
301                    }
302                    current = self.resolve_next_inline(&current, ctx.state())?;
303                }
304            }
305        }
306    }
307
308    /// 内联路由解析。
309    fn resolve_next_inline(
310        &self,
311        current: &str,
312        state: &S,
313    ) -> Result<String, crate::error::GraphError> {
314        let edges = self.edges_from(current);
315
316        if edges.is_empty() {
317            return Err(crate::error::GraphError::Terminal(
318                crate::error::TerminalError::InvalidGraph(format!(
319                    "node '{}' has no outgoing edges and is not the end node",
320                    current
321                )),
322            ));
323        }
324
325        // 1. 条件边
326        for edge in &edges {
327            if edge.is_conditional() && edge.condition.as_ref().is_some_and(|c| c(state)) {
328                return Ok(edge.to.clone());
329            }
330        }
331
332        // 2. 普通边
333        for edge in &edges {
334            if edge.is_normal() {
335                return Ok(edge.to.clone());
336            }
337        }
338
339        // 3. Fallback 边
340        for edge in &edges {
341            if edge.fallback {
342                return Ok(edge.to.clone());
343            }
344        }
345
346        Err(crate::error::GraphError::Terminal(
347            crate::error::TerminalError::InvalidGraph(format!(
348                "node '{}' has no matching outgoing edge",
349                current
350            )),
351        ))
352    }
353
354    /// 查找所有环。
355    fn find_all_cycles(
356        &self,
357        adj: &std::collections::HashMap<String, Vec<String>>,
358    ) -> Vec<Vec<String>> {
359        let mut cycles = Vec::new();
360        for node in self.nodes.keys() {
361            let mut in_path = std::collections::HashSet::new();
362            let mut path = Vec::new();
363            self.dfs_cycles(node, node, adj, &mut in_path, &mut path, &mut cycles);
364        }
365        cycles
366    }
367
368    fn dfs_cycles(
369        &self,
370        start: &str,
371        current: &str,
372        adj: &std::collections::HashMap<String, Vec<String>>,
373        in_path: &mut std::collections::HashSet<String>,
374        path: &mut Vec<String>,
375        cycles: &mut Vec<Vec<String>>,
376    ) {
377        if in_path.contains(current) {
378            return;
379        }
380
381        path.push(current.to_string());
382        in_path.insert(current.to_string());
383
384        if let Some(neighbors) = adj.get(current) {
385            for neighbor in neighbors {
386                if neighbor.as_str() == start && path.len() >= 2 {
387                    cycles.push(path.clone());
388                } else if neighbor.as_str() > start && !in_path.contains(neighbor) {
389                    self.dfs_cycles(start, neighbor, adj, in_path, path, cycles);
390                }
391            }
392        }
393
394        path.pop();
395        in_path.remove(current);
396    }
397
398    fn filter_unprotected_cycles(&self, cycles: &[Vec<String>]) -> Vec<Vec<String>> {
399        let mut unprotected: Vec<Vec<String>> = cycles
400            .iter()
401            .filter(|cycle| {
402                let has_protection = (0..cycle.len()).any(|i| {
403                    let next = (i + 1) % cycle.len();
404                    let from = cycle[i].as_str();
405                    let to = cycle[next].as_str();
406                    self.edges.iter().any(|e| {
407                        e.from == from
408                            && e.to == to
409                            && e.analysis.as_ref().is_some_and(|a| a.max_visits.is_some())
410                    })
411                });
412                !has_protection
413            })
414            .cloned()
415            .collect();
416        unprotected.sort();
417        unprotected.dedup();
418        unprotected
419    }
420
421    /// @deprecated 使用 [`analyze()`](Self::analyze) 替代。
422    pub fn analyze_cycles(&self) -> CycleAnalysis {
423        let adj = self.build_adj();
424        let cycles = self.find_all_cycles(&adj);
425        let unprotected = self.filter_unprotected_cycles(&cycles);
426
427        CycleAnalysis {
428            has_cycles: !cycles.is_empty(),
429            cycles,
430            unprotected_cycles: unprotected,
431            total_edges: self.edges.len(),
432            protected_edges: self
433                .edges
434                .iter()
435                .filter(|e| e.analysis.as_ref().is_some_and(|a| a.max_visits.is_some()))
436                .count(),
437        }
438    }
439}
440
441/// 环分析诊断结果。
442#[derive(Debug, Clone)]
443pub struct CycleAnalysis {
444    pub has_cycles: bool,
445    pub cycles: Vec<Vec<String>>,
446    pub unprotected_cycles: Vec<Vec<String>>,
447    pub total_edges: usize,
448    pub protected_edges: usize,
449}
450
451impl CycleAnalysis {
452    pub fn all_protected(&self) -> bool {
453        self.unprotected_cycles.is_empty()
454    }
455
456    pub fn report(&self) -> String {
457        let mut lines = Vec::new();
458        lines.push("=== Graph Cycle Analysis ===".to_string());
459
460        if !self.has_cycles {
461            lines.push("No cycles detected — graph is a DAG.".to_string());
462            return lines.join("\n");
463        }
464
465        lines.push(format!("Found {} cycle(s).", self.cycles.len()));
466        lines.push(format!(
467            "Edge protection: {}/{} edges have analysis set.",
468            self.protected_edges, self.total_edges
469        ));
470
471        for (i, cycle) in self.cycles.iter().enumerate() {
472            let cycle_str = cycle.join(" → ");
473            lines.push(format!("  Cycle {}: {} → {}", i + 1, cycle_str, cycle[0]));
474
475            if self.unprotected_cycles.contains(cycle) {
476                lines.push("    ⚠️ UNPROTECTED — no max_visits on back-edge".into());
477            } else {
478                lines.push("    ✅ Protected by edge-level analysis".into());
479            }
480        }
481
482        if !self.all_protected() {
483            lines.push("".into());
484            lines.push("⚠️ Recommendation: Set analysis.max_visits on back-edges.".to_string());
485        }
486
487        lines.join("\n")
488    }
489}
490
491// ─── PendingEdge ──────────────────────────────────────────────
492
493/// 待完成的边 — 链式调用的中间句柄。
494pub struct PendingEdge<'a, S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
495    builder: &'a mut GraphBuilder<S, M>,
496    edge_index: usize,
497}
498
499impl<'a, S: WorkflowState, M: MergeStrategy<S>> PendingEdge<'a, S, M> {
500    pub fn max_visits(self, n: usize) -> &'a mut GraphBuilder<S, M> {
501        self.builder.edges[self.edge_index].analysis = Some(EdgeAnalysis {
502            max_visits: Some(n),
503        });
504        self.builder
505    }
506}
507
508// ─── GraphBuilder ─────────────────────────────────────────────
509
510/// Graph 构建器。
511pub struct GraphBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
512    name: String,
513    nodes: IndexMap<String, NodeKind<S, M>>,
514    edges: Vec<Edge<S>>,
515    start: Option<String>,
516    end: Option<String>,
517}
518
519impl<S: WorkflowState, M: MergeStrategy<S>> GraphBuilder<S, M> {
520    /// 创建 GraphBuilder。
521    ///
522    /// 类型参数由调用方推断或显式指定。
523    /// - 默认: `GraphBuilder::new("name")` → `GraphBuilder<State, StateMerge>`
524    /// - 自定义: `GraphBuilder::<AgentState, _>::new("name")`
525    pub fn new(name: impl Into<String>) -> Self {
526        Self {
527            name: name.into(),
528            nodes: IndexMap::new(),
529            edges: Vec::new(),
530            start: None,
531            end: None,
532        }
533    }
534    pub fn start(&mut self, node: impl Into<String>) -> &mut Self {
535        self.start = Some(node.into());
536        self
537    }
538
539    pub fn end(&mut self, node: impl Into<String>) -> &mut Self {
540        self.end = Some(node.into());
541        self
542    }
543
544    pub fn node(&mut self, name: impl Into<String>, kind: NodeKind<S, M>) -> &mut Self {
545        self.nodes.insert(name.into(), kind);
546        self
547    }
548
549    pub fn edge(
550        &mut self,
551        from: impl Into<String>,
552        to: impl Into<String>,
553    ) -> PendingEdge<'_, S, M> {
554        let edge_index = self.edges.len();
555        self.edges.push(Edge {
556            from: from.into(),
557            to: to.into(),
558            condition: None,
559            analysis: None,
560            fallback: false,
561        });
562        PendingEdge {
563            builder: self,
564            edge_index,
565        }
566    }
567
568    pub fn edge_if(
569        &mut self,
570        from: impl Into<String>,
571        to: impl Into<String>,
572        condition: impl Fn(&S) -> bool + Send + Sync + 'static,
573    ) -> PendingEdge<'_, S, M> {
574        let edge_index = self.edges.len();
575        self.edges.push(Edge {
576            from: from.into(),
577            to: to.into(),
578            condition: Some(Arc::new(condition)),
579            analysis: None,
580            fallback: false,
581        });
582        PendingEdge {
583            builder: self,
584            edge_index,
585        }
586    }
587
588    pub fn edge_fallback(
589        &mut self,
590        from: impl Into<String>,
591        to: impl Into<String>,
592    ) -> PendingEdge<'_, S, M> {
593        let edge_index = self.edges.len();
594        self.edges.push(Edge {
595            from: from.into(),
596            to: to.into(),
597            condition: None,
598            analysis: None,
599            fallback: true,
600        });
601        PendingEdge {
602            builder: self,
603            edge_index,
604        }
605    }
606
607    pub fn build(self) -> Result<Graph<S, M>, BuildErrors> {
608        let mut errors = BuildErrors::new();
609
610        let start = match self.start {
611            Some(s) => s,
612            None => {
613                errors.push(BuildError::MissingEntryPoint);
614                return Err(errors);
615            }
616        };
617        let end = match self.end {
618            Some(s) => s,
619            None => {
620                errors.push(BuildError::MissingExitPoint);
621                return Err(errors);
622            }
623        };
624
625        let mut seen_nodes = std::collections::HashSet::new();
626        for name in self.nodes.keys() {
627            if !seen_nodes.insert(name.clone()) {
628                errors.push(BuildError::DuplicateNode { id: name.clone() });
629            }
630        }
631
632        for edge in &self.edges {
633            if !self.nodes.contains_key(&edge.from) {
634                errors.push(BuildError::MissingNode {
635                    from: edge.from.clone(),
636                    to: edge.to.clone(),
637                });
638            }
639            if !self.nodes.contains_key(&edge.to) {
640                errors.push(BuildError::MissingNode {
641                    from: edge.from.clone(),
642                    to: edge.to.clone(),
643                });
644            }
645        }
646
647        if !errors.is_empty() {
648            return Err(errors);
649        }
650
651        let graph = Graph {
652            name: self.name,
653            nodes: self.nodes,
654            edges: self.edges,
655            start,
656            end,
657        };
658
659        if let Err(e) = graph.validate() {
660            return Err(BuildErrors(vec![BuildError::InvalidEdgeDefinition {
661                from: "graph".to_string(),
662                to: "graph".to_string(),
663                reason: e.to_string(),
664            }]));
665        }
666
667        Ok(graph)
668    }
669
670    pub fn name(&self) -> &str {
671        &self.name
672    }
673}
674
675// ─── 诊断辅助函数 ───────────────────────────────────────────────
676
677fn format_cycle(cycle: &[String]) -> String {
678    cycle.join(" → ")
679}
680
681fn check_fallback_in_cycles<S: WorkflowState, M: MergeStrategy<S>>(
682    graph: &Graph<S, M>,
683    cycles: &[Vec<String>],
684    diag: &mut GraphDiagnostics,
685) {
686    let fallback_edges: std::collections::HashSet<(&str, &str)> = graph
687        .edges
688        .iter()
689        .filter(|e| e.fallback)
690        .map(|e| (e.from.as_str(), e.to.as_str()))
691        .collect();
692
693    if fallback_edges.is_empty() {
694        return;
695    }
696
697    for cycle in cycles {
698        for i in 0..cycle.len() {
699            let next = (i + 1) % cycle.len();
700            let from = cycle[i].as_str();
701            let to = cycle[next].as_str();
702            if fallback_edges.contains(&(from, to)) {
703                let edge_str = format!("{} → {}", from, to);
704                diag.add_warning(
705                    DiagnosticCategory::FallbackInCycle,
706                    format!(
707                        "fallback edge {} participates in cycle: {} → {}",
708                        edge_str,
709                        format_cycle(cycle),
710                        cycle[0]
711                    ),
712                );
713            }
714        }
715    }
716}
717
718fn check_unreachable_nodes<S: WorkflowState, M: MergeStrategy<S>>(
719    graph: &Graph<S, M>,
720    adj: &std::collections::HashMap<String, Vec<String>>,
721    diag: &mut GraphDiagnostics,
722) {
723    let mut visited = std::collections::HashSet::new();
724    let mut queue = Vec::new();
725
726    queue.push(graph.start.clone());
727    visited.insert(graph.start.clone());
728
729    while let Some(node) = queue.pop() {
730        if let Some(neighbors) = adj.get(&node) {
731            for neighbor in neighbors {
732                if visited.insert(neighbor.clone()) {
733                    queue.push(neighbor.clone());
734                }
735            }
736        }
737    }
738
739    for name in graph.nodes.keys() {
740        if !visited.contains(name) {
741            diag.add_info(
742                DiagnosticCategory::Unreachable,
743                format!(
744                    "node '{}' is not reachable from start node '{}'",
745                    name, graph.start
746                ),
747            );
748        }
749    }
750}
751
752fn check_end_node_outgoing<S: WorkflowState, M: MergeStrategy<S>>(
753    graph: &Graph<S, M>,
754    diag: &mut GraphDiagnostics,
755) {
756    let outgoing: Vec<&Edge<S>> = graph.edges.iter().filter(|e| e.from == graph.end).collect();
757
758    if !outgoing.is_empty() {
759        let targets: Vec<&str> = outgoing.iter().map(|e| e.to.as_str()).collect();
760        diag.add_info(
761            DiagnosticCategory::EndNodeOutgoing,
762            format!(
763                "end node '{}' has {} outgoing edge(s) to: {:?}",
764                graph.end,
765                outgoing.len(),
766                targets
767            ),
768        );
769    }
770}
771
772fn fnv_hash(s: &str) -> u64 {
773    let mut hash: u64 = 0xcbf29ce484222325;
774    for &byte in s.as_bytes() {
775        hash ^= byte as u64;
776        hash = hash.wrapping_mul(0x100000001b3);
777    }
778    hash
779}