use std::path::Path;
use std::time::{Duration, Instant};
use serde::Serialize;
use thiserror::Error;
use tokio::process::Command;
use tracing::Instrument;
use crate::context::{IssueRun, IssueStage};
use crate::logging::Phase;
use crate::shell::{CommandExecError, CommandExt};
use crate::template::{JinjaRenderer, TemplateError};
const HOOK_TIMEOUT: Duration = Duration::from_secs(30);
const STDERR_TAIL_BYTES: usize = 2048;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum HookKind {
AfterIssueWorkdirCreate,
BeforeIssueStageRun,
AfterIssueStageRun,
}
impl HookKind {
pub fn as_str(&self) -> &'static str {
match self {
HookKind::AfterIssueWorkdirCreate => "after_issue_workdir_create",
HookKind::BeforeIssueStageRun => "before_issue_stage_run",
HookKind::AfterIssueStageRun => "after_issue_stage_run",
}
}
}
#[derive(Debug, Error)]
pub enum HookError {
#[error("hook `{hook}` template render failed: {source}")]
Render {
hook: &'static str,
#[source]
source: TemplateError,
},
#[error("hook `{hook}` shell execution failed: {source}")]
Exec {
hook: &'static str,
#[source]
source: CommandExecError,
},
#[error("hook `{hook}` exited with non-zero status {code}: {stderr_tail}")]
NonZeroExit {
hook: &'static str,
code: i32,
stderr_tail: String,
},
}
#[derive(Debug, Clone)]
pub struct HookRunner {
renderer: JinjaRenderer,
timeout: Duration,
}
impl Default for HookRunner {
fn default() -> Self {
Self::new()
}
}
impl HookRunner {
pub fn new() -> Self {
Self {
renderer: JinjaRenderer::new(),
timeout: HOOK_TIMEOUT,
}
}
#[allow(dead_code)]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[inline]
pub async fn after_issue_workdir_created(&self, issue: &IssueRun, hook: &Option<String>) -> Result<(), HookError> {
self
.schedule_inner(HookKind::AfterIssueWorkdirCreate, issue.workdir(), hook, issue)
.await
}
#[inline]
pub async fn before_issue_stage_run(&self, stage: &IssueStage, hook: &Option<String>) -> Result<(), HookError> {
self
.schedule_inner(HookKind::BeforeIssueStageRun, stage.workdir(), hook, stage)
.await
}
#[inline]
pub async fn after_issue_stage_run(&self, stage: &IssueStage, hook: &Option<String>) -> Result<(), HookError> {
self
.schedule_inner(HookKind::AfterIssueStageRun, stage.workdir(), hook, stage)
.await
}
async fn schedule_inner<Context: Serialize>(
&self,
kind: HookKind,
cwd: &Path,
hook: &Option<String>,
context: Context,
) -> Result<(), HookError> {
let hook_name = kind.as_str();
let _span = tracing::info_span!(
"hook",
phase = %Phase::Hook,
hook = %hook_name,
);
let command = match hook {
Some(body) => body,
None => {
tracing::debug!("hook not configured; skipping execution");
return Ok(());
},
};
let command = self.render_hook_command(kind, command, context)?;
self.run_command(kind, cwd, command).in_current_span().await
}
fn render_hook_command<Context: Serialize>(
&self,
kind: HookKind,
command: &str,
context: Context,
) -> Result<String, HookError> {
self.renderer.render(command, context).map_err(|e| HookError::Render {
hook: kind.as_str(),
source: e,
})
}
async fn run_command(&self, kind: HookKind, cwd: &Path, command: String) -> Result<(), HookError> {
let started = Instant::now();
tracing::debug!(cwd = %cwd.display(), "hook shell starting");
let output = match shell_command(&command).current_dir(cwd).timeout(self.timeout).output().await {
Ok(output) => output,
Err(source) => {
let duration = started.elapsed().as_millis();
tracing::error!(duration, error = %source, "hook shell exec errored");
return Err(HookError::Exec {
hook: kind.as_str(),
source,
});
},
};
let duration = started.elapsed().as_millis();
if output.status.success() {
tracing::info!(duration, "hook completed");
return Ok(());
}
let code = output.status.code().unwrap_or(-1);
let stderr_tail = tail_utf8(&output.stderr, STDERR_TAIL_BYTES);
let error = HookError::NonZeroExit {
hook: kind.as_str(),
code,
stderr_tail,
};
tracing::error!(duration, error = %error, "hook exited non-zero");
Err(error)
}
}
fn tail_utf8(bytes: &[u8], limit: usize) -> String {
if bytes.len() <= limit {
return String::from_utf8_lossy(bytes).into_owned();
}
let start = bytes.len() - limit;
String::from_utf8_lossy(&bytes[start..]).into_owned()
}
#[cfg(windows)]
fn shell_command(body: &str) -> Command {
let mut cmd = Command::new("cmd");
cmd.args(["/C", body]);
cmd
}
#[cfg(not(windows))]
fn shell_command(body: &str) -> Command {
let mut cmd = Command::new("sh");
cmd.args(["-c", body]);
cmd
}