mf_engine/handler/
decision.rs

1use crate::handler::custom_node_adapter::CustomNodeAdapter;
2use crate::handler::function::function::Function;
3use crate::handler::graph::{DecisionGraph, DecisionGraphConfig};
4use crate::handler::node::{NodeRequest, NodeResponse, NodeResult};
5use crate::loader::DecisionLoader;
6use crate::model::DecisionNodeKind;
7use crate::util::validator_cache::ValidatorCache;
8use anyhow::anyhow;
9use std::future::Future;
10use std::pin::Pin;
11use std::rc::Rc;
12use std::sync::Arc;
13use tokio::sync::Mutex;
14
15pub struct DecisionHandler<
16    L: DecisionLoader + 'static,
17    A: CustomNodeAdapter + 'static,
18> {
19    trace: bool,
20    loader: Arc<L>,
21    adapter: Arc<A>,
22    max_depth: u8,
23    js_function: Option<Rc<Function>>,
24    validator_cache: ValidatorCache,
25}
26
27impl<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static>
28    DecisionHandler<L, A>
29{
30    pub fn new(
31        trace: bool,
32        max_depth: u8,
33        loader: Arc<L>,
34        adapter: Arc<A>,
35        js_function: Option<Rc<Function>>,
36        validator_cache: ValidatorCache,
37    ) -> Self {
38        Self { trace, loader, adapter, max_depth, js_function, validator_cache }
39    }
40
41    pub fn handle<'s, 'arg, 'recursion>(
42        &'s self,
43        request: NodeRequest,
44    ) -> Pin<Box<dyn Future<Output = NodeResult> + 'recursion>>
45    where
46        's: 'recursion,
47        'arg: 'recursion,
48    {
49        Box::pin(async move {
50            let content = match &request.node.kind {
51                DecisionNodeKind::DecisionNode { content } => Ok(content),
52                _ => Err(anyhow!("Unexpected node type")),
53            }?;
54
55            let sub_decision = self.loader.load(&content.key).await?;
56            let sub_tree = DecisionGraph::try_new(DecisionGraphConfig {
57                content: sub_decision,
58                max_depth: self.max_depth,
59                loader: self.loader.clone(),
60                adapter: self.adapter.clone(),
61                iteration: request.iteration + 1,
62                trace: self.trace,
63                validator_cache: Some(self.validator_cache.clone()),
64            })?
65            .with_function(self.js_function.clone());
66
67            let sub_tree_mutex = Arc::new(Mutex::new(sub_tree));
68
69            content
70                .transform_attributes
71                .run_with(request.input, |input| {
72                    let sub_tree_mutex = sub_tree_mutex.clone();
73
74                    async move {
75                        let mut sub_tree_ref = sub_tree_mutex.lock().await;
76
77                        sub_tree_ref.reset_graph();
78                        sub_tree_ref
79                            .evaluate(input)
80                            .await
81                            .map(|r| NodeResponse {
82                                output: r.result,
83                                trace_data: serde_json::to_value(r.trace).ok(),
84                            })
85                            .map_err(|e| e.source)
86                    }
87                })
88                .await
89        })
90    }
91}