zen-engine 0.55.0

Business rules engine
Documentation
use crate::nodes::definition::{NodeDataType, TraceDataType};
use crate::nodes::extensions::NodeHandlerExtensions;
use crate::nodes::function::v2::function::Function;
use crate::nodes::result::{NodeResponse, NodeResult};
use crate::nodes::NodeError;
use crate::ZEN_CONFIG;
use ahash::AHasher;
use jsonschema::ValidationError;
use serde::Serialize;
use serde_json::Value;
use std::cell::RefCell;
use std::fmt::{Display, Formatter};
use std::hash::Hasher;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use thiserror::Error;
use zen_types::variable::{ToVariable, Variable};

#[derive(Clone)]
pub struct NodeContext<NodeData, TraceData>
where
    NodeData: NodeDataType,
    TraceData: TraceDataType,
{
    pub id: Arc<str>,
    pub name: Arc<str>,
    pub node: NodeData,
    pub input: Variable,
    pub trace: Option<RefCell<TraceData>>,
    pub extensions: NodeHandlerExtensions,
    pub iteration: u8,
    pub config: NodeContextConfig,
}

impl<NodeData, TraceData> NodeContext<NodeData, TraceData>
where
    NodeData: NodeDataType,
    TraceData: TraceDataType,
{
    pub fn from_base(base: NodeContextBase, data: NodeData) -> Self {
        Self {
            id: base.id,
            name: base.name,
            input: base.input,
            extensions: base.extensions,
            iteration: base.iteration,
            trace: base.config.trace.then(|| Default::default()),
            node: data,
            config: base.config,
        }
    }

    pub fn trace<Function>(&self, mutator: Function)
    where
        Function: FnOnce(&mut TraceData),
    {
        if let Some(trace) = &self.trace {
            mutator(&mut *trace.borrow_mut());
        }
    }

    pub fn error<Error>(&self, error: Error) -> NodeResult
    where
        Error: Into<Box<dyn std::error::Error>>,
    {
        Err(self.make_error(error))
    }

    pub fn success(&self, output: Variable) -> NodeResult {
        Ok(NodeResponse {
            output,
            trace_data: self.trace.as_ref().map(|v| (*v.borrow()).to_variable()),
        })
    }

    pub(crate) fn make_error<Error>(&self, error: Error) -> NodeError
    where
        Error: Into<Box<dyn std::error::Error>>,
    {
        NodeError {
            node_id: self.id.clone(),
            trace: self.trace.as_ref().map(|v| (*v.borrow()).to_variable()),
            source: error.into(),
        }
    }

    pub(crate) async fn function_runtime(&self) -> Result<&Function, NodeError> {
        self.extensions.function_runtime().await.node_context(self)
    }

    pub fn validate(&self, schema: &Value, value: &Value) -> Result<(), NodeError> {
        let validator_cache = self.extensions.validator_cache();
        let hash = self.hash_node();

        let validator = validator_cache
            .get_or_insert(hash, schema)
            .node_context(self)?;

        validator
            .validate(value)
            .map_err(|err| ValidationErrorJson::from(err))
            .node_context(self)?;

        Ok(())
    }

    fn hash_node(&self) -> u64 {
        let mut hasher = AHasher::default();
        hasher.write(self.id.as_bytes());
        hasher.write(self.name.as_bytes());
        hasher.finish()
    }
}

pub trait NodeContextExt<T, Context>: Sized {
    type Error: Into<Box<dyn std::error::Error>>;

    fn with_node_context<Function, NewError>(
        self,
        ctx: &Context,
        f: Function,
    ) -> Result<T, NodeError>
    where
        Function: FnOnce(Self::Error) -> NewError,
        NewError: Into<Box<dyn std::error::Error>>;

    fn node_context(self, ctx: &Context) -> Result<T, NodeError> {
        self.with_node_context(ctx, |e| e.into())
    }

    fn node_context_message(self, ctx: &Context, message: &str) -> Result<T, NodeError> {
        self.with_node_context(ctx, |err| format!("{}: {}", message, err.into()))
    }
}

impl<T, E, NodeData, TraceData> NodeContextExt<T, NodeContext<NodeData, TraceData>> for Result<T, E>
where
    E: Into<Box<dyn std::error::Error>>,
    NodeData: NodeDataType,
    TraceData: TraceDataType,
{
    type Error = E;

    fn with_node_context<Function, NewError>(
        self,
        ctx: &NodeContext<NodeData, TraceData>,
        f: Function,
    ) -> Result<T, NodeError>
    where
        Function: FnOnce(Self::Error) -> NewError,
        NewError: Into<Box<dyn std::error::Error>>,
    {
        self.map_err(|err| ctx.make_error(f(err)))
    }
}

impl<T, NodeData, TraceData> NodeContextExt<T, NodeContext<NodeData, TraceData>> for Option<T>
where
    NodeData: NodeDataType,
    TraceData: TraceDataType,
{
    type Error = &'static str;

    fn with_node_context<Function, NewError>(
        self,
        ctx: &NodeContext<NodeData, TraceData>,
        f: Function,
    ) -> Result<T, NodeError>
    where
        Function: FnOnce(Self::Error) -> NewError,
        NewError: Into<Box<dyn std::error::Error>>,
    {
        self.ok_or_else(|| ctx.make_error(f("None")))
    }

    fn node_context_message(
        self,
        ctx: &NodeContext<NodeData, TraceData>,
        message: &str,
    ) -> Result<T, NodeError> {
        self.with_node_context(ctx, |_| message.to_string())
    }
}

#[derive(Clone)]
pub struct NodeContextBase {
    pub id: Arc<str>,
    pub name: Arc<str>,
    pub input: Variable,
    pub iteration: u8,
    pub extensions: NodeHandlerExtensions,
    pub config: NodeContextConfig,
    pub trace: Option<RefCell<Variable>>,
}

impl NodeContextBase {
    pub fn error<Error>(&self, error: Error) -> NodeResult
    where
        Error: Into<Box<dyn std::error::Error>>,
    {
        Err(self.make_error(error))
    }

    pub fn success(&self, output: Variable) -> NodeResult {
        Ok(NodeResponse {
            output,
            trace_data: self.trace.as_ref().map(|v| v.borrow().to_variable()),
        })
    }

    fn make_error<Error>(&self, error: Error) -> NodeError
    where
        Error: Into<Box<dyn std::error::Error>>,
    {
        NodeError {
            node_id: self.id.clone(),
            trace: self.trace.as_ref().map(|t| t.borrow().to_variable()),
            source: error.into(),
        }
    }

    pub fn trace<Function>(&self, mutator: Function)
    where
        Function: FnOnce(&mut Variable),
    {
        if let Some(trace) = &self.trace {
            mutator(&mut *trace.borrow_mut());
        }
    }
}

impl<NodeData, TraceData> From<NodeContext<NodeData, TraceData>> for NodeContextBase
where
    NodeData: NodeDataType,
    TraceData: TraceDataType,
{
    fn from(value: NodeContext<NodeData, TraceData>) -> Self {
        let trace = match value.config.trace {
            true => Some(RefCell::new(Variable::Null)),
            false => None,
        };

        Self {
            id: value.id,
            name: value.name,
            input: value.input,
            extensions: value.extensions,
            iteration: value.iteration,
            config: value.config,
            trace,
        }
    }
}

impl<T, E> NodeContextExt<T, NodeContextBase> for Result<T, E>
where
    E: Into<Box<dyn std::error::Error>>,
{
    type Error = E;

    fn with_node_context<Function, NewError>(
        self,
        ctx: &NodeContextBase,
        f: Function,
    ) -> Result<T, NodeError>
    where
        Function: FnOnce(Self::Error) -> NewError,
        NewError: Into<Box<dyn std::error::Error>>,
    {
        self.map_err(|err| ctx.make_error(f(err)))
    }
}

impl<T> NodeContextExt<T, NodeContextBase> for Option<T> {
    type Error = &'static str;

    fn with_node_context<Function, NewError>(
        self,
        ctx: &NodeContextBase,
        f: Function,
    ) -> Result<T, NodeError>
    where
        Function: FnOnce(Self::Error) -> NewError,
        NewError: Into<Box<dyn std::error::Error>>,
    {
        self.ok_or_else(|| ctx.make_error(f("None")))
    }

    fn node_context_message(self, ctx: &NodeContextBase, message: &str) -> Result<T, NodeError> {
        self.with_node_context(ctx, |_| message.to_string())
    }
}

#[derive(Debug, Serialize, Error)]
#[serde(rename_all = "camelCase")]
struct ValidationErrorJson {
    path: String,
    message: String,
}

impl Display for ValidationErrorJson {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}: {}", self.path, self.message)
    }
}

impl<'a> From<ValidationError<'a>> for ValidationErrorJson {
    fn from(value: ValidationError<'a>) -> Self {
        ValidationErrorJson {
            path: value.instance_path.to_string(),
            message: format!("{}", value),
        }
    }
}

#[derive(Clone)]
pub struct NodeContextConfig {
    pub trace: bool,
    pub nodes_in_context: bool,
    pub max_depth: u8,
    pub function_timeout_millis: u64,
    pub http_auth: bool,
}

impl Default for NodeContextConfig {
    fn default() -> Self {
        Self {
            trace: false,
            nodes_in_context: ZEN_CONFIG.nodes_in_context.load(Ordering::Relaxed),
            function_timeout_millis: ZEN_CONFIG.function_timeout_millis.load(Ordering::Relaxed),
            http_auth: ZEN_CONFIG.http_auth.load(Ordering::Relaxed),
            max_depth: 5,
        }
    }
}