1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::{Duration, Instant};
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{
10 ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, run_version_command,
11};
12
13pub struct LuaEngine {
14 interpreter: Option<PathBuf>,
15}
16
17impl Default for LuaEngine {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl LuaEngine {
24 pub fn new() -> Self {
25 Self {
26 interpreter: resolve_lua_binary(),
27 }
28 }
29
30 fn ensure_interpreter(&self) -> Result<&Path> {
31 self.interpreter.as_deref().ok_or_else(|| {
32 anyhow::anyhow!(
33 "Lua support requires the `lua` executable. Install it from https://www.lua.org/download.html and ensure it is on your PATH."
34 )
35 })
36 }
37
38 fn write_temp_script(&self, code: &str) -> Result<(tempfile::TempDir, PathBuf)> {
39 let dir = Builder::new()
40 .prefix("run-lua")
41 .tempdir()
42 .context("failed to create temporary directory for lua source")?;
43 let path = dir.path().join("snippet.lua");
44 let mut contents = code.to_string();
45 if !contents.ends_with('\n') {
46 contents.push('\n');
47 }
48 std::fs::write(&path, contents).with_context(|| {
49 format!("failed to write temporary Lua source to {}", path.display())
50 })?;
51 Ok((dir, path))
52 }
53
54 fn execute_script(&self, script: &Path, args: &[String]) -> Result<std::process::Output> {
55 let interpreter = self.ensure_interpreter()?;
56 let mut cmd = Command::new(interpreter);
57 cmd.arg(script)
58 .args(args)
59 .stdout(Stdio::piped())
60 .stderr(Stdio::piped());
61 cmd.stdin(Stdio::inherit());
62 if let Some(dir) = script.parent() {
63 cmd.current_dir(dir);
64 }
65 cmd.output().with_context(|| {
66 format!(
67 "failed to execute {} with script {}",
68 interpreter.display(),
69 script.display()
70 )
71 })
72 }
73}
74
75impl LanguageEngine for LuaEngine {
76 fn id(&self) -> &'static str {
77 "lua"
78 }
79
80 fn display_name(&self) -> &'static str {
81 "Lua"
82 }
83
84 fn aliases(&self) -> &[&'static str] {
85 &[]
86 }
87
88 fn supports_sessions(&self) -> bool {
89 self.interpreter.is_some()
90 }
91
92 fn validate(&self) -> Result<()> {
93 let interpreter = self.ensure_interpreter()?;
94 let mut cmd = Command::new(interpreter);
95 cmd.arg("-v").stdout(Stdio::null()).stderr(Stdio::null());
96 cmd.status()
97 .with_context(|| format!("failed to invoke {}", interpreter.display()))?
98 .success()
99 .then_some(())
100 .ok_or_else(|| anyhow::anyhow!("{} is not executable", interpreter.display()))
101 }
102
103 fn toolchain_version(&self) -> Result<Option<String>> {
104 let interpreter = self.ensure_interpreter()?;
105 let mut cmd = Command::new(interpreter);
106 cmd.arg("-v");
107 let context = format!("{}", interpreter.display());
108 run_version_command(cmd, &context)
109 }
110
111 fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
112 let start = Instant::now();
113 let (temp_dir, script_path) = match payload {
114 ExecutionPayload::Inline { code, .. } | ExecutionPayload::Stdin { code, .. } => {
115 let (dir, path) = self.write_temp_script(code)?;
116 (Some(dir), path)
117 }
118 ExecutionPayload::File { path, .. } => (None, path.clone()),
119 };
120
121 let output = self.execute_script(&script_path, payload.args())?;
122
123 drop(temp_dir);
124
125 Ok(ExecutionOutcome {
126 language: self.id().to_string(),
127 exit_code: output.status.code(),
128 stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
129 stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
130 duration: start.elapsed(),
131 })
132 }
133
134 fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
135 let interpreter = self.ensure_interpreter()?.to_path_buf();
136 let session = LuaSession::new(interpreter)?;
137 Ok(Box::new(session))
138 }
139}
140
141fn resolve_lua_binary() -> Option<PathBuf> {
142 which::which("lua").ok()
143}
144
145const SESSION_MAIN_FILE: &str = "session.lua";
146
147struct LuaSession {
148 interpreter: PathBuf,
149 workspace: TempDir,
150 statements: Vec<String>,
151 last_stdout: String,
152 last_stderr: String,
153}
154
155impl LuaSession {
156 fn new(interpreter: PathBuf) -> Result<Self> {
157 let workspace = TempDir::new().context("failed to create Lua session workspace")?;
158 let session = Self {
159 interpreter,
160 workspace,
161 statements: Vec::new(),
162 last_stdout: String::new(),
163 last_stderr: String::new(),
164 };
165 session.persist_source()?;
166 Ok(session)
167 }
168
169 fn language_id(&self) -> &str {
170 "lua"
171 }
172
173 fn source_path(&self) -> PathBuf {
174 self.workspace.path().join(SESSION_MAIN_FILE)
175 }
176
177 fn persist_source(&self) -> Result<()> {
178 let path = self.source_path();
179 let mut source = String::new();
180 if self.statements.is_empty() {
181 source.push_str("-- session body\n");
182 } else {
183 for stmt in &self.statements {
184 source.push_str(stmt);
185 if !stmt.ends_with('\n') {
186 source.push('\n');
187 }
188 }
189 }
190 fs::write(&path, source)
191 .with_context(|| format!("failed to write Lua session source at {}", path.display()))
192 }
193
194 fn run_program(&self) -> Result<std::process::Output> {
195 let mut cmd = Command::new(&self.interpreter);
196 cmd.arg(SESSION_MAIN_FILE)
197 .stdout(Stdio::piped())
198 .stderr(Stdio::piped())
199 .current_dir(self.workspace.path());
200 cmd.output().with_context(|| {
201 format!(
202 "failed to execute {} for Lua session",
203 self.interpreter.display()
204 )
205 })
206 }
207
208 fn normalize_output(bytes: &[u8]) -> String {
209 String::from_utf8_lossy(bytes)
210 .replace("\r\n", "\n")
211 .replace('\r', "")
212 }
213
214 fn diff_outputs(previous: &str, current: &str) -> String {
215 if let Some(suffix) = current.strip_prefix(previous) {
216 suffix.to_string()
217 } else {
218 current.to_string()
219 }
220 }
221}
222
223fn looks_like_expression_snippet(code: &str) -> bool {
224 if code.is_empty() || code.contains('\n') {
225 return false;
226 }
227
228 let trimmed = code.trim();
229 if trimmed.is_empty() {
230 return false;
231 }
232
233 let lower = trimmed.to_ascii_lowercase();
234 const CONTROL_KEYWORDS: &[&str] = &[
235 "local", "function", "for", "while", "repeat", "if", "do", "return", "break", "goto", "end",
236 ];
237
238 for kw in CONTROL_KEYWORDS {
239 if lower == *kw
240 || lower.starts_with(&format!("{} ", kw))
241 || lower.starts_with(&format!("{}(", kw))
242 || lower.starts_with(&format!("{}\t", kw))
243 {
244 return false;
245 }
246 }
247
248 if lower.starts_with("--") {
249 return false;
250 }
251
252 if has_assignment_operator(trimmed) {
253 return false;
254 }
255
256 true
257}
258
259fn has_assignment_operator(code: &str) -> bool {
260 let bytes = code.as_bytes();
261 for (i, byte) in bytes.iter().enumerate() {
262 if *byte == b'=' {
263 let prev = if i > 0 { bytes[i - 1] } else { b'\0' };
264 let next = if i + 1 < bytes.len() {
265 bytes[i + 1]
266 } else {
267 b'\0'
268 };
269 let part_of_comparison = matches!(prev, b'=' | b'<' | b'>' | b'~') || next == b'=';
270 if !part_of_comparison {
271 return true;
272 }
273 }
274 }
275 false
276}
277
278fn wrap_expression_snippet(code: &str) -> String {
279 let trimmed = code.trim();
280 format!(
281 "do\n local __run_pack = table.pack(({expr}))\n local __run_n = __run_pack.n or #__run_pack\n if __run_n > 0 then\n for __run_i = 1, __run_n do\n if __run_i > 1 then io.write(\"\\t\") end\n local __run_val = __run_pack[__run_i]\n if __run_val == nil then\n io.write(\"nil\")\n else\n io.write(tostring(__run_val))\n end\n end\n io.write(\"\\n\")\n end\nend\n",
282 expr = trimmed
283 )
284}
285impl LanguageSession for LuaSession {
286 fn language_id(&self) -> &str {
287 self.language_id()
288 }
289
290 fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
291 let trimmed = code.trim();
292
293 if trimmed.eq_ignore_ascii_case(":reset") {
294 self.statements.clear();
295 self.last_stdout.clear();
296 self.last_stderr.clear();
297 self.persist_source()?;
298 return Ok(ExecutionOutcome {
299 language: self.language_id().to_string(),
300 exit_code: None,
301 stdout: String::new(),
302 stderr: String::new(),
303 duration: Duration::default(),
304 });
305 }
306
307 if trimmed.eq_ignore_ascii_case(":help") {
308 return Ok(ExecutionOutcome {
309 language: self.language_id().to_string(),
310 exit_code: None,
311 stdout:
312 "Lua commands:\n :reset - clear session state\n :help - show this message\n"
313 .to_string(),
314 stderr: String::new(),
315 duration: Duration::default(),
316 });
317 }
318
319 if trimmed.is_empty() {
320 return Ok(ExecutionOutcome {
321 language: self.language_id().to_string(),
322 exit_code: None,
323 stdout: String::new(),
324 stderr: String::new(),
325 duration: Duration::default(),
326 });
327 }
328
329 let (effective_code, force_expression) = if let Some(stripped) = trimmed.strip_prefix('=') {
330 (stripped.trim(), true)
331 } else {
332 (trimmed, false)
333 };
334
335 let is_expression = force_expression || looks_like_expression_snippet(effective_code);
336 let statement = if is_expression {
337 wrap_expression_snippet(effective_code)
338 } else {
339 format!("{}\n", code.trim_end_matches(['\r', '\n']))
340 };
341
342 let previous_stdout = self.last_stdout.clone();
343 let previous_stderr = self.last_stderr.clone();
344
345 self.statements.push(statement);
346 self.persist_source()?;
347
348 let start = Instant::now();
349 let output = self.run_program()?;
350 let stdout_full = LuaSession::normalize_output(&output.stdout);
351 let stderr_full = LuaSession::normalize_output(&output.stderr);
352 let stdout = LuaSession::diff_outputs(&self.last_stdout, &stdout_full);
353 let stderr = LuaSession::diff_outputs(&self.last_stderr, &stderr_full);
354 let duration = start.elapsed();
355
356 if output.status.success() {
357 if is_expression {
358 self.statements.pop();
359 self.persist_source()?;
360 self.last_stdout = previous_stdout;
361 self.last_stderr = previous_stderr;
362 } else {
363 self.last_stdout = stdout_full;
364 self.last_stderr = stderr_full;
365 }
366 Ok(ExecutionOutcome {
367 language: self.language_id().to_string(),
368 exit_code: output.status.code(),
369 stdout,
370 stderr,
371 duration,
372 })
373 } else {
374 self.statements.pop();
375 self.persist_source()?;
376 self.last_stdout = previous_stdout;
377 self.last_stderr = previous_stderr;
378 Ok(ExecutionOutcome {
379 language: self.language_id().to_string(),
380 exit_code: output.status.code(),
381 stdout,
382 stderr,
383 duration,
384 })
385 }
386 }
387
388 fn shutdown(&mut self) -> Result<()> {
389 Ok(())
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::{LuaSession, looks_like_expression_snippet, wrap_expression_snippet};
396
397 #[test]
398 fn diff_outputs_appends_only_suffix() {
399 let previous = "a\nb\n";
400 let current = "a\nb\nc\n";
401 assert_eq!(LuaSession::diff_outputs(previous, current), "c\n");
402
403 let previous = "a\n";
404 let current = "x\na\n";
405 assert_eq!(LuaSession::diff_outputs(previous, current), "x\na\n");
406 }
407
408 #[test]
409 fn detects_simple_expression() {
410 assert!(looks_like_expression_snippet("a"));
411 assert!(looks_like_expression_snippet("foo(bar)"));
412 assert!(!looks_like_expression_snippet("local a = 1"));
413 assert!(!looks_like_expression_snippet("a = 1"));
414 }
415
416 #[test]
417 fn wraps_expression_with_print_block() {
418 let wrapped = wrap_expression_snippet("a");
419 assert!(wrapped.contains("table.pack((a))"));
420 assert!(wrapped.contains("io.write(\"\\n\")"));
421 }
422}