Skip to main content

run/engine/
python.rs

1use std::fs;
2use std::io::Write;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{
11    ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, execution_timeout,
12    run_version_command, wait_with_timeout,
13};
14
15pub struct PythonEngine {
16    executable: PathBuf,
17}
18
19impl Default for PythonEngine {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl PythonEngine {
26    pub fn new() -> Self {
27        let executable = resolve_python_binary();
28        Self { executable }
29    }
30
31    fn binary(&self) -> &Path {
32        &self.executable
33    }
34
35    fn run_command(&self) -> Command {
36        Command::new(self.binary())
37    }
38}
39
40impl LanguageEngine for PythonEngine {
41    fn id(&self) -> &'static str {
42        "python"
43    }
44
45    fn display_name(&self) -> &'static str {
46        "Python"
47    }
48
49    fn aliases(&self) -> &[&'static str] {
50        &["py", "python3", "py3"]
51    }
52
53    fn supports_sessions(&self) -> bool {
54        true
55    }
56
57    fn validate(&self) -> Result<()> {
58        let mut cmd = self.run_command();
59        cmd.arg("--version")
60            .stdout(Stdio::null())
61            .stderr(Stdio::null());
62        cmd.status()
63            .with_context(|| format!("failed to invoke {}", self.binary().display()))?
64            .success()
65            .then_some(())
66            .ok_or_else(|| anyhow::anyhow!("{} is not executable", self.binary().display()))
67    }
68
69    fn toolchain_version(&self) -> Result<Option<String>> {
70        let mut cmd = self.run_command();
71        cmd.arg("--version");
72        let context = format!("{}", self.binary().display());
73        run_version_command(cmd, &context)
74    }
75
76    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
77        let start = Instant::now();
78        let timeout = execution_timeout();
79        let mut cmd = self.run_command();
80        let args = payload.args();
81        let output = match payload {
82            ExecutionPayload::Inline { code, .. } => {
83                cmd.arg("-c")
84                    .arg(code)
85                    .args(args)
86                    .stdin(Stdio::inherit())
87                    .stdout(Stdio::piped())
88                    .stderr(Stdio::piped());
89                let child = cmd
90                    .spawn()
91                    .with_context(|| format!("failed to start {}", self.binary().display()))?;
92                wait_with_timeout(child, timeout)?
93            }
94            ExecutionPayload::File { path, .. } => {
95                cmd.arg(path)
96                    .args(args)
97                    .stdin(Stdio::inherit())
98                    .stdout(Stdio::piped())
99                    .stderr(Stdio::piped());
100                let child = cmd
101                    .spawn()
102                    .with_context(|| format!("failed to start {}", self.binary().display()))?;
103                wait_with_timeout(child, timeout)?
104            }
105            ExecutionPayload::Stdin { code, .. } => {
106                cmd.arg("-")
107                    .args(args)
108                    .stdin(Stdio::piped())
109                    .stdout(Stdio::piped())
110                    .stderr(Stdio::piped());
111                let mut child = cmd.spawn().with_context(|| {
112                    format!(
113                        "failed to start {} for stdin execution",
114                        self.binary().display()
115                    )
116                })?;
117                if let Some(mut stdin) = child.stdin.take() {
118                    stdin.write_all(code.as_bytes())?;
119                }
120                wait_with_timeout(child, timeout)?
121            }
122        };
123
124        Ok(ExecutionOutcome {
125            language: self.id().to_string(),
126            exit_code: output.status.code(),
127            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
128            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
129            duration: start.elapsed(),
130        })
131    }
132
133    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
134        Ok(Box::new(PythonSession::new(self.executable.clone())?))
135    }
136}
137
138struct PythonSession {
139    executable: PathBuf,
140    dir: TempDir,
141    source_path: PathBuf,
142    statements: Vec<String>,
143    previous_stdout: String,
144    previous_stderr: String,
145}
146
147impl PythonSession {
148    fn new(executable: PathBuf) -> Result<Self> {
149        let dir = Builder::new()
150            .prefix("run-python-repl")
151            .tempdir()
152            .context("failed to create temporary directory for python repl")?;
153        let source_path = dir.path().join("session.py");
154        fs::write(&source_path, "# Python REPL session\n")
155            .with_context(|| format!("failed to initialize {}", source_path.display()))?;
156
157        Ok(Self {
158            executable,
159            dir,
160            source_path,
161            statements: Vec::new(),
162            previous_stdout: String::new(),
163            previous_stderr: String::new(),
164        })
165    }
166
167    fn render_source(&self) -> String {
168        let mut source = String::from("import sys\nfrom math import *\n\n");
169        for snippet in &self.statements {
170            source.push_str(snippet);
171            if !snippet.ends_with('\n') {
172                source.push('\n');
173            }
174        }
175        source
176    }
177
178    fn write_source(&self, contents: &str) -> Result<()> {
179        fs::write(&self.source_path, contents).with_context(|| {
180            format!(
181                "failed to write generated Python REPL source to {}",
182                self.source_path.display()
183            )
184        })
185    }
186
187    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
188        let source = self.render_source();
189        self.write_source(&source)?;
190
191        let output = self.run_script()?;
192        let stdout_full = normalize_output(&output.stdout);
193        let stderr_full = normalize_output(&output.stderr);
194
195        let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
196        let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
197
198        let success = output.status.success();
199        if success {
200            self.previous_stdout = stdout_full;
201            self.previous_stderr = stderr_full;
202        }
203
204        let outcome = ExecutionOutcome {
205            language: "python".to_string(),
206            exit_code: output.status.code(),
207            stdout: stdout_delta,
208            stderr: stderr_delta,
209            duration: start.elapsed(),
210        };
211
212        Ok((outcome, success))
213    }
214
215    fn run_script(&self) -> Result<std::process::Output> {
216        let mut cmd = Command::new(&self.executable);
217        cmd.arg(&self.source_path)
218            .stdout(Stdio::piped())
219            .stderr(Stdio::piped())
220            .current_dir(self.dir.path());
221        cmd.output().with_context(|| {
222            format!(
223                "failed to run python session script {} with {}",
224                self.source_path.display(),
225                self.executable.display()
226            )
227        })
228    }
229
230    fn run_snippet(&mut self, snippet: String) -> Result<ExecutionOutcome> {
231        self.statements.push(snippet);
232        let start = Instant::now();
233        let (outcome, success) = self.run_current(start)?;
234        if !success {
235            let _ = self.statements.pop();
236            let source = self.render_source();
237            self.write_source(&source)?;
238        }
239        Ok(outcome)
240    }
241
242    fn reset_state(&mut self) -> Result<()> {
243        self.statements.clear();
244        self.previous_stdout.clear();
245        self.previous_stderr.clear();
246        let source = self.render_source();
247        self.write_source(&source)
248    }
249}
250
251impl LanguageSession for PythonSession {
252    fn language_id(&self) -> &str {
253        "python"
254    }
255
256    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
257        let trimmed = code.trim();
258        if trimmed.is_empty() {
259            return Ok(ExecutionOutcome {
260                language: self.language_id().to_string(),
261                exit_code: None,
262                stdout: String::new(),
263                stderr: String::new(),
264                duration: Duration::default(),
265            });
266        }
267
268        if trimmed.eq_ignore_ascii_case(":reset") {
269            self.reset_state()?;
270            return Ok(ExecutionOutcome {
271                language: self.language_id().to_string(),
272                exit_code: None,
273                stdout: String::new(),
274                stderr: String::new(),
275                duration: Duration::default(),
276            });
277        }
278
279        if trimmed.eq_ignore_ascii_case(":help") {
280            return Ok(ExecutionOutcome {
281                language: self.language_id().to_string(),
282                exit_code: None,
283                stdout:
284                    "Python commands:\n  :reset - clear session state\n  :help  - show this message\n"
285                        .to_string(),
286                stderr: String::new(),
287                duration: Duration::default(),
288            });
289        }
290
291        if should_treat_as_expression(trimmed) {
292            let snippet = wrap_expression(trimmed, self.statements.len());
293            let outcome = self.run_snippet(snippet)?;
294            if outcome.exit_code.unwrap_or(0) == 0 {
295                return Ok(outcome);
296            }
297        }
298
299        let snippet = ensure_trailing_newline(code);
300        self.run_snippet(snippet)
301    }
302
303    fn shutdown(&mut self) -> Result<()> {
304        Ok(())
305    }
306}
307
308pub(super) fn resolve_python_binary() -> PathBuf {
309    let candidates = ["python3", "python", "py"]; // windows py launcher
310    for name in candidates {
311        if let Ok(path) = which::which(name) {
312            return path;
313        }
314    }
315    PathBuf::from("python3")
316}
317
318fn ensure_trailing_newline(code: &str) -> String {
319    let mut owned = code.to_string();
320    if !owned.ends_with('\n') {
321        owned.push('\n');
322    }
323    owned
324}
325
326fn wrap_expression(code: &str, index: usize) -> String {
327    // Store result in both a unique var and `_` for last-result access
328    format!(
329        "__run_value_{index} = ({code})\n_ = __run_value_{index}\nprint(repr(__run_value_{index}), flush=True)\n"
330    )
331}
332
333fn diff_output(previous: &str, current: &str) -> String {
334    if let Some(stripped) = current.strip_prefix(previous) {
335        stripped.to_string()
336    } else {
337        current.to_string()
338    }
339}
340
341fn normalize_output(bytes: &[u8]) -> String {
342    String::from_utf8_lossy(bytes)
343        .replace("\r\n", "\n")
344        .replace('\r', "")
345}
346
347fn should_treat_as_expression(code: &str) -> bool {
348    let trimmed = code.trim();
349    if trimmed.is_empty() {
350        return false;
351    }
352    if trimmed.contains('\n') {
353        return false;
354    }
355    if trimmed.ends_with(':') {
356        return false;
357    }
358
359    let lowered = trimmed.to_ascii_lowercase();
360    const STATEMENT_PREFIXES: [&str; 21] = [
361        "import ",
362        "from ",
363        "def ",
364        "class ",
365        "if ",
366        "for ",
367        "while ",
368        "try",
369        "except",
370        "finally",
371        "with ",
372        "return ",
373        "raise ",
374        "yield",
375        "async ",
376        "await ",
377        "assert ",
378        "del ",
379        "global ",
380        "nonlocal ",
381        "pass",
382    ];
383    if STATEMENT_PREFIXES
384        .iter()
385        .any(|prefix| lowered.starts_with(prefix))
386    {
387        return false;
388    }
389
390    if lowered.starts_with("print(") || lowered.starts_with("print ") {
391        return false;
392    }
393
394    if trimmed.starts_with("#") {
395        return false;
396    }
397
398    if trimmed.contains('=')
399        && !trimmed.contains("==")
400        && !trimmed.contains("!=")
401        && !trimmed.contains(">=")
402        && !trimmed.contains("<=")
403        && !trimmed.contains("=>")
404    {
405        return false;
406    }
407
408    true
409}