Skip to main content

run/engine/
lua.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::{Duration, Instant};
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{
10    ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, run_version_command,
11};
12
13pub struct LuaEngine {
14    interpreter: Option<PathBuf>,
15}
16
17impl Default for LuaEngine {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl LuaEngine {
24    pub fn new() -> Self {
25        Self {
26            interpreter: resolve_lua_binary(),
27        }
28    }
29
30    fn ensure_interpreter(&self) -> Result<&Path> {
31        self.interpreter.as_deref().ok_or_else(|| {
32            anyhow::anyhow!(
33                "Lua support requires the `lua` executable. Install it from https://www.lua.org/download.html and ensure it is on your PATH." 
34            )
35        })
36    }
37
38    fn write_temp_script(&self, code: &str) -> Result<(tempfile::TempDir, PathBuf)> {
39        let dir = Builder::new()
40            .prefix("run-lua")
41            .tempdir()
42            .context("failed to create temporary directory for lua source")?;
43        let path = dir.path().join("snippet.lua");
44        let mut contents = code.to_string();
45        if !contents.ends_with('\n') {
46            contents.push('\n');
47        }
48        std::fs::write(&path, contents).with_context(|| {
49            format!("failed to write temporary Lua source to {}", path.display())
50        })?;
51        Ok((dir, path))
52    }
53
54    fn execute_script(&self, script: &Path, args: &[String]) -> Result<std::process::Output> {
55        let interpreter = self.ensure_interpreter()?;
56        let mut cmd = Command::new(interpreter);
57        cmd.arg(script)
58            .args(args)
59            .stdout(Stdio::piped())
60            .stderr(Stdio::piped());
61        cmd.stdin(Stdio::inherit());
62        if let Some(dir) = script.parent() {
63            cmd.current_dir(dir);
64        }
65        cmd.output().with_context(|| {
66            format!(
67                "failed to execute {} with script {}",
68                interpreter.display(),
69                script.display()
70            )
71        })
72    }
73}
74
75impl LanguageEngine for LuaEngine {
76    fn id(&self) -> &'static str {
77        "lua"
78    }
79
80    fn display_name(&self) -> &'static str {
81        "Lua"
82    }
83
84    fn aliases(&self) -> &[&'static str] {
85        &[]
86    }
87
88    fn supports_sessions(&self) -> bool {
89        self.interpreter.is_some()
90    }
91
92    fn validate(&self) -> Result<()> {
93        let interpreter = self.ensure_interpreter()?;
94        let mut cmd = Command::new(interpreter);
95        cmd.arg("-v").stdout(Stdio::null()).stderr(Stdio::null());
96        cmd.status()
97            .with_context(|| format!("failed to invoke {}", interpreter.display()))?
98            .success()
99            .then_some(())
100            .ok_or_else(|| anyhow::anyhow!("{} is not executable", interpreter.display()))
101    }
102
103    fn toolchain_version(&self) -> Result<Option<String>> {
104        let interpreter = self.ensure_interpreter()?;
105        let mut cmd = Command::new(interpreter);
106        cmd.arg("-v");
107        let context = format!("{}", interpreter.display());
108        run_version_command(cmd, &context)
109    }
110
111    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
112        let start = Instant::now();
113        let (temp_dir, script_path) = match payload {
114            ExecutionPayload::Inline { code, .. } | ExecutionPayload::Stdin { code, .. } => {
115                let (dir, path) = self.write_temp_script(code)?;
116                (Some(dir), path)
117            }
118            ExecutionPayload::File { path, .. } => (None, path.clone()),
119        };
120
121        let output = self.execute_script(&script_path, payload.args())?;
122
123        drop(temp_dir);
124
125        Ok(ExecutionOutcome {
126            language: self.id().to_string(),
127            exit_code: output.status.code(),
128            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
129            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
130            duration: start.elapsed(),
131        })
132    }
133
134    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
135        let interpreter = self.ensure_interpreter()?.to_path_buf();
136        let session = LuaSession::new(interpreter)?;
137        Ok(Box::new(session))
138    }
139}
140
141fn resolve_lua_binary() -> Option<PathBuf> {
142    which::which("lua").ok()
143}
144
145const SESSION_MAIN_FILE: &str = "session.lua";
146
147struct LuaSession {
148    interpreter: PathBuf,
149    workspace: TempDir,
150    statements: Vec<String>,
151    last_stdout: String,
152    last_stderr: String,
153}
154
155impl LuaSession {
156    fn new(interpreter: PathBuf) -> Result<Self> {
157        let workspace = TempDir::new().context("failed to create Lua session workspace")?;
158        let session = Self {
159            interpreter,
160            workspace,
161            statements: Vec::new(),
162            last_stdout: String::new(),
163            last_stderr: String::new(),
164        };
165        session.persist_source()?;
166        Ok(session)
167    }
168
169    fn language_id(&self) -> &str {
170        "lua"
171    }
172
173    fn source_path(&self) -> PathBuf {
174        self.workspace.path().join(SESSION_MAIN_FILE)
175    }
176
177    fn persist_source(&self) -> Result<()> {
178        let path = self.source_path();
179        let mut source = String::new();
180        if self.statements.is_empty() {
181            source.push_str("-- session body\n");
182        } else {
183            for stmt in &self.statements {
184                source.push_str(stmt);
185                if !stmt.ends_with('\n') {
186                    source.push('\n');
187                }
188            }
189        }
190        fs::write(&path, source)
191            .with_context(|| format!("failed to write Lua session source at {}", path.display()))
192    }
193
194    fn run_program(&self) -> Result<std::process::Output> {
195        let mut cmd = Command::new(&self.interpreter);
196        cmd.arg(SESSION_MAIN_FILE)
197            .stdout(Stdio::piped())
198            .stderr(Stdio::piped())
199            .current_dir(self.workspace.path());
200        cmd.output().with_context(|| {
201            format!(
202                "failed to execute {} for Lua session",
203                self.interpreter.display()
204            )
205        })
206    }
207
208    fn normalize_output(bytes: &[u8]) -> String {
209        String::from_utf8_lossy(bytes)
210            .replace("\r\n", "\n")
211            .replace('\r', "")
212    }
213
214    fn diff_outputs(previous: &str, current: &str) -> String {
215        if let Some(suffix) = current.strip_prefix(previous) {
216            suffix.to_string()
217        } else {
218            current.to_string()
219        }
220    }
221}
222
223fn looks_like_expression_snippet(code: &str) -> bool {
224    if code.is_empty() || code.contains('\n') {
225        return false;
226    }
227
228    let trimmed = code.trim();
229    if trimmed.is_empty() {
230        return false;
231    }
232
233    let lower = trimmed.to_ascii_lowercase();
234    const CONTROL_KEYWORDS: &[&str] = &[
235        "local", "function", "for", "while", "repeat", "if", "do", "return", "break", "goto", "end",
236    ];
237
238    for kw in CONTROL_KEYWORDS {
239        if lower == *kw
240            || lower.starts_with(&format!("{} ", kw))
241            || lower.starts_with(&format!("{}(", kw))
242            || lower.starts_with(&format!("{}\t", kw))
243        {
244            return false;
245        }
246    }
247
248    if lower.starts_with("--") {
249        return false;
250    }
251
252    if has_assignment_operator(trimmed) {
253        return false;
254    }
255
256    true
257}
258
259fn has_assignment_operator(code: &str) -> bool {
260    let bytes = code.as_bytes();
261    for (i, byte) in bytes.iter().enumerate() {
262        if *byte == b'=' {
263            let prev = if i > 0 { bytes[i - 1] } else { b'\0' };
264            let next = if i + 1 < bytes.len() {
265                bytes[i + 1]
266            } else {
267                b'\0'
268            };
269            let part_of_comparison = matches!(prev, b'=' | b'<' | b'>' | b'~') || next == b'=';
270            if !part_of_comparison {
271                return true;
272            }
273        }
274    }
275    false
276}
277
278fn wrap_expression_snippet(code: &str) -> String {
279    let trimmed = code.trim();
280    format!(
281        "do\n    local __run_pack = table.pack(({expr}))\n    local __run_n = __run_pack.n or #__run_pack\n    if __run_n > 0 then\n        for __run_i = 1, __run_n do\n            if __run_i > 1 then io.write(\"\\t\") end\n            local __run_val = __run_pack[__run_i]\n            if __run_val == nil then\n                io.write(\"nil\")\n            else\n                io.write(tostring(__run_val))\n            end\n        end\n        io.write(\"\\n\")\n    end\nend\n",
282        expr = trimmed
283    )
284}
285impl LanguageSession for LuaSession {
286    fn language_id(&self) -> &str {
287        self.language_id()
288    }
289
290    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
291        let trimmed = code.trim();
292
293        if trimmed.eq_ignore_ascii_case(":reset") {
294            self.statements.clear();
295            self.last_stdout.clear();
296            self.last_stderr.clear();
297            self.persist_source()?;
298            return Ok(ExecutionOutcome {
299                language: self.language_id().to_string(),
300                exit_code: None,
301                stdout: String::new(),
302                stderr: String::new(),
303                duration: Duration::default(),
304            });
305        }
306
307        if trimmed.eq_ignore_ascii_case(":help") {
308            return Ok(ExecutionOutcome {
309                language: self.language_id().to_string(),
310                exit_code: None,
311                stdout:
312                    "Lua commands:\n  :reset - clear session state\n  :help  - show this message\n"
313                        .to_string(),
314                stderr: String::new(),
315                duration: Duration::default(),
316            });
317        }
318
319        if trimmed.is_empty() {
320            return Ok(ExecutionOutcome {
321                language: self.language_id().to_string(),
322                exit_code: None,
323                stdout: String::new(),
324                stderr: String::new(),
325                duration: Duration::default(),
326            });
327        }
328
329        let (effective_code, force_expression) = if let Some(stripped) = trimmed.strip_prefix('=') {
330            (stripped.trim(), true)
331        } else {
332            (trimmed, false)
333        };
334
335        let is_expression = force_expression || looks_like_expression_snippet(effective_code);
336        let statement = if is_expression {
337            wrap_expression_snippet(effective_code)
338        } else {
339            format!("{}\n", code.trim_end_matches(['\r', '\n']))
340        };
341
342        let previous_stdout = self.last_stdout.clone();
343        let previous_stderr = self.last_stderr.clone();
344
345        self.statements.push(statement);
346        self.persist_source()?;
347
348        let start = Instant::now();
349        let output = self.run_program()?;
350        let stdout_full = LuaSession::normalize_output(&output.stdout);
351        let stderr_full = LuaSession::normalize_output(&output.stderr);
352        let stdout = LuaSession::diff_outputs(&self.last_stdout, &stdout_full);
353        let stderr = LuaSession::diff_outputs(&self.last_stderr, &stderr_full);
354        let duration = start.elapsed();
355
356        if output.status.success() {
357            if is_expression {
358                self.statements.pop();
359                self.persist_source()?;
360                self.last_stdout = previous_stdout;
361                self.last_stderr = previous_stderr;
362            } else {
363                self.last_stdout = stdout_full;
364                self.last_stderr = stderr_full;
365            }
366            Ok(ExecutionOutcome {
367                language: self.language_id().to_string(),
368                exit_code: output.status.code(),
369                stdout,
370                stderr,
371                duration,
372            })
373        } else {
374            self.statements.pop();
375            self.persist_source()?;
376            self.last_stdout = previous_stdout;
377            self.last_stderr = previous_stderr;
378            Ok(ExecutionOutcome {
379                language: self.language_id().to_string(),
380                exit_code: output.status.code(),
381                stdout,
382                stderr,
383                duration,
384            })
385        }
386    }
387
388    fn shutdown(&mut self) -> Result<()> {
389        Ok(())
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::{LuaSession, looks_like_expression_snippet, wrap_expression_snippet};
396
397    #[test]
398    fn diff_outputs_appends_only_suffix() {
399        let previous = "a\nb\n";
400        let current = "a\nb\nc\n";
401        assert_eq!(LuaSession::diff_outputs(previous, current), "c\n");
402
403        let previous = "a\n";
404        let current = "x\na\n";
405        assert_eq!(LuaSession::diff_outputs(previous, current), "x\na\n");
406    }
407
408    #[test]
409    fn detects_simple_expression() {
410        assert!(looks_like_expression_snippet("a"));
411        assert!(looks_like_expression_snippet("foo(bar)"));
412        assert!(!looks_like_expression_snippet("local a = 1"));
413        assert!(!looks_like_expression_snippet("a = 1"));
414    }
415
416    #[test]
417    fn wraps_expression_with_print_block() {
418        let wrapped = wrap_expression_snippet("a");
419        assert!(wrapped.contains("table.pack((a))"));
420        assert!(wrapped.contains("io.write(\"\\n\")"));
421    }
422}