run/engine/
go.rs

1use std::collections::BTreeSet;
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::Instant;
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
11
12pub struct GoEngine {
13    executable: Option<PathBuf>,
14}
15
16impl GoEngine {
17    pub fn new() -> Self {
18        Self {
19            executable: resolve_go_binary(),
20        }
21    }
22
23    fn ensure_executable(&self) -> Result<&Path> {
24        self.executable.as_deref().ok_or_else(|| {
25            anyhow::anyhow!(
26                "Go support requires the `go` executable. Install it from https://go.dev/dl/ and ensure it is on your PATH."
27            )
28        })
29    }
30
31    fn write_temp_source(&self, code: &str) -> Result<(tempfile::TempDir, PathBuf)> {
32        let dir = Builder::new()
33            .prefix("run-go")
34            .tempdir()
35            .context("failed to create temporary directory for go source")?;
36        let path = dir.path().join("main.go");
37        let mut contents = code.to_string();
38        if !contents.ends_with('\n') {
39            contents.push('\n');
40        }
41        std::fs::write(&path, contents).with_context(|| {
42            format!("failed to write temporary Go source to {}", path.display())
43        })?;
44        Ok((dir, path))
45    }
46
47    fn execute_with_path(&self, binary: &Path, source: &Path) -> Result<std::process::Output> {
48        let mut cmd = Command::new(binary);
49        cmd.arg("run")
50            .stdout(Stdio::piped())
51            .stderr(Stdio::piped())
52            .env("GO111MODULE", "off");
53        cmd.stdin(Stdio::inherit());
54
55        if let Some(parent) = source.parent() {
56            cmd.current_dir(parent);
57            if let Some(file_name) = source.file_name() {
58                cmd.arg(file_name);
59            } else {
60                cmd.arg(source);
61            }
62        } else {
63            cmd.arg(source);
64        }
65        cmd.output().with_context(|| {
66            format!(
67                "failed to invoke {} to run {}",
68                binary.display(),
69                source.display()
70            )
71        })
72    }
73}
74
75impl LanguageEngine for GoEngine {
76    fn id(&self) -> &'static str {
77        "go"
78    }
79
80    fn display_name(&self) -> &'static str {
81        "Go"
82    }
83
84    fn aliases(&self) -> &[&'static str] {
85        &["golang"]
86    }
87
88    fn supports_sessions(&self) -> bool {
89        true
90    }
91
92    fn validate(&self) -> Result<()> {
93        let binary = self.ensure_executable()?;
94        let mut cmd = Command::new(binary);
95        cmd.arg("version")
96            .stdout(Stdio::null())
97            .stderr(Stdio::null());
98        cmd.status()
99            .with_context(|| format!("failed to invoke {}", binary.display()))?
100            .success()
101            .then_some(())
102            .ok_or_else(|| anyhow::anyhow!("{} is not executable", binary.display()))
103    }
104
105    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
106        let binary = self.ensure_executable()?;
107        let start = Instant::now();
108
109        let (temp_dir, source_path) = match payload {
110            ExecutionPayload::Inline { code } => {
111                let (dir, path) = self.write_temp_source(code)?;
112                (Some(dir), path)
113            }
114            ExecutionPayload::Stdin { code } => {
115                let (dir, path) = self.write_temp_source(code)?;
116                (Some(dir), path)
117            }
118            ExecutionPayload::File { path } => (None, path.clone()),
119        };
120
121        let output = self.execute_with_path(binary, &source_path)?;
122
123        // Ensure temp_dir stays in scope until after the command runs
124        drop(temp_dir);
125
126        Ok(ExecutionOutcome {
127            language: self.id().to_string(),
128            exit_code: output.status.code(),
129            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
130            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
131            duration: start.elapsed(),
132        })
133    }
134
135    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
136        let binary = self.ensure_executable()?.to_path_buf();
137        let session = GoSession::new(binary)?;
138        Ok(Box::new(session))
139    }
140}
141
142fn resolve_go_binary() -> Option<PathBuf> {
143    which::which("go").ok()
144}
145
146fn import_is_used_in_code(import: &str, code: &str) -> bool {
147    let import_trimmed = import.trim().trim_matches('"');
148    let package_name = import_trimmed.rsplit('/').next().unwrap_or(import_trimmed);
149    let pattern = format!("{}.", package_name);
150    code.contains(&pattern)
151}
152
153const SESSION_MAIN_FILE: &str = "main.go";
154
155struct GoSession {
156    go_binary: PathBuf,
157    workspace: TempDir,
158    imports: BTreeSet<String>,
159    items: Vec<String>,
160    statements: Vec<String>,
161    last_stdout: String,
162    last_stderr: String,
163}
164
165enum GoSnippetKind {
166    Import(Option<String>),
167    Item,
168    Statement,
169}
170
171impl GoSession {
172    fn new(go_binary: PathBuf) -> Result<Self> {
173        let workspace = TempDir::new().context("failed to create Go session workspace")?;
174        let mut imports = BTreeSet::new();
175        imports.insert("\"fmt\"".to_string());
176        let session = Self {
177            go_binary,
178            workspace,
179            imports,
180            items: Vec::new(),
181            statements: Vec::new(),
182            last_stdout: String::new(),
183            last_stderr: String::new(),
184        };
185        session.persist_source()?;
186        Ok(session)
187    }
188
189    fn language_id(&self) -> &str {
190        "go"
191    }
192
193    fn source_path(&self) -> PathBuf {
194        self.workspace.path().join(SESSION_MAIN_FILE)
195    }
196
197    fn persist_source(&self) -> Result<()> {
198        let source = self.render_source();
199        fs::write(self.source_path(), source)
200            .with_context(|| "failed to write Go session source".to_string())
201    }
202
203    fn render_source(&self) -> String {
204        let mut source = String::from("package main\n\n");
205
206        if !self.imports.is_empty() {
207            source.push_str("import (\n");
208            for import in &self.imports {
209                source.push_str("\t");
210                source.push_str(import);
211                source.push('\n');
212            }
213            source.push_str(")\n\n");
214        }
215
216        source.push_str(concat!(
217            "func __print(value interface{}) {\n",
218            "\tif s, ok := value.(string); ok {\n",
219            "\t\tfmt.Println(s)\n",
220            "\t\treturn\n",
221            "\t}\n",
222            "\tfmt.Printf(\"%#v\\n\", value)\n",
223            "}\n\n",
224        ));
225
226        for item in &self.items {
227            source.push_str(item);
228            if !item.ends_with('\n') {
229                source.push('\n');
230            }
231            source.push('\n');
232        }
233
234        source.push_str("func main() {\n");
235        if self.statements.is_empty() {
236            source.push_str("\t// session body\n");
237        } else {
238            for snippet in &self.statements {
239                for line in snippet.lines() {
240                    source.push('\t');
241                    source.push_str(line);
242                    source.push('\n');
243                }
244            }
245        }
246        source.push_str("}\n");
247
248        source
249    }
250
251    fn run_program(&self) -> Result<std::process::Output> {
252        let mut cmd = Command::new(&self.go_binary);
253        cmd.arg("run")
254            .arg(SESSION_MAIN_FILE)
255            .env("GO111MODULE", "off")
256            .stdout(Stdio::piped())
257            .stderr(Stdio::piped())
258            .current_dir(self.workspace.path());
259        cmd.output().with_context(|| {
260            format!(
261                "failed to execute {} for Go session",
262                self.go_binary.display()
263            )
264        })
265    }
266
267    fn run_standalone_program(&self, code: &str) -> Result<ExecutionOutcome> {
268        let start = Instant::now();
269        let standalone_path = self.workspace.path().join("standalone.go");
270
271        let source = if has_package_declaration(code) {
272            let mut snippet = code.to_string();
273            if !snippet.ends_with('\n') {
274                snippet.push('\n');
275            }
276            snippet
277        } else {
278            let mut source = String::from("package main\n\n");
279
280            let used_imports: Vec<_> = self
281                .imports
282                .iter()
283                .filter(|import| import_is_used_in_code(import, code))
284                .cloned()
285                .collect();
286
287            if !used_imports.is_empty() {
288                source.push_str("import (\n");
289                for import in &used_imports {
290                    source.push_str("\t");
291                    source.push_str(import);
292                    source.push('\n');
293                }
294                source.push_str(")\n\n");
295            }
296
297            source.push_str(code);
298            if !code.ends_with('\n') {
299                source.push('\n');
300            }
301            source
302        };
303
304        fs::write(&standalone_path, source)
305            .with_context(|| "failed to write Go standalone source".to_string())?;
306
307        let mut cmd = Command::new(&self.go_binary);
308        cmd.arg("run")
309            .arg("standalone.go")
310            .env("GO111MODULE", "off")
311            .stdout(Stdio::piped())
312            .stderr(Stdio::piped())
313            .current_dir(self.workspace.path());
314
315        let output = cmd.output().with_context(|| {
316            format!(
317                "failed to execute {} for Go standalone program",
318                self.go_binary.display()
319            )
320        })?;
321
322        let outcome = ExecutionOutcome {
323            language: self.language_id().to_string(),
324            exit_code: output.status.code(),
325            stdout: Self::normalize_output(&output.stdout),
326            stderr: Self::normalize_output(&output.stderr),
327            duration: start.elapsed(),
328        };
329
330        let _ = fs::remove_file(&standalone_path);
331
332        Ok(outcome)
333    }
334
335    fn add_import(&mut self, spec: &str) -> GoSnippetKind {
336        let added = self.imports.insert(spec.to_string());
337        if added {
338            GoSnippetKind::Import(Some(spec.to_string()))
339        } else {
340            GoSnippetKind::Import(None)
341        }
342    }
343
344    fn add_item(&mut self, code: &str) -> GoSnippetKind {
345        let mut snippet = code.to_string();
346        if !snippet.ends_with('\n') {
347            snippet.push('\n');
348        }
349        self.items.push(snippet);
350        GoSnippetKind::Item
351    }
352
353    fn add_statement(&mut self, code: &str) -> GoSnippetKind {
354        let snippet = sanitize_statement(code);
355        self.statements.push(snippet);
356        GoSnippetKind::Statement
357    }
358
359    fn add_expression(&mut self, code: &str) -> GoSnippetKind {
360        let wrapped = wrap_expression(code);
361        self.statements.push(wrapped);
362        GoSnippetKind::Statement
363    }
364
365    fn rollback(&mut self, kind: GoSnippetKind) -> Result<()> {
366        match kind {
367            GoSnippetKind::Import(Some(spec)) => {
368                self.imports.remove(&spec);
369            }
370            GoSnippetKind::Import(None) => {}
371            GoSnippetKind::Item => {
372                self.items.pop();
373            }
374            GoSnippetKind::Statement => {
375                self.statements.pop();
376            }
377        }
378        self.persist_source()
379    }
380
381    fn normalize_output(bytes: &[u8]) -> String {
382        String::from_utf8_lossy(bytes)
383            .replace("\r\n", "\n")
384            .replace('\r', "")
385    }
386
387    fn diff_outputs(previous: &str, current: &str) -> String {
388        if let Some(suffix) = current.strip_prefix(previous) {
389            suffix.to_string()
390        } else {
391            current.to_string()
392        }
393    }
394
395    fn run_insertion(&mut self, kind: GoSnippetKind) -> Result<(ExecutionOutcome, bool)> {
396        match kind {
397            GoSnippetKind::Import(None) => Ok((
398                ExecutionOutcome {
399                    language: self.language_id().to_string(),
400                    exit_code: None,
401                    stdout: String::new(),
402                    stderr: String::new(),
403                    duration: Default::default(),
404                },
405                true,
406            )),
407            other_kind => {
408                self.persist_source()?;
409                let start = Instant::now();
410                let output = self.run_program()?;
411
412                let stdout_full = Self::normalize_output(&output.stdout);
413                let stderr_full = Self::normalize_output(&output.stderr);
414
415                let stdout = Self::diff_outputs(&self.last_stdout, &stdout_full);
416                let stderr = Self::diff_outputs(&self.last_stderr, &stderr_full);
417                let duration = start.elapsed();
418
419                if output.status.success() {
420                    self.last_stdout = stdout_full;
421                    self.last_stderr = stderr_full;
422                    let outcome = ExecutionOutcome {
423                        language: self.language_id().to_string(),
424                        exit_code: output.status.code(),
425                        stdout,
426                        stderr,
427                        duration,
428                    };
429                    return Ok((outcome, true));
430                }
431
432                if matches!(&other_kind, GoSnippetKind::Import(Some(_)))
433                    && stderr_full.contains("imported and not used")
434                {
435                    return Ok((
436                        ExecutionOutcome {
437                            language: self.language_id().to_string(),
438                            exit_code: None,
439                            stdout: String::new(),
440                            stderr: String::new(),
441                            duration,
442                        },
443                        true,
444                    ));
445                }
446
447                self.rollback(other_kind)?;
448                let outcome = ExecutionOutcome {
449                    language: self.language_id().to_string(),
450                    exit_code: output.status.code(),
451                    stdout,
452                    stderr,
453                    duration,
454                };
455                Ok((outcome, false))
456            }
457        }
458    }
459
460    fn run_import(&mut self, spec: &str) -> Result<(ExecutionOutcome, bool)> {
461        let kind = self.add_import(spec);
462        self.run_insertion(kind)
463    }
464
465    fn run_item(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
466        let kind = self.add_item(code);
467        self.run_insertion(kind)
468    }
469
470    fn run_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
471        let kind = self.add_statement(code);
472        self.run_insertion(kind)
473    }
474
475    fn run_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
476        let kind = self.add_expression(code);
477        self.run_insertion(kind)
478    }
479}
480
481impl LanguageSession for GoSession {
482    fn language_id(&self) -> &str {
483        GoSession::language_id(self)
484    }
485
486    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
487        let trimmed = code.trim();
488        if trimmed.is_empty() {
489            return Ok(ExecutionOutcome {
490                language: self.language_id().to_string(),
491                exit_code: None,
492                stdout: String::new(),
493                stderr: String::new(),
494                duration: Instant::now().elapsed(),
495            });
496        }
497
498        if trimmed.starts_with("package ") && !trimmed.contains('\n') {
499            return Ok(ExecutionOutcome {
500                language: self.language_id().to_string(),
501                exit_code: None,
502                stdout: String::new(),
503                stderr: String::new(),
504                duration: Instant::now().elapsed(),
505            });
506        }
507
508        if contains_main_definition(trimmed) {
509            let outcome = self.run_standalone_program(code)?;
510            return Ok(outcome);
511        }
512
513        if let Some(import) = parse_import_spec(trimmed) {
514            let (outcome, _) = self.run_import(&import)?;
515            return Ok(outcome);
516        }
517
518        if is_item_snippet(trimmed) {
519            let (outcome, _) = self.run_item(code)?;
520            return Ok(outcome);
521        }
522
523        if should_treat_as_expression(trimmed) {
524            let (outcome, success) = self.run_expression(trimmed)?;
525            if success {
526                return Ok(outcome);
527            }
528        }
529
530        let (outcome, _) = self.run_statement(code)?;
531        Ok(outcome)
532    }
533
534    fn shutdown(&mut self) -> Result<()> {
535        Ok(())
536    }
537}
538
539fn parse_import_spec(code: &str) -> Option<String> {
540    let trimmed = code.trim_start();
541    if !trimmed.starts_with("import ") {
542        return None;
543    }
544    let rest = trimmed.trim_start_matches("import").trim();
545    if rest.is_empty() || rest.starts_with('(') {
546        return None;
547    }
548    Some(rest.to_string())
549}
550
551fn is_item_snippet(code: &str) -> bool {
552    let trimmed = code.trim_start();
553    if trimmed.is_empty() {
554        return false;
555    }
556    const KEYWORDS: [&str; 6] = ["type", "const", "var", "func", "package", "import"];
557    KEYWORDS.iter().any(|kw| {
558        trimmed.starts_with(kw)
559            && trimmed
560                .chars()
561                .nth(kw.len())
562                .map(|ch| ch.is_whitespace() || ch == '(')
563                .unwrap_or(true)
564    })
565}
566
567fn should_treat_as_expression(code: &str) -> bool {
568    let trimmed = code.trim();
569    if trimmed.is_empty() {
570        return false;
571    }
572    if trimmed.contains('\n') {
573        return false;
574    }
575    if trimmed.ends_with(';') {
576        return false;
577    }
578    if trimmed.contains(":=") {
579        return false;
580    }
581    if trimmed.contains('=') && !trimmed.contains("==") {
582        return false;
583    }
584    const RESERVED: [&str; 8] = [
585        "if ", "for ", "switch ", "select ", "return ", "go ", "defer ", "var ",
586    ];
587    if RESERVED.iter().any(|kw| trimmed.starts_with(kw)) {
588        return false;
589    }
590    true
591}
592
593fn wrap_expression(code: &str) -> String {
594    format!("__print({});\n", code)
595}
596
597fn sanitize_statement(code: &str) -> String {
598    let mut snippet = code.to_string();
599    if !snippet.ends_with('\n') {
600        snippet.push('\n');
601    }
602
603    let trimmed = code.trim();
604    if trimmed.is_empty() || trimmed.contains('\n') {
605        return snippet;
606    }
607
608    let mut identifiers: Vec<String> = Vec::new();
609
610    if let Some(idx) = trimmed.find(" :=") {
611        let lhs = &trimmed[..idx];
612        identifiers = lhs
613            .split(',')
614            .map(|part| part.trim())
615            .filter(|name| !name.is_empty() && *name != "_")
616            .map(|name| name.to_string())
617            .collect();
618    } else if let Some(idx) = trimmed.find(':') {
619        if trimmed[idx..].starts_with(":=") {
620            let lhs = &trimmed[..idx];
621            identifiers = lhs
622                .split(',')
623                .map(|part| part.trim())
624                .filter(|name| !name.is_empty() && *name != "_")
625                .map(|name| name.to_string())
626                .collect();
627        }
628    } else if trimmed.starts_with("var ") {
629        let rest = trimmed[4..].trim();
630        if !rest.starts_with('(') {
631            let names_part = rest.split('=').next().unwrap_or(rest).trim();
632            identifiers = names_part
633                .split(',')
634                .filter_map(|segment| {
635                    let token = segment.trim().split_whitespace().next().unwrap_or("");
636                    if token.is_empty() || token == "_" {
637                        None
638                    } else {
639                        Some(token.to_string())
640                    }
641                })
642                .collect();
643        }
644    } else if trimmed.starts_with("const ") {
645        let rest = trimmed[6..].trim();
646        if !rest.starts_with('(') {
647            let names_part = rest.split('=').next().unwrap_or(rest).trim();
648            identifiers = names_part
649                .split(',')
650                .filter_map(|segment| {
651                    let token = segment.trim().split_whitespace().next().unwrap_or("");
652                    if token.is_empty() || token == "_" {
653                        None
654                    } else {
655                        Some(token.to_string())
656                    }
657                })
658                .collect();
659        }
660    }
661
662    if identifiers.is_empty() {
663        return snippet;
664    }
665
666    for name in identifiers {
667        snippet.push_str("_ = ");
668        snippet.push_str(&name);
669        snippet.push('\n');
670    }
671
672    snippet
673}
674
675fn has_package_declaration(code: &str) -> bool {
676    code.lines()
677        .any(|line| line.trim_start().starts_with("package "))
678}
679
680fn contains_main_definition(code: &str) -> bool {
681    let bytes = code.as_bytes();
682    let len = bytes.len();
683    let mut i = 0;
684    let mut in_line_comment = false;
685    let mut in_block_comment = false;
686    let mut in_string = false;
687    let mut string_delim = b'"';
688    let mut in_char = false;
689
690    while i < len {
691        let b = bytes[i];
692
693        if in_line_comment {
694            if b == b'\n' {
695                in_line_comment = false;
696            }
697            i += 1;
698            continue;
699        }
700
701        if in_block_comment {
702            if b == b'*' && i + 1 < len && bytes[i + 1] == b'/' {
703                in_block_comment = false;
704                i += 2;
705                continue;
706            }
707            i += 1;
708            continue;
709        }
710
711        if in_string {
712            if b == b'\\' {
713                i = (i + 2).min(len);
714                continue;
715            }
716            if b == string_delim {
717                in_string = false;
718            }
719            i += 1;
720            continue;
721        }
722
723        if in_char {
724            if b == b'\\' {
725                i = (i + 2).min(len);
726                continue;
727            }
728            if b == b'\'' {
729                in_char = false;
730            }
731            i += 1;
732            continue;
733        }
734
735        match b {
736            b'/' if i + 1 < len && bytes[i + 1] == b'/' => {
737                in_line_comment = true;
738                i += 2;
739                continue;
740            }
741            b'/' if i + 1 < len && bytes[i + 1] == b'*' => {
742                in_block_comment = true;
743                i += 2;
744                continue;
745            }
746            b'"' | b'`' => {
747                in_string = true;
748                string_delim = b;
749                i += 1;
750                continue;
751            }
752            b'\'' => {
753                in_char = true;
754                i += 1;
755                continue;
756            }
757            b'f' if i + 4 <= len && &bytes[i..i + 4] == b"func" => {
758                if i > 0 {
759                    let prev = bytes[i - 1];
760                    if prev.is_ascii_alphanumeric() || prev == b'_' {
761                        i += 1;
762                        continue;
763                    }
764                }
765
766                let mut j = i + 4;
767                while j < len && bytes[j].is_ascii_whitespace() {
768                    j += 1;
769                }
770
771                if j + 4 > len || &bytes[j..j + 4] != b"main" {
772                    i += 1;
773                    continue;
774                }
775
776                let after = j + 4;
777                if after < len {
778                    let ch = bytes[after];
779                    if ch.is_ascii_alphanumeric() || ch == b'_' {
780                        i += 1;
781                        continue;
782                    }
783                }
784
785                let mut k = after;
786                while k < len && bytes[k].is_ascii_whitespace() {
787                    k += 1;
788                }
789                if k < len && bytes[k] == b'(' {
790                    return true;
791                }
792            }
793            _ => {}
794        }
795
796        i += 1;
797    }
798
799    false
800}