1use std::fs;
2use std::io::Write;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
11
12pub struct PythonEngine {
13 executable: PathBuf,
14}
15
16impl PythonEngine {
17 pub fn new() -> Self {
18 let executable = resolve_python_binary();
19 Self { executable }
20 }
21
22 fn binary(&self) -> &Path {
23 &self.executable
24 }
25
26 fn run_command(&self) -> Command {
27 Command::new(self.binary())
28 }
29}
30
31impl LanguageEngine for PythonEngine {
32 fn id(&self) -> &'static str {
33 "python"
34 }
35
36 fn display_name(&self) -> &'static str {
37 "Python"
38 }
39
40 fn aliases(&self) -> &[&'static str] {
41 &["py", "python3", "py3"]
42 }
43
44 fn supports_sessions(&self) -> bool {
45 true
46 }
47
48 fn validate(&self) -> Result<()> {
49 let mut cmd = self.run_command();
50 cmd.arg("--version")
51 .stdout(Stdio::null())
52 .stderr(Stdio::null());
53 cmd.status()
54 .with_context(|| format!("failed to invoke {}", self.binary().display()))?
55 .success()
56 .then_some(())
57 .ok_or_else(|| anyhow::anyhow!("{} is not executable", self.binary().display()))
58 }
59
60 fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
61 let start = Instant::now();
62 let mut cmd = self.run_command();
63 let output = match payload {
64 ExecutionPayload::Inline { code } => {
65 cmd.arg("-c").arg(code);
66 cmd.stdin(Stdio::inherit());
67 cmd.output()
68 }
69 ExecutionPayload::File { path } => {
70 cmd.arg(path);
71 cmd.stdin(Stdio::inherit());
72 cmd.output()
73 }
74 ExecutionPayload::Stdin { code } => {
75 cmd.arg("-")
76 .stdin(Stdio::piped())
77 .stdout(Stdio::piped())
78 .stderr(Stdio::piped());
79 let mut child = cmd.spawn().with_context(|| {
80 format!(
81 "failed to start {} for stdin execution",
82 self.binary().display()
83 )
84 })?;
85 if let Some(mut stdin) = child.stdin.take() {
86 stdin.write_all(code.as_bytes())?;
87 }
88 child.wait_with_output()
89 }
90 }?;
91
92 Ok(ExecutionOutcome {
93 language: self.id().to_string(),
94 exit_code: output.status.code(),
95 stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
96 stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
97 duration: start.elapsed(),
98 })
99 }
100
101 fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
102 Ok(Box::new(PythonSession::new(self.executable.clone())?))
103 }
104}
105
106struct PythonSession {
107 executable: PathBuf,
108 dir: TempDir,
109 source_path: PathBuf,
110 statements: Vec<String>,
111 previous_stdout: String,
112 previous_stderr: String,
113}
114
115impl PythonSession {
116 fn new(executable: PathBuf) -> Result<Self> {
117 let dir = Builder::new()
118 .prefix("run-python-repl")
119 .tempdir()
120 .context("failed to create temporary directory for python repl")?;
121 let source_path = dir.path().join("session.py");
122 fs::write(&source_path, "# Python REPL session\n")
123 .with_context(|| format!("failed to initialize {}", source_path.display()))?;
124
125 Ok(Self {
126 executable,
127 dir,
128 source_path,
129 statements: Vec::new(),
130 previous_stdout: String::new(),
131 previous_stderr: String::new(),
132 })
133 }
134
135 fn render_source(&self) -> String {
136 let mut source = String::from("import sys\nfrom math import *\n\n");
137 for snippet in &self.statements {
138 source.push_str(snippet);
139 if !snippet.ends_with('\n') {
140 source.push('\n');
141 }
142 }
143 source
144 }
145
146 fn write_source(&self, contents: &str) -> Result<()> {
147 fs::write(&self.source_path, contents).with_context(|| {
148 format!(
149 "failed to write generated Python REPL source to {}",
150 self.source_path.display()
151 )
152 })
153 }
154
155 fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
156 let source = self.render_source();
157 self.write_source(&source)?;
158
159 let output = self.run_script()?;
160 let stdout_full = normalize_output(&output.stdout);
161 let stderr_full = normalize_output(&output.stderr);
162
163 let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
164 let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
165
166 let success = output.status.success();
167 if success {
168 self.previous_stdout = stdout_full;
169 self.previous_stderr = stderr_full;
170 }
171
172 let outcome = ExecutionOutcome {
173 language: "python".to_string(),
174 exit_code: output.status.code(),
175 stdout: stdout_delta,
176 stderr: stderr_delta,
177 duration: start.elapsed(),
178 };
179
180 Ok((outcome, success))
181 }
182
183 fn run_script(&self) -> Result<std::process::Output> {
184 let mut cmd = Command::new(&self.executable);
185 cmd.arg(&self.source_path)
186 .stdout(Stdio::piped())
187 .stderr(Stdio::piped())
188 .current_dir(self.dir.path());
189 cmd.output().with_context(|| {
190 format!(
191 "failed to run python session script {} with {}",
192 self.source_path.display(),
193 self.executable.display()
194 )
195 })
196 }
197
198 fn run_snippet(&mut self, snippet: String) -> Result<ExecutionOutcome> {
199 self.statements.push(snippet);
200 let start = Instant::now();
201 let (outcome, success) = self.run_current(start)?;
202 if !success {
203 let _ = self.statements.pop();
204 let source = self.render_source();
205 self.write_source(&source)?;
206 }
207 Ok(outcome)
208 }
209
210 fn reset_state(&mut self) -> Result<()> {
211 self.statements.clear();
212 self.previous_stdout.clear();
213 self.previous_stderr.clear();
214 let source = self.render_source();
215 self.write_source(&source)
216 }
217}
218
219impl LanguageSession for PythonSession {
220 fn language_id(&self) -> &str {
221 "python"
222 }
223
224 fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
225 let trimmed = code.trim();
226 if trimmed.is_empty() {
227 return Ok(ExecutionOutcome {
228 language: self.language_id().to_string(),
229 exit_code: None,
230 stdout: String::new(),
231 stderr: String::new(),
232 duration: Duration::default(),
233 });
234 }
235
236 if trimmed.eq_ignore_ascii_case(":reset") {
237 self.reset_state()?;
238 return Ok(ExecutionOutcome {
239 language: self.language_id().to_string(),
240 exit_code: None,
241 stdout: String::new(),
242 stderr: String::new(),
243 duration: Duration::default(),
244 });
245 }
246
247 if trimmed.eq_ignore_ascii_case(":help") {
248 return Ok(ExecutionOutcome {
249 language: self.language_id().to_string(),
250 exit_code: None,
251 stdout:
252 "Python commands:\n :reset - clear session state\n :help - show this message\n"
253 .to_string(),
254 stderr: String::new(),
255 duration: Duration::default(),
256 });
257 }
258
259 if should_treat_as_expression(trimmed) {
260 let snippet = wrap_expression(trimmed, self.statements.len());
261 let outcome = self.run_snippet(snippet)?;
262 if outcome.exit_code.unwrap_or(0) == 0 {
263 return Ok(outcome);
264 }
265 }
266
267 let snippet = ensure_trailing_newline(code);
268 self.run_snippet(snippet)
269 }
270
271 fn shutdown(&mut self) -> Result<()> {
272 Ok(())
273 }
274}
275
276fn resolve_python_binary() -> PathBuf {
277 let candidates = ["python3", "python", "py"]; for name in candidates {
279 if let Ok(path) = which::which(name) {
280 return path;
281 }
282 }
283 PathBuf::from("python3")
284}
285
286fn ensure_trailing_newline(code: &str) -> String {
287 let mut owned = code.to_string();
288 if !owned.ends_with('\n') {
289 owned.push('\n');
290 }
291 owned
292}
293
294fn wrap_expression(code: &str, index: usize) -> String {
295 format!("__run_value_{index} = ({code})\nprint(repr(__run_value_{index}), flush=True)\n")
296}
297
298fn diff_output(previous: &str, current: &str) -> String {
299 if let Some(stripped) = current.strip_prefix(previous) {
300 stripped.to_string()
301 } else {
302 current.to_string()
303 }
304}
305
306fn normalize_output(bytes: &[u8]) -> String {
307 String::from_utf8_lossy(bytes)
308 .replace("\r\n", "\n")
309 .replace('\r', "")
310}
311
312fn should_treat_as_expression(code: &str) -> bool {
313 let trimmed = code.trim();
314 if trimmed.is_empty() {
315 return false;
316 }
317 if trimmed.contains('\n') {
318 return false;
319 }
320 if trimmed.ends_with(':') {
321 return false;
322 }
323
324 let lowered = trimmed.to_ascii_lowercase();
325 const STATEMENT_PREFIXES: [&str; 21] = [
326 "import ",
327 "from ",
328 "def ",
329 "class ",
330 "if ",
331 "for ",
332 "while ",
333 "try",
334 "except",
335 "finally",
336 "with ",
337 "return ",
338 "raise ",
339 "yield",
340 "async ",
341 "await ",
342 "assert ",
343 "del ",
344 "global ",
345 "nonlocal ",
346 "pass",
347 ];
348 if STATEMENT_PREFIXES
349 .iter()
350 .any(|prefix| lowered.starts_with(prefix))
351 {
352 return false;
353 }
354
355 if lowered.starts_with("print(") || lowered.starts_with("print ") {
356 return false;
357 }
358
359 if trimmed.starts_with("#") {
360 return false;
361 }
362
363 if trimmed.contains('=')
364 && !trimmed.contains("==")
365 && !trimmed.contains("!=")
366 && !trimmed.contains(">=")
367 && !trimmed.contains("<=")
368 && !trimmed.contains("=>")
369 {
370 return false;
371 }
372
373 true
374}