Skip to main content

run/engine/
haskell.rs

1use std::collections::BTreeSet;
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{
11    ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, run_version_command,
12};
13
14pub struct HaskellEngine {
15    executable: Option<PathBuf>,
16}
17
18impl Default for HaskellEngine {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl HaskellEngine {
25    pub fn new() -> Self {
26        Self {
27            executable: resolve_runghc_binary(),
28        }
29    }
30
31    fn ensure_executable(&self) -> Result<&Path> {
32        self.executable.as_deref().ok_or_else(|| {
33            anyhow::anyhow!(
34                "Haskell support requires the `runghc` executable. Install the GHC toolchain from https://www.haskell.org/ghc/ (or via ghcup) and ensure `runghc` is on your PATH."
35            )
36        })
37    }
38
39    fn write_temp_source(&self, code: &str) -> Result<(TempDir, PathBuf)> {
40        let dir = Builder::new()
41            .prefix("run-haskell")
42            .tempdir()
43            .context("failed to create temporary directory for Haskell source")?;
44        let path = dir.path().join("snippet.hs");
45        let mut contents = code.to_string();
46        if !contents.ends_with('\n') {
47            contents.push('\n');
48        }
49        fs::write(&path, contents).with_context(|| {
50            format!(
51                "failed to write temporary Haskell source to {}",
52                path.display()
53            )
54        })?;
55        Ok((dir, path))
56    }
57
58    fn execute_path(&self, path: &Path, args: &[String]) -> Result<std::process::Output> {
59        let executable = self.ensure_executable()?;
60        let mut cmd = Command::new(executable);
61        cmd.arg(path)
62            .args(args)
63            .stdout(Stdio::piped())
64            .stderr(Stdio::piped());
65        cmd.stdin(Stdio::inherit());
66        if let Some(parent) = path.parent() {
67            cmd.current_dir(parent);
68        }
69        cmd.output().with_context(|| {
70            format!(
71                "failed to execute {} with script {}",
72                executable.display(),
73                path.display()
74            )
75        })
76    }
77}
78
79impl LanguageEngine for HaskellEngine {
80    fn id(&self) -> &'static str {
81        "haskell"
82    }
83
84    fn display_name(&self) -> &'static str {
85        "Haskell"
86    }
87
88    fn aliases(&self) -> &[&'static str] {
89        &["hs", "ghci"]
90    }
91
92    fn supports_sessions(&self) -> bool {
93        self.executable.is_some()
94    }
95
96    fn validate(&self) -> Result<()> {
97        let executable = self.ensure_executable()?;
98        let mut cmd = Command::new(executable);
99        cmd.arg("--version")
100            .stdout(Stdio::null())
101            .stderr(Stdio::null());
102        cmd.status()
103            .with_context(|| format!("failed to invoke {}", executable.display()))?
104            .success()
105            .then_some(())
106            .ok_or_else(|| anyhow::anyhow!("{} is not executable", executable.display()))
107    }
108
109    fn toolchain_version(&self) -> Result<Option<String>> {
110        let executable = self.ensure_executable()?;
111        let mut cmd = Command::new(executable);
112        cmd.arg("--version");
113        let context = format!("{}", executable.display());
114        run_version_command(cmd, &context)
115    }
116
117    fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
118        let start = Instant::now();
119        let (temp_dir, path) = match payload {
120            ExecutionPayload::Inline { code, .. } | ExecutionPayload::Stdin { code, .. } => {
121                let (dir, path) = self.write_temp_source(code)?;
122                (Some(dir), path)
123            }
124            ExecutionPayload::File { path, .. } => (None, path.clone()),
125        };
126
127        let output = self.execute_path(&path, payload.args())?;
128        drop(temp_dir);
129
130        Ok(ExecutionOutcome {
131            language: self.id().to_string(),
132            exit_code: output.status.code(),
133            stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
134            stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
135            duration: start.elapsed(),
136        })
137    }
138
139    fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
140        let executable = self.ensure_executable()?.to_path_buf();
141        Ok(Box::new(HaskellSession::new(executable)?))
142    }
143}
144
145fn resolve_runghc_binary() -> Option<PathBuf> {
146    which::which("runghc").ok()
147}
148
149#[derive(Default)]
150struct HaskellSessionState {
151    imports: BTreeSet<String>,
152    declarations: Vec<String>,
153    statements: Vec<String>,
154}
155
156struct HaskellSession {
157    executable: PathBuf,
158    workspace: TempDir,
159    state: HaskellSessionState,
160    previous_stdout: String,
161    previous_stderr: String,
162}
163
164impl HaskellSession {
165    fn new(executable: PathBuf) -> Result<Self> {
166        let workspace = Builder::new()
167            .prefix("run-haskell-repl")
168            .tempdir()
169            .context("failed to create temporary directory for Haskell repl")?;
170        let session = Self {
171            executable,
172            workspace,
173            state: HaskellSessionState::default(),
174            previous_stdout: String::new(),
175            previous_stderr: String::new(),
176        };
177        session.persist_source()?;
178        Ok(session)
179    }
180
181    fn source_path(&self) -> PathBuf {
182        self.workspace.path().join("session.hs")
183    }
184
185    fn persist_source(&self) -> Result<()> {
186        let source = self.render_source();
187        fs::write(self.source_path(), source)
188            .with_context(|| "failed to write Haskell session source".to_string())
189    }
190
191    fn render_source(&self) -> String {
192        let mut source = String::new();
193        source.push_str("import Prelude\n");
194        for import in &self.state.imports {
195            source.push_str(import);
196            if !import.ends_with('\n') {
197                source.push('\n');
198            }
199        }
200        source.push('\n');
201
202        for decl in &self.state.declarations {
203            source.push_str(decl);
204            if !decl.ends_with('\n') {
205                source.push('\n');
206            }
207            source.push('\n');
208        }
209
210        source.push_str("main :: IO ()\n");
211        source.push_str("main = do\n");
212        if self.state.statements.is_empty() {
213            source.push_str("    return ()\n");
214        } else {
215            for stmt in &self.state.statements {
216                source.push_str(stmt);
217                if !stmt.ends_with('\n') {
218                    source.push('\n');
219                }
220            }
221
222            if let Some(last) = self.state.statements.last()
223                && last.trim().starts_with("let ")
224            {
225                source.push_str("    return ()\n");
226            }
227        }
228
229        source
230    }
231
232    fn run_program(&self) -> Result<std::process::Output> {
233        let mut cmd = Command::new(&self.executable);
234        cmd.arg("session.hs")
235            .stdout(Stdio::piped())
236            .stderr(Stdio::piped())
237            .current_dir(self.workspace.path());
238        cmd.output().with_context(|| {
239            format!(
240                "failed to execute {} for Haskell session",
241                self.executable.display()
242            )
243        })
244    }
245
246    fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
247        self.persist_source()?;
248        let output = self.run_program()?;
249        let stdout_full = normalize_output(&output.stdout);
250        let stderr_full = normalize_output(&output.stderr);
251
252        let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
253        let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
254
255        let success = output.status.success();
256        if success {
257            self.previous_stdout = stdout_full;
258            self.previous_stderr = stderr_full;
259        }
260
261        let outcome = ExecutionOutcome {
262            language: "haskell".to_string(),
263            exit_code: output.status.code(),
264            stdout: stdout_delta,
265            stderr: stderr_delta,
266            duration: start.elapsed(),
267        };
268
269        Ok((outcome, success))
270    }
271
272    fn apply_import(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
273        let mut inserted = Vec::new();
274        for line in code.lines() {
275            let trimmed = line.trim();
276            if trimmed.is_empty() {
277                continue;
278            }
279            let normalized = trimmed.to_string();
280            if self.state.imports.insert(normalized.clone()) {
281                inserted.push(normalized);
282            }
283        }
284
285        if inserted.is_empty() {
286            return Ok((
287                ExecutionOutcome {
288                    language: "haskell".to_string(),
289                    exit_code: None,
290                    stdout: String::new(),
291                    stderr: String::new(),
292                    duration: Duration::default(),
293                },
294                true,
295            ));
296        }
297
298        let start = Instant::now();
299        let (outcome, success) = self.run_current(start)?;
300        if !success {
301            for item in inserted {
302                self.state.imports.remove(&item);
303            }
304            self.persist_source()?;
305        }
306        Ok((outcome, success))
307    }
308
309    fn apply_declaration(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
310        let snippet = ensure_trailing_newline(code);
311        self.state.declarations.push(snippet);
312        let start = Instant::now();
313        let (outcome, success) = self.run_current(start)?;
314        if !success {
315            let _ = self.state.declarations.pop();
316            self.persist_source()?;
317        }
318        Ok((outcome, success))
319    }
320
321    fn apply_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
322        let snippet = indent_block(code);
323        self.state.statements.push(snippet);
324        let start = Instant::now();
325        let (outcome, success) = self.run_current(start)?;
326        if !success {
327            let _ = self.state.statements.pop();
328            self.persist_source()?;
329        }
330        Ok((outcome, success))
331    }
332
333    fn apply_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
334        let wrapped = wrap_expression(code);
335        self.state.statements.push(wrapped);
336        let start = Instant::now();
337        let (outcome, success) = self.run_current(start)?;
338        if !success {
339            let _ = self.state.statements.pop();
340            self.persist_source()?;
341        }
342        Ok((outcome, success))
343    }
344
345    fn reset(&mut self) -> Result<()> {
346        self.state.imports.clear();
347        self.state.declarations.clear();
348        self.state.statements.clear();
349        self.previous_stdout.clear();
350        self.previous_stderr.clear();
351        self.persist_source()
352    }
353}
354
355impl LanguageSession for HaskellSession {
356    fn language_id(&self) -> &str {
357        "haskell"
358    }
359
360    fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
361        let trimmed = code.trim();
362        if trimmed.is_empty() {
363            return Ok(ExecutionOutcome {
364                language: "haskell".to_string(),
365                exit_code: None,
366                stdout: String::new(),
367                stderr: String::new(),
368                duration: Duration::default(),
369            });
370        }
371
372        if trimmed.eq_ignore_ascii_case(":reset") {
373            self.reset()?;
374            return Ok(ExecutionOutcome {
375                language: "haskell".to_string(),
376                exit_code: None,
377                stdout: String::new(),
378                stderr: String::new(),
379                duration: Duration::default(),
380            });
381        }
382
383        if trimmed.eq_ignore_ascii_case(":help") {
384            return Ok(ExecutionOutcome {
385                language: "haskell".to_string(),
386                exit_code: None,
387                stdout: "Haskell commands:\n  :reset - clear session state\n  :help  - show this message\n"
388                    .to_string(),
389                stderr: String::new(),
390                duration: Duration::default(),
391            });
392        }
393
394        match classify_snippet(trimmed) {
395            HaskellSnippet::Import => {
396                let (outcome, _) = self.apply_import(code)?;
397                Ok(outcome)
398            }
399            HaskellSnippet::Declaration => {
400                let (outcome, _) = self.apply_declaration(code)?;
401                Ok(outcome)
402            }
403            HaskellSnippet::Expression => {
404                let (outcome, _) = self.apply_expression(trimmed)?;
405                Ok(outcome)
406            }
407            HaskellSnippet::Statement => {
408                let (outcome, _) = self.apply_statement(code)?;
409                Ok(outcome)
410            }
411        }
412    }
413
414    fn shutdown(&mut self) -> Result<()> {
415        Ok(())
416    }
417}
418
419enum HaskellSnippet {
420    Import,
421    Declaration,
422    Statement,
423    Expression,
424}
425
426fn classify_snippet(code: &str) -> HaskellSnippet {
427    if is_import(code) {
428        return HaskellSnippet::Import;
429    }
430
431    if is_declaration(code) {
432        return HaskellSnippet::Declaration;
433    }
434
435    if should_wrap_expression(code) {
436        return HaskellSnippet::Expression;
437    }
438
439    HaskellSnippet::Statement
440}
441
442fn is_import(code: &str) -> bool {
443    code.lines()
444        .all(|line| line.trim_start().starts_with("import "))
445}
446
447fn is_declaration(code: &str) -> bool {
448    let trimmed = code.trim_start();
449    if trimmed.starts_with("let ") {
450        return false;
451    }
452    let lowered = trimmed.to_ascii_lowercase();
453    const PREFIXES: [&str; 8] = [
454        "module ",
455        "data ",
456        "type ",
457        "newtype ",
458        "class ",
459        "instance ",
460        "foreign ",
461        "default ",
462    ];
463    if PREFIXES.iter().any(|prefix| lowered.starts_with(prefix)) {
464        return true;
465    }
466
467    if trimmed.contains("::") {
468        return true;
469    }
470
471    if !trimmed.contains('=') {
472        return false;
473    }
474
475    if let Some(lhs) = trimmed.split('=').next() {
476        let lhs = lhs.trim();
477        if lhs.is_empty() {
478            return false;
479        }
480        let first_token = lhs.split_whitespace().next().unwrap_or("");
481        if first_token.eq_ignore_ascii_case("let") {
482            return false;
483        }
484        first_token
485            .chars()
486            .next()
487            .map(|c| c.is_alphabetic())
488            .unwrap_or(false)
489    } else {
490        false
491    }
492}
493
494fn should_wrap_expression(code: &str) -> bool {
495    if code.contains('\n') {
496        return false;
497    }
498
499    let trimmed = code.trim();
500    if trimmed.is_empty() {
501        return false;
502    }
503
504    let lowered = trimmed.to_ascii_lowercase();
505    const STATEMENT_PREFIXES: [&str; 11] = [
506        "let ",
507        "case ",
508        "if ",
509        "do ",
510        "import ",
511        "module ",
512        "data ",
513        "type ",
514        "newtype ",
515        "class ",
516        "instance ",
517    ];
518
519    if STATEMENT_PREFIXES
520        .iter()
521        .any(|prefix| lowered.starts_with(prefix))
522    {
523        return false;
524    }
525
526    if trimmed.contains('=') || trimmed.contains("->") || trimmed.contains("<-") {
527        return false;
528    }
529
530    true
531}
532
533fn ensure_trailing_newline(code: &str) -> String {
534    let mut owned = code.to_string();
535    if !owned.ends_with('\n') {
536        owned.push('\n');
537    }
538    owned
539}
540
541fn indent_block(code: &str) -> String {
542    let mut result = String::new();
543    for line in code.split_inclusive('\n') {
544        if line.ends_with('\n') {
545            result.push_str("    ");
546            result.push_str(line);
547        } else {
548            result.push_str("    ");
549            result.push_str(line);
550            result.push('\n');
551        }
552    }
553    result
554}
555
556fn wrap_expression(code: &str) -> String {
557    indent_block(&format!("print (({}))\n", code.trim()))
558}
559
560fn diff_output(previous: &str, current: &str) -> String {
561    if let Some(stripped) = current.strip_prefix(previous) {
562        stripped.to_string()
563    } else {
564        current.to_string()
565    }
566}
567
568fn normalize_output(bytes: &[u8]) -> String {
569    String::from_utf8_lossy(bytes)
570        .replace("\r\n", "\n")
571        .replace('\r', "")
572}