use std::collections::VecDeque;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader};
pub const DEFAULT_MAX_TEST_DURATION_SECS: u64 = 600;
pub const TAIL_LINE_CAP: usize = 500;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CommandOutput {
pub exit_code: Option<i32>,
pub success: bool,
pub stdout_tail: String,
pub stderr_tail: String,
}
#[derive(Debug, thiserror::Error)]
pub enum CommandError {
#[error("command timed out after {0:?}")]
Timeout(Duration),
#[error("command io error: {0}")]
Io(String),
}
#[async_trait]
pub trait CommandRunner: Send + Sync {
async fn run(
&self,
cmd: &str,
args: &[&str],
cwd: &Path,
timeout: Duration,
) -> Result<CommandOutput, CommandError>;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TestRunOutcome {
pub passed: bool,
pub exit_code: Option<i32>,
pub stdout_tail: String,
pub stderr_tail: String,
pub wall_time: Duration,
pub timed_out: bool,
}
#[derive(Debug, Clone)]
pub struct GateConfig {
pub repo_root: PathBuf,
pub merge_sha: String,
pub max_test_duration: Duration,
pub revert_push_remote: Option<String>,
pub revert_push_branch: String,
}
impl GateConfig {
pub fn new(repo_root: PathBuf, merge_sha: String) -> Self {
Self {
repo_root,
merge_sha,
max_test_duration: Duration::from_secs(DEFAULT_MAX_TEST_DURATION_SECS),
revert_push_remote: Some("origin".to_string()),
revert_push_branch: "main".to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RevertOutcome {
pub revert_sha: String,
pub pushed: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FailureKind {
TestFailure,
HarnessError,
Timeout,
Unknown,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FailureClassification {
pub kind: FailureKind,
pub failing_tests: Vec<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum GateError {
#[error("command error: {0}")]
Command(#[from] CommandError),
#[error("revert failed: {0}")]
Revert(String),
}
pub async fn run_workspace_tests<R: CommandRunner + ?Sized>(
runner: &R,
cfg: &GateConfig,
) -> Result<TestRunOutcome, GateError> {
let started = std::time::Instant::now();
let result = runner
.run(
"cargo",
&["test", "--workspace", "--no-fail-fast"],
&cfg.repo_root,
cfg.max_test_duration,
)
.await;
let wall_time = started.elapsed();
match result {
Ok(out) => Ok(TestRunOutcome {
passed: out.success,
exit_code: out.exit_code,
stdout_tail: out.stdout_tail,
stderr_tail: out.stderr_tail,
wall_time,
timed_out: false,
}),
Err(CommandError::Timeout(_)) => Ok(TestRunOutcome {
passed: false,
exit_code: None,
stdout_tail: String::new(),
stderr_tail: format!(
"post_merge_gate: cargo test exceeded {:?}; child killed",
cfg.max_test_duration
),
wall_time,
timed_out: true,
}),
Err(e) => Err(GateError::Command(e)),
}
}
pub fn classify_failure(outcome: &TestRunOutcome) -> FailureClassification {
if outcome.timed_out {
return FailureClassification {
kind: FailureKind::Timeout,
failing_tests: Vec::new(),
};
}
if outcome.passed {
return FailureClassification {
kind: FailureKind::Unknown,
failing_tests: Vec::new(),
};
}
let combined = format!("{}\n{}", outcome.stdout_tail, outcome.stderr_tail);
let failing = parse_failing_tests(&combined);
let looks_like_harness = combined.contains("error: could not compile")
|| combined.contains("error: no such subcommand")
|| combined.contains("error: failed to compile")
|| combined.contains("error: Command")
|| combined.contains("error[E");
let kind = if !failing.is_empty() {
FailureKind::TestFailure
} else if looks_like_harness {
FailureKind::HarnessError
} else {
FailureKind::Unknown
};
FailureClassification {
kind,
failing_tests: failing,
}
}
fn parse_failing_tests(output: &str) -> Vec<String> {
let mut failing = Vec::new();
let mut in_failures_block = false;
for line in output.lines() {
let trimmed = line.trim();
if trimmed == "failures:" {
in_failures_block = true;
continue;
}
if in_failures_block {
if trimmed.starts_with("test result:") {
in_failures_block = false;
continue;
}
if trimmed.is_empty() {
continue;
}
let name = trimmed.to_string();
if !name.starts_with("----") && !failing.contains(&name) {
failing.push(name);
}
}
}
failing
}
pub async fn revert_merge<R: CommandRunner + ?Sized>(
runner: &R,
cfg: &GateConfig,
) -> Result<RevertOutcome, GateError> {
let short_timeout = Duration::from_secs(60);
let revert_out = runner
.run(
"git",
&["revert", "--no-edit", "-m", "1", cfg.merge_sha.as_str()],
&cfg.repo_root,
short_timeout,
)
.await?;
if !revert_out.success {
let plain = runner
.run(
"git",
&["revert", "--no-edit", cfg.merge_sha.as_str()],
&cfg.repo_root,
short_timeout,
)
.await?;
if !plain.success {
return Err(GateError::Revert(format!(
"git revert exited with code {:?}: {}",
plain.exit_code.or(revert_out.exit_code),
plain.stderr_tail
)));
}
}
let rev_parse = runner
.run("git", &["rev-parse", "HEAD"], &cfg.repo_root, short_timeout)
.await?;
if !rev_parse.success {
return Err(GateError::Revert(format!(
"git rev-parse HEAD failed: {}",
rev_parse.stderr_tail
)));
}
let revert_sha = rev_parse.stdout_tail.trim().to_string();
let pushed = match &cfg.revert_push_remote {
Some(remote) => {
let refspec = format!("HEAD:{}", cfg.revert_push_branch);
let push_out = runner
.run(
"git",
&["push", remote.as_str(), refspec.as_str()],
&cfg.repo_root,
short_timeout,
)
.await?;
push_out.success
}
None => false,
};
Ok(RevertOutcome { revert_sha, pushed })
}
pub struct TokioCommandRunner;
#[async_trait]
impl CommandRunner for TokioCommandRunner {
async fn run(
&self,
cmd: &str,
args: &[&str],
cwd: &Path,
timeout: Duration,
) -> Result<CommandOutput, CommandError> {
let mut child = tokio::process::Command::new(cmd)
.args(args)
.current_dir(cwd)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| CommandError::Io(format!("failed to spawn {cmd}: {e}")))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| CommandError::Io("child stdout pipe missing".to_string()))?;
let stderr = child
.stderr
.take()
.ok_or_else(|| CommandError::Io("child stderr pipe missing".to_string()))?;
let stdout_task = tokio::spawn(tail_stream(stdout, TAIL_LINE_CAP));
let stderr_task = tokio::spawn(tail_stream(stderr, TAIL_LINE_CAP));
let wait_fut = child.wait();
let status_res = tokio::time::timeout(timeout, wait_fut).await;
match status_res {
Ok(Ok(status)) => {
let stdout_tail = stdout_task.await.unwrap_or_default();
let stderr_tail = stderr_task.await.unwrap_or_default();
Ok(CommandOutput {
exit_code: status.code(),
success: status.success(),
stdout_tail,
stderr_tail,
})
}
Ok(Err(e)) => Err(CommandError::Io(format!("child wait failed: {e}"))),
Err(_) => {
let _ = child.start_kill();
let _ = child.wait().await;
let _ = stdout_task.await;
let _ = stderr_task.await;
Err(CommandError::Timeout(timeout))
}
}
}
}
async fn tail_stream<R: AsyncRead + Unpin>(reader: R, max_lines: usize) -> String {
let mut lines = BufReader::new(reader).lines();
let mut ring: VecDeque<String> = VecDeque::with_capacity(max_lines + 1);
while let Ok(Some(line)) = lines.next_line().await {
if ring.len() >= max_lines {
ring.pop_front();
}
ring.push_back(line);
}
ring.into_iter().collect::<Vec<_>>().join("\n")
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CallRecord {
pub cmd: String,
pub args: Vec<String>,
pub cwd: PathBuf,
}
#[derive(Clone, Default)]
pub struct ScriptedRunner {
responses: Arc<Mutex<VecDeque<Result<CommandOutput, CommandError>>>>,
calls: Arc<Mutex<Vec<CallRecord>>>,
}
impl ScriptedRunner {
pub fn new() -> Self {
Self::default()
}
pub fn push_ok(&self, code: i32, stdout: &str, stderr: &str) {
self.responses.lock().unwrap().push_back(Ok(CommandOutput {
exit_code: Some(code),
success: code == 0,
stdout_tail: stdout.to_string(),
stderr_tail: stderr.to_string(),
}));
}
pub fn push_err(&self, err: CommandError) {
self.responses.lock().unwrap().push_back(Err(err));
}
pub fn calls(&self) -> Vec<CallRecord> {
self.calls.lock().unwrap().clone()
}
}
#[async_trait]
impl CommandRunner for ScriptedRunner {
async fn run(
&self,
cmd: &str,
args: &[&str],
cwd: &Path,
_timeout: Duration,
) -> Result<CommandOutput, CommandError> {
self.calls.lock().unwrap().push(CallRecord {
cmd: cmd.to_string(),
args: args.iter().map(|s| s.to_string()).collect(),
cwd: cwd.to_path_buf(),
});
let mut queue = self.responses.lock().unwrap();
queue.pop_front().unwrap_or_else(|| {
Ok(CommandOutput {
exit_code: Some(0),
success: true,
stdout_tail: String::new(),
stderr_tail: String::new(),
})
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn base_cfg() -> GateConfig {
let mut cfg = GateConfig::new(PathBuf::from("/tmp/fake"), "deadbeef".to_string());
cfg.revert_push_remote = None;
cfg
}
#[tokio::test]
async fn run_workspace_tests_reports_green_when_exit_zero() {
let runner = ScriptedRunner::new();
runner.push_ok(0, "", "");
let cfg = base_cfg();
let out = run_workspace_tests(&runner, &cfg).await.unwrap();
assert!(out.passed);
assert!(!out.timed_out);
assert_eq!(out.exit_code, Some(0));
let calls = runner.calls();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].cmd, "cargo");
assert_eq!(calls[0].args, vec!["test", "--workspace", "--no-fail-fast"]);
}
#[tokio::test]
async fn run_workspace_tests_marks_timeout() {
let runner = ScriptedRunner::new();
runner.push_err(CommandError::Timeout(Duration::from_secs(600)));
let out = run_workspace_tests(&runner, &base_cfg()).await.unwrap();
assert!(out.timed_out);
assert!(!out.passed);
}
#[tokio::test]
async fn run_workspace_tests_propagates_io_error() {
let runner = ScriptedRunner::new();
runner.push_err(CommandError::Io("no such file".to_string()));
let res = run_workspace_tests(&runner, &base_cfg()).await;
assert!(matches!(res, Err(GateError::Command(_))));
}
#[test]
fn classify_failure_parses_test_failures() {
let outcome = TestRunOutcome {
passed: false,
exit_code: Some(101),
stdout_tail: "\
running 3 tests
test foo ... ok
test bar ... FAILED
failures:
mod::bar
mod::baz
test result: FAILED. 1 passed; 2 failed
"
.to_string(),
stderr_tail: String::new(),
wall_time: Duration::from_secs(1),
timed_out: false,
};
let c = classify_failure(&outcome);
assert_eq!(c.kind, FailureKind::TestFailure);
assert_eq!(c.failing_tests, vec!["mod::bar", "mod::baz"]);
}
#[test]
fn classify_failure_detects_harness_error() {
let outcome = TestRunOutcome {
passed: false,
exit_code: Some(101),
stdout_tail: String::new(),
stderr_tail: "error: could not compile `foo` due to previous error".to_string(),
wall_time: Duration::from_secs(1),
timed_out: false,
};
let c = classify_failure(&outcome);
assert_eq!(c.kind, FailureKind::HarnessError);
assert!(c.failing_tests.is_empty());
}
#[test]
fn classify_failure_detects_timeout() {
let outcome = TestRunOutcome {
passed: false,
exit_code: None,
stdout_tail: String::new(),
stderr_tail: String::new(),
wall_time: Duration::from_secs(600),
timed_out: true,
};
let c = classify_failure(&outcome);
assert_eq!(c.kind, FailureKind::Timeout);
}
#[tokio::test]
async fn revert_merge_captures_new_sha_no_push() {
let runner = ScriptedRunner::new();
runner.push_ok(0, "", "");
runner.push_ok(0, "abc123def456\n", "");
let out = revert_merge(&runner, &base_cfg()).await.unwrap();
assert_eq!(out.revert_sha, "abc123def456");
assert!(!out.pushed);
let calls = runner.calls();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].args[0], "revert");
assert_eq!(calls[1].args[0], "rev-parse");
}
#[tokio::test]
async fn revert_merge_pushes_when_remote_configured() {
let runner = ScriptedRunner::new();
runner.push_ok(0, "", ""); runner.push_ok(0, "cafef00d\n", ""); runner.push_ok(0, "", "");
let mut cfg = base_cfg();
cfg.revert_push_remote = Some("origin".to_string());
cfg.revert_push_branch = "main".to_string();
let out = revert_merge(&runner, &cfg).await.unwrap();
assert_eq!(out.revert_sha, "cafef00d");
assert!(out.pushed);
let calls = runner.calls();
assert_eq!(calls.len(), 3);
assert_eq!(calls[2].cmd, "git");
assert_eq!(calls[2].args, vec!["push", "origin", "HEAD:main"]);
}
#[tokio::test]
async fn revert_merge_fails_when_all_revert_paths_fail() {
let runner = ScriptedRunner::new();
runner.push_ok(1, "", "fatal: not a merge");
runner.push_ok(1, "", "fatal: commit not found");
let res = revert_merge(&runner, &base_cfg()).await;
assert!(matches!(res, Err(GateError::Revert(_))));
}
}