Skip to main content

run/engine/
lua.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::{Duration, Instant};
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
10
11pub struct LuaEngine {
12    interpreter: Option<PathBuf>,
13}
14
15impl Default for LuaEngine {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl LuaEngine {
22    pub fn new() -> Self {
23        Self {
24            interpreter: resolve_lua_binary(),
25        }
26    }
27
28    fn ensure_interpreter(&self) -> Result<&Path> {
29        self.interpreter.as_deref().ok_or_else(|| {
30            anyhow::anyhow!(
31                "Lua support requires the `lua` executable. Install it from https://www.lua.org/download.html and ensure it is on your PATH." 
32            )
33        })
34    }
35
36    fn write_temp_script(&self, code: &str) -> Result<(tempfile::TempDir, PathBuf)> {
37        let dir = Builder::new()
38            .prefix("run-lua")
39            .tempdir()
40            .context("failed to create temporary directory for lua source")?;
41        let path = dir.path().join("snippet.lua");
42        let mut contents = code.to_string();
43        if !contents.ends_with('\n') {
44            contents.push('\n');
45        }
46        std::fs::write(&path, contents).with_context(|| {
47            format!("failed to write temporary Lua source to {}", path.display())
48        })?;
49        Ok((dir, path))
50    }
51
52    fn execute_script(&self, script: &Path) -> Result<std::process::Output> {
53        let interpreter = self.ensure_interpreter()?;
54        let mut cmd = Command::new(interpreter);
55        cmd.arg(script)
56            .stdout(Stdio::piped())
57            .stderr(Stdio::piped());
58        cmd.stdin(Stdio::inherit());
59        if let Some(dir) = script.parent() {
60            cmd.current_dir(dir);
61        }
62        cmd.output().with_context(|| {
63            format!(
64                "failed to execute {} with script {}",
65                interpreter.display(),
66                script.display()
67            )
68        })
69    }
70}
71
72impl LanguageEngine for LuaEngine {
73    fn id(&self) -> &'static str {
74        "lua"
75    }
76
77    fn display_name(&self) -> &'static str {
78        "Lua"
79    }
80
81    fn aliases(&self) -> &[&'static str] {
82        &[]
83    }
84
85    fn supports_sessions(&self) -> bool {
86        self.interpreter.is_some()
87    }
88
89    fn validate(&self) -> Result<()> {
90        let interpreter = self.ensure_interpreter()?;
91        let mut cmd = Command::new(interpreter);
92        cmd.arg("-v").stdout(Stdio::null()).stderr(Stdio::null());
93        cmd.status()
94            .with_context(|| format!("failed to invoke {}", interpreter.display()))?
95            .success()
96            .then_some(())
97            .ok_or_else(|| anyhow::anyhow!("{} is not executable", interpreter.display()))
98    }
99
100    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
101        let start = Instant::now();
102        let (temp_dir, script_path) = match payload {
103            ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
104                let (dir, path) = self.write_temp_script(code)?;
105                (Some(dir), path)
106            }
107            ExecutionPayload::File { path } => (None, path.clone()),
108        };
109
110        let output = self.execute_script(&script_path)?;
111
112        drop(temp_dir);
113
114        Ok(ExecutionOutcome {
115            language: self.id().to_string(),
116            exit_code: output.status.code(),
117            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
118            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
119            duration: start.elapsed(),
120        })
121    }
122
123    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
124        let interpreter = self.ensure_interpreter()?.to_path_buf();
125        let session = LuaSession::new(interpreter)?;
126        Ok(Box::new(session))
127    }
128}
129
130fn resolve_lua_binary() -> Option<PathBuf> {
131    which::which("lua").ok()
132}
133
134const SESSION_MAIN_FILE: &str = "session.lua";
135
136struct LuaSession {
137    interpreter: PathBuf,
138    workspace: TempDir,
139    statements: Vec<String>,
140    last_stdout: String,
141    last_stderr: String,
142}
143
144impl LuaSession {
145    fn new(interpreter: PathBuf) -> Result<Self> {
146        let workspace = TempDir::new().context("failed to create Lua session workspace")?;
147        let session = Self {
148            interpreter,
149            workspace,
150            statements: Vec::new(),
151            last_stdout: String::new(),
152            last_stderr: String::new(),
153        };
154        session.persist_source()?;
155        Ok(session)
156    }
157
158    fn language_id(&self) -> &str {
159        "lua"
160    }
161
162    fn source_path(&self) -> PathBuf {
163        self.workspace.path().join(SESSION_MAIN_FILE)
164    }
165
166    fn persist_source(&self) -> Result<()> {
167        let path = self.source_path();
168        let mut source = String::new();
169        if self.statements.is_empty() {
170            source.push_str("-- session body\n");
171        } else {
172            for stmt in &self.statements {
173                source.push_str(stmt);
174                if !stmt.ends_with('\n') {
175                    source.push('\n');
176                }
177            }
178        }
179        fs::write(&path, source)
180            .with_context(|| format!("failed to write Lua session source at {}", path.display()))
181    }
182
183    fn run_program(&self) -> Result<std::process::Output> {
184        let mut cmd = Command::new(&self.interpreter);
185        cmd.arg(SESSION_MAIN_FILE)
186            .stdout(Stdio::piped())
187            .stderr(Stdio::piped())
188            .current_dir(self.workspace.path());
189        cmd.output().with_context(|| {
190            format!(
191                "failed to execute {} for Lua session",
192                self.interpreter.display()
193            )
194        })
195    }
196
197    fn normalize_output(bytes: &[u8]) -> String {
198        String::from_utf8_lossy(bytes)
199            .replace("\r\n", "\n")
200            .replace('\r', "")
201    }
202
203    fn diff_outputs(previous: &str, current: &str) -> String {
204        if let Some(suffix) = current.strip_prefix(previous) {
205            suffix.to_string()
206        } else {
207            current.to_string()
208        }
209    }
210}
211
212fn looks_like_expression_snippet(code: &str) -> bool {
213    if code.is_empty() || code.contains('\n') {
214        return false;
215    }
216
217    let trimmed = code.trim();
218    if trimmed.is_empty() {
219        return false;
220    }
221
222    let lower = trimmed.to_ascii_lowercase();
223    const CONTROL_KEYWORDS: &[&str] = &[
224        "local", "function", "for", "while", "repeat", "if", "do", "return", "break", "goto", "end",
225    ];
226
227    for kw in CONTROL_KEYWORDS {
228        if lower == *kw
229            || lower.starts_with(&format!("{} ", kw))
230            || lower.starts_with(&format!("{}(", kw))
231            || lower.starts_with(&format!("{}\t", kw))
232        {
233            return false;
234        }
235    }
236
237    if lower.starts_with("--") {
238        return false;
239    }
240
241    if has_assignment_operator(trimmed) {
242        return false;
243    }
244
245    true
246}
247
248fn has_assignment_operator(code: &str) -> bool {
249    let bytes = code.as_bytes();
250    for (i, byte) in bytes.iter().enumerate() {
251        if *byte == b'=' {
252            let prev = if i > 0 { bytes[i - 1] } else { b'\0' };
253            let next = if i + 1 < bytes.len() {
254                bytes[i + 1]
255            } else {
256                b'\0'
257            };
258            let part_of_comparison = matches!(prev, b'=' | b'<' | b'>' | b'~') || next == b'=';
259            if !part_of_comparison {
260                return true;
261            }
262        }
263    }
264    false
265}
266
267fn wrap_expression_snippet(code: &str) -> String {
268    let trimmed = code.trim();
269    format!(
270        "do\n    local __run_pack = table.pack(({expr}))\n    local __run_n = __run_pack.n or #__run_pack\n    if __run_n > 0 then\n        for __run_i = 1, __run_n do\n            if __run_i > 1 then io.write(\"\\t\") end\n            local __run_val = __run_pack[__run_i]\n            if __run_val == nil then\n                io.write(\"nil\")\n            else\n                io.write(tostring(__run_val))\n            end\n        end\n        io.write(\"\\n\")\n    end\nend\n",
271        expr = trimmed
272    )
273}
274impl LanguageSession for LuaSession {
275    fn language_id(&self) -> &str {
276        self.language_id()
277    }
278
279    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
280        let trimmed = code.trim();
281
282        if trimmed.eq_ignore_ascii_case(":reset") {
283            self.statements.clear();
284            self.last_stdout.clear();
285            self.last_stderr.clear();
286            self.persist_source()?;
287            return Ok(ExecutionOutcome {
288                language: self.language_id().to_string(),
289                exit_code: None,
290                stdout: String::new(),
291                stderr: String::new(),
292                duration: Duration::default(),
293            });
294        }
295
296        if trimmed.eq_ignore_ascii_case(":help") {
297            return Ok(ExecutionOutcome {
298                language: self.language_id().to_string(),
299                exit_code: None,
300                stdout:
301                    "Lua commands:\n  :reset - clear session state\n  :help  - show this message\n"
302                        .to_string(),
303                stderr: String::new(),
304                duration: Duration::default(),
305            });
306        }
307
308        if trimmed.is_empty() {
309            return Ok(ExecutionOutcome {
310                language: self.language_id().to_string(),
311                exit_code: None,
312                stdout: String::new(),
313                stderr: String::new(),
314                duration: Duration::default(),
315            });
316        }
317
318        let (effective_code, force_expression) = if let Some(stripped) = trimmed.strip_prefix('=') {
319            (stripped.trim(), true)
320        } else {
321            (trimmed, false)
322        };
323
324        let is_expression = force_expression || looks_like_expression_snippet(effective_code);
325        let statement = if is_expression {
326            wrap_expression_snippet(effective_code)
327        } else {
328            format!("{}\n", code.trim_end_matches(['\r', '\n']))
329        };
330
331        let previous_stdout = self.last_stdout.clone();
332        let previous_stderr = self.last_stderr.clone();
333
334        self.statements.push(statement);
335        self.persist_source()?;
336
337        let start = Instant::now();
338        let output = self.run_program()?;
339        let stdout_full = LuaSession::normalize_output(&output.stdout);
340        let stderr_full = LuaSession::normalize_output(&output.stderr);
341        let stdout = LuaSession::diff_outputs(&self.last_stdout, &stdout_full);
342        let stderr = LuaSession::diff_outputs(&self.last_stderr, &stderr_full);
343        let duration = start.elapsed();
344
345        if output.status.success() {
346            if is_expression {
347                self.statements.pop();
348                self.persist_source()?;
349                self.last_stdout = previous_stdout;
350                self.last_stderr = previous_stderr;
351            } else {
352                self.last_stdout = stdout_full;
353                self.last_stderr = stderr_full;
354            }
355            Ok(ExecutionOutcome {
356                language: self.language_id().to_string(),
357                exit_code: output.status.code(),
358                stdout,
359                stderr,
360                duration,
361            })
362        } else {
363            self.statements.pop();
364            self.persist_source()?;
365            self.last_stdout = previous_stdout;
366            self.last_stderr = previous_stderr;
367            Ok(ExecutionOutcome {
368                language: self.language_id().to_string(),
369                exit_code: output.status.code(),
370                stdout,
371                stderr,
372                duration,
373            })
374        }
375    }
376
377    fn shutdown(&mut self) -> Result<()> {
378        Ok(())
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::{LuaSession, looks_like_expression_snippet, wrap_expression_snippet};
385
386    #[test]
387    fn diff_outputs_appends_only_suffix() {
388        let previous = "a\nb\n";
389        let current = "a\nb\nc\n";
390        assert_eq!(LuaSession::diff_outputs(previous, current), "c\n");
391
392        let previous = "a\n";
393        let current = "x\na\n";
394        assert_eq!(LuaSession::diff_outputs(previous, current), "x\na\n");
395    }
396
397    #[test]
398    fn detects_simple_expression() {
399        assert!(looks_like_expression_snippet("a"));
400        assert!(looks_like_expression_snippet("foo(bar)"));
401        assert!(!looks_like_expression_snippet("local a = 1"));
402        assert!(!looks_like_expression_snippet("a = 1"));
403    }
404
405    #[test]
406    fn wraps_expression_with_print_block() {
407        let wrapped = wrap_expression_snippet("a");
408        assert!(wrapped.contains("table.pack((a))"));
409        assert!(wrapped.contains("io.write(\"\\n\")"));
410    }
411}