run/engine/
rust.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::Instant;
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
10
11pub struct RustEngine {
12    compiler: Option<PathBuf>,
13}
14
15impl RustEngine {
16    pub fn new() -> Self {
17        Self {
18            compiler: resolve_rustc_binary(),
19        }
20    }
21
22    fn ensure_compiler(&self) -> Result<&Path> {
23        self.compiler.as_deref().ok_or_else(|| {
24            anyhow::anyhow!(
25                "Rust support requires the `rustc` executable. Install it via Rustup and ensure it is on your PATH."
26            )
27        })
28    }
29
30    fn compile(&self, source: &Path, output: &Path) -> Result<std::process::Output> {
31        let compiler = self.ensure_compiler()?;
32        let mut cmd = Command::new(compiler);
33        cmd.arg("--color=never")
34            .arg("--edition=2021")
35            .arg("--crate-name")
36            .arg("run_snippet")
37            .arg(source)
38            .arg("-o")
39            .arg(output);
40        cmd.output()
41            .with_context(|| format!("failed to invoke rustc at {}", compiler.display()))
42    }
43
44    fn run_binary(&self, binary: &Path) -> Result<std::process::Output> {
45        let mut cmd = Command::new(binary);
46        cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
47        cmd.stdin(Stdio::inherit());
48        cmd.output()
49            .with_context(|| format!("failed to execute compiled binary {}", binary.display()))
50    }
51
52    fn write_inline_source(&self, code: &str, dir: &Path) -> Result<PathBuf> {
53        let source_path = dir.join("main.rs");
54        std::fs::write(&source_path, code).with_context(|| {
55            format!(
56                "failed to write temporary Rust source to {}",
57                source_path.display()
58            )
59        })?;
60        Ok(source_path)
61    }
62
63    fn tmp_binary_path(dir: &Path) -> PathBuf {
64        let mut path = dir.join("run_rust_binary");
65        if let Some(ext) = std::env::consts::EXE_SUFFIX.strip_prefix('.') {
66            if !ext.is_empty() {
67                path.set_extension(ext);
68            }
69        } else if !std::env::consts::EXE_SUFFIX.is_empty() {
70            path = PathBuf::from(format!(
71                "{}{}",
72                path.display(),
73                std::env::consts::EXE_SUFFIX
74            ));
75        }
76        path
77    }
78}
79
80impl LanguageEngine for RustEngine {
81    fn id(&self) -> &'static str {
82        "rust"
83    }
84
85    fn display_name(&self) -> &'static str {
86        "Rust"
87    }
88
89    fn aliases(&self) -> &[&'static str] {
90        &["rs"]
91    }
92
93    fn supports_sessions(&self) -> bool {
94        true
95    }
96
97    fn validate(&self) -> Result<()> {
98        let compiler = self.ensure_compiler()?;
99        let mut cmd = Command::new(compiler);
100        cmd.arg("--version")
101            .stdout(Stdio::null())
102            .stderr(Stdio::null());
103        cmd.status()
104            .with_context(|| format!("failed to invoke {}", compiler.display()))?
105            .success()
106            .then_some(())
107            .ok_or_else(|| anyhow::anyhow!("{} is not executable", compiler.display()))
108    }
109
110    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
111        let temp_dir = Builder::new()
112            .prefix("run-rust")
113            .tempdir()
114            .context("failed to create temporary directory for rust build")?;
115        let dir_path = temp_dir.path();
116
117        let (source_path, cleanup_source): (PathBuf, bool) = match payload {
118            ExecutionPayload::Inline { code } => (self.write_inline_source(code, dir_path)?, true),
119            ExecutionPayload::Stdin { code } => (self.write_inline_source(code, dir_path)?, true),
120            ExecutionPayload::File { path } => (path.clone(), false),
121        };
122
123        let binary_path = Self::tmp_binary_path(dir_path);
124        let start = Instant::now();
125
126        let compile_output = self.compile(&source_path, &binary_path)?;
127        if !compile_output.status.success() {
128            let stdout = String::from_utf8_lossy(&compile_output.stdout).into_owned();
129            let stderr = String::from_utf8_lossy(&compile_output.stderr).into_owned();
130            return Ok(ExecutionOutcome {
131                language: self.id().to_string(),
132                exit_code: compile_output.status.code(),
133                stdout,
134                stderr,
135                duration: start.elapsed(),
136            });
137        }
138
139        let runtime_output = self.run_binary(&binary_path)?;
140        let outcome = ExecutionOutcome {
141            language: self.id().to_string(),
142            exit_code: runtime_output.status.code(),
143            stdout: String::from_utf8_lossy(&runtime_output.stdout).into_owned(),
144            stderr: String::from_utf8_lossy(&runtime_output.stderr).into_owned(),
145            duration: start.elapsed(),
146        };
147
148        if cleanup_source {
149            let _ = std::fs::remove_file(&source_path);
150        }
151        let _ = std::fs::remove_file(&binary_path);
152
153        Ok(outcome)
154    }
155
156    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
157        let compiler = self.ensure_compiler()?.to_path_buf();
158        let session = RustSession::new(compiler)?;
159        Ok(Box::new(session))
160    }
161}
162
163struct RustSession {
164    compiler: PathBuf,
165    workspace: TempDir,
166    items: Vec<String>,
167    statements: Vec<String>,
168    last_stdout: String,
169    last_stderr: String,
170}
171
172enum RustSnippetKind {
173    Item,
174    Statement,
175}
176
177impl RustSession {
178    fn new(compiler: PathBuf) -> Result<Self> {
179        let workspace = TempDir::new().context("failed to create Rust session workspace")?;
180        let session = Self {
181            compiler,
182            workspace,
183            items: Vec::new(),
184            statements: Vec::new(),
185            last_stdout: String::new(),
186            last_stderr: String::new(),
187        };
188        session.persist_source()?;
189        Ok(session)
190    }
191
192    fn language_id(&self) -> &str {
193        "rust"
194    }
195
196    fn source_path(&self) -> PathBuf {
197        self.workspace.path().join("session.rs")
198    }
199
200    fn binary_path(&self) -> PathBuf {
201        RustEngine::tmp_binary_path(self.workspace.path())
202    }
203
204    fn persist_source(&self) -> Result<()> {
205        let source = self.render_source();
206        fs::write(self.source_path(), source)
207            .with_context(|| "failed to write Rust session source".to_string())
208    }
209
210    fn render_source(&self) -> String {
211        let mut source = String::from(
212            r#"#![allow(unused_variables, unused_assignments, unused_mut, dead_code, unused_imports)]
213use std::fmt::Debug;
214
215fn __print<T: Debug>(value: T) {
216    println!("{:?}", value);
217}
218
219"#,
220        );
221
222        for item in &self.items {
223            source.push_str(item);
224            if !item.ends_with('\n') {
225                source.push('\n');
226            }
227            source.push('\n');
228        }
229
230        source.push_str("fn main() {\n");
231        if self.statements.is_empty() {
232            source.push_str("    // session body\n");
233        } else {
234            for snippet in &self.statements {
235                for line in snippet.lines() {
236                    source.push_str("    ");
237                    source.push_str(line);
238                    source.push('\n');
239                }
240            }
241        }
242        source.push_str("}\n");
243
244        source
245    }
246
247    fn compile(&self, source: &Path, output: &Path) -> Result<std::process::Output> {
248        let mut cmd = Command::new(&self.compiler);
249        cmd.arg("--color=never")
250            .arg("--edition=2021")
251            .arg("--crate-name")
252            .arg("run_snippet")
253            .arg(source)
254            .arg("-o")
255            .arg(output);
256        cmd.output()
257            .with_context(|| format!("failed to invoke rustc at {}", self.compiler.display()))
258    }
259
260    fn run_binary(&self, binary: &Path) -> Result<std::process::Output> {
261        let mut cmd = Command::new(binary);
262        cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
263        cmd.output().with_context(|| {
264            format!(
265                "failed to execute compiled Rust session binary {}",
266                binary.display()
267            )
268        })
269    }
270
271    fn run_standalone_program(&mut self, code: &str) -> Result<ExecutionOutcome> {
272        let start = Instant::now();
273        let source_path = self.workspace.path().join("standalone.rs");
274        fs::write(&source_path, code)
275            .with_context(|| "failed to write standalone Rust source".to_string())?;
276
277        let binary_path = self.binary_path();
278        let compile_output = self.compile(&source_path, &binary_path)?;
279        if !compile_output.status.success() {
280            let outcome = ExecutionOutcome {
281                language: self.language_id().to_string(),
282                exit_code: compile_output.status.code(),
283                stdout: String::from_utf8_lossy(&compile_output.stdout).into_owned(),
284                stderr: String::from_utf8_lossy(&compile_output.stderr).into_owned(),
285                duration: start.elapsed(),
286            };
287            let _ = fs::remove_file(&source_path);
288            let _ = fs::remove_file(&binary_path);
289            return Ok(outcome);
290        }
291
292        let runtime_output = self.run_binary(&binary_path)?;
293        let outcome = ExecutionOutcome {
294            language: self.language_id().to_string(),
295            exit_code: runtime_output.status.code(),
296            stdout: String::from_utf8_lossy(&runtime_output.stdout).into_owned(),
297            stderr: String::from_utf8_lossy(&runtime_output.stderr).into_owned(),
298            duration: start.elapsed(),
299        };
300
301        let _ = fs::remove_file(&source_path);
302        let _ = fs::remove_file(&binary_path);
303
304        Ok(outcome)
305    }
306
307    fn add_snippet(&mut self, code: &str) -> RustSnippetKind {
308        let trimmed = code.trim();
309        if trimmed.is_empty() {
310            return RustSnippetKind::Statement;
311        }
312
313        if is_item_snippet(trimmed) {
314            let mut snippet = code.to_string();
315            if !snippet.ends_with('\n') {
316                snippet.push('\n');
317            }
318            self.items.push(snippet);
319            RustSnippetKind::Item
320        } else {
321            let stored = if should_treat_as_expression(trimmed) {
322                wrap_expression(trimmed)
323            } else {
324                let mut snippet = code.to_string();
325                if !snippet.ends_with('\n') {
326                    snippet.push('\n');
327                }
328                snippet
329            };
330            self.statements.push(stored);
331            RustSnippetKind::Statement
332        }
333    }
334
335    fn rollback(&mut self, kind: RustSnippetKind) -> Result<()> {
336        match kind {
337            RustSnippetKind::Item => {
338                self.items.pop();
339            }
340            RustSnippetKind::Statement => {
341                self.statements.pop();
342            }
343        }
344        self.persist_source()
345    }
346
347    fn normalize_output(bytes: &[u8]) -> String {
348        String::from_utf8_lossy(bytes)
349            .replace("\r\n", "\n")
350            .replace('\r', "")
351    }
352
353    fn diff_outputs(previous: &str, current: &str) -> String {
354        if let Some(suffix) = current.strip_prefix(previous) {
355            suffix.to_string()
356        } else {
357            current.to_string()
358        }
359    }
360
361    fn run_snippet(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
362        let start = Instant::now();
363        let kind = self.add_snippet(code);
364        self.persist_source()?;
365
366        let source_path = self.source_path();
367        let binary_path = self.binary_path();
368
369        let compile_output = self.compile(&source_path, &binary_path)?;
370        if !compile_output.status.success() {
371            self.rollback(kind)?;
372            let outcome = ExecutionOutcome {
373                language: self.language_id().to_string(),
374                exit_code: compile_output.status.code(),
375                stdout: String::from_utf8_lossy(&compile_output.stdout).into_owned(),
376                stderr: String::from_utf8_lossy(&compile_output.stderr).into_owned(),
377                duration: start.elapsed(),
378            };
379            let _ = fs::remove_file(&binary_path);
380            return Ok((outcome, false));
381        }
382
383        let runtime_output = self.run_binary(&binary_path)?;
384        let stdout_full = Self::normalize_output(&runtime_output.stdout);
385        let stderr_full = Self::normalize_output(&runtime_output.stderr);
386
387        let stdout = Self::diff_outputs(&self.last_stdout, &stdout_full);
388        let stderr = Self::diff_outputs(&self.last_stderr, &stderr_full);
389        let success = runtime_output.status.success();
390
391        if success {
392            self.last_stdout = stdout_full;
393            self.last_stderr = stderr_full;
394        } else {
395            self.rollback(kind)?;
396        }
397
398        let outcome = ExecutionOutcome {
399            language: self.language_id().to_string(),
400            exit_code: runtime_output.status.code(),
401            stdout,
402            stderr,
403            duration: start.elapsed(),
404        };
405
406        let _ = fs::remove_file(&binary_path);
407
408        Ok((outcome, success))
409    }
410}
411
412impl LanguageSession for RustSession {
413    fn language_id(&self) -> &str {
414        RustSession::language_id(self)
415    }
416
417    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
418        let trimmed = code.trim();
419        if trimmed.is_empty() {
420            return Ok(ExecutionOutcome {
421                language: self.language_id().to_string(),
422                exit_code: None,
423                stdout: String::new(),
424                stderr: String::new(),
425                duration: Instant::now().elapsed(),
426            });
427        }
428
429        if contains_main_definition(trimmed) {
430            return self.run_standalone_program(code);
431        }
432
433        let (outcome, _) = self.run_snippet(code)?;
434        Ok(outcome)
435    }
436
437    fn shutdown(&mut self) -> Result<()> {
438        Ok(())
439    }
440}
441
442fn resolve_rustc_binary() -> Option<PathBuf> {
443    which::which("rustc").ok()
444}
445
446fn is_item_snippet(code: &str) -> bool {
447    let mut trimmed = code.trim_start();
448    if trimmed.is_empty() {
449        return false;
450    }
451
452    if trimmed.starts_with("#[") || trimmed.starts_with("#!") {
453        return true;
454    }
455
456    if trimmed.starts_with("pub ") {
457        trimmed = trimmed[4..].trim_start();
458    } else if trimmed.starts_with("pub(") {
459        if let Some(idx) = trimmed.find(')') {
460            trimmed = trimmed[idx + 1..].trim_start();
461        }
462    }
463
464    let first_token = trimmed.split_whitespace().next().unwrap_or("");
465    let keywords = [
466        "fn",
467        "struct",
468        "enum",
469        "trait",
470        "impl",
471        "mod",
472        "use",
473        "type",
474        "const",
475        "static",
476        "macro_rules!",
477        "extern",
478    ];
479
480    if keywords.iter().any(|kw| first_token.starts_with(kw)) {
481        return true;
482    }
483
484    false
485}
486
487fn should_treat_as_expression(code: &str) -> bool {
488    let trimmed = code.trim();
489    if trimmed.is_empty() {
490        return false;
491    }
492    if trimmed.contains('\n') {
493        return false;
494    }
495    if trimmed.ends_with(';') {
496        return false;
497    }
498    const RESERVED: [&str; 11] = [
499        "let ", "const ", "static ", "fn ", "struct ", "enum ", "impl", "trait ", "mod ", "while ",
500        "for ",
501    ];
502    if RESERVED.iter().any(|kw| trimmed.starts_with(kw)) {
503        return false;
504    }
505    if trimmed.starts_with("if ") || trimmed.starts_with("loop ") || trimmed.starts_with("match ") {
506        return false;
507    }
508    if trimmed.starts_with("return ") {
509        return false;
510    }
511    true
512}
513
514fn wrap_expression(code: &str) -> String {
515    format!("__print({});\n", code)
516}
517
518fn contains_main_definition(code: &str) -> bool {
519    let bytes = code.as_bytes();
520    let len = bytes.len();
521    let mut i = 0;
522    let mut in_line_comment = false;
523    let mut block_depth = 0usize;
524    let mut in_string = false;
525    let mut in_char = false;
526
527    while i < len {
528        let byte = bytes[i];
529
530        if in_line_comment {
531            if byte == b'\n' {
532                in_line_comment = false;
533            }
534            i += 1;
535            continue;
536        }
537
538        if in_string {
539            if byte == b'\\' {
540                i = (i + 2).min(len);
541                continue;
542            }
543            if byte == b'"' {
544                in_string = false;
545            }
546            i += 1;
547            continue;
548        }
549
550        if in_char {
551            if byte == b'\\' {
552                i = (i + 2).min(len);
553                continue;
554            }
555            if byte == b'\'' {
556                in_char = false;
557            }
558            i += 1;
559            continue;
560        }
561
562        if block_depth > 0 {
563            if byte == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
564                block_depth += 1;
565                i += 2;
566                continue;
567            }
568            if byte == b'*' && i + 1 < len && bytes[i + 1] == b'/' {
569                block_depth -= 1;
570                i += 2;
571                continue;
572            }
573            i += 1;
574            continue;
575        }
576
577        match byte {
578            b'/' if i + 1 < len && bytes[i + 1] == b'/' => {
579                in_line_comment = true;
580                i += 2;
581                continue;
582            }
583            b'/' if i + 1 < len && bytes[i + 1] == b'*' => {
584                block_depth = 1;
585                i += 2;
586                continue;
587            }
588            b'"' => {
589                in_string = true;
590                i += 1;
591                continue;
592            }
593            b'\'' => {
594                in_char = true;
595                i += 1;
596                continue;
597            }
598            b'f' if i + 1 < len && bytes[i + 1] == b'n' => {
599                let mut prev_idx = i;
600                let mut preceding_identifier = false;
601                while prev_idx > 0 {
602                    prev_idx -= 1;
603                    let ch = bytes[prev_idx];
604                    if ch.is_ascii_whitespace() {
605                        continue;
606                    }
607                    if ch.is_ascii_alphanumeric() || ch == b'_' {
608                        preceding_identifier = true;
609                    }
610                    break;
611                }
612                if preceding_identifier {
613                    i += 1;
614                    continue;
615                }
616
617                let mut j = i + 2;
618                while j < len && bytes[j].is_ascii_whitespace() {
619                    j += 1;
620                }
621                if j + 4 > len || &bytes[j..j + 4] != b"main" {
622                    i += 1;
623                    continue;
624                }
625
626                let end_idx = j + 4;
627                if end_idx < len {
628                    let ch = bytes[end_idx];
629                    if ch.is_ascii_alphanumeric() || ch == b'_' {
630                        i += 1;
631                        continue;
632                    }
633                }
634
635                let mut after = end_idx;
636                while after < len && bytes[after].is_ascii_whitespace() {
637                    after += 1;
638                }
639                if after < len && bytes[after] != b'(' {
640                    i += 1;
641                    continue;
642                }
643
644                return true;
645            }
646            _ => {}
647        }
648
649        i += 1;
650    }
651
652    false
653}