Skip to main content

run/engine/
rust.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::Instant;
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{
10    ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, cache_store,
11    execution_timeout, hash_source, run_version_command, try_cached_execution, wait_with_timeout,
12};
13
14pub struct RustEngine {
15    compiler: Option<PathBuf>,
16}
17
18impl Default for RustEngine {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl RustEngine {
25    pub fn new() -> Self {
26        Self {
27            compiler: resolve_rustc_binary(),
28        }
29    }
30
31    fn ensure_compiler(&self) -> Result<&Path> {
32        self.compiler.as_deref().ok_or_else(|| {
33            anyhow::anyhow!(
34                "Rust support requires the `rustc` executable. Install it via Rustup and ensure it is on your PATH."
35            )
36        })
37    }
38
39    fn compile(&self, source: &Path, output: &Path) -> Result<std::process::Output> {
40        let compiler = self.ensure_compiler()?;
41        let mut cmd = Command::new(compiler);
42        cmd.arg("--color=never")
43            .arg("--edition=2021")
44            .arg("--crate-name")
45            .arg("run_snippet")
46            .arg(source)
47            .arg("-o")
48            .arg(output);
49        cmd.output()
50            .with_context(|| format!("failed to invoke rustc at {}", compiler.display()))
51    }
52
53    fn run_binary(&self, binary: &Path, args: &[String]) -> Result<std::process::Output> {
54        let mut cmd = Command::new(binary);
55        cmd.args(args).stdout(Stdio::piped()).stderr(Stdio::piped());
56        cmd.stdin(Stdio::inherit());
57        let child = cmd
58            .spawn()
59            .with_context(|| format!("failed to execute compiled binary {}", binary.display()))?;
60        wait_with_timeout(child, execution_timeout())
61    }
62
63    fn write_inline_source(&self, code: &str, dir: &Path) -> Result<PathBuf> {
64        let source_path = dir.join("main.rs");
65        std::fs::write(&source_path, code).with_context(|| {
66            format!(
67                "failed to write temporary Rust source to {}",
68                source_path.display()
69            )
70        })?;
71        Ok(source_path)
72    }
73
74    fn tmp_binary_path(dir: &Path) -> PathBuf {
75        let mut path = dir.join("run_rust_binary");
76        if let Some(ext) = std::env::consts::EXE_SUFFIX.strip_prefix('.') {
77            if !ext.is_empty() {
78                path.set_extension(ext);
79            }
80        } else if !std::env::consts::EXE_SUFFIX.is_empty() {
81            path = PathBuf::from(format!(
82                "{}{}",
83                path.display(),
84                std::env::consts::EXE_SUFFIX
85            ));
86        }
87        path
88    }
89}
90
91impl LanguageEngine for RustEngine {
92    fn id(&self) -> &'static str {
93        "rust"
94    }
95
96    fn display_name(&self) -> &'static str {
97        "Rust"
98    }
99
100    fn aliases(&self) -> &[&'static str] {
101        &["rs"]
102    }
103
104    fn supports_sessions(&self) -> bool {
105        true
106    }
107
108    fn validate(&self) -> Result<()> {
109        let compiler = self.ensure_compiler()?;
110        let mut cmd = Command::new(compiler);
111        cmd.arg("--version")
112            .stdout(Stdio::null())
113            .stderr(Stdio::null());
114        cmd.status()
115            .with_context(|| format!("failed to invoke {}", compiler.display()))?
116            .success()
117            .then_some(())
118            .ok_or_else(|| anyhow::anyhow!("{} is not executable", compiler.display()))
119    }
120
121    fn toolchain_version(&self) -> Result<Option<String>> {
122        let compiler = self.ensure_compiler()?;
123        let mut cmd = Command::new(compiler);
124        cmd.arg("--version");
125        let context = format!("{}", compiler.display());
126        run_version_command(cmd, &context)
127    }
128
129    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
130        // Try cache for inline/stdin payloads
131        let args = payload.args();
132
133        if let Some(code) = match payload {
134            ExecutionPayload::Inline { code, .. } | ExecutionPayload::Stdin { code, .. } => {
135                Some(code.as_str())
136            }
137            _ => None,
138        } {
139            let src_hash = hash_source(code);
140            if let Some(output) = try_cached_execution(src_hash) {
141                let start = Instant::now();
142                return Ok(ExecutionOutcome {
143                    language: self.id().to_string(),
144                    exit_code: output.status.code(),
145                    stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
146                    stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
147                    duration: start.elapsed(),
148                });
149            }
150        }
151
152        let temp_dir = Builder::new()
153            .prefix("run-rust")
154            .tempdir()
155            .context("failed to create temporary directory for rust build")?;
156        let dir_path = temp_dir.path();
157
158        let (source_path, cleanup_source, cache_key): (PathBuf, bool, Option<u64>) = match payload {
159            ExecutionPayload::Inline { code, .. } => {
160                let h = hash_source(code);
161                (self.write_inline_source(code, dir_path)?, true, Some(h))
162            }
163            ExecutionPayload::Stdin { code, .. } => {
164                let h = hash_source(code);
165                (self.write_inline_source(code, dir_path)?, true, Some(h))
166            }
167            ExecutionPayload::File { path, .. } => (path.clone(), false, None),
168        };
169
170        let binary_path = Self::tmp_binary_path(dir_path);
171        let start = Instant::now();
172
173        let compile_output = self.compile(&source_path, &binary_path)?;
174        if !compile_output.status.success() {
175            let stdout = String::from_utf8_lossy(&compile_output.stdout).into_owned();
176            let stderr = String::from_utf8_lossy(&compile_output.stderr).into_owned();
177            return Ok(ExecutionOutcome {
178                language: self.id().to_string(),
179                exit_code: compile_output.status.code(),
180                stdout,
181                stderr,
182                duration: start.elapsed(),
183            });
184        }
185
186        // Store in cache before running
187        if let Some(h) = cache_key {
188            cache_store(h, &binary_path);
189        }
190
191        let runtime_output = self.run_binary(&binary_path, args)?;
192        let outcome = ExecutionOutcome {
193            language: self.id().to_string(),
194            exit_code: runtime_output.status.code(),
195            stdout: String::from_utf8_lossy(&runtime_output.stdout).into_owned(),
196            stderr: String::from_utf8_lossy(&runtime_output.stderr).into_owned(),
197            duration: start.elapsed(),
198        };
199
200        if cleanup_source {
201            let _ = std::fs::remove_file(&source_path);
202        }
203        let _ = std::fs::remove_file(&binary_path);
204
205        Ok(outcome)
206    }
207
208    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
209        let compiler = self.ensure_compiler()?.to_path_buf();
210        let session = RustSession::new(compiler)?;
211        Ok(Box::new(session))
212    }
213}
214
215struct RustSession {
216    compiler: PathBuf,
217    workspace: TempDir,
218    items: Vec<String>,
219    statements: Vec<String>,
220    last_stdout: String,
221    last_stderr: String,
222}
223
224enum RustSnippetKind {
225    Item,
226    Statement,
227}
228
229impl RustSession {
230    fn new(compiler: PathBuf) -> Result<Self> {
231        let workspace = TempDir::new().context("failed to create Rust session workspace")?;
232        let session = Self {
233            compiler,
234            workspace,
235            items: Vec::new(),
236            statements: Vec::new(),
237            last_stdout: String::new(),
238            last_stderr: String::new(),
239        };
240        session.persist_source()?;
241        Ok(session)
242    }
243
244    fn language_id(&self) -> &str {
245        "rust"
246    }
247
248    fn source_path(&self) -> PathBuf {
249        self.workspace.path().join("session.rs")
250    }
251
252    fn binary_path(&self) -> PathBuf {
253        RustEngine::tmp_binary_path(self.workspace.path())
254    }
255
256    fn persist_source(&self) -> Result<()> {
257        let source = self.render_source();
258        fs::write(self.source_path(), source)
259            .with_context(|| "failed to write Rust session source".to_string())
260    }
261
262    fn render_source(&self) -> String {
263        let mut source = String::from(
264            r#"#![allow(unused_variables, unused_assignments, unused_mut, dead_code, unused_imports)]
265use std::fmt::Debug;
266
267fn __print<T: Debug>(value: T) {
268    println!("{:?}", value);
269}
270
271"#,
272        );
273
274        for item in &self.items {
275            source.push_str(item);
276            if !item.ends_with('\n') {
277                source.push('\n');
278            }
279            source.push('\n');
280        }
281
282        source.push_str("fn main() {\n");
283        if self.statements.is_empty() {
284            source.push_str("    // session body\n");
285        } else {
286            for snippet in &self.statements {
287                for line in snippet.lines() {
288                    source.push_str("    ");
289                    source.push_str(line);
290                    source.push('\n');
291                }
292            }
293        }
294        source.push_str("}\n");
295
296        source
297    }
298
299    fn compile(&self, source: &Path, output: &Path) -> Result<std::process::Output> {
300        let mut cmd = Command::new(&self.compiler);
301        cmd.arg("--color=never")
302            .arg("--edition=2021")
303            .arg("--crate-name")
304            .arg("run_snippet")
305            .arg(source)
306            .arg("-o")
307            .arg(output);
308        cmd.output()
309            .with_context(|| format!("failed to invoke rustc at {}", self.compiler.display()))
310    }
311
312    fn run_binary(&self, binary: &Path) -> Result<std::process::Output> {
313        let mut cmd = Command::new(binary);
314        cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
315        cmd.output().with_context(|| {
316            format!(
317                "failed to execute compiled Rust session binary {}",
318                binary.display()
319            )
320        })
321    }
322
323    fn run_standalone_program(&mut self, code: &str) -> Result<ExecutionOutcome> {
324        let start = Instant::now();
325        let source_path = self.workspace.path().join("standalone.rs");
326        fs::write(&source_path, code)
327            .with_context(|| "failed to write standalone Rust source".to_string())?;
328
329        let binary_path = self.binary_path();
330        let compile_output = self.compile(&source_path, &binary_path)?;
331        if !compile_output.status.success() {
332            let outcome = ExecutionOutcome {
333                language: self.language_id().to_string(),
334                exit_code: compile_output.status.code(),
335                stdout: String::from_utf8_lossy(&compile_output.stdout).into_owned(),
336                stderr: String::from_utf8_lossy(&compile_output.stderr).into_owned(),
337                duration: start.elapsed(),
338            };
339            let _ = fs::remove_file(&source_path);
340            let _ = fs::remove_file(&binary_path);
341            return Ok(outcome);
342        }
343
344        let runtime_output = self.run_binary(&binary_path)?;
345        let outcome = ExecutionOutcome {
346            language: self.language_id().to_string(),
347            exit_code: runtime_output.status.code(),
348            stdout: String::from_utf8_lossy(&runtime_output.stdout).into_owned(),
349            stderr: String::from_utf8_lossy(&runtime_output.stderr).into_owned(),
350            duration: start.elapsed(),
351        };
352
353        let _ = fs::remove_file(&source_path);
354        let _ = fs::remove_file(&binary_path);
355
356        Ok(outcome)
357    }
358
359    fn add_snippet(&mut self, code: &str) -> RustSnippetKind {
360        let trimmed = code.trim();
361        if trimmed.is_empty() {
362            return RustSnippetKind::Statement;
363        }
364
365        if is_item_snippet(trimmed) {
366            let mut snippet = code.to_string();
367            if !snippet.ends_with('\n') {
368                snippet.push('\n');
369            }
370            self.items.push(snippet);
371            RustSnippetKind::Item
372        } else {
373            let stored = if should_treat_as_expression(trimmed) {
374                wrap_expression(trimmed)
375            } else {
376                let mut snippet = code.to_string();
377                if !snippet.ends_with('\n') {
378                    snippet.push('\n');
379                }
380                snippet
381            };
382            self.statements.push(stored);
383            RustSnippetKind::Statement
384        }
385    }
386
387    fn rollback(&mut self, kind: RustSnippetKind) -> Result<()> {
388        match kind {
389            RustSnippetKind::Item => {
390                self.items.pop();
391            }
392            RustSnippetKind::Statement => {
393                self.statements.pop();
394            }
395        }
396        self.persist_source()
397    }
398
399    fn normalize_output(bytes: &[u8]) -> String {
400        String::from_utf8_lossy(bytes)
401            .replace("\r\n", "\n")
402            .replace('\r', "")
403    }
404
405    fn diff_outputs(previous: &str, current: &str) -> String {
406        if let Some(suffix) = current.strip_prefix(previous) {
407            suffix.to_string()
408        } else {
409            current.to_string()
410        }
411    }
412
413    fn run_snippet(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
414        let start = Instant::now();
415        let kind = self.add_snippet(code);
416        self.persist_source()?;
417
418        let source_path = self.source_path();
419        let binary_path = self.binary_path();
420
421        let compile_output = self.compile(&source_path, &binary_path)?;
422        if !compile_output.status.success() {
423            self.rollback(kind)?;
424            let outcome = ExecutionOutcome {
425                language: self.language_id().to_string(),
426                exit_code: compile_output.status.code(),
427                stdout: String::from_utf8_lossy(&compile_output.stdout).into_owned(),
428                stderr: String::from_utf8_lossy(&compile_output.stderr).into_owned(),
429                duration: start.elapsed(),
430            };
431            let _ = fs::remove_file(&binary_path);
432            return Ok((outcome, false));
433        }
434
435        let runtime_output = self.run_binary(&binary_path)?;
436        let stdout_full = Self::normalize_output(&runtime_output.stdout);
437        let stderr_full = Self::normalize_output(&runtime_output.stderr);
438
439        let stdout = Self::diff_outputs(&self.last_stdout, &stdout_full);
440        let stderr = Self::diff_outputs(&self.last_stderr, &stderr_full);
441        let success = runtime_output.status.success();
442
443        if success {
444            self.last_stdout = stdout_full;
445            self.last_stderr = stderr_full;
446        } else {
447            self.rollback(kind)?;
448        }
449
450        let outcome = ExecutionOutcome {
451            language: self.language_id().to_string(),
452            exit_code: runtime_output.status.code(),
453            stdout,
454            stderr,
455            duration: start.elapsed(),
456        };
457
458        let _ = fs::remove_file(&binary_path);
459
460        Ok((outcome, success))
461    }
462}
463
464impl LanguageSession for RustSession {
465    fn language_id(&self) -> &str {
466        RustSession::language_id(self)
467    }
468
469    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
470        let trimmed = code.trim();
471        if trimmed.is_empty() {
472            return Ok(ExecutionOutcome {
473                language: self.language_id().to_string(),
474                exit_code: None,
475                stdout: String::new(),
476                stderr: String::new(),
477                duration: Instant::now().elapsed(),
478            });
479        }
480
481        if contains_main_definition(trimmed) {
482            return self.run_standalone_program(code);
483        }
484
485        let (outcome, _) = self.run_snippet(code)?;
486        Ok(outcome)
487    }
488
489    fn shutdown(&mut self) -> Result<()> {
490        Ok(())
491    }
492}
493
494fn resolve_rustc_binary() -> Option<PathBuf> {
495    which::which("rustc").ok()
496}
497
498fn is_item_snippet(code: &str) -> bool {
499    let mut trimmed = code.trim_start();
500    if trimmed.is_empty() {
501        return false;
502    }
503
504    if trimmed.starts_with("#[") || trimmed.starts_with("#!") {
505        return true;
506    }
507
508    if trimmed.starts_with("pub ") {
509        trimmed = trimmed[4..].trim_start();
510    } else if trimmed.starts_with("pub(")
511        && let Some(idx) = trimmed.find(')')
512    {
513        trimmed = trimmed[idx + 1..].trim_start();
514    }
515
516    let first_token = trimmed.split_whitespace().next().unwrap_or("");
517    let keywords = [
518        "fn",
519        "struct",
520        "enum",
521        "trait",
522        "impl",
523        "mod",
524        "use",
525        "type",
526        "const",
527        "static",
528        "macro_rules!",
529        "extern",
530    ];
531
532    if keywords.iter().any(|kw| first_token.starts_with(kw)) {
533        return true;
534    }
535
536    false
537}
538
539fn should_treat_as_expression(code: &str) -> bool {
540    let trimmed = code.trim();
541    if trimmed.is_empty() {
542        return false;
543    }
544    if trimmed.contains('\n') {
545        return false;
546    }
547    if trimmed.ends_with(';') {
548        return false;
549    }
550    const RESERVED: [&str; 11] = [
551        "let ", "const ", "static ", "fn ", "struct ", "enum ", "impl", "trait ", "mod ", "while ",
552        "for ",
553    ];
554    if RESERVED.iter().any(|kw| trimmed.starts_with(kw)) {
555        return false;
556    }
557    if trimmed.starts_with("if ") || trimmed.starts_with("loop ") || trimmed.starts_with("match ") {
558        return false;
559    }
560    if trimmed.starts_with("return ") {
561        return false;
562    }
563    true
564}
565
566fn wrap_expression(code: &str) -> String {
567    format!("__print({});\n", code)
568}
569
570fn contains_main_definition(code: &str) -> bool {
571    let bytes = code.as_bytes();
572    let len = bytes.len();
573    let mut i = 0;
574    let mut in_line_comment = false;
575    let mut block_depth = 0usize;
576    let mut in_string = false;
577    let mut in_char = false;
578
579    while i < len {
580        let byte = bytes[i];
581
582        if in_line_comment {
583            if byte == b'\n' {
584                in_line_comment = false;
585            }
586            i += 1;
587            continue;
588        }
589
590        if in_string {
591            if byte == b'\\' {
592                i = (i + 2).min(len);
593                continue;
594            }
595            if byte == b'"' {
596                in_string = false;
597            }
598            i += 1;
599            continue;
600        }
601
602        if in_char {
603            if byte == b'\\' {
604                i = (i + 2).min(len);
605                continue;
606            }
607            if byte == b'\'' {
608                in_char = false;
609            }
610            i += 1;
611            continue;
612        }
613
614        if block_depth > 0 {
615            if byte == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
616                block_depth += 1;
617                i += 2;
618                continue;
619            }
620            if byte == b'*' && i + 1 < len && bytes[i + 1] == b'/' {
621                block_depth -= 1;
622                i += 2;
623                continue;
624            }
625            i += 1;
626            continue;
627        }
628
629        match byte {
630            b'/' if i + 1 < len && bytes[i + 1] == b'/' => {
631                in_line_comment = true;
632                i += 2;
633                continue;
634            }
635            b'/' if i + 1 < len && bytes[i + 1] == b'*' => {
636                block_depth = 1;
637                i += 2;
638                continue;
639            }
640            b'"' => {
641                in_string = true;
642                i += 1;
643                continue;
644            }
645            b'\'' => {
646                in_char = true;
647                i += 1;
648                continue;
649            }
650            b'f' if i + 1 < len && bytes[i + 1] == b'n' => {
651                let mut prev_idx = i;
652                let mut preceding_identifier = false;
653                while prev_idx > 0 {
654                    prev_idx -= 1;
655                    let ch = bytes[prev_idx];
656                    if ch.is_ascii_whitespace() {
657                        continue;
658                    }
659                    if ch.is_ascii_alphanumeric() || ch == b'_' {
660                        preceding_identifier = true;
661                    }
662                    break;
663                }
664                if preceding_identifier {
665                    i += 1;
666                    continue;
667                }
668
669                let mut j = i + 2;
670                while j < len && bytes[j].is_ascii_whitespace() {
671                    j += 1;
672                }
673                if j + 4 > len || &bytes[j..j + 4] != b"main" {
674                    i += 1;
675                    continue;
676                }
677
678                let end_idx = j + 4;
679                if end_idx < len {
680                    let ch = bytes[end_idx];
681                    if ch.is_ascii_alphanumeric() || ch == b'_' {
682                        i += 1;
683                        continue;
684                    }
685                }
686
687                let mut after = end_idx;
688                while after < len && bytes[after].is_ascii_whitespace() {
689                    after += 1;
690                }
691                if after < len && bytes[after] != b'(' {
692                    i += 1;
693                    continue;
694                }
695
696                return true;
697            }
698            _ => {}
699        }
700
701        i += 1;
702    }
703
704    false
705}