use crate::hub::{get_hub, HubEvent};
use crate::run::path_consts::CUSTOM_LUA_DIR;
use crate::run::{get_devai_base_dir, RuntimeContext};
use crate::script::lua_script::helpers::{process_lua_eval_result, serde_to_lua_value};
use crate::script::lua_script::{
utils_cmd, utils_code, utils_devai, utils_file, utils_git, utils_hbs, utils_html, utils_json, utils_lua, utils_md,
utils_path, utils_rust, utils_text, utils_web,
};
use crate::Result;
use mlua::{IntoLua, Lua, Table, Value};
pub struct LuaEngine {
lua: Lua,
#[allow(unused)]
runtime_context: RuntimeContext,
}
impl LuaEngine {
pub fn new(runtime_context: RuntimeContext) -> Result<Self> {
let lua = Lua::new();
init_utils(&lua, &runtime_context)?;
utils_devai::init_module(&lua, &runtime_context)?;
init_package_path(&lua, &runtime_context)?;
init_print(&lua)?;
let engine = LuaEngine { lua, runtime_context };
Ok(engine)
}
}
impl LuaEngine {
pub fn eval(&self, script: &str, scope: Option<Table>, addl_lua_paths: Option<&[&str]>) -> Result<Value> {
let lua = &self.lua;
let chunck = lua.load(script);
let chunck = if let Some(scope) = scope {
let env = self.upgrade_scope(scope, addl_lua_paths)?;
chunck.set_environment(env)
} else {
chunck
};
let res = chunck.eval::<Value>();
let res = process_lua_eval_result(&self.lua, res, script)?;
Ok(res)
}
pub fn create_table(&self) -> Result<Table> {
let res = self.lua.create_table()?;
Ok(res)
}
pub fn serde_to_lua_value(&self, val: serde_json::Value) -> Result<Value> {
serde_to_lua_value(&self.lua, val)
}
#[allow(unused)]
pub fn to_lua(&self, val: impl IntoLua) -> Result<Value> {
let res = val.into_lua(&self.lua)?;
Ok(res)
}
}
impl LuaEngine {
fn upgrade_scope(&self, scope: Table, addl_base_lua_paths: Option<&[&str]>) -> Result<Table> {
let globals = self.lua.globals();
for pair in globals.pairs::<Value, Value>() {
let (key, value) = pair?;
scope.set(key, value)?; }
if let Some(addl_lua_paths) = addl_base_lua_paths {
let mut paths: Vec<String> = Vec::new();
for path in addl_lua_paths {
paths.push(format!("{path}/lua/?.lua;{path}/lua/?/init.lua"));
}
if let Ok(lua_package) = globals.get::<Table>("package") {
let path: String = lua_package.get("path")?;
let new_path = format!("{};{path}", paths.join(";"));
lua_package.set("path", new_path)?;
}
}
Ok(scope)
}
}
fn init_package_path(lua: &Lua, runtime_context: &RuntimeContext) -> Result<()> {
let globals = lua.globals();
let package: Table = globals.get("package")?;
let path: String = package.get("path")?;
let devai_dir = runtime_context.dir_context().devai_dir();
let custom_lua_dir = devai_dir.get_lua_custom_dir()?;
let mut addl_paths = format!("{custom_lua_dir}/?.lua;{custom_lua_dir}/?/init.lua");
if let Some(base_lua_dir) = get_devai_base_dir().and_then(|base_dir| base_dir.join(CUSTOM_LUA_DIR).ok()) {
if base_lua_dir.exists() {
addl_paths = format!("{addl_paths};{base_lua_dir}/?.lua;{base_lua_dir}/?/init.lua");
}
}
let new_path = format!("{addl_paths};{path}");
package.set("path", new_path)?;
Ok(())
}
fn init_print(lua: &Lua) -> Result<()> {
let globals = lua.globals();
globals.set(
"print",
lua.create_function(|_, args: mlua::Variadic<Value>| {
let output: Vec<String> = args
.iter()
.map(|arg| match arg {
Value::String(s) => s.to_str().map(|s| s.to_string()).unwrap_or_default(),
Value::Number(n) => n.to_string(),
Value::Integer(n) => n.to_string(),
Value::Boolean(b) => b.to_string(),
_ => "<unsupported value for print args>".to_string(),
})
.collect();
let text = output.join("\t"); get_hub().publish_sync(HubEvent::LuaPrint(text.into()));
Ok(())
})?,
)?;
Ok(())
}
macro_rules! init_and_set {
($table:expr, $lua:expr, $runtime_context:expr, $($name:ident),*) => {
paste::paste! {
$(
let $name = [<utils_ $name>]::init_module($lua, $runtime_context)?;
$table.set(stringify!($name), $name)?;
)*
}
};
}
fn init_utils(lua_vm: &Lua, runtime_context: &RuntimeContext) -> Result<()> {
let table = lua_vm.create_table()?;
init_and_set!(
table,
lua_vm,
runtime_context,
file,
git,
web,
text,
rust,
path,
md,
json,
html,
cmd,
lua,
code,
hbs
);
let globals = lua_vm.globals();
globals.set("utils", table)?;
Ok(())
}
#[cfg(test)]
mod tests {
type Result<T> = core::result::Result<T, Box<dyn std::error::Error>>;
use super::*;
use crate::_test_support::SANDBOX_01_DIR;
use crate::run::Runtime;
use simple_fs::ensure_dir;
use std::path::Path;
#[tokio::test]
async fn test_lua_engine_eval_simple_ok() -> Result<()> {
let runtime = Runtime::new_test_runtime_sandbox_01()?;
let engine = LuaEngine::new(runtime.context().clone())?;
let fx_script = r#"
local square_root = math.sqrt(25)
return "Hello " .. my_name .. " - " .. square_root
"#;
let scope = engine.create_table()?;
scope.set("my_name", "Lua World")?;
let res = engine.eval(fx_script, Some(scope), None)?;
let res = serde_json::to_value(res)?;
let res = res.as_str().ok_or("Should be string")?;
assert_eq!(res, "Hello Lua World - 5.0");
Ok(())
}
#[tokio::test]
async fn test_lua_engine_eval_file_load_ok() -> Result<()> {
let runtime = Runtime::new_test_runtime_sandbox_01()?;
let engine = LuaEngine::new(runtime.context().clone())?;
let fx_script = r#"
local file = utils.file.load("other/hello.txt")
return "Hello " .. my_name .. " - " .. file.content
"#;
let scope = engine.create_table()?;
scope.set("my_name", "Lua World")?;
let res = engine.eval(fx_script, Some(scope), None)?;
let res = serde_json::to_value(res)?;
let res = res.as_str().ok_or("Should be string")?;
assert_eq!(res, "Hello Lua World - hello from the other/hello.txt");
Ok(())
}
#[tokio::test]
async fn test_lua_engine_eval_require_ok() -> Result<()> {
let runtime = Runtime::new_test_runtime_sandbox_01()?;
ensure_dir("tests-data/sandbox-01/.devai/custom/lua")?;
std::fs::copy(
Path::new(SANDBOX_01_DIR).join("other/demo.lua"),
"tests-data/sandbox-01/.devai/custom/lua/demo.lua",
)?;
let engine = LuaEngine::new(runtime.context().clone())?;
let fx_script = r#"
local demo = require("demo")
return "demo.name_one is " .. "'" .. demo.name_one .. "'"
"#;
let res = engine.eval(fx_script, None, None)?;
let res = serde_json::to_value(res)?;
let res = res.as_str().ok_or("Should be string")?;
assert_eq!(res, "demo.name_one is 'Demo One'");
Ok(())
}
}