Skip to main content

run/engine/
python.rs

1use std::fs;
2use std::io::Write;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
11
12pub struct PythonEngine {
13    executable: PathBuf,
14}
15
16impl PythonEngine {
17    pub fn new() -> Self {
18        let executable = resolve_python_binary();
19        Self { executable }
20    }
21
22    fn binary(&self) -> &Path {
23        &self.executable
24    }
25
26    fn run_command(&self) -> Command {
27        Command::new(self.binary())
28    }
29}
30
31impl LanguageEngine for PythonEngine {
32    fn id(&self) -> &'static str {
33        "python"
34    }
35
36    fn display_name(&self) -> &'static str {
37        "Python"
38    }
39
40    fn aliases(&self) -> &[&'static str] {
41        &["py", "python3", "py3"]
42    }
43
44    fn supports_sessions(&self) -> bool {
45        true
46    }
47
48    fn validate(&self) -> Result<()> {
49        let mut cmd = self.run_command();
50        cmd.arg("--version")
51            .stdout(Stdio::null())
52            .stderr(Stdio::null());
53        cmd.status()
54            .with_context(|| format!("failed to invoke {}", self.binary().display()))?
55            .success()
56            .then_some(())
57            .ok_or_else(|| anyhow::anyhow!("{} is not executable", self.binary().display()))
58    }
59
60    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
61        let start = Instant::now();
62        let mut cmd = self.run_command();
63        let output = match payload {
64            ExecutionPayload::Inline { code } => {
65                cmd.arg("-c").arg(code);
66                cmd.stdin(Stdio::inherit());
67                cmd.output()
68            }
69            ExecutionPayload::File { path } => {
70                cmd.arg(path);
71                cmd.stdin(Stdio::inherit());
72                cmd.output()
73            }
74            ExecutionPayload::Stdin { code } => {
75                cmd.arg("-")
76                    .stdin(Stdio::piped())
77                    .stdout(Stdio::piped())
78                    .stderr(Stdio::piped());
79                let mut child = cmd.spawn().with_context(|| {
80                    format!(
81                        "failed to start {} for stdin execution",
82                        self.binary().display()
83                    )
84                })?;
85                if let Some(mut stdin) = child.stdin.take() {
86                    stdin.write_all(code.as_bytes())?;
87                }
88                child.wait_with_output()
89            }
90        }?;
91
92        Ok(ExecutionOutcome {
93            language: self.id().to_string(),
94            exit_code: output.status.code(),
95            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
96            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
97            duration: start.elapsed(),
98        })
99    }
100
101    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
102        Ok(Box::new(PythonSession::new(self.executable.clone())?))
103    }
104}
105
106struct PythonSession {
107    executable: PathBuf,
108    dir: TempDir,
109    source_path: PathBuf,
110    statements: Vec<String>,
111    previous_stdout: String,
112    previous_stderr: String,
113}
114
115impl PythonSession {
116    fn new(executable: PathBuf) -> Result<Self> {
117        let dir = Builder::new()
118            .prefix("run-python-repl")
119            .tempdir()
120            .context("failed to create temporary directory for python repl")?;
121        let source_path = dir.path().join("session.py");
122        fs::write(&source_path, "# Python REPL session\n")
123            .with_context(|| format!("failed to initialize {}", source_path.display()))?;
124
125        Ok(Self {
126            executable,
127            dir,
128            source_path,
129            statements: Vec::new(),
130            previous_stdout: String::new(),
131            previous_stderr: String::new(),
132        })
133    }
134
135    fn render_source(&self) -> String {
136        let mut source = String::from("import sys\nfrom math import *\n\n");
137        for snippet in &self.statements {
138            source.push_str(snippet);
139            if !snippet.ends_with('\n') {
140                source.push('\n');
141            }
142        }
143        source
144    }
145
146    fn write_source(&self, contents: &str) -> Result<()> {
147        fs::write(&self.source_path, contents).with_context(|| {
148            format!(
149                "failed to write generated Python REPL source to {}",
150                self.source_path.display()
151            )
152        })
153    }
154
155    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
156        let source = self.render_source();
157        self.write_source(&source)?;
158
159        let output = self.run_script()?;
160        let stdout_full = normalize_output(&output.stdout);
161        let stderr_full = normalize_output(&output.stderr);
162
163        let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
164        let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
165
166        let success = output.status.success();
167        if success {
168            self.previous_stdout = stdout_full;
169            self.previous_stderr = stderr_full;
170        }
171
172        let outcome = ExecutionOutcome {
173            language: "python".to_string(),
174            exit_code: output.status.code(),
175            stdout: stdout_delta,
176            stderr: stderr_delta,
177            duration: start.elapsed(),
178        };
179
180        Ok((outcome, success))
181    }
182
183    fn run_script(&self) -> Result<std::process::Output> {
184        let mut cmd = Command::new(&self.executable);
185        cmd.arg(&self.source_path)
186            .stdout(Stdio::piped())
187            .stderr(Stdio::piped())
188            .current_dir(self.dir.path());
189        cmd.output().with_context(|| {
190            format!(
191                "failed to run python session script {} with {}",
192                self.source_path.display(),
193                self.executable.display()
194            )
195        })
196    }
197
198    fn run_snippet(&mut self, snippet: String) -> Result<ExecutionOutcome> {
199        self.statements.push(snippet);
200        let start = Instant::now();
201        let (outcome, success) = self.run_current(start)?;
202        if !success {
203            let _ = self.statements.pop();
204            let source = self.render_source();
205            self.write_source(&source)?;
206        }
207        Ok(outcome)
208    }
209
210    fn reset_state(&mut self) -> Result<()> {
211        self.statements.clear();
212        self.previous_stdout.clear();
213        self.previous_stderr.clear();
214        let source = self.render_source();
215        self.write_source(&source)
216    }
217}
218
219impl LanguageSession for PythonSession {
220    fn language_id(&self) -> &str {
221        "python"
222    }
223
224    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
225        let trimmed = code.trim();
226        if trimmed.is_empty() {
227            return Ok(ExecutionOutcome {
228                language: self.language_id().to_string(),
229                exit_code: None,
230                stdout: String::new(),
231                stderr: String::new(),
232                duration: Duration::default(),
233            });
234        }
235
236        if trimmed.eq_ignore_ascii_case(":reset") {
237            self.reset_state()?;
238            return Ok(ExecutionOutcome {
239                language: self.language_id().to_string(),
240                exit_code: None,
241                stdout: String::new(),
242                stderr: String::new(),
243                duration: Duration::default(),
244            });
245        }
246
247        if trimmed.eq_ignore_ascii_case(":help") {
248            return Ok(ExecutionOutcome {
249                language: self.language_id().to_string(),
250                exit_code: None,
251                stdout:
252                    "Python commands:\n  :reset - clear session state\n  :help  - show this message\n"
253                        .to_string(),
254                stderr: String::new(),
255                duration: Duration::default(),
256            });
257        }
258
259        if should_treat_as_expression(trimmed) {
260            let snippet = wrap_expression(trimmed, self.statements.len());
261            let outcome = self.run_snippet(snippet)?;
262            if outcome.exit_code.unwrap_or(0) == 0 {
263                return Ok(outcome);
264            }
265        }
266
267        let snippet = ensure_trailing_newline(code);
268        self.run_snippet(snippet)
269    }
270
271    fn shutdown(&mut self) -> Result<()> {
272        Ok(())
273    }
274}
275
276fn resolve_python_binary() -> PathBuf {
277    let candidates = ["python3", "python", "py"]; // windows py launcher
278    for name in candidates {
279        if let Ok(path) = which::which(name) {
280            return path;
281        }
282    }
283    PathBuf::from("python3")
284}
285
286fn ensure_trailing_newline(code: &str) -> String {
287    let mut owned = code.to_string();
288    if !owned.ends_with('\n') {
289        owned.push('\n');
290    }
291    owned
292}
293
294fn wrap_expression(code: &str, index: usize) -> String {
295    format!("__run_value_{index} = ({code})\nprint(repr(__run_value_{index}), flush=True)\n")
296}
297
298fn diff_output(previous: &str, current: &str) -> String {
299    if let Some(stripped) = current.strip_prefix(previous) {
300        stripped.to_string()
301    } else {
302        current.to_string()
303    }
304}
305
306fn normalize_output(bytes: &[u8]) -> String {
307    String::from_utf8_lossy(bytes)
308        .replace("\r\n", "\n")
309        .replace('\r', "")
310}
311
312fn should_treat_as_expression(code: &str) -> bool {
313    let trimmed = code.trim();
314    if trimmed.is_empty() {
315        return false;
316    }
317    if trimmed.contains('\n') {
318        return false;
319    }
320    if trimmed.ends_with(':') {
321        return false;
322    }
323
324    let lowered = trimmed.to_ascii_lowercase();
325    const STATEMENT_PREFIXES: [&str; 21] = [
326        "import ",
327        "from ",
328        "def ",
329        "class ",
330        "if ",
331        "for ",
332        "while ",
333        "try",
334        "except",
335        "finally",
336        "with ",
337        "return ",
338        "raise ",
339        "yield",
340        "async ",
341        "await ",
342        "assert ",
343        "del ",
344        "global ",
345        "nonlocal ",
346        "pass",
347    ];
348    if STATEMENT_PREFIXES
349        .iter()
350        .any(|prefix| lowered.starts_with(prefix))
351    {
352        return false;
353    }
354
355    if lowered.starts_with("print(") || lowered.starts_with("print ") {
356        return false;
357    }
358
359    if trimmed.starts_with("#") {
360        return false;
361    }
362
363    if trimmed.contains('=')
364        && !trimmed.contains("==")
365        && !trimmed.contains("!=")
366        && !trimmed.contains(">=")
367        && !trimmed.contains("<=")
368        && !trimmed.contains("=>")
369    {
370        return false;
371    }
372
373    true
374}