Skip to main content

run/engine/
python.rs

1use std::fs;
2use std::io::Write;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{
11    ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, execution_timeout,
12    wait_with_timeout,
13};
14
15pub struct PythonEngine {
16    executable: PathBuf,
17}
18
19impl Default for PythonEngine {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl PythonEngine {
26    pub fn new() -> Self {
27        let executable = resolve_python_binary();
28        Self { executable }
29    }
30
31    fn binary(&self) -> &Path {
32        &self.executable
33    }
34
35    fn run_command(&self) -> Command {
36        Command::new(self.binary())
37    }
38}
39
40impl LanguageEngine for PythonEngine {
41    fn id(&self) -> &'static str {
42        "python"
43    }
44
45    fn display_name(&self) -> &'static str {
46        "Python"
47    }
48
49    fn aliases(&self) -> &[&'static str] {
50        &["py", "python3", "py3"]
51    }
52
53    fn supports_sessions(&self) -> bool {
54        true
55    }
56
57    fn validate(&self) -> Result<()> {
58        let mut cmd = self.run_command();
59        cmd.arg("--version")
60            .stdout(Stdio::null())
61            .stderr(Stdio::null());
62        cmd.status()
63            .with_context(|| format!("failed to invoke {}", self.binary().display()))?
64            .success()
65            .then_some(())
66            .ok_or_else(|| anyhow::anyhow!("{} is not executable", self.binary().display()))
67    }
68
69    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
70        let start = Instant::now();
71        let timeout = execution_timeout();
72        let mut cmd = self.run_command();
73        let output = match payload {
74            ExecutionPayload::Inline { code } => {
75                cmd.arg("-c")
76                    .arg(code)
77                    .stdin(Stdio::inherit())
78                    .stdout(Stdio::piped())
79                    .stderr(Stdio::piped());
80                let child = cmd
81                    .spawn()
82                    .with_context(|| format!("failed to start {}", self.binary().display()))?;
83                wait_with_timeout(child, timeout)?
84            }
85            ExecutionPayload::File { path } => {
86                cmd.arg(path)
87                    .stdin(Stdio::inherit())
88                    .stdout(Stdio::piped())
89                    .stderr(Stdio::piped());
90                let child = cmd
91                    .spawn()
92                    .with_context(|| format!("failed to start {}", self.binary().display()))?;
93                wait_with_timeout(child, timeout)?
94            }
95            ExecutionPayload::Stdin { code } => {
96                cmd.arg("-")
97                    .stdin(Stdio::piped())
98                    .stdout(Stdio::piped())
99                    .stderr(Stdio::piped());
100                let mut child = cmd.spawn().with_context(|| {
101                    format!(
102                        "failed to start {} for stdin execution",
103                        self.binary().display()
104                    )
105                })?;
106                if let Some(mut stdin) = child.stdin.take() {
107                    stdin.write_all(code.as_bytes())?;
108                }
109                wait_with_timeout(child, timeout)?
110            }
111        };
112
113        Ok(ExecutionOutcome {
114            language: self.id().to_string(),
115            exit_code: output.status.code(),
116            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
117            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
118            duration: start.elapsed(),
119        })
120    }
121
122    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
123        Ok(Box::new(PythonSession::new(self.executable.clone())?))
124    }
125}
126
127struct PythonSession {
128    executable: PathBuf,
129    dir: TempDir,
130    source_path: PathBuf,
131    statements: Vec<String>,
132    previous_stdout: String,
133    previous_stderr: String,
134}
135
136impl PythonSession {
137    fn new(executable: PathBuf) -> Result<Self> {
138        let dir = Builder::new()
139            .prefix("run-python-repl")
140            .tempdir()
141            .context("failed to create temporary directory for python repl")?;
142        let source_path = dir.path().join("session.py");
143        fs::write(&source_path, "# Python REPL session\n")
144            .with_context(|| format!("failed to initialize {}", source_path.display()))?;
145
146        Ok(Self {
147            executable,
148            dir,
149            source_path,
150            statements: Vec::new(),
151            previous_stdout: String::new(),
152            previous_stderr: String::new(),
153        })
154    }
155
156    fn render_source(&self) -> String {
157        let mut source = String::from("import sys\nfrom math import *\n\n");
158        for snippet in &self.statements {
159            source.push_str(snippet);
160            if !snippet.ends_with('\n') {
161                source.push('\n');
162            }
163        }
164        source
165    }
166
167    fn write_source(&self, contents: &str) -> Result<()> {
168        fs::write(&self.source_path, contents).with_context(|| {
169            format!(
170                "failed to write generated Python REPL source to {}",
171                self.source_path.display()
172            )
173        })
174    }
175
176    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
177        let source = self.render_source();
178        self.write_source(&source)?;
179
180        let output = self.run_script()?;
181        let stdout_full = normalize_output(&output.stdout);
182        let stderr_full = normalize_output(&output.stderr);
183
184        let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
185        let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
186
187        let success = output.status.success();
188        if success {
189            self.previous_stdout = stdout_full;
190            self.previous_stderr = stderr_full;
191        }
192
193        let outcome = ExecutionOutcome {
194            language: "python".to_string(),
195            exit_code: output.status.code(),
196            stdout: stdout_delta,
197            stderr: stderr_delta,
198            duration: start.elapsed(),
199        };
200
201        Ok((outcome, success))
202    }
203
204    fn run_script(&self) -> Result<std::process::Output> {
205        let mut cmd = Command::new(&self.executable);
206        cmd.arg(&self.source_path)
207            .stdout(Stdio::piped())
208            .stderr(Stdio::piped())
209            .current_dir(self.dir.path());
210        cmd.output().with_context(|| {
211            format!(
212                "failed to run python session script {} with {}",
213                self.source_path.display(),
214                self.executable.display()
215            )
216        })
217    }
218
219    fn run_snippet(&mut self, snippet: String) -> Result<ExecutionOutcome> {
220        self.statements.push(snippet);
221        let start = Instant::now();
222        let (outcome, success) = self.run_current(start)?;
223        if !success {
224            let _ = self.statements.pop();
225            let source = self.render_source();
226            self.write_source(&source)?;
227        }
228        Ok(outcome)
229    }
230
231    fn reset_state(&mut self) -> Result<()> {
232        self.statements.clear();
233        self.previous_stdout.clear();
234        self.previous_stderr.clear();
235        let source = self.render_source();
236        self.write_source(&source)
237    }
238}
239
240impl LanguageSession for PythonSession {
241    fn language_id(&self) -> &str {
242        "python"
243    }
244
245    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
246        let trimmed = code.trim();
247        if trimmed.is_empty() {
248            return Ok(ExecutionOutcome {
249                language: self.language_id().to_string(),
250                exit_code: None,
251                stdout: String::new(),
252                stderr: String::new(),
253                duration: Duration::default(),
254            });
255        }
256
257        if trimmed.eq_ignore_ascii_case(":reset") {
258            self.reset_state()?;
259            return Ok(ExecutionOutcome {
260                language: self.language_id().to_string(),
261                exit_code: None,
262                stdout: String::new(),
263                stderr: String::new(),
264                duration: Duration::default(),
265            });
266        }
267
268        if trimmed.eq_ignore_ascii_case(":help") {
269            return Ok(ExecutionOutcome {
270                language: self.language_id().to_string(),
271                exit_code: None,
272                stdout:
273                    "Python commands:\n  :reset - clear session state\n  :help  - show this message\n"
274                        .to_string(),
275                stderr: String::new(),
276                duration: Duration::default(),
277            });
278        }
279
280        if should_treat_as_expression(trimmed) {
281            let snippet = wrap_expression(trimmed, self.statements.len());
282            let outcome = self.run_snippet(snippet)?;
283            if outcome.exit_code.unwrap_or(0) == 0 {
284                return Ok(outcome);
285            }
286        }
287
288        let snippet = ensure_trailing_newline(code);
289        self.run_snippet(snippet)
290    }
291
292    fn shutdown(&mut self) -> Result<()> {
293        Ok(())
294    }
295}
296
297fn resolve_python_binary() -> PathBuf {
298    let candidates = ["python3", "python", "py"]; // windows py launcher
299    for name in candidates {
300        if let Ok(path) = which::which(name) {
301            return path;
302        }
303    }
304    PathBuf::from("python3")
305}
306
307fn ensure_trailing_newline(code: &str) -> String {
308    let mut owned = code.to_string();
309    if !owned.ends_with('\n') {
310        owned.push('\n');
311    }
312    owned
313}
314
315fn wrap_expression(code: &str, index: usize) -> String {
316    // Store result in both a unique var and `_` for last-result access
317    format!(
318        "__run_value_{index} = ({code})\n_ = __run_value_{index}\nprint(repr(__run_value_{index}), flush=True)\n"
319    )
320}
321
322fn diff_output(previous: &str, current: &str) -> String {
323    if let Some(stripped) = current.strip_prefix(previous) {
324        stripped.to_string()
325    } else {
326        current.to_string()
327    }
328}
329
330fn normalize_output(bytes: &[u8]) -> String {
331    String::from_utf8_lossy(bytes)
332        .replace("\r\n", "\n")
333        .replace('\r', "")
334}
335
336fn should_treat_as_expression(code: &str) -> bool {
337    let trimmed = code.trim();
338    if trimmed.is_empty() {
339        return false;
340    }
341    if trimmed.contains('\n') {
342        return false;
343    }
344    if trimmed.ends_with(':') {
345        return false;
346    }
347
348    let lowered = trimmed.to_ascii_lowercase();
349    const STATEMENT_PREFIXES: [&str; 21] = [
350        "import ",
351        "from ",
352        "def ",
353        "class ",
354        "if ",
355        "for ",
356        "while ",
357        "try",
358        "except",
359        "finally",
360        "with ",
361        "return ",
362        "raise ",
363        "yield",
364        "async ",
365        "await ",
366        "assert ",
367        "del ",
368        "global ",
369        "nonlocal ",
370        "pass",
371    ];
372    if STATEMENT_PREFIXES
373        .iter()
374        .any(|prefix| lowered.starts_with(prefix))
375    {
376        return false;
377    }
378
379    if lowered.starts_with("print(") || lowered.starts_with("print ") {
380        return false;
381    }
382
383    if trimmed.starts_with("#") {
384        return false;
385    }
386
387    if trimmed.contains('=')
388        && !trimmed.contains("==")
389        && !trimmed.contains("!=")
390        && !trimmed.contains(">=")
391        && !trimmed.contains("<=")
392        && !trimmed.contains("=>")
393    {
394        return false;
395    }
396
397    true
398}