moduforge_rules_engine/handler/
graph.rs

1use crate::handler::custom_node_adapter::{CustomNodeAdapter, CustomNodeRequest};
2use crate::handler::decision::DecisionHandler;
3use crate::handler::expression::ExpressionHandler;
4use crate::handler::function::function::{Function, FunctionConfig};
5use crate::handler::function::module::console::ConsoleListener;
6use crate::handler::function::module::custom::ModuforgeListener;
7use crate::handler::function::module::zen::ZenListener;
8use crate::handler::function::FunctionHandler;
9use crate::handler::function_v1;
10use crate::handler::function_v1::runtime::create_runtime;
11use crate::handler::node::{NodeRequest, PartialTraceError};
12use crate::handler::table::zen::DecisionTableHandler;
13use crate::handler::traversal::{GraphWalker, StableDiDecisionGraph};
14use crate::loader::DecisionLoader;
15use crate::model::{DecisionContent, DecisionNodeKind, FunctionNodeContent};
16use crate::util::validator_cache::ValidatorCache;
17use crate::{EvaluationError, NodeError};
18use ahash::{HashMap, HashMapExt};
19use anyhow::anyhow;
20use petgraph::algo::is_cyclic_directed;
21use serde::ser::SerializeMap;
22use serde::{Deserialize, Serialize, Serializer};
23use serde_json::Value;
24use std::hash::{DefaultHasher, Hash, Hasher};
25use std::rc::Rc;
26use std::sync::Arc;
27use std::time::Instant;
28use thiserror::Error;
29use moduforge_rules_expression::variable::Variable;
30
31/// 决策图结构体
32/// 用于表示和管理决策图,包含图的构建、验证和评估功能
33pub struct DecisionGraph<
34    L: DecisionLoader + 'static,
35    A: CustomNodeAdapter + 'static,
36> {
37    /// 初始图结构,用于重置
38    initial_graph: StableDiDecisionGraph,
39    /// 当前图结构,可能经过修改
40    graph: StableDiDecisionGraph,
41    /// 自定义节点适配器
42    adapter: Arc<A>,
43    /// 决策加载器
44    loader: Arc<L>,
45    /// 是否启用跟踪
46    trace: bool,
47    /// 最大深度限制
48    max_depth: u8,
49    /// 当前迭代次数
50    iteration: u8,
51    /// 运行时函数
52    runtime: Option<Rc<Function>>,
53    /// 验证器缓存
54    validator_cache: ValidatorCache,
55}
56
57/// 决策图配置结构体
58/// 用于初始化决策图的配置参数
59pub struct DecisionGraphConfig<
60    L: DecisionLoader + 'static,
61    A: CustomNodeAdapter + 'static,
62> {
63    /// 决策加载器
64    pub loader: Arc<L>,
65    /// 自定义节点适配器
66    pub adapter: Arc<A>,
67    /// 决策内容
68    pub content: Arc<DecisionContent>,
69    /// 是否启用跟踪
70    pub trace: bool,
71    /// 迭代次数
72    pub iteration: u8,
73    /// 最大深度限制
74    pub max_depth: u8,
75    /// 验证器缓存
76    pub validator_cache: Option<ValidatorCache>,
77}
78
79impl<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static>
80    DecisionGraph<L, A>
81{
82    /// 创建新的决策图实例
83    ///
84    /// # 参数
85    /// * `config` - 决策图配置
86    ///
87    /// # 返回
88    /// * `Result<Self, DecisionGraphValidationError>` - 成功返回决策图实例,失败返回验证错误
89    pub fn try_new(
90        config: DecisionGraphConfig<L, A>
91    ) -> Result<Self, DecisionGraphValidationError> {
92        let content = config.content;
93        let mut graph = StableDiDecisionGraph::new();
94        let mut index_map = HashMap::new();
95
96        // 添加所有节点到图中
97        for node in &content.nodes {
98            let node_id = node.id.clone();
99            let node_index = graph.add_node(node.clone());
100            index_map.insert(node_id, node_index);
101        }
102
103        // 添加所有边到图中
104        for (_, edge) in content.edges.iter().enumerate() {
105            let source_index =
106                index_map.get(&edge.source_id).ok_or_else(|| {
107                    DecisionGraphValidationError::MissingNode(
108                        edge.source_id.to_string(),
109                    )
110                })?;
111
112            let target_index =
113                index_map.get(&edge.target_id).ok_or_else(|| {
114                    DecisionGraphValidationError::MissingNode(
115                        edge.target_id.to_string(),
116                    )
117                })?;
118
119            graph.add_edge(
120                source_index.clone(),
121                target_index.clone(),
122                edge.clone(),
123            );
124        }
125
126        Ok(Self {
127            initial_graph: graph.clone(),
128            graph,
129            iteration: config.iteration,
130            trace: config.trace,
131            loader: config.loader,
132            adapter: config.adapter,
133            max_depth: config.max_depth,
134            validator_cache: config.validator_cache.unwrap_or_default(),
135            runtime: None,
136        })
137    }
138
139    /// 设置运行时函数
140    pub(crate) fn with_function(
141        mut self,
142        runtime: Option<Rc<Function>>,
143    ) -> Self {
144        self.runtime = runtime;
145        self
146    }
147
148    /// 重置图到初始状态
149    pub(crate) fn reset_graph(&mut self) {
150        self.graph = self.initial_graph.clone();
151    }
152
153    /// 获取或创建运行时函数
154    async fn get_or_insert_function(&mut self) -> anyhow::Result<Rc<Function>> {
155        if let Some(function) = &self.runtime {
156            return Ok(function.clone());
157        }
158
159        // 创建新的运行时函数
160        let function = Function::create(FunctionConfig {
161            listeners: Some(vec![
162                Box::new(ModuforgeListener {}),
163                Box::new(ConsoleListener),
164                Box::new(ZenListener {
165                    loader: self.loader.clone(),
166                    adapter: self.adapter.clone(),
167                }),
168            ]),
169        })
170        .await
171        .map_err(|err| anyhow!(err.to_string()))?;
172        let rc_function = Rc::new(function);
173        self.runtime.replace(rc_function.clone());
174
175        Ok(rc_function)
176    }
177
178    /// 验证决策图的有效性
179    ///
180    /// # 验证内容
181    /// 1. 检查输入节点数量是否为1
182    /// 2. 检查是否存在循环依赖
183    pub fn validate(&self) -> Result<(), DecisionGraphValidationError> {
184        let input_count = self.input_node_count();
185        if input_count != 1 {
186            return Err(DecisionGraphValidationError::InvalidInputCount(
187                input_count as u32,
188            ));
189        }
190
191        if is_cyclic_directed(&self.graph) {
192            return Err(DecisionGraphValidationError::CyclicGraph);
193        }
194
195        Ok(())
196    }
197
198    /// 计算输入节点的数量
199    fn input_node_count(&self) -> usize {
200        self.graph
201            .node_weights()
202            .filter(|weight| {
203                matches!(
204                    weight.kind,
205                    DecisionNodeKind::InputNode { content: _ }
206                )
207            })
208            .count()
209    }
210
211    /// 评估决策图
212    ///
213    /// # 参数
214    /// * `context` - 输入上下文变量
215    ///
216    /// # 返回
217    /// * `Result<DecisionGraphResponse, NodeError>` - 评估结果或错误
218    pub async fn evaluate(
219        &mut self,
220        context: Variable,
221    ) -> Result<DecisionGraphResponse, NodeError> {
222        let root_start = Instant::now();
223
224        // 验证图的有效性
225        self.validate().map_err(|e| NodeError {
226            node_id: "".to_string(),
227            source: anyhow!(e),
228            trace: None,
229        })?;
230
231        // 检查是否超过最大深度限制
232        if self.iteration >= self.max_depth {
233            return Err(NodeError {
234                node_id: "".to_string(),
235                source: anyhow!(EvaluationError::DepthLimitExceeded),
236                trace: None,
237            });
238        }
239
240        // 创建图遍历器并开始遍历
241        let mut walker = GraphWalker::new(&self.graph);
242        let mut node_traces = self.trace.then(|| HashMap::default());
243
244        // 遍历图中的所有节点
245        while let Some(nid) = walker.next(
246            &mut self.graph,
247            self.trace.then_some(|mut trace: DecisionGraphTrace| {
248                if let Some(nt) = &mut node_traces {
249                    trace.order = nt.len() as u32;
250                    nt.insert(trace.id.clone(), trace);
251                };
252            }),
253        ) {
254            // 如果节点已有数据,跳过处理
255            if let Some(_) = walker.get_node_data(nid) {
256                continue;
257            }
258
259            let node = (&self.graph[nid]).clone();
260            let start = Instant::now();
261
262            // 定义跟踪宏
263            macro_rules! trace {
264                ({ $($field:ident: $value:expr),* $(,)? }) => {
265                    if let Some(nt) = &mut node_traces {
266                        nt.insert(
267                            node.id.clone(),
268                            DecisionGraphTrace {
269                                name: node.name.clone(),
270                                id: node.id.clone(),
271                                performance: Some(format!("{:.1?}", start.elapsed())),
272                                order: nt.len() as u32,
273                                $($field: $value,)*
274                            }
275                        );
276                    }
277                };
278            }
279
280            // 根据节点类型处理
281            match &node.kind {
282                // 处理输入节点
283                DecisionNodeKind::InputNode { content } => {
284                    trace!({
285                        input: Variable::Null,
286                        output: context.clone(),
287                        trace_data: None,
288                    });
289
290                    // 验证输入数据
291                    if let Some(json_schema) = content
292                        .schema
293                        .as_ref()
294                        .map(|s| serde_json::from_str::<Value>(&s).ok())
295                        .flatten()
296                    {
297                        let validator_key =
298                            create_validator_cache_key(&json_schema);
299                        let validator = self
300                            .validator_cache
301                            .get_or_insert(validator_key, &json_schema)
302                            .await
303                            .map_err(|e| NodeError {
304                                source: e.into(),
305                                node_id: node.id.clone(),
306                                trace: error_trace(&node_traces),
307                            })?;
308
309                        let context_json = context.to_value();
310                        validator.validate(&context_json).map_err(|e| {
311                            NodeError {
312                                source: anyhow!(
313                                    serde_json::to_value(Into::<
314                                        Box<EvaluationError>,
315                                    >::into(
316                                        e
317                                    ))
318                                    .unwrap_or_default()
319                                ),
320                                node_id: node.id.clone(),
321                                trace: error_trace(&node_traces),
322                            }
323                        })?;
324                    }
325
326                    walker.set_node_data(nid, context.clone());
327                },
328                // 处理输出节点
329                DecisionNodeKind::OutputNode { content } => {
330                    let incoming_data =
331                        walker.incoming_node_data(&self.graph, nid, false);
332
333                    trace!({
334                        input: incoming_data.clone(),
335                        output: Variable::Null,
336                        trace_data: None,
337                    });
338
339                    // 验证输出数据
340                    if let Some(json_schema) = content
341                        .schema
342                        .as_ref()
343                        .map(|s| serde_json::from_str::<Value>(&s).ok())
344                        .flatten()
345                    {
346                        let validator_key =
347                            create_validator_cache_key(&json_schema);
348                        let validator = self
349                            .validator_cache
350                            .get_or_insert(validator_key, &json_schema)
351                            .await
352                            .map_err(|e| NodeError {
353                                source: e.into(),
354                                node_id: node.id.clone(),
355                                trace: error_trace(&node_traces),
356                            })?;
357
358                        let incoming_data_json = incoming_data.to_value();
359                        validator.validate(&incoming_data_json).map_err(
360                            |e| NodeError {
361                                source: anyhow!(
362                                    serde_json::to_value(Into::<
363                                        Box<EvaluationError>,
364                                    >::into(
365                                        e
366                                    ))
367                                    .unwrap_or_default()
368                                ),
369                                node_id: node.id.clone(),
370                                trace: error_trace(&node_traces),
371                            },
372                        )?;
373                    }
374
375                    return Ok(DecisionGraphResponse {
376                        result: incoming_data,
377                        performance: format!("{:.1?}", root_start.elapsed()),
378                        trace: node_traces,
379                    });
380                },
381                // 处理开关节点
382                DecisionNodeKind::SwitchNode { .. } => {
383                    let input_data =
384                        walker.incoming_node_data(&self.graph, nid, false);
385                    walker.set_node_data(nid, input_data);
386                },
387                // 处理函数节点
388                DecisionNodeKind::FunctionNode { content } => {
389                    let function = self
390                        .get_or_insert_function()
391                        .await
392                        .map_err(|e| NodeError {
393                            source: e.into(),
394                            node_id: node.id.clone(),
395                            trace: error_trace(&node_traces),
396                        })?;
397
398                    let node_request = NodeRequest {
399                        node: node.clone(),
400                        iteration: self.iteration,
401                        input: walker.incoming_node_data(
402                            &self.graph,
403                            nid,
404                            true,
405                        ),
406                    };
407
408                    // 根据函数版本处理
409                    let res = match content {
410                        FunctionNodeContent::Version2(_) => {
411                            FunctionHandler::new(
412                                function,
413                                self.trace,
414                                self.iteration,
415                                self.max_depth,
416                            )
417                            .handle(node_request.clone())
418                            .await
419                            .map_err(|e| {
420                                if let Some(detailed_err) =
421                                    e.downcast_ref::<PartialTraceError>()
422                                {
423                                    trace!({
424                                        input: node_request.input.clone(),
425                                        output: Variable::Null,
426                                        trace_data: detailed_err.trace.clone(),
427                                    });
428                                }
429
430                                NodeError {
431                                    source: e.into(),
432                                    node_id: node.id.clone(),
433                                    trace: error_trace(&node_traces),
434                                }
435                            })?
436                        },
437                        FunctionNodeContent::Version1(_) => {
438                            let runtime =
439                                create_runtime().map_err(|e| NodeError {
440                                    source: e.into(),
441                                    node_id: node.id.clone(),
442                                    trace: error_trace(&node_traces),
443                                })?;
444
445                            function_v1::FunctionHandler::new(
446                                self.trace, runtime,
447                            )
448                            .handle(node_request.clone())
449                            .await
450                            .map_err(|e| {
451                                NodeError {
452                                    source: e.into(),
453                                    node_id: node.id.clone(),
454                                    trace: error_trace(&node_traces),
455                                }
456                            })?
457                        },
458                    };
459
460                    node_request.input.dot_remove("$nodes");
461                    res.output.dot_remove("$nodes");
462
463                    trace!({
464                        input: node_request.input,
465                        output: res.output.clone(),
466                        trace_data: res.trace_data,
467                    });
468                    walker.set_node_data(nid, res.output);
469                },
470                // 处理决策节点
471                DecisionNodeKind::DecisionNode { .. } => {
472                    let node_request = NodeRequest {
473                        node: node.clone(),
474                        iteration: self.iteration,
475                        input: walker.incoming_node_data(
476                            &self.graph,
477                            nid,
478                            true,
479                        ),
480                    };
481
482                    let res = DecisionHandler::new(
483                        self.trace,
484                        self.max_depth,
485                        self.loader.clone(),
486                        self.adapter.clone(),
487                        self.runtime.clone(),
488                        self.validator_cache.clone(),
489                    )
490                    .handle(node_request.clone())
491                    .await
492                    .map_err(|e| NodeError {
493                        source: e.into(),
494                        node_id: node.id.to_string(),
495                        trace: error_trace(&node_traces),
496                    })?;
497
498                    node_request.input.dot_remove("$nodes");
499                    res.output.dot_remove("$nodes");
500
501                    trace!({
502                        input: node_request.input,
503                        output: res.output.clone(),
504                        trace_data: res.trace_data,
505                    });
506                    walker.set_node_data(nid, res.output);
507                },
508                // 处理决策表节点
509                DecisionNodeKind::DecisionTableNode { .. } => {
510                    let node_request = NodeRequest {
511                        node: node.clone(),
512                        iteration: self.iteration,
513                        input: walker.incoming_node_data(
514                            &self.graph,
515                            nid,
516                            true,
517                        ),
518                    };
519
520                    let res = DecisionTableHandler::new(self.trace)
521                        .handle(node_request.clone())
522                        .await
523                        .map_err(|e| NodeError {
524                            node_id: node.id.clone(),
525                            source: e.into(),
526                            trace: error_trace(&node_traces),
527                        })?;
528
529                    node_request.input.dot_remove("$nodes");
530                    res.output.dot_remove("$nodes");
531                    res.output.dot_remove("$");
532
533                    trace!({
534                        input: node_request.input,
535                        output: res.output.clone(),
536                        trace_data: res.trace_data,
537                    });
538                    walker.set_node_data(nid, res.output);
539                },
540                // 处理表达式节点
541                DecisionNodeKind::ExpressionNode { .. } => {
542                    let node_request = NodeRequest {
543                        node: node.clone(),
544                        iteration: self.iteration,
545                        input: walker.incoming_node_data(
546                            &self.graph,
547                            nid,
548                            true,
549                        ),
550                    };
551
552                    let res = ExpressionHandler::new(self.trace)
553                        .handle(node_request.clone())
554                        .await
555                        .map_err(|e| {
556                            if let Some(detailed_err) =
557                                e.downcast_ref::<PartialTraceError>()
558                            {
559                                trace!({
560                                    input: node_request.input.clone(),
561                                    output: Variable::Null,
562                                    trace_data: detailed_err.trace.clone(),
563                                });
564                            }
565
566                            NodeError {
567                                node_id: node.id.clone(),
568                                source: e.into(),
569                                trace: error_trace(&node_traces),
570                            }
571                        })?;
572
573                    node_request.input.dot_remove("$nodes");
574                    res.output.dot_remove("$nodes");
575
576                    trace!({
577                        input: node_request.input,
578                        output: res.output.clone(),
579                        trace_data: res.trace_data,
580                    });
581                    walker.set_node_data(nid, res.output);
582                },
583                // 处理自定义节点
584                DecisionNodeKind::CustomNode { .. } => {
585                    let node_request = NodeRequest {
586                        node: node.clone(),
587                        iteration: self.iteration,
588                        input: walker.incoming_node_data(
589                            &self.graph,
590                            nid,
591                            true,
592                        ),
593                    };
594
595                    let res = self
596                        .adapter
597                        .handle(
598                            CustomNodeRequest::try_from(node_request.clone())
599                                .unwrap(),
600                        )
601                        .await
602                        .map_err(|e| NodeError {
603                            node_id: node.id.clone(),
604                            source: e.into(),
605                            trace: error_trace(&node_traces),
606                        })?;
607
608                    node_request.input.dot_remove("$nodes");
609                    res.output.dot_remove("$nodes");
610
611                    trace!({
612                        input: node_request.input,
613                        output: res.output.clone(),
614                        trace_data: res.trace_data,
615                    });
616                    walker.set_node_data(nid, res.output);
617                },
618            }
619        }
620
621        // 返回最终结果
622        Ok(DecisionGraphResponse {
623            result: walker.ending_variables(&self.graph),
624            performance: format!("{:.1?}", root_start.elapsed()),
625            trace: node_traces,
626        })
627    }
628}
629
630/// 决策图验证错误类型
631#[derive(Debug, Error)]
632pub enum DecisionGraphValidationError {
633    /// 输入节点数量无效
634    #[error("无效的输入节点数量: {0}")]
635    InvalidInputCount(u32),
636
637    /// 输出节点数量无效
638    #[error("无效的输出节点数量: {0}")]
639    InvalidOutputCount(u32),
640
641    /// 检测到循环依赖
642    #[error("检测到循环依赖")]
643    CyclicGraph,
644
645    /// 节点缺失
646    #[error("节点缺失: {0}")]
647    MissingNode(String),
648}
649
650/// 实现序列化特性
651impl Serialize for DecisionGraphValidationError {
652    fn serialize<S>(
653        &self,
654        serializer: S,
655    ) -> Result<S::Ok, S::Error>
656    where
657        S: Serializer,
658    {
659        let mut map = serializer.serialize_map(None)?;
660
661        match &self {
662            DecisionGraphValidationError::InvalidInputCount(count) => {
663                map.serialize_entry("type", "invalidInputCount")?;
664                map.serialize_entry("nodeCount", count)?;
665            },
666            DecisionGraphValidationError::InvalidOutputCount(count) => {
667                map.serialize_entry("type", "invalidOutputCount")?;
668                map.serialize_entry("nodeCount", count)?;
669            },
670            DecisionGraphValidationError::MissingNode(node_id) => {
671                map.serialize_entry("type", "missingNode")?;
672                map.serialize_entry("nodeId", node_id)?;
673            },
674            DecisionGraphValidationError::CyclicGraph => {
675                map.serialize_entry("type", "cyclicGraph")?;
676            },
677        }
678
679        map.end()
680    }
681}
682
683/// 决策图响应结构体
684#[derive(Debug, Clone, Serialize, Deserialize)]
685#[serde(rename_all = "camelCase")]
686pub struct DecisionGraphResponse {
687    /// 性能信息
688    pub performance: String,
689    /// 评估结果
690    pub result: Variable,
691    /// 可选的跟踪信息
692    #[serde(skip_serializing_if = "Option::is_none")]
693    pub trace: Option<HashMap<String, DecisionGraphTrace>>,
694}
695
696/// 决策图跟踪信息结构体
697#[derive(Debug, Clone, Serialize, Deserialize)]
698#[serde(rename_all = "camelCase")]
699pub struct DecisionGraphTrace {
700    /// 输入数据
701    pub input: Variable,
702    /// 输出数据
703    pub output: Variable,
704    /// 节点名称
705    pub name: String,
706    /// 节点ID
707    pub id: String,
708    /// 性能信息
709    pub performance: Option<String>,
710    /// 跟踪数据
711    pub trace_data: Option<Value>,
712    /// 执行顺序
713    pub order: u32,
714}
715
716/// 将跟踪信息转换为JSON值
717pub(crate) fn error_trace(
718    trace: &Option<HashMap<String, DecisionGraphTrace>>
719) -> Option<Value> {
720    trace.as_ref().map(|s| serde_json::to_value(s).ok()).flatten()
721}
722
723/// 创建验证器缓存键
724fn create_validator_cache_key(content: &Value) -> u64 {
725    let mut hasher = DefaultHasher::new();
726    content.hash(&mut hasher);
727    hasher.finish()
728}