Skip to main content

run/engine/
haskell.rs

1use std::collections::BTreeSet;
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
11
12pub struct HaskellEngine {
13    executable: Option<PathBuf>,
14}
15
16impl Default for HaskellEngine {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl HaskellEngine {
23    pub fn new() -> Self {
24        Self {
25            executable: resolve_runghc_binary(),
26        }
27    }
28
29    fn ensure_executable(&self) -> Result<&Path> {
30        self.executable.as_deref().ok_or_else(|| {
31            anyhow::anyhow!(
32                "Haskell support requires the `runghc` executable. Install the GHC toolchain from https://www.haskell.org/ghc/ (or via ghcup) and ensure `runghc` is on your PATH."
33            )
34        })
35    }
36
37    fn write_temp_source(&self, code: &str) -> Result<(TempDir, PathBuf)> {
38        let dir = Builder::new()
39            .prefix("run-haskell")
40            .tempdir()
41            .context("failed to create temporary directory for Haskell source")?;
42        let path = dir.path().join("snippet.hs");
43        let mut contents = code.to_string();
44        if !contents.ends_with('\n') {
45            contents.push('\n');
46        }
47        fs::write(&path, contents).with_context(|| {
48            format!(
49                "failed to write temporary Haskell source to {}",
50                path.display()
51            )
52        })?;
53        Ok((dir, path))
54    }
55
56    fn execute_path(&self, path: &Path) -> Result<std::process::Output> {
57        let executable = self.ensure_executable()?;
58        let mut cmd = Command::new(executable);
59        cmd.arg(path).stdout(Stdio::piped()).stderr(Stdio::piped());
60        cmd.stdin(Stdio::inherit());
61        if let Some(parent) = path.parent() {
62            cmd.current_dir(parent);
63        }
64        cmd.output().with_context(|| {
65            format!(
66                "failed to execute {} with script {}",
67                executable.display(),
68                path.display()
69            )
70        })
71    }
72}
73
74impl LanguageEngine for HaskellEngine {
75    fn id(&self) -> &'static str {
76        "haskell"
77    }
78
79    fn display_name(&self) -> &'static str {
80        "Haskell"
81    }
82
83    fn aliases(&self) -> &[&'static str] {
84        &["hs", "ghci"]
85    }
86
87    fn supports_sessions(&self) -> bool {
88        self.executable.is_some()
89    }
90
91    fn validate(&self) -> Result<()> {
92        let executable = self.ensure_executable()?;
93        let mut cmd = Command::new(executable);
94        cmd.arg("--version")
95            .stdout(Stdio::null())
96            .stderr(Stdio::null());
97        cmd.status()
98            .with_context(|| format!("failed to invoke {}", executable.display()))?
99            .success()
100            .then_some(())
101            .ok_or_else(|| anyhow::anyhow!("{} is not executable", executable.display()))
102    }
103
104    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
105        let start = Instant::now();
106        let (temp_dir, path) = match payload {
107            ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
108                let (dir, path) = self.write_temp_source(code)?;
109                (Some(dir), path)
110            }
111            ExecutionPayload::File { path } => (None, path.clone()),
112        };
113
114        let output = self.execute_path(&path)?;
115        drop(temp_dir);
116
117        Ok(ExecutionOutcome {
118            language: self.id().to_string(),
119            exit_code: output.status.code(),
120            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
121            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
122            duration: start.elapsed(),
123        })
124    }
125
126    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
127        let executable = self.ensure_executable()?.to_path_buf();
128        Ok(Box::new(HaskellSession::new(executable)?))
129    }
130}
131
132fn resolve_runghc_binary() -> Option<PathBuf> {
133    which::which("runghc").ok()
134}
135
136#[derive(Default)]
137struct HaskellSessionState {
138    imports: BTreeSet<String>,
139    declarations: Vec<String>,
140    statements: Vec<String>,
141}
142
143struct HaskellSession {
144    executable: PathBuf,
145    workspace: TempDir,
146    state: HaskellSessionState,
147    previous_stdout: String,
148    previous_stderr: String,
149}
150
151impl HaskellSession {
152    fn new(executable: PathBuf) -> Result<Self> {
153        let workspace = Builder::new()
154            .prefix("run-haskell-repl")
155            .tempdir()
156            .context("failed to create temporary directory for Haskell repl")?;
157        let session = Self {
158            executable,
159            workspace,
160            state: HaskellSessionState::default(),
161            previous_stdout: String::new(),
162            previous_stderr: String::new(),
163        };
164        session.persist_source()?;
165        Ok(session)
166    }
167
168    fn source_path(&self) -> PathBuf {
169        self.workspace.path().join("session.hs")
170    }
171
172    fn persist_source(&self) -> Result<()> {
173        let source = self.render_source();
174        fs::write(self.source_path(), source)
175            .with_context(|| "failed to write Haskell session source".to_string())
176    }
177
178    fn render_source(&self) -> String {
179        let mut source = String::new();
180        source.push_str("import Prelude\n");
181        for import in &self.state.imports {
182            source.push_str(import);
183            if !import.ends_with('\n') {
184                source.push('\n');
185            }
186        }
187        source.push('\n');
188
189        for decl in &self.state.declarations {
190            source.push_str(decl);
191            if !decl.ends_with('\n') {
192                source.push('\n');
193            }
194            source.push('\n');
195        }
196
197        source.push_str("main :: IO ()\n");
198        source.push_str("main = do\n");
199        if self.state.statements.is_empty() {
200            source.push_str("    return ()\n");
201        } else {
202            for stmt in &self.state.statements {
203                source.push_str(stmt);
204                if !stmt.ends_with('\n') {
205                    source.push('\n');
206                }
207            }
208
209            if let Some(last) = self.state.statements.last()
210                && last.trim().starts_with("let ")
211            {
212                source.push_str("    return ()\n");
213            }
214        }
215
216        source
217    }
218
219    fn run_program(&self) -> Result<std::process::Output> {
220        let mut cmd = Command::new(&self.executable);
221        cmd.arg("session.hs")
222            .stdout(Stdio::piped())
223            .stderr(Stdio::piped())
224            .current_dir(self.workspace.path());
225        cmd.output().with_context(|| {
226            format!(
227                "failed to execute {} for Haskell session",
228                self.executable.display()
229            )
230        })
231    }
232
233    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
234        self.persist_source()?;
235        let output = self.run_program()?;
236        let stdout_full = normalize_output(&output.stdout);
237        let stderr_full = normalize_output(&output.stderr);
238
239        let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
240        let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
241
242        let success = output.status.success();
243        if success {
244            self.previous_stdout = stdout_full;
245            self.previous_stderr = stderr_full;
246        }
247
248        let outcome = ExecutionOutcome {
249            language: "haskell".to_string(),
250            exit_code: output.status.code(),
251            stdout: stdout_delta,
252            stderr: stderr_delta,
253            duration: start.elapsed(),
254        };
255
256        Ok((outcome, success))
257    }
258
259    fn apply_import(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
260        let mut inserted = Vec::new();
261        for line in code.lines() {
262            let trimmed = line.trim();
263            if trimmed.is_empty() {
264                continue;
265            }
266            let normalized = trimmed.to_string();
267            if self.state.imports.insert(normalized.clone()) {
268                inserted.push(normalized);
269            }
270        }
271
272        if inserted.is_empty() {
273            return Ok((
274                ExecutionOutcome {
275                    language: "haskell".to_string(),
276                    exit_code: None,
277                    stdout: String::new(),
278                    stderr: String::new(),
279                    duration: Duration::default(),
280                },
281                true,
282            ));
283        }
284
285        let start = Instant::now();
286        let (outcome, success) = self.run_current(start)?;
287        if !success {
288            for item in inserted {
289                self.state.imports.remove(&item);
290            }
291            self.persist_source()?;
292        }
293        Ok((outcome, success))
294    }
295
296    fn apply_declaration(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
297        let snippet = ensure_trailing_newline(code);
298        self.state.declarations.push(snippet);
299        let start = Instant::now();
300        let (outcome, success) = self.run_current(start)?;
301        if !success {
302            let _ = self.state.declarations.pop();
303            self.persist_source()?;
304        }
305        Ok((outcome, success))
306    }
307
308    fn apply_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
309        let snippet = indent_block(code);
310        self.state.statements.push(snippet);
311        let start = Instant::now();
312        let (outcome, success) = self.run_current(start)?;
313        if !success {
314            let _ = self.state.statements.pop();
315            self.persist_source()?;
316        }
317        Ok((outcome, success))
318    }
319
320    fn apply_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
321        let wrapped = wrap_expression(code);
322        self.state.statements.push(wrapped);
323        let start = Instant::now();
324        let (outcome, success) = self.run_current(start)?;
325        if !success {
326            let _ = self.state.statements.pop();
327            self.persist_source()?;
328        }
329        Ok((outcome, success))
330    }
331
332    fn reset(&mut self) -> Result<()> {
333        self.state.imports.clear();
334        self.state.declarations.clear();
335        self.state.statements.clear();
336        self.previous_stdout.clear();
337        self.previous_stderr.clear();
338        self.persist_source()
339    }
340}
341
342impl LanguageSession for HaskellSession {
343    fn language_id(&self) -> &str {
344        "haskell"
345    }
346
347    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
348        let trimmed = code.trim();
349        if trimmed.is_empty() {
350            return Ok(ExecutionOutcome {
351                language: "haskell".to_string(),
352                exit_code: None,
353                stdout: String::new(),
354                stderr: String::new(),
355                duration: Duration::default(),
356            });
357        }
358
359        if trimmed.eq_ignore_ascii_case(":reset") {
360            self.reset()?;
361            return Ok(ExecutionOutcome {
362                language: "haskell".to_string(),
363                exit_code: None,
364                stdout: String::new(),
365                stderr: String::new(),
366                duration: Duration::default(),
367            });
368        }
369
370        if trimmed.eq_ignore_ascii_case(":help") {
371            return Ok(ExecutionOutcome {
372                language: "haskell".to_string(),
373                exit_code: None,
374                stdout: "Haskell commands:\n  :reset - clear session state\n  :help  - show this message\n"
375                    .to_string(),
376                stderr: String::new(),
377                duration: Duration::default(),
378            });
379        }
380
381        match classify_snippet(trimmed) {
382            HaskellSnippet::Import => {
383                let (outcome, _) = self.apply_import(code)?;
384                Ok(outcome)
385            }
386            HaskellSnippet::Declaration => {
387                let (outcome, _) = self.apply_declaration(code)?;
388                Ok(outcome)
389            }
390            HaskellSnippet::Expression => {
391                let (outcome, _) = self.apply_expression(trimmed)?;
392                Ok(outcome)
393            }
394            HaskellSnippet::Statement => {
395                let (outcome, _) = self.apply_statement(code)?;
396                Ok(outcome)
397            }
398        }
399    }
400
401    fn shutdown(&mut self) -> Result<()> {
402        Ok(())
403    }
404}
405
406enum HaskellSnippet {
407    Import,
408    Declaration,
409    Statement,
410    Expression,
411}
412
413fn classify_snippet(code: &str) -> HaskellSnippet {
414    if is_import(code) {
415        return HaskellSnippet::Import;
416    }
417
418    if is_declaration(code) {
419        return HaskellSnippet::Declaration;
420    }
421
422    if should_wrap_expression(code) {
423        return HaskellSnippet::Expression;
424    }
425
426    HaskellSnippet::Statement
427}
428
429fn is_import(code: &str) -> bool {
430    code.lines()
431        .all(|line| line.trim_start().starts_with("import "))
432}
433
434fn is_declaration(code: &str) -> bool {
435    let trimmed = code.trim_start();
436    if trimmed.starts_with("let ") {
437        return false;
438    }
439    let lowered = trimmed.to_ascii_lowercase();
440    const PREFIXES: [&str; 8] = [
441        "module ",
442        "data ",
443        "type ",
444        "newtype ",
445        "class ",
446        "instance ",
447        "foreign ",
448        "default ",
449    ];
450    if PREFIXES.iter().any(|prefix| lowered.starts_with(prefix)) {
451        return true;
452    }
453
454    if trimmed.contains("::") {
455        return true;
456    }
457
458    if !trimmed.contains('=') {
459        return false;
460    }
461
462    if let Some(lhs) = trimmed.split('=').next() {
463        let lhs = lhs.trim();
464        if lhs.is_empty() {
465            return false;
466        }
467        let first_token = lhs.split_whitespace().next().unwrap_or("");
468        if first_token.eq_ignore_ascii_case("let") {
469            return false;
470        }
471        first_token
472            .chars()
473            .next()
474            .map(|c| c.is_alphabetic())
475            .unwrap_or(false)
476    } else {
477        false
478    }
479}
480
481fn should_wrap_expression(code: &str) -> bool {
482    if code.contains('\n') {
483        return false;
484    }
485
486    let trimmed = code.trim();
487    if trimmed.is_empty() {
488        return false;
489    }
490
491    let lowered = trimmed.to_ascii_lowercase();
492    const STATEMENT_PREFIXES: [&str; 11] = [
493        "let ",
494        "case ",
495        "if ",
496        "do ",
497        "import ",
498        "module ",
499        "data ",
500        "type ",
501        "newtype ",
502        "class ",
503        "instance ",
504    ];
505
506    if STATEMENT_PREFIXES
507        .iter()
508        .any(|prefix| lowered.starts_with(prefix))
509    {
510        return false;
511    }
512
513    if trimmed.contains('=') || trimmed.contains("->") || trimmed.contains("<-") {
514        return false;
515    }
516
517    true
518}
519
520fn ensure_trailing_newline(code: &str) -> String {
521    let mut owned = code.to_string();
522    if !owned.ends_with('\n') {
523        owned.push('\n');
524    }
525    owned
526}
527
528fn indent_block(code: &str) -> String {
529    let mut result = String::new();
530    for line in code.split_inclusive('\n') {
531        if line.ends_with('\n') {
532            result.push_str("    ");
533            result.push_str(line);
534        } else {
535            result.push_str("    ");
536            result.push_str(line);
537            result.push('\n');
538        }
539    }
540    result
541}
542
543fn wrap_expression(code: &str) -> String {
544    indent_block(&format!("print (({}))\n", code.trim()))
545}
546
547fn diff_output(previous: &str, current: &str) -> String {
548    if let Some(stripped) = current.strip_prefix(previous) {
549        stripped.to_string()
550    } else {
551        current.to_string()
552    }
553}
554
555fn normalize_output(bytes: &[u8]) -> String {
556    String::from_utf8_lossy(bytes)
557        .replace("\r\n", "\n")
558        .replace('\r', "")
559}