Skip to main content

run/engine/
haskell.rs

1use std::collections::BTreeSet;
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
11
12pub struct HaskellEngine {
13    executable: Option<PathBuf>,
14}
15
16impl HaskellEngine {
17    pub fn new() -> Self {
18        Self {
19            executable: resolve_runghc_binary(),
20        }
21    }
22
23    fn ensure_executable(&self) -> Result<&Path> {
24        self.executable.as_deref().ok_or_else(|| {
25            anyhow::anyhow!(
26                "Haskell support requires the `runghc` executable. Install the GHC toolchain from https://www.haskell.org/ghc/ (or via ghcup) and ensure `runghc` is on your PATH."
27            )
28        })
29    }
30
31    fn write_temp_source(&self, code: &str) -> Result<(TempDir, PathBuf)> {
32        let dir = Builder::new()
33            .prefix("run-haskell")
34            .tempdir()
35            .context("failed to create temporary directory for Haskell source")?;
36        let path = dir.path().join("snippet.hs");
37        let mut contents = code.to_string();
38        if !contents.ends_with('\n') {
39            contents.push('\n');
40        }
41        fs::write(&path, contents).with_context(|| {
42            format!(
43                "failed to write temporary Haskell source to {}",
44                path.display()
45            )
46        })?;
47        Ok((dir, path))
48    }
49
50    fn execute_path(&self, path: &Path) -> Result<std::process::Output> {
51        let executable = self.ensure_executable()?;
52        let mut cmd = Command::new(executable);
53        cmd.arg(path).stdout(Stdio::piped()).stderr(Stdio::piped());
54        cmd.stdin(Stdio::inherit());
55        if let Some(parent) = path.parent() {
56            cmd.current_dir(parent);
57        }
58        cmd.output().with_context(|| {
59            format!(
60                "failed to execute {} with script {}",
61                executable.display(),
62                path.display()
63            )
64        })
65    }
66}
67
68impl LanguageEngine for HaskellEngine {
69    fn id(&self) -> &'static str {
70        "haskell"
71    }
72
73    fn display_name(&self) -> &'static str {
74        "Haskell"
75    }
76
77    fn aliases(&self) -> &[&'static str] {
78        &["hs", "ghci"]
79    }
80
81    fn supports_sessions(&self) -> bool {
82        self.executable.is_some()
83    }
84
85    fn validate(&self) -> Result<()> {
86        let executable = self.ensure_executable()?;
87        let mut cmd = Command::new(executable);
88        cmd.arg("--version")
89            .stdout(Stdio::null())
90            .stderr(Stdio::null());
91        cmd.status()
92            .with_context(|| format!("failed to invoke {}", executable.display()))?
93            .success()
94            .then_some(())
95            .ok_or_else(|| anyhow::anyhow!("{} is not executable", executable.display()))
96    }
97
98    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
99        let start = Instant::now();
100        let (temp_dir, path) = match payload {
101            ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
102                let (dir, path) = self.write_temp_source(code)?;
103                (Some(dir), path)
104            }
105            ExecutionPayload::File { path } => (None, path.clone()),
106        };
107
108        let output = self.execute_path(&path)?;
109        drop(temp_dir);
110
111        Ok(ExecutionOutcome {
112            language: self.id().to_string(),
113            exit_code: output.status.code(),
114            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
115            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
116            duration: start.elapsed(),
117        })
118    }
119
120    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
121        let executable = self.ensure_executable()?.to_path_buf();
122        Ok(Box::new(HaskellSession::new(executable)?))
123    }
124}
125
126fn resolve_runghc_binary() -> Option<PathBuf> {
127    which::which("runghc").ok()
128}
129
130#[derive(Default)]
131struct HaskellSessionState {
132    imports: BTreeSet<String>,
133    declarations: Vec<String>,
134    statements: Vec<String>,
135}
136
137struct HaskellSession {
138    executable: PathBuf,
139    workspace: TempDir,
140    state: HaskellSessionState,
141    previous_stdout: String,
142    previous_stderr: String,
143}
144
145impl HaskellSession {
146    fn new(executable: PathBuf) -> Result<Self> {
147        let workspace = Builder::new()
148            .prefix("run-haskell-repl")
149            .tempdir()
150            .context("failed to create temporary directory for Haskell repl")?;
151        let session = Self {
152            executable,
153            workspace,
154            state: HaskellSessionState::default(),
155            previous_stdout: String::new(),
156            previous_stderr: String::new(),
157        };
158        session.persist_source()?;
159        Ok(session)
160    }
161
162    fn source_path(&self) -> PathBuf {
163        self.workspace.path().join("session.hs")
164    }
165
166    fn persist_source(&self) -> Result<()> {
167        let source = self.render_source();
168        fs::write(self.source_path(), source)
169            .with_context(|| "failed to write Haskell session source".to_string())
170    }
171
172    fn render_source(&self) -> String {
173        let mut source = String::new();
174        source.push_str("import Prelude\n");
175        for import in &self.state.imports {
176            source.push_str(import);
177            if !import.ends_with('\n') {
178                source.push('\n');
179            }
180        }
181        source.push('\n');
182
183        for decl in &self.state.declarations {
184            source.push_str(decl);
185            if !decl.ends_with('\n') {
186                source.push('\n');
187            }
188            source.push('\n');
189        }
190
191        source.push_str("main :: IO ()\n");
192        source.push_str("main = do\n");
193        if self.state.statements.is_empty() {
194            source.push_str("    return ()\n");
195        } else {
196            for stmt in &self.state.statements {
197                source.push_str(stmt);
198                if !stmt.ends_with('\n') {
199                    source.push('\n');
200                }
201            }
202
203            if let Some(last) = self.state.statements.last() {
204                if last.trim().starts_with("let ") {
205                    source.push_str("    return ()\n");
206                }
207            }
208        }
209
210        source
211    }
212
213    fn run_program(&self) -> Result<std::process::Output> {
214        let mut cmd = Command::new(&self.executable);
215        cmd.arg("session.hs")
216            .stdout(Stdio::piped())
217            .stderr(Stdio::piped())
218            .current_dir(self.workspace.path());
219        cmd.output().with_context(|| {
220            format!(
221                "failed to execute {} for Haskell session",
222                self.executable.display()
223            )
224        })
225    }
226
227    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
228        self.persist_source()?;
229        let output = self.run_program()?;
230        let stdout_full = normalize_output(&output.stdout);
231        let stderr_full = normalize_output(&output.stderr);
232
233        let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
234        let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
235
236        let success = output.status.success();
237        if success {
238            self.previous_stdout = stdout_full;
239            self.previous_stderr = stderr_full;
240        }
241
242        let outcome = ExecutionOutcome {
243            language: "haskell".to_string(),
244            exit_code: output.status.code(),
245            stdout: stdout_delta,
246            stderr: stderr_delta,
247            duration: start.elapsed(),
248        };
249
250        Ok((outcome, success))
251    }
252
253    fn apply_import(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
254        let mut inserted = Vec::new();
255        for line in code.lines() {
256            let trimmed = line.trim();
257            if trimmed.is_empty() {
258                continue;
259            }
260            let normalized = trimmed.to_string();
261            if self.state.imports.insert(normalized.clone()) {
262                inserted.push(normalized);
263            }
264        }
265
266        if inserted.is_empty() {
267            return Ok((
268                ExecutionOutcome {
269                    language: "haskell".to_string(),
270                    exit_code: None,
271                    stdout: String::new(),
272                    stderr: String::new(),
273                    duration: Duration::default(),
274                },
275                true,
276            ));
277        }
278
279        let start = Instant::now();
280        let (outcome, success) = self.run_current(start)?;
281        if !success {
282            for item in inserted {
283                self.state.imports.remove(&item);
284            }
285            self.persist_source()?;
286        }
287        Ok((outcome, success))
288    }
289
290    fn apply_declaration(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
291        let snippet = ensure_trailing_newline(code);
292        self.state.declarations.push(snippet);
293        let start = Instant::now();
294        let (outcome, success) = self.run_current(start)?;
295        if !success {
296            let _ = self.state.declarations.pop();
297            self.persist_source()?;
298        }
299        Ok((outcome, success))
300    }
301
302    fn apply_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
303        let snippet = indent_block(code);
304        self.state.statements.push(snippet);
305        let start = Instant::now();
306        let (outcome, success) = self.run_current(start)?;
307        if !success {
308            let _ = self.state.statements.pop();
309            self.persist_source()?;
310        }
311        Ok((outcome, success))
312    }
313
314    fn apply_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
315        let wrapped = wrap_expression(code);
316        self.state.statements.push(wrapped);
317        let start = Instant::now();
318        let (outcome, success) = self.run_current(start)?;
319        if !success {
320            let _ = self.state.statements.pop();
321            self.persist_source()?;
322        }
323        Ok((outcome, success))
324    }
325
326    fn reset(&mut self) -> Result<()> {
327        self.state.imports.clear();
328        self.state.declarations.clear();
329        self.state.statements.clear();
330        self.previous_stdout.clear();
331        self.previous_stderr.clear();
332        self.persist_source()
333    }
334}
335
336impl LanguageSession for HaskellSession {
337    fn language_id(&self) -> &str {
338        "haskell"
339    }
340
341    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
342        let trimmed = code.trim();
343        if trimmed.is_empty() {
344            return Ok(ExecutionOutcome {
345                language: "haskell".to_string(),
346                exit_code: None,
347                stdout: String::new(),
348                stderr: String::new(),
349                duration: Duration::default(),
350            });
351        }
352
353        if trimmed.eq_ignore_ascii_case(":reset") {
354            self.reset()?;
355            return Ok(ExecutionOutcome {
356                language: "haskell".to_string(),
357                exit_code: None,
358                stdout: String::new(),
359                stderr: String::new(),
360                duration: Duration::default(),
361            });
362        }
363
364        if trimmed.eq_ignore_ascii_case(":help") {
365            return Ok(ExecutionOutcome {
366                language: "haskell".to_string(),
367                exit_code: None,
368                stdout: "Haskell commands:\n  :reset - clear session state\n  :help  - show this message\n"
369                    .to_string(),
370                stderr: String::new(),
371                duration: Duration::default(),
372            });
373        }
374
375        match classify_snippet(trimmed) {
376            HaskellSnippet::Import => {
377                let (outcome, _) = self.apply_import(code)?;
378                Ok(outcome)
379            }
380            HaskellSnippet::Declaration => {
381                let (outcome, _) = self.apply_declaration(code)?;
382                Ok(outcome)
383            }
384            HaskellSnippet::Expression => {
385                let (outcome, _) = self.apply_expression(trimmed)?;
386                Ok(outcome)
387            }
388            HaskellSnippet::Statement => {
389                let (outcome, _) = self.apply_statement(code)?;
390                Ok(outcome)
391            }
392        }
393    }
394
395    fn shutdown(&mut self) -> Result<()> {
396        Ok(())
397    }
398}
399
400enum HaskellSnippet {
401    Import,
402    Declaration,
403    Statement,
404    Expression,
405}
406
407fn classify_snippet(code: &str) -> HaskellSnippet {
408    if is_import(code) {
409        return HaskellSnippet::Import;
410    }
411
412    if is_declaration(code) {
413        return HaskellSnippet::Declaration;
414    }
415
416    if should_wrap_expression(code) {
417        return HaskellSnippet::Expression;
418    }
419
420    HaskellSnippet::Statement
421}
422
423fn is_import(code: &str) -> bool {
424    code.lines()
425        .all(|line| line.trim_start().starts_with("import "))
426}
427
428fn is_declaration(code: &str) -> bool {
429    let trimmed = code.trim_start();
430    if trimmed.starts_with("let ") {
431        return false;
432    }
433    let lowered = trimmed.to_ascii_lowercase();
434    const PREFIXES: [&str; 8] = [
435        "module ",
436        "data ",
437        "type ",
438        "newtype ",
439        "class ",
440        "instance ",
441        "foreign ",
442        "default ",
443    ];
444    if PREFIXES.iter().any(|prefix| lowered.starts_with(prefix)) {
445        return true;
446    }
447
448    if trimmed.contains("::") {
449        return true;
450    }
451
452    if !trimmed.contains('=') {
453        return false;
454    }
455
456    if let Some(lhs) = trimmed.split('=').next() {
457        let lhs = lhs.trim();
458        if lhs.is_empty() {
459            return false;
460        }
461        let first_token = lhs.split_whitespace().next().unwrap_or("");
462        if first_token.eq_ignore_ascii_case("let") {
463            return false;
464        }
465        first_token
466            .chars()
467            .next()
468            .map(|c| c.is_alphabetic())
469            .unwrap_or(false)
470    } else {
471        false
472    }
473}
474
475fn should_wrap_expression(code: &str) -> bool {
476    if code.contains('\n') {
477        return false;
478    }
479
480    let trimmed = code.trim();
481    if trimmed.is_empty() {
482        return false;
483    }
484
485    let lowered = trimmed.to_ascii_lowercase();
486    const STATEMENT_PREFIXES: [&str; 11] = [
487        "let ",
488        "case ",
489        "if ",
490        "do ",
491        "import ",
492        "module ",
493        "data ",
494        "type ",
495        "newtype ",
496        "class ",
497        "instance ",
498    ];
499
500    if STATEMENT_PREFIXES
501        .iter()
502        .any(|prefix| lowered.starts_with(prefix))
503    {
504        return false;
505    }
506
507    if trimmed.contains('=') || trimmed.contains("->") || trimmed.contains("<-") {
508        return false;
509    }
510
511    true
512}
513
514fn ensure_trailing_newline(code: &str) -> String {
515    let mut owned = code.to_string();
516    if !owned.ends_with('\n') {
517        owned.push('\n');
518    }
519    owned
520}
521
522fn indent_block(code: &str) -> String {
523    let mut result = String::new();
524    for line in code.split_inclusive('\n') {
525        if line.ends_with('\n') {
526            result.push_str("    ");
527            result.push_str(line);
528        } else {
529            result.push_str("    ");
530            result.push_str(line);
531            result.push('\n');
532        }
533    }
534    result
535}
536
537fn wrap_expression(code: &str) -> String {
538    indent_block(&format!("print (({}))\n", code.trim()))
539}
540
541fn diff_output(previous: &str, current: &str) -> String {
542    if let Some(stripped) = current.strip_prefix(previous) {
543        stripped.to_string()
544    } else {
545        current.to_string()
546    }
547}
548
549fn normalize_output(bytes: &[u8]) -> String {
550    String::from_utf8_lossy(bytes)
551        .replace("\r\n", "\n")
552        .replace('\r', "")
553}