Skip to main content

run/engine/
zig.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::{Duration, Instant};
5
6use anyhow::{Context, Result};
7use tempfile::{Builder, TempDir};
8
9use super::{
10    ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, cache_store, hash_source,
11    try_cached_execution,
12};
13
14pub struct ZigEngine {
15    executable: Option<PathBuf>,
16}
17
18impl Default for ZigEngine {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl ZigEngine {
25    pub fn new() -> Self {
26        Self {
27            executable: resolve_zig_binary(),
28        }
29    }
30
31    fn ensure_executable(&self) -> Result<&Path> {
32        self.executable.as_deref().ok_or_else(|| {
33            anyhow::anyhow!(
34                "Zig support requires the `zig` executable. Install it from https://ziglang.org/download/ and ensure it is on your PATH."
35            )
36        })
37    }
38
39    fn write_temp_source(&self, code: &str) -> Result<(TempDir, PathBuf)> {
40        let dir = Builder::new()
41            .prefix("run-zig")
42            .tempdir()
43            .context("failed to create temporary directory for Zig source")?;
44        let path = dir.path().join("snippet.zig");
45        let mut contents = code.to_string();
46        if !contents.ends_with('\n') {
47            contents.push('\n');
48        }
49        std::fs::write(&path, contents).with_context(|| {
50            format!("failed to write temporary Zig source to {}", path.display())
51        })?;
52        Ok((dir, path))
53    }
54
55    fn run_source(&self, source: &Path) -> Result<std::process::Output> {
56        let executable = self.ensure_executable()?;
57        let mut cmd = Command::new(executable);
58        cmd.arg("run")
59            .arg(source)
60            .stdout(Stdio::piped())
61            .stderr(Stdio::piped());
62        cmd.stdin(Stdio::inherit());
63        if let Some(dir) = source.parent() {
64            cmd.current_dir(dir);
65        }
66        cmd.output().with_context(|| {
67            format!(
68                "failed to execute {} with source {}",
69                executable.display(),
70                source.display()
71            )
72        })
73    }
74}
75
76impl LanguageEngine for ZigEngine {
77    fn id(&self) -> &'static str {
78        "zig"
79    }
80
81    fn display_name(&self) -> &'static str {
82        "Zig"
83    }
84
85    fn aliases(&self) -> &[&'static str] {
86        &["ziglang"]
87    }
88
89    fn supports_sessions(&self) -> bool {
90        self.executable.is_some()
91    }
92
93    fn validate(&self) -> Result<()> {
94        let executable = self.ensure_executable()?;
95        let mut cmd = Command::new(executable);
96        cmd.arg("version")
97            .stdout(Stdio::null())
98            .stderr(Stdio::null());
99        cmd.status()
100            .with_context(|| format!("failed to invoke {}", executable.display()))?
101            .success()
102            .then_some(())
103            .ok_or_else(|| anyhow::anyhow!("{} is not executable", executable.display()))
104    }
105
106    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
107        // Try cache for inline/stdin payloads
108        if let Some(code) = match payload {
109            ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
110                Some(code.as_str())
111            }
112            _ => None,
113        } {
114            let snippet = wrap_inline_snippet(code);
115            let src_hash = hash_source(&snippet);
116            if let Some(output) = try_cached_execution(src_hash) {
117                let start = Instant::now();
118                return Ok(ExecutionOutcome {
119                    language: self.id().to_string(),
120                    exit_code: output.status.code(),
121                    stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
122                    stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
123                    duration: start.elapsed(),
124                });
125            }
126        }
127
128        let start = Instant::now();
129        let (temp_dir, source_path, cache_key) = match payload {
130            ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
131                let snippet = wrap_inline_snippet(code);
132                let h = hash_source(&snippet);
133                let (dir, path) = self.write_temp_source(&snippet)?;
134                (Some(dir), path, Some(h))
135            }
136            ExecutionPayload::File { path } => {
137                if path.extension().and_then(|e| e.to_str()) != Some("zig") {
138                    let code = std::fs::read_to_string(path)?;
139                    let (dir, new_path) = self.write_temp_source(&code)?;
140                    (Some(dir), new_path, None)
141                } else {
142                    (None, path.clone(), None)
143                }
144            }
145        };
146
147        // For cacheable code, try zig build-exe + cache
148        if let Some(h) = cache_key {
149            let executable = self.ensure_executable()?;
150            let dir = source_path.parent().unwrap_or(std::path::Path::new("."));
151            let bin_path = dir.join("snippet");
152            let mut build_cmd = Command::new(executable);
153            build_cmd
154                .arg("build-exe")
155                .arg(&source_path)
156                .arg("-femit-bin=snippet")
157                .stdout(Stdio::piped())
158                .stderr(Stdio::piped())
159                .current_dir(dir);
160
161            if let Ok(build_output) = build_cmd.output()
162                && build_output.status.success()
163                && bin_path.exists()
164            {
165                cache_store(h, &bin_path);
166                let mut run_cmd = Command::new(&bin_path);
167                run_cmd
168                    .stdout(Stdio::piped())
169                    .stderr(Stdio::piped())
170                    .stdin(Stdio::inherit());
171                if let Ok(output) = run_cmd.output() {
172                    drop(temp_dir);
173                    return Ok(ExecutionOutcome {
174                        language: self.id().to_string(),
175                        exit_code: output.status.code(),
176                        stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
177                        stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
178                        duration: start.elapsed(),
179                    });
180                }
181            }
182        }
183
184        // Fallback to zig run
185        let output = self.run_source(&source_path)?;
186        drop(temp_dir);
187
188        let mut combined_stdout = String::from_utf8_lossy(&output.stdout).into_owned();
189        let stderr_str = String::from_utf8_lossy(&output.stderr).into_owned();
190
191        if output.status.success() && !stderr_str.contains("error:") {
192            if !combined_stdout.is_empty() && !stderr_str.is_empty() {
193                combined_stdout.push_str(&stderr_str);
194            } else if combined_stdout.is_empty() {
195                combined_stdout = stderr_str.clone();
196            }
197        }
198
199        Ok(ExecutionOutcome {
200            language: self.id().to_string(),
201            exit_code: output.status.code(),
202            stdout: combined_stdout,
203            stderr: if output.status.success() && !stderr_str.contains("error:") {
204                String::new()
205            } else {
206                stderr_str
207            },
208            duration: start.elapsed(),
209        })
210    }
211
212    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
213        let executable = self.ensure_executable()?.to_path_buf();
214        Ok(Box::new(ZigSession::new(executable)?))
215    }
216}
217
218fn resolve_zig_binary() -> Option<PathBuf> {
219    which::which("zig").ok()
220}
221
222const ZIG_NUMERIC_SUFFIXES: [&str; 17] = [
223    "usize", "isize", "u128", "i128", "f128", "f80", "u64", "i64", "f64", "u32", "i32", "f32",
224    "u16", "i16", "f16", "u8", "i8",
225];
226
227fn wrap_inline_snippet(code: &str) -> String {
228    let trimmed = code.trim();
229    if trimmed.is_empty() || trimmed.contains("pub fn main") {
230        let mut owned = code.to_string();
231        if !owned.ends_with('\n') {
232            owned.push('\n');
233        }
234        return owned;
235    }
236
237    let mut body = String::new();
238    for line in code.lines() {
239        body.push_str("    ");
240        body.push_str(line);
241        if !line.ends_with('\n') {
242            body.push('\n');
243        }
244    }
245    if body.is_empty() {
246        body.push_str("    const stdout = std.io.getStdOut().writer(); _ = stdout.print(\"\\n\", .{}) catch {};\n");
247    }
248
249    format!("const std = @import(\"std\");\n\npub fn main() !void {{\n{body}}}\n")
250}
251
252struct ZigSession {
253    executable: PathBuf,
254    workspace: TempDir,
255    items: Vec<String>,
256    statements: Vec<String>,
257    last_stdout: String,
258    last_stderr: String,
259}
260
261enum ZigSnippetKind {
262    Declaration,
263    Statement,
264    Expression,
265}
266
267impl ZigSession {
268    fn new(executable: PathBuf) -> Result<Self> {
269        let workspace = TempDir::new().context("failed to create Zig session workspace")?;
270        let session = Self {
271            executable,
272            workspace,
273            items: Vec::new(),
274            statements: Vec::new(),
275            last_stdout: String::new(),
276            last_stderr: String::new(),
277        };
278        session.persist_source()?;
279        Ok(session)
280    }
281
282    fn source_path(&self) -> PathBuf {
283        self.workspace.path().join("session.zig")
284    }
285
286    fn persist_source(&self) -> Result<()> {
287        let source = self.render_source();
288        fs::write(self.source_path(), source)
289            .with_context(|| "failed to write Zig session source".to_string())
290    }
291
292    fn render_source(&self) -> String {
293        let mut source = String::from("const std = @import(\"std\");\n\n");
294
295        for item in &self.items {
296            source.push_str(item);
297            if !item.ends_with('\n') {
298                source.push('\n');
299            }
300            source.push('\n');
301        }
302
303        source.push_str("pub fn main() !void {\n");
304        if self.statements.is_empty() {
305            source.push_str("    return;\n");
306        } else {
307            for snippet in &self.statements {
308                for line in snippet.lines() {
309                    source.push_str("    ");
310                    source.push_str(line);
311                    source.push('\n');
312                }
313            }
314        }
315        source.push_str("}\n");
316
317        source
318    }
319
320    fn run_program(&self) -> Result<std::process::Output> {
321        let mut cmd = Command::new(&self.executable);
322        cmd.arg("run")
323            .arg("session.zig")
324            .stdout(Stdio::piped())
325            .stderr(Stdio::piped())
326            .current_dir(self.workspace.path());
327        cmd.output().with_context(|| {
328            format!(
329                "failed to execute {} for Zig session",
330                self.executable.display()
331            )
332        })
333    }
334
335    fn run_standalone_program(&self, code: &str) -> Result<ExecutionOutcome> {
336        let start = Instant::now();
337        let path = self.workspace.path().join("standalone.zig");
338        let mut contents = code.to_string();
339        if !contents.ends_with('\n') {
340            contents.push('\n');
341        }
342        fs::write(&path, contents)
343            .with_context(|| "failed to write Zig standalone source".to_string())?;
344
345        let mut cmd = Command::new(&self.executable);
346        cmd.arg("run")
347            .arg("standalone.zig")
348            .stdout(Stdio::piped())
349            .stderr(Stdio::piped())
350            .current_dir(self.workspace.path());
351        let output = cmd.output().with_context(|| {
352            format!(
353                "failed to execute {} for Zig standalone snippet",
354                self.executable.display()
355            )
356        })?;
357
358        let mut stdout = Self::normalize_output(&output.stdout);
359        let stderr = Self::normalize_output(&output.stderr);
360
361        if output.status.success() && !stderr.contains("error:") {
362            if stdout.is_empty() {
363                stdout = stderr.clone();
364            } else {
365                stdout.push_str(&stderr);
366            }
367        }
368
369        Ok(ExecutionOutcome {
370            language: self.language_id().to_string(),
371            exit_code: output.status.code(),
372            stdout,
373            stderr: if output.status.success() && !stderr.contains("error:") {
374                String::new()
375            } else {
376                stderr
377            },
378            duration: start.elapsed(),
379        })
380    }
381
382    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
383        self.persist_source()?;
384        let output = self.run_program()?;
385        let mut stdout_full = Self::normalize_output(&output.stdout);
386        let stderr_full = Self::normalize_output(&output.stderr);
387
388        let success = output.status.success();
389
390        if success && !stderr_full.is_empty() && !stderr_full.contains("error:") {
391            if stdout_full.is_empty() {
392                stdout_full = stderr_full.clone();
393            } else {
394                stdout_full.push_str(&stderr_full);
395            }
396        }
397
398        let (stdout, stderr) = if success {
399            let stdout_delta = Self::diff_outputs(&self.last_stdout, &stdout_full);
400            let stderr_clean = if !stderr_full.contains("error:") {
401                String::new()
402            } else {
403                stderr_full.clone()
404            };
405            let stderr_delta = Self::diff_outputs(&self.last_stderr, &stderr_clean);
406            self.last_stdout = stdout_full;
407            self.last_stderr = stderr_clean;
408            (stdout_delta, stderr_delta)
409        } else {
410            (stdout_full, stderr_full)
411        };
412
413        let outcome = ExecutionOutcome {
414            language: self.language_id().to_string(),
415            exit_code: output.status.code(),
416            stdout,
417            stderr,
418            duration: start.elapsed(),
419        };
420
421        Ok((outcome, success))
422    }
423
424    fn apply_declaration(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
425        let normalized = normalize_snippet(code);
426        let mut snippet = normalized;
427        if !snippet.ends_with('\n') {
428            snippet.push('\n');
429        }
430        self.items.push(snippet);
431        let start = Instant::now();
432        let (outcome, success) = self.run_current(start)?;
433        if !success {
434            let _ = self.items.pop();
435            self.persist_source()?;
436        }
437        Ok((outcome, success))
438    }
439
440    fn apply_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
441        let normalized = normalize_snippet(code);
442        let snippet = ensure_trailing_newline(&normalized);
443        self.statements.push(snippet);
444        let start = Instant::now();
445        let (outcome, success) = self.run_current(start)?;
446        if !success {
447            let _ = self.statements.pop();
448            self.persist_source()?;
449        }
450        Ok((outcome, success))
451    }
452
453    fn apply_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
454        let normalized = normalize_snippet(code);
455        let wrapped = wrap_expression(&normalized);
456        self.statements.push(wrapped);
457        let start = Instant::now();
458        let (outcome, success) = self.run_current(start)?;
459        if !success {
460            let _ = self.statements.pop();
461            self.persist_source()?;
462        }
463        Ok((outcome, success))
464    }
465
466    fn reset(&mut self) -> Result<()> {
467        self.items.clear();
468        self.statements.clear();
469        self.last_stdout.clear();
470        self.last_stderr.clear();
471        self.persist_source()
472    }
473
474    fn normalize_output(bytes: &[u8]) -> String {
475        String::from_utf8_lossy(bytes)
476            .replace("\r\n", "\n")
477            .replace('\r', "")
478    }
479
480    fn diff_outputs(previous: &str, current: &str) -> String {
481        current
482            .strip_prefix(previous)
483            .map(|s| s.to_string())
484            .unwrap_or_else(|| current.to_string())
485    }
486}
487
488impl LanguageSession for ZigSession {
489    fn language_id(&self) -> &str {
490        "zig"
491    }
492
493    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
494        let trimmed = code.trim();
495        if trimmed.is_empty() {
496            return Ok(ExecutionOutcome {
497                language: self.language_id().to_string(),
498                exit_code: None,
499                stdout: String::new(),
500                stderr: String::new(),
501                duration: Duration::default(),
502            });
503        }
504
505        if trimmed.eq_ignore_ascii_case(":reset") {
506            self.reset()?;
507            return Ok(ExecutionOutcome {
508                language: self.language_id().to_string(),
509                exit_code: None,
510                stdout: String::new(),
511                stderr: String::new(),
512                duration: Duration::default(),
513            });
514        }
515
516        if trimmed.eq_ignore_ascii_case(":help") {
517            return Ok(ExecutionOutcome {
518                language: self.language_id().to_string(),
519                exit_code: None,
520                stdout:
521                    "Zig commands:\n  :reset - clear session state\n  :help  - show this message\n"
522                        .to_string(),
523                stderr: String::new(),
524                duration: Duration::default(),
525            });
526        }
527
528        if trimmed.contains("pub fn main") {
529            return self.run_standalone_program(code);
530        }
531
532        match classify_snippet(trimmed) {
533            ZigSnippetKind::Declaration => {
534                let (outcome, _) = self.apply_declaration(code)?;
535                Ok(outcome)
536            }
537            ZigSnippetKind::Statement => {
538                let (outcome, _) = self.apply_statement(code)?;
539                Ok(outcome)
540            }
541            ZigSnippetKind::Expression => {
542                let (outcome, _) = self.apply_expression(trimmed)?;
543                Ok(outcome)
544            }
545        }
546    }
547
548    fn shutdown(&mut self) -> Result<()> {
549        Ok(())
550    }
551}
552
553fn classify_snippet(code: &str) -> ZigSnippetKind {
554    if looks_like_declaration(code) {
555        ZigSnippetKind::Declaration
556    } else if looks_like_statement(code) {
557        ZigSnippetKind::Statement
558    } else {
559        ZigSnippetKind::Expression
560    }
561}
562
563fn looks_like_declaration(code: &str) -> bool {
564    let trimmed = code.trim_start();
565    matches!(
566        trimmed,
567        t if t.starts_with("const ")
568            || t.starts_with("var ")
569            || t.starts_with("pub ")
570            || t.starts_with("fn ")
571            || t.starts_with("usingnamespace ")
572            || t.starts_with("extern ")
573            || t.starts_with("comptime ")
574            || t.starts_with("test ")
575    )
576}
577
578fn looks_like_statement(code: &str) -> bool {
579    let trimmed = code.trim_end();
580    trimmed.contains('\n')
581        || trimmed.ends_with(';')
582        || trimmed.ends_with('}')
583        || trimmed.ends_with(':')
584        || trimmed.starts_with("//")
585        || trimmed.starts_with("/*")
586}
587
588fn ensure_trailing_newline(code: &str) -> String {
589    let mut snippet = code.to_string();
590    if !snippet.ends_with('\n') {
591        snippet.push('\n');
592    }
593    snippet
594}
595
596fn wrap_expression(code: &str) -> String {
597    format!("std.debug.print(\"{{any}}\\n\", .{{ {} }});", code)
598}
599
600fn normalize_snippet(code: &str) -> String {
601    rewrite_numeric_suffixes(code)
602}
603
604fn rewrite_numeric_suffixes(code: &str) -> String {
605    let bytes = code.as_bytes();
606    let mut result = String::with_capacity(code.len());
607    let mut i = 0;
608    while i < bytes.len() {
609        let ch = bytes[i] as char;
610
611        if ch == '"' {
612            let (segment, advance) = extract_string_literal(&code[i..]);
613            result.push_str(segment);
614            i += advance;
615            continue;
616        }
617
618        if ch == '\'' {
619            let (segment, advance) = extract_char_literal(&code[i..]);
620            result.push_str(segment);
621            i += advance;
622            continue;
623        }
624
625        if ch == '/' && i + 1 < bytes.len() {
626            let next = bytes[i + 1] as char;
627            if next == '/' {
628                result.push_str(&code[i..]);
629                break;
630            }
631            if next == '*' {
632                let (segment, advance) = extract_block_comment(&code[i..]);
633                result.push_str(segment);
634                i += advance;
635                continue;
636            }
637        }
638
639        if ch.is_ascii_digit() {
640            if i > 0 {
641                let prev = bytes[i - 1] as char;
642                if prev.is_ascii_alphanumeric() || prev == '_' {
643                    result.push(ch);
644                    i += 1;
645                    continue;
646                }
647            }
648
649            let literal_end = scan_numeric_literal(bytes, i);
650            if literal_end > i {
651                if let Some((suffix, suffix_len)) = match_suffix(&code[literal_end..])
652                    && !is_identifier_char(bytes, literal_end + suffix_len)
653                {
654                    let literal = &code[i..literal_end];
655                    result.push_str("@as(");
656                    result.push_str(suffix);
657                    result.push_str(", ");
658                    result.push_str(literal);
659                    result.push(')');
660                    i = literal_end + suffix_len;
661                    continue;
662                }
663
664                result.push_str(&code[i..literal_end]);
665                i = literal_end;
666                continue;
667            }
668        }
669
670        result.push(ch);
671        i += 1;
672    }
673
674    if result.len() == code.len() {
675        code.to_string()
676    } else {
677        result
678    }
679}
680
681fn extract_string_literal(source: &str) -> (&str, usize) {
682    let bytes = source.as_bytes();
683    let mut i = 1; // skip opening quote
684    while i < bytes.len() {
685        match bytes[i] {
686            b'\\' => {
687                i += 2;
688            }
689            b'"' => {
690                i += 1;
691                break;
692            }
693            _ => i += 1,
694        }
695    }
696    (&source[..i], i)
697}
698
699fn extract_char_literal(source: &str) -> (&str, usize) {
700    let bytes = source.as_bytes();
701    let mut i = 1; // skip opening quote
702    while i < bytes.len() {
703        match bytes[i] {
704            b'\\' => {
705                i += 2;
706            }
707            b'\'' => {
708                i += 1;
709                break;
710            }
711            _ => i += 1,
712        }
713    }
714    (&source[..i], i)
715}
716
717fn extract_block_comment(source: &str) -> (&str, usize) {
718    if let Some(idx) = source[2..].find("*/") {
719        let end = 2 + idx + 2;
720        (&source[..end], end)
721    } else {
722        (source, source.len())
723    }
724}
725
726fn scan_numeric_literal(bytes: &[u8], start: usize) -> usize {
727    let len = bytes.len();
728    if start >= len {
729        return start;
730    }
731
732    let mut i = start;
733
734    if bytes[i] == b'0' && i + 1 < len {
735        match bytes[i + 1] {
736            b'x' | b'X' => {
737                i += 2;
738                while i < len {
739                    match bytes[i] {
740                        b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F' | b'_' => i += 1,
741                        _ => break,
742                    }
743                }
744                return i;
745            }
746            b'o' | b'O' => {
747                i += 2;
748                while i < len {
749                    match bytes[i] {
750                        b'0'..=b'7' | b'_' => i += 1,
751                        _ => break,
752                    }
753                }
754                return i;
755            }
756            b'b' | b'B' => {
757                i += 2;
758                while i < len {
759                    match bytes[i] {
760                        b'0' | b'1' | b'_' => i += 1,
761                        _ => break,
762                    }
763                }
764                return i;
765            }
766            _ => {}
767        }
768    }
769
770    i = start;
771    let mut seen_dot = false;
772    while i < len {
773        match bytes[i] {
774            b'0'..=b'9' | b'_' => i += 1,
775            b'.' if !seen_dot => {
776                if i + 1 < len && bytes[i + 1].is_ascii_digit() {
777                    seen_dot = true;
778                    i += 1;
779                } else {
780                    break;
781                }
782            }
783            b'e' | b'E' | b'p' | b'P' => {
784                let mut j = i + 1;
785                if j < len && (bytes[j] == b'+' || bytes[j] == b'-') {
786                    j += 1;
787                }
788                let mut exp_digits = 0;
789                while j < len {
790                    match bytes[j] {
791                        b'0'..=b'9' | b'_' => {
792                            exp_digits += 1;
793                            j += 1;
794                        }
795                        _ => break,
796                    }
797                }
798                if exp_digits == 0 {
799                    break;
800                }
801                i = j;
802            }
803            _ => break,
804        }
805    }
806
807    i
808}
809
810fn match_suffix(rest: &str) -> Option<(&'static str, usize)> {
811    for &suffix in &ZIG_NUMERIC_SUFFIXES {
812        if rest.starts_with(suffix) {
813            return Some((suffix, suffix.len()));
814        }
815    }
816    None
817}
818
819fn is_identifier_char(bytes: &[u8], index: usize) -> bool {
820    if index >= bytes.len() {
821        return false;
822    }
823    let ch = bytes[index] as char;
824    ch.is_ascii_alphanumeric() || ch == '_'
825}