Skip to main content

run/engine/
python.rs

1use std::fs;
2use std::io::Write;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, execution_timeout, wait_with_timeout};
11
12pub struct PythonEngine {
13    executable: PathBuf,
14}
15
16impl PythonEngine {
17    pub fn new() -> Self {
18        let executable = resolve_python_binary();
19        Self { executable }
20    }
21
22    fn binary(&self) -> &Path {
23        &self.executable
24    }
25
26    fn run_command(&self) -> Command {
27        Command::new(self.binary())
28    }
29}
30
31impl LanguageEngine for PythonEngine {
32    fn id(&self) -> &'static str {
33        "python"
34    }
35
36    fn display_name(&self) -> &'static str {
37        "Python"
38    }
39
40    fn aliases(&self) -> &[&'static str] {
41        &["py", "python3", "py3"]
42    }
43
44    fn supports_sessions(&self) -> bool {
45        true
46    }
47
48    fn validate(&self) -> Result<()> {
49        let mut cmd = self.run_command();
50        cmd.arg("--version")
51            .stdout(Stdio::null())
52            .stderr(Stdio::null());
53        cmd.status()
54            .with_context(|| format!("failed to invoke {}", self.binary().display()))?
55            .success()
56            .then_some(())
57            .ok_or_else(|| anyhow::anyhow!("{} is not executable", self.binary().display()))
58    }
59
60    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
61        let start = Instant::now();
62        let timeout = execution_timeout();
63        let mut cmd = self.run_command();
64        let output = match payload {
65            ExecutionPayload::Inline { code } => {
66                cmd.arg("-c")
67                    .arg(code)
68                    .stdin(Stdio::inherit())
69                    .stdout(Stdio::piped())
70                    .stderr(Stdio::piped());
71                let child = cmd.spawn().with_context(|| {
72                    format!("failed to start {}", self.binary().display())
73                })?;
74                wait_with_timeout(child, timeout)?
75            }
76            ExecutionPayload::File { path } => {
77                cmd.arg(path)
78                    .stdin(Stdio::inherit())
79                    .stdout(Stdio::piped())
80                    .stderr(Stdio::piped());
81                let child = cmd.spawn().with_context(|| {
82                    format!("failed to start {}", self.binary().display())
83                })?;
84                wait_with_timeout(child, timeout)?
85            }
86            ExecutionPayload::Stdin { code } => {
87                cmd.arg("-")
88                    .stdin(Stdio::piped())
89                    .stdout(Stdio::piped())
90                    .stderr(Stdio::piped());
91                let mut child = cmd.spawn().with_context(|| {
92                    format!(
93                        "failed to start {} for stdin execution",
94                        self.binary().display()
95                    )
96                })?;
97                if let Some(mut stdin) = child.stdin.take() {
98                    stdin.write_all(code.as_bytes())?;
99                }
100                wait_with_timeout(child, timeout)?
101            }
102        };
103
104        Ok(ExecutionOutcome {
105            language: self.id().to_string(),
106            exit_code: output.status.code(),
107            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
108            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
109            duration: start.elapsed(),
110        })
111    }
112
113    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
114        Ok(Box::new(PythonSession::new(self.executable.clone())?))
115    }
116}
117
118struct PythonSession {
119    executable: PathBuf,
120    dir: TempDir,
121    source_path: PathBuf,
122    statements: Vec<String>,
123    previous_stdout: String,
124    previous_stderr: String,
125}
126
127impl PythonSession {
128    fn new(executable: PathBuf) -> Result<Self> {
129        let dir = Builder::new()
130            .prefix("run-python-repl")
131            .tempdir()
132            .context("failed to create temporary directory for python repl")?;
133        let source_path = dir.path().join("session.py");
134        fs::write(&source_path, "# Python REPL session\n")
135            .with_context(|| format!("failed to initialize {}", source_path.display()))?;
136
137        Ok(Self {
138            executable,
139            dir,
140            source_path,
141            statements: Vec::new(),
142            previous_stdout: String::new(),
143            previous_stderr: String::new(),
144        })
145    }
146
147    fn render_source(&self) -> String {
148        let mut source = String::from("import sys\nfrom math import *\n\n");
149        for snippet in &self.statements {
150            source.push_str(snippet);
151            if !snippet.ends_with('\n') {
152                source.push('\n');
153            }
154        }
155        source
156    }
157
158    fn write_source(&self, contents: &str) -> Result<()> {
159        fs::write(&self.source_path, contents).with_context(|| {
160            format!(
161                "failed to write generated Python REPL source to {}",
162                self.source_path.display()
163            )
164        })
165    }
166
167    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
168        let source = self.render_source();
169        self.write_source(&source)?;
170
171        let output = self.run_script()?;
172        let stdout_full = normalize_output(&output.stdout);
173        let stderr_full = normalize_output(&output.stderr);
174
175        let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
176        let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
177
178        let success = output.status.success();
179        if success {
180            self.previous_stdout = stdout_full;
181            self.previous_stderr = stderr_full;
182        }
183
184        let outcome = ExecutionOutcome {
185            language: "python".to_string(),
186            exit_code: output.status.code(),
187            stdout: stdout_delta,
188            stderr: stderr_delta,
189            duration: start.elapsed(),
190        };
191
192        Ok((outcome, success))
193    }
194
195    fn run_script(&self) -> Result<std::process::Output> {
196        let mut cmd = Command::new(&self.executable);
197        cmd.arg(&self.source_path)
198            .stdout(Stdio::piped())
199            .stderr(Stdio::piped())
200            .current_dir(self.dir.path());
201        cmd.output().with_context(|| {
202            format!(
203                "failed to run python session script {} with {}",
204                self.source_path.display(),
205                self.executable.display()
206            )
207        })
208    }
209
210    fn run_snippet(&mut self, snippet: String) -> Result<ExecutionOutcome> {
211        self.statements.push(snippet);
212        let start = Instant::now();
213        let (outcome, success) = self.run_current(start)?;
214        if !success {
215            let _ = self.statements.pop();
216            let source = self.render_source();
217            self.write_source(&source)?;
218        }
219        Ok(outcome)
220    }
221
222    fn reset_state(&mut self) -> Result<()> {
223        self.statements.clear();
224        self.previous_stdout.clear();
225        self.previous_stderr.clear();
226        let source = self.render_source();
227        self.write_source(&source)
228    }
229}
230
231impl LanguageSession for PythonSession {
232    fn language_id(&self) -> &str {
233        "python"
234    }
235
236    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
237        let trimmed = code.trim();
238        if trimmed.is_empty() {
239            return Ok(ExecutionOutcome {
240                language: self.language_id().to_string(),
241                exit_code: None,
242                stdout: String::new(),
243                stderr: String::new(),
244                duration: Duration::default(),
245            });
246        }
247
248        if trimmed.eq_ignore_ascii_case(":reset") {
249            self.reset_state()?;
250            return Ok(ExecutionOutcome {
251                language: self.language_id().to_string(),
252                exit_code: None,
253                stdout: String::new(),
254                stderr: String::new(),
255                duration: Duration::default(),
256            });
257        }
258
259        if trimmed.eq_ignore_ascii_case(":help") {
260            return Ok(ExecutionOutcome {
261                language: self.language_id().to_string(),
262                exit_code: None,
263                stdout:
264                    "Python commands:\n  :reset - clear session state\n  :help  - show this message\n"
265                        .to_string(),
266                stderr: String::new(),
267                duration: Duration::default(),
268            });
269        }
270
271        if should_treat_as_expression(trimmed) {
272            let snippet = wrap_expression(trimmed, self.statements.len());
273            let outcome = self.run_snippet(snippet)?;
274            if outcome.exit_code.unwrap_or(0) == 0 {
275                return Ok(outcome);
276            }
277        }
278
279        let snippet = ensure_trailing_newline(code);
280        self.run_snippet(snippet)
281    }
282
283    fn shutdown(&mut self) -> Result<()> {
284        Ok(())
285    }
286}
287
288fn resolve_python_binary() -> PathBuf {
289    let candidates = ["python3", "python", "py"]; // windows py launcher
290    for name in candidates {
291        if let Ok(path) = which::which(name) {
292            return path;
293        }
294    }
295    PathBuf::from("python3")
296}
297
298fn ensure_trailing_newline(code: &str) -> String {
299    let mut owned = code.to_string();
300    if !owned.ends_with('\n') {
301        owned.push('\n');
302    }
303    owned
304}
305
306fn wrap_expression(code: &str, index: usize) -> String {
307    // Store result in both a unique var and `_` for last-result access
308    format!("__run_value_{index} = ({code})\n_ = __run_value_{index}\nprint(repr(__run_value_{index}), flush=True)\n")
309}
310
311fn diff_output(previous: &str, current: &str) -> String {
312    if let Some(stripped) = current.strip_prefix(previous) {
313        stripped.to_string()
314    } else {
315        current.to_string()
316    }
317}
318
319fn normalize_output(bytes: &[u8]) -> String {
320    String::from_utf8_lossy(bytes)
321        .replace("\r\n", "\n")
322        .replace('\r', "")
323}
324
325fn should_treat_as_expression(code: &str) -> bool {
326    let trimmed = code.trim();
327    if trimmed.is_empty() {
328        return false;
329    }
330    if trimmed.contains('\n') {
331        return false;
332    }
333    if trimmed.ends_with(':') {
334        return false;
335    }
336
337    let lowered = trimmed.to_ascii_lowercase();
338    const STATEMENT_PREFIXES: [&str; 21] = [
339        "import ",
340        "from ",
341        "def ",
342        "class ",
343        "if ",
344        "for ",
345        "while ",
346        "try",
347        "except",
348        "finally",
349        "with ",
350        "return ",
351        "raise ",
352        "yield",
353        "async ",
354        "await ",
355        "assert ",
356        "del ",
357        "global ",
358        "nonlocal ",
359        "pass",
360    ];
361    if STATEMENT_PREFIXES
362        .iter()
363        .any(|prefix| lowered.starts_with(prefix))
364    {
365        return false;
366    }
367
368    if lowered.starts_with("print(") || lowered.starts_with("print ") {
369        return false;
370    }
371
372    if trimmed.starts_with("#") {
373        return false;
374    }
375
376    if trimmed.contains('=')
377        && !trimmed.contains("==")
378        && !trimmed.contains("!=")
379        && !trimmed.contains(">=")
380        && !trimmed.contains("<=")
381        && !trimmed.contains("=>")
382    {
383        return false;
384    }
385
386    true
387}