#![deny(missing_docs)]
#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
#![doc(html_root_url = "https://docs.rs/ghci/0.2.0")]
use core::time::Duration;
use nix::poll::{poll, PollFd, PollFlags, PollTimeout};
use nonblock::NonBlockingReader;
use std::io::{ErrorKind, LineWriter, Read, Write};
use std::os::fd::{AsRawFd, BorrowedFd, RawFd};
use std::path::{Path, PathBuf};
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command, Stdio};
use std::sync::{Mutex, MutexGuard, OnceLock};
use std::time::Instant;
pub mod haskell;
pub use haskell::{FromHaskell, HaskellParseError, ToHaskell};
pub struct Ghci {
child: Child,
stdin: LineWriter<ChildStdin>,
stdout: NonBlockingReader<ChildStdout>,
stdout_fd: RawFd,
stderr: NonBlockingReader<ChildStderr>,
stderr_fd: RawFd,
timeout: Option<Duration>,
}
#[derive(Debug)]
#[non_exhaustive]
pub struct EvalOutput {
pub stdout: String,
pub stderr: String,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum GhciError {
#[error("ghci session timed out waiting on output")]
Timeout,
#[error("IO error: {0}")]
IOError(#[from] std::io::Error),
#[error("Poll error: {0}")]
PollError(#[from] nix::errno::Errno),
#[error("ghci eval error:\n{stderr}")]
EvalError {
stdout: String,
stderr: String,
},
#[error("Haskell parse error: {0}")]
HaskellParse(#[from] haskell::HaskellParseError),
#[error("disallowed input: {0}")]
DisallowedInput(&'static str),
}
pub type Result<T> = std::result::Result<T, GhciError>;
const PROMPT: &str = "__ghci_rust_prompt__>\n";
pub struct GhciBuilder {
ghci_path: Option<String>,
args: Vec<String>,
working_dir: Option<PathBuf>,
}
impl Default for GhciBuilder {
fn default() -> Self {
Self::new()
}
}
impl GhciBuilder {
#[must_use]
pub const fn new() -> Self {
Self {
ghci_path: None,
args: Vec::new(),
working_dir: None,
}
}
#[must_use]
pub fn ghci_path(mut self, path: impl Into<String>) -> Self {
self.ghci_path = Some(path.into());
self
}
#[must_use]
pub fn arg(mut self, arg: impl Into<String>) -> Self {
self.args.push(arg.into());
self
}
#[must_use]
pub fn args(mut self, args: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.args.extend(args.into_iter().map(Into::into));
self
}
#[must_use]
pub fn working_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.working_dir = Some(path.into());
self
}
pub fn build(self) -> Result<Ghci> {
const PIPE_ERR: &str = "pipe should be present";
let ghci_path = self
.ghci_path
.or_else(|| std::env::var("GHCI_PATH").ok())
.unwrap_or_else(|| "ghci".to_string());
let mut cmd = Command::new(ghci_path);
cmd.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
cmd.arg("-ignore-dot-ghci");
if !self.args.is_empty() {
cmd.args(&self.args);
}
if let Some(dir) = self.working_dir {
cmd.current_dir(dir);
}
let mut child = cmd.spawn()?;
let mut stdin = LineWriter::new(child.stdin.take().expect(PIPE_ERR));
let mut stdout = child.stdout.take().expect(PIPE_ERR);
let stderr = child.stderr.take().expect(PIPE_ERR);
clear_blocking_reader_until(&mut stdout, b"> ")?;
stdin.write_all(b":set prompt \"")?;
stdin.write_all(&PROMPT.as_bytes()[..PROMPT.len() - 1])?;
stdin.write_all(b"\\n\"\n")?;
clear_blocking_reader_until(&mut stdout, PROMPT.as_bytes())?;
stdin.write_all(b":set prompt-cont \"\"\n")?;
clear_blocking_reader_until(&mut stdout, PROMPT.as_bytes())?;
Ok(Ghci {
stdin,
stdout_fd: stdout.as_raw_fd(),
stdout: NonBlockingReader::from_fd(stdout)?,
stderr_fd: stderr.as_raw_fd(),
stderr: NonBlockingReader::from_fd(stderr)?,
child,
timeout: None,
})
}
}
impl Ghci {
pub fn new() -> Result<Self> {
GhciBuilder::new().build()
}
pub fn eval(&mut self, input: &str) -> Result<String> {
let out = self.eval_raw(input)?;
if out.stderr.is_empty() {
Ok(out.stdout)
} else {
Err(GhciError::EvalError {
stdout: out.stdout,
stderr: out.stderr,
})
}
}
pub fn eval_as<T: FromHaskell>(&mut self, input: &str) -> Result<T> {
let output = self.eval(input)?;
Ok(T::from_haskell(output.trim_end_matches('\n'))?)
}
pub fn eval_raw(&mut self, input: &str) -> Result<EvalOutput> {
if input.trim_start().starts_with(":set prompt") {
return Err(GhciError::DisallowedInput(
":set prompt and :set prompt-cont are managed by ghci-rs and cannot be changed",
));
}
self.stdin.write_all(b":{\n")?;
self.stdin.write_all(input.as_bytes())?;
self.stdin.write_all(b"\n:}\n")?;
let mut stdout = String::new();
let mut stderr = String::new();
let deadline = self.timeout.map(|d| Instant::now() + d);
loop {
let stderr_fd = unsafe { BorrowedFd::borrow_raw(self.stderr_fd) };
let stdout_fd = unsafe { BorrowedFd::borrow_raw(self.stdout_fd) };
let mut poll_fds = [
PollFd::new(stderr_fd, PollFlags::POLLIN),
PollFd::new(stdout_fd, PollFlags::POLLIN),
];
let poll_timeout = match deadline {
None => PollTimeout::NONE,
Some(dl) => {
let remaining = dl.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Err(GhciError::Timeout);
}
remaining
.as_millis()
.try_into()
.ok()
.and_then(|ms: i32| PollTimeout::try_from(ms).ok())
.unwrap_or(PollTimeout::NONE)
}
};
let ret = poll(&mut poll_fds, poll_timeout)?;
if ret == 0 {
return Err(GhciError::Timeout);
}
if poll_fds[0].any() == Some(true) {
self.stderr.read_available_to_string(&mut stderr)?;
}
if poll_fds[1].any() == Some(true) {
self.stdout.read_available_to_string(&mut stdout)?;
if stdout.ends_with(PROMPT) {
stdout.truncate(stdout.len() - PROMPT.len());
break;
}
}
}
Ok(EvalOutput { stdout, stderr })
}
#[inline]
pub const fn set_timeout(&mut self, timeout: Option<Duration>) {
self.timeout = timeout;
}
#[inline]
pub fn import(&mut self, modules: &[&str]) -> Result<()> {
let mut line = String::from(":module ");
line.push_str(&modules.join(" "));
self.eval(&line)?;
Ok(())
}
#[inline]
pub fn load(&mut self, paths: &[&Path]) -> Result<()> {
let mut line = String::from(":load");
for path in paths {
use std::fmt::Write as _;
let _ = write!(line, " {}", path.display());
}
self.eval(&line)?;
Ok(())
}
#[inline]
pub fn close(mut self) -> Result<()> {
Ok(self.child.kill()?)
}
}
impl Drop for Ghci {
fn drop(&mut self) {
if self.child.try_wait().unwrap().is_none() {
self.child.kill().unwrap();
}
}
}
pub struct SharedGhci {
inner: OnceLock<Mutex<Ghci>>,
init: fn() -> Result<Ghci>,
}
impl SharedGhci {
#[must_use]
pub const fn new(init: fn() -> Result<Ghci>) -> Self {
Self {
inner: OnceLock::new(),
init,
}
}
pub fn lock(&self) -> MutexGuard<'_, Ghci> {
self.inner
.get_or_init(|| {
let ghci = (self.init)().expect("SharedGhci initialization failed");
Mutex::new(ghci)
})
.lock()
.expect("SharedGhci mutex poisoned")
}
}
fn clear_blocking_reader_until(mut r: impl Read, expected_end: &[u8]) -> std::io::Result<()> {
let mut buffer = [0; 1024];
let mut end = 0;
loop {
match r.read(&mut buffer[end..]) {
Ok(0) => return Ok(()),
Ok(bytes) => {
end += bytes;
if buffer[..end].ends_with(expected_end) {
return Ok(());
}
}
Err(err) if err.kind() == ErrorKind::Interrupted => {}
Err(err) => return Err(err),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_error() {
let mut ghci = Ghci::new().unwrap();
let res = ghci.eval("x ::");
match res {
Err(GhciError::EvalError { stderr, .. }) => {
assert!(stderr.contains("parse error"));
}
other => panic!("expected EvalError, got {other:?}"),
}
}
#[test]
fn parse_error_raw() {
let mut ghci = Ghci::new().unwrap();
let res = ghci.eval_raw("x ::").unwrap();
assert!(res.stderr.contains("parse error"));
}
#[test]
fn eval_as_integer() -> Result<()> {
let mut ghci = Ghci::new()?;
let x: i32 = ghci.eval_as("1 + 1")?;
assert_eq!(x, 2);
Ok(())
}
#[test]
fn eval_as_boolean() -> Result<()> {
let mut ghci = Ghci::new()?;
let b: bool = ghci.eval_as("True")?;
assert_eq!(b, true);
Ok(())
}
#[test]
fn eval_as_string() -> Result<()> {
let mut ghci = Ghci::new()?;
let s: String = ghci.eval_as(r#""hello" ++ " world""#)?;
assert_eq!(s, "hello world");
Ok(())
}
#[test]
fn eval_as_option() -> Result<()> {
let mut ghci = Ghci::new()?;
let opt: Option<i32> = ghci.eval_as("(Just 42)")?;
assert_eq!(opt, Some(42));
Ok(())
}
#[test]
fn eval_as_vec() -> Result<()> {
let mut ghci = Ghci::new()?;
let vec: Vec<i32> = ghci.eval_as("[1, 2, 3]")?;
assert_eq!(vec, vec![1, 2, 3]);
Ok(())
}
#[test]
fn disallow_set_prompt() {
let mut ghci = Ghci::new().unwrap();
let res = ghci.eval(":set prompt \"foo> \"");
assert!(
matches!(res, Err(GhciError::DisallowedInput(_))),
"expected DisallowedInput, got {res:?}"
);
}
#[test]
fn timeout_on_infinite_output() {
let mut ghci = Ghci::new().unwrap();
ghci.set_timeout(Some(Duration::from_millis(200)));
let res = ghci.eval("mapM_ print [1..]");
assert!(
matches!(res, Err(GhciError::Timeout)),
"expected Timeout, got {res:?}"
);
}
}