Skip to main content

mdvault_core/scripting/
engine.rs

1//! Lua scripting engine with sandboxing.
2//!
3//! This module provides a sandboxed Lua execution environment for
4//! running user-defined scripts safely.
5
6use mlua::{Lua, Result as LuaResult, StdLib, Value};
7
8use super::bindings::register_mdv_table;
9use super::index_bindings::register_index_bindings;
10use super::types::{SandboxConfig, ScriptingError};
11use super::vault_bindings::register_vault_bindings;
12use super::vault_context::VaultContext;
13
14/// A sandboxed Lua execution environment.
15///
16/// The engine provides access to mdvault functionality through the `mdv`
17/// global table while restricting dangerous operations like file I/O
18/// and shell execution.
19///
20/// # Example
21///
22/// ```rust
23/// use mdvault_core::scripting::LuaEngine;
24///
25/// let engine = LuaEngine::sandboxed().unwrap();
26/// let result = engine.eval_string(r#"mdv.date("today + 7d")"#).unwrap();
27/// println!("One week from now: {}", result);
28/// ```
29pub struct LuaEngine {
30    lua: Lua,
31    #[allow(dead_code)]
32    config: SandboxConfig,
33}
34
35impl LuaEngine {
36    /// Create a new Lua engine with the given sandbox configuration.
37    pub fn new(config: SandboxConfig) -> Result<Self, ScriptingError> {
38        // Create Lua with restricted standard library
39        // Note: base functions (print, type, tostring, etc.) are always available
40        // We add: table, string, utf8, math
41        let libs = StdLib::TABLE | StdLib::STRING | StdLib::UTF8 | StdLib::MATH;
42
43        let lua = Lua::new_with(libs, mlua::LuaOptions::default())?;
44
45        // Apply memory limit if configured
46        if config.memory_limit > 0 {
47            lua.set_memory_limit(config.memory_limit)?;
48        }
49
50        // Remove dangerous globals
51        Self::apply_sandbox(&lua)?;
52
53        // Register mdv bindings
54        register_mdv_table(&lua)?;
55
56        Ok(Self { lua, config })
57    }
58
59    /// Create a new engine with default restrictive sandbox.
60    pub fn sandboxed() -> Result<Self, ScriptingError> {
61        Self::new(SandboxConfig::restricted())
62    }
63
64    /// Create a new Lua engine with vault context for hook execution.
65    ///
66    /// This provides access to `mdv.template()`, `mdv.capture()`, `mdv.macro()`
67    /// and index query functions in addition to the standard sandboxed bindings.
68    ///
69    /// # Example
70    ///
71    /// ```ignore
72    /// use mdvault_core::scripting::{LuaEngine, VaultContext, SandboxConfig};
73    ///
74    /// let vault_ctx = VaultContext::new(config, templates, captures, macros, types);
75    /// let engine = LuaEngine::with_vault_context(SandboxConfig::restricted(), vault_ctx)?;
76    ///
77    /// // Now Lua scripts can use vault operations
78    /// engine.eval_string(r#"
79    ///     local ok, err = mdv.capture("log-to-daily", { text = "Hello" })
80    /// "#)?;
81    ///
82    /// // And query the index (if available)
83    /// engine.eval_string(r#"
84    ///     local tasks = mdv.query({ type = "task" })
85    /// "#)?;
86    /// ```
87    pub fn with_vault_context(
88        config: SandboxConfig,
89        vault_ctx: VaultContext,
90    ) -> Result<Self, ScriptingError> {
91        // Create Lua with restricted standard library
92        let libs = StdLib::TABLE | StdLib::STRING | StdLib::UTF8 | StdLib::MATH;
93        let lua = Lua::new_with(libs, mlua::LuaOptions::default())?;
94
95        // Apply memory limit if configured
96        if config.memory_limit > 0 {
97            lua.set_memory_limit(config.memory_limit)?;
98        }
99
100        // Remove dangerous globals
101        Self::apply_sandbox(&lua)?;
102
103        // Register standard mdv bindings
104        register_mdv_table(&lua)?;
105
106        // Register vault operation bindings
107        register_vault_bindings(&lua, vault_ctx)?;
108
109        // Register index query bindings (uses VaultContext from app_data)
110        register_index_bindings(&lua)?;
111
112        Ok(Self { lua, config })
113    }
114
115    /// Execute a Lua script and return the result.
116    ///
117    /// Returns `None` if the script returns nil or no value.
118    pub fn eval(&self, script: &str) -> Result<Option<String>, ScriptingError> {
119        let value: Value = self.lua.load(script).eval()?;
120
121        match value {
122            Value::Nil => Ok(None),
123            Value::String(s) => Ok(Some(s.to_str()?.to_string())),
124            Value::Integer(i) => Ok(Some(i.to_string())),
125            Value::Number(n) => Ok(Some(n.to_string())),
126            Value::Boolean(b) => Ok(Some(b.to_string())),
127            _ => Ok(Some(format!("{:?}", value))),
128        }
129    }
130
131    /// Execute a Lua script that must return a string value.
132    ///
133    /// Returns an error if the script returns nil.
134    pub fn eval_string(&self, script: &str) -> Result<String, ScriptingError> {
135        self.eval(script)?.ok_or_else(|| {
136            ScriptingError::Lua(mlua::Error::runtime("script returned nil"))
137        })
138    }
139
140    /// Execute a Lua script that returns a boolean.
141    pub fn eval_bool(&self, script: &str) -> Result<bool, ScriptingError> {
142        let value: Value = self.lua.load(script).eval()?;
143        match value {
144            Value::Boolean(b) => Ok(b),
145            Value::Nil => Ok(false),
146            _ => {
147                Err(ScriptingError::Lua(mlua::Error::runtime("expected boolean result")))
148            }
149        }
150    }
151
152    /// Get a reference to the underlying Lua state (for advanced usage).
153    pub fn lua(&self) -> &Lua {
154        &self.lua
155    }
156
157    /// Apply sandbox restrictions by removing dangerous globals.
158    fn apply_sandbox(lua: &Lua) -> LuaResult<()> {
159        let globals = lua.globals();
160
161        // Remove dangerous functions that could:
162        // - Execute arbitrary code: load, loadfile, dofile
163        // - Access the filesystem: io
164        // - Execute system commands: os
165        // - Load external modules: require, package
166        // - Inspect/modify internals: debug
167        // - Cause resource exhaustion: collectgarbage
168
169        globals.set("dofile", Value::Nil)?;
170        globals.set("loadfile", Value::Nil)?;
171        globals.set("load", Value::Nil)?;
172        globals.set("require", Value::Nil)?;
173        globals.set("package", Value::Nil)?;
174        globals.set("io", Value::Nil)?;
175        globals.set("os", Value::Nil)?;
176        globals.set("debug", Value::Nil)?;
177        globals.set("collectgarbage", Value::Nil)?;
178
179        Ok(())
180    }
181}
182
183impl Default for LuaEngine {
184    fn default() -> Self {
185        Self::sandboxed().expect("failed to create default Lua engine")
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_date_basic() {
195        let engine = LuaEngine::sandboxed().unwrap();
196        let result = engine.eval_string(r#"mdv.date("today")"#).unwrap();
197        // Should be in YYYY-MM-DD format
198        assert_eq!(result.len(), 10);
199        assert_eq!(result.chars().nth(4), Some('-'));
200        assert_eq!(result.chars().nth(7), Some('-'));
201    }
202
203    #[test]
204    fn test_date_with_offset() {
205        let engine = LuaEngine::sandboxed().unwrap();
206        // Just verify it doesn't error - exact value depends on current date
207        let result = engine.eval_string(r#"mdv.date("today + 1d")"#);
208        assert!(result.is_ok());
209    }
210
211    #[test]
212    fn test_date_with_format() {
213        let engine = LuaEngine::sandboxed().unwrap();
214        let result = engine.eval_string(r#"mdv.date("today", "%A")"#).unwrap();
215        // Should be a weekday name
216        let valid_days = [
217            "Monday",
218            "Tuesday",
219            "Wednesday",
220            "Thursday",
221            "Friday",
222            "Saturday",
223            "Sunday",
224        ];
225        assert!(valid_days.contains(&result.as_str()));
226    }
227
228    #[test]
229    fn test_date_week() {
230        let engine = LuaEngine::sandboxed().unwrap();
231        let result = engine.eval_string(r#"mdv.date("week")"#).unwrap();
232        // Should be a number between 1 and 53
233        let week: u32 = result.parse().expect("week should be a number");
234        assert!((1..=53).contains(&week));
235    }
236
237    #[test]
238    fn test_date_year() {
239        let engine = LuaEngine::sandboxed().unwrap();
240        let result = engine.eval_string(r#"mdv.date("year")"#).unwrap();
241        // Should be a 4-digit year
242        assert_eq!(result.len(), 4);
243        let year: u32 = result.parse().expect("year should be a number");
244        assert!(year >= 2020);
245    }
246
247    #[test]
248    fn test_render_basic() {
249        let engine = LuaEngine::sandboxed().unwrap();
250        let result = engine
251            .eval_string(r#"mdv.render("Hello {{name}}", { name = "World" })"#)
252            .unwrap();
253        assert_eq!(result, "Hello World");
254    }
255
256    #[test]
257    fn test_render_multiple_vars() {
258        let engine = LuaEngine::sandboxed().unwrap();
259        let result = engine
260            .eval_string(r#"mdv.render("{{greeting}}, {{name}}!", { greeting = "Hi", name = "Lua" })"#)
261            .unwrap();
262        assert_eq!(result, "Hi, Lua!");
263    }
264
265    #[test]
266    fn test_render_with_numbers() {
267        let engine = LuaEngine::sandboxed().unwrap();
268        let result =
269            engine.eval_string(r#"mdv.render("Count: {{n}}", { n = 42 })"#).unwrap();
270        assert_eq!(result, "Count: 42");
271    }
272
273    #[test]
274    fn test_render_with_date_expr() {
275        let engine = LuaEngine::sandboxed().unwrap();
276        // Template engine should handle date expressions in templates
277        let result = engine.eval_string(r#"mdv.render("Date: {{today}}", {})"#).unwrap();
278        // Should contain "Date: " followed by a date
279        assert!(result.starts_with("Date: "));
280        assert_eq!(result.len(), 16); // "Date: " + "YYYY-MM-DD"
281    }
282
283    #[test]
284    fn test_is_date_expr_true() {
285        let engine = LuaEngine::sandboxed().unwrap();
286        let result = engine.eval_bool(r#"mdv.is_date_expr("today + 1d")"#).unwrap();
287        assert!(result);
288    }
289
290    #[test]
291    fn test_is_date_expr_false() {
292        let engine = LuaEngine::sandboxed().unwrap();
293        let result = engine.eval_bool(r#"mdv.is_date_expr("hello")"#).unwrap();
294        assert!(!result);
295    }
296
297    #[test]
298    fn test_is_date_expr_week() {
299        let engine = LuaEngine::sandboxed().unwrap();
300        let result = engine.eval_bool(r#"mdv.is_date_expr("week/start")"#).unwrap();
301        assert!(result);
302    }
303
304    #[test]
305    fn test_sandbox_no_io() {
306        let engine = LuaEngine::sandboxed().unwrap();
307        let result = engine.eval(r#"io"#).unwrap();
308        assert!(result.is_none(), "io should be nil in sandbox");
309    }
310
311    #[test]
312    fn test_sandbox_no_os() {
313        let engine = LuaEngine::sandboxed().unwrap();
314        let result = engine.eval(r#"os"#).unwrap();
315        assert!(result.is_none(), "os should be nil in sandbox");
316    }
317
318    #[test]
319    fn test_sandbox_no_require() {
320        let engine = LuaEngine::sandboxed().unwrap();
321        let result = engine.eval(r#"require"#).unwrap();
322        assert!(result.is_none(), "require should be nil in sandbox");
323    }
324
325    #[test]
326    fn test_sandbox_no_load() {
327        let engine = LuaEngine::sandboxed().unwrap();
328        let result = engine.eval(r#"load"#).unwrap();
329        assert!(result.is_none(), "load should be nil in sandbox");
330    }
331
332    #[test]
333    fn test_sandbox_no_debug() {
334        let engine = LuaEngine::sandboxed().unwrap();
335        let result = engine.eval(r#"debug"#).unwrap();
336        assert!(result.is_none(), "debug should be nil in sandbox");
337    }
338
339    #[test]
340    fn test_date_error_handling() {
341        let engine = LuaEngine::sandboxed().unwrap();
342        let result = engine.eval_string(r#"mdv.date("invalid_expr")"#);
343        assert!(result.is_err());
344    }
345
346    #[test]
347    fn test_pure_lua_math() {
348        let engine = LuaEngine::sandboxed().unwrap();
349        let result = engine.eval_string(r#"tostring(1 + 2)"#).unwrap();
350        assert_eq!(result, "3");
351    }
352
353    #[test]
354    fn test_pure_lua_string() {
355        let engine = LuaEngine::sandboxed().unwrap();
356        let result = engine.eval_string(r#"string.upper("hello")"#).unwrap();
357        assert_eq!(result, "HELLO");
358    }
359
360    #[test]
361    fn test_pure_lua_table() {
362        let engine = LuaEngine::sandboxed().unwrap();
363        let result =
364            engine.eval_string(r#"local t = {1, 2, 3}; return tostring(#t)"#).unwrap();
365        assert_eq!(result, "3");
366    }
367
368    #[test]
369    fn test_pure_lua_math_functions() {
370        let engine = LuaEngine::sandboxed().unwrap();
371        let result = engine.eval_string(r#"tostring(math.floor(3.7))"#).unwrap();
372        assert_eq!(result, "3");
373    }
374
375    #[test]
376    fn test_eval_returns_none_for_nil() {
377        let engine = LuaEngine::sandboxed().unwrap();
378        let result = engine.eval(r#"nil"#).unwrap();
379        assert!(result.is_none());
380    }
381
382    #[test]
383    fn test_eval_returns_none_for_no_return() {
384        let engine = LuaEngine::sandboxed().unwrap();
385        let result = engine.eval(r#"local x = 1"#).unwrap();
386        assert!(result.is_none());
387    }
388}