use anyhow::{anyhow, Context, Result};
use std::io::BufRead;
use std::path::Path;
use std::process::Command;
const NULL_SHA: &str = "0000000000000000000000000000000000000000";
const DEFAULT_PROTECTED_BRANCHES: &[&str] = &[
"main",
"master",
"prod",
"production",
"release",
"release/*",
"prod/*",
"hotfix/*",
];
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RefUpdate {
pub local_ref: String,
pub local_sha: String,
pub remote_ref: String,
pub remote_sha: String,
}
#[derive(Debug, Clone)]
pub enum PushVerdict {
Ok,
Deletion {
protected_branch: String,
},
ForcePush {
protected_branch: String,
remote_sha: String,
local_sha: String,
},
}
#[derive(Debug, Default)]
pub struct CheckPushedReport {
pub refs_inspected: usize,
pub violations: Vec<(RefUpdate, PushVerdict)>,
}
impl CheckPushedReport {
pub fn exit_code(&self) -> u8 {
if self.violations.is_empty() {
0
} else {
1
}
}
}
pub fn parse_line(line: &str) -> Option<RefUpdate> {
let mut iter = line.split_whitespace();
let local_ref = iter.next()?.to_string();
let local_sha = iter.next()?.to_string();
let remote_ref = iter.next()?.to_string();
let remote_sha = iter.next()?.to_string();
Some(RefUpdate {
local_ref,
local_sha,
remote_ref,
remote_sha,
})
}
pub fn protected_patterns() -> Vec<String> {
if let Ok(raw) = std::env::var("SHIELD_PROTECTED_BRANCHES") {
raw.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
} else {
DEFAULT_PROTECTED_BRANCHES
.iter()
.map(|s| (*s).to_string())
.collect()
}
}
pub fn pattern_matches(pattern: &str, short_name: &str) -> bool {
if let Some(prefix) = pattern.strip_suffix("/*") {
return short_name.starts_with(&format!("{}/", prefix));
}
pattern == short_name
}
fn short_name(full_ref: &str) -> &str {
full_ref.strip_prefix("refs/heads/").unwrap_or(full_ref)
}
pub fn is_protected(remote_ref: &str, patterns: &[String]) -> Option<String> {
let s = short_name(remote_ref);
for p in patterns {
if pattern_matches(p, s) {
return Some(s.to_string());
}
}
None
}
fn is_ancestor(
repo_root: &Path,
ancestor_sha: &str,
descendant_sha: &str,
) -> Result<bool> {
if ancestor_sha == NULL_SHA {
return Ok(true);
}
let status = Command::new("git")
.args([
"merge-base",
"--is-ancestor",
ancestor_sha,
descendant_sha,
])
.current_dir(repo_root)
.status()
.with_context(|| {
"git merge-base --is-ancestor failed (is git installed?)"
})?;
match status.code() {
Some(0) => Ok(true),
Some(1) => Ok(false),
Some(code) => Err(anyhow!(
"git merge-base exited unexpectedly with code {} for {}..{}",
code,
ancestor_sha,
descendant_sha
)),
None => Err(anyhow!(
"git merge-base was killed by signal during {}..{}",
ancestor_sha,
descendant_sha
)),
}
}
pub fn verdict(repo_root: &Path, upd: &RefUpdate, patterns: &[String]) -> Result<PushVerdict> {
let protected = match is_protected(&upd.remote_ref, patterns) {
Some(name) => name,
None => return Ok(PushVerdict::Ok),
};
if upd.local_sha == NULL_SHA {
return Ok(PushVerdict::Deletion {
protected_branch: protected,
});
}
if upd.remote_sha == NULL_SHA {
return Ok(PushVerdict::Ok);
}
if !is_ancestor(repo_root, &upd.remote_sha, &upd.local_sha)? {
return Ok(PushVerdict::ForcePush {
protected_branch: protected,
remote_sha: upd.remote_sha.clone(),
local_sha: upd.local_sha.clone(),
});
}
Ok(PushVerdict::Ok)
}
pub fn run(repo_root: &Path, stdin: impl BufRead) -> Result<CheckPushedReport> {
let patterns = protected_patterns();
let mut report = CheckPushedReport::default();
for line in stdin.lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
let upd = match parse_line(&line) {
Some(u) => u,
None => continue,
};
report.refs_inspected += 1;
let v = verdict(repo_root, &upd, &patterns)?;
if !matches!(v, PushVerdict::Ok) {
report.violations.push((upd, v));
}
}
Ok(report)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn parses_well_formed_stdin_line() {
let l = "refs/heads/feat/foo 1111 refs/heads/main 2222";
let u = parse_line(l).unwrap();
assert_eq!(u.local_ref, "refs/heads/feat/foo");
assert_eq!(u.local_sha, "1111");
assert_eq!(u.remote_ref, "refs/heads/main");
assert_eq!(u.remote_sha, "2222");
}
#[test]
fn parse_line_handles_short_input() {
assert!(parse_line("").is_none());
assert!(parse_line("only one field").is_none());
}
#[test]
fn pattern_matches_exact_and_globbed() {
assert!(pattern_matches("main", "main"));
assert!(!pattern_matches("main", "develop"));
assert!(pattern_matches("release/*", "release/2026-05"));
assert!(pattern_matches("release/*", "release/foo/bar")); assert!(!pattern_matches("release/*", "release"));
assert!(!pattern_matches("release/*", "feature/release/x"));
}
#[test]
fn is_protected_recognises_default_set() {
let _guard = ENV_LOCK.lock().unwrap();
std::env::remove_var("SHIELD_PROTECTED_BRANCHES");
let pats = protected_patterns();
assert_eq!(is_protected("refs/heads/main", &pats).as_deref(), Some("main"));
assert_eq!(is_protected("refs/heads/master", &pats).as_deref(), Some("master"));
assert_eq!(
is_protected("refs/heads/release/2026-05", &pats).as_deref(),
Some("release/2026-05")
);
assert_eq!(is_protected("refs/heads/develop", &pats), None);
}
#[test]
fn env_override_protected_branches() {
let _guard = ENV_LOCK.lock().unwrap();
std::env::set_var("SHIELD_PROTECTED_BRANCHES", "trunk, deploy/*");
let pats = protected_patterns();
assert!(is_protected("refs/heads/trunk", &pats).is_some());
assert!(is_protected("refs/heads/deploy/prod", &pats).is_some());
assert!(is_protected("refs/heads/main", &pats).is_none());
std::env::remove_var("SHIELD_PROTECTED_BRANCHES");
}
#[test]
fn empty_stdin_yields_clean_report() {
let tmp = tempfile::tempdir().unwrap();
let report = run(tmp.path(), std::io::Cursor::new(b"")).expect("run");
assert_eq!(report.refs_inspected, 0);
assert!(report.violations.is_empty());
assert_eq!(report.exit_code(), 0);
}
}