1use std::collections::BTreeSet;
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
11
12pub struct HaskellEngine {
13 executable: Option<PathBuf>,
14}
15
16impl HaskellEngine {
17 pub fn new() -> Self {
18 Self {
19 executable: resolve_runghc_binary(),
20 }
21 }
22
23 fn ensure_executable(&self) -> Result<&Path> {
24 self.executable.as_deref().ok_or_else(|| {
25 anyhow::anyhow!(
26 "Haskell support requires the `runghc` executable. Install the GHC toolchain from https://www.haskell.org/ghc/ (or via ghcup) and ensure `runghc` is on your PATH."
27 )
28 })
29 }
30
31 fn write_temp_source(&self, code: &str) -> Result<(TempDir, PathBuf)> {
32 let dir = Builder::new()
33 .prefix("run-haskell")
34 .tempdir()
35 .context("failed to create temporary directory for Haskell source")?;
36 let path = dir.path().join("snippet.hs");
37 let mut contents = code.to_string();
38 if !contents.ends_with('\n') {
39 contents.push('\n');
40 }
41 fs::write(&path, contents).with_context(|| {
42 format!(
43 "failed to write temporary Haskell source to {}",
44 path.display()
45 )
46 })?;
47 Ok((dir, path))
48 }
49
50 fn execute_path(&self, path: &Path) -> Result<std::process::Output> {
51 let executable = self.ensure_executable()?;
52 let mut cmd = Command::new(executable);
53 cmd.arg(path).stdout(Stdio::piped()).stderr(Stdio::piped());
54 cmd.stdin(Stdio::inherit());
55 if let Some(parent) = path.parent() {
56 cmd.current_dir(parent);
57 }
58 cmd.output().with_context(|| {
59 format!(
60 "failed to execute {} with script {}",
61 executable.display(),
62 path.display()
63 )
64 })
65 }
66}
67
68impl LanguageEngine for HaskellEngine {
69 fn id(&self) -> &'static str {
70 "haskell"
71 }
72
73 fn display_name(&self) -> &'static str {
74 "Haskell"
75 }
76
77 fn aliases(&self) -> &[&'static str] {
78 &["hs", "ghci"]
79 }
80
81 fn supports_sessions(&self) -> bool {
82 self.executable.is_some()
83 }
84
85 fn validate(&self) -> Result<()> {
86 let executable = self.ensure_executable()?;
87 let mut cmd = Command::new(executable);
88 cmd.arg("--version")
89 .stdout(Stdio::null())
90 .stderr(Stdio::null());
91 cmd.status()
92 .with_context(|| format!("failed to invoke {}", executable.display()))?
93 .success()
94 .then_some(())
95 .ok_or_else(|| anyhow::anyhow!("{} is not executable", executable.display()))
96 }
97
98 fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
99 let start = Instant::now();
100 let (temp_dir, path) = match payload {
101 ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
102 let (dir, path) = self.write_temp_source(code)?;
103 (Some(dir), path)
104 }
105 ExecutionPayload::File { path } => (None, path.clone()),
106 };
107
108 let output = self.execute_path(&path)?;
109 drop(temp_dir);
110
111 Ok(ExecutionOutcome {
112 language: self.id().to_string(),
113 exit_code: output.status.code(),
114 stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
115 stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
116 duration: start.elapsed(),
117 })
118 }
119
120 fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
121 let executable = self.ensure_executable()?.to_path_buf();
122 Ok(Box::new(HaskellSession::new(executable)?))
123 }
124}
125
126fn resolve_runghc_binary() -> Option<PathBuf> {
127 which::which("runghc").ok()
128}
129
130#[derive(Default)]
131struct HaskellSessionState {
132 imports: BTreeSet<String>,
133 declarations: Vec<String>,
134 statements: Vec<String>,
135}
136
137struct HaskellSession {
138 executable: PathBuf,
139 workspace: TempDir,
140 state: HaskellSessionState,
141 previous_stdout: String,
142 previous_stderr: String,
143}
144
145impl HaskellSession {
146 fn new(executable: PathBuf) -> Result<Self> {
147 let workspace = Builder::new()
148 .prefix("run-haskell-repl")
149 .tempdir()
150 .context("failed to create temporary directory for Haskell repl")?;
151 let session = Self {
152 executable,
153 workspace,
154 state: HaskellSessionState::default(),
155 previous_stdout: String::new(),
156 previous_stderr: String::new(),
157 };
158 session.persist_source()?;
159 Ok(session)
160 }
161
162 fn source_path(&self) -> PathBuf {
163 self.workspace.path().join("session.hs")
164 }
165
166 fn persist_source(&self) -> Result<()> {
167 let source = self.render_source();
168 fs::write(self.source_path(), source)
169 .with_context(|| "failed to write Haskell session source".to_string())
170 }
171
172 fn render_source(&self) -> String {
173 let mut source = String::new();
174 source.push_str("import Prelude\n");
175 for import in &self.state.imports {
176 source.push_str(import);
177 if !import.ends_with('\n') {
178 source.push('\n');
179 }
180 }
181 source.push('\n');
182
183 for decl in &self.state.declarations {
184 source.push_str(decl);
185 if !decl.ends_with('\n') {
186 source.push('\n');
187 }
188 source.push('\n');
189 }
190
191 source.push_str("main :: IO ()\n");
192 source.push_str("main = do\n");
193 if self.state.statements.is_empty() {
194 source.push_str(" return ()\n");
195 } else {
196 for stmt in &self.state.statements {
197 source.push_str(stmt);
198 if !stmt.ends_with('\n') {
199 source.push('\n');
200 }
201 }
202
203 if let Some(last) = self.state.statements.last() {
204 if last.trim().starts_with("let ") {
205 source.push_str(" return ()\n");
206 }
207 }
208 }
209
210 source
211 }
212
213 fn run_program(&self) -> Result<std::process::Output> {
214 let mut cmd = Command::new(&self.executable);
215 cmd.arg("session.hs")
216 .stdout(Stdio::piped())
217 .stderr(Stdio::piped())
218 .current_dir(self.workspace.path());
219 cmd.output().with_context(|| {
220 format!(
221 "failed to execute {} for Haskell session",
222 self.executable.display()
223 )
224 })
225 }
226
227 fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
228 self.persist_source()?;
229 let output = self.run_program()?;
230 let stdout_full = normalize_output(&output.stdout);
231 let stderr_full = normalize_output(&output.stderr);
232
233 let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
234 let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
235
236 let success = output.status.success();
237 if success {
238 self.previous_stdout = stdout_full;
239 self.previous_stderr = stderr_full;
240 }
241
242 let outcome = ExecutionOutcome {
243 language: "haskell".to_string(),
244 exit_code: output.status.code(),
245 stdout: stdout_delta,
246 stderr: stderr_delta,
247 duration: start.elapsed(),
248 };
249
250 Ok((outcome, success))
251 }
252
253 fn apply_import(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
254 let mut inserted = Vec::new();
255 for line in code.lines() {
256 let trimmed = line.trim();
257 if trimmed.is_empty() {
258 continue;
259 }
260 let normalized = trimmed.to_string();
261 if self.state.imports.insert(normalized.clone()) {
262 inserted.push(normalized);
263 }
264 }
265
266 if inserted.is_empty() {
267 return Ok((
268 ExecutionOutcome {
269 language: "haskell".to_string(),
270 exit_code: None,
271 stdout: String::new(),
272 stderr: String::new(),
273 duration: Duration::default(),
274 },
275 true,
276 ));
277 }
278
279 let start = Instant::now();
280 let (outcome, success) = self.run_current(start)?;
281 if !success {
282 for item in inserted {
283 self.state.imports.remove(&item);
284 }
285 self.persist_source()?;
286 }
287 Ok((outcome, success))
288 }
289
290 fn apply_declaration(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
291 let snippet = ensure_trailing_newline(code);
292 self.state.declarations.push(snippet);
293 let start = Instant::now();
294 let (outcome, success) = self.run_current(start)?;
295 if !success {
296 let _ = self.state.declarations.pop();
297 self.persist_source()?;
298 }
299 Ok((outcome, success))
300 }
301
302 fn apply_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
303 let snippet = indent_block(code);
304 self.state.statements.push(snippet);
305 let start = Instant::now();
306 let (outcome, success) = self.run_current(start)?;
307 if !success {
308 let _ = self.state.statements.pop();
309 self.persist_source()?;
310 }
311 Ok((outcome, success))
312 }
313
314 fn apply_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
315 let wrapped = wrap_expression(code);
316 self.state.statements.push(wrapped);
317 let start = Instant::now();
318 let (outcome, success) = self.run_current(start)?;
319 if !success {
320 let _ = self.state.statements.pop();
321 self.persist_source()?;
322 }
323 Ok((outcome, success))
324 }
325
326 fn reset(&mut self) -> Result<()> {
327 self.state.imports.clear();
328 self.state.declarations.clear();
329 self.state.statements.clear();
330 self.previous_stdout.clear();
331 self.previous_stderr.clear();
332 self.persist_source()
333 }
334}
335
336impl LanguageSession for HaskellSession {
337 fn language_id(&self) -> &str {
338 "haskell"
339 }
340
341 fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
342 let trimmed = code.trim();
343 if trimmed.is_empty() {
344 return Ok(ExecutionOutcome {
345 language: "haskell".to_string(),
346 exit_code: None,
347 stdout: String::new(),
348 stderr: String::new(),
349 duration: Duration::default(),
350 });
351 }
352
353 if trimmed.eq_ignore_ascii_case(":reset") {
354 self.reset()?;
355 return Ok(ExecutionOutcome {
356 language: "haskell".to_string(),
357 exit_code: None,
358 stdout: String::new(),
359 stderr: String::new(),
360 duration: Duration::default(),
361 });
362 }
363
364 if trimmed.eq_ignore_ascii_case(":help") {
365 return Ok(ExecutionOutcome {
366 language: "haskell".to_string(),
367 exit_code: None,
368 stdout: "Haskell commands:\n :reset — clear session state\n :help — show this message\n"
369 .to_string(),
370 stderr: String::new(),
371 duration: Duration::default(),
372 });
373 }
374
375 match classify_snippet(trimmed) {
376 HaskellSnippet::Import => {
377 let (outcome, _) = self.apply_import(code)?;
378 Ok(outcome)
379 }
380 HaskellSnippet::Declaration => {
381 let (outcome, _) = self.apply_declaration(code)?;
382 Ok(outcome)
383 }
384 HaskellSnippet::Expression => {
385 let (outcome, _) = self.apply_expression(trimmed)?;
386 Ok(outcome)
387 }
388 HaskellSnippet::Statement => {
389 let (outcome, _) = self.apply_statement(code)?;
390 Ok(outcome)
391 }
392 }
393 }
394
395 fn shutdown(&mut self) -> Result<()> {
396 Ok(())
397 }
398}
399
400enum HaskellSnippet {
401 Import,
402 Declaration,
403 Statement,
404 Expression,
405}
406
407fn classify_snippet(code: &str) -> HaskellSnippet {
408 if is_import(code) {
409 return HaskellSnippet::Import;
410 }
411
412 if is_declaration(code) {
413 return HaskellSnippet::Declaration;
414 }
415
416 if should_wrap_expression(code) {
417 return HaskellSnippet::Expression;
418 }
419
420 HaskellSnippet::Statement
421}
422
423fn is_import(code: &str) -> bool {
424 code.lines()
425 .all(|line| line.trim_start().starts_with("import "))
426}
427
428fn is_declaration(code: &str) -> bool {
429 let trimmed = code.trim_start();
430 if trimmed.starts_with("let ") {
431 return false;
432 }
433 let lowered = trimmed.to_ascii_lowercase();
434 const PREFIXES: [&str; 8] = [
435 "module ",
436 "data ",
437 "type ",
438 "newtype ",
439 "class ",
440 "instance ",
441 "foreign ",
442 "default ",
443 ];
444 if PREFIXES.iter().any(|prefix| lowered.starts_with(prefix)) {
445 return true;
446 }
447
448 if trimmed.contains("::") {
449 return true;
450 }
451
452 if !trimmed.contains('=') {
455 return false;
456 }
457
458 if let Some(lhs) = trimmed.split('=').next() {
459 let lhs = lhs.trim();
460 if lhs.is_empty() {
461 return false;
462 }
463 let first_token = lhs.split_whitespace().next().unwrap_or("");
464 if first_token.eq_ignore_ascii_case("let") {
465 return false;
466 }
467 first_token
468 .chars()
469 .next()
470 .map(|c| c.is_alphabetic())
471 .unwrap_or(false)
472 } else {
473 false
474 }
475}
476
477fn should_wrap_expression(code: &str) -> bool {
478 if code.contains('\n') {
479 return false;
480 }
481
482 let trimmed = code.trim();
483 if trimmed.is_empty() {
484 return false;
485 }
486
487 let lowered = trimmed.to_ascii_lowercase();
488 const STATEMENT_PREFIXES: [&str; 11] = [
489 "let ",
490 "case ",
491 "if ",
492 "do ",
493 "import ",
494 "module ",
495 "data ",
496 "type ",
497 "newtype ",
498 "class ",
499 "instance ",
500 ];
501
502 if STATEMENT_PREFIXES
503 .iter()
504 .any(|prefix| lowered.starts_with(prefix))
505 {
506 return false;
507 }
508
509 if trimmed.contains('=') || trimmed.contains("->") || trimmed.contains("<-") {
510 return false;
511 }
512
513 true
514}
515
516fn ensure_trailing_newline(code: &str) -> String {
517 let mut owned = code.to_string();
518 if !owned.ends_with('\n') {
519 owned.push('\n');
520 }
521 owned
522}
523
524fn indent_block(code: &str) -> String {
525 let mut result = String::new();
526 for line in code.split_inclusive('\n') {
527 if line.ends_with('\n') {
528 result.push_str(" ");
529 result.push_str(line);
530 } else {
531 result.push_str(" ");
532 result.push_str(line);
533 result.push('\n');
534 }
535 }
536 result
537}
538
539fn wrap_expression(code: &str) -> String {
540 indent_block(&format!("print (({}))\n", code.trim()))
541}
542
543fn diff_output(previous: &str, current: &str) -> String {
544 if let Some(stripped) = current.strip_prefix(previous) {
545 stripped.to_string()
546 } else {
547 current.to_string()
548 }
549}
550
551fn normalize_output(bytes: &[u8]) -> String {
552 String::from_utf8_lossy(bytes)
553 .replace("\r\n", "\n")
554 .replace('\r', "")
555}