1use anyhow::{anyhow, Context, Result};
55use std::io::BufRead;
56use std::path::Path;
57use std::process::Command;
58
59const NULL_SHA: &str = "0000000000000000000000000000000000000000";
60
61const DEFAULT_PROTECTED_BRANCHES: &[&str] = &[
62 "main",
63 "master",
64 "prod",
65 "production",
66 "release",
67 "release/*",
68 "prod/*",
69 "hotfix/*",
70];
71
72#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct RefUpdate {
75 pub local_ref: String,
76 pub local_sha: String,
77 pub remote_ref: String,
78 pub remote_sha: String,
79}
80
81#[derive(Debug, Clone)]
82pub enum PushVerdict {
83 Ok,
84 Deletion {
85 protected_branch: String,
86 },
87 ForcePush {
88 protected_branch: String,
89 remote_sha: String,
90 local_sha: String,
91 },
92}
93
94#[derive(Debug, Default)]
95pub struct CheckPushedReport {
96 pub refs_inspected: usize,
97 pub violations: Vec<(RefUpdate, PushVerdict)>,
98}
99
100impl CheckPushedReport {
101 pub fn exit_code(&self) -> u8 {
102 if self.violations.is_empty() {
103 0
104 } else {
105 1
106 }
107 }
108}
109
110pub fn parse_line(line: &str) -> Option<RefUpdate> {
113 let mut iter = line.split_whitespace();
114 let local_ref = iter.next()?.to_string();
115 let local_sha = iter.next()?.to_string();
116 let remote_ref = iter.next()?.to_string();
117 let remote_sha = iter.next()?.to_string();
118 Some(RefUpdate {
119 local_ref,
120 local_sha,
121 remote_ref,
122 remote_sha,
123 })
124}
125
126pub fn protected_patterns() -> Vec<String> {
128 if let Ok(raw) = std::env::var("SHIELD_PROTECTED_BRANCHES") {
129 raw.split(',')
130 .map(|s| s.trim().to_string())
131 .filter(|s| !s.is_empty())
132 .collect()
133 } else {
134 DEFAULT_PROTECTED_BRANCHES
135 .iter()
136 .map(|s| (*s).to_string())
137 .collect()
138 }
139}
140
141pub fn pattern_matches(pattern: &str, short_name: &str) -> bool {
144 if let Some(prefix) = pattern.strip_suffix("/*") {
145 return short_name.starts_with(&format!("{}/", prefix));
146 }
147 pattern == short_name
148}
149
150fn short_name(full_ref: &str) -> &str {
152 full_ref.strip_prefix("refs/heads/").unwrap_or(full_ref)
153}
154
155pub fn is_protected(remote_ref: &str, patterns: &[String]) -> Option<String> {
157 let s = short_name(remote_ref);
158 for p in patterns {
159 if pattern_matches(p, s) {
160 return Some(s.to_string());
161 }
162 }
163 None
164}
165
166fn is_ancestor(
170 repo_root: &Path,
171 ancestor_sha: &str,
172 descendant_sha: &str,
173) -> Result<bool> {
174 if ancestor_sha == NULL_SHA {
175 return Ok(true);
178 }
179 let status = Command::new("git")
180 .args([
181 "merge-base",
182 "--is-ancestor",
183 ancestor_sha,
184 descendant_sha,
185 ])
186 .current_dir(repo_root)
187 .status()
188 .with_context(|| {
189 "git merge-base --is-ancestor failed (is git installed?)"
190 })?;
191 match status.code() {
195 Some(0) => Ok(true),
196 Some(1) => Ok(false),
197 Some(code) => Err(anyhow!(
198 "git merge-base exited unexpectedly with code {} for {}..{}",
199 code,
200 ancestor_sha,
201 descendant_sha
202 )),
203 None => Err(anyhow!(
204 "git merge-base was killed by signal during {}..{}",
205 ancestor_sha,
206 descendant_sha
207 )),
208 }
209}
210
211pub fn verdict(repo_root: &Path, upd: &RefUpdate, patterns: &[String]) -> Result<PushVerdict> {
213 let protected = match is_protected(&upd.remote_ref, patterns) {
214 Some(name) => name,
215 None => return Ok(PushVerdict::Ok),
216 };
217
218 if upd.local_sha == NULL_SHA {
220 return Ok(PushVerdict::Deletion {
221 protected_branch: protected,
222 });
223 }
224
225 if upd.remote_sha == NULL_SHA {
228 return Ok(PushVerdict::Ok);
229 }
230
231 if !is_ancestor(repo_root, &upd.remote_sha, &upd.local_sha)? {
233 return Ok(PushVerdict::ForcePush {
234 protected_branch: protected,
235 remote_sha: upd.remote_sha.clone(),
236 local_sha: upd.local_sha.clone(),
237 });
238 }
239
240 Ok(PushVerdict::Ok)
241}
242
243pub fn run(repo_root: &Path, stdin: impl BufRead) -> Result<CheckPushedReport> {
245 let patterns = protected_patterns();
246 let mut report = CheckPushedReport::default();
247
248 for line in stdin.lines() {
249 let line = line?;
250 if line.trim().is_empty() {
251 continue;
252 }
253 let upd = match parse_line(&line) {
254 Some(u) => u,
255 None => continue,
256 };
257 report.refs_inspected += 1;
258 let v = verdict(repo_root, &upd, &patterns)?;
259 if !matches!(v, PushVerdict::Ok) {
260 report.violations.push((upd, v));
261 }
262 }
263 Ok(report)
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use std::sync::Mutex;
270
271 static ENV_LOCK: Mutex<()> = Mutex::new(());
277
278 #[test]
279 fn parses_well_formed_stdin_line() {
280 let l = "refs/heads/feat/foo 1111 refs/heads/main 2222";
281 let u = parse_line(l).unwrap();
282 assert_eq!(u.local_ref, "refs/heads/feat/foo");
283 assert_eq!(u.local_sha, "1111");
284 assert_eq!(u.remote_ref, "refs/heads/main");
285 assert_eq!(u.remote_sha, "2222");
286 }
287
288 #[test]
289 fn parse_line_handles_short_input() {
290 assert!(parse_line("").is_none());
291 assert!(parse_line("only one field").is_none());
292 }
293
294 #[test]
295 fn pattern_matches_exact_and_globbed() {
296 assert!(pattern_matches("main", "main"));
297 assert!(!pattern_matches("main", "develop"));
298 assert!(pattern_matches("release/*", "release/2026-05"));
299 assert!(pattern_matches("release/*", "release/foo/bar")); assert!(!pattern_matches("release/*", "release"));
301 assert!(!pattern_matches("release/*", "feature/release/x"));
302 }
303
304 #[test]
312 fn is_protected_recognises_default_set() {
313 let _guard = ENV_LOCK.lock().unwrap();
314 std::env::remove_var("SHIELD_PROTECTED_BRANCHES");
315 let pats = protected_patterns();
316 assert_eq!(is_protected("refs/heads/main", &pats).as_deref(), Some("main"));
317 assert_eq!(is_protected("refs/heads/master", &pats).as_deref(), Some("master"));
318 assert_eq!(
319 is_protected("refs/heads/release/2026-05", &pats).as_deref(),
320 Some("release/2026-05")
321 );
322 assert_eq!(is_protected("refs/heads/develop", &pats), None);
323 }
324
325 #[test]
326 fn env_override_protected_branches() {
327 let _guard = ENV_LOCK.lock().unwrap();
328 std::env::set_var("SHIELD_PROTECTED_BRANCHES", "trunk, deploy/*");
329 let pats = protected_patterns();
330 assert!(is_protected("refs/heads/trunk", &pats).is_some());
331 assert!(is_protected("refs/heads/deploy/prod", &pats).is_some());
332 assert!(is_protected("refs/heads/main", &pats).is_none());
333 std::env::remove_var("SHIELD_PROTECTED_BRANCHES");
334 }
335
336 #[test]
337 fn empty_stdin_yields_clean_report() {
338 let tmp = tempfile::tempdir().unwrap();
339 let report = run(tmp.path(), std::io::Cursor::new(b"")).expect("run");
340 assert_eq!(report.refs_inspected, 0);
341 assert!(report.violations.is_empty());
342 assert_eq!(report.exit_code(), 0);
343 }
344}