Skip to main content

ward/cli/
rollback.rs

1use anyhow::Result;
2use clap::Args;
3use console::style;
4
5use crate::engine::audit_log::{self, AuditEntry};
6use crate::github::Client;
7
8#[derive(Args)]
9pub struct RollbackCommand {
10    /// Show last N audit entries
11    #[arg(long, default_value_t = 10)]
12    last: usize,
13
14    /// Filter to a specific repository
15    #[arg(long)]
16    repo: Option<String>,
17
18    /// Show what would be reversed without applying
19    #[arg(long)]
20    dry_run: bool,
21
22    /// Skip confirmation prompt
23    #[arg(long)]
24    yes: bool,
25}
26
27impl RollbackCommand {
28    pub async fn run(&self, client: &Client) -> Result<()> {
29        let log_path = audit_log::default_log_path()?;
30
31        if !log_path.exists() {
32            println!(
33                "\n  {} No audit log found at {}",
34                style("⚠️").yellow(),
35                log_path.display()
36            );
37            return Ok(());
38        }
39
40        let entries = audit_log::read_entries(&log_path)?;
41        let successful: Vec<&AuditEntry> = entries
42            .iter()
43            .filter(|e| e.status == "success")
44            .filter(|e| self.repo.as_ref().is_none_or(|r| e.repo == *r))
45            .collect();
46
47        let to_process: Vec<&AuditEntry> =
48            successful.iter().rev().take(self.last).copied().collect();
49
50        if to_process.is_empty() {
51            println!(
52                "\n  {} No matching audit entries found.",
53                style("ℹ️").blue()
54            );
55            return Ok(());
56        }
57
58        println!();
59        println!(
60            "  {} Rollback candidates ({} entries):",
61            style("🔄").bold(),
62            to_process.len()
63        );
64        println!();
65
66        let mut reversible = Vec::new();
67        let mut skipped = Vec::new();
68
69        for entry in &to_process {
70            match classify_rollback(entry) {
71                RollbackAction::Reverse(desc) => {
72                    println!(
73                        "  {} {} / {} / {}",
74                        style("⚡").yellow(),
75                        entry.repo,
76                        entry.action,
77                        desc
78                    );
79                    reversible.push((*entry, desc));
80                }
81                RollbackAction::Skip(reason) => {
82                    println!(
83                        "  {} {} / {} / {}",
84                        style("⏭").dim(),
85                        style(&entry.repo).dim(),
86                        style(&entry.action).dim(),
87                        style(&reason).dim()
88                    );
89                    skipped.push((*entry, reason));
90                }
91            }
92        }
93
94        println!();
95        println!(
96            "  {} reversible, {} skipped",
97            style(reversible.len()).yellow().bold(),
98            style(skipped.len()).dim()
99        );
100
101        if reversible.is_empty() {
102            println!("\n  {} Nothing to rollback.", style("ℹ️").blue());
103            return Ok(());
104        }
105
106        if self.dry_run {
107            println!("\n  {} Dry run - no changes applied.", style("ℹ️").blue());
108            return Ok(());
109        }
110
111        if !self.yes {
112            let proceed = dialoguer::Confirm::new()
113                .with_prompt(format!("  Rollback {} entries?", reversible.len()))
114                .default(false)
115                .interact()?;
116
117            if !proceed {
118                println!("  Aborted.");
119                return Ok(());
120            }
121        }
122
123        println!();
124        println!("  {} Rolling back...", style("⚡").bold());
125
126        let mut succeeded = 0usize;
127        let mut failed: Vec<(String, String)> = Vec::new();
128
129        for (entry, _desc) in &reversible {
130            match execute_rollback(client, entry).await {
131                Ok(()) => {
132                    println!(
133                        "  {} {}/{}: ✅ rolled back",
134                        style("▶").magenta(),
135                        entry.repo,
136                        entry.action
137                    );
138                    succeeded += 1;
139                }
140                Err(e) => {
141                    println!(
142                        "  {} {}/{}: ❌ {e}",
143                        style("▶").magenta(),
144                        entry.repo,
145                        entry.action
146                    );
147                    failed.push((entry.repo.clone(), e.to_string()));
148                }
149            }
150        }
151
152        println!();
153        if failed.is_empty() {
154            println!(
155                "  {} All {} entries rolled back successfully.",
156                style("✅").green(),
157                succeeded
158            );
159        } else {
160            println!(
161                "  {} {} succeeded, {} failed:",
162                style("⚠️").yellow(),
163                succeeded,
164                failed.len()
165            );
166            for (repo, err) in &failed {
167                println!("    {} {}: {}", style("❌").red(), repo, err);
168            }
169        }
170
171        Ok(())
172    }
173}
174
175enum RollbackAction {
176    Reverse(String),
177    Skip(String),
178}
179
180impl RollbackAction {
181    #[cfg(test)]
182    fn is_reverse(&self) -> bool {
183        matches!(self, RollbackAction::Reverse(_))
184    }
185
186    #[cfg(test)]
187    fn is_skip(&self) -> bool {
188        matches!(self, RollbackAction::Skip(_))
189    }
190}
191
192fn classify_rollback(entry: &AuditEntry) -> RollbackAction {
193    match entry.action.as_str() {
194        "set_secret_scanning" if entry.after == serde_json::Value::Bool(true) => {
195            RollbackAction::Reverse("disable secret scanning".to_string())
196        }
197        "set_push_protection" if entry.after == serde_json::Value::Bool(true) => {
198            RollbackAction::Reverse("disable push protection".to_string())
199        }
200        "set_secret_scanning_ai_detection" if entry.after == serde_json::Value::Bool(true) => {
201            RollbackAction::Reverse("disable secret scanning AI detection".to_string())
202        }
203        "enable_dependabot_alerts" => {
204            RollbackAction::Skip("disabling Dependabot alerts not supported via API".to_string())
205        }
206        "enable_dependabot_security_updates" => RollbackAction::Skip(
207            "disabling Dependabot security updates not supported via API".to_string(),
208        ),
209        "create_copilot_review_ruleset" => RollbackAction::Skip(
210            "ruleset deletion requires ruleset ID - manual removal needed".to_string(),
211        ),
212        "deploy_copilot_instructions" => {
213            RollbackAction::Skip("file deletion not supported - remove via PR".to_string())
214        }
215        "update_branch_protection" => RollbackAction::Skip(
216            "branch protection rollback not supported - re-run with desired config".to_string(),
217        ),
218        _ => RollbackAction::Skip(format!("unknown action: {}", entry.action)),
219    }
220}
221
222async fn execute_rollback(client: &Client, entry: &AuditEntry) -> Result<()> {
223    match entry.action.as_str() {
224        "set_secret_scanning" => {
225            client
226                .set_security_features(&entry.repo, false, true, true)
227                .await
228        }
229        "set_push_protection" => {
230            client
231                .set_security_features(&entry.repo, true, true, false)
232                .await
233        }
234        "set_secret_scanning_ai_detection" => {
235            client
236                .set_security_features(&entry.repo, true, false, true)
237                .await
238        }
239        _ => anyhow::bail!("Cannot rollback action: {}", entry.action),
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    fn entry(action: &str, after: serde_json::Value) -> AuditEntry {
248        AuditEntry {
249            timestamp: String::new(),
250            repo: "test-repo".to_string(),
251            action: action.to_string(),
252            status: "success".to_string(),
253            before: serde_json::Value::Null,
254            after,
255        }
256    }
257
258    #[test]
259    fn secret_scanning_enabled_is_reverse() {
260        let e = entry("set_secret_scanning", serde_json::Value::Bool(true));
261        assert!(classify_rollback(&e).is_reverse());
262    }
263
264    #[test]
265    fn secret_scanning_not_enabled_is_skip() {
266        let e = entry("set_secret_scanning", serde_json::Value::Bool(false));
267        assert!(classify_rollback(&e).is_skip());
268    }
269
270    #[test]
271    fn push_protection_enabled_is_reverse() {
272        let e = entry("set_push_protection", serde_json::Value::Bool(true));
273        assert!(classify_rollback(&e).is_reverse());
274    }
275
276    #[test]
277    fn ai_detection_enabled_is_reverse() {
278        let e = entry(
279            "set_secret_scanning_ai_detection",
280            serde_json::Value::Bool(true),
281        );
282        assert!(classify_rollback(&e).is_reverse());
283    }
284
285    #[test]
286    fn dependabot_alerts_is_skip() {
287        let e = entry("enable_dependabot_alerts", serde_json::Value::Null);
288        assert!(classify_rollback(&e).is_skip());
289    }
290
291    #[test]
292    fn dependabot_security_updates_is_skip() {
293        let e = entry(
294            "enable_dependabot_security_updates",
295            serde_json::Value::Null,
296        );
297        assert!(classify_rollback(&e).is_skip());
298    }
299
300    #[test]
301    fn create_copilot_review_ruleset_is_skip() {
302        let e = entry("create_copilot_review_ruleset", serde_json::Value::Null);
303        assert!(classify_rollback(&e).is_skip());
304    }
305
306    #[test]
307    fn update_branch_protection_is_skip() {
308        let e = entry("update_branch_protection", serde_json::Value::Null);
309        assert!(classify_rollback(&e).is_skip());
310    }
311
312    #[test]
313    fn unknown_action_is_skip() {
314        let e = entry("some_future_action", serde_json::Value::Null);
315        assert!(classify_rollback(&e).is_skip());
316    }
317}