Skip to main content

lellm_graph/
graph.rs

1//! Graph 和 GraphBuilder。
2//!
3//! Edge 三类边模型:
4//! - **条件边** (`edge_if`) — `if/else-if` 规则链,按注册顺序求值,first match wins
5//! - **普通边** (`edge`) — 无条件非 fallback,条件边无命中时生效
6//! - **Fallback 边** (`edge_fallback`) — 最后兜底
7//!
8//! v0.4+: 泛型化 `Graph<S: WorkflowState>`,默认 `S = State`(向后兼容)。
9//!
10//! 运行时安全由 `run_inline()` 的 `max_steps` 参数负责。
11
12use std::sync::Arc;
13
14use indexmap::IndexMap;
15
16use crate::error::{BuildError, BuildErrors, GraphDiagnostics, GraphError, TerminalError};
17use crate::execution_engine::{ExecutionEngine, ExecutorState, NextAction};
18use crate::graph_analysis::{self, CycleAnalysis};
19use crate::node::{BarrierNode, ConditionNode, ExecutorOperation, FlowNode, LeafNode, NodeKind};
20use crate::state::{State, StateMerge};
21use crate::workflow_state::{MergeStrategy, WorkflowState};
22
23// ─── Edge ──────────────────────────────────────────────────────
24
25/// 边条件回调类型别名。
26pub type EdgeCondition<S> = Arc<dyn Fn(&S) -> bool + Send + Sync>;
27
28/// 边(Edge)— 三类边模型。
29#[derive(Clone)]
30pub struct Edge<S: WorkflowState = State> {
31    pub from: String,
32    pub to: String,
33    /// 路由条件。Some = 条件边;None = 普通边或 fallback 边。
34    pub condition: Option<EdgeCondition<S>>,
35    /// 分析用约束(不参与 runtime 决策)
36    pub analysis: Option<EdgeAnalysis>,
37    /// 是否为 fallback 边(最后兜底)
38    pub fallback: bool,
39}
40
41impl<S: WorkflowState> Edge<S> {
42    /// 判断是否为条件边。
43    pub fn is_conditional(&self) -> bool {
44        self.condition.is_some() && !self.fallback
45    }
46
47    /// 判断是否为普通边(无条件非 fallback)。
48    pub fn is_normal(&self) -> bool {
49        self.condition.is_none() && !self.fallback
50    }
51}
52
53impl<S: WorkflowState> std::fmt::Debug for Edge<S> {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("Edge")
56            .field("from", &self.from)
57            .field("to", &self.to)
58            .field("has_condition", &self.condition.is_some())
59            .field("analysis", &self.analysis)
60            .field("fallback", &self.fallback)
61            .finish()
62    }
63}
64
65/// 分析用约束 — 仅用于 `analyze_cycles()` 静态分析。
66#[derive(Debug, Clone)]
67pub struct EdgeAnalysis {
68    /// 建议的最大访问次数
69    pub max_visits: Option<usize>,
70}
71
72// ─── Graph ─────────────────────────────────────────────────────
73
74/// 图(Graph)— 允许有环,循环保护由 GraphExecutor::max_steps 运行时熔断提供。
75#[derive(Clone)]
76pub struct Graph<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
77    pub(crate) name: String,
78    pub(crate) nodes: IndexMap<String, NodeKind<S, M>>,
79    pub(crate) edges: Vec<Edge<S>>,
80    pub(crate) start: String,
81    pub(crate) end: String,
82}
83
84impl<S: WorkflowState, M: MergeStrategy<S>> Graph<S, M> {
85    pub fn name(&self) -> &str {
86        &self.name
87    }
88
89    pub fn node_names(&self) -> Vec<&str> {
90        self.nodes.keys().map(|s| s.as_str()).collect()
91    }
92
93    pub fn start_node(&self) -> &str {
94        &self.start
95    }
96
97    pub fn end_node(&self) -> &str {
98        &self.end
99    }
100
101    /// 计算图结构指纹 hash(u64 原始值)。
102    ///
103    /// 用于 Checkpoint 的 graph compatibility 校验。
104    pub fn hash_u64(&self) -> u64 {
105        let mut s = String::new();
106        let mut names: Vec<&str> = self.nodes.keys().map(|k| k.as_str()).collect();
107        names.sort();
108        s.push_str(&names.join(","));
109        s.push('|');
110        let mut edge_strs: Vec<String> = self
111            .edges
112            .iter()
113            .map(|e| {
114                format!(
115                    "{}->{}{:?}{}",
116                    e.from,
117                    e.to,
118                    if e.condition.is_some() { "?" } else { "" },
119                    if e.fallback { "!" } else { "" }
120                )
121            })
122            .collect();
123        edge_strs.sort();
124        s.push_str(&edge_strs.join(","));
125        fnv_hash(&s)
126    }
127
128    /// 计算图结构指纹 hash(hex 字符串)。
129    pub fn hash(&self) -> String {
130        format!("{:016x}", self.hash_u64())
131    }
132
133    pub fn edges_from(&self, from: &str) -> Vec<&Edge<S>> {
134        self.edges.iter().filter(|e| e.from == from).collect()
135    }
136
137    pub fn find_edge(&self, from: &str, to: &str) -> Option<&Edge<S>> {
138        self.edges.iter().find(|e| e.from == from && e.to == to)
139    }
140
141    /// 获取节点映射表引用。
142    pub fn node_map(&self) -> &IndexMap<String, NodeKind<S, M>> {
143        &self.nodes
144    }
145
146    /// 路由解析 — 根据当前节点和 State 找到下一个节点(返回 Option)。
147    ///
148    /// 内部统一使用的边评估逻辑。无匹配时返回 `None`(不区分"无边"和"无匹配")。
149    fn resolve_next(&self, current: &str, state: &S) -> Option<String> {
150        // 1. 条件边
151        for edge in self.edges_from(current) {
152            if edge.is_conditional() && edge.condition.as_ref().is_some_and(|c| c(state)) {
153                return Some(edge.to.clone());
154            }
155        }
156
157        // 2. 普通边
158        for edge in self.edges_from(current) {
159            if edge.is_normal() {
160                return Some(edge.to.clone());
161            }
162        }
163
164        // 3. Fallback 边
165        for edge in self.edges_from(current) {
166            if edge.fallback {
167                return Some(edge.to.clone());
168            }
169        }
170
171        None
172    }
173
174    /// 路由解析 — 内联执行使用,无匹配时返回错误。
175    pub(crate) fn resolve_next_inline(
176        &self,
177        current: &str,
178        state: &S,
179    ) -> Result<String, GraphError> {
180        if self.edges_from(current).is_empty() {
181            return Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
182                "node '{}' has no outgoing edges and is not the end node",
183                current
184            ))));
185        }
186
187        self.resolve_next(current, state).ok_or_else(|| {
188            GraphError::Terminal(TerminalError::InvalidGraph(format!(
189                "node '{}' has no matching outgoing edge",
190                current
191            )))
192        })
193    }
194
195    pub fn find_fallback_edge(&self, from: &str) -> Option<String> {
196        self.edges
197            .iter()
198            .find(|e| e.from == from && e.fallback)
199            .map(|e| e.to.clone())
200    }
201
202    /// 验证图结构。
203    pub fn validate(&self) -> Result<(), TerminalError> {
204        if !self.nodes.contains_key(&self.start) {
205            return Err(TerminalError::InvalidGraph(format!(
206                "start node '{}' not found",
207                self.start
208            )));
209        }
210
211        if !self.nodes.contains_key(&self.end) {
212            return Err(TerminalError::InvalidGraph(format!(
213                "end node '{}' not found",
214                self.end
215            )));
216        }
217
218        for edge in &self.edges {
219            if !self.nodes.contains_key(&edge.from) {
220                return Err(TerminalError::InvalidGraph(format!(
221                    "edge references non-existent source node '{}'",
222                    edge.from
223                )));
224            }
225            if !self.nodes.contains_key(&edge.to) {
226                return Err(TerminalError::InvalidGraph(format!(
227                    "edge references non-existent target node '{}'",
228                    edge.to
229                )));
230            }
231        }
232
233        Ok(())
234    }
235
236    /// 完整图诊断分析。
237    pub fn analyze(&self) -> GraphDiagnostics {
238        graph_analysis::analyze_graph(self)
239    }
240
241    /// @deprecated 使用 [`analyze()`](Self::analyze) 替代。
242    pub fn analyze_cycles(&self) -> CycleAnalysis {
243        let cycles = graph_analysis::find_all_cycles(self);
244        let unprotected = graph_analysis::filter_unprotected_cycles(self, &cycles);
245
246        CycleAnalysis {
247            has_cycles: !cycles.is_empty(),
248            cycles,
249            unprotected_cycles: unprotected,
250            total_edges: self.edges.len(),
251            protected_edges: self
252                .edges
253                .iter()
254                .filter(|e| e.analysis.as_ref().is_some_and(|a| a.max_visits.is_some()))
255                .count(),
256        }
257    }
258
259    // ─── 内联执行 ────────────────────────────────────────────
260
261    /// 内联执行 — 不产生 RuntimeEvent,不 Checkpoint。
262    ///
263    /// 接收 [`ExecutionEngine`](拥有者),内部循环构建 [`NodeContext`](能力视图)
264    /// 供节点使用。执行完毕后通过 `take_*` 消费 Mutation 和控制信号。
265    ///
266    /// 数据流:
267    /// ```text
268    /// ExecutionEngine
269    ///   → build_node_context()  → NodeContext<'_, S>
270    ///   → node.execute(ctx)     → 节点 record() Mutations
271    ///   → drop(ctx)             → 释放借用
272    ///   → take_mutations()      → 消费 Mutation 缓冲
273    ///   → state.apply_batch()   → apply 到 State
274    ///   → take_control()        → 获取路由信号
275    /// ```
276    pub async fn run_inline(
277        &self,
278        exec_ctx: &mut ExecutionEngine<S>,
279        max_steps: usize,
280    ) -> Result<(), GraphError> {
281        let mut current = self.start_node().to_string();
282        let mut step: usize = 0;
283
284        loop {
285            step += 1;
286            if step > max_steps {
287                return Err(GraphError::Terminal(TerminalError::StepsExceeded {
288                    limit: max_steps,
289                }));
290            }
291
292            let node = self.nodes.get(&current).ok_or_else(|| {
293                GraphError::Terminal(TerminalError::NodeNotFound(current.clone()))
294            })?;
295
296            // 根据 NodeKind 分发执行
297            match node {
298                NodeKind::Task(n) => {
299                    let mut ctx = exec_ctx.build_node_context();
300                    n.execute(&mut ctx).await?;
301                }
302                NodeKind::Condition(n) => {
303                    let mut ctx = exec_ctx.build_leaf_context();
304                    <ConditionNode<S> as LeafNode<S>>::execute(n, &mut ctx).await?;
305                }
306                NodeKind::Barrier(n) => {
307                    let mut ctx = exec_ctx.build_leaf_context();
308                    <BarrierNode<S> as LeafNode<S>>::execute(n, &mut ctx).await?;
309                }
310                NodeKind::External(n) => {
311                    let mut ctx = exec_ctx.build_node_context();
312                    n.execute(&mut ctx).await?;
313                }
314                NodeKind::ExternalLeaf(n) => {
315                    let mut ctx = exec_ctx.build_leaf_context();
316                    n.execute(&mut ctx).await?;
317                }
318                NodeKind::Parallel(p) => {
319                    // ExecutorOperation 直接接收 &mut ExecutionEngine
320                    p.execute(exec_ctx).await?;
321                }
322            }
323
324            // commit mutations (Unit of Work) — 对 Parallel 是空操作
325            exec_ctx.commit();
326
327            // 消费 FlowEvent 缓冲(积累到 engine,执行结束后由调用者取用)
328            let _flow_events = exec_ctx.take_flow_events();
329
330            // 提取控制信号
331            let (next_action, _signal) = exec_ctx.take_control();
332
333            // 处理路由
334            match next_action {
335                NextAction::End => return Ok(()),
336                NextAction::Goto(target) => {
337                    current = target;
338                }
339                NextAction::Next => {
340                    if current == self.end_node() {
341                        return Ok(());
342                    }
343                    current = self.resolve_next_inline(&current, exec_ctx.state())?;
344                }
345            }
346        }
347    }
348}
349
350// ─── PendingEdge ──────────────────────────────────────────────
351
352/// 待完成的边 — 链式调用的中间句柄。
353pub struct PendingEdge<'a, S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
354    builder: &'a mut GraphBuilder<S, M>,
355    edge_index: usize,
356}
357
358impl<'a, S: WorkflowState, M: MergeStrategy<S>> PendingEdge<'a, S, M> {
359    pub fn max_visits(self, n: usize) -> &'a mut GraphBuilder<S, M> {
360        self.builder.edges[self.edge_index].analysis = Some(EdgeAnalysis {
361            max_visits: Some(n),
362        });
363        self.builder
364    }
365}
366
367// ─── GraphBuilder ─────────────────────────────────────────────
368
369/// Graph 构建器。
370pub struct GraphBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
371    name: String,
372    nodes: IndexMap<String, NodeKind<S, M>>,
373    edges: Vec<Edge<S>>,
374    start: Option<String>,
375    end: Option<String>,
376}
377
378impl<S: WorkflowState, M: MergeStrategy<S>> GraphBuilder<S, M> {
379    /// 创建 GraphBuilder。
380    ///
381    /// 类型参数由调用方推断或显式指定。
382    /// - 默认: `GraphBuilder::new("name")` → `GraphBuilder<State, StateMerge>`
383    /// - 自定义: `GraphBuilder::<AgentState, _>::new("name")`
384    pub fn new(name: impl Into<String>) -> Self {
385        Self {
386            name: name.into(),
387            nodes: IndexMap::new(),
388            edges: Vec::new(),
389            start: None,
390            end: None,
391        }
392    }
393
394    pub fn start(&mut self, node: impl Into<String>) -> &mut Self {
395        self.start = Some(node.into());
396        self
397    }
398
399    pub fn end(&mut self, node: impl Into<String>) -> &mut Self {
400        self.end = Some(node.into());
401        self
402    }
403
404    pub fn node(&mut self, name: impl Into<String>, kind: NodeKind<S, M>) -> &mut Self {
405        self.nodes.insert(name.into(), kind);
406        self
407    }
408
409    pub fn edge(
410        &mut self,
411        from: impl Into<String>,
412        to: impl Into<String>,
413    ) -> PendingEdge<'_, S, M> {
414        let edge_index = self.edges.len();
415        self.edges.push(Edge {
416            from: from.into(),
417            to: to.into(),
418            condition: None,
419            analysis: None,
420            fallback: false,
421        });
422        PendingEdge {
423            builder: self,
424            edge_index,
425        }
426    }
427
428    pub fn edge_if(
429        &mut self,
430        from: impl Into<String>,
431        to: impl Into<String>,
432        condition: impl Fn(&S) -> bool + Send + Sync + 'static,
433    ) -> PendingEdge<'_, S, M> {
434        let edge_index = self.edges.len();
435        self.edges.push(Edge {
436            from: from.into(),
437            to: to.into(),
438            condition: Some(Arc::new(condition)),
439            analysis: None,
440            fallback: false,
441        });
442        PendingEdge {
443            builder: self,
444            edge_index,
445        }
446    }
447
448    pub fn edge_fallback(
449        &mut self,
450        from: impl Into<String>,
451        to: impl Into<String>,
452    ) -> PendingEdge<'_, S, M> {
453        let edge_index = self.edges.len();
454        self.edges.push(Edge {
455            from: from.into(),
456            to: to.into(),
457            condition: None,
458            analysis: None,
459            fallback: true,
460        });
461        PendingEdge {
462            builder: self,
463            edge_index,
464        }
465    }
466
467    pub fn build(self) -> Result<Graph<S, M>, BuildErrors> {
468        let mut errors = BuildErrors::new();
469
470        let start = match self.start {
471            Some(s) => s,
472            None => {
473                errors.push(BuildError::MissingEntryPoint);
474                return Err(errors);
475            }
476        };
477        let end = match self.end {
478            Some(s) => s,
479            None => {
480                errors.push(BuildError::MissingExitPoint);
481                return Err(errors);
482            }
483        };
484
485        let mut seen_nodes = std::collections::HashSet::new();
486        for name in self.nodes.keys() {
487            if !seen_nodes.insert(name.clone()) {
488                errors.push(BuildError::DuplicateNode { id: name.clone() });
489            }
490        }
491
492        for edge in &self.edges {
493            if !self.nodes.contains_key(&edge.from) {
494                errors.push(BuildError::MissingNode {
495                    from: edge.from.clone(),
496                    to: edge.to.clone(),
497                });
498            }
499            if !self.nodes.contains_key(&edge.to) {
500                errors.push(BuildError::MissingNode {
501                    from: edge.from.clone(),
502                    to: edge.to.clone(),
503                });
504            }
505        }
506
507        if !errors.is_empty() {
508            return Err(errors);
509        }
510
511        let graph = Graph {
512            name: self.name,
513            nodes: self.nodes,
514            edges: self.edges,
515            start,
516            end,
517        };
518
519        if let Err(e) = graph.validate() {
520            return Err(BuildErrors(vec![BuildError::InvalidEdgeDefinition {
521                from: "graph".to_string(),
522                to: "graph".to_string(),
523                reason: e.to_string(),
524            }]));
525        }
526
527        Ok(graph)
528    }
529
530    pub fn name(&self) -> &str {
531        &self.name
532    }
533}
534
535// ─── Utilities ─────────────────────────────────────────────────
536
537fn fnv_hash(s: &str) -> u64 {
538    let mut hash: u64 = 0xcbf29ce484222325;
539    for &byte in s.as_bytes() {
540        hash ^= byte as u64;
541        hash = hash.wrapping_mul(0x100000001b3);
542    }
543    hash
544}