mdvault_core/scripting/
engine.rs1use mlua::{Lua, Result as LuaResult, StdLib, Value};
7
8use super::bindings::register_mdv_table;
9use super::index_bindings::register_index_bindings;
10use super::types::{SandboxConfig, ScriptingError};
11use super::vault_bindings::register_vault_bindings;
12use super::vault_context::VaultContext;
13
14pub struct LuaEngine {
30 lua: Lua,
31 #[allow(dead_code)]
32 config: SandboxConfig,
33}
34
35impl LuaEngine {
36 pub fn new(config: SandboxConfig) -> Result<Self, ScriptingError> {
38 let libs = StdLib::TABLE | StdLib::STRING | StdLib::UTF8 | StdLib::MATH;
42
43 let lua = Lua::new_with(libs, mlua::LuaOptions::default())?;
44
45 if config.memory_limit > 0 {
47 lua.set_memory_limit(config.memory_limit)?;
48 }
49
50 Self::apply_sandbox(&lua)?;
52
53 register_mdv_table(&lua)?;
55
56 Ok(Self { lua, config })
57 }
58
59 pub fn sandboxed() -> Result<Self, ScriptingError> {
61 Self::new(SandboxConfig::restricted())
62 }
63
64 pub fn with_vault_context(
88 config: SandboxConfig,
89 vault_ctx: VaultContext,
90 ) -> Result<Self, ScriptingError> {
91 let libs = StdLib::TABLE | StdLib::STRING | StdLib::UTF8 | StdLib::MATH;
93 let lua = Lua::new_with(libs, mlua::LuaOptions::default())?;
94
95 if config.memory_limit > 0 {
97 lua.set_memory_limit(config.memory_limit)?;
98 }
99
100 Self::apply_sandbox(&lua)?;
102
103 register_mdv_table(&lua)?;
105
106 register_vault_bindings(&lua, vault_ctx)?;
108
109 register_index_bindings(&lua)?;
111
112 Ok(Self { lua, config })
113 }
114
115 pub fn eval(&self, script: &str) -> Result<Option<String>, ScriptingError> {
119 let value: Value = self.lua.load(script).eval()?;
120
121 match value {
122 Value::Nil => Ok(None),
123 Value::String(s) => Ok(Some(s.to_str()?.to_string())),
124 Value::Integer(i) => Ok(Some(i.to_string())),
125 Value::Number(n) => Ok(Some(n.to_string())),
126 Value::Boolean(b) => Ok(Some(b.to_string())),
127 _ => Ok(Some(format!("{:?}", value))),
128 }
129 }
130
131 pub fn eval_string(&self, script: &str) -> Result<String, ScriptingError> {
135 self.eval(script)?.ok_or_else(|| {
136 ScriptingError::Lua(mlua::Error::runtime("script returned nil"))
137 })
138 }
139
140 pub fn eval_bool(&self, script: &str) -> Result<bool, ScriptingError> {
142 let value: Value = self.lua.load(script).eval()?;
143 match value {
144 Value::Boolean(b) => Ok(b),
145 Value::Nil => Ok(false),
146 _ => {
147 Err(ScriptingError::Lua(mlua::Error::runtime("expected boolean result")))
148 }
149 }
150 }
151
152 pub fn lua(&self) -> &Lua {
154 &self.lua
155 }
156
157 fn apply_sandbox(lua: &Lua) -> LuaResult<()> {
159 let globals = lua.globals();
160
161 globals.set("dofile", Value::Nil)?;
170 globals.set("loadfile", Value::Nil)?;
171 globals.set("load", Value::Nil)?;
172 globals.set("require", Value::Nil)?;
173 globals.set("package", Value::Nil)?;
174 globals.set("io", Value::Nil)?;
175 globals.set("os", Value::Nil)?;
176 globals.set("debug", Value::Nil)?;
177 globals.set("collectgarbage", Value::Nil)?;
178
179 Ok(())
180 }
181}
182
183impl Default for LuaEngine {
184 fn default() -> Self {
185 Self::sandboxed().expect("failed to create default Lua engine")
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_date_basic() {
195 let engine = LuaEngine::sandboxed().unwrap();
196 let result = engine.eval_string(r#"mdv.date("today")"#).unwrap();
197 assert_eq!(result.len(), 10);
199 assert_eq!(result.chars().nth(4), Some('-'));
200 assert_eq!(result.chars().nth(7), Some('-'));
201 }
202
203 #[test]
204 fn test_date_with_offset() {
205 let engine = LuaEngine::sandboxed().unwrap();
206 let result = engine.eval_string(r#"mdv.date("today + 1d")"#);
208 assert!(result.is_ok());
209 }
210
211 #[test]
212 fn test_date_with_format() {
213 let engine = LuaEngine::sandboxed().unwrap();
214 let result = engine.eval_string(r#"mdv.date("today", "%A")"#).unwrap();
215 let valid_days = [
217 "Monday",
218 "Tuesday",
219 "Wednesday",
220 "Thursday",
221 "Friday",
222 "Saturday",
223 "Sunday",
224 ];
225 assert!(valid_days.contains(&result.as_str()));
226 }
227
228 #[test]
229 fn test_date_week() {
230 let engine = LuaEngine::sandboxed().unwrap();
231 let result = engine.eval_string(r#"mdv.date("week")"#).unwrap();
232 let week: u32 = result.parse().expect("week should be a number");
234 assert!((1..=53).contains(&week));
235 }
236
237 #[test]
238 fn test_date_year() {
239 let engine = LuaEngine::sandboxed().unwrap();
240 let result = engine.eval_string(r#"mdv.date("year")"#).unwrap();
241 assert_eq!(result.len(), 4);
243 let year: u32 = result.parse().expect("year should be a number");
244 assert!(year >= 2020);
245 }
246
247 #[test]
248 fn test_render_basic() {
249 let engine = LuaEngine::sandboxed().unwrap();
250 let result = engine
251 .eval_string(r#"mdv.render("Hello {{name}}", { name = "World" })"#)
252 .unwrap();
253 assert_eq!(result, "Hello World");
254 }
255
256 #[test]
257 fn test_render_multiple_vars() {
258 let engine = LuaEngine::sandboxed().unwrap();
259 let result = engine
260 .eval_string(r#"mdv.render("{{greeting}}, {{name}}!", { greeting = "Hi", name = "Lua" })"#)
261 .unwrap();
262 assert_eq!(result, "Hi, Lua!");
263 }
264
265 #[test]
266 fn test_render_with_numbers() {
267 let engine = LuaEngine::sandboxed().unwrap();
268 let result =
269 engine.eval_string(r#"mdv.render("Count: {{n}}", { n = 42 })"#).unwrap();
270 assert_eq!(result, "Count: 42");
271 }
272
273 #[test]
274 fn test_render_with_date_expr() {
275 let engine = LuaEngine::sandboxed().unwrap();
276 let result = engine.eval_string(r#"mdv.render("Date: {{today}}", {})"#).unwrap();
278 assert!(result.starts_with("Date: "));
280 assert_eq!(result.len(), 16); }
282
283 #[test]
284 fn test_is_date_expr_true() {
285 let engine = LuaEngine::sandboxed().unwrap();
286 let result = engine.eval_bool(r#"mdv.is_date_expr("today + 1d")"#).unwrap();
287 assert!(result);
288 }
289
290 #[test]
291 fn test_is_date_expr_false() {
292 let engine = LuaEngine::sandboxed().unwrap();
293 let result = engine.eval_bool(r#"mdv.is_date_expr("hello")"#).unwrap();
294 assert!(!result);
295 }
296
297 #[test]
298 fn test_is_date_expr_week() {
299 let engine = LuaEngine::sandboxed().unwrap();
300 let result = engine.eval_bool(r#"mdv.is_date_expr("week/start")"#).unwrap();
301 assert!(result);
302 }
303
304 #[test]
305 fn test_sandbox_no_io() {
306 let engine = LuaEngine::sandboxed().unwrap();
307 let result = engine.eval(r#"io"#).unwrap();
308 assert!(result.is_none(), "io should be nil in sandbox");
309 }
310
311 #[test]
312 fn test_sandbox_no_os() {
313 let engine = LuaEngine::sandboxed().unwrap();
314 let result = engine.eval(r#"os"#).unwrap();
315 assert!(result.is_none(), "os should be nil in sandbox");
316 }
317
318 #[test]
319 fn test_sandbox_no_require() {
320 let engine = LuaEngine::sandboxed().unwrap();
321 let result = engine.eval(r#"require"#).unwrap();
322 assert!(result.is_none(), "require should be nil in sandbox");
323 }
324
325 #[test]
326 fn test_sandbox_no_load() {
327 let engine = LuaEngine::sandboxed().unwrap();
328 let result = engine.eval(r#"load"#).unwrap();
329 assert!(result.is_none(), "load should be nil in sandbox");
330 }
331
332 #[test]
333 fn test_sandbox_no_debug() {
334 let engine = LuaEngine::sandboxed().unwrap();
335 let result = engine.eval(r#"debug"#).unwrap();
336 assert!(result.is_none(), "debug should be nil in sandbox");
337 }
338
339 #[test]
340 fn test_date_error_handling() {
341 let engine = LuaEngine::sandboxed().unwrap();
342 let result = engine.eval_string(r#"mdv.date("invalid_expr")"#);
343 assert!(result.is_err());
344 }
345
346 #[test]
347 fn test_pure_lua_math() {
348 let engine = LuaEngine::sandboxed().unwrap();
349 let result = engine.eval_string(r#"tostring(1 + 2)"#).unwrap();
350 assert_eq!(result, "3");
351 }
352
353 #[test]
354 fn test_pure_lua_string() {
355 let engine = LuaEngine::sandboxed().unwrap();
356 let result = engine.eval_string(r#"string.upper("hello")"#).unwrap();
357 assert_eq!(result, "HELLO");
358 }
359
360 #[test]
361 fn test_pure_lua_table() {
362 let engine = LuaEngine::sandboxed().unwrap();
363 let result =
364 engine.eval_string(r#"local t = {1, 2, 3}; return tostring(#t)"#).unwrap();
365 assert_eq!(result, "3");
366 }
367
368 #[test]
369 fn test_pure_lua_math_functions() {
370 let engine = LuaEngine::sandboxed().unwrap();
371 let result = engine.eval_string(r#"tostring(math.floor(3.7))"#).unwrap();
372 assert_eq!(result, "3");
373 }
374
375 #[test]
376 fn test_eval_returns_none_for_nil() {
377 let engine = LuaEngine::sandboxed().unwrap();
378 let result = engine.eval(r#"nil"#).unwrap();
379 assert!(result.is_none());
380 }
381
382 #[test]
383 fn test_eval_returns_none_for_no_return() {
384 let engine = LuaEngine::sandboxed().unwrap();
385 let result = engine.eval(r#"local x = 1"#).unwrap();
386 assert!(result.is_none());
387 }
388}