1use std::collections::BTreeSet;
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{
11 ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession, run_version_command,
12};
13
14pub struct HaskellEngine {
15 executable: Option<PathBuf>,
16}
17
18impl Default for HaskellEngine {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl HaskellEngine {
25 pub fn new() -> Self {
26 Self {
27 executable: resolve_runghc_binary(),
28 }
29 }
30
31 fn ensure_executable(&self) -> Result<&Path> {
32 self.executable.as_deref().ok_or_else(|| {
33 anyhow::anyhow!(
34 "Haskell support requires the `runghc` executable. Install the GHC toolchain from https://www.haskell.org/ghc/ (or via ghcup) and ensure `runghc` is on your PATH."
35 )
36 })
37 }
38
39 fn write_temp_source(&self, code: &str) -> Result<(TempDir, PathBuf)> {
40 let dir = Builder::new()
41 .prefix("run-haskell")
42 .tempdir()
43 .context("failed to create temporary directory for Haskell source")?;
44 let path = dir.path().join("snippet.hs");
45 let mut contents = code.to_string();
46 if !contents.ends_with('\n') {
47 contents.push('\n');
48 }
49 fs::write(&path, contents).with_context(|| {
50 format!(
51 "failed to write temporary Haskell source to {}",
52 path.display()
53 )
54 })?;
55 Ok((dir, path))
56 }
57
58 fn execute_path(&self, path: &Path, args: &[String]) -> Result<std::process::Output> {
59 let executable = self.ensure_executable()?;
60 let mut cmd = Command::new(executable);
61 cmd.arg(path)
62 .args(args)
63 .stdout(Stdio::piped())
64 .stderr(Stdio::piped());
65 cmd.stdin(Stdio::inherit());
66 if let Some(parent) = path.parent() {
67 cmd.current_dir(parent);
68 }
69 cmd.output().with_context(|| {
70 format!(
71 "failed to execute {} with script {}",
72 executable.display(),
73 path.display()
74 )
75 })
76 }
77}
78
79impl LanguageEngine for HaskellEngine {
80 fn id(&self) -> &'static str {
81 "haskell"
82 }
83
84 fn display_name(&self) -> &'static str {
85 "Haskell"
86 }
87
88 fn aliases(&self) -> &[&'static str] {
89 &["hs", "ghci"]
90 }
91
92 fn supports_sessions(&self) -> bool {
93 self.executable.is_some()
94 }
95
96 fn validate(&self) -> Result<()> {
97 let executable = self.ensure_executable()?;
98 let mut cmd = Command::new(executable);
99 cmd.arg("--version")
100 .stdout(Stdio::null())
101 .stderr(Stdio::null());
102 cmd.status()
103 .with_context(|| format!("failed to invoke {}", executable.display()))?
104 .success()
105 .then_some(())
106 .ok_or_else(|| anyhow::anyhow!("{} is not executable", executable.display()))
107 }
108
109 fn toolchain_version(&self) -> Result<Option<String>> {
110 let executable = self.ensure_executable()?;
111 let mut cmd = Command::new(executable);
112 cmd.arg("--version");
113 let context = format!("{}", executable.display());
114 run_version_command(cmd, &context)
115 }
116
117 fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
118 let start = Instant::now();
119 let (temp_dir, path) = match payload {
120 ExecutionPayload::Inline { code, .. } | ExecutionPayload::Stdin { code, .. } => {
121 let (dir, path) = self.write_temp_source(code)?;
122 (Some(dir), path)
123 }
124 ExecutionPayload::File { path, .. } => (None, path.clone()),
125 };
126
127 let output = self.execute_path(&path, payload.args())?;
128 drop(temp_dir);
129
130 Ok(ExecutionOutcome {
131 language: self.id().to_string(),
132 exit_code: output.status.code(),
133 stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
134 stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
135 duration: start.elapsed(),
136 })
137 }
138
139 fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
140 let executable = self.ensure_executable()?.to_path_buf();
141 Ok(Box::new(HaskellSession::new(executable)?))
142 }
143}
144
145fn resolve_runghc_binary() -> Option<PathBuf> {
146 which::which("runghc").ok()
147}
148
149#[derive(Default)]
150struct HaskellSessionState {
151 imports: BTreeSet<String>,
152 declarations: Vec<String>,
153 statements: Vec<String>,
154}
155
156struct HaskellSession {
157 executable: PathBuf,
158 workspace: TempDir,
159 state: HaskellSessionState,
160 previous_stdout: String,
161 previous_stderr: String,
162}
163
164impl HaskellSession {
165 fn new(executable: PathBuf) -> Result<Self> {
166 let workspace = Builder::new()
167 .prefix("run-haskell-repl")
168 .tempdir()
169 .context("failed to create temporary directory for Haskell repl")?;
170 let session = Self {
171 executable,
172 workspace,
173 state: HaskellSessionState::default(),
174 previous_stdout: String::new(),
175 previous_stderr: String::new(),
176 };
177 session.persist_source()?;
178 Ok(session)
179 }
180
181 fn source_path(&self) -> PathBuf {
182 self.workspace.path().join("session.hs")
183 }
184
185 fn persist_source(&self) -> Result<()> {
186 let source = self.render_source();
187 fs::write(self.source_path(), source)
188 .with_context(|| "failed to write Haskell session source".to_string())
189 }
190
191 fn render_source(&self) -> String {
192 let mut source = String::new();
193 source.push_str("import Prelude\n");
194 for import in &self.state.imports {
195 source.push_str(import);
196 if !import.ends_with('\n') {
197 source.push('\n');
198 }
199 }
200 source.push('\n');
201
202 for decl in &self.state.declarations {
203 source.push_str(decl);
204 if !decl.ends_with('\n') {
205 source.push('\n');
206 }
207 source.push('\n');
208 }
209
210 source.push_str("main :: IO ()\n");
211 source.push_str("main = do\n");
212 if self.state.statements.is_empty() {
213 source.push_str(" return ()\n");
214 } else {
215 for stmt in &self.state.statements {
216 source.push_str(stmt);
217 if !stmt.ends_with('\n') {
218 source.push('\n');
219 }
220 }
221
222 if let Some(last) = self.state.statements.last()
223 && last.trim().starts_with("let ")
224 {
225 source.push_str(" return ()\n");
226 }
227 }
228
229 source
230 }
231
232 fn run_program(&self) -> Result<std::process::Output> {
233 let mut cmd = Command::new(&self.executable);
234 cmd.arg("session.hs")
235 .stdout(Stdio::piped())
236 .stderr(Stdio::piped())
237 .current_dir(self.workspace.path());
238 cmd.output().with_context(|| {
239 format!(
240 "failed to execute {} for Haskell session",
241 self.executable.display()
242 )
243 })
244 }
245
246 fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
247 self.persist_source()?;
248 let output = self.run_program()?;
249 let stdout_full = normalize_output(&output.stdout);
250 let stderr_full = normalize_output(&output.stderr);
251
252 let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
253 let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
254
255 let success = output.status.success();
256 if success {
257 self.previous_stdout = stdout_full;
258 self.previous_stderr = stderr_full;
259 }
260
261 let outcome = ExecutionOutcome {
262 language: "haskell".to_string(),
263 exit_code: output.status.code(),
264 stdout: stdout_delta,
265 stderr: stderr_delta,
266 duration: start.elapsed(),
267 };
268
269 Ok((outcome, success))
270 }
271
272 fn apply_import(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
273 let mut inserted = Vec::new();
274 for line in code.lines() {
275 let trimmed = line.trim();
276 if trimmed.is_empty() {
277 continue;
278 }
279 let normalized = trimmed.to_string();
280 if self.state.imports.insert(normalized.clone()) {
281 inserted.push(normalized);
282 }
283 }
284
285 if inserted.is_empty() {
286 return Ok((
287 ExecutionOutcome {
288 language: "haskell".to_string(),
289 exit_code: None,
290 stdout: String::new(),
291 stderr: String::new(),
292 duration: Duration::default(),
293 },
294 true,
295 ));
296 }
297
298 let start = Instant::now();
299 let (outcome, success) = self.run_current(start)?;
300 if !success {
301 for item in inserted {
302 self.state.imports.remove(&item);
303 }
304 self.persist_source()?;
305 }
306 Ok((outcome, success))
307 }
308
309 fn apply_declaration(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
310 let snippet = ensure_trailing_newline(code);
311 self.state.declarations.push(snippet);
312 let start = Instant::now();
313 let (outcome, success) = self.run_current(start)?;
314 if !success {
315 let _ = self.state.declarations.pop();
316 self.persist_source()?;
317 }
318 Ok((outcome, success))
319 }
320
321 fn apply_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
322 let snippet = indent_block(code);
323 self.state.statements.push(snippet);
324 let start = Instant::now();
325 let (outcome, success) = self.run_current(start)?;
326 if !success {
327 let _ = self.state.statements.pop();
328 self.persist_source()?;
329 }
330 Ok((outcome, success))
331 }
332
333 fn apply_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
334 let wrapped = wrap_expression(code);
335 self.state.statements.push(wrapped);
336 let start = Instant::now();
337 let (outcome, success) = self.run_current(start)?;
338 if !success {
339 let _ = self.state.statements.pop();
340 self.persist_source()?;
341 }
342 Ok((outcome, success))
343 }
344
345 fn reset(&mut self) -> Result<()> {
346 self.state.imports.clear();
347 self.state.declarations.clear();
348 self.state.statements.clear();
349 self.previous_stdout.clear();
350 self.previous_stderr.clear();
351 self.persist_source()
352 }
353}
354
355impl LanguageSession for HaskellSession {
356 fn language_id(&self) -> &str {
357 "haskell"
358 }
359
360 fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
361 let trimmed = code.trim();
362 if trimmed.is_empty() {
363 return Ok(ExecutionOutcome {
364 language: "haskell".to_string(),
365 exit_code: None,
366 stdout: String::new(),
367 stderr: String::new(),
368 duration: Duration::default(),
369 });
370 }
371
372 if trimmed.eq_ignore_ascii_case(":reset") {
373 self.reset()?;
374 return Ok(ExecutionOutcome {
375 language: "haskell".to_string(),
376 exit_code: None,
377 stdout: String::new(),
378 stderr: String::new(),
379 duration: Duration::default(),
380 });
381 }
382
383 if trimmed.eq_ignore_ascii_case(":help") {
384 return Ok(ExecutionOutcome {
385 language: "haskell".to_string(),
386 exit_code: None,
387 stdout: "Haskell commands:\n :reset - clear session state\n :help - show this message\n"
388 .to_string(),
389 stderr: String::new(),
390 duration: Duration::default(),
391 });
392 }
393
394 match classify_snippet(trimmed) {
395 HaskellSnippet::Import => {
396 let (outcome, _) = self.apply_import(code)?;
397 Ok(outcome)
398 }
399 HaskellSnippet::Declaration => {
400 let (outcome, _) = self.apply_declaration(code)?;
401 Ok(outcome)
402 }
403 HaskellSnippet::Expression => {
404 let (outcome, _) = self.apply_expression(trimmed)?;
405 Ok(outcome)
406 }
407 HaskellSnippet::Statement => {
408 let (outcome, _) = self.apply_statement(code)?;
409 Ok(outcome)
410 }
411 }
412 }
413
414 fn shutdown(&mut self) -> Result<()> {
415 Ok(())
416 }
417}
418
419enum HaskellSnippet {
420 Import,
421 Declaration,
422 Statement,
423 Expression,
424}
425
426fn classify_snippet(code: &str) -> HaskellSnippet {
427 if is_import(code) {
428 return HaskellSnippet::Import;
429 }
430
431 if is_declaration(code) {
432 return HaskellSnippet::Declaration;
433 }
434
435 if should_wrap_expression(code) {
436 return HaskellSnippet::Expression;
437 }
438
439 HaskellSnippet::Statement
440}
441
442fn is_import(code: &str) -> bool {
443 code.lines()
444 .all(|line| line.trim_start().starts_with("import "))
445}
446
447fn is_declaration(code: &str) -> bool {
448 let trimmed = code.trim_start();
449 if trimmed.starts_with("let ") {
450 return false;
451 }
452 let lowered = trimmed.to_ascii_lowercase();
453 const PREFIXES: [&str; 8] = [
454 "module ",
455 "data ",
456 "type ",
457 "newtype ",
458 "class ",
459 "instance ",
460 "foreign ",
461 "default ",
462 ];
463 if PREFIXES.iter().any(|prefix| lowered.starts_with(prefix)) {
464 return true;
465 }
466
467 if trimmed.contains("::") {
468 return true;
469 }
470
471 if !trimmed.contains('=') {
472 return false;
473 }
474
475 if let Some(lhs) = trimmed.split('=').next() {
476 let lhs = lhs.trim();
477 if lhs.is_empty() {
478 return false;
479 }
480 let first_token = lhs.split_whitespace().next().unwrap_or("");
481 if first_token.eq_ignore_ascii_case("let") {
482 return false;
483 }
484 first_token
485 .chars()
486 .next()
487 .map(|c| c.is_alphabetic())
488 .unwrap_or(false)
489 } else {
490 false
491 }
492}
493
494fn should_wrap_expression(code: &str) -> bool {
495 if code.contains('\n') {
496 return false;
497 }
498
499 let trimmed = code.trim();
500 if trimmed.is_empty() {
501 return false;
502 }
503
504 let lowered = trimmed.to_ascii_lowercase();
505 const STATEMENT_PREFIXES: [&str; 11] = [
506 "let ",
507 "case ",
508 "if ",
509 "do ",
510 "import ",
511 "module ",
512 "data ",
513 "type ",
514 "newtype ",
515 "class ",
516 "instance ",
517 ];
518
519 if STATEMENT_PREFIXES
520 .iter()
521 .any(|prefix| lowered.starts_with(prefix))
522 {
523 return false;
524 }
525
526 if trimmed.contains('=') || trimmed.contains("->") || trimmed.contains("<-") {
527 return false;
528 }
529
530 true
531}
532
533fn ensure_trailing_newline(code: &str) -> String {
534 let mut owned = code.to_string();
535 if !owned.ends_with('\n') {
536 owned.push('\n');
537 }
538 owned
539}
540
541fn indent_block(code: &str) -> String {
542 let mut result = String::new();
543 for line in code.split_inclusive('\n') {
544 if line.ends_with('\n') {
545 result.push_str(" ");
546 result.push_str(line);
547 } else {
548 result.push_str(" ");
549 result.push_str(line);
550 result.push('\n');
551 }
552 }
553 result
554}
555
556fn wrap_expression(code: &str) -> String {
557 indent_block(&format!("print (({}))\n", code.trim()))
558}
559
560fn diff_output(previous: &str, current: &str) -> String {
561 if let Some(stripped) = current.strip_prefix(previous) {
562 stripped.to_string()
563 } else {
564 current.to_string()
565 }
566}
567
568fn normalize_output(bytes: &[u8]) -> String {
569 String::from_utf8_lossy(bytes)
570 .replace("\r\n", "\n")
571 .replace('\r', "")
572}