1use std::collections::BTreeSet;
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::process::{Command, Stdio};
5use std::time::{Duration, Instant};
6
7use anyhow::{Context, Result};
8use tempfile::{Builder, TempDir};
9
10use super::{ExecutionOutcome, ExecutionPayload, LanguageEngine, LanguageSession};
11
12pub struct HaskellEngine {
13 executable: Option<PathBuf>,
14}
15
16impl Default for HaskellEngine {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl HaskellEngine {
23 pub fn new() -> Self {
24 Self {
25 executable: resolve_runghc_binary(),
26 }
27 }
28
29 fn ensure_executable(&self) -> Result<&Path> {
30 self.executable.as_deref().ok_or_else(|| {
31 anyhow::anyhow!(
32 "Haskell support requires the `runghc` executable. Install the GHC toolchain from https://www.haskell.org/ghc/ (or via ghcup) and ensure `runghc` is on your PATH."
33 )
34 })
35 }
36
37 fn write_temp_source(&self, code: &str) -> Result<(TempDir, PathBuf)> {
38 let dir = Builder::new()
39 .prefix("run-haskell")
40 .tempdir()
41 .context("failed to create temporary directory for Haskell source")?;
42 let path = dir.path().join("snippet.hs");
43 let mut contents = code.to_string();
44 if !contents.ends_with('\n') {
45 contents.push('\n');
46 }
47 fs::write(&path, contents).with_context(|| {
48 format!(
49 "failed to write temporary Haskell source to {}",
50 path.display()
51 )
52 })?;
53 Ok((dir, path))
54 }
55
56 fn execute_path(&self, path: &Path) -> Result<std::process::Output> {
57 let executable = self.ensure_executable()?;
58 let mut cmd = Command::new(executable);
59 cmd.arg(path).stdout(Stdio::piped()).stderr(Stdio::piped());
60 cmd.stdin(Stdio::inherit());
61 if let Some(parent) = path.parent() {
62 cmd.current_dir(parent);
63 }
64 cmd.output().with_context(|| {
65 format!(
66 "failed to execute {} with script {}",
67 executable.display(),
68 path.display()
69 )
70 })
71 }
72}
73
74impl LanguageEngine for HaskellEngine {
75 fn id(&self) -> &'static str {
76 "haskell"
77 }
78
79 fn display_name(&self) -> &'static str {
80 "Haskell"
81 }
82
83 fn aliases(&self) -> &[&'static str] {
84 &["hs", "ghci"]
85 }
86
87 fn supports_sessions(&self) -> bool {
88 self.executable.is_some()
89 }
90
91 fn validate(&self) -> Result<()> {
92 let executable = self.ensure_executable()?;
93 let mut cmd = Command::new(executable);
94 cmd.arg("--version")
95 .stdout(Stdio::null())
96 .stderr(Stdio::null());
97 cmd.status()
98 .with_context(|| format!("failed to invoke {}", executable.display()))?
99 .success()
100 .then_some(())
101 .ok_or_else(|| anyhow::anyhow!("{} is not executable", executable.display()))
102 }
103
104 fn execute(&self, payload: &ExecutionPayload) -> Result<ExecutionOutcome> {
105 let start = Instant::now();
106 let (temp_dir, path) = match payload {
107 ExecutionPayload::Inline { code } | ExecutionPayload::Stdin { code } => {
108 let (dir, path) = self.write_temp_source(code)?;
109 (Some(dir), path)
110 }
111 ExecutionPayload::File { path } => (None, path.clone()),
112 };
113
114 let output = self.execute_path(&path)?;
115 drop(temp_dir);
116
117 Ok(ExecutionOutcome {
118 language: self.id().to_string(),
119 exit_code: output.status.code(),
120 stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
121 stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
122 duration: start.elapsed(),
123 })
124 }
125
126 fn start_session(&self) -> Result<Box<dyn LanguageSession>> {
127 let executable = self.ensure_executable()?.to_path_buf();
128 Ok(Box::new(HaskellSession::new(executable)?))
129 }
130}
131
132fn resolve_runghc_binary() -> Option<PathBuf> {
133 which::which("runghc").ok()
134}
135
136#[derive(Default)]
137struct HaskellSessionState {
138 imports: BTreeSet<String>,
139 declarations: Vec<String>,
140 statements: Vec<String>,
141}
142
143struct HaskellSession {
144 executable: PathBuf,
145 workspace: TempDir,
146 state: HaskellSessionState,
147 previous_stdout: String,
148 previous_stderr: String,
149}
150
151impl HaskellSession {
152 fn new(executable: PathBuf) -> Result<Self> {
153 let workspace = Builder::new()
154 .prefix("run-haskell-repl")
155 .tempdir()
156 .context("failed to create temporary directory for Haskell repl")?;
157 let session = Self {
158 executable,
159 workspace,
160 state: HaskellSessionState::default(),
161 previous_stdout: String::new(),
162 previous_stderr: String::new(),
163 };
164 session.persist_source()?;
165 Ok(session)
166 }
167
168 fn source_path(&self) -> PathBuf {
169 self.workspace.path().join("session.hs")
170 }
171
172 fn persist_source(&self) -> Result<()> {
173 let source = self.render_source();
174 fs::write(self.source_path(), source)
175 .with_context(|| "failed to write Haskell session source".to_string())
176 }
177
178 fn render_source(&self) -> String {
179 let mut source = String::new();
180 source.push_str("import Prelude\n");
181 for import in &self.state.imports {
182 source.push_str(import);
183 if !import.ends_with('\n') {
184 source.push('\n');
185 }
186 }
187 source.push('\n');
188
189 for decl in &self.state.declarations {
190 source.push_str(decl);
191 if !decl.ends_with('\n') {
192 source.push('\n');
193 }
194 source.push('\n');
195 }
196
197 source.push_str("main :: IO ()\n");
198 source.push_str("main = do\n");
199 if self.state.statements.is_empty() {
200 source.push_str(" return ()\n");
201 } else {
202 for stmt in &self.state.statements {
203 source.push_str(stmt);
204 if !stmt.ends_with('\n') {
205 source.push('\n');
206 }
207 }
208
209 if let Some(last) = self.state.statements.last()
210 && last.trim().starts_with("let ")
211 {
212 source.push_str(" return ()\n");
213 }
214 }
215
216 source
217 }
218
219 fn run_program(&self) -> Result<std::process::Output> {
220 let mut cmd = Command::new(&self.executable);
221 cmd.arg("session.hs")
222 .stdout(Stdio::piped())
223 .stderr(Stdio::piped())
224 .current_dir(self.workspace.path());
225 cmd.output().with_context(|| {
226 format!(
227 "failed to execute {} for Haskell session",
228 self.executable.display()
229 )
230 })
231 }
232
233 fn run_current(&mut self, start: Instant) -> Result<(ExecutionOutcome, bool)> {
234 self.persist_source()?;
235 let output = self.run_program()?;
236 let stdout_full = normalize_output(&output.stdout);
237 let stderr_full = normalize_output(&output.stderr);
238
239 let stdout_delta = diff_output(&self.previous_stdout, &stdout_full);
240 let stderr_delta = diff_output(&self.previous_stderr, &stderr_full);
241
242 let success = output.status.success();
243 if success {
244 self.previous_stdout = stdout_full;
245 self.previous_stderr = stderr_full;
246 }
247
248 let outcome = ExecutionOutcome {
249 language: "haskell".to_string(),
250 exit_code: output.status.code(),
251 stdout: stdout_delta,
252 stderr: stderr_delta,
253 duration: start.elapsed(),
254 };
255
256 Ok((outcome, success))
257 }
258
259 fn apply_import(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
260 let mut inserted = Vec::new();
261 for line in code.lines() {
262 let trimmed = line.trim();
263 if trimmed.is_empty() {
264 continue;
265 }
266 let normalized = trimmed.to_string();
267 if self.state.imports.insert(normalized.clone()) {
268 inserted.push(normalized);
269 }
270 }
271
272 if inserted.is_empty() {
273 return Ok((
274 ExecutionOutcome {
275 language: "haskell".to_string(),
276 exit_code: None,
277 stdout: String::new(),
278 stderr: String::new(),
279 duration: Duration::default(),
280 },
281 true,
282 ));
283 }
284
285 let start = Instant::now();
286 let (outcome, success) = self.run_current(start)?;
287 if !success {
288 for item in inserted {
289 self.state.imports.remove(&item);
290 }
291 self.persist_source()?;
292 }
293 Ok((outcome, success))
294 }
295
296 fn apply_declaration(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
297 let snippet = ensure_trailing_newline(code);
298 self.state.declarations.push(snippet);
299 let start = Instant::now();
300 let (outcome, success) = self.run_current(start)?;
301 if !success {
302 let _ = self.state.declarations.pop();
303 self.persist_source()?;
304 }
305 Ok((outcome, success))
306 }
307
308 fn apply_statement(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
309 let snippet = indent_block(code);
310 self.state.statements.push(snippet);
311 let start = Instant::now();
312 let (outcome, success) = self.run_current(start)?;
313 if !success {
314 let _ = self.state.statements.pop();
315 self.persist_source()?;
316 }
317 Ok((outcome, success))
318 }
319
320 fn apply_expression(&mut self, code: &str) -> Result<(ExecutionOutcome, bool)> {
321 let wrapped = wrap_expression(code);
322 self.state.statements.push(wrapped);
323 let start = Instant::now();
324 let (outcome, success) = self.run_current(start)?;
325 if !success {
326 let _ = self.state.statements.pop();
327 self.persist_source()?;
328 }
329 Ok((outcome, success))
330 }
331
332 fn reset(&mut self) -> Result<()> {
333 self.state.imports.clear();
334 self.state.declarations.clear();
335 self.state.statements.clear();
336 self.previous_stdout.clear();
337 self.previous_stderr.clear();
338 self.persist_source()
339 }
340}
341
342impl LanguageSession for HaskellSession {
343 fn language_id(&self) -> &str {
344 "haskell"
345 }
346
347 fn eval(&mut self, code: &str) -> Result<ExecutionOutcome> {
348 let trimmed = code.trim();
349 if trimmed.is_empty() {
350 return Ok(ExecutionOutcome {
351 language: "haskell".to_string(),
352 exit_code: None,
353 stdout: String::new(),
354 stderr: String::new(),
355 duration: Duration::default(),
356 });
357 }
358
359 if trimmed.eq_ignore_ascii_case(":reset") {
360 self.reset()?;
361 return Ok(ExecutionOutcome {
362 language: "haskell".to_string(),
363 exit_code: None,
364 stdout: String::new(),
365 stderr: String::new(),
366 duration: Duration::default(),
367 });
368 }
369
370 if trimmed.eq_ignore_ascii_case(":help") {
371 return Ok(ExecutionOutcome {
372 language: "haskell".to_string(),
373 exit_code: None,
374 stdout: "Haskell commands:\n :reset - clear session state\n :help - show this message\n"
375 .to_string(),
376 stderr: String::new(),
377 duration: Duration::default(),
378 });
379 }
380
381 match classify_snippet(trimmed) {
382 HaskellSnippet::Import => {
383 let (outcome, _) = self.apply_import(code)?;
384 Ok(outcome)
385 }
386 HaskellSnippet::Declaration => {
387 let (outcome, _) = self.apply_declaration(code)?;
388 Ok(outcome)
389 }
390 HaskellSnippet::Expression => {
391 let (outcome, _) = self.apply_expression(trimmed)?;
392 Ok(outcome)
393 }
394 HaskellSnippet::Statement => {
395 let (outcome, _) = self.apply_statement(code)?;
396 Ok(outcome)
397 }
398 }
399 }
400
401 fn shutdown(&mut self) -> Result<()> {
402 Ok(())
403 }
404}
405
406enum HaskellSnippet {
407 Import,
408 Declaration,
409 Statement,
410 Expression,
411}
412
413fn classify_snippet(code: &str) -> HaskellSnippet {
414 if is_import(code) {
415 return HaskellSnippet::Import;
416 }
417
418 if is_declaration(code) {
419 return HaskellSnippet::Declaration;
420 }
421
422 if should_wrap_expression(code) {
423 return HaskellSnippet::Expression;
424 }
425
426 HaskellSnippet::Statement
427}
428
429fn is_import(code: &str) -> bool {
430 code.lines()
431 .all(|line| line.trim_start().starts_with("import "))
432}
433
434fn is_declaration(code: &str) -> bool {
435 let trimmed = code.trim_start();
436 if trimmed.starts_with("let ") {
437 return false;
438 }
439 let lowered = trimmed.to_ascii_lowercase();
440 const PREFIXES: [&str; 8] = [
441 "module ",
442 "data ",
443 "type ",
444 "newtype ",
445 "class ",
446 "instance ",
447 "foreign ",
448 "default ",
449 ];
450 if PREFIXES.iter().any(|prefix| lowered.starts_with(prefix)) {
451 return true;
452 }
453
454 if trimmed.contains("::") {
455 return true;
456 }
457
458 if !trimmed.contains('=') {
459 return false;
460 }
461
462 if let Some(lhs) = trimmed.split('=').next() {
463 let lhs = lhs.trim();
464 if lhs.is_empty() {
465 return false;
466 }
467 let first_token = lhs.split_whitespace().next().unwrap_or("");
468 if first_token.eq_ignore_ascii_case("let") {
469 return false;
470 }
471 first_token
472 .chars()
473 .next()
474 .map(|c| c.is_alphabetic())
475 .unwrap_or(false)
476 } else {
477 false
478 }
479}
480
481fn should_wrap_expression(code: &str) -> bool {
482 if code.contains('\n') {
483 return false;
484 }
485
486 let trimmed = code.trim();
487 if trimmed.is_empty() {
488 return false;
489 }
490
491 let lowered = trimmed.to_ascii_lowercase();
492 const STATEMENT_PREFIXES: [&str; 11] = [
493 "let ",
494 "case ",
495 "if ",
496 "do ",
497 "import ",
498 "module ",
499 "data ",
500 "type ",
501 "newtype ",
502 "class ",
503 "instance ",
504 ];
505
506 if STATEMENT_PREFIXES
507 .iter()
508 .any(|prefix| lowered.starts_with(prefix))
509 {
510 return false;
511 }
512
513 if trimmed.contains('=') || trimmed.contains("->") || trimmed.contains("<-") {
514 return false;
515 }
516
517 true
518}
519
520fn ensure_trailing_newline(code: &str) -> String {
521 let mut owned = code.to_string();
522 if !owned.ends_with('\n') {
523 owned.push('\n');
524 }
525 owned
526}
527
528fn indent_block(code: &str) -> String {
529 let mut result = String::new();
530 for line in code.split_inclusive('\n') {
531 if line.ends_with('\n') {
532 result.push_str(" ");
533 result.push_str(line);
534 } else {
535 result.push_str(" ");
536 result.push_str(line);
537 result.push('\n');
538 }
539 }
540 result
541}
542
543fn wrap_expression(code: &str) -> String {
544 indent_block(&format!("print (({}))\n", code.trim()))
545}
546
547fn diff_output(previous: &str, current: &str) -> String {
548 if let Some(stripped) = current.strip_prefix(previous) {
549 stripped.to_string()
550 } else {
551 current.to_string()
552 }
553}
554
555fn normalize_output(bytes: &[u8]) -> String {
556 String::from_utf8_lossy(bytes)
557 .replace("\r\n", "\n")
558 .replace('\r', "")
559}