Skip to main content

brainwires_code_interpreters/languages/
lua.rs

1//! Lua executor - Small, fast scripting language
2//!
3//! Lua is lightweight and fast, making it ideal for embedded scripting.
4//! Uses mlua which supports Lua 5.4 with vendored builds.
5//!
6//! ## Features
7//! - Very small runtime footprint
8//! - Fast execution
9//! - Good for game scripting and configuration
10//! - Memory limit support
11//!
12//! ## Limitations
13//! - Smaller ecosystem than Python/JS
14//! - 1-indexed arrays (can be confusing)
15
16use mlua::{Lua, MultiValue, Result as LuaResult, Value};
17use std::sync::{Arc, Mutex};
18use std::time::Instant;
19
20use super::{LanguageExecutor, get_limits, truncate_output};
21use crate::types::{ExecutionLimits, ExecutionRequest, ExecutionResult};
22
23/// Lua code executor
24pub struct LuaExecutor {
25    _limits: ExecutionLimits,
26}
27
28impl LuaExecutor {
29    /// Create a new Lua executor with default limits
30    pub fn new() -> Self {
31        Self {
32            _limits: ExecutionLimits::default(),
33        }
34    }
35
36    /// Create a new Lua executor with custom limits
37    pub fn with_limits(limits: ExecutionLimits) -> Self {
38        Self { _limits: limits }
39    }
40
41    /// Execute Lua code
42    pub fn execute_code(&self, request: &ExecutionRequest) -> ExecutionResult {
43        let limits = get_limits(request);
44        let start = Instant::now();
45
46        // Create Lua instance
47        let lua = Lua::new();
48
49        // Set memory limit
50        let _ = lua.set_memory_limit(limits.max_memory_mb as usize * 1024 * 1024);
51
52        // Capture print output
53        let output = Arc::new(Mutex::new(Vec::<String>::new()));
54
55        // Override print function to capture output
56        if let Err(e) = self.setup_print(&lua, output.clone()) {
57            return ExecutionResult::error(
58                format!("Failed to setup print: {}", e),
59                start.elapsed().as_millis() as u64,
60            );
61        }
62
63        // Inject context variables
64        if let Some(context) = &request.context
65            && let Err(e) = self.inject_context(&lua, context)
66        {
67            return ExecutionResult::error(
68                format!("Failed to inject context: {}", e),
69                start.elapsed().as_millis() as u64,
70            );
71        }
72
73        // Execute the code
74        let result = lua.load(&request.code).eval::<Value>();
75        let timing_ms = start.elapsed().as_millis() as u64;
76
77        // Get captured output
78        let stdout = output.lock().map(|out| out.join("\n")).unwrap_or_default();
79        let stdout = truncate_output(&stdout, limits.max_output_bytes);
80
81        // Get memory usage
82        let memory_used = lua.used_memory() as u64;
83
84        match result {
85            Ok(value) => {
86                let result_value = lua_to_json(&value);
87                let mut stdout_with_result = stdout;
88
89                // If there's a non-nil result, append it to stdout
90                if !matches!(value, Value::Nil) {
91                    if !stdout_with_result.is_empty() {
92                        stdout_with_result.push('\n');
93                    }
94                    stdout_with_result.push_str(&format_lua_value(&value));
95                }
96
97                ExecutionResult {
98                    success: true,
99                    stdout: stdout_with_result,
100                    stderr: String::new(),
101                    result: result_value,
102                    error: None,
103                    timing_ms,
104                    memory_used_bytes: Some(memory_used),
105                    operations_count: None,
106                }
107            }
108            Err(e) => {
109                let error_message = format_lua_error(&e);
110                ExecutionResult {
111                    success: false,
112                    stdout,
113                    stderr: error_message.clone(),
114                    result: None,
115                    error: Some(error_message),
116                    timing_ms,
117                    memory_used_bytes: Some(memory_used),
118                    operations_count: None,
119                }
120            }
121        }
122    }
123
124    /// Setup print function to capture output
125    fn setup_print(&self, lua: &Lua, output: Arc<Mutex<Vec<String>>>) -> LuaResult<()> {
126        let print_fn = lua.create_function(move |_, args: MultiValue| {
127            let parts: Vec<String> = args.into_iter().map(|v| format_lua_value(&v)).collect();
128            let line = parts.join("\t");
129
130            if let Ok(mut out) = output.lock() {
131                out.push(line);
132            }
133            Ok(())
134        })?;
135
136        lua.globals().set("print", print_fn)?;
137        Ok(())
138    }
139
140    /// Inject context variables into Lua globals
141    fn inject_context(&self, lua: &Lua, context: &serde_json::Value) -> LuaResult<()> {
142        if let serde_json::Value::Object(map) = context {
143            let globals = lua.globals();
144            for (key, value) in map {
145                let lua_value = json_to_lua(lua, value)?;
146                globals.set(key.as_str(), lua_value)?;
147            }
148        }
149        Ok(())
150    }
151}
152
153impl Default for LuaExecutor {
154    fn default() -> Self {
155        Self::new()
156    }
157}
158
159impl LanguageExecutor for LuaExecutor {
160    fn execute(&self, request: &ExecutionRequest) -> ExecutionResult {
161        self.execute_code(request)
162    }
163
164    fn language_name(&self) -> &'static str {
165        "lua"
166    }
167
168    fn language_version(&self) -> String {
169        "5.4".to_string()
170    }
171}
172
173/// Convert JSON value to Lua value
174fn json_to_lua(lua: &Lua, value: &serde_json::Value) -> LuaResult<Value> {
175    match value {
176        serde_json::Value::Null => Ok(Value::Nil),
177        serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
178        serde_json::Value::Number(n) => {
179            if let Some(i) = n.as_i64() {
180                Ok(Value::Integer(i))
181            } else if let Some(f) = n.as_f64() {
182                Ok(Value::Number(f))
183            } else {
184                Ok(Value::Nil)
185            }
186        }
187        serde_json::Value::String(s) => {
188            let lua_string = lua.create_string(s)?;
189            Ok(Value::String(lua_string))
190        }
191        serde_json::Value::Array(arr) => {
192            let table = lua.create_table()?;
193            for (i, v) in arr.iter().enumerate() {
194                let lua_value = json_to_lua(lua, v)?;
195                table.set(i + 1, lua_value)?; // Lua is 1-indexed
196            }
197            Ok(Value::Table(table))
198        }
199        serde_json::Value::Object(obj) => {
200            let table = lua.create_table()?;
201            for (k, v) in obj {
202                let lua_value = json_to_lua(lua, v)?;
203                table.set(k.as_str(), lua_value)?;
204            }
205            Ok(Value::Table(table))
206        }
207    }
208}
209
210/// Convert Lua value to JSON
211fn lua_to_json(value: &Value) -> Option<serde_json::Value> {
212    match value {
213        Value::Nil => None,
214        Value::Boolean(b) => Some(serde_json::Value::Bool(*b)),
215        Value::Integer(i) => Some(serde_json::Value::Number(serde_json::Number::from(*i))),
216        Value::Number(f) => serde_json::Number::from_f64(*f).map(serde_json::Value::Number),
217        Value::String(s) => s
218            .to_str()
219            .ok()
220            .map(|s| serde_json::Value::String(s.to_string())),
221        Value::Table(t) => {
222            // Try to determine if it's an array or object
223            let mut is_array = true;
224            let mut max_index = 0i64;
225            let mut has_string_keys = false;
226
227            // Check keys
228            if let Ok(pairs) = t
229                .clone()
230                .pairs::<Value, Value>()
231                .collect::<LuaResult<Vec<_>>>()
232            {
233                for (k, _) in &pairs {
234                    match k {
235                        Value::Integer(i) if *i > 0 => {
236                            max_index = max_index.max(*i);
237                        }
238                        Value::String(_) => {
239                            has_string_keys = true;
240                            is_array = false;
241                        }
242                        _ => {
243                            is_array = false;
244                        }
245                    }
246                }
247
248                if is_array && !has_string_keys && max_index > 0 {
249                    // It's an array
250                    let mut arr = Vec::new();
251                    for i in 1..=max_index {
252                        if let Ok(v) = t.get::<Value>(i) {
253                            arr.push(lua_to_json(&v).unwrap_or(serde_json::Value::Null));
254                        }
255                    }
256                    Some(serde_json::Value::Array(arr))
257                } else {
258                    // It's an object
259                    let mut map = serde_json::Map::new();
260                    for (k, v) in pairs {
261                        let key = format_lua_value(&k);
262                        if let Some(json_v) = lua_to_json(&v) {
263                            map.insert(key, json_v);
264                        }
265                    }
266                    Some(serde_json::Value::Object(map))
267                }
268            } else {
269                None
270            }
271        }
272        Value::Function(_) => Some(serde_json::Value::String("[function]".to_string())),
273        Value::Thread(_) => Some(serde_json::Value::String("[thread]".to_string())),
274        Value::UserData(_) => Some(serde_json::Value::String("[userdata]".to_string())),
275        Value::LightUserData(_) => Some(serde_json::Value::String("[lightuserdata]".to_string())),
276        Value::Error(e) => Some(serde_json::Value::String(format!("[error: {}]", e))),
277        _ => Some(serde_json::Value::String("[unknown]".to_string())),
278    }
279}
280
281/// Format Lua value for display
282fn format_lua_value(value: &Value) -> String {
283    match value {
284        Value::Nil => "nil".to_string(),
285        Value::Boolean(b) => b.to_string(),
286        Value::Integer(i) => i.to_string(),
287        Value::Number(f) => f.to_string(),
288        Value::String(s) => s
289            .to_str()
290            .map(|s| s.to_string())
291            .unwrap_or_else(|_| "[invalid utf8]".to_string()),
292        Value::Table(t) => {
293            // Simple table representation
294            let mut parts = Vec::new();
295            if let Ok(pairs) = t
296                .clone()
297                .pairs::<Value, Value>()
298                .collect::<LuaResult<Vec<_>>>()
299            {
300                for (k, v) in pairs.iter().take(10) {
301                    parts.push(format!("{}={}", format_lua_value(k), format_lua_value(v)));
302                }
303                if pairs.len() > 10 {
304                    parts.push("...".to_string());
305                }
306            }
307            format!("{{{}}}", parts.join(", "))
308        }
309        Value::Function(_) => "[function]".to_string(),
310        Value::Thread(_) => "[thread]".to_string(),
311        Value::UserData(_) => "[userdata]".to_string(),
312        Value::LightUserData(_) => "[lightuserdata]".to_string(),
313        Value::Error(e) => format!("[error: {}]", e),
314        _ => "[unknown]".to_string(),
315    }
316}
317
318/// Format Lua error for user display
319fn format_lua_error(error: &mlua::Error) -> String {
320    match error {
321        mlua::Error::SyntaxError { message, .. } => {
322            format!("Syntax error: {}", message)
323        }
324        mlua::Error::RuntimeError(msg) => {
325            format!("Runtime error: {}", msg)
326        }
327        mlua::Error::MemoryError(msg) => {
328            format!("Memory error: {}", msg)
329        }
330        mlua::Error::CallbackError { traceback, cause } => {
331            format!("Callback error: {}\nTraceback: {}", cause, traceback)
332        }
333        _ => format!("{}", error),
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use crate::types::Language;
341
342    fn make_request(code: &str) -> ExecutionRequest {
343        ExecutionRequest {
344            language: Language::Lua,
345            code: code.to_string(),
346            ..Default::default()
347        }
348    }
349
350    #[test]
351    fn test_simple_expression() {
352        let executor = LuaExecutor::new();
353        let result = executor.execute(&make_request("return 1 + 2"));
354        assert!(result.success);
355        assert!(result.stdout.contains("3"));
356    }
357
358    #[test]
359    fn test_print() {
360        let executor = LuaExecutor::new();
361        let result = executor.execute(&make_request(r#"print("Hello, World!")"#));
362        assert!(result.success);
363        assert!(result.stdout.contains("Hello, World!"));
364    }
365
366    #[test]
367    fn test_variables() {
368        let executor = LuaExecutor::new();
369        let result = executor.execute(&make_request(
370            r#"
371            local x = 10
372            local y = 20
373            return x + y
374            "#,
375        ));
376        assert!(result.success);
377        assert!(result.stdout.contains("30"));
378    }
379
380    #[test]
381    fn test_loop() {
382        let executor = LuaExecutor::new();
383        let result = executor.execute(&make_request(
384            r#"
385            local sum = 0
386            for i = 0, 9 do
387                sum = sum + i
388            end
389            return sum
390            "#,
391        ));
392        assert!(result.success);
393        assert!(result.stdout.contains("45")); // Sum of 0..9
394    }
395
396    #[test]
397    fn test_table() {
398        let executor = LuaExecutor::new();
399        let result = executor.execute(&make_request(
400            r#"
401            local t = {1, 2, 3, 4, 5}
402            return #t
403            "#,
404        ));
405        assert!(result.success);
406        assert!(result.stdout.contains("5"));
407    }
408
409    #[test]
410    fn test_syntax_error() {
411        let executor = LuaExecutor::new();
412        let result = executor.execute(&make_request("local x = "));
413        assert!(!result.success);
414        assert!(result.error.is_some());
415    }
416
417    #[test]
418    fn test_context_injection() {
419        let executor = LuaExecutor::new();
420        let mut request = make_request("return x + y");
421        request.context = Some(serde_json::json!({
422            "x": 10,
423            "y": 20
424        }));
425        let result = executor.execute(&request);
426        assert!(result.success);
427        assert!(result.stdout.contains("30"));
428    }
429
430    #[test]
431    fn test_string_operations() {
432        let executor = LuaExecutor::new();
433        let result = executor.execute(&make_request(
434            r#"
435            local s = "hello"
436            return string.upper(s)
437            "#,
438        ));
439        assert!(result.success);
440        assert!(result.stdout.contains("HELLO"));
441    }
442
443    #[test]
444    fn test_function_definition() {
445        let executor = LuaExecutor::new();
446        let result = executor.execute(&make_request(
447            r#"
448            local function add(a, b)
449                return a + b
450            end
451            return add(3, 4)
452            "#,
453        ));
454        assert!(result.success);
455        assert!(result.stdout.contains("7"));
456    }
457}