mf_engine/handler/
traversal.rs

1// 导入必要的依赖
2use ahash::HashMap;
3use fixedbitset::FixedBitSet;
4use petgraph::data::DataMap;
5use petgraph::matrix_graph::Zero;
6use petgraph::prelude::{EdgeIndex, NodeIndex, StableDiGraph};
7use petgraph::visit::{EdgeRef, IntoNodeIdentifiers, VisitMap, Visitable};
8use petgraph::{Incoming, Outgoing};
9use serde_json::json;
10use std::rc::Rc;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::Instant;
14
15use crate::config::ZEN_CONFIG;
16use crate::model::{
17    DecisionEdge, DecisionNode, DecisionNodeKind, SwitchStatement,
18    SwitchStatementHitPolicy,
19};
20use crate::DecisionGraphTrace;
21use mf_expression::variable::Variable;
22use mf_expression::Isolate;
23
24/// # Petgraph 库说明
25///
26/// Petgraph 是一个功能强大的 Rust 图数据结构库,在本代码中主要用于实现决策图的遍历和管理。
27///
28/// ## 主要功能
29/// 1. 图数据结构
30///    - 支持有向图和无向图
31///    - 提供稳定的图结构(StableDiGraph)
32///    - 支持节点和边的权重
33///
34/// ## 核心组件
35/// 1. 图结构
36///    - `StableDiGraph`: 稳定的有向图实现,保证节点和边的索引在删除操作后保持稳定
37///    - `NodeIndex`: 节点索引类型,用于唯一标识图中的节点
38///    - `EdgeIndex`: 边索引类型,用于唯一标识图中的边
39///
40/// 2. 遍历功能
41///    - `Incoming`: 用于获取节点的入边
42///    - `Outgoing`: 用于获取节点的出边
43///    - `VisitMap`: 用于跟踪已访问的节点
44///    - `IntoNodeIdentifiers`: 用于遍历图中的所有节点
45///
46/// ## 在本代码中的应用
47/// 1. 决策图表示
48///    - 使用 `StableDiGraph` 存储决策节点和边
49///    - 节点存储 `DecisionNode` 信息
50///    - 边存储 `DecisionEdge` 信息
51///
52/// 2. 图遍历
53///    - 使用 `Incoming` 和 `Outgoing` 遍历节点的依赖关系
54///    - 使用 `VisitMap` 跟踪已访问的节点
55///    - 使用 `IntoNodeIdentifiers` 获取所有节点
56///
57/// 3. 图操作
58///    - 添加和删除节点
59///    - 添加和删除边
60///    - 查询节点和边的属性
61///
62/// ## 优势
63/// 1. 性能
64///    - 高效的图操作
65///    - 稳定的索引保证
66///    - 优化的内存使用
67///
68/// 2. 功能
69///    - 丰富的图算法支持
70///    - 灵活的数据结构
71///    - 良好的类型安全
72///
73/// 3. 可靠性
74///    - 经过充分测试
75///    - 活跃的维护
76///    - 良好的文档支持
77
78/// 定义决策图的类型别名,使用稳定的有向图结构
79/// 节点类型为 Arc<DecisionNode>,边类型为 Arc<DecisionEdge>
80pub(crate) type StableDiDecisionGraph =
81    StableDiGraph<Arc<DecisionNode>, Arc<DecisionEdge>>;
82
83/// 图遍历器,用于处理决策图的遍历和状态管理
84pub(crate) struct GraphWalker {
85    /// 记录已访问的节点
86    ordered: FixedBitSet,
87    /// 待访问的节点队列
88    to_visit: Vec<NodeIndex>,
89    /// 存储节点数据的映射
90    node_data: HashMap<NodeIndex, Variable>,
91    /// 当前迭代次数
92    iter: usize,
93    /// 已访问的开关节点列表
94    visited_switch_nodes: Vec<NodeIndex>,
95    /// 是否在上下文中包含节点信息
96    nodes_in_context: bool,
97}
98
99/// 最大迭代次数限制,防止无限循环
100const ITER_MAX: usize = 1_000;
101
102impl GraphWalker {
103    /// 创建新的图遍历器实例
104    /// 初始化遍历器并添加初始节点
105    pub fn new(graph: &StableDiDecisionGraph) -> Self {
106        let mut topo = Self::empty(graph);
107        topo.extend_with_initials(graph);
108        topo
109    }
110
111    /// 扩展初始节点(没有入边的节点)
112    /// 将输入节点添加到待访问队列中
113    fn extend_with_initials(
114        &mut self,
115        g: &StableDiDecisionGraph,
116    ) {
117        self.to_visit.extend(g.node_identifiers().filter(move |&nid| {
118            g.node_weight(nid).is_some_and(|n| {
119                matches!(n.kind, DecisionNodeKind::InputNode { content: _ })
120            })
121        }));
122    }
123
124    /// 创建空的图遍历器
125    /// 初始化所有字段为默认值
126    fn empty(graph: &StableDiDecisionGraph) -> Self {
127        Self {
128            ordered: graph.visit_map(),
129            to_visit: Vec::new(),
130            node_data: Default::default(),
131            visited_switch_nodes: Default::default(),
132            iter: 0,
133            nodes_in_context: ZEN_CONFIG
134                .nodes_in_context
135                .load(Ordering::Relaxed),
136        }
137    }
138
139    /// 重置图遍历器状态
140    /// 清空已访问节点和待访问队列,重新添加初始节点
141    pub fn reset(
142        &mut self,
143        g: &StableDiDecisionGraph,
144    ) {
145        self.ordered.clear();
146        self.to_visit.clear();
147        self.extend_with_initials(g);
148        self.iter += 1;
149    }
150
151    /// 获取指定节点的数据
152    /// 返回节点的变量数据,如果不存在则返回None
153    pub fn get_node_data(
154        &self,
155        node_id: NodeIndex,
156    ) -> Option<Variable> {
157        self.node_data.get(&node_id).cloned()
158    }
159
160    /// 获取所有结束节点的变量
161    /// 合并所有没有出边的已访问节点的数据
162    pub fn ending_variables(
163        &self,
164        g: &StableDiDecisionGraph,
165    ) -> Variable {
166        g.node_indices()
167            .filter(|nid| {
168                self.ordered.is_visited(nid)
169                    && g.neighbors_directed(*nid, Outgoing).count().is_zero()
170            })
171            .fold(Variable::empty_object(), |mut acc, curr| {
172                match self.node_data.get(&curr) {
173                    None => acc,
174                    Some(data) => acc.merge(data),
175                }
176            })
177    }
178
179    /// 获取所有节点数据
180    /// 将节点数据转换为变量对象
181    pub fn get_all_node_data(
182        &self,
183        g: &StableDiDecisionGraph,
184    ) -> Variable {
185        let node_values = self
186            .node_data
187            .iter()
188            .filter_map(|(idx, value)| {
189                let weight = g.node_weight(*idx)?;
190                Some((Rc::from(weight.name.as_str()), value.clone()))
191            })
192            .collect();
193
194        Variable::from_object(node_values)
195    }
196
197    /// 设置节点数据
198    /// 将变量数据存储到指定节点
199    pub fn set_node_data(
200        &mut self,
201        node_id: NodeIndex,
202        value: Variable,
203    ) {
204        self.node_data.insert(node_id, value);
205    }
206
207    /// 获取入边节点的数据
208    /// 合并所有入边节点的数据,可选择是否包含节点上下文
209    pub fn incoming_node_data(
210        &self,
211        g: &StableDiDecisionGraph,
212        node_id: NodeIndex,
213        with_nodes: bool,
214    ) -> Variable {
215        let value = self
216            .merge_node_data(g.neighbors_directed(node_id, Incoming))
217            .depth_clone(1);
218        if self.nodes_in_context {
219            if let Some(object_ref) =
220                with_nodes.then_some(value.as_object()).flatten()
221            {
222                let mut object = object_ref.borrow_mut();
223                object.insert(Rc::from("$nodes"), self.get_all_node_data(g));
224            }
225        }
226        value
227    }
228
229    /// 合并多个节点的数据
230    /// 将多个节点的数据合并为一个变量对象
231    pub fn merge_node_data<I>(
232        &self,
233        iter: I,
234    ) -> Variable
235    where
236        I: Iterator<Item = NodeIndex>,
237    {
238        let default_map = Variable::empty_object();
239        iter.fold(Variable::empty_object(), |mut prev, curr| {
240            let data = self.node_data.get(&curr).unwrap_or(&default_map);
241            prev.merge_clone(data)
242        })
243    }
244
245    /// 获取下一个要处理的节点
246    ///
247    /// # 功能说明
248    /// 实现图遍历的核心逻辑,负责:
249    /// 1. 按拓扑顺序遍历决策图
250    /// 2. 处理开关节点的条件评估
251    /// 3. 移除无效的边和节点
252    /// 4. 生成执行跟踪信息
253    ///
254    /// # 参数说明
255    /// * `g` - 要遍历的决策图
256    /// * `on_trace` - 可选的跟踪回调函数,用于记录节点执行信息
257    ///
258    /// # 返回值
259    /// * `Option<NodeIndex>` - 返回下一个要处理的节点索引,如果没有更多节点则返回None
260    ///
261    /// # 处理流程
262    /// 1. 检查迭代次数是否超过限制
263    /// 2. 从待访问队列中取出节点
264    /// 3. 检查节点依赖是否已解析
265    /// 4. 处理开关节点的条件评估
266    /// 5. 移除无效边和死分支
267    /// 6. 添加后继节点到待访问队列
268    pub fn next<F: FnMut(DecisionGraphTrace)>(
269        &mut self,
270        g: &mut StableDiDecisionGraph,
271        mut on_trace: Option<F>,
272    ) -> Option<NodeIndex> {
273        // 记录开始时间,用于性能跟踪
274        let start = Instant::now();
275
276        // 检查是否超过最大迭代次数限制
277        if self.iter >= ITER_MAX {
278            return None;
279        }
280
281        // 循环处理待访问队列中的节点
282        while let Some(nid) = self.to_visit.pop() {
283            // 获取当前节点的决策节点数据
284            let decision_node = g.node_weight(nid)?.clone();
285
286            // 跳过已访问的节点
287            if self.ordered.is_visited(&nid) {
288                continue;
289            }
290
291            // 检查节点的所有依赖是否已解析
292            // 如果有未解析的依赖,将当前节点和未解析的依赖重新加入队列
293            if !self.all_dependencies_resolved(g, nid) {
294                self.to_visit.push(nid);
295                self.to_visit.extend(self.get_unresolved_dependencies(g, nid));
296                continue;
297            }
298
299            // 标记当前节点为已访问
300            self.ordered.visit(nid);
301
302            // 处理开关节点
303            if let DecisionNodeKind::SwitchNode { content } =
304                &decision_node.kind
305            {
306                // 确保每个开关节点只处理一次
307                if !self.visited_switch_nodes.contains(&nid) {
308                    // 获取输入数据并准备执行环境
309                    let input_data = self.incoming_node_data(g, nid, true);
310                    let env = input_data.depth_clone(1);
311                    env.dot_insert("$", input_data.depth_clone(1));
312                    let mut isolate = Isolate::with_environment(env);
313
314                    // 根据命中策略处理开关语句
315                    let mut statement_iter = content.statements.iter();
316                    let valid_statements: Vec<&SwitchStatement> =
317                        match content.hit_policy {
318                            // First策略:找到第一个满足条件的语句
319                            SwitchStatementHitPolicy::First => statement_iter
320                                .find(|&s| {
321                                    switch_statement_evaluate(&mut isolate, &s)
322                                })
323                                .into_iter()
324                                .collect(),
325                            // Collect策略:收集所有满足条件的语句
326                            SwitchStatementHitPolicy::Collect => statement_iter
327                                .filter(|&s| {
328                                    switch_statement_evaluate(&mut isolate, &s)
329                                })
330                                .collect(),
331                        };
332
333                    // 生成跟踪数据,记录有效的语句ID
334                    let valid_statements_trace = Variable::from_array(
335                        valid_statements
336                            .iter()
337                            .map(|&statement| {
338                                let v = Variable::empty_object();
339                                v.dot_insert(
340                                    "id",
341                                    Variable::String(Rc::from(
342                                        statement.id.as_str(),
343                                    )),
344                                );
345                                v
346                            })
347                            .collect(),
348                    );
349
350                    // 移除节点上下文数据
351                    input_data.dot_remove("$nodes");
352
353                    // 执行跟踪回调,记录节点执行信息
354                    if let Some(on_trace) = &mut on_trace {
355                        on_trace(DecisionGraphTrace {
356                            id: decision_node.id.clone(),
357                            name: decision_node.name.clone(),
358                            input: input_data.shallow_clone(),
359                            output: input_data.shallow_clone(),
360                            order: 0,
361                            performance: Some(format!(
362                                "{:.1?}",
363                                start.elapsed()
364                            )),
365                            trace_data: Some(
366                                json!({ "statements": valid_statements_trace })
367                                    .into(),
368                            ),
369                        });
370                    }
371
372                    // 移除无效边
373                    // 找出所有不在有效语句列表中的边
374                    let edges_to_remove: Vec<EdgeIndex> = g
375                        .edges_directed(nid, Outgoing)
376                        .filter(|edge| {
377                            edge.weight().source_handle.as_ref().map_or(
378                                true,
379                                |handle| {
380                                    !valid_statements
381                                        .iter()
382                                        .any(|s| s.id == *handle)
383                                },
384                            )
385                        })
386                        .map(|edge| edge.id())
387                        .collect();
388
389                    // 记录移除的边数量
390                    let edges_remove_count = edges_to_remove.len();
391
392                    // 递归移除无效边及其相关的死分支
393                    for edge in edges_to_remove {
394                        remove_edge_recursive(g, edge);
395                    }
396
397                    // 标记当前开关节点为已访问
398                    self.visited_switch_nodes.push(nid);
399
400                    // 如果移除了边,重置图遍历器并继续
401                    if edges_remove_count > 0 {
402                        self.reset(g);
403                        continue;
404                    }
405                }
406            }
407
408            // 将当前节点的所有后继节点添加到待访问队列
409            let successors = g.neighbors_directed(nid, Outgoing);
410            self.to_visit.extend(successors);
411
412            // 返回当前处理的节点
413            return Some(nid);
414        }
415
416        // 如果没有更多节点要处理,返回None
417        None
418    }
419
420    /// 检查节点的所有依赖是否已解析
421    /// 确保所有入边节点都已被访问
422    fn all_dependencies_resolved(
423        &self,
424        g: &StableDiDecisionGraph,
425        nid: NodeIndex,
426    ) -> bool {
427        g.neighbors_directed(nid, Incoming)
428            .all(|dep| self.ordered.is_visited(&dep))
429    }
430
431    /// 获取未解析的依赖节点
432    /// 返回所有未被访问的入边节点
433    fn get_unresolved_dependencies(
434        &self,
435        g: &StableDiDecisionGraph,
436        nid: NodeIndex,
437    ) -> Vec<NodeIndex> {
438        g.neighbors_directed(nid, Incoming)
439            .filter(|dep| !self.ordered.is_visited(dep))
440            .collect()
441    }
442}
443
444/// 评估开关语句的条件
445/// 如果条件为空则返回true,否则在隔离环境中执行条件表达式
446fn switch_statement_evaluate<'a>(
447    isolate: &mut Isolate<'a>,
448    switch_statement: &'a SwitchStatement,
449) -> bool {
450    if switch_statement.condition.is_empty() {
451        return true;
452    }
453
454    // 直接使用 run_standard,表达式系统会自动使用 thread_local State
455    isolate
456        .run_standard(switch_statement.condition.as_str())
457        .map_or(false, |v| v.as_bool().unwrap_or(false))
458}
459
460/// 递归移除边及其相关的死分支
461/// 处理目标节点和源节点的死分支,确保图的完整性
462fn remove_edge_recursive(
463    g: &mut StableDiDecisionGraph,
464    edge_id: EdgeIndex,
465) {
466    let Some((source_nid, target_nid)) = g.edge_endpoints(edge_id) else {
467        return;
468    };
469
470    g.remove_edge(edge_id);
471
472    // 处理目标节点的死分支
473    let target_incoming_count = g.edges_directed(target_nid, Incoming).count();
474    if target_incoming_count.is_zero() {
475        let edge_ids: Vec<EdgeIndex> = g
476            .edges_directed(target_nid, Outgoing)
477            .map(|edge| edge.id())
478            .collect();
479
480        edge_ids.iter().for_each(|edge_id| {
481            remove_edge_recursive(g, edge_id.clone());
482        });
483
484        if g.edges(target_nid).count().is_zero() {
485            g.remove_node(target_nid);
486        }
487    }
488
489    // 处理源节点的死分支
490    let source_outgoing_count = g.edges_directed(source_nid, Outgoing).count();
491    if source_outgoing_count.is_zero() {
492        let edge_ids: Vec<EdgeIndex> = g
493            .edges_directed(source_nid, Incoming)
494            .map(|edge| edge.id())
495            .collect();
496
497        edge_ids.iter().for_each(|edge_id| {
498            remove_edge_recursive(g, edge_id.clone());
499        });
500
501        if g.edges(source_nid).count().is_zero() {
502            g.remove_node(source_nid);
503        }
504    }
505}