brainwires_code_interpreters/languages/
lua.rs1use mlua::{Lua, MultiValue, Result as LuaResult, Value};
17use std::sync::{Arc, Mutex};
18use std::time::Instant;
19
20use super::{get_limits, truncate_output, LanguageExecutor};
21use crate::types::{ExecutionLimits, ExecutionRequest, ExecutionResult};
22
23pub struct LuaExecutor {
25 _limits: ExecutionLimits,
26}
27
28impl LuaExecutor {
29 pub fn new() -> Self {
31 Self {
32 _limits: ExecutionLimits::default(),
33 }
34 }
35
36 pub fn with_limits(limits: ExecutionLimits) -> Self {
38 Self { _limits: limits }
39 }
40
41 pub fn execute_code(&self, request: &ExecutionRequest) -> ExecutionResult {
43 let limits = get_limits(request);
44 let start = Instant::now();
45
46 let lua = Lua::new();
48
49 let _ = lua.set_memory_limit(limits.max_memory_mb as usize * 1024 * 1024);
51
52 let output = Arc::new(Mutex::new(Vec::<String>::new()));
54
55 if let Err(e) = self.setup_print(&lua, output.clone()) {
57 return ExecutionResult::error(
58 format!("Failed to setup print: {}", e),
59 start.elapsed().as_millis() as u64,
60 );
61 }
62
63 if let Some(context) = &request.context
65 && let Err(e) = self.inject_context(&lua, context) {
66 return ExecutionResult::error(
67 format!("Failed to inject context: {}", e),
68 start.elapsed().as_millis() as u64,
69 );
70 }
71
72 let result = lua.load(&request.code).eval::<Value>();
74 let timing_ms = start.elapsed().as_millis() as u64;
75
76 let stdout = output
78 .lock()
79 .map(|out| out.join("\n"))
80 .unwrap_or_default();
81 let stdout = truncate_output(&stdout, limits.max_output_bytes);
82
83 let memory_used = lua.used_memory() as u64;
85
86 match result {
87 Ok(value) => {
88 let result_value = lua_to_json(&value);
89 let mut stdout_with_result = stdout;
90
91 if !matches!(value, Value::Nil) {
93 if !stdout_with_result.is_empty() {
94 stdout_with_result.push('\n');
95 }
96 stdout_with_result.push_str(&format_lua_value(&value));
97 }
98
99 ExecutionResult {
100 success: true,
101 stdout: stdout_with_result,
102 stderr: String::new(),
103 result: result_value,
104 error: None,
105 timing_ms,
106 memory_used_bytes: Some(memory_used),
107 operations_count: None,
108 }
109 }
110 Err(e) => {
111 let error_message = format_lua_error(&e);
112 ExecutionResult {
113 success: false,
114 stdout,
115 stderr: error_message.clone(),
116 result: None,
117 error: Some(error_message),
118 timing_ms,
119 memory_used_bytes: Some(memory_used),
120 operations_count: None,
121 }
122 }
123 }
124 }
125
126 fn setup_print(&self, lua: &Lua, output: Arc<Mutex<Vec<String>>>) -> LuaResult<()> {
128 let print_fn = lua.create_function(move |_, args: MultiValue| {
129 let parts: Vec<String> = args
130 .into_iter()
131 .map(|v| format_lua_value(&v))
132 .collect();
133 let line = parts.join("\t");
134
135 if let Ok(mut out) = output.lock() {
136 out.push(line);
137 }
138 Ok(())
139 })?;
140
141 lua.globals().set("print", print_fn)?;
142 Ok(())
143 }
144
145 fn inject_context(&self, lua: &Lua, context: &serde_json::Value) -> LuaResult<()> {
147 if let serde_json::Value::Object(map) = context {
148 let globals = lua.globals();
149 for (key, value) in map {
150 let lua_value = json_to_lua(lua, value)?;
151 globals.set(key.as_str(), lua_value)?;
152 }
153 }
154 Ok(())
155 }
156}
157
158impl Default for LuaExecutor {
159 fn default() -> Self {
160 Self::new()
161 }
162}
163
164impl LanguageExecutor for LuaExecutor {
165 fn execute(&self, request: &ExecutionRequest) -> ExecutionResult {
166 self.execute_code(request)
167 }
168
169 fn language_name(&self) -> &'static str {
170 "lua"
171 }
172
173 fn language_version(&self) -> String {
174 "5.4".to_string()
175 }
176}
177
178fn json_to_lua(lua: &Lua, value: &serde_json::Value) -> LuaResult<Value> {
180 match value {
181 serde_json::Value::Null => Ok(Value::Nil),
182 serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
183 serde_json::Value::Number(n) => {
184 if let Some(i) = n.as_i64() {
185 Ok(Value::Integer(i))
186 } else if let Some(f) = n.as_f64() {
187 Ok(Value::Number(f))
188 } else {
189 Ok(Value::Nil)
190 }
191 }
192 serde_json::Value::String(s) => {
193 let lua_string = lua.create_string(s)?;
194 Ok(Value::String(lua_string))
195 }
196 serde_json::Value::Array(arr) => {
197 let table = lua.create_table()?;
198 for (i, v) in arr.iter().enumerate() {
199 let lua_value = json_to_lua(lua, v)?;
200 table.set(i + 1, lua_value)?; }
202 Ok(Value::Table(table))
203 }
204 serde_json::Value::Object(obj) => {
205 let table = lua.create_table()?;
206 for (k, v) in obj {
207 let lua_value = json_to_lua(lua, v)?;
208 table.set(k.as_str(), lua_value)?;
209 }
210 Ok(Value::Table(table))
211 }
212 }
213}
214
215fn lua_to_json(value: &Value) -> Option<serde_json::Value> {
217 match value {
218 Value::Nil => None,
219 Value::Boolean(b) => Some(serde_json::Value::Bool(*b)),
220 Value::Integer(i) => Some(serde_json::Value::Number(serde_json::Number::from(*i))),
221 Value::Number(f) => serde_json::Number::from_f64(*f).map(serde_json::Value::Number),
222 Value::String(s) => s.to_str().ok().map(|s| serde_json::Value::String(s.to_string())),
223 Value::Table(t) => {
224 let mut is_array = true;
226 let mut max_index = 0i64;
227 let mut has_string_keys = false;
228
229 if let Ok(pairs) = t.clone().pairs::<Value, Value>().collect::<LuaResult<Vec<_>>>() {
231 for (k, _) in &pairs {
232 match k {
233 Value::Integer(i) if *i > 0 => {
234 max_index = max_index.max(*i);
235 }
236 Value::String(_) => {
237 has_string_keys = true;
238 is_array = false;
239 }
240 _ => {
241 is_array = false;
242 }
243 }
244 }
245
246 if is_array && !has_string_keys && max_index > 0 {
247 let mut arr = Vec::new();
249 for i in 1..=max_index {
250 if let Ok(v) = t.get::<Value>(i) {
251 arr.push(lua_to_json(&v).unwrap_or(serde_json::Value::Null));
252 }
253 }
254 Some(serde_json::Value::Array(arr))
255 } else {
256 let mut map = serde_json::Map::new();
258 for (k, v) in pairs {
259 let key = format_lua_value(&k);
260 if let Some(json_v) = lua_to_json(&v) {
261 map.insert(key, json_v);
262 }
263 }
264 Some(serde_json::Value::Object(map))
265 }
266 } else {
267 None
268 }
269 }
270 Value::Function(_) => Some(serde_json::Value::String("[function]".to_string())),
271 Value::Thread(_) => Some(serde_json::Value::String("[thread]".to_string())),
272 Value::UserData(_) => Some(serde_json::Value::String("[userdata]".to_string())),
273 Value::LightUserData(_) => Some(serde_json::Value::String("[lightuserdata]".to_string())),
274 Value::Error(e) => Some(serde_json::Value::String(format!("[error: {}]", e))),
275 _ => Some(serde_json::Value::String("[unknown]".to_string())),
276 }
277}
278
279fn format_lua_value(value: &Value) -> String {
281 match value {
282 Value::Nil => "nil".to_string(),
283 Value::Boolean(b) => b.to_string(),
284 Value::Integer(i) => i.to_string(),
285 Value::Number(f) => f.to_string(),
286 Value::String(s) => s.to_str().map(|s| s.to_string()).unwrap_or_else(|_| "[invalid utf8]".to_string()),
287 Value::Table(t) => {
288 let mut parts = Vec::new();
290 if let Ok(pairs) = t.clone().pairs::<Value, Value>().collect::<LuaResult<Vec<_>>>() {
291 for (k, v) in pairs.iter().take(10) {
292 parts.push(format!("{}={}", format_lua_value(k), format_lua_value(v)));
293 }
294 if pairs.len() > 10 {
295 parts.push("...".to_string());
296 }
297 }
298 format!("{{{}}}", parts.join(", "))
299 }
300 Value::Function(_) => "[function]".to_string(),
301 Value::Thread(_) => "[thread]".to_string(),
302 Value::UserData(_) => "[userdata]".to_string(),
303 Value::LightUserData(_) => "[lightuserdata]".to_string(),
304 Value::Error(e) => format!("[error: {}]", e),
305 _ => "[unknown]".to_string(),
306 }
307}
308
309fn format_lua_error(error: &mlua::Error) -> String {
311 match error {
312 mlua::Error::SyntaxError { message, .. } => {
313 format!("Syntax error: {}", message)
314 }
315 mlua::Error::RuntimeError(msg) => {
316 format!("Runtime error: {}", msg)
317 }
318 mlua::Error::MemoryError(msg) => {
319 format!("Memory error: {}", msg)
320 }
321 mlua::Error::CallbackError { traceback, cause } => {
322 format!("Callback error: {}\nTraceback: {}", cause, traceback)
323 }
324 _ => format!("{}", error),
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use crate::types::Language;
332
333 fn make_request(code: &str) -> ExecutionRequest {
334 ExecutionRequest {
335 language: Language::Lua,
336 code: code.to_string(),
337 ..Default::default()
338 }
339 }
340
341 #[test]
342 fn test_simple_expression() {
343 let executor = LuaExecutor::new();
344 let result = executor.execute(&make_request("return 1 + 2"));
345 assert!(result.success);
346 assert!(result.stdout.contains("3"));
347 }
348
349 #[test]
350 fn test_print() {
351 let executor = LuaExecutor::new();
352 let result = executor.execute(&make_request(r#"print("Hello, World!")"#));
353 assert!(result.success);
354 assert!(result.stdout.contains("Hello, World!"));
355 }
356
357 #[test]
358 fn test_variables() {
359 let executor = LuaExecutor::new();
360 let result = executor.execute(&make_request(
361 r#"
362 local x = 10
363 local y = 20
364 return x + y
365 "#,
366 ));
367 assert!(result.success);
368 assert!(result.stdout.contains("30"));
369 }
370
371 #[test]
372 fn test_loop() {
373 let executor = LuaExecutor::new();
374 let result = executor.execute(&make_request(
375 r#"
376 local sum = 0
377 for i = 0, 9 do
378 sum = sum + i
379 end
380 return sum
381 "#,
382 ));
383 assert!(result.success);
384 assert!(result.stdout.contains("45")); }
386
387 #[test]
388 fn test_table() {
389 let executor = LuaExecutor::new();
390 let result = executor.execute(&make_request(
391 r#"
392 local t = {1, 2, 3, 4, 5}
393 return #t
394 "#,
395 ));
396 assert!(result.success);
397 assert!(result.stdout.contains("5"));
398 }
399
400 #[test]
401 fn test_syntax_error() {
402 let executor = LuaExecutor::new();
403 let result = executor.execute(&make_request("local x = "));
404 assert!(!result.success);
405 assert!(result.error.is_some());
406 }
407
408 #[test]
409 fn test_context_injection() {
410 let executor = LuaExecutor::new();
411 let mut request = make_request("return x + y");
412 request.context = Some(serde_json::json!({
413 "x": 10,
414 "y": 20
415 }));
416 let result = executor.execute(&request);
417 assert!(result.success);
418 assert!(result.stdout.contains("30"));
419 }
420
421 #[test]
422 fn test_string_operations() {
423 let executor = LuaExecutor::new();
424 let result = executor.execute(&make_request(
425 r#"
426 local s = "hello"
427 return string.upper(s)
428 "#,
429 ));
430 assert!(result.success);
431 assert!(result.stdout.contains("HELLO"));
432 }
433
434 #[test]
435 fn test_function_definition() {
436 let executor = LuaExecutor::new();
437 let result = executor.execute(&make_request(
438 r#"
439 local function add(a, b)
440 return a + b
441 end
442 return add(3, 4)
443 "#,
444 ));
445 assert!(result.success);
446 assert!(result.stdout.contains("7"));
447 }
448}