Skip to main content

run/engine/
rust.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::Instant;
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{
10    ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession,
11    cache_store, execution_timeout, hash_source, try_cached_execution, wait_with_timeout,
12};
13
14pub struct RustEngine {
15    compiler: Option<PathBuf>,
16}
17
18impl RustEngine {
19    pub fn new() -> Self {
20        Self {
21            compiler: resolve_rustc_binary(),
22        }
23    }
24
25    fn ensure_compiler(&self) -> Result<&Path> {
26        self.compiler.as_deref().ok_or_else(|| {
27            anyhow::anyhow!(
28                "Rust support requires the `rustc` executable. Install it via Rustup and ensure it is on your PATH."
29            )
30        })
31    }
32
33    fn compile(&self, source: &Path, output: &Path) -> Result<std::process::Output> {
34        let compiler = self.ensure_compiler()?;
35        let mut cmd = Command::new(compiler);
36        cmd.arg("--color=never")
37            .arg("--edition=2021")
38            .arg("--crate-name")
39            .arg("run_snippet")
40            .arg(source)
41            .arg("-o")
42            .arg(output);
43        cmd.output()
44            .with_context(|| format!("failed to invoke rustc at {}", compiler.display()))
45    }
46
47    fn run_binary(&self, binary: &Path) -> Result<std::process::Output> {
48        let mut cmd = Command::new(binary);
49        cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
50        cmd.stdin(Stdio::inherit());
51        let child = cmd.spawn()
52            .with_context(|| format!("failed to execute compiled binary {}", binary.display()))?;
53        wait_with_timeout(child, execution_timeout())
54    }
55
56    fn write_inline_source(&self, code: &str, dir: &Path) -> Result<PathBuf> {
57        let source_path = dir.join("main.rs");
58        std::fs::write(&source_path, code).with_context(|| {
59            format!(
60                "failed to write temporary Rust source to {}",
61                source_path.display()
62            )
63        })?;
64        Ok(source_path)
65    }
66
67    fn tmp_binary_path(dir: &Path) -> PathBuf {
68        let mut path = dir.join("run_rust_binary");
69        if let Some(ext) = std::env::consts::EXE_SUFFIX.strip_prefix('.') {
70            if !ext.is_empty() {
71                path.set_extension(ext);
72            }
73        } else if !std::env::consts::EXE_SUFFIX.is_empty() {
74            path = PathBuf::from(format!(
75                "{}{}",
76                path.display(),
77                std::env::consts::EXE_SUFFIX
78            ));
79        }
80        path
81    }
82}
83
84impl LanguageEngine for RustEngine {
85    fn id(&self) -> &'static str {
86        "rust"
87    }
88
89    fn display_name(&self) -> &'static str {
90        "Rust"
91    }
92
93    fn aliases(&self) -> &[&'static str] {
94        &["rs"]
95    }
96
97    fn supports_sessions(&self) -> bool {
98        true
99    }
100
101    fn validate(&self) -> Result<()> {
102        let compiler = self.ensure_compiler()?;
103        let mut cmd = Command::new(compiler);
104        cmd.arg("--version")
105            .stdout(Stdio::null())
106            .stderr(Stdio::null());
107        cmd.status()
108            .with_context(|| format!("failed to invoke {}", compiler.display()))?
109            .success()
110            .then_some(())
111            .ok_or_else(|| anyhow::anyhow!("{} is not executable", compiler.display()))
112    }
113
114    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
115        // Try cache for inline/stdin payloads
116        if let Some(code) = match payload {
117            ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => Some(code.as_str()),
118            _ => None,
119        } {
120            let src_hash = hash_source(code);
121            if let Some(output) = try_cached_execution(src_hash) {
122                let start = Instant::now();
123                return Ok(ExecutionOutcome {
124                    language: self.id().to_string(),
125                    exit_code: output.status.code(),
126                    stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
127                    stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
128                    duration: start.elapsed(),
129                });
130            }
131        }
132
133        let temp_dir = Builder::new()
134            .prefix("run-rust")
135            .tempdir()
136            .context("failed to create temporary directory for rust build")?;
137        let dir_path = temp_dir.path();
138
139        let (source_path, cleanup_source, cache_key): (PathBuf, bool, Option<u64>) = match payload {
140            ExecutionPayload::Inline { code } => {
141                let h = hash_source(code);
142                (self.write_inline_source(code, dir_path)?, true, Some(h))
143            }
144            ExecutionPayload::Stdin { code } => {
145                let h = hash_source(code);
146                (self.write_inline_source(code, dir_path)?, true, Some(h))
147            }
148            ExecutionPayload::File { path } => (path.clone(), false, None),
149        };
150
151        let binary_path = Self::tmp_binary_path(dir_path);
152        let start = Instant::now();
153
154        let compile_output = self.compile(&source_path, &binary_path)?;
155        if !compile_output.status.success() {
156            let stdout = String::from_utf8_lossy(&compile_output.stdout).into_owned();
157            let stderr = String::from_utf8_lossy(&compile_output.stderr).into_owned();
158            return Ok(ExecutionOutcome {
159                language: self.id().to_string(),
160                exit_code: compile_output.status.code(),
161                stdout,
162                stderr,
163                duration: start.elapsed(),
164            });
165        }
166
167        // Store in cache before running
168        if let Some(h) = cache_key {
169            cache_store(h, &binary_path);
170        }
171
172        let runtime_output = self.run_binary(&binary_path)?;
173        let outcome = ExecutionOutcome {
174            language: self.id().to_string(),
175            exit_code: runtime_output.status.code(),
176            stdout: String::from_utf8_lossy(&runtime_output.stdout).into_owned(),
177            stderr: String::from_utf8_lossy(&runtime_output.stderr).into_owned(),
178            duration: start.elapsed(),
179        };
180
181        if cleanup_source {
182            let _ = std::fs::remove_file(&source_path);
183        }
184        let _ = std::fs::remove_file(&binary_path);
185
186        Ok(outcome)
187    }
188
189    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
190        let compiler = self.ensure_compiler()?.to_path_buf();
191        let session = RustSession::new(compiler)?;
192        Ok(Box::new(session))
193    }
194}
195
196struct RustSession {
197    compiler: PathBuf,
198    workspace: TempDir,
199    items: Vec<String>,
200    statements: Vec<String>,
201    last_stdout: String,
202    last_stderr: String,
203}
204
205enum RustSnippetKind {
206    Item,
207    Statement,
208}
209
210impl RustSession {
211    fn new(compiler: PathBuf) -> Result<Self> {
212        let workspace = TempDir::new().context("failed to create Rust session workspace")?;
213        let session = Self {
214            compiler,
215            workspace,
216            items: Vec::new(),
217            statements: Vec::new(),
218            last_stdout: String::new(),
219            last_stderr: String::new(),
220        };
221        session.persist_source()?;
222        Ok(session)
223    }
224
225    fn language_id(&self) -> &str {
226        "rust"
227    }
228
229    fn source_path(&self) -> PathBuf {
230        self.workspace.path().join("session.rs")
231    }
232
233    fn binary_path(&self) -> PathBuf {
234        RustEngine::tmp_binary_path(self.workspace.path())
235    }
236
237    fn persist_source(&self) -> Result<()> {
238        let source = self.render_source();
239        fs::write(self.source_path(), source)
240            .with_context(|| "failed to write Rust session source".to_string())
241    }
242
243    fn render_source(&self) -> String {
244        let mut source = String::from(
245            r#"#![allow(unused_variables, unused_assignments, unused_mut, dead_code, unused_imports)]
246use std::fmt::Debug;
247
248fn __print<T: Debug>(value: T) {
249    println!("{:?}", value);
250}
251
252"#,
253        );
254
255        for item in &self.items {
256            source.push_str(item);
257            if !item.ends_with('\n') {
258                source.push('\n');
259            }
260            source.push('\n');
261        }
262
263        source.push_str("fn main() {\n");
264        if self.statements.is_empty() {
265            source.push_str("    // session body\n");
266        } else {
267            for snippet in &self.statements {
268                for line in snippet.lines() {
269                    source.push_str("    ");
270                    source.push_str(line);
271                    source.push('\n');
272                }
273            }
274        }
275        source.push_str("}\n");
276
277        source
278    }
279
280    fn compile(&self, source: &Path, output: &Path) -> Result<std::process::Output> {
281        let mut cmd = Command::new(&self.compiler);
282        cmd.arg("--color=never")
283            .arg("--edition=2021")
284            .arg("--crate-name")
285            .arg("run_snippet")
286            .arg(source)
287            .arg("-o")
288            .arg(output);
289        cmd.output()
290            .with_context(|| format!("failed to invoke rustc at {}", self.compiler.display()))
291    }
292
293    fn run_binary(&self, binary: &Path) -> Result<std::process::Output> {
294        let mut cmd = Command::new(binary);
295        cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
296        cmd.output().with_context(|| {
297            format!(
298                "failed to execute compiled Rust session binary {}",
299                binary.display()
300            )
301        })
302    }
303
304    fn run_standalone_program(&mut self, code: &str) -> Result<ExecutionOutcome> {
305        let start = Instant::now();
306        let source_path = self.workspace.path().join("standalone.rs");
307        fs::write(&source_path, code)
308            .with_context(|| "failed to write standalone Rust source".to_string())?;
309
310        let binary_path = self.binary_path();
311        let compile_output = self.compile(&source_path, &binary_path)?;
312        if !compile_output.status.success() {
313            let outcome = ExecutionOutcome {
314                language: self.language_id().to_string(),
315                exit_code: compile_output.status.code(),
316                stdout: String::from_utf8_lossy(&compile_output.stdout).into_owned(),
317                stderr: String::from_utf8_lossy(&compile_output.stderr).into_owned(),
318                duration: start.elapsed(),
319            };
320            let _ = fs::remove_file(&source_path);
321            let _ = fs::remove_file(&binary_path);
322            return Ok(outcome);
323        }
324
325        let runtime_output = self.run_binary(&binary_path)?;
326        let outcome = ExecutionOutcome {
327            language: self.language_id().to_string(),
328            exit_code: runtime_output.status.code(),
329            stdout: String::from_utf8_lossy(&runtime_output.stdout).into_owned(),
330            stderr: String::from_utf8_lossy(&runtime_output.stderr).into_owned(),
331            duration: start.elapsed(),
332        };
333
334        let _ = fs::remove_file(&source_path);
335        let _ = fs::remove_file(&binary_path);
336
337        Ok(outcome)
338    }
339
340    fn add_snippet(&mut self, code: &str) -> RustSnippetKind {
341        let trimmed = code.trim();
342        if trimmed.is_empty() {
343            return RustSnippetKind::Statement;
344        }
345
346        if is_item_snippet(trimmed) {
347            let mut snippet = code.to_string();
348            if !snippet.ends_with('\n') {
349                snippet.push('\n');
350            }
351            self.items.push(snippet);
352            RustSnippetKind::Item
353        } else {
354            let stored = if should_treat_as_expression(trimmed) {
355                wrap_expression(trimmed)
356            } else {
357                let mut snippet = code.to_string();
358                if !snippet.ends_with('\n') {
359                    snippet.push('\n');
360                }
361                snippet
362            };
363            self.statements.push(stored);
364            RustSnippetKind::Statement
365        }
366    }
367
368    fn rollback(&mut self, kind: RustSnippetKind) -> Result<()> {
369        match kind {
370            RustSnippetKind::Item => {
371                self.items.pop();
372            }
373            RustSnippetKind::Statement => {
374                self.statements.pop();
375            }
376        }
377        self.persist_source()
378    }
379
380    fn normalize_output(bytes: &[u8]) -> String {
381        String::from_utf8_lossy(bytes)
382            .replace("\r\n", "\n")
383            .replace('\r', "")
384    }
385
386    fn diff_outputs(previous: &str, current: &str) -> String {
387        if let Some(suffix) = current.strip_prefix(previous) {
388            suffix.to_string()
389        } else {
390            current.to_string()
391        }
392    }
393
394    fn run_snippet(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
395        let start = Instant::now();
396        let kind = self.add_snippet(code);
397        self.persist_source()?;
398
399        let source_path = self.source_path();
400        let binary_path = self.binary_path();
401
402        let compile_output = self.compile(&source_path, &binary_path)?;
403        if !compile_output.status.success() {
404            self.rollback(kind)?;
405            let outcome = ExecutionOutcome {
406                language: self.language_id().to_string(),
407                exit_code: compile_output.status.code(),
408                stdout: String::from_utf8_lossy(&compile_output.stdout).into_owned(),
409                stderr: String::from_utf8_lossy(&compile_output.stderr).into_owned(),
410                duration: start.elapsed(),
411            };
412            let _ = fs::remove_file(&binary_path);
413            return Ok((outcome, false));
414        }
415
416        let runtime_output = self.run_binary(&binary_path)?;
417        let stdout_full = Self::normalize_output(&runtime_output.stdout);
418        let stderr_full = Self::normalize_output(&runtime_output.stderr);
419
420        let stdout = Self::diff_outputs(&self.last_stdout, &stdout_full);
421        let stderr = Self::diff_outputs(&self.last_stderr, &stderr_full);
422        let success = runtime_output.status.success();
423
424        if success {
425            self.last_stdout = stdout_full;
426            self.last_stderr = stderr_full;
427        } else {
428            self.rollback(kind)?;
429        }
430
431        let outcome = ExecutionOutcome {
432            language: self.language_id().to_string(),
433            exit_code: runtime_output.status.code(),
434            stdout,
435            stderr,
436            duration: start.elapsed(),
437        };
438
439        let _ = fs::remove_file(&binary_path);
440
441        Ok((outcome, success))
442    }
443}
444
445impl LanguageSession for RustSession {
446    fn language_id(&self) -> &str {
447        RustSession::language_id(self)
448    }
449
450    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
451        let trimmed = code.trim();
452        if trimmed.is_empty() {
453            return Ok(ExecutionOutcome {
454                language: self.language_id().to_string(),
455                exit_code: None,
456                stdout: String::new(),
457                stderr: String::new(),
458                duration: Instant::now().elapsed(),
459            });
460        }
461
462        if contains_main_definition(trimmed) {
463            return self.run_standalone_program(code);
464        }
465
466        let (outcome, _) = self.run_snippet(code)?;
467        Ok(outcome)
468    }
469
470    fn shutdown(&mut self) -> Result<()> {
471        Ok(())
472    }
473}
474
475fn resolve_rustc_binary() -> Option<PathBuf> {
476    which::which("rustc").ok()
477}
478
479fn is_item_snippet(code: &str) -> bool {
480    let mut trimmed = code.trim_start();
481    if trimmed.is_empty() {
482        return false;
483    }
484
485    if trimmed.starts_with("#[") || trimmed.starts_with("#!") {
486        return true;
487    }
488
489    if trimmed.starts_with("pub ") {
490        trimmed = trimmed[4..].trim_start();
491    } else if trimmed.starts_with("pub(") {
492        if let Some(idx) = trimmed.find(')') {
493            trimmed = trimmed[idx + 1..].trim_start();
494        }
495    }
496
497    let first_token = trimmed.split_whitespace().next().unwrap_or("");
498    let keywords = [
499        "fn",
500        "struct",
501        "enum",
502        "trait",
503        "impl",
504        "mod",
505        "use",
506        "type",
507        "const",
508        "static",
509        "macro_rules!",
510        "extern",
511    ];
512
513    if keywords.iter().any(|kw| first_token.starts_with(kw)) {
514        return true;
515    }
516
517    false
518}
519
520fn should_treat_as_expression(code: &str) -> bool {
521    let trimmed = code.trim();
522    if trimmed.is_empty() {
523        return false;
524    }
525    if trimmed.contains('\n') {
526        return false;
527    }
528    if trimmed.ends_with(';') {
529        return false;
530    }
531    const RESERVED: [&str; 11] = [
532        "let ", "const ", "static ", "fn ", "struct ", "enum ", "impl", "trait ", "mod ", "while ",
533        "for ",
534    ];
535    if RESERVED.iter().any(|kw| trimmed.starts_with(kw)) {
536        return false;
537    }
538    if trimmed.starts_with("if ") || trimmed.starts_with("loop ") || trimmed.starts_with("match ") {
539        return false;
540    }
541    if trimmed.starts_with("return ") {
542        return false;
543    }
544    true
545}
546
547fn wrap_expression(code: &str) -> String {
548    format!("__print({});\n", code)
549}
550
551fn contains_main_definition(code: &str) -> bool {
552    let bytes = code.as_bytes();
553    let len = bytes.len();
554    let mut i = 0;
555    let mut in_line_comment = false;
556    let mut block_depth = 0usize;
557    let mut in_string = false;
558    let mut in_char = false;
559
560    while i < len {
561        let byte = bytes[i];
562
563        if in_line_comment {
564            if byte == b'\n' {
565                in_line_comment = false;
566            }
567            i += 1;
568            continue;
569        }
570
571        if in_string {
572            if byte == b'\\' {
573                i = (i + 2).min(len);
574                continue;
575            }
576            if byte == b'"' {
577                in_string = false;
578            }
579            i += 1;
580            continue;
581        }
582
583        if in_char {
584            if byte == b'\\' {
585                i = (i + 2).min(len);
586                continue;
587            }
588            if byte == b'\'' {
589                in_char = false;
590            }
591            i += 1;
592            continue;
593        }
594
595        if block_depth > 0 {
596            if byte == b'/' && i + 1 < len && bytes[i + 1] == b'*' {
597                block_depth += 1;
598                i += 2;
599                continue;
600            }
601            if byte == b'*' && i + 1 < len && bytes[i + 1] == b'/' {
602                block_depth -= 1;
603                i += 2;
604                continue;
605            }
606            i += 1;
607            continue;
608        }
609
610        match byte {
611            b'/' if i + 1 < len && bytes[i + 1] == b'/' => {
612                in_line_comment = true;
613                i += 2;
614                continue;
615            }
616            b'/' if i + 1 < len && bytes[i + 1] == b'*' => {
617                block_depth = 1;
618                i += 2;
619                continue;
620            }
621            b'"' => {
622                in_string = true;
623                i += 1;
624                continue;
625            }
626            b'\'' => {
627                in_char = true;
628                i += 1;
629                continue;
630            }
631            b'f' if i + 1 < len && bytes[i + 1] == b'n' => {
632                let mut prev_idx = i;
633                let mut preceding_identifier = false;
634                while prev_idx > 0 {
635                    prev_idx -= 1;
636                    let ch = bytes[prev_idx];
637                    if ch.is_ascii_whitespace() {
638                        continue;
639                    }
640                    if ch.is_ascii_alphanumeric() || ch == b'_' {
641                        preceding_identifier = true;
642                    }
643                    break;
644                }
645                if preceding_identifier {
646                    i += 1;
647                    continue;
648                }
649
650                let mut j = i + 2;
651                while j < len && bytes[j].is_ascii_whitespace() {
652                    j += 1;
653                }
654                if j + 4 > len || &bytes[j..j + 4] != b"main" {
655                    i += 1;
656                    continue;
657                }
658
659                let end_idx = j + 4;
660                if end_idx < len {
661                    let ch = bytes[end_idx];
662                    if ch.is_ascii_alphanumeric() || ch == b'_' {
663                        i += 1;
664                        continue;
665                    }
666                }
667
668                let mut after = end_idx;
669                while after < len && bytes[after].is_ascii_whitespace() {
670                    after += 1;
671                }
672                if after < len && bytes[after] != b'(' {
673                    i += 1;
674                    continue;
675                }
676
677                return true;
678            }
679            _ => {}
680        }
681
682        i += 1;
683    }
684
685    false
686}