1use crate::handler::custom_node_adapter::{CustomNodeAdapter, CustomNodeRequest};
2use crate::handler::decision::DecisionHandler;
3use crate::handler::expression::ExpressionHandler;
4use crate::handler::function::function::{Function, FunctionConfig};
5use crate::handler::function::module::console::ConsoleListener;
6use crate::handler::function::module::zen::ZenListener;
7use crate::handler::function::FunctionHandler;
8use crate::handler::function_v1;
9use crate::handler::function_v1::runtime::create_runtime;
10use crate::handler::node::{NodeRequest, PartialTraceError};
11use crate::handler::table::zen::DecisionTableHandler;
12use crate::handler::traversal::{GraphWalker, StableDiDecisionGraph};
13use crate::loader::DecisionLoader;
14use crate::model::{DecisionContent, DecisionNodeKind, FunctionNodeContent};
15use crate::util::validator_cache::ValidatorCache;
16use crate::{EvaluationError, NodeError};
17use ahash::{HashMap, HashMapExt};
18use anyhow::anyhow;
19use petgraph::algo::is_cyclic_directed;
20use serde::ser::SerializeMap;
21use serde::{Deserialize, Serialize, Serializer};
22use serde_json::Value;
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::rc::Rc;
25use std::sync::Arc;
26use std::time::Instant;
27use thiserror::Error;
28use mf_expression::variable::Variable;
29use crate::handler::function::module::mf::ModuforgeListener;
30
31pub struct DecisionGraph<
32 L: DecisionLoader + 'static,
33 A: CustomNodeAdapter + 'static,
34> {
35 initial_graph: StableDiDecisionGraph,
36 graph: StableDiDecisionGraph,
37 adapter: Arc<A>,
38 loader: Arc<L>,
39 trace: bool,
40 max_depth: u8,
41 iteration: u8,
42 runtime: Option<Rc<Function>>,
43 validator_cache: ValidatorCache,
44}
45
46pub struct DecisionGraphConfig<
47 L: DecisionLoader + 'static,
48 A: CustomNodeAdapter + 'static,
49> {
50 pub loader: Arc<L>,
51 pub adapter: Arc<A>,
52 pub content: Arc<DecisionContent>,
53 pub trace: bool,
54 pub iteration: u8,
55 pub max_depth: u8,
56 pub validator_cache: Option<ValidatorCache>,
57}
58
59impl<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static>
60 DecisionGraph<L, A>
61{
62 pub fn try_new(
63 config: DecisionGraphConfig<L, A>
64 ) -> Result<Self, DecisionGraphValidationError> {
65 let content = config.content;
66 let mut graph = StableDiDecisionGraph::new();
67 let mut index_map = HashMap::new();
68
69 for node in &content.nodes {
70 let node_id = node.id.clone();
71 let node_index = graph.add_node(node.clone());
72
73 index_map.insert(node_id, node_index);
74 }
75
76 for (_, edge) in content.edges.iter().enumerate() {
77 let source_index =
78 index_map.get(&edge.source_id).ok_or_else(|| {
79 DecisionGraphValidationError::MissingNode(
80 edge.source_id.to_string(),
81 )
82 })?;
83
84 let target_index =
85 index_map.get(&edge.target_id).ok_or_else(|| {
86 DecisionGraphValidationError::MissingNode(
87 edge.target_id.to_string(),
88 )
89 })?;
90
91 graph.add_edge(
92 source_index.clone(),
93 target_index.clone(),
94 edge.clone(),
95 );
96 }
97
98 Ok(Self {
99 initial_graph: graph.clone(),
100 graph,
101 iteration: config.iteration,
102 trace: config.trace,
103 loader: config.loader,
104 adapter: config.adapter,
105 max_depth: config.max_depth,
106 validator_cache: config.validator_cache.unwrap_or_default(),
107 runtime: None,
108 })
109 }
110
111 pub(crate) fn with_function(
112 mut self,
113 runtime: Option<Rc<Function>>,
114 ) -> Self {
115 self.runtime = runtime;
116 self
117 }
118
119 pub(crate) fn reset_graph(&mut self) {
120 self.graph = self.initial_graph.clone();
121 }
122
123 async fn get_or_insert_function(&mut self) -> anyhow::Result<Rc<Function>> {
124 if let Some(function) = &self.runtime {
125 return Ok(function.clone());
126 }
127
128 let function = Function::create(FunctionConfig {
129 listeners: Some(vec![
130 Box::new(ConsoleListener),
131 Box::new(ZenListener {
132 loader: self.loader.clone(),
133 adapter: self.adapter.clone(),
134 }),
135 Box::new(ModuforgeListener {}),
136 ]),
137 })
138 .await
139 .map_err(|err| anyhow!(err.to_string()))?;
140 let rc_function = Rc::new(function);
141 self.runtime.replace(rc_function.clone());
142
143 Ok(rc_function)
144 }
145
146 pub fn validate(&self) -> Result<(), DecisionGraphValidationError> {
147 let input_count = self.input_node_count();
148 if input_count != 1 {
149 return Err(DecisionGraphValidationError::InvalidInputCount(
150 input_count as u32,
151 ));
152 }
153
154 if is_cyclic_directed(&self.graph) {
155 return Err(DecisionGraphValidationError::CyclicGraph);
156 }
157
158 Ok(())
159 }
160
161 fn input_node_count(&self) -> usize {
162 self.graph
163 .node_weights()
164 .filter(|weight| {
165 matches!(
166 weight.kind,
167 DecisionNodeKind::InputNode { content: _ }
168 )
169 })
170 .count()
171 }
172
173 pub async fn evaluate(
174 &mut self,
175 context: Variable,
176 ) -> Result<DecisionGraphResponse, NodeError> {
177 let root_start = Instant::now();
178
179 self.validate().map_err(|e| NodeError {
180 node_id: "".to_string(),
181 source: anyhow!(e),
182 trace: None,
183 })?;
184
185 if self.iteration >= self.max_depth {
186 return Err(NodeError {
187 node_id: "".to_string(),
188 source: anyhow!(EvaluationError::DepthLimitExceeded),
189 trace: None,
190 });
191 }
192
193 let mut walker = GraphWalker::new(&self.graph);
194 let mut node_traces = self.trace.then(|| HashMap::default());
195
196 while let Some(nid) = walker.next(
197 &mut self.graph,
198 self.trace.then_some(|mut trace: DecisionGraphTrace| {
199 if let Some(nt) = &mut node_traces {
200 trace.order = nt.len() as u32;
201 nt.insert(trace.id.clone(), trace);
202 };
203 }),
204 ) {
205 if let Some(_) = walker.get_node_data(nid) {
206 continue;
207 }
208
209 let node = (&self.graph[nid]).clone();
210 let start = Instant::now();
211
212 macro_rules! trace {
213 ({ $($field:ident: $value:expr),* $(,)? }) => {
214 if let Some(nt) = &mut node_traces {
215 nt.insert(
216 node.id.clone(),
217 DecisionGraphTrace {
218 name: node.name.clone(),
219 id: node.id.clone(),
220 performance: Some(format!("{:.1?}", start.elapsed())),
221 order: nt.len() as u32,
222 $($field: $value,)*
223 }
224 );
225 }
226 };
227 }
228
229 match &node.kind {
230 DecisionNodeKind::InputNode { content } => {
231 trace!({
232 input: Variable::Null,
233 output: context.clone(),
234 trace_data: None,
235 });
236
237 if let Some(json_schema) = content
238 .schema
239 .as_ref()
240 .map(|s| serde_json::from_str::<Value>(&s).ok())
241 .flatten()
242 {
243 let validator_key =
244 create_validator_cache_key(&json_schema);
245 let validator = self
246 .validator_cache
247 .get_or_insert(validator_key, &json_schema)
248 .await
249 .map_err(|e| NodeError {
250 source: e.into(),
251 node_id: node.id.clone(),
252 trace: error_trace(&node_traces),
253 })?;
254
255 let context_json = context.to_value();
256 validator.validate(&context_json).map_err(|e| {
257 NodeError {
258 source: anyhow!(
259 serde_json::to_value(Into::<
260 Box<EvaluationError>,
261 >::into(
262 e
263 ))
264 .unwrap_or_default()
265 ),
266 node_id: node.id.clone(),
267 trace: error_trace(&node_traces),
268 }
269 })?;
270 }
271
272 walker.set_node_data(nid, context.clone());
273 },
274 DecisionNodeKind::OutputNode { content } => {
275 let incoming_data =
276 walker.incoming_node_data(&self.graph, nid, false);
277
278 trace!({
279 input: incoming_data.clone(),
280 output: Variable::Null,
281 trace_data: None,
282 });
283
284 if let Some(json_schema) = content
285 .schema
286 .as_ref()
287 .map(|s| serde_json::from_str::<Value>(&s).ok())
288 .flatten()
289 {
290 let validator_key =
291 create_validator_cache_key(&json_schema);
292 let validator = self
293 .validator_cache
294 .get_or_insert(validator_key, &json_schema)
295 .await
296 .map_err(|e| NodeError {
297 source: e.into(),
298 node_id: node.id.clone(),
299 trace: error_trace(&node_traces),
300 })?;
301
302 let incoming_data_json = incoming_data.to_value();
303 validator.validate(&incoming_data_json).map_err(
304 |e| NodeError {
305 source: anyhow!(
306 serde_json::to_value(Into::<
307 Box<EvaluationError>,
308 >::into(
309 e
310 ))
311 .unwrap_or_default()
312 ),
313 node_id: node.id.clone(),
314 trace: error_trace(&node_traces),
315 },
316 )?;
317 }
318
319 return Ok(DecisionGraphResponse {
320 result: incoming_data,
321 performance: format!("{:.1?}", root_start.elapsed()),
322 trace: node_traces,
323 });
324 },
325 DecisionNodeKind::SwitchNode { .. } => {
326 let input_data =
327 walker.incoming_node_data(&self.graph, nid, false);
328
329 walker.set_node_data(nid, input_data);
330 },
331 DecisionNodeKind::FunctionNode { content } => {
332 let function = self
333 .get_or_insert_function()
334 .await
335 .map_err(|e| NodeError {
336 source: e.into(),
337 node_id: node.id.clone(),
338 trace: error_trace(&node_traces),
339 })?;
340
341 let node_request = NodeRequest {
342 node: node.clone(),
343 iteration: self.iteration,
344 input: walker.incoming_node_data(
345 &self.graph,
346 nid,
347 true,
348 ),
349 };
350 let res = match content {
351 FunctionNodeContent::Version2(_) => {
352 FunctionHandler::new(
353 function,
354 self.trace,
355 self.iteration,
356 self.max_depth,
357 )
358 .handle(node_request.clone())
359 .await
360 .map_err(|e| {
361 if let Some(detailed_err) =
362 e.downcast_ref::<PartialTraceError>()
363 {
364 trace!({
365 input: node_request.input.clone(),
366 output: Variable::Null,
367 trace_data: detailed_err.trace.clone(),
368 });
369 }
370
371 NodeError {
372 source: e.into(),
373 node_id: node.id.clone(),
374 trace: error_trace(&node_traces),
375 }
376 })?
377 },
378 FunctionNodeContent::Version1(_) => {
379 let runtime =
380 create_runtime().map_err(|e| NodeError {
381 source: e.into(),
382 node_id: node.id.clone(),
383 trace: error_trace(&node_traces),
384 })?;
385
386 function_v1::FunctionHandler::new(
387 self.trace, runtime,
388 )
389 .handle(node_request.clone())
390 .await
391 .map_err(|e| {
392 NodeError {
393 source: e.into(),
394 node_id: node.id.clone(),
395 trace: error_trace(&node_traces),
396 }
397 })?
398 },
399 };
400
401 node_request.input.dot_remove("$nodes");
402 res.output.dot_remove("$nodes");
403
404 trace!({
405 input: node_request.input,
406 output: res.output.clone(),
407 trace_data: res.trace_data,
408 });
409 walker.set_node_data(nid, res.output);
410 },
411 DecisionNodeKind::DecisionNode { .. } => {
412 let node_request = NodeRequest {
413 node: node.clone(),
414 iteration: self.iteration,
415 input: walker.incoming_node_data(
416 &self.graph,
417 nid,
418 true,
419 ),
420 };
421
422 let res = DecisionHandler::new(
423 self.trace,
424 self.max_depth,
425 self.loader.clone(),
426 self.adapter.clone(),
427 self.runtime.clone(),
428 self.validator_cache.clone(),
429 )
430 .handle(node_request.clone())
431 .await
432 .map_err(|e| NodeError {
433 source: e.into(),
434 node_id: node.id.to_string(),
435 trace: error_trace(&node_traces),
436 })?;
437
438 node_request.input.dot_remove("$nodes");
439 res.output.dot_remove("$nodes");
440
441 trace!({
442 input: node_request.input,
443 output: res.output.clone(),
444 trace_data: res.trace_data,
445 });
446 walker.set_node_data(nid, res.output);
447 },
448 DecisionNodeKind::DecisionTableNode { .. } => {
449 let node_request = NodeRequest {
450 node: node.clone(),
451 iteration: self.iteration,
452 input: walker.incoming_node_data(
453 &self.graph,
454 nid,
455 true,
456 ),
457 };
458
459 let res = DecisionTableHandler::new(self.trace)
460 .handle(node_request.clone())
461 .await
462 .map_err(|e| NodeError {
463 node_id: node.id.clone(),
464 source: e.into(),
465 trace: error_trace(&node_traces),
466 })?;
467
468 node_request.input.dot_remove("$nodes");
469 res.output.dot_remove("$nodes");
470
471 trace!({
472 input: node_request.input,
473 output: res.output.clone(),
474 trace_data: res.trace_data,
475 });
476 walker.set_node_data(nid, res.output);
477 },
478 DecisionNodeKind::ExpressionNode { .. } => {
479 let node_request = NodeRequest {
480 node: node.clone(),
481 iteration: self.iteration,
482 input: walker.incoming_node_data(
483 &self.graph,
484 nid,
485 true,
486 ),
487 };
488
489 let res = ExpressionHandler::new(self.trace)
490 .handle(node_request.clone())
491 .await
492 .map_err(|e| {
493 if let Some(detailed_err) =
494 e.downcast_ref::<PartialTraceError>()
495 {
496 trace!({
497 input: node_request.input.clone(),
498 output: Variable::Null,
499 trace_data: detailed_err.trace.clone(),
500 });
501 }
502
503 NodeError {
504 node_id: node.id.clone(),
505 source: e.into(),
506 trace: error_trace(&node_traces),
507 }
508 })?;
509
510 node_request.input.dot_remove("$nodes");
511 res.output.dot_remove("$nodes");
512
513 trace!({
514 input: node_request.input,
515 output: res.output.clone(),
516 trace_data: res.trace_data,
517 });
518 walker.set_node_data(nid, res.output);
519 },
520 DecisionNodeKind::CustomNode { .. } => {
521 let node_request = NodeRequest {
522 node: node.clone(),
523 iteration: self.iteration,
524 input: walker.incoming_node_data(
525 &self.graph,
526 nid,
527 true,
528 ),
529 };
530
531 let res = self
532 .adapter
533 .handle(
534 CustomNodeRequest::try_from(node_request.clone())
535 .unwrap(),
536 )
537 .await
538 .map_err(|e| NodeError {
539 node_id: node.id.clone(),
540 source: e.into(),
541 trace: error_trace(&node_traces),
542 })?;
543
544 node_request.input.dot_remove("$nodes");
545 res.output.dot_remove("$nodes");
546
547 trace!({
548 input: node_request.input,
549 output: res.output.clone(),
550 trace_data: res.trace_data,
551 });
552 walker.set_node_data(nid, res.output);
553 },
554 }
555 }
556
557 Ok(DecisionGraphResponse {
558 result: walker.ending_variables(&self.graph),
559 performance: format!("{:.1?}", root_start.elapsed()),
560 trace: node_traces,
561 })
562 }
563}
564
565#[derive(Debug, Error)]
566pub enum DecisionGraphValidationError {
567 #[error("Invalid input node count: {0}")]
568 InvalidInputCount(u32),
569
570 #[error("Invalid output node count: {0}")]
571 InvalidOutputCount(u32),
572
573 #[error("Cyclic graph detected")]
574 CyclicGraph,
575
576 #[error("Missing node")]
577 MissingNode(String),
578}
579
580impl Serialize for DecisionGraphValidationError {
581 fn serialize<S>(
582 &self,
583 serializer: S,
584 ) -> Result<S::Ok, S::Error>
585 where
586 S: Serializer,
587 {
588 let mut map = serializer.serialize_map(None)?;
589
590 match &self {
591 DecisionGraphValidationError::InvalidInputCount(count) => {
592 map.serialize_entry("type", "invalidInputCount")?;
593 map.serialize_entry("nodeCount", count)?;
594 },
595 DecisionGraphValidationError::InvalidOutputCount(count) => {
596 map.serialize_entry("type", "invalidOutputCount")?;
597 map.serialize_entry("nodeCount", count)?;
598 },
599 DecisionGraphValidationError::MissingNode(node_id) => {
600 map.serialize_entry("type", "missingNode")?;
601 map.serialize_entry("nodeId", node_id)?;
602 },
603 DecisionGraphValidationError::CyclicGraph => {
604 map.serialize_entry("type", "cyclicGraph")?;
605 },
606 }
607
608 map.end()
609 }
610}
611
612#[derive(Debug, Clone, Serialize, Deserialize)]
613#[serde(rename_all = "camelCase")]
614pub struct DecisionGraphResponse {
615 pub performance: String,
616 pub result: Variable,
617 #[serde(skip_serializing_if = "Option::is_none")]
618 pub trace: Option<HashMap<String, DecisionGraphTrace>>,
619}
620
621#[derive(Debug, Clone, Serialize, Deserialize)]
622#[serde(rename_all = "camelCase")]
623pub struct DecisionGraphTrace {
624 pub input: Variable,
625 pub output: Variable,
626 pub name: String,
627 pub id: String,
628 pub performance: Option<String>,
629 pub trace_data: Option<Value>,
630 pub order: u32,
631}
632
633pub(crate) fn error_trace(
634 trace: &Option<HashMap<String, DecisionGraphTrace>>
635) -> Option<Value> {
636 trace.as_ref().map(|s| serde_json::to_value(s).ok()).flatten()
637}
638
639fn create_validator_cache_key(content: &Value) -> u64 {
640 let mut hasher = DefaultHasher::new();
641 content.hash(&mut hasher);
642 hasher.finish()
643}