run/engine/
lua.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::{Duration, Instant};
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
10
11pub struct LuaEngine {
12    interpreter: Option<PathBuf>,
13}
14
15impl LuaEngine {
16    pub fn new() -> Self {
17        Self {
18            interpreter: resolve_lua_binary(),
19        }
20    }
21
22    fn ensure_interpreter(&self) -> Result<&Path> {
23        self.interpreter.as_deref().ok_or_else(|| {
24            anyhow::anyhow!(
25                "Lua support requires the `lua` executable. Install it from https://www.lua.org/download.html and ensure it is on your PATH." 
26            )
27        })
28    }
29
30    fn write_temp_script(&self, code: &str) -> Result<(tempfile::TempDir, PathBuf)> {
31        let dir = Builder::new()
32            .prefix("run-lua")
33            .tempdir()
34            .context("failed to create temporary directory for lua source")?;
35        let path = dir.path().join("snippet.lua");
36        let mut contents = code.to_string();
37        if !contents.ends_with('\n') {
38            contents.push('\n');
39        }
40        std::fs::write(&path, contents).with_context(|| {
41            format!("failed to write temporary Lua source to {}", path.display())
42        })?;
43        Ok((dir, path))
44    }
45
46    fn execute_script(&self, script: &Path) -> Result<std::process::Output> {
47        let interpreter = self.ensure_interpreter()?;
48        let mut cmd = Command::new(interpreter);
49        cmd.arg(script)
50            .stdout(Stdio::piped())
51            .stderr(Stdio::piped());
52        cmd.stdin(Stdio::inherit());
53        if let Some(dir) = script.parent() {
54            cmd.current_dir(dir);
55        }
56        cmd.output().with_context(|| {
57            format!(
58                "failed to execute {} with script {}",
59                interpreter.display(),
60                script.display()
61            )
62        })
63    }
64}
65
66impl LanguageEngine for LuaEngine {
67    fn id(&self) -> &'static str {
68        "lua"
69    }
70
71    fn display_name(&self) -> &'static str {
72        "Lua"
73    }
74
75    fn aliases(&self) -> &[&'static str] {
76        &[]
77    }
78
79    fn supports_sessions(&self) -> bool {
80        self.interpreter.is_some()
81    }
82
83    fn validate(&self) -> Result<()> {
84        let interpreter = self.ensure_interpreter()?;
85        let mut cmd = Command::new(interpreter);
86        cmd.arg("-v").stdout(Stdio::null()).stderr(Stdio::null());
87        cmd.status()
88            .with_context(|| format!("failed to invoke {}", interpreter.display()))?
89            .success()
90            .then_some(())
91            .ok_or_else(|| anyhow::anyhow!("{} is not executable", interpreter.display()))
92    }
93
94    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
95        let start = Instant::now();
96        let (temp_dir, script_path) = match payload {
97            ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
98                let (dir, path) = self.write_temp_script(code)?;
99                (Some(dir), path)
100            }
101            ExecutionPayload::File { path } => (None, path.clone()),
102        };
103
104        let output = self.execute_script(&script_path)?;
105
106        drop(temp_dir);
107
108        Ok(ExecutionOutcome {
109            language: self.id().to_string(),
110            exit_code: output.status.code(),
111            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
112            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
113            duration: start.elapsed(),
114        })
115    }
116
117    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
118        let interpreter = self.ensure_interpreter()?.to_path_buf();
119        let session = LuaSession::new(interpreter)?;
120        Ok(Box::new(session))
121    }
122}
123
124fn resolve_lua_binary() -> Option<PathBuf> {
125    which::which("lua").ok()
126}
127
128const SESSION_MAIN_FILE: &str = "session.lua";
129
130struct LuaSession {
131    interpreter: PathBuf,
132    workspace: TempDir,
133    statements: Vec<String>,
134    last_stdout: String,
135    last_stderr: String,
136}
137
138impl LuaSession {
139    fn new(interpreter: PathBuf) -> Result<Self> {
140        let workspace = TempDir::new().context("failed to create Lua session workspace")?;
141        let session = Self {
142            interpreter,
143            workspace,
144            statements: Vec::new(),
145            last_stdout: String::new(),
146            last_stderr: String::new(),
147        };
148        session.persist_source()?;
149        Ok(session)
150    }
151
152    fn language_id(&self) -> &str {
153        "lua"
154    }
155
156    fn source_path(&self) -> PathBuf {
157        self.workspace.path().join(SESSION_MAIN_FILE)
158    }
159
160    fn persist_source(&self) -> Result<()> {
161        let path = self.source_path();
162        let mut source = String::new();
163        if self.statements.is_empty() {
164            source.push_str("-- session body\n");
165        } else {
166            for stmt in &self.statements {
167                source.push_str(stmt);
168                if !stmt.ends_with('\n') {
169                    source.push('\n');
170                }
171            }
172        }
173        fs::write(&path, source)
174            .with_context(|| format!("failed to write Lua session source at {}", path.display()))
175    }
176
177    fn run_program(&self) -> Result<std::process::Output> {
178        let mut cmd = Command::new(&self.interpreter);
179        cmd.arg(SESSION_MAIN_FILE)
180            .stdout(Stdio::piped())
181            .stderr(Stdio::piped())
182            .current_dir(self.workspace.path());
183        cmd.output().with_context(|| {
184            format!(
185                "failed to execute {} for Lua session",
186                self.interpreter.display()
187            )
188        })
189    }
190
191    fn normalize_output(bytes: &[u8]) -> String {
192        String::from_utf8_lossy(bytes)
193            .replace("\r\n", "\n")
194            .replace('\r', "")
195    }
196
197    fn diff_outputs(previous: &str, current: &str) -> String {
198        if let Some(suffix) = current.strip_prefix(previous) {
199            suffix.to_string()
200        } else {
201            current.to_string()
202        }
203    }
204}
205
206fn looks_like_expression_snippet(code: &str) -> bool {
207    if code.is_empty() || code.contains('\n') {
208        return false;
209    }
210
211    let trimmed = code.trim();
212    if trimmed.is_empty() {
213        return false;
214    }
215
216    let lower = trimmed.to_ascii_lowercase();
217    const CONTROL_KEYWORDS: &[&str] = &[
218        "local", "function", "for", "while", "repeat", "if", "do", "return", "break", "goto", "end",
219    ];
220
221    for kw in CONTROL_KEYWORDS {
222        if lower == *kw
223            || lower.starts_with(&format!("{} ", kw))
224            || lower.starts_with(&format!("{}(", kw))
225            || lower.starts_with(&format!("{}\t", kw))
226        {
227            return false;
228        }
229    }
230
231    if lower.starts_with("--") {
232        return false;
233    }
234
235    if has_assignment_operator(trimmed) {
236        return false;
237    }
238
239    true
240}
241
242fn has_assignment_operator(code: &str) -> bool {
243    let bytes = code.as_bytes();
244    for (i, byte) in bytes.iter().enumerate() {
245        if *byte == b'=' {
246            let prev = if i > 0 { bytes[i - 1] } else { b'\0' };
247            let next = if i + 1 < bytes.len() {
248                bytes[i + 1]
249            } else {
250                b'\0'
251            };
252            let part_of_comparison = matches!(prev, b'=' | b'<' | b'>' | b'~') || next == b'=';
253            if !part_of_comparison {
254                return true;
255            }
256        }
257    }
258    false
259}
260
261fn wrap_expression_snippet(code: &str) -> String {
262    let trimmed = code.trim();
263    format!(
264        "do\n    local __run_pack = table.pack(({expr}))\n    local __run_n = __run_pack.n or #__run_pack\n    if __run_n > 0 then\n        for __run_i = 1, __run_n do\n            if __run_i > 1 then io.write(\"\\t\") end\n            local __run_val = __run_pack[__run_i]\n            if __run_val == nil then\n                io.write(\"nil\")\n            else\n                io.write(tostring(__run_val))\n            end\n        end\n        io.write(\"\\n\")\n    end\nend\n",
265        expr = trimmed
266    )
267}
268impl LanguageSession for LuaSession {
269    fn language_id(&self) -> &str {
270        self.language_id()
271    }
272
273    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
274        let trimmed = code.trim();
275
276        if trimmed.eq_ignore_ascii_case(":reset") {
277            self.statements.clear();
278            self.last_stdout.clear();
279            self.last_stderr.clear();
280            self.persist_source()?;
281            return Ok(ExecutionOutcome {
282                language: self.language_id().to_string(),
283                exit_code: None,
284                stdout: String::new(),
285                stderr: String::new(),
286                duration: Duration::default(),
287            });
288        }
289
290        if trimmed.eq_ignore_ascii_case(":help") {
291            return Ok(ExecutionOutcome {
292                language: self.language_id().to_string(),
293                exit_code: None,
294                stdout:
295                    "Lua commands:\n  :reset — clear session state\n  :help  — show this message\n"
296                        .to_string(),
297                stderr: String::new(),
298                duration: Duration::default(),
299            });
300        }
301
302        if trimmed.is_empty() {
303            return Ok(ExecutionOutcome {
304                language: self.language_id().to_string(),
305                exit_code: None,
306                stdout: String::new(),
307                stderr: String::new(),
308                duration: Duration::default(),
309            });
310        }
311
312        let (effective_code, force_expression) = if trimmed.starts_with('=') {
313            (trimmed[1..].trim(), true)
314        } else {
315            (trimmed, false)
316        };
317
318        let is_expression = force_expression || looks_like_expression_snippet(effective_code);
319        let statement = if is_expression {
320            wrap_expression_snippet(effective_code)
321        } else {
322            format!("{}\n", code.trim_end_matches(|c| c == '\r' || c == '\n'))
323        };
324
325        let previous_stdout = self.last_stdout.clone();
326        let previous_stderr = self.last_stderr.clone();
327
328        self.statements.push(statement);
329        self.persist_source()?;
330
331        let start = Instant::now();
332        let output = self.run_program()?;
333        let stdout_full = LuaSession::normalize_output(&output.stdout);
334        let stderr_full = LuaSession::normalize_output(&output.stderr);
335        let stdout = LuaSession::diff_outputs(&self.last_stdout, &stdout_full);
336        let stderr = LuaSession::diff_outputs(&self.last_stderr, &stderr_full);
337        let duration = start.elapsed();
338
339        if output.status.success() {
340            if is_expression {
341                self.statements.pop();
342                self.persist_source()?;
343                self.last_stdout = previous_stdout;
344                self.last_stderr = previous_stderr;
345            } else {
346                self.last_stdout = stdout_full;
347                self.last_stderr = stderr_full;
348            }
349            Ok(ExecutionOutcome {
350                language: self.language_id().to_string(),
351                exit_code: output.status.code(),
352                stdout,
353                stderr,
354                duration,
355            })
356        } else {
357            self.statements.pop();
358            self.persist_source()?;
359            self.last_stdout = previous_stdout;
360            self.last_stderr = previous_stderr;
361            Ok(ExecutionOutcome {
362                language: self.language_id().to_string(),
363                exit_code: output.status.code(),
364                stdout,
365                stderr,
366                duration,
367            })
368        }
369    }
370
371    fn shutdown(&mut self) -> Result<()> {
372        Ok(())
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::{LuaSession, looks_like_expression_snippet, wrap_expression_snippet};
379
380    #[test]
381    fn diff_outputs_appends_only_suffix() {
382        let previous = "a\nb\n";
383        let current = "a\nb\nc\n";
384        assert_eq!(LuaSession::diff_outputs(previous, current), "c\n");
385
386        let previous = "a\n";
387        let current = "x\na\n";
388        assert_eq!(LuaSession::diff_outputs(previous, current), "x\na\n");
389    }
390
391    #[test]
392    fn detects_simple_expression() {
393        assert!(looks_like_expression_snippet("a"));
394        assert!(looks_like_expression_snippet("foo(bar)"));
395        assert!(!looks_like_expression_snippet("local a = 1"));
396        assert!(!looks_like_expression_snippet("a = 1"));
397    }
398
399    #[test]
400    fn wraps_expression_with_print_block() {
401        let wrapped = wrap_expression_snippet("a");
402        assert!(wrapped.contains("table.pack((a))"));
403        assert!(wrapped.contains("io.write(\"\\n\")"));
404    }
405}