guts_auth/
branch_protection.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct BranchProtection {
12 pub id: u64,
14 pub repo_key: String,
16 pub pattern: String,
18 pub require_pr: bool,
20 pub required_reviews: u32,
22 pub required_status_checks: HashSet<String>,
24 pub dismiss_stale_reviews: bool,
26 pub require_code_owner_review: bool,
28 pub restrict_pushes: bool,
30 pub allow_force_push: bool,
32 pub allow_deletion: bool,
34 pub created_at: u64,
36 pub updated_at: u64,
38}
39
40impl BranchProtection {
41 pub fn new(id: u64, repo_key: String, pattern: String) -> Self {
43 let now = Self::now();
44 Self {
45 id,
46 repo_key,
47 pattern,
48 require_pr: true,
49 required_reviews: 1,
50 required_status_checks: HashSet::new(),
51 dismiss_stale_reviews: false,
52 require_code_owner_review: false,
53 restrict_pushes: false,
54 allow_force_push: false,
55 allow_deletion: false,
56 created_at: now,
57 updated_at: now,
58 }
59 }
60
61 pub fn matches(&self, branch: &str) -> bool {
63 if self.pattern.contains('*') {
64 let parts: Vec<&str> = self.pattern.split('*').collect();
66 if parts.len() == 1 {
67 branch == self.pattern
69 } else if parts.len() == 2 {
70 let prefix = parts[0];
72 let suffix = parts[1];
73 branch.starts_with(prefix) && branch.ends_with(suffix)
74 } else {
75 branch.starts_with(parts[0])
77 }
78 } else {
79 branch == self.pattern
80 }
81 }
82
83 pub fn allows_direct_push(&self, is_admin: bool) -> bool {
85 if !self.require_pr {
86 return true;
87 }
88 if self.restrict_pushes {
89 return is_admin;
90 }
91 false
92 }
93
94 pub fn allows_force_push(&self) -> bool {
96 self.allow_force_push
97 }
98
99 pub fn allows_deletion(&self) -> bool {
101 self.allow_deletion
102 }
103
104 pub fn check_reviews(&self, approving_reviews: u32, has_code_owner_review: bool) -> bool {
106 if approving_reviews < self.required_reviews {
107 return false;
108 }
109 if self.require_code_owner_review && !has_code_owner_review {
110 return false;
111 }
112 true
113 }
114
115 pub fn check_status(&self, passed_checks: &HashSet<String>) -> bool {
117 self.required_status_checks.is_subset(passed_checks)
118 }
119
120 pub fn add_required_check(&mut self, check: String) {
122 self.required_status_checks.insert(check);
123 self.updated_at = Self::now();
124 }
125
126 pub fn remove_required_check(&mut self, check: &str) -> bool {
128 let removed = self.required_status_checks.remove(check);
129 if removed {
130 self.updated_at = Self::now();
131 }
132 removed
133 }
134
135 fn now() -> u64 {
136 std::time::SystemTime::now()
137 .duration_since(std::time::UNIX_EPOCH)
138 .unwrap_or_default()
139 .as_secs()
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct BranchProtectionRequest {
146 #[serde(default)]
148 pub require_pr: bool,
149 #[serde(default)]
151 pub required_reviews: u32,
152 #[serde(default)]
154 pub required_status_checks: Vec<String>,
155 #[serde(default)]
157 pub dismiss_stale_reviews: bool,
158 #[serde(default)]
160 pub require_code_owner_review: bool,
161 #[serde(default)]
163 pub restrict_pushes: bool,
164 #[serde(default)]
166 pub allow_force_push: bool,
167 #[serde(default)]
169 pub allow_deletion: bool,
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_pattern_matching() {
178 let rule = BranchProtection::new(1, "acme/api".into(), "main".into());
179 assert!(rule.matches("main"));
180 assert!(!rule.matches("master"));
181 assert!(!rule.matches("main-backup"));
182
183 let rule = BranchProtection::new(2, "acme/api".into(), "release/*".into());
184 assert!(rule.matches("release/1.0"));
185 assert!(rule.matches("release/2.0.1"));
186 assert!(!rule.matches("releases/1.0"));
187 assert!(!rule.matches("release"));
188
189 let rule = BranchProtection::new(3, "acme/api".into(), "feature-*-test".into());
190 assert!(rule.matches("feature-foo-test"));
191 assert!(rule.matches("feature-bar-test"));
192 assert!(!rule.matches("feature-foo-tests"));
193 }
194
195 #[test]
196 fn test_direct_push() {
197 let mut rule = BranchProtection::new(1, "acme/api".into(), "main".into());
198
199 assert!(!rule.allows_direct_push(false));
201 assert!(!rule.allows_direct_push(true));
202
203 rule.restrict_pushes = true;
205 assert!(!rule.allows_direct_push(false));
206 assert!(rule.allows_direct_push(true));
207
208 rule.require_pr = false;
210 assert!(rule.allows_direct_push(false));
211 assert!(rule.allows_direct_push(true));
212 }
213
214 #[test]
215 fn test_review_requirements() {
216 let mut rule = BranchProtection::new(1, "acme/api".into(), "main".into());
217 rule.required_reviews = 2;
218
219 assert!(!rule.check_reviews(0, false));
220 assert!(!rule.check_reviews(1, false));
221 assert!(rule.check_reviews(2, false));
222 assert!(rule.check_reviews(3, false));
223
224 rule.require_code_owner_review = true;
225 assert!(!rule.check_reviews(2, false));
226 assert!(rule.check_reviews(2, true));
227 }
228
229 #[test]
230 fn test_status_checks() {
231 let mut rule = BranchProtection::new(1, "acme/api".into(), "main".into());
232 rule.add_required_check("ci/build".into());
233 rule.add_required_check("ci/test".into());
234
235 let mut passed = HashSet::new();
236 assert!(!rule.check_status(&passed));
237
238 passed.insert("ci/build".into());
239 assert!(!rule.check_status(&passed));
240
241 passed.insert("ci/test".into());
242 assert!(rule.check_status(&passed));
243
244 passed.insert("ci/lint".into()); assert!(rule.check_status(&passed));
246 }
247}