Skip to main content

run/engine/
zig.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::{Duration, Instant};
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
10
11pub struct ZigEngine {
12    executable: Option<PathBuf>,
13}
14
15impl ZigEngine {
16    pub fn new() -> Self {
17        Self {
18            executable: resolve_zig_binary(),
19        }
20    }
21
22    fn ensure_executable(&self) -> Result<&Path> {
23        self.executable.as_deref().ok_or_else(|| {
24            anyhow::anyhow!(
25                "Zig support requires the `zig` executable. Install it from https://ziglang.org/download/ and ensure it is on your PATH."
26            )
27        })
28    }
29
30    fn write_temp_source(&self, code: &str) -> Result<(TempDir, PathBuf)> {
31        let dir = Builder::new()
32            .prefix("run-zig")
33            .tempdir()
34            .context("failed to create temporary directory for Zig source")?;
35        let path = dir.path().join("snippet.zig");
36        let mut contents = code.to_string();
37        if !contents.ends_with('\n') {
38            contents.push('\n');
39        }
40        std::fs::write(&path, contents).with_context(|| {
41            format!("failed to write temporary Zig source to {}", path.display())
42        })?;
43        Ok((dir, path))
44    }
45
46    fn run_source(&self, source: &Path) -> Result<std::process::Output> {
47        let executable = self.ensure_executable()?;
48        let mut cmd = Command::new(executable);
49        cmd.arg("run")
50            .arg(source)
51            .stdout(Stdio::piped())
52            .stderr(Stdio::piped());
53        cmd.stdin(Stdio::inherit());
54        if let Some(dir) = source.parent() {
55            cmd.current_dir(dir);
56        }
57        cmd.output().with_context(|| {
58            format!(
59                "failed to execute {} with source {}",
60                executable.display(),
61                source.display()
62            )
63        })
64    }
65}
66
67impl LanguageEngine for ZigEngine {
68    fn id(&self) -> &'static str {
69        "zig"
70    }
71
72    fn display_name(&self) -> &'static str {
73        "Zig"
74    }
75
76    fn aliases(&self) -> &[&'static str] {
77        &["ziglang"]
78    }
79
80    fn supports_sessions(&self) -> bool {
81        self.executable.is_some()
82    }
83
84    fn validate(&self) -> Result<()> {
85        let executable = self.ensure_executable()?;
86        let mut cmd = Command::new(executable);
87        cmd.arg("version")
88            .stdout(Stdio::null())
89            .stderr(Stdio::null());
90        cmd.status()
91            .with_context(|| format!("failed to invoke {}", executable.display()))?
92            .success()
93            .then_some(())
94            .ok_or_else(|| anyhow::anyhow!("{} is not executable", executable.display()))
95    }
96
97    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
98        let start = Instant::now();
99        let (temp_dir, source_path) = match payload {
100            ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
101                let snippet = wrap_inline_snippet(code);
102                let (dir, path) = self.write_temp_source(&snippet)?;
103                (Some(dir), path)
104            }
105            ExecutionPayload::File { path } => {
106                if path.extension().and_then(|e| e.to_str()) != Some("zig") {
107                    let code = std::fs::read_to_string(path)?;
108                    let (dir, new_path) = self.write_temp_source(&code)?;
109                    (Some(dir), new_path)
110                } else {
111                    (None, path.clone())
112                }
113            }
114        };
115
116        let output = self.run_source(&source_path)?;
117        drop(temp_dir);
118
119        let mut combined_stdout = String::from_utf8_lossy(&output.stdout).into_owned();
120        let stderr_str = String::from_utf8_lossy(&output.stderr).into_owned();
121
122        if output.status.success() && !stderr_str.contains("error:") {
123            if !combined_stdout.is_empty() && !stderr_str.is_empty() {
124                combined_stdout.push_str(&stderr_str);
125            } else if combined_stdout.is_empty() {
126                combined_stdout = stderr_str.clone();
127            }
128        }
129
130        Ok(ExecutionOutcome {
131            language: self.id().to_string(),
132            exit_code: output.status.code(),
133            stdout: combined_stdout,
134            stderr: if output.status.success() && !stderr_str.contains("error:") {
135                String::new()
136            } else {
137                stderr_str
138            },
139            duration: start.elapsed(),
140        })
141    }
142
143    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
144        let executable = self.ensure_executable()?.to_path_buf();
145        Ok(Box::new(ZigSession::new(executable)?))
146    }
147}
148
149fn resolve_zig_binary() -> Option<PathBuf> {
150    which::which("zig").ok()
151}
152
153const ZIG_NUMERIC_SUFFIXES: [&str; 17] = [
154    "usize", "isize", "u128", "i128", "f128", "f80", "u64", "i64", "f64", "u32", "i32", "f32",
155    "u16", "i16", "f16", "u8", "i8",
156];
157
158fn wrap_inline_snippet(code: &str) -> String {
159    let trimmed = code.trim();
160    if trimmed.is_empty() || trimmed.contains("pub fn main") {
161        let mut owned = code.to_string();
162        if !owned.ends_with('\n') {
163            owned.push('\n');
164        }
165        return owned;
166    }
167
168    let mut body = String::new();
169    for line in code.lines() {
170        body.push_str("    ");
171        body.push_str(line);
172        if !line.ends_with('\n') {
173            body.push('\n');
174        }
175    }
176    if body.is_empty() {
177        body.push_str("    const stdout = std.io.getStdOut().writer(); _ = stdout.print(\"\\n\", .{}) catch {};\n");
178    }
179
180    format!("const std = @import(\"std\");\n\npub fn main() !void {{\n{body}}}\n")
181}
182
183struct ZigSession {
184    executable: PathBuf,
185    workspace: TempDir,
186    items: Vec<String>,
187    statements: Vec<String>,
188    last_stdout: String,
189    last_stderr: String,
190}
191
192enum ZigSnippetKind {
193    Declaration,
194    Statement,
195    Expression,
196}
197
198impl ZigSession {
199    fn new(executable: PathBuf) -> Result<Self> {
200        let workspace = TempDir::new().context("failed to create Zig session workspace")?;
201        let session = Self {
202            executable,
203            workspace,
204            items: Vec::new(),
205            statements: Vec::new(),
206            last_stdout: String::new(),
207            last_stderr: String::new(),
208        };
209        session.persist_source()?;
210        Ok(session)
211    }
212
213    fn source_path(&self) -> PathBuf {
214        self.workspace.path().join("session.zig")
215    }
216
217    fn persist_source(&self) -> Result<()> {
218        let source = self.render_source();
219        fs::write(self.source_path(), source)
220            .with_context(|| "failed to write Zig session source".to_string())
221    }
222
223    fn render_source(&self) -> String {
224        let mut source = String::from("const std = @import(\"std\");\n\n");
225
226        for item in &self.items {
227            source.push_str(item);
228            if !item.ends_with('\n') {
229                source.push('\n');
230            }
231            source.push('\n');
232        }
233
234        source.push_str("pub fn main() !void {\n");
235        if self.statements.is_empty() {
236            source.push_str("    return;\n");
237        } else {
238            for snippet in &self.statements {
239                for line in snippet.lines() {
240                    source.push_str("    ");
241                    source.push_str(line);
242                    source.push('\n');
243                }
244            }
245        }
246        source.push_str("}\n");
247
248        source
249    }
250
251    fn run_program(&self) -> Result<std::process::Output> {
252        let mut cmd = Command::new(&self.executable);
253        cmd.arg("run")
254            .arg("session.zig")
255            .stdout(Stdio::piped())
256            .stderr(Stdio::piped())
257            .current_dir(self.workspace.path());
258        cmd.output().with_context(|| {
259            format!(
260                "failed to execute {} for Zig session",
261                self.executable.display()
262            )
263        })
264    }
265
266    fn run_standalone_program(&self, code: &str) -> Result<ExecutionOutcome> {
267        let start = Instant::now();
268        let path = self.workspace.path().join("standalone.zig");
269        let mut contents = code.to_string();
270        if !contents.ends_with('\n') {
271            contents.push('\n');
272        }
273        fs::write(&path, contents)
274            .with_context(|| "failed to write Zig standalone source".to_string())?;
275
276        let mut cmd = Command::new(&self.executable);
277        cmd.arg("run")
278            .arg("standalone.zig")
279            .stdout(Stdio::piped())
280            .stderr(Stdio::piped())
281            .current_dir(self.workspace.path());
282        let output = cmd.output().with_context(|| {
283            format!(
284                "failed to execute {} for Zig standalone snippet",
285                self.executable.display()
286            )
287        })?;
288
289        let mut stdout = Self::normalize_output(&output.stdout);
290        let stderr = Self::normalize_output(&output.stderr);
291
292        if output.status.success() && !stderr.contains("error:") {
293            if stdout.is_empty() {
294                stdout = stderr.clone();
295            } else {
296                stdout.push_str(&stderr);
297            }
298        }
299
300        Ok(ExecutionOutcome {
301            language: self.language_id().to_string(),
302            exit_code: output.status.code(),
303            stdout,
304            stderr: if output.status.success() && !stderr.contains("error:") {
305                String::new()
306            } else {
307                stderr
308            },
309            duration: start.elapsed(),
310        })
311    }
312
313    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
314        self.persist_source()?;
315        let output = self.run_program()?;
316        let mut stdout_full = Self::normalize_output(&output.stdout);
317        let stderr_full = Self::normalize_output(&output.stderr);
318
319        let success = output.status.success();
320
321        if success && !stderr_full.is_empty() && !stderr_full.contains("error:") {
322            if stdout_full.is_empty() {
323                stdout_full = stderr_full.clone();
324            } else {
325                stdout_full.push_str(&stderr_full);
326            }
327        }
328
329        let (stdout, stderr) = if success {
330            let stdout_delta = Self::diff_outputs(&self.last_stdout, &stdout_full);
331            let stderr_clean = if !stderr_full.contains("error:") {
332                String::new()
333            } else {
334                stderr_full.clone()
335            };
336            let stderr_delta = Self::diff_outputs(&self.last_stderr, &stderr_clean);
337            self.last_stdout = stdout_full;
338            self.last_stderr = stderr_clean;
339            (stdout_delta, stderr_delta)
340        } else {
341            (stdout_full, stderr_full)
342        };
343
344        let outcome = ExecutionOutcome {
345            language: self.language_id().to_string(),
346            exit_code: output.status.code(),
347            stdout,
348            stderr,
349            duration: start.elapsed(),
350        };
351
352        Ok((outcome, success))
353    }
354
355    fn apply_declaration(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
356        let normalized = normalize_snippet(code);
357        let mut snippet = normalized;
358        if !snippet.ends_with('\n') {
359            snippet.push('\n');
360        }
361        self.items.push(snippet);
362        let start = Instant::now();
363        let (outcome, success) = self.run_current(start)?;
364        if !success {
365            let _ = self.items.pop();
366            self.persist_source()?;
367        }
368        Ok((outcome, success))
369    }
370
371    fn apply_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
372        let normalized = normalize_snippet(code);
373        let snippet = ensure_trailing_newline(&normalized);
374        self.statements.push(snippet);
375        let start = Instant::now();
376        let (outcome, success) = self.run_current(start)?;
377        if !success {
378            let _ = self.statements.pop();
379            self.persist_source()?;
380        }
381        Ok((outcome, success))
382    }
383
384    fn apply_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
385        let normalized = normalize_snippet(code);
386        let wrapped = wrap_expression(&normalized);
387        self.statements.push(wrapped);
388        let start = Instant::now();
389        let (outcome, success) = self.run_current(start)?;
390        if !success {
391            let _ = self.statements.pop();
392            self.persist_source()?;
393        }
394        Ok((outcome, success))
395    }
396
397    fn reset(&mut self) -> Result<()> {
398        self.items.clear();
399        self.statements.clear();
400        self.last_stdout.clear();
401        self.last_stderr.clear();
402        self.persist_source()
403    }
404
405    fn normalize_output(bytes: &[u8]) -> String {
406        String::from_utf8_lossy(bytes)
407            .replace("\r\n", "\n")
408            .replace('\r', "")
409    }
410
411    fn diff_outputs(previous: &str, current: &str) -> String {
412        current
413            .strip_prefix(previous)
414            .map(|s| s.to_string())
415            .unwrap_or_else(|| current.to_string())
416    }
417}
418
419impl LanguageSession for ZigSession {
420    fn language_id(&self) -> &str {
421        "zig"
422    }
423
424    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
425        let trimmed = code.trim();
426        if trimmed.is_empty() {
427            return Ok(ExecutionOutcome {
428                language: self.language_id().to_string(),
429                exit_code: None,
430                stdout: String::new(),
431                stderr: String::new(),
432                duration: Duration::default(),
433            });
434        }
435
436        if trimmed.eq_ignore_ascii_case(":reset") {
437            self.reset()?;
438            return Ok(ExecutionOutcome {
439                language: self.language_id().to_string(),
440                exit_code: None,
441                stdout: String::new(),
442                stderr: String::new(),
443                duration: Duration::default(),
444            });
445        }
446
447        if trimmed.eq_ignore_ascii_case(":help") {
448            return Ok(ExecutionOutcome {
449                language: self.language_id().to_string(),
450                exit_code: None,
451                stdout:
452                    "Zig commands:\n  :reset - clear session state\n  :help  - show this message\n"
453                        .to_string(),
454                stderr: String::new(),
455                duration: Duration::default(),
456            });
457        }
458
459        if trimmed.contains("pub fn main") {
460            return self.run_standalone_program(code);
461        }
462
463        match classify_snippet(trimmed) {
464            ZigSnippetKind::Declaration => {
465                let (outcome, _) = self.apply_declaration(code)?;
466                Ok(outcome)
467            }
468            ZigSnippetKind::Statement => {
469                let (outcome, _) = self.apply_statement(code)?;
470                Ok(outcome)
471            }
472            ZigSnippetKind::Expression => {
473                let (outcome, _) = self.apply_expression(trimmed)?;
474                Ok(outcome)
475            }
476        }
477    }
478
479    fn shutdown(&mut self) -> Result<()> {
480        Ok(())
481    }
482}
483
484fn classify_snippet(code: &str) -> ZigSnippetKind {
485    if looks_like_declaration(code) {
486        ZigSnippetKind::Declaration
487    } else if looks_like_statement(code) {
488        ZigSnippetKind::Statement
489    } else {
490        ZigSnippetKind::Expression
491    }
492}
493
494fn looks_like_declaration(code: &str) -> bool {
495    let trimmed = code.trim_start();
496    matches!(
497        trimmed,
498        t if t.starts_with("const ")
499            || t.starts_with("var ")
500            || t.starts_with("pub ")
501            || t.starts_with("fn ")
502            || t.starts_with("usingnamespace ")
503            || t.starts_with("extern ")
504            || t.starts_with("comptime ")
505            || t.starts_with("test ")
506    )
507}
508
509fn looks_like_statement(code: &str) -> bool {
510    let trimmed = code.trim_end();
511    trimmed.contains('\n')
512        || trimmed.ends_with(';')
513        || trimmed.ends_with('}')
514        || trimmed.ends_with(':')
515        || trimmed.starts_with("//")
516        || trimmed.starts_with("/*")
517}
518
519fn ensure_trailing_newline(code: &str) -> String {
520    let mut snippet = code.to_string();
521    if !snippet.ends_with('\n') {
522        snippet.push('\n');
523    }
524    snippet
525}
526
527fn wrap_expression(code: &str) -> String {
528    format!("std.debug.print(\"{{any}}\\n\", .{{ {} }});", code)
529}
530
531fn normalize_snippet(code: &str) -> String {
532    rewrite_numeric_suffixes(code)
533}
534
535fn rewrite_numeric_suffixes(code: &str) -> String {
536    let bytes = code.as_bytes();
537    let mut result = String::with_capacity(code.len());
538    let mut i = 0;
539    while i < bytes.len() {
540        let ch = bytes[i] as char;
541
542        if ch == '"' {
543            let (segment, advance) = extract_string_literal(&code[i..]);
544            result.push_str(segment);
545            i += advance;
546            continue;
547        }
548
549        if ch == '\'' {
550            let (segment, advance) = extract_char_literal(&code[i..]);
551            result.push_str(segment);
552            i += advance;
553            continue;
554        }
555
556        if ch == '/' && i + 1 < bytes.len() {
557            let next = bytes[i + 1] as char;
558            if next == '/' {
559                result.push_str(&code[i..]);
560                break;
561            }
562            if next == '*' {
563                let (segment, advance) = extract_block_comment(&code[i..]);
564                result.push_str(segment);
565                i += advance;
566                continue;
567            }
568        }
569
570        if ch.is_ascii_digit() {
571            if i > 0 {
572                let prev = bytes[i - 1] as char;
573                if prev.is_ascii_alphanumeric() || prev == '_' {
574                    result.push(ch);
575                    i += 1;
576                    continue;
577                }
578            }
579
580            let literal_end = scan_numeric_literal(bytes, i);
581            if literal_end > i {
582                if let Some((suffix, suffix_len)) = match_suffix(&code[literal_end..]) {
583                    if !is_identifier_char(bytes, literal_end + suffix_len) {
584                        let literal = &code[i..literal_end];
585                        result.push_str("@as(");
586                        result.push_str(suffix);
587                        result.push_str(", ");
588                        result.push_str(literal);
589                        result.push_str(")");
590                        i = literal_end + suffix_len;
591                        continue;
592                    }
593                }
594
595                result.push_str(&code[i..literal_end]);
596                i = literal_end;
597                continue;
598            }
599        }
600
601        result.push(ch);
602        i += 1;
603    }
604
605    if result.len() == code.len() {
606        code.to_string()
607    } else {
608        result
609    }
610}
611
612fn extract_string_literal(source: &str) -> (&str, usize) {
613    let bytes = source.as_bytes();
614    let mut i = 1; // skip opening quote
615    while i < bytes.len() {
616        match bytes[i] {
617            b'\\' => {
618                i += 2;
619            }
620            b'"' => {
621                i += 1;
622                break;
623            }
624            _ => i += 1,
625        }
626    }
627    (&source[..i], i)
628}
629
630fn extract_char_literal(source: &str) -> (&str, usize) {
631    let bytes = source.as_bytes();
632    let mut i = 1; // skip opening quote
633    while i < bytes.len() {
634        match bytes[i] {
635            b'\\' => {
636                i += 2;
637            }
638            b'\'' => {
639                i += 1;
640                break;
641            }
642            _ => i += 1,
643        }
644    }
645    (&source[..i], i)
646}
647
648fn extract_block_comment(source: &str) -> (&str, usize) {
649    if let Some(idx) = source[2..].find("*/") {
650        let end = 2 + idx + 2;
651        (&source[..end], end)
652    } else {
653        (source, source.len())
654    }
655}
656
657fn scan_numeric_literal(bytes: &[u8], start: usize) -> usize {
658    let len = bytes.len();
659    if start >= len {
660        return start;
661    }
662
663    let mut i = start;
664
665    if bytes[i] == b'0' && i + 1 < len {
666        match bytes[i + 1] {
667            b'x' | b'X' => {
668                i += 2;
669                while i < len {
670                    match bytes[i] {
671                        b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F' | b'_' => i += 1,
672                        _ => break,
673                    }
674                }
675                return i;
676            }
677            b'o' | b'O' => {
678                i += 2;
679                while i < len {
680                    match bytes[i] {
681                        b'0'..=b'7' | b'_' => i += 1,
682                        _ => break,
683                    }
684                }
685                return i;
686            }
687            b'b' | b'B' => {
688                i += 2;
689                while i < len {
690                    match bytes[i] {
691                        b'0' | b'1' | b'_' => i += 1,
692                        _ => break,
693                    }
694                }
695                return i;
696            }
697            _ => {}
698        }
699    }
700
701    i = start;
702    let mut seen_dot = false;
703    while i < len {
704        match bytes[i] {
705            b'0'..=b'9' | b'_' => i += 1,
706            b'.' if !seen_dot => {
707                if i + 1 < len && bytes[i + 1].is_ascii_digit() {
708                    seen_dot = true;
709                    i += 1;
710                } else {
711                    break;
712                }
713            }
714            b'e' | b'E' | b'p' | b'P' => {
715                let mut j = i + 1;
716                if j < len && (bytes[j] == b'+' || bytes[j] == b'-') {
717                    j += 1;
718                }
719                let mut exp_digits = 0;
720                while j < len {
721                    match bytes[j] {
722                        b'0'..=b'9' | b'_' => {
723                            exp_digits += 1;
724                            j += 1;
725                        }
726                        _ => break,
727                    }
728                }
729                if exp_digits == 0 {
730                    break;
731                }
732                i = j;
733            }
734            _ => break,
735        }
736    }
737
738    i
739}
740
741fn match_suffix(rest: &str) -> Option<(&'static str, usize)> {
742    for &suffix in &ZIG_NUMERIC_SUFFIXES {
743        if rest.starts_with(suffix) {
744            return Some((suffix, suffix.len()));
745        }
746    }
747    None
748}
749
750fn is_identifier_char(bytes: &[u8], index: usize) -> bool {
751    if index >= bytes.len() {
752        return false;
753    }
754    let ch = bytes[index] as char;
755    ch.is_ascii_alphanumeric() || ch == '_'
756}