1use crate::handler::custom_node_adapter::{CustomNodeAdapter, CustomNodeRequest};
2use crate::handler::decision::DecisionHandler;
3use crate::handler::expression::ExpressionHandler;
4use crate::handler::function::function::{Function, FunctionConfig};
5use crate::handler::function::module::console::ConsoleListener;
6use crate::handler::function::module::custom::ModuforgeListener;
7use crate::handler::function::module::zen::ZenListener;
8use crate::handler::function::FunctionHandler;
9use crate::handler::function_v1;
10use crate::handler::function_v1::runtime::create_runtime;
11use crate::handler::node::{NodeRequest, PartialTraceError};
12use crate::handler::table::zen::DecisionTableHandler;
13use crate::handler::traversal::{GraphWalker, StableDiDecisionGraph};
14use crate::loader::DecisionLoader;
15use crate::model::{DecisionContent, DecisionNodeKind, FunctionNodeContent};
16use crate::util::validator_cache::ValidatorCache;
17use crate::{EvaluationError, NodeError};
18use ahash::{HashMap, HashMapExt};
19use anyhow::anyhow;
20use petgraph::algo::is_cyclic_directed;
21use serde::ser::SerializeMap;
22use serde::{Deserialize, Serialize, Serializer};
23use serde_json::Value;
24use std::hash::{DefaultHasher, Hash, Hasher};
25use std::rc::Rc;
26use std::sync::Arc;
27use std::time::Instant;
28use thiserror::Error;
29use moduforge_rules_expression::variable::Variable;
30
31pub struct DecisionGraph<
34 L: DecisionLoader + 'static,
35 A: CustomNodeAdapter + 'static,
36> {
37 initial_graph: StableDiDecisionGraph,
39 graph: StableDiDecisionGraph,
41 adapter: Arc<A>,
43 loader: Arc<L>,
45 trace: bool,
47 max_depth: u8,
49 iteration: u8,
51 runtime: Option<Rc<Function>>,
53 validator_cache: ValidatorCache,
55}
56
57pub struct DecisionGraphConfig<
60 L: DecisionLoader + 'static,
61 A: CustomNodeAdapter + 'static,
62> {
63 pub loader: Arc<L>,
65 pub adapter: Arc<A>,
67 pub content: Arc<DecisionContent>,
69 pub trace: bool,
71 pub iteration: u8,
73 pub max_depth: u8,
75 pub validator_cache: Option<ValidatorCache>,
77}
78
79impl<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static>
80 DecisionGraph<L, A>
81{
82 pub fn try_new(
90 config: DecisionGraphConfig<L, A>
91 ) -> Result<Self, DecisionGraphValidationError> {
92 let content = config.content;
93 let mut graph = StableDiDecisionGraph::new();
94 let mut index_map = HashMap::new();
95
96 for node in &content.nodes {
98 let node_id = node.id.clone();
99 let node_index = graph.add_node(node.clone());
100 index_map.insert(node_id, node_index);
101 }
102
103 for (_, edge) in content.edges.iter().enumerate() {
105 let source_index =
106 index_map.get(&edge.source_id).ok_or_else(|| {
107 DecisionGraphValidationError::MissingNode(
108 edge.source_id.to_string(),
109 )
110 })?;
111
112 let target_index =
113 index_map.get(&edge.target_id).ok_or_else(|| {
114 DecisionGraphValidationError::MissingNode(
115 edge.target_id.to_string(),
116 )
117 })?;
118
119 graph.add_edge(
120 source_index.clone(),
121 target_index.clone(),
122 edge.clone(),
123 );
124 }
125
126 Ok(Self {
127 initial_graph: graph.clone(),
128 graph,
129 iteration: config.iteration,
130 trace: config.trace,
131 loader: config.loader,
132 adapter: config.adapter,
133 max_depth: config.max_depth,
134 validator_cache: config.validator_cache.unwrap_or_default(),
135 runtime: None,
136 })
137 }
138
139 pub(crate) fn with_function(
141 mut self,
142 runtime: Option<Rc<Function>>,
143 ) -> Self {
144 self.runtime = runtime;
145 self
146 }
147
148 pub(crate) fn reset_graph(&mut self) {
150 self.graph = self.initial_graph.clone();
151 }
152
153 async fn get_or_insert_function(&mut self) -> anyhow::Result<Rc<Function>> {
155 if let Some(function) = &self.runtime {
156 return Ok(function.clone());
157 }
158
159 let function = Function::create(FunctionConfig {
161 listeners: Some(vec![
162 Box::new(ModuforgeListener {}),
163 Box::new(ConsoleListener),
164 Box::new(ZenListener {
165 loader: self.loader.clone(),
166 adapter: self.adapter.clone(),
167 }),
168 ]),
169 })
170 .await
171 .map_err(|err| anyhow!(err.to_string()))?;
172 let rc_function = Rc::new(function);
173 self.runtime.replace(rc_function.clone());
174
175 Ok(rc_function)
176 }
177
178 pub fn validate(&self) -> Result<(), DecisionGraphValidationError> {
184 let input_count = self.input_node_count();
185 if input_count != 1 {
186 return Err(DecisionGraphValidationError::InvalidInputCount(
187 input_count as u32,
188 ));
189 }
190
191 if is_cyclic_directed(&self.graph) {
192 return Err(DecisionGraphValidationError::CyclicGraph);
193 }
194
195 Ok(())
196 }
197
198 fn input_node_count(&self) -> usize {
200 self.graph
201 .node_weights()
202 .filter(|weight| {
203 matches!(
204 weight.kind,
205 DecisionNodeKind::InputNode { content: _ }
206 )
207 })
208 .count()
209 }
210
211 pub async fn evaluate(
219 &mut self,
220 context: Variable,
221 ) -> Result<DecisionGraphResponse, NodeError> {
222 let root_start = Instant::now();
223
224 self.validate().map_err(|e| NodeError {
226 node_id: "".to_string(),
227 source: anyhow!(e),
228 trace: None,
229 })?;
230
231 if self.iteration >= self.max_depth {
233 return Err(NodeError {
234 node_id: "".to_string(),
235 source: anyhow!(EvaluationError::DepthLimitExceeded),
236 trace: None,
237 });
238 }
239
240 let mut walker = GraphWalker::new(&self.graph);
242 let mut node_traces = self.trace.then(|| HashMap::default());
243
244 while let Some(nid) = walker.next(
246 &mut self.graph,
247 self.trace.then_some(|mut trace: DecisionGraphTrace| {
248 if let Some(nt) = &mut node_traces {
249 trace.order = nt.len() as u32;
250 nt.insert(trace.id.clone(), trace);
251 };
252 }),
253 ) {
254 if let Some(_) = walker.get_node_data(nid) {
256 continue;
257 }
258
259 let node = (&self.graph[nid]).clone();
260 let start = Instant::now();
261
262 macro_rules! trace {
264 ({ $($field:ident: $value:expr),* $(,)? }) => {
265 if let Some(nt) = &mut node_traces {
266 nt.insert(
267 node.id.clone(),
268 DecisionGraphTrace {
269 name: node.name.clone(),
270 id: node.id.clone(),
271 performance: Some(format!("{:.1?}", start.elapsed())),
272 order: nt.len() as u32,
273 $($field: $value,)*
274 }
275 );
276 }
277 };
278 }
279
280 match &node.kind {
282 DecisionNodeKind::InputNode { content } => {
284 trace!({
285 input: Variable::Null,
286 output: context.clone(),
287 trace_data: None,
288 });
289
290 if let Some(json_schema) = content
292 .schema
293 .as_ref()
294 .map(|s| serde_json::from_str::<Value>(&s).ok())
295 .flatten()
296 {
297 let validator_key =
298 create_validator_cache_key(&json_schema);
299 let validator = self
300 .validator_cache
301 .get_or_insert(validator_key, &json_schema)
302 .await
303 .map_err(|e| NodeError {
304 source: e.into(),
305 node_id: node.id.clone(),
306 trace: error_trace(&node_traces),
307 })?;
308
309 let context_json = context.to_value();
310 validator.validate(&context_json).map_err(|e| {
311 NodeError {
312 source: anyhow!(
313 serde_json::to_value(Into::<
314 Box<EvaluationError>,
315 >::into(
316 e
317 ))
318 .unwrap_or_default()
319 ),
320 node_id: node.id.clone(),
321 trace: error_trace(&node_traces),
322 }
323 })?;
324 }
325
326 walker.set_node_data(nid, context.clone());
327 },
328 DecisionNodeKind::OutputNode { content } => {
330 let incoming_data =
331 walker.incoming_node_data(&self.graph, nid, false);
332
333 trace!({
334 input: incoming_data.clone(),
335 output: Variable::Null,
336 trace_data: None,
337 });
338
339 if let Some(json_schema) = content
341 .schema
342 .as_ref()
343 .map(|s| serde_json::from_str::<Value>(&s).ok())
344 .flatten()
345 {
346 let validator_key =
347 create_validator_cache_key(&json_schema);
348 let validator = self
349 .validator_cache
350 .get_or_insert(validator_key, &json_schema)
351 .await
352 .map_err(|e| NodeError {
353 source: e.into(),
354 node_id: node.id.clone(),
355 trace: error_trace(&node_traces),
356 })?;
357
358 let incoming_data_json = incoming_data.to_value();
359 validator.validate(&incoming_data_json).map_err(
360 |e| NodeError {
361 source: anyhow!(
362 serde_json::to_value(Into::<
363 Box<EvaluationError>,
364 >::into(
365 e
366 ))
367 .unwrap_or_default()
368 ),
369 node_id: node.id.clone(),
370 trace: error_trace(&node_traces),
371 },
372 )?;
373 }
374
375 return Ok(DecisionGraphResponse {
376 result: incoming_data,
377 performance: format!("{:.1?}", root_start.elapsed()),
378 trace: node_traces,
379 });
380 },
381 DecisionNodeKind::SwitchNode { .. } => {
383 let input_data =
384 walker.incoming_node_data(&self.graph, nid, false);
385 walker.set_node_data(nid, input_data);
386 },
387 DecisionNodeKind::FunctionNode { content } => {
389 let function = self
390 .get_or_insert_function()
391 .await
392 .map_err(|e| NodeError {
393 source: e.into(),
394 node_id: node.id.clone(),
395 trace: error_trace(&node_traces),
396 })?;
397
398 let node_request = NodeRequest {
399 node: node.clone(),
400 iteration: self.iteration,
401 input: walker.incoming_node_data(
402 &self.graph,
403 nid,
404 true,
405 ),
406 };
407
408 let res = match content {
410 FunctionNodeContent::Version2(_) => {
411 FunctionHandler::new(
412 function,
413 self.trace,
414 self.iteration,
415 self.max_depth,
416 )
417 .handle(node_request.clone())
418 .await
419 .map_err(|e| {
420 if let Some(detailed_err) =
421 e.downcast_ref::<PartialTraceError>()
422 {
423 trace!({
424 input: node_request.input.clone(),
425 output: Variable::Null,
426 trace_data: detailed_err.trace.clone(),
427 });
428 }
429
430 NodeError {
431 source: e.into(),
432 node_id: node.id.clone(),
433 trace: error_trace(&node_traces),
434 }
435 })?
436 },
437 FunctionNodeContent::Version1(_) => {
438 let runtime =
439 create_runtime().map_err(|e| NodeError {
440 source: e.into(),
441 node_id: node.id.clone(),
442 trace: error_trace(&node_traces),
443 })?;
444
445 function_v1::FunctionHandler::new(
446 self.trace, runtime,
447 )
448 .handle(node_request.clone())
449 .await
450 .map_err(|e| {
451 NodeError {
452 source: e.into(),
453 node_id: node.id.clone(),
454 trace: error_trace(&node_traces),
455 }
456 })?
457 },
458 };
459
460 node_request.input.dot_remove("$nodes");
461 res.output.dot_remove("$nodes");
462
463 trace!({
464 input: node_request.input,
465 output: res.output.clone(),
466 trace_data: res.trace_data,
467 });
468 walker.set_node_data(nid, res.output);
469 },
470 DecisionNodeKind::DecisionNode { .. } => {
472 let node_request = NodeRequest {
473 node: node.clone(),
474 iteration: self.iteration,
475 input: walker.incoming_node_data(
476 &self.graph,
477 nid,
478 true,
479 ),
480 };
481
482 let res = DecisionHandler::new(
483 self.trace,
484 self.max_depth,
485 self.loader.clone(),
486 self.adapter.clone(),
487 self.runtime.clone(),
488 self.validator_cache.clone(),
489 )
490 .handle(node_request.clone())
491 .await
492 .map_err(|e| NodeError {
493 source: e.into(),
494 node_id: node.id.to_string(),
495 trace: error_trace(&node_traces),
496 })?;
497
498 node_request.input.dot_remove("$nodes");
499 res.output.dot_remove("$nodes");
500
501 trace!({
502 input: node_request.input,
503 output: res.output.clone(),
504 trace_data: res.trace_data,
505 });
506 walker.set_node_data(nid, res.output);
507 },
508 DecisionNodeKind::DecisionTableNode { .. } => {
510 let node_request = NodeRequest {
511 node: node.clone(),
512 iteration: self.iteration,
513 input: walker.incoming_node_data(
514 &self.graph,
515 nid,
516 true,
517 ),
518 };
519
520 let res = DecisionTableHandler::new(self.trace)
521 .handle(node_request.clone())
522 .await
523 .map_err(|e| NodeError {
524 node_id: node.id.clone(),
525 source: e.into(),
526 trace: error_trace(&node_traces),
527 })?;
528
529 node_request.input.dot_remove("$nodes");
530 res.output.dot_remove("$nodes");
531 res.output.dot_remove("$");
532
533 trace!({
534 input: node_request.input,
535 output: res.output.clone(),
536 trace_data: res.trace_data,
537 });
538 walker.set_node_data(nid, res.output);
539 },
540 DecisionNodeKind::ExpressionNode { .. } => {
542 let node_request = NodeRequest {
543 node: node.clone(),
544 iteration: self.iteration,
545 input: walker.incoming_node_data(
546 &self.graph,
547 nid,
548 true,
549 ),
550 };
551
552 let res = ExpressionHandler::new(self.trace)
553 .handle(node_request.clone())
554 .await
555 .map_err(|e| {
556 if let Some(detailed_err) =
557 e.downcast_ref::<PartialTraceError>()
558 {
559 trace!({
560 input: node_request.input.clone(),
561 output: Variable::Null,
562 trace_data: detailed_err.trace.clone(),
563 });
564 }
565
566 NodeError {
567 node_id: node.id.clone(),
568 source: e.into(),
569 trace: error_trace(&node_traces),
570 }
571 })?;
572
573 node_request.input.dot_remove("$nodes");
574 res.output.dot_remove("$nodes");
575
576 trace!({
577 input: node_request.input,
578 output: res.output.clone(),
579 trace_data: res.trace_data,
580 });
581 walker.set_node_data(nid, res.output);
582 },
583 DecisionNodeKind::CustomNode { .. } => {
585 let node_request = NodeRequest {
586 node: node.clone(),
587 iteration: self.iteration,
588 input: walker.incoming_node_data(
589 &self.graph,
590 nid,
591 true,
592 ),
593 };
594
595 let res = self
596 .adapter
597 .handle(
598 CustomNodeRequest::try_from(node_request.clone())
599 .unwrap(),
600 )
601 .await
602 .map_err(|e| NodeError {
603 node_id: node.id.clone(),
604 source: e.into(),
605 trace: error_trace(&node_traces),
606 })?;
607
608 node_request.input.dot_remove("$nodes");
609 res.output.dot_remove("$nodes");
610
611 trace!({
612 input: node_request.input,
613 output: res.output.clone(),
614 trace_data: res.trace_data,
615 });
616 walker.set_node_data(nid, res.output);
617 },
618 }
619 }
620
621 Ok(DecisionGraphResponse {
623 result: walker.ending_variables(&self.graph),
624 performance: format!("{:.1?}", root_start.elapsed()),
625 trace: node_traces,
626 })
627 }
628}
629
630#[derive(Debug, Error)]
632pub enum DecisionGraphValidationError {
633 #[error("无效的输入节点数量: {0}")]
635 InvalidInputCount(u32),
636
637 #[error("无效的输出节点数量: {0}")]
639 InvalidOutputCount(u32),
640
641 #[error("检测到循环依赖")]
643 CyclicGraph,
644
645 #[error("节点缺失: {0}")]
647 MissingNode(String),
648}
649
650impl Serialize for DecisionGraphValidationError {
652 fn serialize<S>(
653 &self,
654 serializer: S,
655 ) -> Result<S::Ok, S::Error>
656 where
657 S: Serializer,
658 {
659 let mut map = serializer.serialize_map(None)?;
660
661 match &self {
662 DecisionGraphValidationError::InvalidInputCount(count) => {
663 map.serialize_entry("type", "invalidInputCount")?;
664 map.serialize_entry("nodeCount", count)?;
665 },
666 DecisionGraphValidationError::InvalidOutputCount(count) => {
667 map.serialize_entry("type", "invalidOutputCount")?;
668 map.serialize_entry("nodeCount", count)?;
669 },
670 DecisionGraphValidationError::MissingNode(node_id) => {
671 map.serialize_entry("type", "missingNode")?;
672 map.serialize_entry("nodeId", node_id)?;
673 },
674 DecisionGraphValidationError::CyclicGraph => {
675 map.serialize_entry("type", "cyclicGraph")?;
676 },
677 }
678
679 map.end()
680 }
681}
682
683#[derive(Debug, Clone, Serialize, Deserialize)]
685#[serde(rename_all = "camelCase")]
686pub struct DecisionGraphResponse {
687 pub performance: String,
689 pub result: Variable,
691 #[serde(skip_serializing_if = "Option::is_none")]
693 pub trace: Option<HashMap<String, DecisionGraphTrace>>,
694}
695
696#[derive(Debug, Clone, Serialize, Deserialize)]
698#[serde(rename_all = "camelCase")]
699pub struct DecisionGraphTrace {
700 pub input: Variable,
702 pub output: Variable,
704 pub name: String,
706 pub id: String,
708 pub performance: Option<String>,
710 pub trace_data: Option<Value>,
712 pub order: u32,
714}
715
716pub(crate) fn error_trace(
718 trace: &Option<HashMap<String, DecisionGraphTrace>>
719) -> Option<Value> {
720 trace.as_ref().map(|s| serde_json::to_value(s).ok()).flatten()
721}
722
723fn create_validator_cache_key(content: &Value) -> u64 {
725 let mut hasher = DefaultHasher::new();
726 content.hash(&mut hasher);
727 hasher.finish()
728}