1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::{Duration, Instant};
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
10
11pub struct LuaEngine {
12 interpreter: Option<PathBuf>,
13}
14
15impl LuaEngine {
16 pub fn new() -> Self {
17 Self {
18 interpreter: resolve_lua_binary(),
19 }
20 }
21
22 fn ensure_interpreter(&self) -> Result<&Path> {
23 self.interpreter.as_deref().ok_or_else(|| {
24 anyhow::anyhow!(
25 "Lua support requires the `lua` executable. Install it from https://www.lua.org/download.html and ensure it is on your PATH."
26 )
27 })
28 }
29
30 fn write_temp_script(&self, code: &str) -> Result<(tempfile::TempDir, PathBuf)> {
31 let dir = Builder::new()
32 .prefix("run-lua")
33 .tempdir()
34 .context("failed to create temporary directory for lua source")?;
35 let path = dir.path().join("snippet.lua");
36 let mut contents = code.to_string();
37 if !contents.ends_with('\n') {
38 contents.push('\n');
39 }
40 std::fs::write(&path, contents).with_context(|| {
41 format!("failed to write temporary Lua source to {}", path.display())
42 })?;
43 Ok((dir, path))
44 }
45
46 fn execute_script(&self, script: &Path) -> Result<std::process::Output> {
47 let interpreter = self.ensure_interpreter()?;
48 let mut cmd = Command::new(interpreter);
49 cmd.arg(script)
50 .stdout(Stdio::piped())
51 .stderr(Stdio::piped());
52 cmd.stdin(Stdio::inherit());
53 if let Some(dir) = script.parent() {
54 cmd.current_dir(dir);
55 }
56 cmd.output().with_context(|| {
57 format!(
58 "failed to execute {} with script {}",
59 interpreter.display(),
60 script.display()
61 )
62 })
63 }
64}
65
66impl LanguageEngine for LuaEngine {
67 fn id(&self) -> &'static str {
68 "lua"
69 }
70
71 fn display_name(&self) -> &'static str {
72 "Lua"
73 }
74
75 fn aliases(&self) -> &[&'static str] {
76 &[]
77 }
78
79 fn supports_sessions(&self) -> bool {
80 self.interpreter.is_some()
81 }
82
83 fn validate(&self) -> Result<()> {
84 let interpreter = self.ensure_interpreter()?;
85 let mut cmd = Command::new(interpreter);
86 cmd.arg("-v").stdout(Stdio::null()).stderr(Stdio::null());
87 cmd.status()
88 .with_context(|| format!("failed to invoke {}", interpreter.display()))?
89 .success()
90 .then_some(())
91 .ok_or_else(|| anyhow::anyhow!("{} is not executable", interpreter.display()))
92 }
93
94 fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
95 let start = Instant::now();
96 let (temp_dir, script_path) = match payload {
97 ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
98 let (dir, path) = self.write_temp_script(code)?;
99 (Some(dir), path)
100 }
101 ExecutionPayload::File { path } => (None, path.clone()),
102 };
103
104 let output = self.execute_script(&script_path)?;
105
106 drop(temp_dir);
107
108 Ok(ExecutionOutcome {
109 language: self.id().to_string(),
110 exit_code: output.status.code(),
111 stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
112 stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
113 duration: start.elapsed(),
114 })
115 }
116
117 fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
118 let interpreter = self.ensure_interpreter()?.to_path_buf();
119 let session = LuaSession::new(interpreter)?;
120 Ok(Box::new(session))
121 }
122}
123
124fn resolve_lua_binary() -> Option<PathBuf> {
125 which::which("lua").ok()
126}
127
128const SESSION_MAIN_FILE: &str = "session.lua";
129
130struct LuaSession {
131 interpreter: PathBuf,
132 workspace: TempDir,
133 statements: Vec<String>,
134 last_stdout: String,
135 last_stderr: String,
136}
137
138impl LuaSession {
139 fn new(interpreter: PathBuf) -> Result<Self> {
140 let workspace = TempDir::new().context("failed to create Lua session workspace")?;
141 let session = Self {
142 interpreter,
143 workspace,
144 statements: Vec::new(),
145 last_stdout: String::new(),
146 last_stderr: String::new(),
147 };
148 session.persist_source()?;
149 Ok(session)
150 }
151
152 fn language_id(&self) -> &str {
153 "lua"
154 }
155
156 fn source_path(&self) -> PathBuf {
157 self.workspace.path().join(SESSION_MAIN_FILE)
158 }
159
160 fn persist_source(&self) -> Result<()> {
161 let path = self.source_path();
162 let mut source = String::new();
163 if self.statements.is_empty() {
164 source.push_str("-- session body\n");
165 } else {
166 for stmt in &self.statements {
167 source.push_str(stmt);
168 if !stmt.ends_with('\n') {
169 source.push('\n');
170 }
171 }
172 }
173 fs::write(&path, source)
174 .with_context(|| format!("failed to write Lua session source at {}", path.display()))
175 }
176
177 fn run_program(&self) -> Result<std::process::Output> {
178 let mut cmd = Command::new(&self.interpreter);
179 cmd.arg(SESSION_MAIN_FILE)
180 .stdout(Stdio::piped())
181 .stderr(Stdio::piped())
182 .current_dir(self.workspace.path());
183 cmd.output().with_context(|| {
184 format!(
185 "failed to execute {} for Lua session",
186 self.interpreter.display()
187 )
188 })
189 }
190
191 fn normalize_output(bytes: &[u8]) -> String {
192 String::from_utf8_lossy(bytes)
193 .replace("\r\n", "\n")
194 .replace('\r', "")
195 }
196
197 fn diff_outputs(previous: &str, current: &str) -> String {
198 if let Some(suffix) = current.strip_prefix(previous) {
199 suffix.to_string()
200 } else {
201 current.to_string()
202 }
203 }
204}
205
206fn looks_like_expression_snippet(code: &str) -> bool {
207 if code.is_empty() || code.contains('\n') {
208 return false;
209 }
210
211 let trimmed = code.trim();
212 if trimmed.is_empty() {
213 return false;
214 }
215
216 let lower = trimmed.to_ascii_lowercase();
217 const CONTROL_KEYWORDS: &[&str] = &[
218 "local", "function", "for", "while", "repeat", "if", "do", "return", "break", "goto", "end",
219 ];
220
221 for kw in CONTROL_KEYWORDS {
222 if lower == *kw
223 || lower.starts_with(&format!("{} ", kw))
224 || lower.starts_with(&format!("{}(", kw))
225 || lower.starts_with(&format!("{}\t", kw))
226 {
227 return false;
228 }
229 }
230
231 if lower.starts_with("--") {
232 return false;
233 }
234
235 if has_assignment_operator(trimmed) {
236 return false;
237 }
238
239 true
240}
241
242fn has_assignment_operator(code: &str) -> bool {
243 let bytes = code.as_bytes();
244 for (i, byte) in bytes.iter().enumerate() {
245 if *byte == b'=' {
246 let prev = if i > 0 { bytes[i - 1] } else { b'\0' };
247 let next = if i + 1 < bytes.len() {
248 bytes[i + 1]
249 } else {
250 b'\0'
251 };
252 let part_of_comparison = matches!(prev, b'=' | b'<' | b'>' | b'~') || next == b'=';
253 if !part_of_comparison {
254 return true;
255 }
256 }
257 }
258 false
259}
260
261fn wrap_expression_snippet(code: &str) -> String {
262 let trimmed = code.trim();
263 format!(
264 "do\n local __run_pack = table.pack(({expr}))\n local __run_n = __run_pack.n or #__run_pack\n if __run_n > 0 then\n for __run_i = 1, __run_n do\n if __run_i > 1 then io.write(\"\\t\") end\n local __run_val = __run_pack[__run_i]\n if __run_val == nil then\n io.write(\"nil\")\n else\n io.write(tostring(__run_val))\n end\n end\n io.write(\"\\n\")\n end\nend\n",
265 expr = trimmed
266 )
267}
268impl LanguageSession for LuaSession {
269 fn language_id(&self) -> &str {
270 self.language_id()
271 }
272
273 fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
274 let trimmed = code.trim();
275
276 if trimmed.eq_ignore_ascii_case(":reset") {
277 self.statements.clear();
278 self.last_stdout.clear();
279 self.last_stderr.clear();
280 self.persist_source()?;
281 return Ok(ExecutionOutcome {
282 language: self.language_id().to_string(),
283 exit_code: None,
284 stdout: String::new(),
285 stderr: String::new(),
286 duration: Duration::default(),
287 });
288 }
289
290 if trimmed.eq_ignore_ascii_case(":help") {
291 return Ok(ExecutionOutcome {
292 language: self.language_id().to_string(),
293 exit_code: None,
294 stdout:
295 "Lua commands:\n :reset — clear session state\n :help — show this message\n"
296 .to_string(),
297 stderr: String::new(),
298 duration: Duration::default(),
299 });
300 }
301
302 if trimmed.is_empty() {
303 return Ok(ExecutionOutcome {
304 language: self.language_id().to_string(),
305 exit_code: None,
306 stdout: String::new(),
307 stderr: String::new(),
308 duration: Duration::default(),
309 });
310 }
311
312 let (effective_code, force_expression) = if trimmed.starts_with('=') {
313 (trimmed[1..].trim(), true)
314 } else {
315 (trimmed, false)
316 };
317
318 let is_expression = force_expression || looks_like_expression_snippet(effective_code);
319 let statement = if is_expression {
320 wrap_expression_snippet(effective_code)
321 } else {
322 format!("{}\n", code.trim_end_matches(|c| c == '\r' || c == '\n'))
323 };
324
325 let previous_stdout = self.last_stdout.clone();
326 let previous_stderr = self.last_stderr.clone();
327
328 self.statements.push(statement);
329 self.persist_source()?;
330
331 let start = Instant::now();
332 let output = self.run_program()?;
333 let stdout_full = LuaSession::normalize_output(&output.stdout);
334 let stderr_full = LuaSession::normalize_output(&output.stderr);
335 let stdout = LuaSession::diff_outputs(&self.last_stdout, &stdout_full);
336 let stderr = LuaSession::diff_outputs(&self.last_stderr, &stderr_full);
337 let duration = start.elapsed();
338
339 if output.status.success() {
340 if is_expression {
341 self.statements.pop();
342 self.persist_source()?;
343 self.last_stdout = previous_stdout;
344 self.last_stderr = previous_stderr;
345 } else {
346 self.last_stdout = stdout_full;
347 self.last_stderr = stderr_full;
348 }
349 Ok(ExecutionOutcome {
350 language: self.language_id().to_string(),
351 exit_code: output.status.code(),
352 stdout,
353 stderr,
354 duration,
355 })
356 } else {
357 self.statements.pop();
358 self.persist_source()?;
359 self.last_stdout = previous_stdout;
360 self.last_stderr = previous_stderr;
361 Ok(ExecutionOutcome {
362 language: self.language_id().to_string(),
363 exit_code: output.status.code(),
364 stdout,
365 stderr,
366 duration,
367 })
368 }
369 }
370
371 fn shutdown(&mut self) -> Result<()> {
372 Ok(())
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::{LuaSession, looks_like_expression_snippet, wrap_expression_snippet};
379
380 #[test]
381 fn diff_outputs_appends_only_suffix() {
382 let previous = "a\nb\n";
383 let current = "a\nb\nc\n";
384 assert_eq!(LuaSession::diff_outputs(previous, current), "c\n");
385
386 let previous = "a\n";
387 let current = "x\na\n";
388 assert_eq!(LuaSession::diff_outputs(previous, current), "x\na\n");
389 }
390
391 #[test]
392 fn detects_simple_expression() {
393 assert!(looks_like_expression_snippet("a"));
394 assert!(looks_like_expression_snippet("foo(bar)"));
395 assert!(!looks_like_expression_snippet("local a = 1"));
396 assert!(!looks_like_expression_snippet("a = 1"));
397 }
398
399 #[test]
400 fn wraps_expression_with_print_block() {
401 let wrapped = wrap_expression_snippet("a");
402 assert!(wrapped.contains("table.pack((a))"));
403 assert!(wrapped.contains("io.write(\"\\n\")"));
404 }
405}