rush_lua_engine 0.1.0

The rules engine is based on the rete algorithm
Documentation
use anyhow::{anyhow, Error};
use mlua::{Function, Lua, LuaSerdeExt, Value};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::str::FromStr;
use std::time::Duration;
use tokio::sync::oneshot::*;
use wd_tools::{PFErr, PFOk};

#[derive(Debug)]
struct Task {
    input: serde_json::Value,
    sender: Sender<anyhow::Result<serde_json::Value>>,
}

#[derive(Serialize, Deserialize, Default)]
pub struct InitResult {
    #[serde(default = "Default::default")]
    code: isize,
    #[serde(default = "String::default")]
    message: String,
    #[serde(default = "String::default")]
    handle_function: String,
}

#[derive(Debug)]
pub struct LuaRuntime {
    sender: async_channel::Sender<Task>,
}

impl Clone for LuaRuntime {
    fn clone(&self) -> Self {
        Self {
            sender: self.sender.clone(),
        }
    }
}

impl LuaRuntime {
    pub fn new(script: String, envs: HashMap<String, String>) -> anyhow::Result<LuaRuntime> {
        let (sender, receiver) = async_channel::bounded(32);
        let (sender_init, mut receiver_init) = channel::<anyhow::Result<InitResult>>();

        std::thread::spawn(move || {
            let lua = Lua::new();
            let global = lua.globals();

            for (k, v) in envs {
                if let Err(e) = global.set(k, v) {
                    let _ = sender_init.send(Err(Error::from(e)));
                    return;
                }
            }

            let result = lua.load(script).eval::<Option<Value>>();
            let val = match result {
                Ok(o) => {
                    if o.is_none() {
                        let _ = sender_init.send(
                            anyhow!("load script init failed, not return check result").err(),
                        );
                        return;
                    }
                    o.unwrap()
                }
                Err(e) => {
                    let _ = sender_init.send(Err(Error::from(e)));
                    return;
                }
            };
            let result = match lua.from_value::<InitResult>(val) {
                Ok(o) => o,
                Err(e) => {
                    let _ = sender_init.send(Err(Error::from(e)));
                    return;
                }
            };
            if result.code != 0 {
                let _ = sender_init.send(Ok(result));
                return;
            }
            if result.handle_function.is_empty() {
                let _ = sender_init.send(anyhow!("entry function must manifest").err());
                return;
            }
            let func = match global.get::<_, Function>(result.handle_function.as_str()) {
                Ok(f) => f,
                Err(e) => {
                    let _ = sender_init.send(Err(Error::from(e)));
                    return;
                }
            };

            if sender_init.send(InitResult::default().ok()).is_err() {
                return;
            }
            loop {
                let Task { input, sender } = match receiver.recv_blocking() {
                    Ok(o) => o,
                    Err(_) => return,
                };
                let val = match lua.to_value(&input) {
                    Ok(o) => o,
                    Err(e) => {
                        let _ = sender.send(Err(Error::from(e)));
                        continue;
                    }
                };
                let val = match func.call::<_, Value>(val) {
                    Ok(o) => o,
                    Err(e) => {
                        let _ = sender.send(Err(Error::from(e)));
                        continue;
                    }
                };

                match lua.from_value::<serde_json::Value>(val) {
                    Ok(o) => {
                        let _ = sender.send(Ok(o));
                    }
                    Err(e) => {
                        let _ = sender.send(Err(Error::from(e)));
                    }
                };
            }
        });

        loop {
            match receiver_init.try_recv() {
                Ok(Ok(o)) => {
                    if o.code != 0 {
                        return anyhow!("init lua runtime failed:{}", o.message).err();
                    }
                    break;
                }
                Ok(Err(e)) => {
                    return Err(e);
                }
                Err(error::TryRecvError::Closed) => {
                    return anyhow!("init lua runtime unknown error").err();
                }
                Err(error::TryRecvError::Empty) => {
                    std::thread::sleep(Duration::from_millis(1));
                }
            }
        }
        Ok(LuaRuntime { sender })
    }

    pub fn call<S: Serialize, Out: for<'a> Deserialize<'a>>(&self, req: S) -> anyhow::Result<Out> {
        let req = serde_json::to_value(req)?;
        let (sender, receiver) = channel();
        let task = Task { input: req, sender };
        if let Err(e) = self.sender.send_blocking(task) {
            let err = e.to_string();
            return anyhow!("lua runtime call failed: {}", err).err();
        }
        return match receiver.blocking_recv() {
            Ok(o) => {
                let out = serde_json::from_value::<Out>(o?)?;
                Ok(out)
            }
            Err(e) => anyhow!("lua runtime error:{}", e).err(),
        };
    }
    pub async fn async_call<S: Serialize, Out: for<'a> Deserialize<'a>>(
        &self,
        req: S,
    ) -> anyhow::Result<Out> {
        let req = serde_json::to_value(req)?;
        let (sender, receiver) = channel();
        let task = Task { input: req, sender };
        if let Err(e) = self.sender.send(task).await {
            let err = e.to_string();
            return anyhow!("lua runtime call failed: {}", err).err();
        }
        return match receiver.await {
            Ok(o) => {
                let out: Out = serde_json::from_value(o?)?;
                Ok(out)
            }
            Err(e) => anyhow!("lua runtime error:{}", e).err(),
        };
    }
    pub fn close(&self) {
        self.sender.close();
    }
}

impl Drop for LuaRuntime {
    fn drop(&mut self) {
        self.close();
    }
}

#[cfg(feature = "rule-flow")]
impl rush_core::RuleFlow for LuaRuntime {
    fn flow<Obj: Serialize, Out: for<'a> Deserialize<'a>>(&self, obj: Obj) -> anyhow::Result<Out> {
        self.call(obj)
    }
}
#[cfg(feature = "rule-flow")]
#[async_trait::async_trait]
impl rush_core::AsyncRuleFlow for LuaRuntime {
    async fn async_flow<Obj: Serialize + Send, Out: for<'a> Deserialize<'a>>(
        &self,
        obj: Obj,
    ) -> anyhow::Result<Out> {
        self.async_call(obj).await
    }
}

impl FromStr for LuaRuntime {
    type Err = anyhow::Error;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        LuaRuntime::new(s.to_string(), HashMap::new())
    }
}

#[cfg(test)]
mod test {
    use crate::LuaRuntime;
    use serde_json::Value;
    use std::collections::HashMap;

    const TEST_LUA_SCRIPT: &'static str = r#"
    function handle(req)
        for k, v in pairs(req) do
            print("--->",k,v)
        end
        local resp = "success"
        return resp
    end

    return {code=0,message="success",handle_function="handle"}
    "#;

    #[test]
    fn test_function_lua_runtime() {
        let rt = LuaRuntime::new(TEST_LUA_SCRIPT.to_string(), HashMap::new()).unwrap();
        let result: Value = rt
            .call(r#"{"order_id":"78632839429034208","status":1}"#.parse::<Value>().unwrap())
            .unwrap();
        assert_eq!(result, Value::String("success".into()))
    }
}