Skip to main content

run/engine/
rust.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::Instant;
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{
10    ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, cache_store,
11    execution_timeout, hash_source, try_cached_execution, wait_with_timeout,
12};
13
14pub struct RustEngine {
15    compiler: Option<PathBuf>,
16}
17
18impl Default for RustEngine {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl RustEngine {
25    pub fn new() -> Self {
26        Self {
27            compiler: resolve_rustc_binary(),
28        }
29    }
30
31    fn ensure_compiler(&self) -> Result<&Path> {
32        self.compiler.as_deref().ok_or_else(|| {
33            anyhow::anyhow!(
34                "Rust support requires the `rustc` executable. Install it via Rustup and ensure it is on your PATH."
35            )
36        })
37    }
38
39    fn compile(&self, source: &Path, output: &Path) -> Result<std::process::Output> {
40        let compiler = self.ensure_compiler()?;
41        let mut cmd = Command::new(compiler);
42        cmd.arg("--color=never")
43            .arg("--edition=2021")
44            .arg("--crate-name")
45            .arg("run_snippet")
46            .arg(source)
47            .arg("-o")
48            .arg(output);
49        cmd.output()
50            .with_context(|| format!("failed to invoke rustc at {}", compiler.display()))
51    }
52
53    fn run_binary(&self, binary: &Path) -> Result<std::process::Output> {
54        let mut cmd = Command::new(binary);
55        cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
56        cmd.stdin(Stdio::inherit());
57        let child = cmd
58            .spawn()
59            .with_context(|| format!("failed to execute compiled binary {}", binary.display()))?;
60        wait_with_timeout(child, execution_timeout())
61    }
62
63    fn write_inline_source(&self, code: &str, dir: &Path) -> Result<PathBuf> {
64        let source_path = dir.join("main.rs");
65        std::fs::write(&source_path, code).with_context(|| {
66            format!(
67                "failed to write temporary Rust source to {}",
68                source_path.display()
69            )
70        })?;
71        Ok(source_path)
72    }
73
74    fn tmp_binary_path(dir: &Path) -> PathBuf {
75        let mut path = dir.join("run_rust_binary");
76        if let Some(ext) = std::env::consts::EXE_SUFFIX.strip_prefix('.') {
77            if !ext.is_empty() {
78                path.set_extension(ext);
79            }
80        } else if !std::env::consts::EXE_SUFFIX.is_empty() {
81            path = PathBuf::from(format!(
82                "{}{}",
83                path.display(),
84                std::env::consts::EXE_SUFFIX
85            ));
86        }
87        path
88    }
89}
90
91impl LanguageEngine for RustEngine {
92    fn id(&self) -> &'static str {
93        "rust"
94    }
95
96    fn display_name(&self) -> &'static str {
97        "Rust"
98    }
99
100    fn aliases(&self) -> &[&'static str] {
101        &["rs"]
102    }
103
104    fn supports_sessions(&self) -> bool {
105        true
106    }
107
108    fn validate(&self) -> Result<()> {
109        let compiler = self.ensure_compiler()?;
110        let mut cmd = Command::new(compiler);
111        cmd.arg("--version")
112            .stdout(Stdio::null())
113            .stderr(Stdio::null());
114        cmd.status()
115            .with_context(|| format!("failed to invoke {}", compiler.display()))?
116            .success()
117            .then_some(())
118            .ok_or_else(|| anyhow::anyhow!("{} is not executable", compiler.display()))
119    }
120
121    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
122        // Try cache for inline/stdin payloads
123        if let Some(code) = match payload {
124            ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
125                Some(code.as_str())
126            }
127            _ => None,
128        } {
129            let src_hash = hash_source(code);
130            if let Some(output) = try_cached_execution(src_hash) {
131                let start = Instant::now();
132                return Ok(ExecutionOutcome {
133                    language: self.id().to_string(),
134                    exit_code: output.status.code(),
135                    stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
136                    stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
137                    duration: start.elapsed(),
138                });
139            }
140        }
141
142        let temp_dir = Builder::new()
143            .prefix("run-rust")
144            .tempdir()
145            .context("failed to create temporary directory for rust build")?;
146        let dir_path = temp_dir.path();
147
148        let (source_path, cleanup_source, cache_key): (PathBuf, bool, Option<u64>) = match payload {
149            ExecutionPayload::Inline { code } => {
150                let h = hash_source(code);
151                (self.write_inline_source(code, dir_path)?, true, Some(h))
152            }
153            ExecutionPayload::Stdin { code } => {
154                let h = hash_source(code);
155                (self.write_inline_source(code, dir_path)?, true, Some(h))
156            }
157            ExecutionPayload::File { path } => (path.clone(), false, None),
158        };
159
160        let binary_path = Self::tmp_binary_path(dir_path);
161        let start = Instant::now();
162
163        let compile_output = self.compile(&source_path, &binary_path)?;
164        if !compile_output.status.success() {
165            let stdout = String::from_utf8_lossy(&compile_output.stdout).into_owned();
166            let stderr = String::from_utf8_lossy(&compile_output.stderr).into_owned();
167            return Ok(ExecutionOutcome {
168                language: self.id().to_string(),
169                exit_code: compile_output.status.code(),
170                stdout,
171                stderr,
172                duration: start.elapsed(),
173            });
174        }
175
176        // Store in cache before running
177        if let Some(h) = cache_key {
178            cache_store(h, &binary_path);
179        }
180
181        let runtime_output = self.run_binary(&binary_path)?;
182        let outcome = ExecutionOutcome {
183            language: self.id().to_string(),
184            exit_code: runtime_output.status.code(),
185            stdout: String::from_utf8_lossy(&runtime_output.stdout).into_owned(),
186            stderr: String::from_utf8_lossy(&runtime_output.stderr).into_owned(),
187            duration: start.elapsed(),
188        };
189
190        if cleanup_source {
191            let _ = std::fs::remove_file(&source_path);
192        }
193        let _ = std::fs::remove_file(&binary_path);
194
195        Ok(outcome)
196    }
197
198    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
199        let compiler = self.ensure_compiler()?.to_path_buf();
200        let session = RustSession::new(compiler)?;
201        Ok(Box::new(session))
202    }
203}
204
205struct RustSession {
206    compiler: PathBuf,
207    workspace: TempDir,
208    items: Vec<String>,
209    statements: Vec<String>,
210    last_stdout: String,
211    last_stderr: String,
212}
213
214enum RustSnippetKind {
215    Item,
216    Statement,
217}
218
219impl RustSession {
220    fn new(compiler: PathBuf) -> Result<Self> {
221        let workspace = TempDir::new().context("failed to create Rust session workspace")?;
222        let session = Self {
223            compiler,
224            workspace,
225            items: Vec::new(),
226            statements: Vec::new(),
227            last_stdout: String::new(),
228            last_stderr: String::new(),
229        };
230        session.persist_source()?;
231        Ok(session)
232    }
233
234    fn language_id(&self) -> &str {
235        "rust"
236    }
237
238    fn source_path(&self) -> PathBuf {
239        self.workspace.path().join("session.rs")
240    }
241
242    fn binary_path(&self) -> PathBuf {
243        RustEngine::tmp_binary_path(self.workspace.path())
244    }
245
246    fn persist_source(&self) -> Result<()> {
247        let source = self.render_source();
248        fs::write(self.source_path(), source)
249            .with_context(|| "failed to write Rust session source".to_string())
250    }
251
252    fn render_source(&self) -> String {
253        let mut source = String::from(
254            r#"#![allow(unused_variables, unused_assignments, unused_mut, dead_code, unused_imports)]
255use std::fmt::Debug;
256
257fn __print<T: Debug>(value: T) {
258    println!("{:?}", value);
259}
260
261"#,
262        );
263
264        for item in &self.items {
265            source.push_str(item);
266            if !item.ends_with('\n') {
267                source.push('\n');
268            }
269            source.push('\n');
270        }
271
272        source.push_str("fn main() {\n");
273        if self.statements.is_empty() {
274            source.push_str("    // session body\n");
275        } else {
276            for snippet in &self.statements {
277                for line in snippet.lines() {
278                    source.push_str("    ");
279                    source.push_str(line);
280                    source.push('\n');
281                }
282            }
283        }
284        source.push_str("}\n");
285
286        source
287    }
288
289    fn compile(&self, source: &Path, output: &Path) -> Result<std::process::Output> {
290        let mut cmd = Command::new(&self.compiler);
291        cmd.arg("--color=never")
292            .arg("--edition=2021")
293            .arg("--crate-name")
294            .arg("run_snippet")
295            .arg(source)
296            .arg("-o")
297            .arg(output);
298        cmd.output()
299            .with_context(|| format!("failed to invoke rustc at {}", self.compiler.display()))
300    }
301
302    fn run_binary(&self, binary: &Path) -> Result<std::process::Output> {
303        let mut cmd = Command::new(binary);
304        cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
305        cmd.output().with_context(|| {
306            format!(
307                "failed to execute compiled Rust session binary {}",
308                binary.display()
309            )
310        })
311    }
312
313    fn run_standalone_program(&mut self, code: &str) -> Result<ExecutionOutcome> {
314        let start = Instant::now();
315        let source_path = self.workspace.path().join("standalone.rs");
316        fs::write(&source_path, code)
317            .with_context(|| "failed to write standalone Rust source".to_string())?;
318
319        let binary_path = self.binary_path();
320        let compile_output = self.compile(&source_path, &binary_path)?;
321        if !compile_output.status.success() {
322            let outcome = ExecutionOutcome {
323                language: self.language_id().to_string(),
324                exit_code: compile_output.status.code(),
325                stdout: String::from_utf8_lossy(&compile_output.stdout).into_owned(),
326                stderr: String::from_utf8_lossy(&compile_output.stderr).into_owned(),
327                duration: start.elapsed(),
328            };
329            let _ = fs::remove_file(&source_path);
330            let _ = fs::remove_file(&binary_path);
331            return Ok(outcome);
332        }
333
334        let runtime_output = self.run_binary(&binary_path)?;
335        let outcome = ExecutionOutcome {
336            language: self.language_id().to_string(),
337            exit_code: runtime_output.status.code(),
338            stdout: String::from_utf8_lossy(&runtime_output.stdout).into_owned(),
339            stderr: String::from_utf8_lossy(&runtime_output.stderr).into_owned(),
340            duration: start.elapsed(),
341        };
342
343        let _ = fs::remove_file(&source_path);
344        let _ = fs::remove_file(&binary_path);
345
346        Ok(outcome)
347    }
348
349    fn add_snippet(&mut self, code: &str) -> RustSnippetKind {
350        let trimmed = code.trim();
351        if trimmed.is_empty() {
352            return RustSnippetKind::Statement;
353        }
354
355        if is_item_snippet(trimmed) {
356            let mut snippet = code.to_string();
357            if !snippet.ends_with('\n') {
358                snippet.push('\n');
359            }
360            self.items.push(snippet);
361            RustSnippetKind::Item
362        } else {
363            let stored = if should_treat_as_expression(trimmed) {
364                wrap_expression(trimmed)
365            } else {
366                let mut snippet = code.to_string();
367                if !snippet.ends_with('\n') {
368                    snippet.push('\n');
369                }
370                snippet
371            };
372            self.statements.push(stored);
373            RustSnippetKind::Statement
374        }
375    }
376
377    fn rollback(&mut self, kind: RustSnippetKind) -> Result<()> {
378        match kind {
379            RustSnippetKind::Item => {
380                self.items.pop();
381            }
382            RustSnippetKind::Statement => {
383                self.statements.pop();
384            }
385        }
386        self.persist_source()
387    }
388
389    fn normalize_output(bytes: &[u8]) -> String {
390        String::from_utf8_lossy(bytes)
391            .replace("\r\n", "\n")
392            .replace('\r', "")
393    }
394
395    fn diff_outputs(previous: &str, current: &str) -> String {
396        if let Some(suffix) = current.strip_prefix(previous) {
397            suffix.to_string()
398        } else {
399            current.to_string()
400        }
401    }
402
403    fn run_snippet(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
404        let start = Instant::now();
405        let kind = self.add_snippet(code);
406        self.persist_source()?;
407
408        let source_path = self.source_path();
409        let binary_path = self.binary_path();
410
411        let compile_output = self.compile(&source_path, &binary_path)?;
412        if !compile_output.status.success() {
413            self.rollback(kind)?;
414            let outcome = ExecutionOutcome {
415                language: self.language_id().to_string(),
416                exit_code: compile_output.status.code(),
417                stdout: String::from_utf8_lossy(&compile_output.stdout).into_owned(),
418                stderr: String::from_utf8_lossy(&compile_output.stderr).into_owned(),
419                duration: start.elapsed(),
420            };
421            let _ = fs::remove_file(&binary_path);
422            return Ok((outcome, false));
423        }
424
425        let runtime_output = self.run_binary(&binary_path)?;
426        let stdout_full = Self::normalize_output(&runtime_output.stdout);
427        let stderr_full = Self::normalize_output(&runtime_output.stderr);
428
429        let stdout = Self::diff_outputs(&self.last_stdout, &stdout_full);
430        let stderr = Self::diff_outputs(&self.last_stderr, &stderr_full);
431        let success = runtime_output.status.success();
432
433        if success {
434            self.last_stdout = stdout_full;
435            self.last_stderr = stderr_full;
436        } else {
437            self.rollback(kind)?;
438        }
439
440        let outcome = ExecutionOutcome {
441            language: self.language_id().to_string(),
442            exit_code: runtime_output.status.code(),
443            stdout,
444            stderr,
445            duration: start.elapsed(),
446        };
447
448        let _ = fs::remove_file(&binary_path);
449
450        Ok((outcome, success))
451    }
452}
453
454impl LanguageSession for RustSession {
455    fn language_id(&self) -> &str {
456        RustSession::language_id(self)
457    }
458
459    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
460        let trimmed = code.trim();
461        if trimmed.is_empty() {
462            return Ok(ExecutionOutcome {
463                language: self.language_id().to_string(),
464                exit_code: None,
465                stdout: String::new(),
466                stderr: String::new(),
467                duration: Instant::now().elapsed(),
468            });
469        }
470
471        if contains_main_definition(trimmed) {
472            return self.run_standalone_program(code);
473        }
474
475        let (outcome, _) = self.run_snippet(code)?;
476        Ok(outcome)
477    }
478
479    fn shutdown(&mut self) -> Result<()> {
480        Ok(())
481    }
482}
483
484fn resolve_rustc_binary() -> Option<PathBuf> {
485    which::which("rustc").ok()
486}
487
488fn is_item_snippet(code: &str) -> bool {
489    let mut trimmed = code.trim_start();
490    if trimmed.is_empty() {
491        return false;
492    }
493
494    if trimmed.starts_with("#[") || trimmed.starts_with("#!") {
495        return true;
496    }
497
498    if trimmed.starts_with("pub ") {
499        trimmed = trimmed[4..].trim_start();
500    } else if trimmed.starts_with("pub(")
501        && let Some(idx) = trimmed.find(')')
502    {
503        trimmed = trimmed[idx + 1..].trim_start();
504    }
505
506    let first_token = trimmed.split_whitespace().next().unwrap_or("");
507    let keywords = [
508        "fn",
509        "struct",
510        "enum",
511        "trait",
512        "impl",
513        "mod",
514        "use",
515        "type",
516        "const",
517        "static",
518        "macro_rules!",
519        "extern",
520    ];
521
522    if keywords.iter().any(|kw| first_token.starts_with(kw)) {
523        return true;
524    }
525
526    false
527}
528
529fn should_treat_as_expression(code: &str) -> bool {
530    let trimmed = code.trim();
531    if trimmed.is_empty() {
532        return false;
533    }
534    if trimmed.contains('\n') {
535        return false;
536    }
537    if trimmed.ends_with(';') {
538        return false;
539    }
540    const RESERVED: [&str; 11] = [
541        "let ", "const ", "static ", "fn ", "struct ", "enum ", "impl", "trait ", "mod ", "while ",
542        "for ",
543    ];
544    if RESERVED.iter().any(|kw| trimmed.starts_with(kw)) {
545        return false;
546    }
547    if trimmed.starts_with("if ") || trimmed.starts_with("loop ") || trimmed.starts_with("match ") {
548        return false;
549    }
550    if trimmed.starts_with("return ") {
551        return false;
552    }
553    true
554}
555
556fn wrap_expression(code: &str) -> String {
557    format!("__print({});\n", code)
558}
559
560fn contains_main_definition(code: &str) -> bool {
561    let bytes = code.as_bytes();
562    let len = bytes.len();
563    let mut i = 0;
564    let mut in_line_comment = false;
565    let mut block_depth = 0usize;
566    let mut in_string = false;
567    let mut in_char = false;
568
569    while i < len {
570        let byte = bytes[i];
571
572        if in_line_comment {
573            if byte == b'\n' {
574                in_line_comment = false;
575            }
576            i += 1;
577            continue;
578        }
579
580        if in_string {
581            if byte == b'\\' {
582                i = (i + 2).min(len);
583                continue;
584            }
585            if byte == b'"' {
586                in_string = false;
587            }
588            i += 1;
589            continue;
590        }
591
592        if in_char {
593            if byte == b'\\' {
594                i = (i + 2).min(len);
595                continue;
596            }
597            if byte == b'\'' {
598                in_char = false;
599            }
600            i += 1;
601            continue;
602        }
603
604        if block_depth > 0 {
605            if byte == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
606                block_depth += 1;
607                i += 2;
608                continue;
609            }
610            if byte == b'*' && i + 1 < len && bytes[i + 1] == b'/' {
611                block_depth -= 1;
612                i += 2;
613                continue;
614            }
615            i += 1;
616            continue;
617        }
618
619        match byte {
620            b'/' if i + 1 < len && bytes[i + 1] == b'/' => {
621                in_line_comment = true;
622                i += 2;
623                continue;
624            }
625            b'/' if i + 1 < len && bytes[i + 1] == b'*' => {
626                block_depth = 1;
627                i += 2;
628                continue;
629            }
630            b'"' => {
631                in_string = true;
632                i += 1;
633                continue;
634            }
635            b'\'' => {
636                in_char = true;
637                i += 1;
638                continue;
639            }
640            b'f' if i + 1 < len && bytes[i + 1] == b'n' => {
641                let mut prev_idx = i;
642                let mut preceding_identifier = false;
643                while prev_idx > 0 {
644                    prev_idx -= 1;
645                    let ch = bytes[prev_idx];
646                    if ch.is_ascii_whitespace() {
647                        continue;
648                    }
649                    if ch.is_ascii_alphanumeric() || ch == b'_' {
650                        preceding_identifier = true;
651                    }
652                    break;
653                }
654                if preceding_identifier {
655                    i += 1;
656                    continue;
657                }
658
659                let mut j = i + 2;
660                while j < len && bytes[j].is_ascii_whitespace() {
661                    j += 1;
662                }
663                if j + 4 > len || &bytes[j..j + 4] != b"main" {
664                    i += 1;
665                    continue;
666                }
667
668                let end_idx = j + 4;
669                if end_idx < len {
670                    let ch = bytes[end_idx];
671                    if ch.is_ascii_alphanumeric() || ch == b'_' {
672                        i += 1;
673                        continue;
674                    }
675                }
676
677                let mut after = end_idx;
678                while after < len && bytes[after].is_ascii_whitespace() {
679                    after += 1;
680                }
681                if after < len && bytes[after] != b'(' {
682                    i += 1;
683                    continue;
684                }
685
686                return true;
687            }
688            _ => {}
689        }
690
691        i += 1;
692    }
693
694    false
695}