Skip to main content

ward/cli/
protection.rs

1use anyhow::Result;
2use clap::Args;
3use console::style;
4use dialoguer::Confirm;
5
6use crate::config::Manifest;
7use crate::engine::audit_log::AuditLog;
8use crate::github::Client;
9use crate::github::branch_protection::BranchProtectionState;
10
11#[derive(Args)]
12pub struct ProtectionCommand {
13    #[command(subcommand)]
14    action: ProtectionAction,
15}
16
17#[derive(clap::Subcommand)]
18enum ProtectionAction {
19    /// Show what branch protection changes would be made (dry-run)
20    Plan,
21
22    /// Apply branch protection to default branches
23    Apply {
24        /// Skip confirmation prompt
25        #[arg(long)]
26        yes: bool,
27    },
28
29    /// Show current branch protection state
30    Audit,
31}
32
33impl ProtectionCommand {
34    pub async fn run(
35        &self,
36        client: &Client,
37        manifest: &Manifest,
38        system: Option<&str>,
39        repo: Option<&str>,
40    ) -> Result<()> {
41        match &self.action {
42            ProtectionAction::Plan => plan(client, manifest, system, repo).await,
43            ProtectionAction::Apply { yes } => apply(client, manifest, system, repo, *yes).await,
44            ProtectionAction::Audit => audit(client, manifest, system, repo).await,
45        }
46    }
47}
48
49async fn resolve_repos_with_branches(
50    client: &Client,
51    manifest: &Manifest,
52    system: Option<&str>,
53    repo: Option<&str>,
54) -> Result<Vec<(String, String)>> {
55    if let Some(repo_name) = repo {
56        let r = client.get_repo(repo_name).await?;
57        return Ok(vec![(r.name, r.default_branch)]);
58    }
59
60    let sys = system.ok_or_else(|| {
61        anyhow::anyhow!("Either --system or --repo is required for protection commands")
62    })?;
63
64    let excludes = manifest.exclude_patterns_for_system(sys);
65    let explicit = manifest.explicit_repos_for_system(sys);
66    let repos = client
67        .list_repos_for_system(sys, &excludes, &explicit)
68        .await?;
69    Ok(repos
70        .into_iter()
71        .map(|r| (r.name, r.default_branch))
72        .collect())
73}
74
75struct ProtectionDiff {
76    repo: String,
77    branch: String,
78    changes: Vec<ProtectionChange>,
79}
80
81struct ProtectionChange {
82    field: String,
83    current: String,
84    desired: String,
85}
86
87impl ProtectionDiff {
88    fn has_changes(&self) -> bool {
89        !self.changes.is_empty()
90    }
91}
92
93fn diff_protection(
94    repo: &str,
95    branch: &str,
96    current: &BranchProtectionState,
97    config: &crate::config::manifest::BranchProtectionConfig,
98) -> ProtectionDiff {
99    let mut changes = Vec::new();
100
101    let checks: Vec<(&str, String, String)> = vec![
102        (
103            "required_pull_request_reviews",
104            current.required_pull_request_reviews.to_string(),
105            config.enabled.to_string(),
106        ),
107        (
108            "required_approvals",
109            current.required_approving_review_count.to_string(),
110            config.required_approvals.to_string(),
111        ),
112        (
113            "dismiss_stale_reviews",
114            current.dismiss_stale_reviews.to_string(),
115            config.dismiss_stale_reviews.to_string(),
116        ),
117        (
118            "require_code_owner_reviews",
119            current.require_code_owner_reviews.to_string(),
120            config.require_code_owner_reviews.to_string(),
121        ),
122        (
123            "require_status_checks",
124            current.required_status_checks.to_string(),
125            config.require_status_checks.to_string(),
126        ),
127        (
128            "strict_status_checks",
129            current.strict_status_checks.to_string(),
130            config.strict_status_checks.to_string(),
131        ),
132        (
133            "enforce_admins",
134            current.enforce_admins.to_string(),
135            config.enforce_admins.to_string(),
136        ),
137        (
138            "required_linear_history",
139            current.required_linear_history.to_string(),
140            config.required_linear_history.to_string(),
141        ),
142        (
143            "allow_force_pushes",
144            current.allow_force_pushes.to_string(),
145            config.allow_force_pushes.to_string(),
146        ),
147        (
148            "allow_deletions",
149            current.allow_deletions.to_string(),
150            config.allow_deletions.to_string(),
151        ),
152    ];
153
154    for (field, current_val, desired_val) in checks {
155        if current_val != desired_val {
156            changes.push(ProtectionChange {
157                field: field.to_string(),
158                current: current_val,
159                desired: desired_val,
160            });
161        }
162    }
163
164    ProtectionDiff {
165        repo: repo.to_string(),
166        branch: branch.to_string(),
167        changes,
168    }
169}
170
171async fn build_diffs(
172    client: &Client,
173    manifest: &Manifest,
174    system: Option<&str>,
175    repo: Option<&str>,
176) -> Result<Vec<ProtectionDiff>> {
177    let repos = resolve_repos_with_branches(client, manifest, system, repo).await?;
178    let config = &manifest.branch_protection;
179
180    println!();
181    println!(
182        "  {} Scanning {} repositories...",
183        style("🔍").bold(),
184        repos.len()
185    );
186
187    let mut diffs = Vec::new();
188    for (repo_name, default_branch) in &repos {
189        let current = client
190            .get_branch_protection(repo_name, default_branch)
191            .await?
192            .unwrap_or_default();
193
194        diffs.push(diff_protection(repo_name, default_branch, &current, config));
195    }
196
197    Ok(diffs)
198}
199
200async fn plan(
201    client: &Client,
202    manifest: &Manifest,
203    system: Option<&str>,
204    repo: Option<&str>,
205) -> Result<()> {
206    let diffs = build_diffs(client, manifest, system, repo).await?;
207
208    print_diff_table(&diffs);
209
210    let needs_changes = diffs.iter().filter(|d| d.has_changes()).count();
211    if needs_changes > 0 {
212        println!(
213            "\n  Run {} to apply these changes.",
214            style("ward protection apply").cyan().bold()
215        );
216    }
217
218    Ok(())
219}
220
221async fn apply(
222    client: &Client,
223    manifest: &Manifest,
224    system: Option<&str>,
225    repo: Option<&str>,
226    yes: bool,
227) -> Result<()> {
228    let diffs = build_diffs(client, manifest, system, repo).await?;
229
230    let needs_changes = diffs.iter().filter(|d| d.has_changes()).count();
231    if needs_changes == 0 {
232        println!(
233            "\n  {} All repositories are up to date.",
234            style("✅").green()
235        );
236        return Ok(());
237    }
238
239    print_diff_table(&diffs);
240
241    if !yes {
242        let proceed = Confirm::new()
243            .with_prompt(format!(
244                "  Apply branch protection to {needs_changes} repositories?"
245            ))
246            .default(false)
247            .interact()?;
248
249        if !proceed {
250            println!("  Aborted.");
251            return Ok(());
252        }
253    }
254
255    println!();
256    println!("  {} Applying changes...", style("⚡").bold());
257
258    let audit_log = AuditLog::new()?;
259    let config = &manifest.branch_protection;
260    let mut succeeded = 0usize;
261    let mut failed: Vec<(String, String)> = Vec::new();
262
263    for diff in diffs.iter().filter(|d| d.has_changes()) {
264        match client
265            .update_branch_protection(&diff.repo, &diff.branch, config)
266            .await
267        {
268            Ok(()) => {
269                println!(
270                    "  {} {}/{}: ✅ done",
271                    style("▶").magenta(),
272                    diff.repo,
273                    diff.branch
274                );
275                audit_log.log(
276                    &diff.repo,
277                    "update_branch_protection",
278                    "success",
279                    false,
280                    true,
281                )?;
282                succeeded += 1;
283            }
284            Err(e) => {
285                println!(
286                    "  {} {}/{}: ❌ {e}",
287                    style("▶").magenta(),
288                    diff.repo,
289                    diff.branch
290                );
291                failed.push((diff.repo.clone(), e.to_string()));
292            }
293        }
294    }
295
296    println!();
297    if failed.is_empty() {
298        println!(
299            "  {} All {} repositories updated successfully.",
300            style("✅").green(),
301            succeeded
302        );
303    } else {
304        println!(
305            "  {} {} succeeded, {} failed:",
306            style("⚠️").yellow(),
307            succeeded,
308            failed.len()
309        );
310        for (repo, err) in &failed {
311            println!("    {} {}: {}", style("❌").red(), repo, err);
312        }
313    }
314
315    println!(
316        "\n  {} Audit log: {}",
317        style("📋").bold(),
318        audit_log.path().display()
319    );
320
321    Ok(())
322}
323
324async fn audit(
325    client: &Client,
326    manifest: &Manifest,
327    system: Option<&str>,
328    repo: Option<&str>,
329) -> Result<()> {
330    let repos = resolve_repos_with_branches(client, manifest, system, repo).await?;
331
332    println!();
333    println!(
334        "  {} Auditing branch protection for {} repositories...",
335        style("🔍").bold(),
336        repos.len()
337    );
338
339    println!();
340    println!(
341        "  {:40} {:8} {:10} {:10} {:10} {:10} {:10} {:10}",
342        style("Repository").bold().underlined(),
343        style("Branch").bold().underlined(),
344        style("PR Rev").bold().underlined(),
345        style("Approvals").bold().underlined(),
346        style("Stale").bold().underlined(),
347        style("Admins").bold().underlined(),
348        style("Linear").bold().underlined(),
349        style("Force").bold().underlined(),
350    );
351
352    let mut total_ok = 0;
353    let mut total_issues = 0;
354
355    for (repo_name, default_branch) in &repos {
356        let state = client
357            .get_branch_protection(repo_name, default_branch)
358            .await?
359            .unwrap_or_default();
360
361        let protected = state.required_pull_request_reviews;
362        if protected {
363            total_ok += 1;
364        } else {
365            total_issues += 1;
366        }
367
368        let icon = |v: bool| {
369            if v {
370                format!("{}", style("✅").green())
371            } else {
372                format!("{}", style("❌").red())
373            }
374        };
375
376        println!(
377            "  {:40} {:8} {:10} {:10} {:10} {:10} {:10} {:10}",
378            repo_name,
379            default_branch,
380            icon(state.required_pull_request_reviews),
381            state.required_approving_review_count,
382            icon(state.dismiss_stale_reviews),
383            icon(state.enforce_admins),
384            icon(state.required_linear_history),
385            icon(state.allow_force_pushes),
386        );
387    }
388
389    println!();
390    println!(
391        "  Summary: {} protected, {} unprotected",
392        style(total_ok).green().bold(),
393        if total_issues > 0 {
394            style(total_issues).red().bold()
395        } else {
396            style(total_issues).green().bold()
397        }
398    );
399
400    Ok(())
401}
402
403fn print_diff_table(diffs: &[ProtectionDiff]) {
404    println!();
405    println!("  {}", style("Branch Protection Plan").bold().cyan());
406    println!("  {}", style("─".repeat(60)).dim());
407
408    for diff in diffs {
409        if diff.has_changes() {
410            println!(
411                "  {} {} ({})",
412                style("⚡").yellow(),
413                style(&diff.repo).bold(),
414                diff.branch
415            );
416            for change in &diff.changes {
417                println!(
418                    "     {}: {} → {}",
419                    change.field,
420                    style(&change.current).red(),
421                    style(&change.desired).green().bold()
422                );
423            }
424        } else {
425            println!("  {} {}", style("✓").green(), style(&diff.repo).dim());
426        }
427    }
428
429    let needs_changes = diffs.iter().filter(|d| d.has_changes()).count();
430    let up_to_date = diffs.len() - needs_changes;
431
432    println!();
433    println!(
434        "  Summary: {} need changes, {} up to date",
435        if needs_changes > 0 {
436            style(needs_changes).yellow().bold()
437        } else {
438            style(needs_changes).green().bold()
439        },
440        style(up_to_date).green()
441    );
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use crate::config::manifest::BranchProtectionConfig;
448
449    fn default_state() -> BranchProtectionState {
450        BranchProtectionState {
451            required_pull_request_reviews: false,
452            required_approving_review_count: 1,
453            dismiss_stale_reviews: false,
454            require_code_owner_reviews: false,
455            required_status_checks: false,
456            strict_status_checks: false,
457            enforce_admins: false,
458            required_linear_history: false,
459            allow_force_pushes: false,
460            allow_deletions: false,
461        }
462    }
463
464    fn default_config() -> BranchProtectionConfig {
465        BranchProtectionConfig {
466            enabled: false,
467            required_approvals: 1,
468            dismiss_stale_reviews: false,
469            require_code_owner_reviews: false,
470            require_status_checks: false,
471            strict_status_checks: false,
472            enforce_admins: false,
473            required_linear_history: false,
474            allow_force_pushes: false,
475            allow_deletions: false,
476        }
477    }
478
479    #[test]
480    fn no_changes_when_state_matches_config() {
481        let state = default_state();
482        let config = default_config();
483        let diff = diff_protection("my-repo", "main", &state, &config);
484        assert!(!diff.has_changes());
485    }
486
487    #[test]
488    fn all_fields_produce_changes_when_they_differ() {
489        let state = default_state();
490        let config = BranchProtectionConfig {
491            enabled: true,
492            required_approvals: 2,
493            dismiss_stale_reviews: true,
494            require_code_owner_reviews: true,
495            require_status_checks: true,
496            strict_status_checks: true,
497            enforce_admins: true,
498            required_linear_history: true,
499            allow_force_pushes: true,
500            allow_deletions: true,
501        };
502        let diff = diff_protection("my-repo", "main", &state, &config);
503        assert_eq!(diff.changes.len(), 10);
504    }
505
506    #[test]
507    fn partial_changes_detected() {
508        let state = default_state();
509        let mut config = default_config();
510        config.enforce_admins = true;
511        config.required_approvals = 3;
512
513        let diff = diff_protection("my-repo", "main", &state, &config);
514        assert_eq!(diff.changes.len(), 2);
515        let fields: Vec<&str> = diff.changes.iter().map(|c| c.field.as_str()).collect();
516        assert!(fields.contains(&"enforce_admins"));
517        assert!(fields.contains(&"required_approvals"));
518    }
519
520    #[test]
521    fn repo_and_branch_preserved() {
522        let state = default_state();
523        let config = default_config();
524        let diff = diff_protection("acme-service", "develop", &state, &config);
525        assert_eq!(diff.repo, "acme-service");
526        assert_eq!(diff.branch, "develop");
527    }
528}