mf_engine/handler/
graph.rs

1use crate::handler::custom_node_adapter::{CustomNodeAdapter, CustomNodeRequest};
2use crate::handler::decision::DecisionHandler;
3use crate::handler::expression::ExpressionHandler;
4use crate::handler::function::function::{Function, FunctionConfig};
5use crate::handler::function::module::console::ConsoleListener;
6use crate::handler::function::module::zen::ZenListener;
7use crate::handler::function::FunctionHandler;
8use crate::handler::function_v1;
9use crate::handler::function_v1::runtime::create_runtime;
10use crate::handler::node::{NodeRequest, PartialTraceError};
11use crate::handler::table::zen::DecisionTableHandler;
12use crate::handler::traversal::{GraphWalker, StableDiDecisionGraph};
13use crate::loader::DecisionLoader;
14use crate::model::{DecisionContent, DecisionNodeKind, FunctionNodeContent};
15use crate::util::validator_cache::ValidatorCache;
16use crate::{EvaluationError, NodeError};
17use ahash::{HashMap, HashMapExt};
18use anyhow::anyhow;
19use petgraph::algo::is_cyclic_directed;
20use serde::ser::SerializeMap;
21use serde::{Deserialize, Serialize, Serializer};
22use serde_json::Value;
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::rc::Rc;
25use std::sync::Arc;
26use std::time::Instant;
27use thiserror::Error;
28use mf_expression::variable::Variable;
29use crate::handler::function::module::mf::ModuforgeListener;
30
31pub struct DecisionGraph<
32    L: DecisionLoader + 'static,
33    A: CustomNodeAdapter + 'static,
34> {
35    initial_graph: StableDiDecisionGraph,
36    graph: StableDiDecisionGraph,
37    adapter: Arc<A>,
38    loader: Arc<L>,
39    trace: bool,
40    max_depth: u8,
41    iteration: u8,
42    runtime: Option<Rc<Function>>,
43    validator_cache: ValidatorCache,
44}
45
46pub struct DecisionGraphConfig<
47    L: DecisionLoader + 'static,
48    A: CustomNodeAdapter + 'static,
49> {
50    pub loader: Arc<L>,
51    pub adapter: Arc<A>,
52    pub content: Arc<DecisionContent>,
53    pub trace: bool,
54    pub iteration: u8,
55    pub max_depth: u8,
56    pub validator_cache: Option<ValidatorCache>,
57}
58
59impl<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static>
60    DecisionGraph<L, A>
61{
62    pub fn try_new(
63        config: DecisionGraphConfig<L, A>
64    ) -> Result<Self, DecisionGraphValidationError> {
65        let content = config.content;
66        let mut graph = StableDiDecisionGraph::new();
67        let mut index_map = HashMap::new();
68
69        for node in &content.nodes {
70            let node_id = node.id.clone();
71            let node_index = graph.add_node(node.clone());
72
73            index_map.insert(node_id, node_index);
74        }
75
76        for (_, edge) in content.edges.iter().enumerate() {
77            let source_index =
78                index_map.get(&edge.source_id).ok_or_else(|| {
79                    DecisionGraphValidationError::MissingNode(
80                        edge.source_id.to_string(),
81                    )
82                })?;
83
84            let target_index =
85                index_map.get(&edge.target_id).ok_or_else(|| {
86                    DecisionGraphValidationError::MissingNode(
87                        edge.target_id.to_string(),
88                    )
89                })?;
90
91            graph.add_edge(
92                source_index.clone(),
93                target_index.clone(),
94                edge.clone(),
95            );
96        }
97
98        Ok(Self {
99            initial_graph: graph.clone(),
100            graph,
101            iteration: config.iteration,
102            trace: config.trace,
103            loader: config.loader,
104            adapter: config.adapter,
105            max_depth: config.max_depth,
106            validator_cache: config.validator_cache.unwrap_or_default(),
107            runtime: None,
108        })
109    }
110
111    pub(crate) fn with_function(
112        mut self,
113        runtime: Option<Rc<Function>>,
114    ) -> Self {
115        self.runtime = runtime;
116        self
117    }
118
119    pub(crate) fn reset_graph(&mut self) {
120        self.graph = self.initial_graph.clone();
121    }
122
123    async fn get_or_insert_function(&mut self) -> anyhow::Result<Rc<Function>> {
124        if let Some(function) = &self.runtime {
125            return Ok(function.clone());
126        }
127
128        let function = Function::create(FunctionConfig {
129            listeners: Some(vec![
130                Box::new(ConsoleListener),
131                Box::new(ZenListener {
132                    loader: self.loader.clone(),
133                    adapter: self.adapter.clone(),
134                }),
135                Box::new(ModuforgeListener {}),
136            ]),
137        })
138        .await
139        .map_err(|err| anyhow!(err.to_string()))?;
140        let rc_function = Rc::new(function);
141        self.runtime.replace(rc_function.clone());
142
143        Ok(rc_function)
144    }
145
146    pub fn validate(&self) -> Result<(), DecisionGraphValidationError> {
147        let input_count = self.input_node_count();
148        if input_count != 1 {
149            return Err(DecisionGraphValidationError::InvalidInputCount(
150                input_count as u32,
151            ));
152        }
153
154        if is_cyclic_directed(&self.graph) {
155            return Err(DecisionGraphValidationError::CyclicGraph);
156        }
157
158        Ok(())
159    }
160
161    fn input_node_count(&self) -> usize {
162        self.graph
163            .node_weights()
164            .filter(|weight| {
165                matches!(
166                    weight.kind,
167                    DecisionNodeKind::InputNode { content: _ }
168                )
169            })
170            .count()
171    }
172
173    pub async fn evaluate(
174        &mut self,
175        context: Variable,
176    ) -> Result<DecisionGraphResponse, NodeError> {
177        let root_start = Instant::now();
178
179        self.validate().map_err(|e| NodeError {
180            node_id: "".to_string(),
181            source: anyhow!(e),
182            trace: None,
183        })?;
184
185        if self.iteration >= self.max_depth {
186            return Err(NodeError {
187                node_id: "".to_string(),
188                source: anyhow!(EvaluationError::DepthLimitExceeded),
189                trace: None,
190            });
191        }
192
193        let mut walker = GraphWalker::new(&self.graph);
194        let mut node_traces = self.trace.then(|| HashMap::default());
195
196        while let Some(nid) = walker.next(
197            &mut self.graph,
198            self.trace.then_some(|mut trace: DecisionGraphTrace| {
199                if let Some(nt) = &mut node_traces {
200                    trace.order = nt.len() as u32;
201                    nt.insert(trace.id.clone(), trace);
202                };
203            }),
204        ) {
205            if let Some(_) = walker.get_node_data(nid) {
206                continue;
207            }
208
209            let node = (&self.graph[nid]).clone();
210            let start = Instant::now();
211
212            macro_rules! trace {
213                ({ $($field:ident: $value:expr),* $(,)? }) => {
214                    if let Some(nt) = &mut node_traces {
215                        nt.insert(
216                            node.id.clone(),
217                            DecisionGraphTrace {
218                                name: node.name.clone(),
219                                id: node.id.clone(),
220                                performance: Some(format!("{:.1?}", start.elapsed())),
221                                order: nt.len() as u32,
222                                $($field: $value,)*
223                            }
224                        );
225                    }
226                };
227            }
228
229            match &node.kind {
230                DecisionNodeKind::InputNode { content } => {
231                    trace!({
232                        input: Variable::Null,
233                        output: context.clone(),
234                        trace_data: None,
235                    });
236
237                    if let Some(json_schema) = content
238                        .schema
239                        .as_ref()
240                        .map(|s| serde_json::from_str::<Value>(&s).ok())
241                        .flatten()
242                    {
243                        let validator_key =
244                            create_validator_cache_key(&json_schema);
245                        let validator = self
246                            .validator_cache
247                            .get_or_insert(validator_key, &json_schema)
248                            .await
249                            .map_err(|e| NodeError {
250                                source: e.into(),
251                                node_id: node.id.clone(),
252                                trace: error_trace(&node_traces),
253                            })?;
254
255                        let context_json = context.to_value();
256                        validator.validate(&context_json).map_err(|e| {
257                            NodeError {
258                                source: anyhow!(
259                                    serde_json::to_value(Into::<
260                                        Box<EvaluationError>,
261                                    >::into(
262                                        e
263                                    ))
264                                    .unwrap_or_default()
265                                ),
266                                node_id: node.id.clone(),
267                                trace: error_trace(&node_traces),
268                            }
269                        })?;
270                    }
271
272                    walker.set_node_data(nid, context.clone());
273                },
274                DecisionNodeKind::OutputNode { content } => {
275                    let incoming_data =
276                        walker.incoming_node_data(&self.graph, nid, false);
277
278                    trace!({
279                        input: incoming_data.clone(),
280                        output: Variable::Null,
281                        trace_data: None,
282                    });
283
284                    if let Some(json_schema) = content
285                        .schema
286                        .as_ref()
287                        .map(|s| serde_json::from_str::<Value>(&s).ok())
288                        .flatten()
289                    {
290                        let validator_key =
291                            create_validator_cache_key(&json_schema);
292                        let validator = self
293                            .validator_cache
294                            .get_or_insert(validator_key, &json_schema)
295                            .await
296                            .map_err(|e| NodeError {
297                                source: e.into(),
298                                node_id: node.id.clone(),
299                                trace: error_trace(&node_traces),
300                            })?;
301
302                        let incoming_data_json = incoming_data.to_value();
303                        validator.validate(&incoming_data_json).map_err(
304                            |e| NodeError {
305                                source: anyhow!(
306                                    serde_json::to_value(Into::<
307                                        Box<EvaluationError>,
308                                    >::into(
309                                        e
310                                    ))
311                                    .unwrap_or_default()
312                                ),
313                                node_id: node.id.clone(),
314                                trace: error_trace(&node_traces),
315                            },
316                        )?;
317                    }
318
319                    return Ok(DecisionGraphResponse {
320                        result: incoming_data,
321                        performance: format!("{:.1?}", root_start.elapsed()),
322                        trace: node_traces,
323                    });
324                },
325                DecisionNodeKind::SwitchNode { .. } => {
326                    let input_data =
327                        walker.incoming_node_data(&self.graph, nid, false);
328
329                    walker.set_node_data(nid, input_data);
330                },
331                DecisionNodeKind::FunctionNode { content } => {
332                    let function = self
333                        .get_or_insert_function()
334                        .await
335                        .map_err(|e| NodeError {
336                            source: e.into(),
337                            node_id: node.id.clone(),
338                            trace: error_trace(&node_traces),
339                        })?;
340
341                    let node_request = NodeRequest {
342                        node: node.clone(),
343                        iteration: self.iteration,
344                        input: walker.incoming_node_data(
345                            &self.graph,
346                            nid,
347                            true,
348                        ),
349                    };
350                    let res = match content {
351                        FunctionNodeContent::Version2(_) => {
352                            FunctionHandler::new(
353                                function,
354                                self.trace,
355                                self.iteration,
356                                self.max_depth,
357                            )
358                            .handle(node_request.clone())
359                            .await
360                            .map_err(|e| {
361                                if let Some(detailed_err) =
362                                    e.downcast_ref::<PartialTraceError>()
363                                {
364                                    trace!({
365                                        input: node_request.input.clone(),
366                                        output: Variable::Null,
367                                        trace_data: detailed_err.trace.clone(),
368                                    });
369                                }
370
371                                NodeError {
372                                    source: e.into(),
373                                    node_id: node.id.clone(),
374                                    trace: error_trace(&node_traces),
375                                }
376                            })?
377                        },
378                        FunctionNodeContent::Version1(_) => {
379                            let runtime =
380                                create_runtime().map_err(|e| NodeError {
381                                    source: e.into(),
382                                    node_id: node.id.clone(),
383                                    trace: error_trace(&node_traces),
384                                })?;
385
386                            function_v1::FunctionHandler::new(
387                                self.trace, runtime,
388                            )
389                            .handle(node_request.clone())
390                            .await
391                            .map_err(|e| {
392                                NodeError {
393                                    source: e.into(),
394                                    node_id: node.id.clone(),
395                                    trace: error_trace(&node_traces),
396                                }
397                            })?
398                        },
399                    };
400
401                    node_request.input.dot_remove("$nodes");
402                    res.output.dot_remove("$nodes");
403
404                    trace!({
405                        input: node_request.input,
406                        output: res.output.clone(),
407                        trace_data: res.trace_data,
408                    });
409                    walker.set_node_data(nid, res.output);
410                },
411                DecisionNodeKind::DecisionNode { .. } => {
412                    let node_request = NodeRequest {
413                        node: node.clone(),
414                        iteration: self.iteration,
415                        input: walker.incoming_node_data(
416                            &self.graph,
417                            nid,
418                            true,
419                        ),
420                    };
421
422                    let res = DecisionHandler::new(
423                        self.trace,
424                        self.max_depth,
425                        self.loader.clone(),
426                        self.adapter.clone(),
427                        self.runtime.clone(),
428                        self.validator_cache.clone(),
429                    )
430                    .handle(node_request.clone())
431                    .await
432                    .map_err(|e| NodeError {
433                        source: e.into(),
434                        node_id: node.id.to_string(),
435                        trace: error_trace(&node_traces),
436                    })?;
437
438                    node_request.input.dot_remove("$nodes");
439                    res.output.dot_remove("$nodes");
440
441                    trace!({
442                        input: node_request.input,
443                        output: res.output.clone(),
444                        trace_data: res.trace_data,
445                    });
446                    walker.set_node_data(nid, res.output);
447                },
448                DecisionNodeKind::DecisionTableNode { .. } => {
449                    let node_request = NodeRequest {
450                        node: node.clone(),
451                        iteration: self.iteration,
452                        input: walker.incoming_node_data(
453                            &self.graph,
454                            nid,
455                            true,
456                        ),
457                    };
458
459                    let res = DecisionTableHandler::new(self.trace)
460                        .handle(node_request.clone())
461                        .await
462                        .map_err(|e| NodeError {
463                            node_id: node.id.clone(),
464                            source: e.into(),
465                            trace: error_trace(&node_traces),
466                        })?;
467
468                    node_request.input.dot_remove("$nodes");
469                    res.output.dot_remove("$nodes");
470
471                    trace!({
472                        input: node_request.input,
473                        output: res.output.clone(),
474                        trace_data: res.trace_data,
475                    });
476                    walker.set_node_data(nid, res.output);
477                },
478                DecisionNodeKind::ExpressionNode { .. } => {
479                    let node_request = NodeRequest {
480                        node: node.clone(),
481                        iteration: self.iteration,
482                        input: walker.incoming_node_data(
483                            &self.graph,
484                            nid,
485                            true,
486                        ),
487                    };
488
489                    let res = ExpressionHandler::new(self.trace)
490                        .handle(node_request.clone())
491                        .await
492                        .map_err(|e| {
493                            if let Some(detailed_err) =
494                                e.downcast_ref::<PartialTraceError>()
495                            {
496                                trace!({
497                                    input: node_request.input.clone(),
498                                    output: Variable::Null,
499                                    trace_data: detailed_err.trace.clone(),
500                                });
501                            }
502
503                            NodeError {
504                                node_id: node.id.clone(),
505                                source: e.into(),
506                                trace: error_trace(&node_traces),
507                            }
508                        })?;
509
510                    node_request.input.dot_remove("$nodes");
511                    res.output.dot_remove("$nodes");
512
513                    trace!({
514                        input: node_request.input,
515                        output: res.output.clone(),
516                        trace_data: res.trace_data,
517                    });
518                    walker.set_node_data(nid, res.output);
519                },
520                DecisionNodeKind::CustomNode { .. } => {
521                    let node_request = NodeRequest {
522                        node: node.clone(),
523                        iteration: self.iteration,
524                        input: walker.incoming_node_data(
525                            &self.graph,
526                            nid,
527                            true,
528                        ),
529                    };
530
531                    let res = self
532                        .adapter
533                        .handle(
534                            CustomNodeRequest::try_from(node_request.clone())
535                                .unwrap(),
536                        )
537                        .await
538                        .map_err(|e| NodeError {
539                            node_id: node.id.clone(),
540                            source: e.into(),
541                            trace: error_trace(&node_traces),
542                        })?;
543
544                    node_request.input.dot_remove("$nodes");
545                    res.output.dot_remove("$nodes");
546
547                    trace!({
548                        input: node_request.input,
549                        output: res.output.clone(),
550                        trace_data: res.trace_data,
551                    });
552                    walker.set_node_data(nid, res.output);
553                },
554            }
555        }
556
557        Ok(DecisionGraphResponse {
558            result: walker.ending_variables(&self.graph),
559            performance: format!("{:.1?}", root_start.elapsed()),
560            trace: node_traces,
561        })
562    }
563}
564
565#[derive(Debug, Error)]
566pub enum DecisionGraphValidationError {
567    #[error("Invalid input node count: {0}")]
568    InvalidInputCount(u32),
569
570    #[error("Invalid output node count: {0}")]
571    InvalidOutputCount(u32),
572
573    #[error("Cyclic graph detected")]
574    CyclicGraph,
575
576    #[error("Missing node")]
577    MissingNode(String),
578}
579
580impl Serialize for DecisionGraphValidationError {
581    fn serialize<S>(
582        &self,
583        serializer: S,
584    ) -> Result<S::Ok, S::Error>
585    where
586        S: Serializer,
587    {
588        let mut map = serializer.serialize_map(None)?;
589
590        match &self {
591            DecisionGraphValidationError::InvalidInputCount(count) => {
592                map.serialize_entry("type", "invalidInputCount")?;
593                map.serialize_entry("nodeCount", count)?;
594            },
595            DecisionGraphValidationError::InvalidOutputCount(count) => {
596                map.serialize_entry("type", "invalidOutputCount")?;
597                map.serialize_entry("nodeCount", count)?;
598            },
599            DecisionGraphValidationError::MissingNode(node_id) => {
600                map.serialize_entry("type", "missingNode")?;
601                map.serialize_entry("nodeId", node_id)?;
602            },
603            DecisionGraphValidationError::CyclicGraph => {
604                map.serialize_entry("type", "cyclicGraph")?;
605            },
606        }
607
608        map.end()
609    }
610}
611
612#[derive(Debug, Clone, Serialize, Deserialize)]
613#[serde(rename_all = "camelCase")]
614pub struct DecisionGraphResponse {
615    pub performance: String,
616    pub result: Variable,
617    #[serde(skip_serializing_if = "Option::is_none")]
618    pub trace: Option<HashMap<String, DecisionGraphTrace>>,
619}
620
621#[derive(Debug, Clone, Serialize, Deserialize)]
622#[serde(rename_all = "camelCase")]
623pub struct DecisionGraphTrace {
624    pub input: Variable,
625    pub output: Variable,
626    pub name: String,
627    pub id: String,
628    pub performance: Option<String>,
629    pub trace_data: Option<Value>,
630    pub order: u32,
631}
632
633pub(crate) fn error_trace(
634    trace: &Option<HashMap<String, DecisionGraphTrace>>
635) -> Option<Value> {
636    trace.as_ref().map(|s| serde_json::to_value(s).ok()).flatten()
637}
638
639fn create_validator_cache_key(content: &Value) -> u64 {
640    let mut hasher = DefaultHasher::new();
641    content.hash(&mut hasher);
642    hasher.finish()
643}