brainwires_code_interpreters/languages/
lua.rs1use mlua::{Lua, MultiValue, Result as LuaResult, Value};
17use std::sync::{Arc, Mutex};
18use std::time::Instant;
19
20use super::{LanguageExecutor, get_limits, truncate_output};
21use crate::types::{ExecutionLimits, ExecutionRequest, ExecutionResult};
22
23pub struct LuaExecutor {
25 _limits: ExecutionLimits,
26}
27
28impl LuaExecutor {
29 pub fn new() -> Self {
31 Self {
32 _limits: ExecutionLimits::default(),
33 }
34 }
35
36 pub fn with_limits(limits: ExecutionLimits) -> Self {
38 Self { _limits: limits }
39 }
40
41 pub fn execute_code(&self, request: &ExecutionRequest) -> ExecutionResult {
43 let limits = get_limits(request);
44 let start = Instant::now();
45
46 let lua = Lua::new();
48
49 let _ = lua.set_memory_limit(limits.max_memory_mb as usize * 1024 * 1024);
51
52 let output = Arc::new(Mutex::new(Vec::<String>::new()));
54
55 if let Err(e) = self.setup_print(&lua, output.clone()) {
57 return ExecutionResult::error(
58 format!("Failed to setup print: {}", e),
59 start.elapsed().as_millis() as u64,
60 );
61 }
62
63 if let Some(context) = &request.context
65 && let Err(e) = self.inject_context(&lua, context)
66 {
67 return ExecutionResult::error(
68 format!("Failed to inject context: {}", e),
69 start.elapsed().as_millis() as u64,
70 );
71 }
72
73 let result = lua.load(&request.code).eval::<Value>();
75 let timing_ms = start.elapsed().as_millis() as u64;
76
77 let stdout = output.lock().map(|out| out.join("\n")).unwrap_or_default();
79 let stdout = truncate_output(&stdout, limits.max_output_bytes);
80
81 let memory_used = lua.used_memory() as u64;
83
84 match result {
85 Ok(value) => {
86 let result_value = lua_to_json(&value);
87 let mut stdout_with_result = stdout;
88
89 if !matches!(value, Value::Nil) {
91 if !stdout_with_result.is_empty() {
92 stdout_with_result.push('\n');
93 }
94 stdout_with_result.push_str(&format_lua_value(&value));
95 }
96
97 ExecutionResult {
98 success: true,
99 stdout: stdout_with_result,
100 stderr: String::new(),
101 result: result_value,
102 error: None,
103 timing_ms,
104 memory_used_bytes: Some(memory_used),
105 operations_count: None,
106 }
107 }
108 Err(e) => {
109 let error_message = format_lua_error(&e);
110 ExecutionResult {
111 success: false,
112 stdout,
113 stderr: error_message.clone(),
114 result: None,
115 error: Some(error_message),
116 timing_ms,
117 memory_used_bytes: Some(memory_used),
118 operations_count: None,
119 }
120 }
121 }
122 }
123
124 fn setup_print(&self, lua: &Lua, output: Arc<Mutex<Vec<String>>>) -> LuaResult<()> {
126 let print_fn = lua.create_function(move |_, args: MultiValue| {
127 let parts: Vec<String> = args.into_iter().map(|v| format_lua_value(&v)).collect();
128 let line = parts.join("\t");
129
130 if let Ok(mut out) = output.lock() {
131 out.push(line);
132 }
133 Ok(())
134 })?;
135
136 lua.globals().set("print", print_fn)?;
137 Ok(())
138 }
139
140 fn inject_context(&self, lua: &Lua, context: &serde_json::Value) -> LuaResult<()> {
142 if let serde_json::Value::Object(map) = context {
143 let globals = lua.globals();
144 for (key, value) in map {
145 let lua_value = json_to_lua(lua, value)?;
146 globals.set(key.as_str(), lua_value)?;
147 }
148 }
149 Ok(())
150 }
151}
152
153impl Default for LuaExecutor {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159impl LanguageExecutor for LuaExecutor {
160 fn execute(&self, request: &ExecutionRequest) -> ExecutionResult {
161 self.execute_code(request)
162 }
163
164 fn language_name(&self) -> &'static str {
165 "lua"
166 }
167
168 fn language_version(&self) -> String {
169 "5.4".to_string()
170 }
171}
172
173fn json_to_lua(lua: &Lua, value: &serde_json::Value) -> LuaResult<Value> {
175 match value {
176 serde_json::Value::Null => Ok(Value::Nil),
177 serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
178 serde_json::Value::Number(n) => {
179 if let Some(i) = n.as_i64() {
180 Ok(Value::Integer(i))
181 } else if let Some(f) = n.as_f64() {
182 Ok(Value::Number(f))
183 } else {
184 Ok(Value::Nil)
185 }
186 }
187 serde_json::Value::String(s) => {
188 let lua_string = lua.create_string(s)?;
189 Ok(Value::String(lua_string))
190 }
191 serde_json::Value::Array(arr) => {
192 let table = lua.create_table()?;
193 for (i, v) in arr.iter().enumerate() {
194 let lua_value = json_to_lua(lua, v)?;
195 table.set(i + 1, lua_value)?; }
197 Ok(Value::Table(table))
198 }
199 serde_json::Value::Object(obj) => {
200 let table = lua.create_table()?;
201 for (k, v) in obj {
202 let lua_value = json_to_lua(lua, v)?;
203 table.set(k.as_str(), lua_value)?;
204 }
205 Ok(Value::Table(table))
206 }
207 }
208}
209
210fn lua_to_json(value: &Value) -> Option<serde_json::Value> {
212 match value {
213 Value::Nil => None,
214 Value::Boolean(b) => Some(serde_json::Value::Bool(*b)),
215 Value::Integer(i) => Some(serde_json::Value::Number(serde_json::Number::from(*i))),
216 Value::Number(f) => serde_json::Number::from_f64(*f).map(serde_json::Value::Number),
217 Value::String(s) => s
218 .to_str()
219 .ok()
220 .map(|s| serde_json::Value::String(s.to_string())),
221 Value::Table(t) => {
222 let mut is_array = true;
224 let mut max_index = 0i64;
225 let mut has_string_keys = false;
226
227 if let Ok(pairs) = t
229 .clone()
230 .pairs::<Value, Value>()
231 .collect::<LuaResult<Vec<_>>>()
232 {
233 for (k, _) in &pairs {
234 match k {
235 Value::Integer(i) if *i > 0 => {
236 max_index = max_index.max(*i);
237 }
238 Value::String(_) => {
239 has_string_keys = true;
240 is_array = false;
241 }
242 _ => {
243 is_array = false;
244 }
245 }
246 }
247
248 if is_array && !has_string_keys && max_index > 0 {
249 let mut arr = Vec::new();
251 for i in 1..=max_index {
252 if let Ok(v) = t.get::<Value>(i) {
253 arr.push(lua_to_json(&v).unwrap_or(serde_json::Value::Null));
254 }
255 }
256 Some(serde_json::Value::Array(arr))
257 } else {
258 let mut map = serde_json::Map::new();
260 for (k, v) in pairs {
261 let key = format_lua_value(&k);
262 if let Some(json_v) = lua_to_json(&v) {
263 map.insert(key, json_v);
264 }
265 }
266 Some(serde_json::Value::Object(map))
267 }
268 } else {
269 None
270 }
271 }
272 Value::Function(_) => Some(serde_json::Value::String("[function]".to_string())),
273 Value::Thread(_) => Some(serde_json::Value::String("[thread]".to_string())),
274 Value::UserData(_) => Some(serde_json::Value::String("[userdata]".to_string())),
275 Value::LightUserData(_) => Some(serde_json::Value::String("[lightuserdata]".to_string())),
276 Value::Error(e) => Some(serde_json::Value::String(format!("[error: {}]", e))),
277 _ => Some(serde_json::Value::String("[unknown]".to_string())),
278 }
279}
280
281fn format_lua_value(value: &Value) -> String {
283 match value {
284 Value::Nil => "nil".to_string(),
285 Value::Boolean(b) => b.to_string(),
286 Value::Integer(i) => i.to_string(),
287 Value::Number(f) => f.to_string(),
288 Value::String(s) => s
289 .to_str()
290 .map(|s| s.to_string())
291 .unwrap_or_else(|_| "[invalid utf8]".to_string()),
292 Value::Table(t) => {
293 let mut parts = Vec::new();
295 if let Ok(pairs) = t
296 .clone()
297 .pairs::<Value, Value>()
298 .collect::<LuaResult<Vec<_>>>()
299 {
300 for (k, v) in pairs.iter().take(10) {
301 parts.push(format!("{}={}", format_lua_value(k), format_lua_value(v)));
302 }
303 if pairs.len() > 10 {
304 parts.push("...".to_string());
305 }
306 }
307 format!("{{{}}}", parts.join(", "))
308 }
309 Value::Function(_) => "[function]".to_string(),
310 Value::Thread(_) => "[thread]".to_string(),
311 Value::UserData(_) => "[userdata]".to_string(),
312 Value::LightUserData(_) => "[lightuserdata]".to_string(),
313 Value::Error(e) => format!("[error: {}]", e),
314 _ => "[unknown]".to_string(),
315 }
316}
317
318fn format_lua_error(error: &mlua::Error) -> String {
320 match error {
321 mlua::Error::SyntaxError { message, .. } => {
322 format!("Syntax error: {}", message)
323 }
324 mlua::Error::RuntimeError(msg) => {
325 format!("Runtime error: {}", msg)
326 }
327 mlua::Error::MemoryError(msg) => {
328 format!("Memory error: {}", msg)
329 }
330 mlua::Error::CallbackError { traceback, cause } => {
331 format!("Callback error: {}\nTraceback: {}", cause, traceback)
332 }
333 _ => format!("{}", error),
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use crate::types::Language;
341
342 fn make_request(code: &str) -> ExecutionRequest {
343 ExecutionRequest {
344 language: Language::Lua,
345 code: code.to_string(),
346 ..Default::default()
347 }
348 }
349
350 #[test]
351 fn test_simple_expression() {
352 let executor = LuaExecutor::new();
353 let result = executor.execute(&make_request("return 1 + 2"));
354 assert!(result.success);
355 assert!(result.stdout.contains("3"));
356 }
357
358 #[test]
359 fn test_print() {
360 let executor = LuaExecutor::new();
361 let result = executor.execute(&make_request(r#"print("Hello, World!")"#));
362 assert!(result.success);
363 assert!(result.stdout.contains("Hello, World!"));
364 }
365
366 #[test]
367 fn test_variables() {
368 let executor = LuaExecutor::new();
369 let result = executor.execute(&make_request(
370 r#"
371 local x = 10
372 local y = 20
373 return x + y
374 "#,
375 ));
376 assert!(result.success);
377 assert!(result.stdout.contains("30"));
378 }
379
380 #[test]
381 fn test_loop() {
382 let executor = LuaExecutor::new();
383 let result = executor.execute(&make_request(
384 r#"
385 local sum = 0
386 for i = 0, 9 do
387 sum = sum + i
388 end
389 return sum
390 "#,
391 ));
392 assert!(result.success);
393 assert!(result.stdout.contains("45")); }
395
396 #[test]
397 fn test_table() {
398 let executor = LuaExecutor::new();
399 let result = executor.execute(&make_request(
400 r#"
401 local t = {1, 2, 3, 4, 5}
402 return #t
403 "#,
404 ));
405 assert!(result.success);
406 assert!(result.stdout.contains("5"));
407 }
408
409 #[test]
410 fn test_syntax_error() {
411 let executor = LuaExecutor::new();
412 let result = executor.execute(&make_request("local x = "));
413 assert!(!result.success);
414 assert!(result.error.is_some());
415 }
416
417 #[test]
418 fn test_context_injection() {
419 let executor = LuaExecutor::new();
420 let mut request = make_request("return x + y");
421 request.context = Some(serde_json::json!({
422 "x": 10,
423 "y": 20
424 }));
425 let result = executor.execute(&request);
426 assert!(result.success);
427 assert!(result.stdout.contains("30"));
428 }
429
430 #[test]
431 fn test_string_operations() {
432 let executor = LuaExecutor::new();
433 let result = executor.execute(&make_request(
434 r#"
435 local s = "hello"
436 return string.upper(s)
437 "#,
438 ));
439 assert!(result.success);
440 assert!(result.stdout.contains("HELLO"));
441 }
442
443 #[test]
444 fn test_function_definition() {
445 let executor = LuaExecutor::new();
446 let result = executor.execute(&make_request(
447 r#"
448 local function add(a, b)
449 return a + b
450 end
451 return add(3, 4)
452 "#,
453 ));
454 assert!(result.success);
455 assert!(result.stdout.contains("7"));
456 }
457}