run/engine/
haskell.rs

1use std::collections::BTreeSet;
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
11
12pub struct HaskellEngine {
13    executable: Option<PathBuf>,
14}
15
16impl HaskellEngine {
17    pub fn new() -> Self {
18        Self {
19            executable: resolve_runghc_binary(),
20        }
21    }
22
23    fn ensure_executable(&self) -> Result<&Path> {
24        self.executable.as_deref().ok_or_else(|| {
25            anyhow::anyhow!(
26                "Haskell support requires the `runghc` executable. Install the GHC toolchain from https://www.haskell.org/ghc/ (or via ghcup) and ensure `runghc` is on your PATH."
27            )
28        })
29    }
30
31    fn write_temp_source(&self, code: &str) -> Result<(TempDir, PathBuf)> {
32        let dir = Builder::new()
33            .prefix("run-haskell")
34            .tempdir()
35            .context("failed to create temporary directory for Haskell source")?;
36        let path = dir.path().join("snippet.hs");
37        let mut contents = code.to_string();
38        if !contents.ends_with('\n') {
39            contents.push('\n');
40        }
41        fs::write(&path, contents).with_context(|| {
42            format!(
43                "failed to write temporary Haskell source to {}",
44                path.display()
45            )
46        })?;
47        Ok((dir, path))
48    }
49
50    fn execute_path(&self, path: &Path) -> Result<std::process::Output> {
51        let executable = self.ensure_executable()?;
52        let mut cmd = Command::new(executable);
53        cmd.arg(path).stdout(Stdio::piped()).stderr(Stdio::piped());
54        cmd.stdin(Stdio::inherit());
55        if let Some(parent) = path.parent() {
56            cmd.current_dir(parent);
57        }
58        cmd.output().with_context(|| {
59            format!(
60                "failed to execute {} with script {}",
61                executable.display(),
62                path.display()
63            )
64        })
65    }
66}
67
68impl LanguageEngine for HaskellEngine {
69    fn id(&self) -> &'static str {
70        "haskell"
71    }
72
73    fn display_name(&self) -> &'static str {
74        "Haskell"
75    }
76
77    fn aliases(&self) -> &[&'static str] {
78        &["hs", "ghci"]
79    }
80
81    fn supports_sessions(&self) -> bool {
82        self.executable.is_some()
83    }
84
85    fn validate(&self) -> Result<()> {
86        let executable = self.ensure_executable()?;
87        let mut cmd = Command::new(executable);
88        cmd.arg("--version")
89            .stdout(Stdio::null())
90            .stderr(Stdio::null());
91        cmd.status()
92            .with_context(|| format!("failed to invoke {}", executable.display()))?
93            .success()
94            .then_some(())
95            .ok_or_else(|| anyhow::anyhow!("{} is not executable", executable.display()))
96    }
97
98    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
99        let start = Instant::now();
100        let (temp_dir, path) = match payload {
101            ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
102                let (dir, path) = self.write_temp_source(code)?;
103                (Some(dir), path)
104            }
105            ExecutionPayload::File { path } => (None, path.clone()),
106        };
107
108        let output = self.execute_path(&path)?;
109        drop(temp_dir);
110
111        Ok(ExecutionOutcome {
112            language: self.id().to_string(),
113            exit_code: output.status.code(),
114            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
115            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
116            duration: start.elapsed(),
117        })
118    }
119
120    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
121        let executable = self.ensure_executable()?.to_path_buf();
122        Ok(Box::new(HaskellSession::new(executable)?))
123    }
124}
125
126fn resolve_runghc_binary() -> Option<PathBuf> {
127    which::which("runghc").ok()
128}
129
130#[derive(Default)]
131struct HaskellSessionState {
132    imports: BTreeSet<String>,
133    declarations: Vec<String>,
134    statements: Vec<String>,
135}
136
137struct HaskellSession {
138    executable: PathBuf,
139    workspace: TempDir,
140    state: HaskellSessionState,
141    previous_stdout: String,
142    previous_stderr: String,
143}
144
145impl HaskellSession {
146    fn new(executable: PathBuf) -> Result<Self> {
147        let workspace = Builder::new()
148            .prefix("run-haskell-repl")
149            .tempdir()
150            .context("failed to create temporary directory for Haskell repl")?;
151        let session = Self {
152            executable,
153            workspace,
154            state: HaskellSessionState::default(),
155            previous_stdout: String::new(),
156            previous_stderr: String::new(),
157        };
158        session.persist_source()?;
159        Ok(session)
160    }
161
162    fn source_path(&self) -> PathBuf {
163        self.workspace.path().join("session.hs")
164    }
165
166    fn persist_source(&self) -> Result<()> {
167        let source = self.render_source();
168        fs::write(self.source_path(), source)
169            .with_context(|| "failed to write Haskell session source".to_string())
170    }
171
172    fn render_source(&self) -> String {
173        let mut source = String::new();
174        source.push_str("import Prelude\n");
175        for import in &self.state.imports {
176            source.push_str(import);
177            if !import.ends_with('\n') {
178                source.push('\n');
179            }
180        }
181        source.push('\n');
182
183        for decl in &self.state.declarations {
184            source.push_str(decl);
185            if !decl.ends_with('\n') {
186                source.push('\n');
187            }
188            source.push('\n');
189        }
190
191        source.push_str("main :: IO ()\n");
192        source.push_str("main = do\n");
193        if self.state.statements.is_empty() {
194            source.push_str("    return ()\n");
195        } else {
196            for stmt in &self.state.statements {
197                source.push_str(stmt);
198                if !stmt.ends_with('\n') {
199                    source.push('\n');
200                }
201            }
202
203            if let Some(last) = self.state.statements.last() {
204                if last.trim().starts_with("let ") {
205                    source.push_str("    return ()\n");
206                }
207            }
208        }
209
210        source
211    }
212
213    fn run_program(&self) -> Result<std::process::Output> {
214        let mut cmd = Command::new(&self.executable);
215        cmd.arg("session.hs")
216            .stdout(Stdio::piped())
217            .stderr(Stdio::piped())
218            .current_dir(self.workspace.path());
219        cmd.output().with_context(|| {
220            format!(
221                "failed to execute {} for Haskell session",
222                self.executable.display()
223            )
224        })
225    }
226
227    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
228        self.persist_source()?;
229        let output = self.run_program()?;
230        let stdout_full = normalize_output(&output.stdout);
231        let stderr_full = normalize_output(&output.stderr);
232
233        let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
234        let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
235
236        let success = output.status.success();
237        if success {
238            self.previous_stdout = stdout_full;
239            self.previous_stderr = stderr_full;
240        }
241
242        let outcome = ExecutionOutcome {
243            language: "haskell".to_string(),
244            exit_code: output.status.code(),
245            stdout: stdout_delta,
246            stderr: stderr_delta,
247            duration: start.elapsed(),
248        };
249
250        Ok((outcome, success))
251    }
252
253    fn apply_import(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
254        let mut inserted = Vec::new();
255        for line in code.lines() {
256            let trimmed = line.trim();
257            if trimmed.is_empty() {
258                continue;
259            }
260            let normalized = trimmed.to_string();
261            if self.state.imports.insert(normalized.clone()) {
262                inserted.push(normalized);
263            }
264        }
265
266        if inserted.is_empty() {
267            return Ok((
268                ExecutionOutcome {
269                    language: "haskell".to_string(),
270                    exit_code: None,
271                    stdout: String::new(),
272                    stderr: String::new(),
273                    duration: Duration::default(),
274                },
275                true,
276            ));
277        }
278
279        let start = Instant::now();
280        let (outcome, success) = self.run_current(start)?;
281        if !success {
282            for item in inserted {
283                self.state.imports.remove(&item);
284            }
285            self.persist_source()?;
286        }
287        Ok((outcome, success))
288    }
289
290    fn apply_declaration(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
291        let snippet = ensure_trailing_newline(code);
292        self.state.declarations.push(snippet);
293        let start = Instant::now();
294        let (outcome, success) = self.run_current(start)?;
295        if !success {
296            let _ = self.state.declarations.pop();
297            self.persist_source()?;
298        }
299        Ok((outcome, success))
300    }
301
302    fn apply_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
303        let snippet = indent_block(code);
304        self.state.statements.push(snippet);
305        let start = Instant::now();
306        let (outcome, success) = self.run_current(start)?;
307        if !success {
308            let _ = self.state.statements.pop();
309            self.persist_source()?;
310        }
311        Ok((outcome, success))
312    }
313
314    fn apply_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
315        let wrapped = wrap_expression(code);
316        self.state.statements.push(wrapped);
317        let start = Instant::now();
318        let (outcome, success) = self.run_current(start)?;
319        if !success {
320            let _ = self.state.statements.pop();
321            self.persist_source()?;
322        }
323        Ok((outcome, success))
324    }
325
326    fn reset(&mut self) -> Result<()> {
327        self.state.imports.clear();
328        self.state.declarations.clear();
329        self.state.statements.clear();
330        self.previous_stdout.clear();
331        self.previous_stderr.clear();
332        self.persist_source()
333    }
334}
335
336impl LanguageSession for HaskellSession {
337    fn language_id(&self) -> &str {
338        "haskell"
339    }
340
341    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
342        let trimmed = code.trim();
343        if trimmed.is_empty() {
344            return Ok(ExecutionOutcome {
345                language: "haskell".to_string(),
346                exit_code: None,
347                stdout: String::new(),
348                stderr: String::new(),
349                duration: Duration::default(),
350            });
351        }
352
353        if trimmed.eq_ignore_ascii_case(":reset") {
354            self.reset()?;
355            return Ok(ExecutionOutcome {
356                language: "haskell".to_string(),
357                exit_code: None,
358                stdout: String::new(),
359                stderr: String::new(),
360                duration: Duration::default(),
361            });
362        }
363
364        if trimmed.eq_ignore_ascii_case(":help") {
365            return Ok(ExecutionOutcome {
366                language: "haskell".to_string(),
367                exit_code: None,
368                stdout: "Haskell commands:\n  :reset — clear session state\n  :help  — show this message\n"
369                    .to_string(),
370                stderr: String::new(),
371                duration: Duration::default(),
372            });
373        }
374
375        match classify_snippet(trimmed) {
376            HaskellSnippet::Import => {
377                let (outcome, _) = self.apply_import(code)?;
378                Ok(outcome)
379            }
380            HaskellSnippet::Declaration => {
381                let (outcome, _) = self.apply_declaration(code)?;
382                Ok(outcome)
383            }
384            HaskellSnippet::Expression => {
385                let (outcome, _) = self.apply_expression(trimmed)?;
386                Ok(outcome)
387            }
388            HaskellSnippet::Statement => {
389                let (outcome, _) = self.apply_statement(code)?;
390                Ok(outcome)
391            }
392        }
393    }
394
395    fn shutdown(&mut self) -> Result<()> {
396        Ok(())
397    }
398}
399
400enum HaskellSnippet {
401    Import,
402    Declaration,
403    Statement,
404    Expression,
405}
406
407fn classify_snippet(code: &str) -> HaskellSnippet {
408    if is_import(code) {
409        return HaskellSnippet::Import;
410    }
411
412    if is_declaration(code) {
413        return HaskellSnippet::Declaration;
414    }
415
416    if should_wrap_expression(code) {
417        return HaskellSnippet::Expression;
418    }
419
420    HaskellSnippet::Statement
421}
422
423fn is_import(code: &str) -> bool {
424    code.lines()
425        .all(|line| line.trim_start().starts_with("import "))
426}
427
428fn is_declaration(code: &str) -> bool {
429    let trimmed = code.trim_start();
430    if trimmed.starts_with("let ") {
431        return false;
432    }
433    let lowered = trimmed.to_ascii_lowercase();
434    const PREFIXES: [&str; 8] = [
435        "module ",
436        "data ",
437        "type ",
438        "newtype ",
439        "class ",
440        "instance ",
441        "foreign ",
442        "default ",
443    ];
444    if PREFIXES.iter().any(|prefix| lowered.starts_with(prefix)) {
445        return true;
446    }
447
448    if trimmed.contains("::") {
449        return true;
450    }
451
452    // simple function definition detection: name args =
453    // Must contain '=' to be a declaration
454    if !trimmed.contains('=') {
455        return false;
456    }
457
458    if let Some(lhs) = trimmed.split('=').next() {
459        let lhs = lhs.trim();
460        if lhs.is_empty() {
461            return false;
462        }
463        let first_token = lhs.split_whitespace().next().unwrap_or("");
464        if first_token.eq_ignore_ascii_case("let") {
465            return false;
466        }
467        first_token
468            .chars()
469            .next()
470            .map(|c| c.is_alphabetic())
471            .unwrap_or(false)
472    } else {
473        false
474    }
475}
476
477fn should_wrap_expression(code: &str) -> bool {
478    if code.contains('\n') {
479        return false;
480    }
481
482    let trimmed = code.trim();
483    if trimmed.is_empty() {
484        return false;
485    }
486
487    let lowered = trimmed.to_ascii_lowercase();
488    const STATEMENT_PREFIXES: [&str; 11] = [
489        "let ",
490        "case ",
491        "if ",
492        "do ",
493        "import ",
494        "module ",
495        "data ",
496        "type ",
497        "newtype ",
498        "class ",
499        "instance ",
500    ];
501
502    if STATEMENT_PREFIXES
503        .iter()
504        .any(|prefix| lowered.starts_with(prefix))
505    {
506        return false;
507    }
508
509    if trimmed.contains('=') || trimmed.contains("->") || trimmed.contains("<-") {
510        return false;
511    }
512
513    true
514}
515
516fn ensure_trailing_newline(code: &str) -> String {
517    let mut owned = code.to_string();
518    if !owned.ends_with('\n') {
519        owned.push('\n');
520    }
521    owned
522}
523
524fn indent_block(code: &str) -> String {
525    let mut result = String::new();
526    for line in code.split_inclusive('\n') {
527        if line.ends_with('\n') {
528            result.push_str("    ");
529            result.push_str(line);
530        } else {
531            result.push_str("    ");
532            result.push_str(line);
533            result.push('\n');
534        }
535    }
536    result
537}
538
539fn wrap_expression(code: &str) -> String {
540    indent_block(&format!("print (({}))\n", code.trim()))
541}
542
543fn diff_output(previous: &str, current: &str) -> String {
544    if let Some(stripped) = current.strip_prefix(previous) {
545        stripped.to_string()
546    } else {
547        current.to_string()
548    }
549}
550
551fn normalize_output(bytes: &[u8]) -> String {
552    String::from_utf8_lossy(bytes)
553        .replace("\r\n", "\n")
554        .replace('\r', "")
555}