Skip to main content

run/engine/
csharp.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Stdio};
4use std::time::{Duration, Instant};
5
6use anyhow::{Context, Result, bail};
7use tempfile::{Builder, TempDir};
8
9use super::{
10    ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, run_version_command,
11};
12
13pub struct CSharpEngine {
14    runtime: Option<PathBuf>,
15    target_framework: Option<String>,
16}
17
18impl Default for CSharpEngine {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl CSharpEngine {
25    pub fn new() -> Self {
26        let runtime = resolve_dotnet_runtime();
27        let target_framework = runtime
28            .as_ref()
29            .and_then(|path| detect_target_framework(path).ok());
30        Self {
31            runtime,
32            target_framework,
33        }
34    }
35
36    fn ensure_runtime(&self) -> Result<&Path> {
37        self.runtime.as_deref().ok_or_else(|| {
38            anyhow::anyhow!(
39                "C# support requires the `dotnet` CLI. Install the .NET SDK from https://dotnet.microsoft.com/download and ensure `dotnet` is on your PATH."
40            )
41        })
42    }
43
44    fn ensure_target_framework(&self) -> Result<&str> {
45        self.target_framework
46            .as_deref()
47            .ok_or_else(|| anyhow::anyhow!("Unable to detect installed .NET SDK target framework"))
48    }
49
50    fn prepare_source(&self, payload: &ExecutionPayload, dir: &Path) -> Result<PathBuf> {
51        let target = dir.join("Program.cs");
52        match payload {
53            ExecutionPayload::Inline { code, .. } | ExecutionPayload::Stdin { code, .. } => {
54                let mut contents = code.to_string();
55                if !contents.ends_with('\n') {
56                    contents.push('\n');
57                }
58                fs::write(&target, contents).with_context(|| {
59                    format!(
60                        "failed to write temporary C# source to {}",
61                        target.display()
62                    )
63                })?;
64            }
65            ExecutionPayload::File { path, .. } => {
66                fs::copy(path, &target).with_context(|| {
67                    format!(
68                        "failed to copy C# source from {} to {}",
69                        path.display(),
70                        target.display()
71                    )
72                })?;
73            }
74        }
75        Ok(target)
76    }
77
78    fn write_project_file(&self, dir: &Path, tfm: &str) -> Result<PathBuf> {
79        let project_path = dir.join("Run.csproj");
80        let contents = format!(
81            r#"<Project Sdk="Microsoft.NET.Sdk">
82  <PropertyGroup>
83    <OutputType>Exe</OutputType>
84    <TargetFramework>{}</TargetFramework>
85    <ImplicitUsings>enable</ImplicitUsings>
86    <Nullable>disable</Nullable>
87        <NoWarn>CS0219;CS8321</NoWarn>
88  </PropertyGroup>
89</Project>
90"#,
91            tfm
92        );
93        fs::write(&project_path, contents).with_context(|| {
94            format!(
95                "failed to write temporary C# project file to {}",
96                project_path.display()
97            )
98        })?;
99        Ok(project_path)
100    }
101
102    fn run_project(
103        &self,
104        runtime: &Path,
105        project: &Path,
106        workdir: &Path,
107        args: &[String],
108    ) -> Result<std::process::Output> {
109        let mut cmd = Command::new(runtime);
110        cmd.arg("run")
111            .arg("--project")
112            .arg(project)
113            .arg("--nologo")
114            .stdout(Stdio::piped())
115            .stderr(Stdio::piped())
116            .current_dir(workdir);
117        if !args.is_empty() {
118            cmd.arg("--").args(args);
119        }
120        cmd.stdin(Stdio::inherit());
121        cmd.env("DOTNET_CLI_TELEMETRY_OPTOUT", "1");
122        cmd.env("DOTNET_SKIP_FIRST_TIME_EXPERIENCE", "1");
123        cmd.output().with_context(|| {
124            format!(
125                "failed to execute dotnet run for project {} using {}",
126                project.display(),
127                runtime.display()
128            )
129        })
130    }
131}
132
133impl LanguageEngine for CSharpEngine {
134    fn id(&self) -> &'static str {
135        "csharp"
136    }
137
138    fn display_name(&self) -> &'static str {
139        "C#"
140    }
141
142    fn aliases(&self) -> &[&'static str] {
143        &["cs", "c#", "dotnet"]
144    }
145
146    fn supports_sessions(&self) -> bool {
147        self.runtime.is_some() && self.target_framework.is_some()
148    }
149
150    fn validate(&self) -> Result<()> {
151        let runtime = self.ensure_runtime()?;
152        let _tfm = self.ensure_target_framework()?;
153
154        let mut cmd = Command::new(runtime);
155        cmd.arg("--version")
156            .stdout(Stdio::null())
157            .stderr(Stdio::null());
158        cmd.status()
159            .with_context(|| format!("failed to invoke {}", runtime.display()))?
160            .success()
161            .then_some(())
162            .ok_or_else(|| anyhow::anyhow!("{} is not executable", runtime.display()))
163    }
164
165    fn toolchain_version(&self) -> Result<Option<String>> {
166        let runtime = self.ensure_runtime()?;
167        let mut cmd = Command::new(runtime);
168        cmd.arg("--version");
169        let context = format!("{}", runtime.display());
170        run_version_command(cmd, &context)
171    }
172
173    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
174        let runtime = self.ensure_runtime()?;
175        let tfm = self.ensure_target_framework()?;
176
177        let build_dir = Builder::new()
178            .prefix("run-csharp")
179            .tempdir()
180            .context("failed to create temporary directory for csharp build")?;
181        let dir_path = build_dir.path();
182
183        self.write_project_file(dir_path, tfm)?;
184        self.prepare_source(payload, dir_path)?;
185
186        let project_path = dir_path.join("Run.csproj");
187        let start = Instant::now();
188
189        let output = self.run_project(runtime, &project_path, dir_path, payload.args())?;
190
191        Ok(ExecutionOutcome {
192            language: self.id().to_string(),
193            exit_code: output.status.code(),
194            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
195            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
196            duration: start.elapsed(),
197        })
198    }
199
200    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
201        let runtime = self.ensure_runtime()?.to_path_buf();
202        let tfm = self.ensure_target_framework()?.to_string();
203
204        let dir = Builder::new()
205            .prefix("run-csharp-repl")
206            .tempdir()
207            .context("failed to create temporary directory for csharp repl")?;
208        let dir_path = dir.path();
209
210        let project_path = self.write_project_file(dir_path, &tfm)?;
211        let program_path = dir_path.join("Program.cs");
212        fs::write(&program_path, "// C# REPL session\n")
213            .with_context(|| format!("failed to initialize {}", program_path.display()))?;
214
215        Ok(Box::new(CSharpSession {
216            runtime,
217            dir,
218            project_path,
219            program_path,
220            snippets: Vec::new(),
221            previous_stdout: String::new(),
222            previous_stderr: String::new(),
223        }))
224    }
225}
226
227struct CSharpSession {
228    runtime: PathBuf,
229    dir: TempDir,
230    project_path: PathBuf,
231    program_path: PathBuf,
232    snippets: Vec<String>,
233    previous_stdout: String,
234    previous_stderr: String,
235}
236
237impl CSharpSession {
238    fn render_source(&self) -> String {
239        let mut source = String::from(
240            "using System;\nusing System.Collections.Generic;\nusing System.Linq;\nusing System.Text;\nusing System.Threading.Tasks;\n#nullable disable\n\nstatic void __run_print(object value)\n{\n    if (value is null)\n    {\n        Console.WriteLine(\"null\");\n        return;\n    }\n\n    if (value is string s)\n    {\n        Console.WriteLine(s);\n        return;\n    }\n\n    // Pretty-print enumerables: [a, b, c]\n    if (value is System.Collections.IEnumerable enumerable && value is not string)\n    {\n        var sb = new StringBuilder();\n        sb.Append('[');\n        var first = true;\n        foreach (var item in enumerable)\n        {\n            if (!first) sb.Append(\", \");\n            first = false;\n            sb.Append(item is null ? \"null\" : item.ToString());\n        }\n        sb.Append(']');\n        Console.WriteLine(sb.ToString());\n        return;\n    }\n\n    Console.WriteLine(value);\n}\n",
241        );
242        for snippet in &self.snippets {
243            source.push_str(snippet);
244            if !snippet.ends_with('\n') {
245                source.push('\n');
246            }
247        }
248        source
249    }
250
251    fn write_source(&self, contents: &str) -> Result<()> {
252        fs::write(&self.program_path, contents).with_context(|| {
253            format!(
254                "failed to write generated C# REPL source to {}",
255                self.program_path.display()
256            )
257        })
258    }
259
260    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
261        let source = self.render_source();
262        self.write_source(&source)?;
263
264        let output = run_dotnet_project(&self.runtime, &self.project_path, self.dir.path())?;
265        let stdout_full = String::from_utf8_lossy(&output.stdout).into_owned();
266        let stderr_full = String::from_utf8_lossy(&output.stderr).into_owned();
267
268        let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
269        let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
270
271        let success = output.status.success();
272        if success {
273            self.previous_stdout = stdout_full;
274            self.previous_stderr = stderr_full;
275        }
276
277        let outcome = ExecutionOutcome {
278            language: "csharp".to_string(),
279            exit_code: output.status.code(),
280            stdout: stdout_delta,
281            stderr: stderr_delta,
282            duration: start.elapsed(),
283        };
284
285        Ok((outcome, success))
286    }
287
288    fn run_snippet(&mut self, snippet: String) -> Result<ExecutionOutcome> {
289        self.snippets.push(snippet);
290        let start = Instant::now();
291        let (outcome, success) = self.run_current(start)?;
292        if !success {
293            let _ = self.snippets.pop();
294        }
295        Ok(outcome)
296    }
297
298    fn reset_state(&mut self) -> Result<()> {
299        self.snippets.clear();
300        self.previous_stdout.clear();
301        self.previous_stderr.clear();
302        let source = self.render_source();
303        self.write_source(&source)
304    }
305}
306
307impl LanguageSession for CSharpSession {
308    fn language_id(&self) -> &str {
309        "csharp"
310    }
311
312    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
313        let trimmed = code.trim();
314        if trimmed.is_empty() {
315            return Ok(ExecutionOutcome {
316                language: self.language_id().to_string(),
317                exit_code: None,
318                stdout: String::new(),
319                stderr: String::new(),
320                duration: Instant::now().elapsed(),
321            });
322        }
323
324        if trimmed.eq_ignore_ascii_case(":reset") {
325            self.reset_state()?;
326            return Ok(ExecutionOutcome {
327                language: self.language_id().to_string(),
328                exit_code: None,
329                stdout: String::new(),
330                stderr: String::new(),
331                duration: Duration::default(),
332            });
333        }
334
335        if trimmed.eq_ignore_ascii_case(":help") {
336            return Ok(ExecutionOutcome {
337                language: self.language_id().to_string(),
338                exit_code: None,
339                stdout:
340                    "C# commands:\n  :reset - clear session state\n  :help  - show this message\n"
341                        .to_string(),
342                stderr: String::new(),
343                duration: Duration::default(),
344            });
345        }
346
347        if should_treat_as_expression(trimmed) {
348            let snippet = wrap_expression(trimmed, self.snippets.len());
349            let outcome = self.run_snippet(snippet)?;
350            if outcome.exit_code.unwrap_or(0) == 0 {
351                return Ok(outcome);
352            }
353        }
354
355        let snippet = prepare_statement(code);
356        let outcome = self.run_snippet(snippet)?;
357        Ok(outcome)
358    }
359
360    fn shutdown(&mut self) -> Result<()> {
361        Ok(())
362    }
363}
364
365fn diff_output(previous: &str, current: &str) -> String {
366    if let Some(stripped) = current.strip_prefix(previous) {
367        stripped.to_string()
368    } else {
369        current.to_string()
370    }
371}
372
373fn should_treat_as_expression(code: &str) -> bool {
374    let trimmed = code.trim();
375    if trimmed.is_empty() {
376        return false;
377    }
378    if trimmed.contains('\n') {
379        return false;
380    }
381
382    let trimmed = trimmed.trim_end();
383    let without_trailing_semicolon = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
384    if without_trailing_semicolon.is_empty() {
385        return false;
386    }
387    if without_trailing_semicolon.contains(';') {
388        return false;
389    }
390
391    let lowered = without_trailing_semicolon.to_ascii_lowercase();
392    const KEYWORDS: [&str; 17] = [
393        "using ",
394        "namespace ",
395        "class ",
396        "struct ",
397        "record ",
398        "enum ",
399        "interface ",
400        "public ",
401        "private ",
402        "protected ",
403        "internal ",
404        "static ",
405        "if ",
406        "for ",
407        "while ",
408        "switch ",
409        "try ",
410    ];
411    if KEYWORDS.iter().any(|kw| lowered.starts_with(kw)) {
412        return false;
413    }
414    if lowered.starts_with("return ") || lowered.starts_with("throw ") {
415        return false;
416    }
417    if without_trailing_semicolon.starts_with("Console.")
418        || without_trailing_semicolon.starts_with("System.Console.")
419    {
420        return false;
421    }
422
423    if lowered.starts_with("new ") {
424        return true;
425    }
426
427    if without_trailing_semicolon.contains("++") || without_trailing_semicolon.contains("--") {
428        return false;
429    }
430
431    if without_trailing_semicolon.contains('=')
432        && !without_trailing_semicolon.contains("==")
433        && !without_trailing_semicolon.contains("!=")
434        && !without_trailing_semicolon.contains("<=")
435        && !without_trailing_semicolon.contains(">=")
436        && !without_trailing_semicolon.contains("=>")
437    {
438        return false;
439    }
440
441    const DECL_PREFIXES: [&str; 19] = [
442        "var ", "bool ", "byte ", "sbyte ", "char ", "short ", "ushort ", "int ", "uint ", "long ",
443        "ulong ", "float ", "double ", "decimal ", "string ", "object ", "dynamic ", "nint ",
444        "nuint ",
445    ];
446    if DECL_PREFIXES
447        .iter()
448        .any(|prefix| lowered.starts_with(prefix))
449    {
450        return false;
451    }
452
453    let expr = without_trailing_semicolon;
454
455    if expr == "true" || expr == "false" {
456        return true;
457    }
458    if expr.parse::<f64>().is_ok() {
459        return true;
460    }
461    if (expr.starts_with('"') || expr.starts_with("$\"")) && expr.ends_with('"') && expr.len() >= 2
462    {
463        return true;
464    }
465    if expr.starts_with('\'') && expr.ends_with('\'') && expr.len() >= 2 {
466        return true;
467    }
468
469    if expr.contains('(') && expr.ends_with(')') {
470        return true;
471    }
472
473    if expr.contains('[') && expr.ends_with(']') {
474        return true;
475    }
476
477    if expr.contains('.')
478        && expr
479            .chars()
480            .all(|c| !c.is_whitespace() && c != '{' && c != '}' && c != ';')
481        && expr
482            .chars()
483            .last()
484            .is_some_and(|c| c.is_ascii_alphanumeric() || c == '_')
485    {
486        return true;
487    }
488
489    if expr.contains("==")
490        || expr.contains("!=")
491        || expr.contains("<=")
492        || expr.contains(">=")
493        || expr.contains("&&")
494        || expr.contains("||")
495    {
496        return true;
497    }
498    if expr.contains('?') && expr.contains(':') {
499        return true;
500    }
501    if expr.chars().any(|c| "+-*/%<>^|&".contains(c)) {
502        return true;
503    }
504
505    if expr
506        .chars()
507        .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
508    {
509        return true;
510    }
511
512    false
513}
514
515fn wrap_expression(code: &str, index: usize) -> String {
516    let expr = code.trim().trim_end_matches(';').trim_end();
517    let expr = match expr {
518        "null" => "(object)null",
519        "default" => "(object)null",
520        other => other,
521    };
522    format!("var __repl_val_{index} = ({expr});\n__run_print(__repl_val_{index});\n")
523}
524
525fn prepare_statement(code: &str) -> String {
526    let trimmed_end = code.trim_end_matches(['\r', '\n']);
527    if trimmed_end.contains('\n') {
528        let mut snippet = trimmed_end.to_string();
529        if !snippet.ends_with('\n') {
530            snippet.push('\n');
531        }
532        return snippet;
533    }
534
535    let line = trimmed_end.trim();
536    if line.is_empty() {
537        return "\n".to_string();
538    }
539
540    let lowered = line.to_ascii_lowercase();
541    let starts_with_control = [
542        "if ",
543        "for ",
544        "while ",
545        "switch ",
546        "try",
547        "catch",
548        "finally",
549        "else",
550        "do",
551        "using ",
552        "namespace ",
553        "class ",
554        "struct ",
555        "record ",
556        "enum ",
557        "interface ",
558    ]
559    .iter()
560    .any(|kw| lowered.starts_with(kw));
561
562    let looks_like_expr_stmt = line.ends_with("++")
563        || line.ends_with("--")
564        || line.starts_with("++")
565        || line.starts_with("--")
566        || line.contains('=')
567        || (line.contains('(') && line.ends_with(')'));
568
569    let mut snippet = String::new();
570    snippet.push_str(line);
571    if !line.ends_with(';') && !starts_with_control && looks_like_expr_stmt {
572        snippet.push(';');
573    }
574    snippet.push('\n');
575    snippet
576}
577
578fn resolve_dotnet_runtime() -> Option<PathBuf> {
579    which::which("dotnet").ok()
580}
581
582fn detect_target_framework(dotnet: &Path) -> Result<String> {
583    let output = Command::new(dotnet)
584        .arg("--list-sdks")
585        .stdout(Stdio::piped())
586        .stderr(Stdio::null())
587        .output()
588        .with_context(|| format!("failed to query SDKs via {}", dotnet.display()))?;
589
590    if !output.status.success() {
591        bail!(
592            "{} --list-sdks exited with status {}",
593            dotnet.display(),
594            output.status
595        );
596    }
597
598    let stdout = String::from_utf8_lossy(&output.stdout);
599    let mut best: Option<(u32, u32, String)> = None;
600
601    for line in stdout.lines() {
602        let version = line.split_whitespace().next().unwrap_or("");
603        if version.is_empty() {
604            continue;
605        }
606        if let Some((major, minor)) = parse_version(version) {
607            let tfm = format!("net{}.{}", major, minor);
608            match &best {
609                Some((b_major, b_minor, _)) if (*b_major, *b_minor) >= (major, minor) => {}
610                _ => best = Some((major, minor, tfm)),
611            }
612        }
613    }
614
615    best.map(|(_, _, tfm)| tfm).ok_or_else(|| {
616        anyhow::anyhow!("unable to infer target framework from dotnet --list-sdks output")
617    })
618}
619
620fn parse_version(version: &str) -> Option<(u32, u32)> {
621    let mut parts = version.split('.');
622    let major = parts.next()?.parse().ok()?;
623    let minor = parts.next().unwrap_or("0").parse().ok()?;
624    Some((major, minor))
625}
626
627fn run_dotnet_project(
628    runtime: &Path,
629    project: &Path,
630    workdir: &Path,
631) -> Result<std::process::Output> {
632    let mut cmd = Command::new(runtime);
633    cmd.arg("run")
634        .arg("--project")
635        .arg(project)
636        .arg("--nologo")
637        .stdout(Stdio::piped())
638        .stderr(Stdio::piped())
639        .current_dir(workdir);
640    cmd.env("DOTNET_CLI_TELEMETRY_OPTOUT", "1");
641    cmd.env("DOTNET_SKIP_FIRST_TIME_EXPERIENCE", "1");
642    cmd.output().with_context(|| {
643        format!(
644            "failed to execute dotnet run for project {} using {}",
645            project.display(),
646            runtime.display()
647        )
648    })
649}