Skip to main content

brainwires_code_interpreters/languages/
lua.rs

1//! Lua executor - Small, fast scripting language
2//!
3//! Lua is lightweight and fast, making it ideal for embedded scripting.
4//! Uses mlua which supports Lua 5.4 with vendored builds.
5//!
6//! ## Features
7//! - Very small runtime footprint
8//! - Fast execution
9//! - Good for game scripting and configuration
10//! - Memory limit support
11//!
12//! ## Limitations
13//! - Smaller ecosystem than Python/JS
14//! - 1-indexed arrays (can be confusing)
15
16use mlua::{Lua, MultiValue, Result as LuaResult, Value};
17use std::sync::{Arc, Mutex};
18use std::time::Instant;
19
20use super::{get_limits, truncate_output, LanguageExecutor};
21use crate::types::{ExecutionLimits, ExecutionRequest, ExecutionResult};
22
23/// Lua code executor
24pub struct LuaExecutor {
25    _limits: ExecutionLimits,
26}
27
28impl LuaExecutor {
29    /// Create a new Lua executor with default limits
30    pub fn new() -> Self {
31        Self {
32            _limits: ExecutionLimits::default(),
33        }
34    }
35
36    /// Create a new Lua executor with custom limits
37    pub fn with_limits(limits: ExecutionLimits) -> Self {
38        Self { _limits: limits }
39    }
40
41    /// Execute Lua code
42    pub fn execute_code(&self, request: &ExecutionRequest) -> ExecutionResult {
43        let limits = get_limits(request);
44        let start = Instant::now();
45
46        // Create Lua instance
47        let lua = Lua::new();
48
49        // Set memory limit
50        let _ = lua.set_memory_limit(limits.max_memory_mb as usize * 1024 * 1024);
51
52        // Capture print output
53        let output = Arc::new(Mutex::new(Vec::<String>::new()));
54
55        // Override print function to capture output
56        if let Err(e) = self.setup_print(&lua, output.clone()) {
57            return ExecutionResult::error(
58                format!("Failed to setup print: {}", e),
59                start.elapsed().as_millis() as u64,
60            );
61        }
62
63        // Inject context variables
64        if let Some(context) = &request.context
65            && let Err(e) = self.inject_context(&lua, context) {
66                return ExecutionResult::error(
67                    format!("Failed to inject context: {}", e),
68                    start.elapsed().as_millis() as u64,
69                );
70            }
71
72        // Execute the code
73        let result = lua.load(&request.code).eval::<Value>();
74        let timing_ms = start.elapsed().as_millis() as u64;
75
76        // Get captured output
77        let stdout = output
78            .lock()
79            .map(|out| out.join("\n"))
80            .unwrap_or_default();
81        let stdout = truncate_output(&stdout, limits.max_output_bytes);
82
83        // Get memory usage
84        let memory_used = lua.used_memory() as u64;
85
86        match result {
87            Ok(value) => {
88                let result_value = lua_to_json(&value);
89                let mut stdout_with_result = stdout;
90
91                // If there's a non-nil result, append it to stdout
92                if !matches!(value, Value::Nil) {
93                    if !stdout_with_result.is_empty() {
94                        stdout_with_result.push('\n');
95                    }
96                    stdout_with_result.push_str(&format_lua_value(&value));
97                }
98
99                ExecutionResult {
100                    success: true,
101                    stdout: stdout_with_result,
102                    stderr: String::new(),
103                    result: result_value,
104                    error: None,
105                    timing_ms,
106                    memory_used_bytes: Some(memory_used),
107                    operations_count: None,
108                }
109            }
110            Err(e) => {
111                let error_message = format_lua_error(&e);
112                ExecutionResult {
113                    success: false,
114                    stdout,
115                    stderr: error_message.clone(),
116                    result: None,
117                    error: Some(error_message),
118                    timing_ms,
119                    memory_used_bytes: Some(memory_used),
120                    operations_count: None,
121                }
122            }
123        }
124    }
125
126    /// Setup print function to capture output
127    fn setup_print(&self, lua: &Lua, output: Arc<Mutex<Vec<String>>>) -> LuaResult<()> {
128        let print_fn = lua.create_function(move |_, args: MultiValue| {
129            let parts: Vec<String> = args
130                .into_iter()
131                .map(|v| format_lua_value(&v))
132                .collect();
133            let line = parts.join("\t");
134
135            if let Ok(mut out) = output.lock() {
136                out.push(line);
137            }
138            Ok(())
139        })?;
140
141        lua.globals().set("print", print_fn)?;
142        Ok(())
143    }
144
145    /// Inject context variables into Lua globals
146    fn inject_context(&self, lua: &Lua, context: &serde_json::Value) -> LuaResult<()> {
147        if let serde_json::Value::Object(map) = context {
148            let globals = lua.globals();
149            for (key, value) in map {
150                let lua_value = json_to_lua(lua, value)?;
151                globals.set(key.as_str(), lua_value)?;
152            }
153        }
154        Ok(())
155    }
156}
157
158impl Default for LuaExecutor {
159    fn default() -> Self {
160        Self::new()
161    }
162}
163
164impl LanguageExecutor for LuaExecutor {
165    fn execute(&self, request: &ExecutionRequest) -> ExecutionResult {
166        self.execute_code(request)
167    }
168
169    fn language_name(&self) -> &'static str {
170        "lua"
171    }
172
173    fn language_version(&self) -> String {
174        "5.4".to_string()
175    }
176}
177
178/// Convert JSON value to Lua value
179fn json_to_lua(lua: &Lua, value: &serde_json::Value) -> LuaResult<Value> {
180    match value {
181        serde_json::Value::Null => Ok(Value::Nil),
182        serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
183        serde_json::Value::Number(n) => {
184            if let Some(i) = n.as_i64() {
185                Ok(Value::Integer(i))
186            } else if let Some(f) = n.as_f64() {
187                Ok(Value::Number(f))
188            } else {
189                Ok(Value::Nil)
190            }
191        }
192        serde_json::Value::String(s) => {
193            let lua_string = lua.create_string(s)?;
194            Ok(Value::String(lua_string))
195        }
196        serde_json::Value::Array(arr) => {
197            let table = lua.create_table()?;
198            for (i, v) in arr.iter().enumerate() {
199                let lua_value = json_to_lua(lua, v)?;
200                table.set(i + 1, lua_value)?; // Lua is 1-indexed
201            }
202            Ok(Value::Table(table))
203        }
204        serde_json::Value::Object(obj) => {
205            let table = lua.create_table()?;
206            for (k, v) in obj {
207                let lua_value = json_to_lua(lua, v)?;
208                table.set(k.as_str(), lua_value)?;
209            }
210            Ok(Value::Table(table))
211        }
212    }
213}
214
215/// Convert Lua value to JSON
216fn lua_to_json(value: &Value) -> Option<serde_json::Value> {
217    match value {
218        Value::Nil => None,
219        Value::Boolean(b) => Some(serde_json::Value::Bool(*b)),
220        Value::Integer(i) => Some(serde_json::Value::Number(serde_json::Number::from(*i))),
221        Value::Number(f) => serde_json::Number::from_f64(*f).map(serde_json::Value::Number),
222        Value::String(s) => s.to_str().ok().map(|s| serde_json::Value::String(s.to_string())),
223        Value::Table(t) => {
224            // Try to determine if it's an array or object
225            let mut is_array = true;
226            let mut max_index = 0i64;
227            let mut has_string_keys = false;
228
229            // Check keys
230            if let Ok(pairs) = t.clone().pairs::<Value, Value>().collect::<LuaResult<Vec<_>>>() {
231                for (k, _) in &pairs {
232                    match k {
233                        Value::Integer(i) if *i > 0 => {
234                            max_index = max_index.max(*i);
235                        }
236                        Value::String(_) => {
237                            has_string_keys = true;
238                            is_array = false;
239                        }
240                        _ => {
241                            is_array = false;
242                        }
243                    }
244                }
245
246                if is_array && !has_string_keys && max_index > 0 {
247                    // It's an array
248                    let mut arr = Vec::new();
249                    for i in 1..=max_index {
250                        if let Ok(v) = t.get::<Value>(i) {
251                            arr.push(lua_to_json(&v).unwrap_or(serde_json::Value::Null));
252                        }
253                    }
254                    Some(serde_json::Value::Array(arr))
255                } else {
256                    // It's an object
257                    let mut map = serde_json::Map::new();
258                    for (k, v) in pairs {
259                        let key = format_lua_value(&k);
260                        if let Some(json_v) = lua_to_json(&v) {
261                            map.insert(key, json_v);
262                        }
263                    }
264                    Some(serde_json::Value::Object(map))
265                }
266            } else {
267                None
268            }
269        }
270        Value::Function(_) => Some(serde_json::Value::String("[function]".to_string())),
271        Value::Thread(_) => Some(serde_json::Value::String("[thread]".to_string())),
272        Value::UserData(_) => Some(serde_json::Value::String("[userdata]".to_string())),
273        Value::LightUserData(_) => Some(serde_json::Value::String("[lightuserdata]".to_string())),
274        Value::Error(e) => Some(serde_json::Value::String(format!("[error: {}]", e))),
275        _ => Some(serde_json::Value::String("[unknown]".to_string())),
276    }
277}
278
279/// Format Lua value for display
280fn format_lua_value(value: &Value) -> String {
281    match value {
282        Value::Nil => "nil".to_string(),
283        Value::Boolean(b) => b.to_string(),
284        Value::Integer(i) => i.to_string(),
285        Value::Number(f) => f.to_string(),
286        Value::String(s) => s.to_str().map(|s| s.to_string()).unwrap_or_else(|_| "[invalid utf8]".to_string()),
287        Value::Table(t) => {
288            // Simple table representation
289            let mut parts = Vec::new();
290            if let Ok(pairs) = t.clone().pairs::<Value, Value>().collect::<LuaResult<Vec<_>>>() {
291                for (k, v) in pairs.iter().take(10) {
292                    parts.push(format!("{}={}", format_lua_value(k), format_lua_value(v)));
293                }
294                if pairs.len() > 10 {
295                    parts.push("...".to_string());
296                }
297            }
298            format!("{{{}}}", parts.join(", "))
299        }
300        Value::Function(_) => "[function]".to_string(),
301        Value::Thread(_) => "[thread]".to_string(),
302        Value::UserData(_) => "[userdata]".to_string(),
303        Value::LightUserData(_) => "[lightuserdata]".to_string(),
304        Value::Error(e) => format!("[error: {}]", e),
305        _ => "[unknown]".to_string(),
306    }
307}
308
309/// Format Lua error for user display
310fn format_lua_error(error: &mlua::Error) -> String {
311    match error {
312        mlua::Error::SyntaxError { message, .. } => {
313            format!("Syntax error: {}", message)
314        }
315        mlua::Error::RuntimeError(msg) => {
316            format!("Runtime error: {}", msg)
317        }
318        mlua::Error::MemoryError(msg) => {
319            format!("Memory error: {}", msg)
320        }
321        mlua::Error::CallbackError { traceback, cause } => {
322            format!("Callback error: {}\nTraceback: {}", cause, traceback)
323        }
324        _ => format!("{}", error),
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use crate::types::Language;
332
333    fn make_request(code: &str) -> ExecutionRequest {
334        ExecutionRequest {
335            language: Language::Lua,
336            code: code.to_string(),
337            ..Default::default()
338        }
339    }
340
341    #[test]
342    fn test_simple_expression() {
343        let executor = LuaExecutor::new();
344        let result = executor.execute(&make_request("return 1 + 2"));
345        assert!(result.success);
346        assert!(result.stdout.contains("3"));
347    }
348
349    #[test]
350    fn test_print() {
351        let executor = LuaExecutor::new();
352        let result = executor.execute(&make_request(r#"print("Hello, World!")"#));
353        assert!(result.success);
354        assert!(result.stdout.contains("Hello, World!"));
355    }
356
357    #[test]
358    fn test_variables() {
359        let executor = LuaExecutor::new();
360        let result = executor.execute(&make_request(
361            r#"
362            local x = 10
363            local y = 20
364            return x + y
365            "#,
366        ));
367        assert!(result.success);
368        assert!(result.stdout.contains("30"));
369    }
370
371    #[test]
372    fn test_loop() {
373        let executor = LuaExecutor::new();
374        let result = executor.execute(&make_request(
375            r#"
376            local sum = 0
377            for i = 0, 9 do
378                sum = sum + i
379            end
380            return sum
381            "#,
382        ));
383        assert!(result.success);
384        assert!(result.stdout.contains("45")); // Sum of 0..9
385    }
386
387    #[test]
388    fn test_table() {
389        let executor = LuaExecutor::new();
390        let result = executor.execute(&make_request(
391            r#"
392            local t = {1, 2, 3, 4, 5}
393            return #t
394            "#,
395        ));
396        assert!(result.success);
397        assert!(result.stdout.contains("5"));
398    }
399
400    #[test]
401    fn test_syntax_error() {
402        let executor = LuaExecutor::new();
403        let result = executor.execute(&make_request("local x = "));
404        assert!(!result.success);
405        assert!(result.error.is_some());
406    }
407
408    #[test]
409    fn test_context_injection() {
410        let executor = LuaExecutor::new();
411        let mut request = make_request("return x + y");
412        request.context = Some(serde_json::json!({
413            "x": 10,
414            "y": 20
415        }));
416        let result = executor.execute(&request);
417        assert!(result.success);
418        assert!(result.stdout.contains("30"));
419    }
420
421    #[test]
422    fn test_string_operations() {
423        let executor = LuaExecutor::new();
424        let result = executor.execute(&make_request(
425            r#"
426            local s = "hello"
427            return string.upper(s)
428            "#,
429        ));
430        assert!(result.success);
431        assert!(result.stdout.contains("HELLO"));
432    }
433
434    #[test]
435    fn test_function_definition() {
436        let executor = LuaExecutor::new();
437        let result = executor.execute(&make_request(
438            r#"
439            local function add(a, b)
440                return a + b
441            end
442            return add(3, 4)
443            "#,
444        ));
445        assert!(result.success);
446        assert!(result.stdout.contains("7"));
447    }
448}