run/engine/
zig.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::{Duration, Instant};
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
10
11pub struct ZigEngine {
12    executable: Option<PathBuf>,
13}
14
15impl ZigEngine {
16    pub fn new() -> Self {
17        Self {
18            executable: resolve_zig_binary(),
19        }
20    }
21
22    fn ensure_executable(&self) -> Result<&Path> {
23        self.executable.as_deref().ok_or_else(|| {
24            anyhow::anyhow!(
25                "Zig support requires the `zig` executable. Install it from https://ziglang.org/download/ and ensure it is on your PATH."
26            )
27        })
28    }
29
30    fn write_temp_source(&self, code: &str) -> Result<(TempDir, PathBuf)> {
31        let dir = Builder::new()
32            .prefix("run-zig")
33            .tempdir()
34            .context("failed to create temporary directory for Zig source")?;
35        let path = dir.path().join("snippet.zig");
36        let mut contents = code.to_string();
37        if !contents.ends_with('\n') {
38            contents.push('\n');
39        }
40        std::fs::write(&path, contents).with_context(|| {
41            format!("failed to write temporary Zig source to {}", path.display())
42        })?;
43        Ok((dir, path))
44    }
45
46    fn run_source(&self, source: &Path) -> Result<std::process::Output> {
47        let executable = self.ensure_executable()?;
48        let mut cmd = Command::new(executable);
49        cmd.arg("run")
50            .arg(source)
51            .stdout(Stdio::piped())
52            .stderr(Stdio::piped());
53        cmd.stdin(Stdio::inherit());
54        if let Some(dir) = source.parent() {
55            cmd.current_dir(dir);
56        }
57        cmd.output().with_context(|| {
58            format!(
59                "failed to execute {} with source {}",
60                executable.display(),
61                source.display()
62            )
63        })
64    }
65}
66
67impl LanguageEngine for ZigEngine {
68    fn id(&self) -> &'static str {
69        "zig"
70    }
71
72    fn display_name(&self) -> &'static str {
73        "Zig"
74    }
75
76    fn aliases(&self) -> &[&'static str] {
77        &["ziglang"]
78    }
79
80    fn supports_sessions(&self) -> bool {
81        self.executable.is_some()
82    }
83
84    fn validate(&self) -> Result<()> {
85        let executable = self.ensure_executable()?;
86        let mut cmd = Command::new(executable);
87        cmd.arg("version")
88            .stdout(Stdio::null())
89            .stderr(Stdio::null());
90        cmd.status()
91            .with_context(|| format!("failed to invoke {}", executable.display()))?
92            .success()
93            .then_some(())
94            .ok_or_else(|| anyhow::anyhow!("{} is not executable", executable.display()))
95    }
96
97    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
98        let start = Instant::now();
99        let (temp_dir, source_path) = match payload {
100            ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
101                let snippet = wrap_inline_snippet(code);
102                let (dir, path) = self.write_temp_source(&snippet)?;
103                (Some(dir), path)
104            }
105            ExecutionPayload::File { path } => {
106                if path.extension().and_then(|e| e.to_str()) != Some("zig") {
107                    let code = std::fs::read_to_string(path)?;
108                    let (dir, new_path) = self.write_temp_source(&code)?;
109                    (Some(dir), new_path)
110                } else {
111                    (None, path.clone())
112                }
113            }
114        };
115
116        let output = self.run_source(&source_path)?;
117        drop(temp_dir);
118
119        let mut combined_stdout = String::from_utf8_lossy(&output.stdout).into_owned();
120        let stderr_str = String::from_utf8_lossy(&output.stderr).into_owned();
121
122        if output.status.success() && !stderr_str.contains("error:") {
123            if !combined_stdout.is_empty() && !stderr_str.is_empty() {
124                combined_stdout.push_str(&stderr_str);
125            } else if combined_stdout.is_empty() {
126                combined_stdout = stderr_str.clone();
127            }
128        }
129
130        Ok(ExecutionOutcome {
131            language: self.id().to_string(),
132            exit_code: output.status.code(),
133            stdout: combined_stdout,
134            stderr: if output.status.success() && !stderr_str.contains("error:") {
135                String::new()
136            } else {
137                stderr_str
138            },
139            duration: start.elapsed(),
140        })
141    }
142
143    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
144        let executable = self.ensure_executable()?.to_path_buf();
145        Ok(Box::new(ZigSession::new(executable)?))
146    }
147}
148
149fn resolve_zig_binary() -> Option<PathBuf> {
150    which::which("zig").ok()
151}
152
153const ZIG_NUMERIC_SUFFIXES: [&str; 17] = [
154    "usize", "isize", "u128", "i128", "f128", "f80", "u64", "i64", "f64", "u32", "i32", "f32",
155    "u16", "i16", "f16", "u8", "i8",
156];
157
158fn wrap_inline_snippet(code: &str) -> String {
159    let trimmed = code.trim();
160    if trimmed.is_empty() || trimmed.contains("pub fn main") {
161        let mut owned = code.to_string();
162        if !owned.ends_with('\n') {
163            owned.push('\n');
164        }
165        return owned;
166    }
167
168    let mut body = String::new();
169    for line in code.lines() {
170        body.push_str("    ");
171        body.push_str(line);
172        if !line.ends_with('\n') {
173            body.push('\n');
174        }
175    }
176    if body.is_empty() {
177        body.push_str("    const stdout = std.io.getStdOut().writer(); _ = stdout.print(\"\\n\", .{}) catch {};\n");
178    }
179
180    format!("const std = @import(\"std\");\n\npub fn main() !void {{\n{body}}}\n")
181}
182
183struct ZigSession {
184    executable: PathBuf,
185    workspace: TempDir,
186    items: Vec<String>,
187    statements: Vec<String>,
188    last_stdout: String,
189    last_stderr: String,
190}
191
192enum ZigSnippetKind {
193    Declaration,
194    Statement,
195    Expression,
196}
197
198impl ZigSession {
199    fn new(executable: PathBuf) -> Result<Self> {
200        let workspace = TempDir::new().context("failed to create Zig session workspace")?;
201        let session = Self {
202            executable,
203            workspace,
204            items: Vec::new(),
205            statements: Vec::new(),
206            last_stdout: String::new(),
207            last_stderr: String::new(),
208        };
209        session.persist_source()?;
210        Ok(session)
211    }
212
213    fn source_path(&self) -> PathBuf {
214        self.workspace.path().join("session.zig")
215    }
216
217    fn persist_source(&self) -> Result<()> {
218        let source = self.render_source();
219        fs::write(self.source_path(), source)
220            .with_context(|| "failed to write Zig session source".to_string())
221    }
222
223    fn render_source(&self) -> String {
224        let mut source = String::from("const std = @import(\"std\");\n\n");
225
226        for item in &self.items {
227            source.push_str(item);
228            if !item.ends_with('\n') {
229                source.push('\n');
230            }
231            source.push('\n');
232        }
233
234        source.push_str("pub fn main() !void {\n");
235        if self.statements.is_empty() {
236            source.push_str("    return;\n");
237        } else {
238            for snippet in &self.statements {
239                for line in snippet.lines() {
240                    source.push_str("    ");
241                    source.push_str(line);
242                    source.push('\n');
243                }
244            }
245        }
246        source.push_str("}\n");
247
248        source
249    }
250
251    fn run_program(&self) -> Result<std::process::Output> {
252        let mut cmd = Command::new(&self.executable);
253        cmd.arg("run")
254            .arg("session.zig")
255            .stdout(Stdio::piped())
256            .stderr(Stdio::piped())
257            .current_dir(self.workspace.path());
258        cmd.output().with_context(|| {
259            format!(
260                "failed to execute {} for Zig session",
261                self.executable.display()
262            )
263        })
264    }
265
266    fn run_standalone_program(&self, code: &str) -> Result<ExecutionOutcome> {
267        let start = Instant::now();
268        let path = self.workspace.path().join("standalone.zig");
269        let mut contents = code.to_string();
270        if !contents.ends_with('\n') {
271            contents.push('\n');
272        }
273        fs::write(&path, contents)
274            .with_context(|| "failed to write Zig standalone source".to_string())?;
275
276        let mut cmd = Command::new(&self.executable);
277        cmd.arg("run")
278            .arg("standalone.zig")
279            .stdout(Stdio::piped())
280            .stderr(Stdio::piped())
281            .current_dir(self.workspace.path());
282        let output = cmd.output().with_context(|| {
283            format!(
284                "failed to execute {} for Zig standalone snippet",
285                self.executable.display()
286            )
287        })?;
288
289        let mut stdout = Self::normalize_output(&output.stdout);
290        let stderr = Self::normalize_output(&output.stderr);
291
292        if output.status.success() && !stderr.contains("error:") {
293            if stdout.is_empty() {
294                stdout = stderr.clone();
295            } else {
296                stdout.push_str(&stderr);
297            }
298        }
299
300        Ok(ExecutionOutcome {
301            language: self.language_id().to_string(),
302            exit_code: output.status.code(),
303            stdout,
304            stderr: if output.status.success() && !stderr.contains("error:") {
305                String::new()
306            } else {
307                stderr
308            },
309            duration: start.elapsed(),
310        })
311    }
312
313    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
314        self.persist_source()?;
315        let output = self.run_program()?;
316        let mut stdout_full = Self::normalize_output(&output.stdout);
317        let stderr_full = Self::normalize_output(&output.stderr);
318
319        let success = output.status.success();
320
321        // Merge stderr to stdout for Zig (std.debug.print goes to stderr)
322        if success && !stderr_full.is_empty() && !stderr_full.contains("error:") {
323            if stdout_full.is_empty() {
324                stdout_full = stderr_full.clone();
325            } else {
326                stdout_full.push_str(&stderr_full);
327            }
328        }
329
330        let (stdout, stderr) = if success {
331            let stdout_delta = Self::diff_outputs(&self.last_stdout, &stdout_full);
332            let stderr_clean = if !stderr_full.contains("error:") {
333                String::new()
334            } else {
335                stderr_full.clone()
336            };
337            let stderr_delta = Self::diff_outputs(&self.last_stderr, &stderr_clean);
338            self.last_stdout = stdout_full;
339            self.last_stderr = stderr_clean;
340            (stdout_delta, stderr_delta)
341        } else {
342            (stdout_full, stderr_full)
343        };
344
345        let outcome = ExecutionOutcome {
346            language: self.language_id().to_string(),
347            exit_code: output.status.code(),
348            stdout,
349            stderr,
350            duration: start.elapsed(),
351        };
352
353        Ok((outcome, success))
354    }
355
356    fn apply_declaration(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
357        let normalized = normalize_snippet(code);
358        let mut snippet = normalized;
359        if !snippet.ends_with('\n') {
360            snippet.push('\n');
361        }
362        self.items.push(snippet);
363        let start = Instant::now();
364        let (outcome, success) = self.run_current(start)?;
365        if !success {
366            let _ = self.items.pop();
367            self.persist_source()?;
368        }
369        Ok((outcome, success))
370    }
371
372    fn apply_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
373        let normalized = normalize_snippet(code);
374        let snippet = ensure_trailing_newline(&normalized);
375        self.statements.push(snippet);
376        let start = Instant::now();
377        let (outcome, success) = self.run_current(start)?;
378        if !success {
379            let _ = self.statements.pop();
380            self.persist_source()?;
381        }
382        Ok((outcome, success))
383    }
384
385    fn apply_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
386        let normalized = normalize_snippet(code);
387        let wrapped = wrap_expression(&normalized);
388        self.statements.push(wrapped);
389        let start = Instant::now();
390        let (outcome, success) = self.run_current(start)?;
391        if !success {
392            let _ = self.statements.pop();
393            self.persist_source()?;
394        }
395        Ok((outcome, success))
396    }
397
398    fn reset(&mut self) -> Result<()> {
399        self.items.clear();
400        self.statements.clear();
401        self.last_stdout.clear();
402        self.last_stderr.clear();
403        self.persist_source()
404    }
405
406    fn normalize_output(bytes: &[u8]) -> String {
407        String::from_utf8_lossy(bytes)
408            .replace("\r\n", "\n")
409            .replace('\r', "")
410    }
411
412    fn diff_outputs(previous: &str, current: &str) -> String {
413        current
414            .strip_prefix(previous)
415            .map(|s| s.to_string())
416            .unwrap_or_else(|| current.to_string())
417    }
418}
419
420impl LanguageSession for ZigSession {
421    fn language_id(&self) -> &str {
422        "zig"
423    }
424
425    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
426        let trimmed = code.trim();
427        if trimmed.is_empty() {
428            return Ok(ExecutionOutcome {
429                language: self.language_id().to_string(),
430                exit_code: None,
431                stdout: String::new(),
432                stderr: String::new(),
433                duration: Duration::default(),
434            });
435        }
436
437        if trimmed.eq_ignore_ascii_case(":reset") {
438            self.reset()?;
439            return Ok(ExecutionOutcome {
440                language: self.language_id().to_string(),
441                exit_code: None,
442                stdout: String::new(),
443                stderr: String::new(),
444                duration: Duration::default(),
445            });
446        }
447
448        if trimmed.eq_ignore_ascii_case(":help") {
449            return Ok(ExecutionOutcome {
450                language: self.language_id().to_string(),
451                exit_code: None,
452                stdout:
453                    "Zig commands:\n  :reset — clear session state\n  :help  — show this message\n"
454                        .to_string(),
455                stderr: String::new(),
456                duration: Duration::default(),
457            });
458        }
459
460        if trimmed.contains("pub fn main") {
461            return self.run_standalone_program(code);
462        }
463
464        match classify_snippet(trimmed) {
465            ZigSnippetKind::Declaration => {
466                let (outcome, _) = self.apply_declaration(code)?;
467                Ok(outcome)
468            }
469            ZigSnippetKind::Statement => {
470                let (outcome, _) = self.apply_statement(code)?;
471                Ok(outcome)
472            }
473            ZigSnippetKind::Expression => {
474                let (outcome, _) = self.apply_expression(trimmed)?;
475                Ok(outcome)
476            }
477        }
478    }
479
480    fn shutdown(&mut self) -> Result<()> {
481        Ok(())
482    }
483}
484
485fn classify_snippet(code: &str) -> ZigSnippetKind {
486    if looks_like_declaration(code) {
487        ZigSnippetKind::Declaration
488    } else if looks_like_statement(code) {
489        ZigSnippetKind::Statement
490    } else {
491        ZigSnippetKind::Expression
492    }
493}
494
495fn looks_like_declaration(code: &str) -> bool {
496    let trimmed = code.trim_start();
497    matches!(
498        trimmed,
499        t if t.starts_with("const ")
500            || t.starts_with("var ")
501            || t.starts_with("pub ")
502            || t.starts_with("fn ")
503            || t.starts_with("usingnamespace ")
504            || t.starts_with("extern ")
505            || t.starts_with("comptime ")
506            || t.starts_with("test ")
507    )
508}
509
510fn looks_like_statement(code: &str) -> bool {
511    let trimmed = code.trim_end();
512    trimmed.contains('\n')
513        || trimmed.ends_with(';')
514        || trimmed.ends_with('}')
515        || trimmed.ends_with(':')
516        || trimmed.starts_with("//")
517        || trimmed.starts_with("/*")
518}
519
520fn ensure_trailing_newline(code: &str) -> String {
521    let mut snippet = code.to_string();
522    if !snippet.ends_with('\n') {
523        snippet.push('\n');
524    }
525    snippet
526}
527
528fn wrap_expression(code: &str) -> String {
529    format!("std.debug.print(\"{{any}}\\n\", .{{ {} }});", code)
530}
531
532fn normalize_snippet(code: &str) -> String {
533    rewrite_numeric_suffixes(code)
534}
535
536fn rewrite_numeric_suffixes(code: &str) -> String {
537    let bytes = code.as_bytes();
538    let mut result = String::with_capacity(code.len());
539    let mut i = 0;
540    while i < bytes.len() {
541        let ch = bytes[i] as char;
542
543        if ch == '"' {
544            let (segment, advance) = extract_string_literal(&code[i..]);
545            result.push_str(segment);
546            i += advance;
547            continue;
548        }
549
550        if ch == '\'' {
551            let (segment, advance) = extract_char_literal(&code[i..]);
552            result.push_str(segment);
553            i += advance;
554            continue;
555        }
556
557        if ch == '/' && i + 1 < bytes.len() {
558            let next = bytes[i + 1] as char;
559            if next == '/' {
560                result.push_str(&code[i..]);
561                break;
562            }
563            if next == '*' {
564                let (segment, advance) = extract_block_comment(&code[i..]);
565                result.push_str(segment);
566                i += advance;
567                continue;
568            }
569        }
570
571        if ch.is_ascii_digit() {
572            if i > 0 {
573                let prev = bytes[i - 1] as char;
574                if prev.is_ascii_alphanumeric() || prev == '_' {
575                    result.push(ch);
576                    i += 1;
577                    continue;
578                }
579            }
580
581            let literal_end = scan_numeric_literal(bytes, i);
582            if literal_end > i {
583                if let Some((suffix, suffix_len)) = match_suffix(&code[literal_end..]) {
584                    if !is_identifier_char(bytes, literal_end + suffix_len) {
585                        let literal = &code[i..literal_end];
586                        result.push_str("@as(");
587                        result.push_str(suffix);
588                        result.push_str(", ");
589                        result.push_str(literal);
590                        result.push_str(")");
591                        i = literal_end + suffix_len;
592                        continue;
593                    }
594                }
595
596                result.push_str(&code[i..literal_end]);
597                i = literal_end;
598                continue;
599            }
600        }
601
602        result.push(ch);
603        i += 1;
604    }
605
606    if result.len() == code.len() {
607        code.to_string()
608    } else {
609        result
610    }
611}
612
613fn extract_string_literal(source: &str) -> (&str, usize) {
614    let bytes = source.as_bytes();
615    let mut i = 1; // skip opening quote
616    while i < bytes.len() {
617        match bytes[i] {
618            b'\\' => {
619                i += 2;
620            }
621            b'"' => {
622                i += 1;
623                break;
624            }
625            _ => i += 1,
626        }
627    }
628    (&source[..i], i)
629}
630
631fn extract_char_literal(source: &str) -> (&str, usize) {
632    let bytes = source.as_bytes();
633    let mut i = 1; // skip opening quote
634    while i < bytes.len() {
635        match bytes[i] {
636            b'\\' => {
637                i += 2;
638            }
639            b'\'' => {
640                i += 1;
641                break;
642            }
643            _ => i += 1,
644        }
645    }
646    (&source[..i], i)
647}
648
649fn extract_block_comment(source: &str) -> (&str, usize) {
650    if let Some(idx) = source[2..].find("*/") {
651        let end = 2 + idx + 2;
652        (&source[..end], end)
653    } else {
654        (source, source.len())
655    }
656}
657
658fn scan_numeric_literal(bytes: &[u8], start: usize) -> usize {
659    let len = bytes.len();
660    if start >= len {
661        return start;
662    }
663
664    let mut i = start;
665
666    if bytes[i] == b'0' && i + 1 < len {
667        match bytes[i + 1] {
668            b'x' | b'X' => {
669                i += 2;
670                while i < len {
671                    match bytes[i] {
672                        b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F' | b'_' => i += 1,
673                        _ => break,
674                    }
675                }
676                return i;
677            }
678            b'o' | b'O' => {
679                i += 2;
680                while i < len {
681                    match bytes[i] {
682                        b'0'..=b'7' | b'_' => i += 1,
683                        _ => break,
684                    }
685                }
686                return i;
687            }
688            b'b' | b'B' => {
689                i += 2;
690                while i < len {
691                    match bytes[i] {
692                        b'0' | b'1' | b'_' => i += 1,
693                        _ => break,
694                    }
695                }
696                return i;
697            }
698            _ => {}
699        }
700    }
701
702    i = start;
703    let mut seen_dot = false;
704    while i < len {
705        match bytes[i] {
706            b'0'..=b'9' | b'_' => i += 1,
707            b'.' if !seen_dot => {
708                if i + 1 < len && bytes[i + 1].is_ascii_digit() {
709                    seen_dot = true;
710                    i += 1;
711                } else {
712                    break;
713                }
714            }
715            b'e' | b'E' | b'p' | b'P' => {
716                let mut j = i + 1;
717                if j < len && (bytes[j] == b'+' || bytes[j] == b'-') {
718                    j += 1;
719                }
720                let mut exp_digits = 0;
721                while j < len {
722                    match bytes[j] {
723                        b'0'..=b'9' | b'_' => {
724                            exp_digits += 1;
725                            j += 1;
726                        }
727                        _ => break,
728                    }
729                }
730                if exp_digits == 0 {
731                    break;
732                }
733                i = j;
734            }
735            _ => break,
736        }
737    }
738
739    i
740}
741
742fn match_suffix(rest: &str) -> Option<(&'static str, usize)> {
743    for &suffix in &ZIG_NUMERIC_SUFFIXES {
744        if rest.starts_with(suffix) {
745            return Some((suffix, suffix.len()));
746        }
747    }
748    None
749}
750
751fn is_identifier_char(bytes: &[u8], index: usize) -> bool {
752    if index >= bytes.len() {
753        return false;
754    }
755    let ch = bytes[index] as char;
756    ch.is_ascii_alphanumeric() || ch == '_'
757}