Skip to main content

run/engine/
go.rs

1use std::collections::BTreeSet;
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::Instant;
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
11
12pub struct GoEngine {
13    executable: Option<PathBuf>,
14}
15
16impl GoEngine {
17    pub fn new() -> Self {
18        Self {
19            executable: resolve_go_binary(),
20        }
21    }
22
23    fn ensure_executable(&self) -> Result<&Path> {
24        self.executable.as_deref().ok_or_else(|| {
25            anyhow::anyhow!(
26                "Go support requires the `go` executable. Install it from https://go.dev/dl/ and ensure it is on your PATH."
27            )
28        })
29    }
30
31    fn write_temp_source(&self, code: &str) -> Result<(tempfile::TempDir, PathBuf)> {
32        let dir = Builder::new()
33            .prefix("run-go")
34            .tempdir()
35            .context("failed to create temporary directory for go source")?;
36        let path = dir.path().join("main.go");
37        let mut contents = code.to_string();
38        if !contents.ends_with('\n') {
39            contents.push('\n');
40        }
41        std::fs::write(&path, contents).with_context(|| {
42            format!("failed to write temporary Go source to {}", path.display())
43        })?;
44        Ok((dir, path))
45    }
46
47    fn execute_with_path(&self, binary: &Path, source: &Path) -> Result<std::process::Output> {
48        let mut cmd = Command::new(binary);
49        cmd.arg("run")
50            .stdout(Stdio::piped())
51            .stderr(Stdio::piped())
52            .env("GO111MODULE", "off");
53        cmd.stdin(Stdio::inherit());
54
55        if let Some(parent) = source.parent() {
56            cmd.current_dir(parent);
57            if let Some(file_name) = source.file_name() {
58                cmd.arg(file_name);
59            } else {
60                cmd.arg(source);
61            }
62        } else {
63            cmd.arg(source);
64        }
65        cmd.output().with_context(|| {
66            format!(
67                "failed to invoke {} to run {}",
68                binary.display(),
69                source.display()
70            )
71        })
72    }
73}
74
75impl LanguageEngine for GoEngine {
76    fn id(&self) -> &'static str {
77        "go"
78    }
79
80    fn display_name(&self) -> &'static str {
81        "Go"
82    }
83
84    fn aliases(&self) -> &[&'static str] {
85        &["golang"]
86    }
87
88    fn supports_sessions(&self) -> bool {
89        true
90    }
91
92    fn validate(&self) -> Result<()> {
93        let binary = self.ensure_executable()?;
94        let mut cmd = Command::new(binary);
95        cmd.arg("version")
96            .stdout(Stdio::null())
97            .stderr(Stdio::null());
98        cmd.status()
99            .with_context(|| format!("failed to invoke {}", binary.display()))?
100            .success()
101            .then_some(())
102            .ok_or_else(|| anyhow::anyhow!("{} is not executable", binary.display()))
103    }
104
105    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
106        let binary = self.ensure_executable()?;
107        let start = Instant::now();
108
109        let (temp_dir, source_path) = match payload {
110            ExecutionPayload::Inline { code } => {
111                let (dir, path) = self.write_temp_source(code)?;
112                (Some(dir), path)
113            }
114            ExecutionPayload::Stdin { code } => {
115                let (dir, path) = self.write_temp_source(code)?;
116                (Some(dir), path)
117            }
118            ExecutionPayload::File { path } => (None, path.clone()),
119        };
120
121        let output = self.execute_with_path(binary, &source_path)?;
122
123        drop(temp_dir);
124
125        Ok(ExecutionOutcome {
126            language: self.id().to_string(),
127            exit_code: output.status.code(),
128            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
129            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
130            duration: start.elapsed(),
131        })
132    }
133
134    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
135        let binary = self.ensure_executable()?.to_path_buf();
136        let session = GoSession::new(binary)?;
137        Ok(Box::new(session))
138    }
139}
140
141fn resolve_go_binary() -> Option<PathBuf> {
142    which::which("go").ok()
143}
144
145fn import_is_used_in_code(import: &str, code: &str) -> bool {
146    let import_trimmed = import.trim().trim_matches('"');
147    let package_name = import_trimmed.rsplit('/').next().unwrap_or(import_trimmed);
148    let pattern = format!("{}.", package_name);
149    code.contains(&pattern)
150}
151
152const SESSION_MAIN_FILE: &str = "main.go";
153
154struct GoSession {
155    go_binary: PathBuf,
156    workspace: TempDir,
157    imports: BTreeSet<String>,
158    items: Vec<String>,
159    statements: Vec<String>,
160    last_stdout: String,
161    last_stderr: String,
162}
163
164enum GoSnippetKind {
165    Import(Option<String>),
166    Item,
167    Statement,
168}
169
170impl GoSession {
171    fn new(go_binary: PathBuf) -> Result<Self> {
172        let workspace = TempDir::new().context("failed to create Go session workspace")?;
173        let mut imports = BTreeSet::new();
174        imports.insert("\"fmt\"".to_string());
175        let session = Self {
176            go_binary,
177            workspace,
178            imports,
179            items: Vec::new(),
180            statements: Vec::new(),
181            last_stdout: String::new(),
182            last_stderr: String::new(),
183        };
184        session.persist_source()?;
185        Ok(session)
186    }
187
188    fn language_id(&self) -> &str {
189        "go"
190    }
191
192    fn source_path(&self) -> PathBuf {
193        self.workspace.path().join(SESSION_MAIN_FILE)
194    }
195
196    fn persist_source(&self) -> Result<()> {
197        let source = self.render_source();
198        fs::write(self.source_path(), source)
199            .with_context(|| "failed to write Go session source".to_string())
200    }
201
202    fn render_source(&self) -> String {
203        let mut source = String::from("package main\n\n");
204
205        if !self.imports.is_empty() {
206            source.push_str("import (\n");
207            for import in &self.imports {
208                source.push_str("\t");
209                source.push_str(import);
210                source.push('\n');
211            }
212            source.push_str(")\n\n");
213        }
214
215        source.push_str(concat!(
216            "func __print(value interface{}) {\n",
217            "\tif s, ok := value.(string); ok {\n",
218            "\t\tfmt.Println(s)\n",
219            "\t\treturn\n",
220            "\t}\n",
221            "\tfmt.Printf(\"%#v\\n\", value)\n",
222            "}\n\n",
223        ));
224
225        for item in &self.items {
226            source.push_str(item);
227            if !item.ends_with('\n') {
228                source.push('\n');
229            }
230            source.push('\n');
231        }
232
233        source.push_str("func main() {\n");
234        if self.statements.is_empty() {
235            source.push_str("\t// session body\n");
236        } else {
237            for snippet in &self.statements {
238                for line in snippet.lines() {
239                    source.push('\t');
240                    source.push_str(line);
241                    source.push('\n');
242                }
243            }
244        }
245        source.push_str("}\n");
246
247        source
248    }
249
250    fn run_program(&self) -> Result<std::process::Output> {
251        let mut cmd = Command::new(&self.go_binary);
252        cmd.arg("run")
253            .arg(SESSION_MAIN_FILE)
254            .env("GO111MODULE", "off")
255            .stdout(Stdio::piped())
256            .stderr(Stdio::piped())
257            .current_dir(self.workspace.path());
258        cmd.output().with_context(|| {
259            format!(
260                "failed to execute {} for Go session",
261                self.go_binary.display()
262            )
263        })
264    }
265
266    fn run_standalone_program(&self, code: &str) -> Result<ExecutionOutcome> {
267        let start = Instant::now();
268        let standalone_path = self.workspace.path().join("standalone.go");
269
270        let source = if has_package_declaration(code) {
271            let mut snippet = code.to_string();
272            if !snippet.ends_with('\n') {
273                snippet.push('\n');
274            }
275            snippet
276        } else {
277            let mut source = String::from("package main\n\n");
278
279            let used_imports: Vec<_> = self
280                .imports
281                .iter()
282                .filter(|import| import_is_used_in_code(import, code))
283                .cloned()
284                .collect();
285
286            if !used_imports.is_empty() {
287                source.push_str("import (\n");
288                for import in &used_imports {
289                    source.push_str("\t");
290                    source.push_str(import);
291                    source.push('\n');
292                }
293                source.push_str(")\n\n");
294            }
295
296            source.push_str(code);
297            if !code.ends_with('\n') {
298                source.push('\n');
299            }
300            source
301        };
302
303        fs::write(&standalone_path, source)
304            .with_context(|| "failed to write Go standalone source".to_string())?;
305
306        let mut cmd = Command::new(&self.go_binary);
307        cmd.arg("run")
308            .arg("standalone.go")
309            .env("GO111MODULE", "off")
310            .stdout(Stdio::piped())
311            .stderr(Stdio::piped())
312            .current_dir(self.workspace.path());
313
314        let output = cmd.output().with_context(|| {
315            format!(
316                "failed to execute {} for Go standalone program",
317                self.go_binary.display()
318            )
319        })?;
320
321        let outcome = ExecutionOutcome {
322            language: self.language_id().to_string(),
323            exit_code: output.status.code(),
324            stdout: Self::normalize_output(&output.stdout),
325            stderr: Self::normalize_output(&output.stderr),
326            duration: start.elapsed(),
327        };
328
329        let _ = fs::remove_file(&standalone_path);
330
331        Ok(outcome)
332    }
333
334    fn add_import(&mut self, spec: &str) -> GoSnippetKind {
335        let added = self.imports.insert(spec.to_string());
336        if added {
337            GoSnippetKind::Import(Some(spec.to_string()))
338        } else {
339            GoSnippetKind::Import(None)
340        }
341    }
342
343    fn add_item(&mut self, code: &str) -> GoSnippetKind {
344        let mut snippet = code.to_string();
345        if !snippet.ends_with('\n') {
346            snippet.push('\n');
347        }
348        self.items.push(snippet);
349        GoSnippetKind::Item
350    }
351
352    fn add_statement(&mut self, code: &str) -> GoSnippetKind {
353        let snippet = sanitize_statement(code);
354        self.statements.push(snippet);
355        GoSnippetKind::Statement
356    }
357
358    fn add_expression(&mut self, code: &str) -> GoSnippetKind {
359        let wrapped = wrap_expression(code);
360        self.statements.push(wrapped);
361        GoSnippetKind::Statement
362    }
363
364    fn rollback(&mut self, kind: GoSnippetKind) -> Result<()> {
365        match kind {
366            GoSnippetKind::Import(Some(spec)) => {
367                self.imports.remove(&spec);
368            }
369            GoSnippetKind::Import(None) => {}
370            GoSnippetKind::Item => {
371                self.items.pop();
372            }
373            GoSnippetKind::Statement => {
374                self.statements.pop();
375            }
376        }
377        self.persist_source()
378    }
379
380    fn normalize_output(bytes: &[u8]) -> String {
381        String::from_utf8_lossy(bytes)
382            .replace("\r\n", "\n")
383            .replace('\r', "")
384    }
385
386    fn diff_outputs(previous: &str, current: &str) -> String {
387        if let Some(suffix) = current.strip_prefix(previous) {
388            suffix.to_string()
389        } else {
390            current.to_string()
391        }
392    }
393
394    fn run_insertion(&mut self, kind: GoSnippetKind) -> Result<(ExecutionOutcome, bool)> {
395        match kind {
396            GoSnippetKind::Import(None) => Ok((
397                ExecutionOutcome {
398                    language: self.language_id().to_string(),
399                    exit_code: None,
400                    stdout: String::new(),
401                    stderr: String::new(),
402                    duration: Default::default(),
403                },
404                true,
405            )),
406            other_kind => {
407                self.persist_source()?;
408                let start = Instant::now();
409                let output = self.run_program()?;
410
411                let stdout_full = Self::normalize_output(&output.stdout);
412                let stderr_full = Self::normalize_output(&output.stderr);
413
414                let stdout = Self::diff_outputs(&self.last_stdout, &stdout_full);
415                let stderr = Self::diff_outputs(&self.last_stderr, &stderr_full);
416                let duration = start.elapsed();
417
418                if output.status.success() {
419                    self.last_stdout = stdout_full;
420                    self.last_stderr = stderr_full;
421                    let outcome = ExecutionOutcome {
422                        language: self.language_id().to_string(),
423                        exit_code: output.status.code(),
424                        stdout,
425                        stderr,
426                        duration,
427                    };
428                    return Ok((outcome, true));
429                }
430
431                if matches!(&other_kind, GoSnippetKind::Import(Some(_)))
432                    && stderr_full.contains("imported and not used")
433                {
434                    return Ok((
435                        ExecutionOutcome {
436                            language: self.language_id().to_string(),
437                            exit_code: None,
438                            stdout: String::new(),
439                            stderr: String::new(),
440                            duration,
441                        },
442                        true,
443                    ));
444                }
445
446                self.rollback(other_kind)?;
447                let outcome = ExecutionOutcome {
448                    language: self.language_id().to_string(),
449                    exit_code: output.status.code(),
450                    stdout,
451                    stderr,
452                    duration,
453                };
454                Ok((outcome, false))
455            }
456        }
457    }
458
459    fn run_import(&mut self, spec: &str) -> Result<(ExecutionOutcome, bool)> {
460        let kind = self.add_import(spec);
461        self.run_insertion(kind)
462    }
463
464    fn run_item(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
465        let kind = self.add_item(code);
466        self.run_insertion(kind)
467    }
468
469    fn run_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
470        let kind = self.add_statement(code);
471        self.run_insertion(kind)
472    }
473
474    fn run_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
475        let kind = self.add_expression(code);
476        self.run_insertion(kind)
477    }
478}
479
480impl LanguageSession for GoSession {
481    fn language_id(&self) -> &str {
482        GoSession::language_id(self)
483    }
484
485    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
486        let trimmed = code.trim();
487        if trimmed.is_empty() {
488            return Ok(ExecutionOutcome {
489                language: self.language_id().to_string(),
490                exit_code: None,
491                stdout: String::new(),
492                stderr: String::new(),
493                duration: Instant::now().elapsed(),
494            });
495        }
496
497        if trimmed.starts_with("package ") && !trimmed.contains('\n') {
498            return Ok(ExecutionOutcome {
499                language: self.language_id().to_string(),
500                exit_code: None,
501                stdout: String::new(),
502                stderr: String::new(),
503                duration: Instant::now().elapsed(),
504            });
505        }
506
507        if contains_main_definition(trimmed) {
508            let outcome = self.run_standalone_program(code)?;
509            return Ok(outcome);
510        }
511
512        if let Some(import) = parse_import_spec(trimmed) {
513            let (outcome, _) = self.run_import(&import)?;
514            return Ok(outcome);
515        }
516
517        if is_item_snippet(trimmed) {
518            let (outcome, _) = self.run_item(code)?;
519            return Ok(outcome);
520        }
521
522        if should_treat_as_expression(trimmed) {
523            let (outcome, success) = self.run_expression(trimmed)?;
524            if success {
525                return Ok(outcome);
526            }
527        }
528
529        let (outcome, _) = self.run_statement(code)?;
530        Ok(outcome)
531    }
532
533    fn shutdown(&mut self) -> Result<()> {
534        Ok(())
535    }
536}
537
538fn parse_import_spec(code: &str) -> Option<String> {
539    let trimmed = code.trim_start();
540    if !trimmed.starts_with("import ") {
541        return None;
542    }
543    let rest = trimmed.trim_start_matches("import").trim();
544    if rest.is_empty() || rest.starts_with('(') {
545        return None;
546    }
547    Some(rest.to_string())
548}
549
550fn is_item_snippet(code: &str) -> bool {
551    let trimmed = code.trim_start();
552    if trimmed.is_empty() {
553        return false;
554    }
555    const KEYWORDS: [&str; 6] = ["type", "const", "var", "func", "package", "import"];
556    KEYWORDS.iter().any(|kw| {
557        trimmed.starts_with(kw)
558            && trimmed
559                .chars()
560                .nth(kw.len())
561                .map(|ch| ch.is_whitespace() || ch == '(')
562                .unwrap_or(true)
563    })
564}
565
566fn should_treat_as_expression(code: &str) -> bool {
567    let trimmed = code.trim();
568    if trimmed.is_empty() {
569        return false;
570    }
571    if trimmed.contains('\n') {
572        return false;
573    }
574    if trimmed.ends_with(';') {
575        return false;
576    }
577    if trimmed.contains(":=") {
578        return false;
579    }
580    if trimmed.contains('=') && !trimmed.contains("==") {
581        return false;
582    }
583    const RESERVED: [&str; 8] = [
584        "if ", "for ", "switch ", "select ", "return ", "go ", "defer ", "var ",
585    ];
586    if RESERVED.iter().any(|kw| trimmed.starts_with(kw)) {
587        return false;
588    }
589    true
590}
591
592fn wrap_expression(code: &str) -> String {
593    format!("__print({});\n", code)
594}
595
596fn sanitize_statement(code: &str) -> String {
597    let mut snippet = code.to_string();
598    if !snippet.ends_with('\n') {
599        snippet.push('\n');
600    }
601
602    let trimmed = code.trim();
603    if trimmed.is_empty() || trimmed.contains('\n') {
604        return snippet;
605    }
606
607    let mut identifiers: Vec<String> = Vec::new();
608
609    if let Some(idx) = trimmed.find(" :=") {
610        let lhs = &trimmed[..idx];
611        identifiers = lhs
612            .split(',')
613            .map(|part| part.trim())
614            .filter(|name| !name.is_empty() && *name != "_")
615            .map(|name| name.to_string())
616            .collect();
617    } else if let Some(idx) = trimmed.find(':') {
618        if trimmed[idx..].starts_with(":=") {
619            let lhs = &trimmed[..idx];
620            identifiers = lhs
621                .split(',')
622                .map(|part| part.trim())
623                .filter(|name| !name.is_empty() && *name != "_")
624                .map(|name| name.to_string())
625                .collect();
626        }
627    } else if trimmed.starts_with("var ") {
628        let rest = trimmed[4..].trim();
629        if !rest.starts_with('(') {
630            let names_part = rest.split('=').next().unwrap_or(rest).trim();
631            identifiers = names_part
632                .split(',')
633                .filter_map(|segment| {
634                    let token = segment.trim().split_whitespace().next().unwrap_or("");
635                    if token.is_empty() || token == "_" {
636                        None
637                    } else {
638                        Some(token.to_string())
639                    }
640                })
641                .collect();
642        }
643    } else if trimmed.starts_with("const ") {
644        let rest = trimmed[6..].trim();
645        if !rest.starts_with('(') {
646            let names_part = rest.split('=').next().unwrap_or(rest).trim();
647            identifiers = names_part
648                .split(',')
649                .filter_map(|segment| {
650                    let token = segment.trim().split_whitespace().next().unwrap_or("");
651                    if token.is_empty() || token == "_" {
652                        None
653                    } else {
654                        Some(token.to_string())
655                    }
656                })
657                .collect();
658        }
659    }
660
661    if identifiers.is_empty() {
662        return snippet;
663    }
664
665    for name in identifiers {
666        snippet.push_str("_ = ");
667        snippet.push_str(&name);
668        snippet.push('\n');
669    }
670
671    snippet
672}
673
674fn has_package_declaration(code: &str) -> bool {
675    code.lines()
676        .any(|line| line.trim_start().starts_with("package "))
677}
678
679fn contains_main_definition(code: &str) -> bool {
680    let bytes = code.as_bytes();
681    let len = bytes.len();
682    let mut i = 0;
683    let mut in_line_comment = false;
684    let mut in_block_comment = false;
685    let mut in_string = false;
686    let mut string_delim = b'"';
687    let mut in_char = false;
688
689    while i < len {
690        let b = bytes[i];
691
692        if in_line_comment {
693            if b == b'\n' {
694                in_line_comment = false;
695            }
696            i += 1;
697            continue;
698        }
699
700        if in_block_comment {
701            if b == b'*' && i + 1 < len && bytes[i + 1] == b'/' {
702                in_block_comment = false;
703                i += 2;
704                continue;
705            }
706            i += 1;
707            continue;
708        }
709
710        if in_string {
711            if b == b'\\' {
712                i = (i + 2).min(len);
713                continue;
714            }
715            if b == string_delim {
716                in_string = false;
717            }
718            i += 1;
719            continue;
720        }
721
722        if in_char {
723            if b == b'\\' {
724                i = (i + 2).min(len);
725                continue;
726            }
727            if b == b'\'' {
728                in_char = false;
729            }
730            i += 1;
731            continue;
732        }
733
734        match b {
735            b'/' if i + 1 < len && bytes[i + 1] == b'/' => {
736                in_line_comment = true;
737                i += 2;
738                continue;
739            }
740            b'/' if i + 1 < len && bytes[i + 1] == b'*' => {
741                in_block_comment = true;
742                i += 2;
743                continue;
744            }
745            b'"' | b'`' => {
746                in_string = true;
747                string_delim = b;
748                i += 1;
749                continue;
750            }
751            b'\'' => {
752                in_char = true;
753                i += 1;
754                continue;
755            }
756            b'f' if i + 4 <= len && &bytes[i..i + 4] == b"func" => {
757                if i > 0 {
758                    let prev = bytes[i - 1];
759                    if prev.is_ascii_alphanumeric() || prev == b'_' {
760                        i += 1;
761                        continue;
762                    }
763                }
764
765                let mut j = i + 4;
766                while j < len && bytes[j].is_ascii_whitespace() {
767                    j += 1;
768                }
769
770                if j + 4 > len || &bytes[j..j + 4] != b"main" {
771                    i += 1;
772                    continue;
773                }
774
775                let after = j + 4;
776                if after < len {
777                    let ch = bytes[after];
778                    if ch.is_ascii_alphanumeric() || ch == b'_' {
779                        i += 1;
780                        continue;
781                    }
782                }
783
784                let mut k = after;
785                while k < len && bytes[k].is_ascii_whitespace() {
786                    k += 1;
787                }
788                if k < len && bytes[k] == b'(' {
789                    return true;
790                }
791            }
792            _ => {}
793        }
794
795        i += 1;
796    }
797
798    false
799}