use std::collections::HashMap;
use std::path::Path;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use chrono::Utc;
use tokio::fs::OpenOptions;
pub const DEFAULT_HOOK_ENV_PASSTHROUGH: &[&str] =
&["HOME", "USER", "LANG", "SHELL", "SSH_AUTH_SOCK"];
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::Mutex;
use tracing::warn;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HookKind {
PreSession,
PostSession,
OnFailure,
}
impl HookKind {
pub const fn label(self) -> &'static str {
match self {
HookKind::PreSession => "pre_session",
HookKind::PostSession => "post_session",
HookKind::OnFailure => "on_failure",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HookOutcome {
Success,
Failed {
exit_code: i32,
},
Timeout {
secs: u64,
},
SpawnError(String),
}
impl HookOutcome {
pub fn is_success(&self) -> bool {
matches!(self, HookOutcome::Success)
}
pub fn description(&self) -> String {
match self {
HookOutcome::Success => "ok".to_string(),
HookOutcome::Failed { exit_code } => format!("non-zero exit {exit_code}"),
HookOutcome::Timeout { secs } => format!("timed out after {secs}s"),
HookOutcome::SpawnError(msg) => format!("spawn failed: {msg}"),
}
}
}
pub async fn run_hook(
kind: HookKind,
cmd: &str,
env: &HashMap<String, String>,
timeout: Duration,
transcript_path: &Path,
passthrough_extras: &[String],
) -> HookOutcome {
let label = kind.label();
let log = open_transcript(transcript_path).await;
write_banner_open(&log, label, cmd).await;
let mut command = Command::new("sh");
command
.arg("-c")
.arg(cmd)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true)
.env_clear()
.env(
"PATH",
std::env::var("PATH").unwrap_or_else(|_| "/usr/bin:/bin".to_string()),
);
for key in DEFAULT_HOOK_ENV_PASSTHROUGH {
if let Ok(val) = std::env::var(key) {
command.env(key, val);
}
}
for key in passthrough_extras {
if let Ok(val) = std::env::var(key) {
command.env(key, val);
}
}
for (k, v) in env {
command.env(k, v);
}
let mut child: Child = match command.spawn() {
Ok(c) => c,
Err(e) => {
let msg = format!("{e:#}");
warn!(hook = %label, error = %msg, "grind: hook spawn failed");
write_banner_close(&log, label, &HookOutcome::SpawnError(msg.clone())).await;
return HookOutcome::SpawnError(msg);
}
};
let stdout = child.stdout.take().expect("piped stdout");
let stderr = child.stderr.take().expect("piped stderr");
let stdout_task = tokio::spawn(forward(stdout, label, false, log.clone()));
let stderr_task = tokio::spawn(forward(stderr, label, true, log.clone()));
let outcome = tokio::select! {
_ = tokio::time::sleep(timeout) => {
let _ = child.start_kill();
let _ = child.wait().await;
stdout_task.abort();
stderr_task.abort();
HookOutcome::Timeout { secs: timeout.as_secs() }
}
status = child.wait() => match status {
Ok(s) if s.success() => HookOutcome::Success,
Ok(s) => HookOutcome::Failed { exit_code: s.code().unwrap_or(-1) },
Err(e) => HookOutcome::SpawnError(format!("wait failed: {e:#}")),
}
};
let _ = stdout_task.await;
let _ = stderr_task.await;
match &outcome {
HookOutcome::Timeout { secs } => {
warn!(
hook = %label,
timeout_secs = %secs,
"grind: hook timed out"
);
}
HookOutcome::Failed { exit_code } => {
warn!(
hook = %label,
exit_code = %exit_code,
"grind: hook exited non-zero"
);
}
HookOutcome::SpawnError(msg) => {
warn!(hook = %label, error = %msg, "grind: hook errored");
}
HookOutcome::Success => {}
}
write_banner_close(&log, label, &outcome).await;
outcome
}
type SharedLog = Option<Arc<Mutex<tokio::fs::File>>>;
async fn open_transcript(path: &Path) -> SharedLog {
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
if let Err(e) = tokio::fs::create_dir_all(parent).await {
warn!(
path = %path.display(),
error = %format!("{e:#}"),
"grind: failed to create hook transcript directory"
);
return None;
}
}
}
match OpenOptions::new()
.create(true)
.append(true)
.open(path)
.await
{
Ok(f) => Some(Arc::new(Mutex::new(f))),
Err(e) => {
warn!(
path = %path.display(),
error = %format!("{e:#}"),
"grind: failed to open hook transcript"
);
None
}
}
}
async fn write_banner_open(log: &SharedLog, label: &str, cmd: &str) {
let Some(log) = log.as_ref() else { return };
let ts = Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true);
let banner = format!("\n=== pitboss hook: {label} (cmd: {cmd}) [start {ts}] ===\n");
let mut f = log.lock().await;
let _ = f.write_all(banner.as_bytes()).await;
let _ = f.flush().await;
}
async fn write_banner_close(log: &SharedLog, label: &str, outcome: &HookOutcome) {
let Some(log) = log.as_ref() else { return };
let ts = Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true);
let banner = format!(
"=== pitboss hook: {label} {res} [end {ts}] ===\n",
res = outcome.description()
);
let mut f = log.lock().await;
let _ = f.write_all(banner.as_bytes()).await;
let _ = f.flush().await;
}
async fn forward<R>(reader: R, label: &'static str, is_stderr: bool, log: SharedLog)
where
R: tokio::io::AsyncRead + Unpin + Send + 'static,
{
let mut lines = BufReader::new(reader).lines();
while let Ok(Some(line)) = lines.next_line().await {
let prefix = if is_stderr {
format!("[hook:{label}:stderr] ")
} else {
format!("[hook:{label}] ")
};
if let Some(log) = log.as_ref() {
let mut f = log.lock().await;
let _ = f.write_all(prefix.as_bytes()).await;
let _ = f.write_all(line.as_bytes()).await;
let _ = f.write_all(b"\n").await;
}
}
}
#[cfg(all(test, unix))]
mod tests {
use super::*;
use std::collections::HashMap;
use tempfile::tempdir;
fn env(pairs: &[(&str, &str)]) -> HashMap<String, String> {
pairs
.iter()
.map(|(k, v)| ((*k).to_string(), (*v).to_string()))
.collect()
}
#[tokio::test]
async fn success_records_stdout_to_transcript_with_banner() {
let dir = tempdir().unwrap();
let log = dir.path().join("transcripts").join("session-0001.log");
let outcome = run_hook(
HookKind::PreSession,
"echo hello-from-hook",
&env(&[("PITBOSS_SESSION_PROMPT", "alpha")]),
Duration::from_secs(5),
&log,
&[],
)
.await;
assert_eq!(outcome, HookOutcome::Success);
let body = std::fs::read_to_string(&log).unwrap();
assert!(
body.contains("=== pitboss hook: pre_session"),
"missing open banner: {body}"
);
assert!(
body.contains("[hook:pre_session] hello-from-hook"),
"missing forwarded stdout: {body}"
);
assert!(
body.contains("=== pitboss hook: pre_session ok"),
"missing close banner: {body}"
);
}
#[tokio::test]
async fn non_zero_exit_yields_failed_outcome() {
let dir = tempdir().unwrap();
let log = dir.path().join("transcripts").join("session-0001.log");
let outcome = run_hook(
HookKind::PostSession,
"echo about to fail; exit 9",
&env(&[]),
Duration::from_secs(5),
&log,
&[],
)
.await;
assert_eq!(outcome, HookOutcome::Failed { exit_code: 9 });
let body = std::fs::read_to_string(&log).unwrap();
assert!(body.contains("about to fail"));
assert!(body.contains("non-zero exit 9"));
}
#[tokio::test]
async fn stderr_is_forwarded_with_dedicated_prefix() {
let dir = tempdir().unwrap();
let log = dir.path().join("session.log");
let outcome = run_hook(
HookKind::OnFailure,
"echo oh-no 1>&2",
&env(&[]),
Duration::from_secs(5),
&log,
&[],
)
.await;
assert!(outcome.is_success());
let body = std::fs::read_to_string(&log).unwrap();
assert!(
body.contains("[hook:on_failure:stderr] oh-no"),
"expected stderr prefix in transcript: {body}"
);
}
#[tokio::test]
async fn timeout_kills_long_running_hook() {
let dir = tempdir().unwrap();
let log = dir.path().join("session.log");
let start = std::time::Instant::now();
let outcome = run_hook(
HookKind::PreSession,
"sleep 5",
&env(&[]),
Duration::from_secs(1),
&log,
&[],
)
.await;
let elapsed = start.elapsed();
assert_eq!(outcome, HookOutcome::Timeout { secs: 1 });
assert!(
elapsed < Duration::from_secs(4),
"timeout should kill the child quickly, got elapsed={:?}",
elapsed
);
let body = std::fs::read_to_string(&log).unwrap();
assert!(
body.contains("timed out after 1s"),
"missing timeout banner: {body}"
);
}
#[tokio::test]
async fn env_vars_reach_the_child_process() {
let dir = tempdir().unwrap();
let log = dir.path().join("session.log");
let outcome = run_hook(
HookKind::PostSession,
"printf 'prompt=%s status=%s summary=%s' \"$PITBOSS_SESSION_PROMPT\" \"$PITBOSS_SESSION_STATUS\" \"$PITBOSS_SESSION_SUMMARY\"",
&env(&[
("PITBOSS_SESSION_PROMPT", "fp-hunter"),
("PITBOSS_SESSION_STATUS", "ok"),
("PITBOSS_SESSION_SUMMARY", "did the thing"),
]),
Duration::from_secs(5),
&log,
&[],
)
.await;
assert!(outcome.is_success());
let body = std::fs::read_to_string(&log).unwrap();
assert!(
body.contains("prompt=fp-hunter status=ok summary=did the thing"),
"env not reaching child: {body}"
);
}
#[tokio::test]
async fn passthrough_extras_forward_named_parent_env_to_child() {
let key = "PITBOSS_HOOK_PASSTHROUGH_TEST_OK";
std::env::set_var(key, "from-parent");
let dir = tempdir().unwrap();
let log = dir.path().join("session.log");
let outcome = run_hook(
HookKind::PostSession,
"printf 'val=%s' \"$PITBOSS_HOOK_PASSTHROUGH_TEST_OK\"",
&env(&[]),
Duration::from_secs(5),
&log,
&[key.to_string()],
)
.await;
std::env::remove_var(key);
assert!(outcome.is_success(), "hook outcome: {outcome:?}");
let body = std::fs::read_to_string(&log).unwrap();
assert!(
body.contains("val=from-parent"),
"passthrough did not reach child: {body}"
);
}
#[tokio::test]
async fn passthrough_does_not_forward_unlisted_parent_env() {
let key = "PITBOSS_HOOK_PASSTHROUGH_TEST_LEAK";
std::env::set_var(key, "should-not-leak");
let dir = tempdir().unwrap();
let log = dir.path().join("session.log");
let outcome = run_hook(
HookKind::PostSession,
"printf 'val=[%s]' \"$PITBOSS_HOOK_PASSTHROUGH_TEST_LEAK\"",
&env(&[]),
Duration::from_secs(5),
&log,
&[],
)
.await;
std::env::remove_var(key);
assert!(outcome.is_success(), "hook outcome: {outcome:?}");
let body = std::fs::read_to_string(&log).unwrap();
assert!(
body.contains("val=[]"),
"unlisted parent env leaked: {body}"
);
}
#[tokio::test]
async fn hook_outcome_descriptions_are_human_readable() {
assert_eq!(HookOutcome::Success.description(), "ok");
assert_eq!(
HookOutcome::Failed { exit_code: 2 }.description(),
"non-zero exit 2"
);
assert_eq!(
HookOutcome::Timeout { secs: 30 }.description(),
"timed out after 30s"
);
assert_eq!(
HookOutcome::SpawnError("boom".into()).description(),
"spawn failed: boom"
);
}
}