Skip to main content

gravityfile_plugin/lua/
runtime.rs

1//! Lua runtime implementation.
2
3use std::collections::HashMap;
4use std::path::Path;
5
6use mlua::{Function, Lua, MultiValue, Table, Value as LuaValue};
7
8use crate::config::{PluginConfig, PluginMetadata};
9use crate::hooks::{Hook, HookContext, HookResult};
10use crate::runtime::{BoxFuture, IsolatedContext, PluginHandle, PluginRuntime};
11use crate::sandbox::SandboxConfig;
12use crate::types::{PluginError, PluginResult, Value};
13
14use super::bindings;
15use super::isolate::LuaIsolatedContext;
16
17/// A loaded Lua plugin.
18struct LoadedLuaPlugin {
19    /// Plugin name/id.
20    name: String,
21
22    /// The plugin's module table.
23    module: mlua::RegistryKey,
24
25    /// Plugin metadata.
26    metadata: PluginMetadata,
27
28    /// Hooks implemented by this plugin.
29    hooks: Vec<String>,
30}
31
32/// Lua plugin runtime.
33pub struct LuaRuntime {
34    /// The main Lua state.
35    lua: Lua,
36
37    /// Loaded plugins by handle.
38    plugins: HashMap<PluginHandle, LoadedLuaPlugin>,
39
40    /// Next plugin handle ID.
41    next_handle: usize,
42
43    /// Runtime configuration.
44    config: Option<PluginConfig>,
45
46    /// Sandbox configuration used to gate filesystem API access.
47    sandbox: SandboxConfig,
48
49    /// Whether the runtime has been initialized.
50    initialized: bool,
51}
52
53impl LuaRuntime {
54    /// Create a new Lua runtime.
55    pub fn new() -> PluginResult<Self> {
56        let lua = Lua::new();
57
58        // Disable potentially dangerous standard library functions
59        {
60            let globals = lua.globals();
61            let map_err = |e: mlua::Error| PluginError::LoadError {
62                name: "lua".into(),
63                message: e.to_string(),
64            };
65            globals.set("loadfile", LuaValue::Nil).map_err(map_err)?;
66            globals
67                .set("dofile", LuaValue::Nil)
68                .map_err(|e| PluginError::LoadError {
69                    name: "lua".into(),
70                    message: e.to_string(),
71                })?;
72            globals
73                .set("load", LuaValue::Nil)
74                .map_err(|e| PluginError::LoadError {
75                    name: "lua".into(),
76                    message: e.to_string(),
77                })?;
78            globals
79                .set("os", LuaValue::Nil)
80                .map_err(|e| PluginError::LoadError {
81                    name: "lua".into(),
82                    message: e.to_string(),
83                })?;
84            globals
85                .set("io", LuaValue::Nil)
86                .map_err(|e| PluginError::LoadError {
87                    name: "lua".into(),
88                    message: e.to_string(),
89                })?;
90            globals
91                .set("debug", LuaValue::Nil)
92                .map_err(|e| PluginError::LoadError {
93                    name: "lua".into(),
94                    message: e.to_string(),
95                })?;
96            globals
97                .set("require", LuaValue::Nil)
98                .map_err(|e| PluginError::LoadError {
99                    name: "lua".into(),
100                    message: e.to_string(),
101                })?;
102            globals
103                .set("package", LuaValue::Nil)
104                .map_err(|e| PluginError::LoadError {
105                    name: "lua".into(),
106                    message: e.to_string(),
107                })?;
108            // Disable string.dump to prevent bytecode extraction
109            if let Ok(string_table) = globals.get::<mlua::Table>("string") {
110                string_table.set("dump", LuaValue::Nil).ok();
111            }
112        }
113
114        Ok(Self {
115            lua,
116            plugins: HashMap::new(),
117            next_handle: 0,
118            config: None,
119            sandbox: SandboxConfig::default(),
120            initialized: false,
121        })
122    }
123
124    /// Initialize the Lua runtime with the gravityfile API.
125    fn init_api(&self) -> PluginResult<()> {
126        let globals = self.lua.globals();
127
128        // Create the 'gf' namespace (gravityfile API)
129        let gf = self
130            .lua
131            .create_table()
132            .map_err(|e| PluginError::LoadError {
133                name: "lua".into(),
134                message: format!("Failed to create gf table: {}", e),
135            })?;
136
137        // Add version info
138        gf.set("version", env!("CARGO_PKG_VERSION"))
139            .map_err(|e| PluginError::LoadError {
140                name: "lua".into(),
141                message: e.to_string(),
142            })?;
143
144        // Add logging functions
145        let log_info = self
146            .lua
147            .create_function(|_, msg: String| {
148                tracing::info!(target: "plugin", "{}", msg);
149                Ok(())
150            })
151            .map_err(|e| PluginError::LoadError {
152                name: "lua".into(),
153                message: e.to_string(),
154            })?;
155        gf.set("log_info", log_info).ok();
156
157        let log_warn = self
158            .lua
159            .create_function(|_, msg: String| {
160                tracing::warn!(target: "plugin", "{}", msg);
161                Ok(())
162            })
163            .map_err(|e| PluginError::LoadError {
164                name: "lua".into(),
165                message: e.to_string(),
166            })?;
167        gf.set("log_warn", log_warn).ok();
168
169        let log_error = self
170            .lua
171            .create_function(|_, msg: String| {
172                tracing::error!(target: "plugin", "{}", msg);
173                Ok(())
174            })
175            .map_err(|e| PluginError::LoadError {
176                name: "lua".into(),
177                message: e.to_string(),
178            })?;
179        gf.set("log_error", log_error).ok();
180
181        // Add notify function
182        let notify = self
183            .lua
184            .create_function(|_, (msg, level): (String, Option<String>)| {
185                let level = level.unwrap_or_else(|| "info".to_string());
186                tracing::info!(target: "plugin_notify", level = level, "{}", msg);
187                Ok(())
188            })
189            .map_err(|e| PluginError::LoadError {
190                name: "lua".into(),
191                message: e.to_string(),
192            })?;
193        gf.set("notify", notify).ok();
194
195        globals.set("gf", gf).map_err(|e| PluginError::LoadError {
196            name: "lua".into(),
197            message: e.to_string(),
198        })?;
199
200        // Create the 'fs' namespace (filesystem API), gated by the runtime sandbox.
201        let fs = bindings::create_fs_api(&self.lua, Some(self.sandbox.clone()))?;
202        globals.set("fs", fs).map_err(|e| PluginError::LoadError {
203            name: "lua".into(),
204            message: e.to_string(),
205        })?;
206
207        // Create the 'ui' namespace (UI elements)
208        let ui = bindings::create_ui_api(&self.lua)?;
209        globals.set("ui", ui).map_err(|e| PluginError::LoadError {
210            name: "lua".into(),
211            message: e.to_string(),
212        })?;
213
214        Ok(())
215    }
216
217    /// Convert a Lua value to our Value type.
218    fn lua_to_value(lua_val: LuaValue) -> Value {
219        match lua_val {
220            LuaValue::Nil => Value::Null,
221            LuaValue::Boolean(b) => Value::Bool(b),
222            LuaValue::Integer(i) => Value::Integer(i),
223            LuaValue::Number(n) => Value::Float(n),
224            LuaValue::String(s) => Value::String(s.to_string_lossy()),
225            LuaValue::Table(t) => {
226                // Check if it's an array or object
227                let mut is_array = true;
228                let mut max_index = 0i64;
229
230                for pair in t.clone().pairs::<i64, LuaValue>() {
231                    if let Ok((k, _)) = pair {
232                        if k > 0 {
233                            max_index = max_index.max(k);
234                        } else {
235                            is_array = false;
236                            break;
237                        }
238                    } else {
239                        is_array = false;
240                        break;
241                    }
242                }
243
244                if is_array && max_index > 0 {
245                    let mut arr = Vec::new();
246                    for i in 1..=max_index {
247                        if let Ok(v) = t.get::<LuaValue>(i) {
248                            arr.push(Self::lua_to_value(v));
249                        }
250                    }
251                    Value::Array(arr)
252                } else {
253                    let mut obj = std::collections::HashMap::new();
254                    for (k, v) in t.pairs::<String, LuaValue>().flatten() {
255                        obj.insert(k, Self::lua_to_value(v));
256                    }
257                    Value::Object(obj)
258                }
259            }
260            _ => Value::Null,
261        }
262    }
263
264    /// Convert our Value type to a Lua value.
265    fn value_to_lua(&self, lua: &Lua, val: &Value) -> mlua::Result<LuaValue> {
266        match val {
267            Value::Null => Ok(LuaValue::Nil),
268            Value::Bool(b) => Ok(LuaValue::Boolean(*b)),
269            Value::Integer(i) => Ok(LuaValue::Integer(*i)),
270            Value::Float(f) => Ok(LuaValue::Number(*f)),
271            Value::String(s) => Ok(LuaValue::String(lua.create_string(s)?)),
272            Value::Array(arr) => {
273                let table = lua.create_table()?;
274                for (i, v) in arr.iter().enumerate() {
275                    table.set(i + 1, self.value_to_lua(lua, v)?)?;
276                }
277                Ok(LuaValue::Table(table))
278            }
279            Value::Object(obj) => {
280                let table = lua.create_table()?;
281                for (k, v) in obj {
282                    table.set(k.as_str(), self.value_to_lua(lua, v)?)?;
283                }
284                Ok(LuaValue::Table(table))
285            }
286            Value::Bytes(b) => Ok(LuaValue::String(lua.create_string(b)?)),
287        }
288    }
289
290    /// Convert a Hook to a Lua table.
291    fn hook_to_lua(&self, lua: &Lua, hook: &Hook) -> mlua::Result<Table> {
292        let table = lua.create_table()?;
293
294        // Serialize hook to JSON then to Lua table
295        let json = serde_json::to_string(hook).map_err(mlua::Error::external)?;
296        let json_val: serde_json::Value =
297            serde_json::from_str(&json).map_err(mlua::Error::external)?;
298
299        fn json_to_lua(lua: &Lua, val: &serde_json::Value) -> mlua::Result<LuaValue> {
300            match val {
301                serde_json::Value::Null => Ok(LuaValue::Nil),
302                serde_json::Value::Bool(b) => Ok(LuaValue::Boolean(*b)),
303                serde_json::Value::Number(n) => {
304                    if let Some(i) = n.as_i64() {
305                        Ok(LuaValue::Integer(i))
306                    } else {
307                        Ok(LuaValue::Number(n.as_f64().unwrap_or(0.0)))
308                    }
309                }
310                serde_json::Value::String(s) => Ok(LuaValue::String(lua.create_string(s)?)),
311                serde_json::Value::Array(arr) => {
312                    let t = lua.create_table()?;
313                    for (i, v) in arr.iter().enumerate() {
314                        t.set(i + 1, json_to_lua(lua, v)?)?;
315                    }
316                    Ok(LuaValue::Table(t))
317                }
318                serde_json::Value::Object(obj) => {
319                    let t = lua.create_table()?;
320                    for (k, v) in obj {
321                        t.set(k.as_str(), json_to_lua(lua, v)?)?;
322                    }
323                    Ok(LuaValue::Table(t))
324                }
325            }
326        }
327
328        if let serde_json::Value::Object(obj) = json_val {
329            for (k, v) in obj {
330                table.set(k.as_str(), json_to_lua(lua, &v)?)?;
331            }
332        }
333
334        Ok(table)
335    }
336}
337
338impl Default for LuaRuntime {
339    fn default() -> Self {
340        Self::new().expect("Failed to create Lua runtime")
341    }
342}
343
344impl PluginRuntime for LuaRuntime {
345    fn name(&self) -> &'static str {
346        "lua"
347    }
348
349    fn file_extensions(&self) -> &'static [&'static str] {
350        &[".lua"]
351    }
352
353    fn init(&mut self, config: &PluginConfig) -> PluginResult<()> {
354        if self.initialized {
355            return Ok(());
356        }
357
358        // Build a sandbox from plugin config settings.
359        self.sandbox = SandboxConfig {
360            timeout_ms: config.default_timeout_ms,
361            max_memory: config.max_memory_mb * 1024 * 1024,
362            allow_network: config.allow_network,
363            ..SandboxConfig::default()
364        };
365
366        self.config = Some(config.clone());
367        self.init_api()?;
368        self.initialized = true;
369
370        Ok(())
371    }
372
373    fn load_plugin(&mut self, id: &str, source: &Path) -> PluginResult<PluginHandle> {
374        // Read the plugin source
375        let code = std::fs::read_to_string(source)?;
376
377        // Load and execute the plugin
378        let chunk = self.lua.load(&code).set_name(id);
379
380        let module: Table = chunk.eval().map_err(|e| PluginError::LoadError {
381            name: id.to_string(),
382            message: e.to_string(),
383        })?;
384
385        // Detect which hooks are implemented
386        let mut hooks = vec![];
387        for hook_name in [
388            "on_navigate",
389            "on_drill_down",
390            "on_back",
391            "on_scan_start",
392            "on_scan_progress",
393            "on_scan_complete",
394            "on_delete_start",
395            "on_delete_complete",
396            "on_copy_start",
397            "on_copy_complete",
398            "on_move_start",
399            "on_move_complete",
400            "on_render",
401            "on_action",
402            "on_mode_change",
403            "on_startup",
404            "on_shutdown",
405        ] {
406            if module.contains_key(hook_name).unwrap_or(false) {
407                hooks.push(hook_name.to_string());
408            }
409        }
410
411        // Store in registry
412        let key = self
413            .lua
414            .create_registry_value(module)
415            .map_err(|e| PluginError::LoadError {
416                name: id.to_string(),
417                message: e.to_string(),
418            })?;
419
420        let handle = PluginHandle::new(self.next_handle);
421        self.next_handle += 1;
422
423        // Create default metadata (would normally come from plugin.toml)
424        let metadata = PluginMetadata {
425            name: id.to_string(),
426            runtime: "lua".to_string(),
427            ..Default::default()
428        };
429
430        self.plugins.insert(
431            handle,
432            LoadedLuaPlugin {
433                name: id.to_string(),
434                module: key,
435                metadata,
436                hooks,
437            },
438        );
439
440        Ok(handle)
441    }
442
443    fn unload_plugin(&mut self, handle: PluginHandle) -> PluginResult<()> {
444        if let Some(plugin) = self.plugins.remove(&handle) {
445            self.lua.remove_registry_value(plugin.module).ok();
446        }
447        Ok(())
448    }
449
450    fn get_metadata(&self, handle: PluginHandle) -> Option<&PluginMetadata> {
451        self.plugins.get(&handle).map(|p| &p.metadata)
452    }
453
454    fn has_hook(&self, handle: PluginHandle, hook_name: &str) -> bool {
455        self.plugins
456            .get(&handle)
457            .map(|p| p.hooks.contains(&hook_name.to_string()))
458            .unwrap_or(false)
459    }
460
461    fn call_hook_sync(
462        &self,
463        handle: PluginHandle,
464        hook: &Hook,
465        _ctx: &HookContext,
466    ) -> PluginResult<HookResult> {
467        let plugin = self
468            .plugins
469            .get(&handle)
470            .ok_or_else(|| PluginError::NotFound {
471                path: std::path::PathBuf::new(),
472            })?;
473
474        let module: Table =
475            self.lua
476                .registry_value(&plugin.module)
477                .map_err(|e| PluginError::ExecutionError {
478                    name: plugin.name.clone(),
479                    message: e.to_string(),
480                })?;
481
482        let hook_name = hook.name();
483        let func: Function = match module.get(hook_name) {
484            Ok(f) => f,
485            Err(_) => return Ok(HookResult::default()),
486        };
487
488        // Convert hook and context to Lua
489        let hook_table =
490            self.hook_to_lua(&self.lua, hook)
491                .map_err(|e| PluginError::ExecutionError {
492                    name: plugin.name.clone(),
493                    message: e.to_string(),
494                })?;
495
496        // Call the function
497        let result: LuaValue =
498            func.call((module.clone(), hook_table))
499                .map_err(|e| PluginError::ExecutionError {
500                    name: plugin.name.clone(),
501                    message: e.to_string(),
502                })?;
503
504        // Convert result
505        let mut hook_result = HookResult::ok();
506        if let LuaValue::Table(t) = result {
507            if let Ok(prevent) = t.get::<bool>("prevent_default")
508                && prevent
509            {
510                hook_result = hook_result.prevent_default();
511            }
512            if let Ok(stop) = t.get::<bool>("stop_propagation")
513                && stop
514            {
515                hook_result = hook_result.stop_propagation();
516            }
517            if let Ok(val) = t.get::<LuaValue>("value") {
518                hook_result.value = Some(Self::lua_to_value(val));
519            }
520        }
521
522        Ok(hook_result)
523    }
524
525    fn call_hook_async<'a>(
526        &'a self,
527        handle: PluginHandle,
528        hook: &'a Hook,
529        ctx: &'a HookContext,
530    ) -> BoxFuture<'a, PluginResult<HookResult>> {
531        // For now, just call sync version
532        // TODO: Implement true async with spawn_blocking
533        Box::pin(async move { self.call_hook_sync(handle, hook, ctx) })
534    }
535
536    fn call_method<'a>(
537        &'a self,
538        handle: PluginHandle,
539        method: &'a str,
540        args: Vec<Value>,
541    ) -> BoxFuture<'a, PluginResult<Value>> {
542        Box::pin(async move {
543            let plugin = self
544                .plugins
545                .get(&handle)
546                .ok_or_else(|| PluginError::NotFound {
547                    path: std::path::PathBuf::new(),
548                })?;
549
550            let module: Table = self.lua.registry_value(&plugin.module).map_err(|e| {
551                PluginError::ExecutionError {
552                    name: plugin.name.clone(),
553                    message: e.to_string(),
554                }
555            })?;
556
557            let func: Function = module
558                .get(method)
559                .map_err(|e| PluginError::ExecutionError {
560                    name: plugin.name.clone(),
561                    message: format!("Method '{}' not found: {}", method, e),
562                })?;
563
564            // Convert args to Lua
565            let lua_args: Vec<LuaValue> = args
566                .iter()
567                .map(|v| self.value_to_lua(&self.lua, v))
568                .collect::<Result<_, _>>()
569                .map_err(|e| PluginError::ExecutionError {
570                    name: plugin.name.clone(),
571                    message: e.to_string(),
572                })?;
573
574            let result: LuaValue = func
575                .call(MultiValue::from_vec(
576                    std::iter::once(LuaValue::Table(module))
577                        .chain(lua_args)
578                        .collect(),
579                ))
580                .map_err(|e| PluginError::ExecutionError {
581                    name: plugin.name.clone(),
582                    message: e.to_string(),
583                })?;
584
585            Ok(Self::lua_to_value(result))
586        })
587    }
588
589    fn create_isolated_context(
590        &self,
591        sandbox: &SandboxConfig,
592    ) -> PluginResult<Box<dyn IsolatedContext>> {
593        Ok(Box::new(LuaIsolatedContext::new(sandbox.clone())?))
594    }
595
596    fn loaded_plugins(&self) -> Vec<PluginHandle> {
597        self.plugins.keys().copied().collect()
598    }
599
600    fn shutdown(&mut self) -> PluginResult<()> {
601        self.plugins.clear();
602        Ok(())
603    }
604}