1use std::fs;
2use std::io::Write;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
11
12pub struct PythonEngine {
13 executable: PathBuf,
14}
15
16impl PythonEngine {
17 pub fn new() -> Self {
18 let executable = resolve_python_binary();
19 Self { executable }
20 }
21
22 fn binary(&self) -> &Path {
23 &self.executable
24 }
25
26 fn run_command(&self) -> Command {
27 Command::new(self.binary())
28 }
29}
30
31impl LanguageEngine for PythonEngine {
32 fn id(&self) -> &'static str {
33 "python"
34 }
35
36 fn display_name(&self) -> &'static str {
37 "Python"
38 }
39
40 fn aliases(&self) -> &[&'static str] {
41 &["py", "python3", "py3"]
42 }
43
44 fn supports_sessions(&self) -> bool {
45 true
46 }
47
48 fn validate(&self) -> Result<()> {
49 let mut cmd = self.run_command();
50 cmd.arg("--version")
51 .stdout(Stdio::null())
52 .stderr(Stdio::null());
53 cmd.status()
54 .with_context(|| format!("failed to invoke {}", self.binary().display()))?
55 .success()
56 .then_some(())
57 .ok_or_else(|| anyhow::anyhow!("{} is not executable", self.binary().display()))
58 }
59
60 fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
61 let start = Instant::now();
62 let mut cmd = self.run_command();
63 let output = match payload {
64 ExecutionPayload::Inline { code } => {
65 cmd.arg("-c").arg(code);
66 cmd.stdin(Stdio::inherit());
67 cmd.output()
68 }
69 ExecutionPayload::File { path } => {
70 cmd.arg(path);
71 cmd.stdin(Stdio::inherit());
72 cmd.output()
73 }
74 ExecutionPayload::Stdin { code } => {
75 cmd.arg("-")
76 .stdin(Stdio::piped())
77 .stdout(Stdio::piped())
78 .stderr(Stdio::piped());
79 let mut child = cmd.spawn().with_context(|| {
80 format!(
81 "failed to start {} for stdin execution",
82 self.binary().display()
83 )
84 })?;
85 if let Some(mut stdin) = child.stdin.take() {
86 stdin.write_all(code.as_bytes())?;
87 }
88 child.wait_with_output()
89 }
90 }?;
91
92 Ok(ExecutionOutcome {
93 language: self.id().to_string(),
94 exit_code: output.status.code(),
95 stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
96 stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
97 duration: start.elapsed(),
98 })
99 }
100
101 fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
102 Ok(Box::new(PythonSession::new(self.executable.clone())?))
103 }
104}
105
106struct PythonSession {
107 executable: PathBuf,
108 dir: TempDir,
109 source_path: PathBuf,
110 statements: Vec<String>,
111 previous_stdout: String,
112 previous_stderr: String,
113}
114
115impl PythonSession {
116 fn new(executable: PathBuf) -> Result<Self> {
117 let dir = Builder::new()
118 .prefix("run-python-repl")
119 .tempdir()
120 .context("failed to create temporary directory for python repl")?;
121 let source_path = dir.path().join("session.py");
122 fs::write(&source_path, "# Python REPL session\n")
123 .with_context(|| format!("failed to initialize {}", source_path.display()))?;
124
125 Ok(Self {
126 executable,
127 dir,
128 source_path,
129 statements: Vec::new(),
130 previous_stdout: String::new(),
131 previous_stderr: String::new(),
132 })
133 }
134
135 fn render_source(&self) -> String {
136 let mut source = String::from("import sys\nfrom math import *\n\n");
137 for snippet in &self.statements {
138 source.push_str(snippet);
139 if !snippet.ends_with('\n') {
140 source.push('\n');
141 }
142 }
143 source
144 }
145
146 fn write_source(&self, contents: &str) -> Result<()> {
147 fs::write(&self.source_path, contents).with_context(|| {
148 format!(
149 "failed to write generated Python REPL source to {}",
150 self.source_path.display()
151 )
152 })
153 }
154
155 fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
156 let source = self.render_source();
157 self.write_source(&source)?;
158
159 let output = self.run_script()?;
160 let stdout_full = normalize_output(&output.stdout);
161 let stderr_full = normalize_output(&output.stderr);
162
163 let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
164 let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
165
166 let success = output.status.success();
167 if success {
168 self.previous_stdout = stdout_full;
169 self.previous_stderr = stderr_full;
170 }
171
172 let outcome = ExecutionOutcome {
173 language: "python".to_string(),
174 exit_code: output.status.code(),
175 stdout: stdout_delta,
176 stderr: stderr_delta,
177 duration: start.elapsed(),
178 };
179
180 Ok((outcome, success))
181 }
182
183 fn run_script(&self) -> Result<std::process::Output> {
184 let mut cmd = Command::new(&self.executable);
185 cmd.arg(&self.source_path)
186 .stdout(Stdio::piped())
187 .stderr(Stdio::piped())
188 .current_dir(self.dir.path());
189 cmd.output().with_context(|| {
190 format!(
191 "failed to run python session script {} with {}",
192 self.source_path.display(),
193 self.executable.display()
194 )
195 })
196 }
197
198 fn run_snippet(&mut self, snippet: String) -> Result<ExecutionOutcome> {
199 self.statements.push(snippet);
200 let start = Instant::now();
201 let (outcome, success) = self.run_current(start)?;
202 if !success {
203 let _ = self.statements.pop();
204 let source = self.render_source();
205 self.write_source(&source)?;
206 }
207 Ok(outcome)
208 }
209
210 fn reset_state(&mut self) -> Result<()> {
211 self.statements.clear();
212 self.previous_stdout.clear();
213 self.previous_stderr.clear();
214 let source = self.render_source();
215 self.write_source(&source)
216 }
217}
218
219impl LanguageSession for PythonSession {
220 fn language_id(&self) -> &str {
221 "python"
222 }
223
224 fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
225 let trimmed = code.trim();
226 if trimmed.is_empty() {
227 return Ok(ExecutionOutcome {
228 language: self.language_id().to_string(),
229 exit_code: None,
230 stdout: String::new(),
231 stderr: String::new(),
232 duration: Duration::default(),
233 });
234 }
235
236 if trimmed.eq_ignore_ascii_case(":reset") {
237 self.reset_state()?;
238 return Ok(ExecutionOutcome {
239 language: self.language_id().to_string(),
240 exit_code: None,
241 stdout: String::new(),
242 stderr: String::new(),
243 duration: Duration::default(),
244 });
245 }
246
247 if trimmed.eq_ignore_ascii_case(":help") {
248 return Ok(ExecutionOutcome {
249 language: self.language_id().to_string(),
250 exit_code: None,
251 stdout:
252 "Python commands:\n :reset — clear session state\n :help — show this message\n"
253 .to_string(),
254 stderr: String::new(),
255 duration: Duration::default(),
256 });
257 }
258
259 if should_treat_as_expression(trimmed) {
260 let snippet = wrap_expression(trimmed, self.statements.len());
261 let outcome = self.run_snippet(snippet)?;
262 if outcome.exit_code.unwrap_or(0) == 0 {
263 return Ok(outcome);
264 }
265 }
266
267 let snippet = ensure_trailing_newline(code);
268 self.run_snippet(snippet)
269 }
270
271 fn shutdown(&mut self) -> Result<()> {
272 Ok(())
274 }
275}
276
277fn resolve_python_binary() -> PathBuf {
278 let candidates = ["python3", "python", "py"]; for name in candidates {
280 if let Ok(path) = which::which(name) {
281 return path;
282 }
283 }
284 PathBuf::from("python3")
285}
286
287fn ensure_trailing_newline(code: &str) -> String {
288 let mut owned = code.to_string();
289 if !owned.ends_with('\n') {
290 owned.push('\n');
291 }
292 owned
293}
294
295fn wrap_expression(code: &str, index: usize) -> String {
296 format!("__run_value_{index} = ({code})\nprint(repr(__run_value_{index}), flush=True)\n")
297}
298
299fn diff_output(previous: &str, current: &str) -> String {
300 if let Some(stripped) = current.strip_prefix(previous) {
301 stripped.to_string()
302 } else {
303 current.to_string()
304 }
305}
306
307fn normalize_output(bytes: &[u8]) -> String {
308 String::from_utf8_lossy(bytes)
309 .replace("\r\n", "\n")
310 .replace('\r', "")
311}
312
313fn should_treat_as_expression(code: &str) -> bool {
314 let trimmed = code.trim();
315 if trimmed.is_empty() {
316 return false;
317 }
318 if trimmed.contains('\n') {
319 return false;
320 }
321 if trimmed.ends_with(':') {
322 return false;
323 }
324
325 let lowered = trimmed.to_ascii_lowercase();
326 const STATEMENT_PREFIXES: [&str; 21] = [
327 "import ",
328 "from ",
329 "def ",
330 "class ",
331 "if ",
332 "for ",
333 "while ",
334 "try",
335 "except",
336 "finally",
337 "with ",
338 "return ",
339 "raise ",
340 "yield",
341 "async ",
342 "await ",
343 "assert ",
344 "del ",
345 "global ",
346 "nonlocal ",
347 "pass",
348 ];
349 if STATEMENT_PREFIXES
350 .iter()
351 .any(|prefix| lowered.starts_with(prefix))
352 {
353 return false;
354 }
355
356 if lowered.starts_with("print(") || lowered.starts_with("print ") {
357 return false;
358 }
359
360 if trimmed.starts_with("#") {
361 return false;
362 }
363
364 if trimmed.contains('=')
365 && !trimmed.contains("==")
366 && !trimmed.contains("!=")
367 && !trimmed.contains(">=")
368 && !trimmed.contains("<=")
369 && !trimmed.contains("=>")
370 {
371 return false;
372 }
373
374 true
375}