use mlua::{Lua, Result as LuaResult, StdLib, Value};
use super::bindings::register_mdv_table;
use super::index_bindings::register_index_bindings;
use super::types::{SandboxConfig, ScriptingError};
use super::vault_bindings::register_vault_bindings;
use super::vault_context::VaultContext;
pub struct LuaEngine {
lua: Lua,
#[allow(dead_code)]
config: SandboxConfig,
}
impl LuaEngine {
pub fn new(config: SandboxConfig) -> Result<Self, ScriptingError> {
let libs = StdLib::TABLE | StdLib::STRING | StdLib::UTF8 | StdLib::MATH;
let lua = Lua::new_with(libs, mlua::LuaOptions::default())?;
if config.memory_limit > 0 {
lua.set_memory_limit(config.memory_limit)?;
}
Self::apply_sandbox(&lua)?;
register_mdv_table(&lua)?;
Ok(Self { lua, config })
}
pub fn sandboxed() -> Result<Self, ScriptingError> {
Self::new(SandboxConfig::restricted())
}
pub fn with_vault_context(
config: SandboxConfig,
vault_ctx: VaultContext,
) -> Result<Self, ScriptingError> {
let libs = StdLib::TABLE | StdLib::STRING | StdLib::UTF8 | StdLib::MATH;
let lua = Lua::new_with(libs, mlua::LuaOptions::default())?;
if config.memory_limit > 0 {
lua.set_memory_limit(config.memory_limit)?;
}
Self::apply_sandbox(&lua)?;
register_mdv_table(&lua)?;
register_vault_bindings(&lua, vault_ctx)?;
register_index_bindings(&lua)?;
Ok(Self { lua, config })
}
pub fn eval(&self, script: &str) -> Result<Option<String>, ScriptingError> {
let value: Value = self.lua.load(script).eval()?;
match value {
Value::Nil => Ok(None),
Value::String(s) => Ok(Some(s.to_str()?.to_string())),
Value::Integer(i) => Ok(Some(i.to_string())),
Value::Number(n) => Ok(Some(n.to_string())),
Value::Boolean(b) => Ok(Some(b.to_string())),
_ => Ok(Some(format!("{:?}", value))),
}
}
pub fn eval_string(&self, script: &str) -> Result<String, ScriptingError> {
self.eval(script)?.ok_or_else(|| {
ScriptingError::Lua(mlua::Error::runtime("script returned nil"))
})
}
pub fn eval_bool(&self, script: &str) -> Result<bool, ScriptingError> {
let value: Value = self.lua.load(script).eval()?;
match value {
Value::Boolean(b) => Ok(b),
Value::Nil => Ok(false),
_ => {
Err(ScriptingError::Lua(mlua::Error::runtime("expected boolean result")))
}
}
}
pub fn lua(&self) -> &Lua {
&self.lua
}
fn apply_sandbox(lua: &Lua) -> LuaResult<()> {
let globals = lua.globals();
globals.set("dofile", Value::Nil)?;
globals.set("loadfile", Value::Nil)?;
globals.set("load", Value::Nil)?;
globals.set("require", Value::Nil)?;
globals.set("package", Value::Nil)?;
globals.set("io", Value::Nil)?;
globals.set("os", Value::Nil)?;
globals.set("debug", Value::Nil)?;
globals.set("collectgarbage", Value::Nil)?;
Ok(())
}
}
impl Default for LuaEngine {
fn default() -> Self {
Self::sandboxed().expect("failed to create default Lua engine")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_date_basic() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_string(r#"mdv.date("today")"#).unwrap();
assert_eq!(result.len(), 10);
assert_eq!(result.chars().nth(4), Some('-'));
assert_eq!(result.chars().nth(7), Some('-'));
}
#[test]
fn test_date_with_offset() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_string(r#"mdv.date("today + 1d")"#);
assert!(result.is_ok());
}
#[test]
fn test_date_with_format() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_string(r#"mdv.date("today", "%A")"#).unwrap();
let valid_days = [
"Monday",
"Tuesday",
"Wednesday",
"Thursday",
"Friday",
"Saturday",
"Sunday",
];
assert!(valid_days.contains(&result.as_str()));
}
#[test]
fn test_date_week() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_string(r#"mdv.date("week")"#).unwrap();
let week: u32 = result.parse().expect("week should be a number");
assert!((1..=53).contains(&week));
}
#[test]
fn test_date_year() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_string(r#"mdv.date("year")"#).unwrap();
assert_eq!(result.len(), 4);
let year: u32 = result.parse().expect("year should be a number");
assert!(year >= 2020);
}
#[test]
fn test_render_basic() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine
.eval_string(r#"mdv.render("Hello {{name}}", { name = "World" })"#)
.unwrap();
assert_eq!(result, "Hello World");
}
#[test]
fn test_render_multiple_vars() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine
.eval_string(r#"mdv.render("{{greeting}}, {{name}}!", { greeting = "Hi", name = "Lua" })"#)
.unwrap();
assert_eq!(result, "Hi, Lua!");
}
#[test]
fn test_render_with_numbers() {
let engine = LuaEngine::sandboxed().unwrap();
let result =
engine.eval_string(r#"mdv.render("Count: {{n}}", { n = 42 })"#).unwrap();
assert_eq!(result, "Count: 42");
}
#[test]
fn test_render_with_date_expr() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_string(r#"mdv.render("Date: {{today}}", {})"#).unwrap();
assert!(result.starts_with("Date: "));
assert_eq!(result.len(), 16); }
#[test]
fn test_is_date_expr_true() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_bool(r#"mdv.is_date_expr("today + 1d")"#).unwrap();
assert!(result);
}
#[test]
fn test_is_date_expr_false() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_bool(r#"mdv.is_date_expr("hello")"#).unwrap();
assert!(!result);
}
#[test]
fn test_is_date_expr_week() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_bool(r#"mdv.is_date_expr("week/start")"#).unwrap();
assert!(result);
}
#[test]
fn test_sandbox_no_io() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval(r#"io"#).unwrap();
assert!(result.is_none(), "io should be nil in sandbox");
}
#[test]
fn test_sandbox_no_os() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval(r#"os"#).unwrap();
assert!(result.is_none(), "os should be nil in sandbox");
}
#[test]
fn test_sandbox_no_require() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval(r#"require"#).unwrap();
assert!(result.is_none(), "require should be nil in sandbox");
}
#[test]
fn test_sandbox_no_load() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval(r#"load"#).unwrap();
assert!(result.is_none(), "load should be nil in sandbox");
}
#[test]
fn test_sandbox_no_debug() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval(r#"debug"#).unwrap();
assert!(result.is_none(), "debug should be nil in sandbox");
}
#[test]
fn test_date_error_handling() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_string(r#"mdv.date("invalid_expr")"#);
assert!(result.is_err());
}
#[test]
fn test_pure_lua_math() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_string(r#"tostring(1 + 2)"#).unwrap();
assert_eq!(result, "3");
}
#[test]
fn test_pure_lua_string() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_string(r#"string.upper("hello")"#).unwrap();
assert_eq!(result, "HELLO");
}
#[test]
fn test_pure_lua_table() {
let engine = LuaEngine::sandboxed().unwrap();
let result =
engine.eval_string(r#"local t = {1, 2, 3}; return tostring(#t)"#).unwrap();
assert_eq!(result, "3");
}
#[test]
fn test_pure_lua_math_functions() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval_string(r#"tostring(math.floor(3.7))"#).unwrap();
assert_eq!(result, "3");
}
#[test]
fn test_eval_returns_none_for_nil() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval(r#"nil"#).unwrap();
assert!(result.is_none());
}
#[test]
fn test_eval_returns_none_for_no_return() {
let engine = LuaEngine::sandboxed().unwrap();
let result = engine.eval(r#"local x = 1"#).unwrap();
assert!(result.is_none());
}
}