Skip to main content

lellm_graph/
graph.rs

1//! Graph 和 GraphBuilder。
2//!
3//! Edge 三类边模型:
4//! - **条件边** (`edge_if`) — `if/else-if` 规则链,按注册顺序求值,first match wins
5//! - **普通边** (`edge`) — 无条件非 fallback,条件边无命中时生效
6//! - **Fallback 边** (`edge_fallback`) — 最后兜底
7//!
8//! v0.4+: 泛型化 `Graph<S: WorkflowState>`,默认 `S = State`(向后兼容)。
9//!
10//! 运行时安全由 `run_inline()` 的 `max_steps` 参数负责。
11
12use std::sync::Arc;
13
14use indexmap::IndexMap;
15
16use crate::error::{BuildError, BuildErrors, GraphDiagnostics, GraphError, TerminalError};
17use crate::execution_engine::{ExecutionEngine, ExecutorState, NextAction};
18use crate::graph_analysis::{self, CycleAnalysis};
19use crate::node::{BarrierNode, ConditionNode, FlowNode, LeafNode, NodeKind};
20use crate::state::{State, StateMerge};
21use crate::workflow_state::{MergeStrategy, WorkflowState};
22
23// ─── Edge ──────────────────────────────────────────────────────
24
25/// 边条件回调类型别名。
26pub type EdgeCondition<S> = Arc<dyn Fn(&S) -> bool + Send + Sync>;
27
28/// 边(Edge)— 三类边模型。
29#[derive(Clone)]
30pub struct Edge<S: WorkflowState = State> {
31    pub from: String,
32    pub to: String,
33    /// 路由条件。Some = 条件边;None = 普通边或 fallback 边。
34    pub condition: Option<EdgeCondition<S>>,
35    /// 分析用约束(不参与 runtime 决策)
36    pub analysis: Option<EdgeAnalysis>,
37    /// 是否为 fallback 边(最后兜底)
38    pub fallback: bool,
39}
40
41impl<S: WorkflowState> Edge<S> {
42    /// 判断是否为条件边。
43    pub fn is_conditional(&self) -> bool {
44        self.condition.is_some() && !self.fallback
45    }
46
47    /// 判断是否为普通边(无条件非 fallback)。
48    pub fn is_normal(&self) -> bool {
49        self.condition.is_none() && !self.fallback
50    }
51}
52
53impl<S: WorkflowState> std::fmt::Debug for Edge<S> {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("Edge")
56            .field("from", &self.from)
57            .field("to", &self.to)
58            .field("has_condition", &self.condition.is_some())
59            .field("analysis", &self.analysis)
60            .field("fallback", &self.fallback)
61            .finish()
62    }
63}
64
65/// 分析用约束 — 仅用于 `analyze_cycles()` 静态分析。
66#[derive(Debug, Clone)]
67pub struct EdgeAnalysis {
68    /// 建议的最大访问次数
69    pub max_visits: Option<usize>,
70}
71
72// ─── Graph ─────────────────────────────────────────────────────
73
74/// 图(Graph)— 允许有环,循环保护由 GraphExecutor::max_steps 运行时熔断提供。
75#[derive(Clone)]
76pub struct Graph<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
77    pub(crate) name: String,
78    pub(crate) nodes: IndexMap<String, NodeKind<S, M>>,
79    pub(crate) edges: Vec<Edge<S>>,
80    pub(crate) start: String,
81    pub(crate) end: String,
82    /// P0-2: Canonical AST hash — 从 DSL 层计算,不依赖 HashMap 顺序。
83    /// 用于 Checkpoint 的 graph compatibility 校验。
84    pub(crate) canonical_hash: u64,
85}
86
87impl<S: WorkflowState, M: MergeStrategy<S>> Graph<S, M> {
88    pub fn name(&self) -> &str {
89        &self.name
90    }
91
92    pub fn node_names(&self) -> Vec<&str> {
93        self.nodes.keys().map(|s| s.as_str()).collect()
94    }
95
96    pub fn start_node(&self) -> &str {
97        &self.start
98    }
99
100    pub fn end_node(&self) -> &str {
101        &self.end
102    }
103
104    /// 获取 canonical AST hash — 从 DSL 层计算,不依赖 HashMap 顺序。
105    ///
106    /// 用于 Checkpoint 的 graph compatibility 校验。
107    /// 相同输入永远产生相同 hash,Checkpoint 不会因此失效。
108    pub fn canonical_hash(&self) -> u64 {
109        self.canonical_hash
110    }
111
112    /// 计算图结构指纹 hash(u64 原始值)— 基于 compiled graph 结构。
113    ///
114    /// 注意:此 hash 依赖 HashMap 迭代顺序,可能不稳定。
115    /// 优先使用 `canonical_hash()` 进行 Checkpoint 校验。
116    pub fn hash_u64(&self) -> u64 {
117        let mut s = String::new();
118        let mut names: Vec<&str> = self.nodes.keys().map(|k| k.as_str()).collect();
119        names.sort();
120        s.push_str(&names.join(","));
121        s.push('|');
122        let mut edge_strs: Vec<String> = self
123            .edges
124            .iter()
125            .map(|e| {
126                format!(
127                    "{}->{}{:?}{}",
128                    e.from,
129                    e.to,
130                    if e.condition.is_some() { "?" } else { "" },
131                    if e.fallback { "!" } else { "" }
132                )
133            })
134            .collect();
135        edge_strs.sort();
136        s.push_str(&edge_strs.join(","));
137        fnv_hash(&s)
138    }
139
140    /// 计算图结构指纹 hash(hex 字符串)。
141    pub fn hash(&self) -> String {
142        format!("{:016x}", self.canonical_hash)
143    }
144
145    pub fn edges_from(&self, from: &str) -> Vec<&Edge<S>> {
146        self.edges.iter().filter(|e| e.from == from).collect()
147    }
148
149    pub fn find_edge(&self, from: &str, to: &str) -> Option<&Edge<S>> {
150        self.edges.iter().find(|e| e.from == from && e.to == to)
151    }
152
153    /// 获取节点映射表引用。
154    pub fn node_map(&self) -> &IndexMap<String, NodeKind<S, M>> {
155        &self.nodes
156    }
157
158    /// 路由解析 — 根据当前节点和 State 找到下一个节点(返回 Option)。
159    ///
160    /// 内部统一使用的边评估逻辑。无匹配时返回 `None`(不区分"无边"和"无匹配")。
161    fn resolve_next(&self, current: &str, state: &S) -> Option<String> {
162        // 1. 条件边
163        for edge in self.edges_from(current) {
164            if edge.is_conditional() && edge.condition.as_ref().is_some_and(|c| c(state)) {
165                return Some(edge.to.clone());
166            }
167        }
168
169        // 2. 普通边
170        for edge in self.edges_from(current) {
171            if edge.is_normal() {
172                return Some(edge.to.clone());
173            }
174        }
175
176        // 3. Fallback 边
177        for edge in self.edges_from(current) {
178            if edge.fallback {
179                return Some(edge.to.clone());
180            }
181        }
182
183        None
184    }
185
186    /// 路由解析 — 内联执行使用,无匹配时返回错误。
187    pub(crate) fn resolve_next_inline(
188        &self,
189        current: &str,
190        state: &S,
191    ) -> Result<String, GraphError> {
192        if self.edges_from(current).is_empty() {
193            return Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
194                "node '{}' has no outgoing edges and is not the end node",
195                current
196            ))));
197        }
198
199        self.resolve_next(current, state).ok_or_else(|| {
200            GraphError::Terminal(TerminalError::InvalidGraph(format!(
201                "node '{}' has no matching outgoing edge",
202                current
203            )))
204        })
205    }
206
207    pub fn find_fallback_edge(&self, from: &str) -> Option<String> {
208        self.edges
209            .iter()
210            .find(|e| e.from == from && e.fallback)
211            .map(|e| e.to.clone())
212    }
213
214    /// 验证图结构。
215    pub fn validate(&self) -> Result<(), TerminalError> {
216        if !self.nodes.contains_key(&self.start) {
217            return Err(TerminalError::InvalidGraph(format!(
218                "start node '{}' not found",
219                self.start
220            )));
221        }
222
223        if !self.nodes.contains_key(&self.end) {
224            return Err(TerminalError::InvalidGraph(format!(
225                "end node '{}' not found",
226                self.end
227            )));
228        }
229
230        for edge in &self.edges {
231            if !self.nodes.contains_key(&edge.from) {
232                return Err(TerminalError::InvalidGraph(format!(
233                    "edge references non-existent source node '{}'",
234                    edge.from
235                )));
236            }
237            if !self.nodes.contains_key(&edge.to) {
238                return Err(TerminalError::InvalidGraph(format!(
239                    "edge references non-existent target node '{}'",
240                    edge.to
241                )));
242            }
243        }
244
245        Ok(())
246    }
247
248    /// 完整图诊断分析。
249    pub fn analyze(&self) -> GraphDiagnostics {
250        graph_analysis::analyze_graph(self)
251    }
252
253    /// @deprecated 使用 [`analyze()`](Self::analyze) 替代。
254    pub fn analyze_cycles(&self) -> CycleAnalysis {
255        let cycles = graph_analysis::find_all_cycles(self);
256        let unprotected = graph_analysis::filter_unprotected_cycles(self, &cycles);
257
258        CycleAnalysis {
259            has_cycles: !cycles.is_empty(),
260            cycles,
261            unprotected_cycles: unprotected,
262            total_edges: self.edges.len(),
263            protected_edges: self
264                .edges
265                .iter()
266                .filter(|e| e.analysis.as_ref().is_some_and(|a| a.max_visits.is_some()))
267                .count(),
268        }
269    }
270
271    // ─── 内联执行 ────────────────────────────────────────────
272
273    /// 内联执行 — 不产生 RuntimeEvent,不 Checkpoint。
274    ///
275    /// 接收 [`ExecutionEngine`](拥有者),内部循环构建 [`NodeContext`](能力视图)
276    /// 供节点使用。执行完毕后通过 `take_*` 消费 Mutation 和控制信号。
277    ///
278    /// 数据流:
279    /// ```text
280    /// ExecutionEngine
281    ///   → build_node_context()  → NodeContext<'_, S>
282    ///   → node.execute(ctx)     → 节点 record() Mutations
283    ///   → drop(ctx)             → 释放借用
284    ///   → take_mutations()      → 消费 Mutation 缓冲
285    ///   → state.apply_batch()   → apply 到 State
286    ///   → take_control()        → 获取路由信号
287    /// ```
288    pub async fn run_inline(
289        &self,
290        exec_ctx: &mut ExecutionEngine<'_, S>,
291        max_steps: usize,
292    ) -> Result<(), GraphError> {
293        let mut current = self.start_node().to_string();
294        let mut step: usize = 0;
295
296        loop {
297            step += 1;
298            if step > max_steps {
299                return Err(GraphError::Terminal(TerminalError::StepsExceeded {
300                    limit: max_steps,
301                }));
302            }
303
304            let node = self.nodes.get(&current).ok_or_else(|| {
305                GraphError::Terminal(TerminalError::NodeNotFound(current.clone()))
306            })?;
307
308            // 根据 NodeKind 分发执行
309            match node {
310                NodeKind::Task(n) => {
311                    let mut ctx = exec_ctx.build_node_context();
312                    n.execute(&mut ctx).await?;
313                }
314                NodeKind::Condition(n) => {
315                    let mut ctx = exec_ctx.build_leaf_context();
316                    <ConditionNode<S> as LeafNode<S>>::execute(n, &mut ctx).await?;
317                }
318                NodeKind::Barrier(n) => {
319                    let mut ctx = exec_ctx.build_leaf_context();
320                    <BarrierNode<S> as LeafNode<S>>::execute(n, &mut ctx).await?;
321                }
322                NodeKind::External(n) => {
323                    let mut ctx = exec_ctx.build_node_context();
324                    n.execute(&mut ctx).await?;
325                }
326                NodeKind::ExternalLeaf(n) => {
327                    let mut ctx = exec_ctx.build_leaf_context();
328                    n.execute(&mut ctx).await?;
329                }
330                NodeKind::Parallel(p) => {
331                    // ExecutorOperation 直接接收 &mut ExecutionEngine
332                    p.execute(exec_ctx).await?;
333                }
334                NodeKind::Subgraph(spec) => {
335                    // Subgraph 执行 — 通过 CompiledSubgraph 的 StateProjector 递归执行内层 Graph
336                    let stream = exec_ctx.stream_sink();
337                    let cancel = exec_ctx.cancel_token().clone();
338                    spec.execute(exec_ctx.state_mut(), stream, cancel).await?;
339                }
340            }
341
342            // commit mutations (Unit of Work) — 对 Parallel 是空操作
343            exec_ctx.commit();
344
345            // 消费 FlowEvent 缓冲(积累到 engine,执行结束后由调用者取用)
346            let _flow_events = exec_ctx.take_flow_events();
347
348            // 提取控制信号
349            let (next_action, _signal) = exec_ctx.take_control();
350
351            // 处理路由
352            match next_action {
353                NextAction::End => return Ok(()),
354                NextAction::Goto(target) => {
355                    current = target;
356                }
357                NextAction::Next => {
358                    if current == self.end_node() {
359                        return Ok(());
360                    }
361                    current = self.resolve_next_inline(&current, exec_ctx.state())?;
362                }
363            }
364        }
365    }
366
367    /// P0-2: 设置 canonical hash — 由 DSL 层调用。
368    ///
369    /// ⚠️ 已废弃 — 优先使用 `GraphBuilder::canonical_hash()` 在构建时设置。
370    /// 此方法保留用于 `build_react_graph()` 等内部场景。
371    #[deprecated(since = "0.5.0", note = "使用 GraphBuilder::canonical_hash() 替代")]
372    #[doc(hidden)]
373    pub fn set_canonical_hash(&mut self, hash: u64) {
374        self.canonical_hash = hash;
375    }
376}
377
378// ─── PendingEdge ──────────────────────────────────────────────
379
380/// 待完成的边 — 链式调用的中间句柄。
381pub struct PendingEdge<'a, S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
382    builder: &'a mut GraphBuilder<S, M>,
383    edge_index: usize,
384}
385
386impl<'a, S: WorkflowState, M: MergeStrategy<S>> PendingEdge<'a, S, M> {
387    pub fn max_visits(self, n: usize) -> &'a mut GraphBuilder<S, M> {
388        self.builder.edges[self.edge_index].analysis = Some(EdgeAnalysis {
389            max_visits: Some(n),
390        });
391        self.builder
392    }
393}
394
395// ─── GraphBuilder ─────────────────────────────────────────────
396
397/// Graph 构建器。
398pub struct GraphBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
399    name: String,
400    nodes: IndexMap<String, NodeKind<S, M>>,
401    edges: Vec<Edge<S>>,
402    start: Option<String>,
403    end: Option<String>,
404    /// P0-2: 可选的 canonical hash — 如果 DSL 层设置了就使用,否则计算结构 hash。
405    canonical_hash: Option<u64>,
406}
407
408impl<S: WorkflowState, M: MergeStrategy<S>> GraphBuilder<S, M> {
409    /// 创建 GraphBuilder。
410    ///
411    /// 类型参数由调用方推断或显式指定。
412    /// - 默认: `GraphBuilder::new("name")` → `GraphBuilder<State, StateMerge>`
413    /// - 自定义: `GraphBuilder::<AgentState, _>::new("name")`
414    pub fn new(name: impl Into<String>) -> Self {
415        Self {
416            name: name.into(),
417            nodes: IndexMap::new(),
418            edges: Vec::new(),
419            start: None,
420            end: None,
421            canonical_hash: None,
422        }
423    }
424
425    /// P0-2: 设置 canonical hash — 由 DSL 层(如 AgentBuilder)调用。
426    ///
427    /// 如果不设置,`build()` 会自动计算一个基于图结构的 hash。
428    pub fn canonical_hash(&mut self, hash: u64) -> &mut Self {
429        self.canonical_hash = Some(hash);
430        self
431    }
432
433    pub fn start(&mut self, node: impl Into<String>) -> &mut Self {
434        self.start = Some(node.into());
435        self
436    }
437
438    pub fn end(&mut self, node: impl Into<String>) -> &mut Self {
439        self.end = Some(node.into());
440        self
441    }
442
443    pub fn node(&mut self, name: impl Into<String>, kind: NodeKind<S, M>) -> &mut Self {
444        self.nodes.insert(name.into(), kind);
445        self
446    }
447
448    /// 便捷方法 — 添加 Subgraph 节点。
449    ///
450    /// 自动将 [`SubgraphSpec`](crate::SubgraphSpec) 编译为 [`CompiledSubgraph`](crate::CompiledSubgraph)
451    /// 并注册为节点。
452    ///
453    /// # 示例
454    ///
455    /// ```ignore
456    /// use lellm_graph::{GraphBuilder, SubgraphSpec, StateLens};
457    ///
458    /// let agent_graph = AgentBuilder::new(model).tools([...]).build();
459    /// let spec = SubgraphSpec::new(agent_graph, AgentLens);
460    ///
461    /// let mut builder = GraphBuilder::<WorkflowState, _>::new("workflow");
462    /// builder.subgraph("agent", spec);  // 语法糖
463    /// // 等价于:
464    /// // builder.node("agent", NodeKind::Subgraph(spec.compile()));
465    /// ```
466    pub fn subgraph<Inner: WorkflowState, IM: MergeStrategy<Inner>, L: crate::StateLens<S, Inner>>(
467        &mut self,
468        name: impl Into<String>,
469        spec: crate::SubgraphSpec<S, Inner, IM, L>,
470    ) -> &mut Self
471    where
472        S: 'static,
473        Inner: 'static,
474        IM: 'static,
475        L: 'static,
476    {
477        let compiled = spec.compile();
478        self.nodes.insert(name.into(), NodeKind::Subgraph(compiled));
479        self
480    }
481
482    pub fn edge(
483        &mut self,
484        from: impl Into<String>,
485        to: impl Into<String>,
486    ) -> PendingEdge<'_, S, M> {
487        let edge_index = self.edges.len();
488        self.edges.push(Edge {
489            from: from.into(),
490            to: to.into(),
491            condition: None,
492            analysis: None,
493            fallback: false,
494        });
495        PendingEdge {
496            builder: self,
497            edge_index,
498        }
499    }
500
501    pub fn edge_if(
502        &mut self,
503        from: impl Into<String>,
504        to: impl Into<String>,
505        condition: impl Fn(&S) -> bool + Send + Sync + 'static,
506    ) -> PendingEdge<'_, S, M> {
507        let edge_index = self.edges.len();
508        self.edges.push(Edge {
509            from: from.into(),
510            to: to.into(),
511            condition: Some(Arc::new(condition)),
512            analysis: None,
513            fallback: false,
514        });
515        PendingEdge {
516            builder: self,
517            edge_index,
518        }
519    }
520
521    pub fn edge_fallback(
522        &mut self,
523        from: impl Into<String>,
524        to: impl Into<String>,
525    ) -> PendingEdge<'_, S, M> {
526        let edge_index = self.edges.len();
527        self.edges.push(Edge {
528            from: from.into(),
529            to: to.into(),
530            condition: None,
531            analysis: None,
532            fallback: true,
533        });
534        PendingEdge {
535            builder: self,
536            edge_index,
537        }
538    }
539
540    pub fn build(self) -> Result<Graph<S, M>, BuildErrors> {
541        let mut errors = BuildErrors::new();
542
543        let start = match self.start {
544            Some(s) => s,
545            None => {
546                errors.push(BuildError::MissingEntryPoint);
547                return Err(errors);
548            }
549        };
550        let end = match self.end {
551            Some(s) => s,
552            None => {
553                errors.push(BuildError::MissingExitPoint);
554                return Err(errors);
555            }
556        };
557
558        let mut seen_nodes = std::collections::HashSet::new();
559        for name in self.nodes.keys() {
560            if !seen_nodes.insert(name.clone()) {
561                errors.push(BuildError::DuplicateNode { id: name.clone() });
562            }
563        }
564
565        for edge in &self.edges {
566            if !self.nodes.contains_key(&edge.from) {
567                errors.push(BuildError::MissingNode {
568                    from: edge.from.clone(),
569                    to: edge.to.clone(),
570                });
571            }
572            if !self.nodes.contains_key(&edge.to) {
573                errors.push(BuildError::MissingNode {
574                    from: edge.from.clone(),
575                    to: edge.to.clone(),
576                });
577            }
578        }
579
580        if !errors.is_empty() {
581            return Err(errors);
582        }
583
584        // 计算临时结构 hash 用于验证(不依赖 HashMap 顺序)
585        let structural_hash = compute_structural_hash(&self.nodes, &self.edges);
586
587        let graph = Graph {
588            name: self.name,
589            nodes: self.nodes,
590            edges: self.edges,
591            start,
592            end,
593            canonical_hash: self.canonical_hash.unwrap_or(structural_hash),
594        };
595
596        if let Err(e) = graph.validate() {
597            return Err(BuildErrors(vec![BuildError::InvalidEdgeDefinition {
598                from: "graph".to_string(),
599                to: "graph".to_string(),
600                reason: e.to_string(),
601            }]));
602        }
603
604        Ok(graph)
605    }
606
607    pub fn name(&self) -> &str {
608        &self.name
609    }
610
611    /// 构建并编译 — 在 `build()` 之后运行 Compiler Pass(如 InlinePass)。
612    ///
613    /// 与 `build()` 的区别:
614    /// - `build()` — 仅验证 AST,返回原始 Graph
615    /// - `compile()` — 验证 + 运行优化 pass,返回优化后的 Graph
616    ///
617    /// # 示例
618    ///
619    /// ```ignore
620    /// let graph = builder.compile()?;  // 自动运行 InlinePass
621    /// ```
622    pub fn compile(self) -> Result<Graph<S, M>, BuildErrors> {
623        use crate::compiler::CompilerPass;
624
625        let mut graph = self.build()?;
626
627        // 运行 Compiler Pass
628        let mut ctx = crate::compiler::CompilerContext::<S>::new();
629        let pass = crate::compiler::InlinePass::new();
630        pass.run(&mut graph, &mut ctx);
631
632        if ctx.debug {
633            tracing::debug!(
634                inlined = ctx.stats.inlined_count,
635                skipped = ctx.stats.not_inlined_count,
636                "compile passes complete"
637            );
638        }
639
640        Ok(graph)
641    }
642}
643
644// ─── Utilities ─────────────────────────────────────────────────
645
646fn fnv_hash(s: &str) -> u64 {
647    let mut hash: u64 = 0xcbf29ce484222325;
648    for &byte in s.as_bytes() {
649        hash ^= byte as u64;
650        hash = hash.wrapping_mul(0x100000001b3);
651    }
652    hash
653}
654
655/// 计算图结构 hash — 不依赖 HashMap 迭代顺序。
656///
657/// 对节点名和边定义排序后 hash,确保相同结构产生相同 hash。
658/// 用于 `build()` 时没有 DSL 层 canonical_hash 的 fallback。
659fn compute_structural_hash<S: WorkflowState, M: MergeStrategy<S>>(
660    nodes: &IndexMap<String, NodeKind<S, M>>,
661    edges: &[Edge<S>],
662) -> u64 {
663    let mut s = String::new();
664    // 节点名排序
665    let mut names: Vec<&str> = nodes.keys().map(|k| k.as_str()).collect();
666    names.sort();
667    s.push_str(&names.join(","));
668    s.push('|');
669    // 边排序
670    let mut edge_strs: Vec<String> = edges
671        .iter()
672        .map(|e| {
673            format!(
674                "{}->{}{:?}{}",
675                e.from,
676                e.to,
677                if e.condition.is_some() { "?" } else { "" },
678                if e.fallback { "!" } else { "" }
679            )
680        })
681        .collect();
682    edge_strs.sort();
683    s.push_str(&edge_strs.join(","));
684    fnv_hash(&s)
685}