1use std::collections::HashMap;
2
3use anyhow::Result;
4use clap::Args;
5use console::style;
6
7use crate::github::Client;
8use crate::github::branch_protection::BranchProtectionState;
9use crate::github::repos::Repository;
10use crate::github::security::SecurityState;
11
12#[derive(Args)]
13pub struct ImportCommand {
14 #[arg(long, required = true)]
16 org: String,
17
18 #[arg(long)]
20 stdout: bool,
21
22 #[arg(long, default_value_t = 2)]
24 min_group_size: usize,
25
26 #[arg(long, default_value_t = 5)]
28 parallelism: usize,
29}
30
31#[derive(Debug, Clone)]
32struct DetectedSystem {
33 id: String,
34 repos: Vec<String>,
35}
36
37#[derive(Debug, Clone, Default)]
38struct SampledSecurity {
39 secret_scanning: bool,
40 push_protection: bool,
41 dependabot_alerts: bool,
42 dependabot_security_updates: bool,
43 secret_scanning_ai_detection: bool,
44}
45
46#[derive(Debug, Clone, Default)]
47struct SampledProtection {
48 enabled: bool,
49 required_approvals: u32,
50 dismiss_stale_reviews: bool,
51 require_code_owner_reviews: bool,
52 require_status_checks: bool,
53 strict_status_checks: bool,
54 enforce_admins: bool,
55 required_linear_history: bool,
56 allow_force_pushes: bool,
57 allow_deletions: bool,
58}
59
60impl ImportCommand {
61 pub async fn run(self) -> Result<()> {
62 let client = Client::new(&self.org, self.parallelism).await?;
63
64 println!(
65 "\n {} Fetching repositories for {}...",
66 style("[..]").dim(),
67 style(&self.org).cyan().bold()
68 );
69
70 let repos = client.list_repos().await?;
71 let active: Vec<&Repository> = repos.iter().filter(|r| !r.archived).collect();
72
73 println!(
74 " {} Found {} repositories ({} active)",
75 style("[ok]").green(),
76 repos.len(),
77 active.len()
78 );
79
80 let active_names: Vec<String> = active.iter().map(|r| r.name.clone()).collect();
81 let systems = detect_systems(&active_names, self.min_group_size);
82
83 println!(
84 " {} Detected {} systems",
85 style("[ok]").green(),
86 systems.len()
87 );
88 for sys in &systems {
89 println!(
90 " - {} ({} repos)",
91 style(&sys.id).bold(),
92 sys.repos.len()
93 );
94 }
95
96 let grouped: Vec<&str> = systems
97 .iter()
98 .flat_map(|s| s.repos.iter().map(String::as_str))
99 .collect();
100 let ungrouped: Vec<&str> = active_names
101 .iter()
102 .filter(|n| !grouped.contains(&n.as_str()))
103 .map(String::as_str)
104 .collect();
105
106 println!(
107 "\n {} Sampling security and branch protection...",
108 style("[..]").dim()
109 );
110
111 let repo_map: HashMap<&str, &Repository> =
112 active.iter().map(|r| (r.name.as_str(), *r)).collect();
113
114 let mut system_security: HashMap<String, SampledSecurity> = HashMap::new();
115 let mut global_protection = SampledProtection::default();
116 let mut sampled_any_protection = false;
117
118 for sys in &systems {
119 let sample: Vec<&str> = sys.repos.iter().take(5).map(String::as_str).collect();
120 let mut sec_states = Vec::new();
121 let mut prot_states = Vec::new();
122
123 for repo_name in &sample {
124 if let Ok(sec) = client.get_security_state(repo_name).await {
125 sec_states.push(sec);
126 }
127 if let Some(repo) = repo_map.get(repo_name) {
128 if let Ok(Some(prot)) = client
129 .get_branch_protection(repo_name, &repo.default_branch)
130 .await
131 {
132 prot_states.push(prot);
133 }
134 }
135 }
136
137 if !sec_states.is_empty() {
138 system_security.insert(sys.id.clone(), majority_vote_security(&sec_states));
139 }
140
141 if !prot_states.is_empty() && !sampled_any_protection {
142 global_protection = majority_vote_protection(&prot_states);
143 sampled_any_protection = true;
144 }
145 }
146
147 let global_sec = if system_security.is_empty() {
148 SampledSecurity::default()
149 } else {
150 merge_security_samples(system_security.values())
151 };
152
153 let team_map = sample_teams(&client, &systems).await;
154
155 let toml_output = generate_toml(
156 &self.org,
157 &systems,
158 &ungrouped,
159 &global_sec,
160 &global_protection,
161 sampled_any_protection,
162 &team_map,
163 );
164
165 if self.stdout {
166 println!("{toml_output}");
167 } else {
168 let path = "ward.toml";
169 if std::path::Path::new(path).exists() {
170 anyhow::bail!(
171 "ward.toml already exists. Use --stdout to print instead, or remove the file first."
172 );
173 }
174 std::fs::write(path, &toml_output)?;
175 println!("\n {} Wrote {}", style("[ok]").green(), style(path).bold());
176 }
177
178 println!(
179 "\n {} Import complete. Review the generated config and adjust as needed.",
180 style("[ok]").green()
181 );
182
183 Ok(())
184 }
185}
186
187fn detect_systems(repo_names: &[String], min_group_size: usize) -> Vec<DetectedSystem> {
188 let mut groups: HashMap<String, Vec<String>> = HashMap::new();
189
190 for name in repo_names {
191 if let Some(prefix) = name.split('-').next() {
192 if !prefix.is_empty() && prefix != name {
193 groups
194 .entry(prefix.to_string())
195 .or_default()
196 .push(name.clone());
197 }
198 }
199 }
200
201 let mut systems: Vec<DetectedSystem> = groups
202 .into_iter()
203 .filter(|(_, repos)| repos.len() >= min_group_size)
204 .map(|(id, mut repos)| {
205 repos.sort();
206 DetectedSystem { id, repos }
207 })
208 .collect();
209
210 systems.sort_by(|a, b| a.id.cmp(&b.id));
211 systems
212}
213
214fn majority_vote_security(states: &[SecurityState]) -> SampledSecurity {
215 let n = states.len();
216 let threshold = n / 2 + 1;
217
218 SampledSecurity {
219 secret_scanning: states.iter().filter(|s| s.secret_scanning).count() >= threshold,
220 push_protection: states.iter().filter(|s| s.push_protection).count() >= threshold,
221 dependabot_alerts: states.iter().filter(|s| s.dependabot_alerts).count() >= threshold,
222 dependabot_security_updates: states
223 .iter()
224 .filter(|s| s.dependabot_security_updates)
225 .count()
226 >= threshold,
227 secret_scanning_ai_detection: states
228 .iter()
229 .filter(|s| s.secret_scanning_ai_detection)
230 .count()
231 >= threshold,
232 }
233}
234
235fn majority_vote_protection(states: &[BranchProtectionState]) -> SampledProtection {
236 let n = states.len();
237 let threshold = n / 2 + 1;
238
239 let approvals: Vec<u32> = states
240 .iter()
241 .map(|s| s.required_approving_review_count)
242 .collect();
243 let median_approvals = {
244 let mut sorted = approvals.clone();
245 sorted.sort();
246 sorted[sorted.len() / 2]
247 };
248
249 SampledProtection {
250 enabled: states
251 .iter()
252 .filter(|s| s.required_pull_request_reviews)
253 .count()
254 >= threshold,
255 required_approvals: median_approvals,
256 dismiss_stale_reviews: states.iter().filter(|s| s.dismiss_stale_reviews).count()
257 >= threshold,
258 require_code_owner_reviews: states
259 .iter()
260 .filter(|s| s.require_code_owner_reviews)
261 .count()
262 >= threshold,
263 require_status_checks: states.iter().filter(|s| s.required_status_checks).count()
264 >= threshold,
265 strict_status_checks: states.iter().filter(|s| s.strict_status_checks).count() >= threshold,
266 enforce_admins: states.iter().filter(|s| s.enforce_admins).count() >= threshold,
267 required_linear_history: states.iter().filter(|s| s.required_linear_history).count()
268 >= threshold,
269 allow_force_pushes: states.iter().filter(|s| s.allow_force_pushes).count() >= threshold,
270 allow_deletions: states.iter().filter(|s| s.allow_deletions).count() >= threshold,
271 }
272}
273
274fn merge_security_samples<'a>(
275 samples: impl Iterator<Item = &'a SampledSecurity>,
276) -> SampledSecurity {
277 let all: Vec<&SampledSecurity> = samples.collect();
278 let n = all.len();
279 let threshold = n / 2 + 1;
280
281 SampledSecurity {
282 secret_scanning: all.iter().filter(|s| s.secret_scanning).count() >= threshold,
283 push_protection: all.iter().filter(|s| s.push_protection).count() >= threshold,
284 dependabot_alerts: all.iter().filter(|s| s.dependabot_alerts).count() >= threshold,
285 dependabot_security_updates: all.iter().filter(|s| s.dependabot_security_updates).count()
286 >= threshold,
287 secret_scanning_ai_detection: all
288 .iter()
289 .filter(|s| s.secret_scanning_ai_detection)
290 .count()
291 >= threshold,
292 }
293}
294
295async fn sample_teams(
296 client: &Client,
297 systems: &[DetectedSystem],
298) -> HashMap<String, Vec<(String, String)>> {
299 let mut team_map: HashMap<String, Vec<(String, String)>> = HashMap::new();
300
301 for sys in systems {
302 if let Some(repo_name) = sys.repos.first() {
303 if let Ok(teams) = client.list_repo_teams(repo_name).await {
304 let entries: Vec<(String, String)> =
305 teams.into_iter().map(|t| (t.slug, t.permission)).collect();
306 if !entries.is_empty() {
307 team_map.insert(sys.id.clone(), entries);
308 }
309 }
310 }
311 }
312
313 team_map
314}
315
316fn generate_toml(
317 org: &str,
318 systems: &[DetectedSystem],
319 ungrouped: &[&str],
320 security: &SampledSecurity,
321 protection: &SampledProtection,
322 has_protection: bool,
323 team_map: &HashMap<String, Vec<(String, String)>>,
324) -> String {
325 let mut out = String::new();
326
327 out.push_str(&format!("# Ward configuration -- imported from {org}\n\n"));
328
329 out.push_str(&format!("[org]\nname = \"{org}\"\n\n"));
330
331 out.push_str("# Security settings (sampled from existing repos)\n");
332 out.push_str("[security]\n");
333 out.push_str(&format!("secret_scanning = {}\n", security.secret_scanning));
334 out.push_str(&format!(
335 "secret_scanning_ai_detection = {}\n",
336 security.secret_scanning_ai_detection
337 ));
338 out.push_str(&format!("push_protection = {}\n", security.push_protection));
339 out.push_str(&format!(
340 "dependabot_alerts = {}\n",
341 security.dependabot_alerts
342 ));
343 out.push_str(&format!(
344 "dependabot_security_updates = {}\n",
345 security.dependabot_security_updates
346 ));
347 out.push('\n');
348
349 if has_protection {
350 out.push_str("# Branch protection (sampled from existing repos)\n");
351 out.push_str("[branch_protection]\n");
352 out.push_str(&format!("enabled = {}\n", protection.enabled));
353 out.push_str(&format!(
354 "required_approvals = {}\n",
355 protection.required_approvals
356 ));
357 out.push_str(&format!(
358 "dismiss_stale_reviews = {}\n",
359 protection.dismiss_stale_reviews
360 ));
361 out.push_str(&format!(
362 "require_code_owner_reviews = {}\n",
363 protection.require_code_owner_reviews
364 ));
365 out.push_str(&format!(
366 "require_status_checks = {}\n",
367 protection.require_status_checks
368 ));
369 out.push_str(&format!(
370 "strict_status_checks = {}\n",
371 protection.strict_status_checks
372 ));
373 out.push_str(&format!("enforce_admins = {}\n", protection.enforce_admins));
374 out.push_str(&format!(
375 "required_linear_history = {}\n",
376 protection.required_linear_history
377 ));
378 out.push_str(&format!(
379 "allow_force_pushes = {}\n",
380 protection.allow_force_pushes
381 ));
382 out.push_str(&format!(
383 "allow_deletions = {}\n",
384 protection.allow_deletions
385 ));
386 out.push('\n');
387 }
388
389 for sys in systems {
390 out.push_str(&format!("# Detected system: {} repos\n", sys.repos.len()));
391 out.push_str("[[systems]]\n");
392 out.push_str(&format!("id = \"{}\"\n", sys.id));
393 out.push_str(&format!("name = \"{}\"\n", titlecase(&sys.id)));
394
395 if let Some(teams) = team_map.get(&sys.id) {
396 out.push_str("teams = [\n");
397 for (slug, perm) in teams {
398 out.push_str(&format!(
399 " {{ slug = \"{slug}\", permission = \"{perm}\" }},\n"
400 ));
401 }
402 out.push_str("]\n");
403 }
404
405 out.push('\n');
406 }
407
408 if !ungrouped.is_empty() {
409 out.push_str("# Ungrouped repositories (did not match any system prefix)\n");
410 for name in ungrouped {
411 out.push_str(&format!("# - {name}\n"));
412 }
413 out.push('\n');
414 }
415
416 out
417}
418
419fn titlecase(s: &str) -> String {
420 let mut chars = s.chars();
421 match chars.next() {
422 None => String::new(),
423 Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_detect_systems_groups_by_prefix() {
433 let repos = vec![
434 "backend-api".to_string(),
435 "backend-auth".to_string(),
436 "backend-common".to_string(),
437 "frontend-web".to_string(),
438 "frontend-mobile".to_string(),
439 "standalone".to_string(),
440 ];
441
442 let systems = detect_systems(&repos, 2);
443 assert_eq!(systems.len(), 2);
444
445 let be = systems.iter().find(|s| s.id == "backend").unwrap();
446 assert_eq!(be.repos.len(), 3);
447 assert!(be.repos.contains(&"backend-api".to_string()));
448 assert!(be.repos.contains(&"backend-auth".to_string()));
449 assert!(be.repos.contains(&"backend-common".to_string()));
450
451 let fe = systems.iter().find(|s| s.id == "frontend").unwrap();
452 assert_eq!(fe.repos.len(), 2);
453 }
454
455 #[test]
456 fn test_detect_systems_respects_min_group_size() {
457 let repos = vec![
458 "backend-api".to_string(),
459 "backend-auth".to_string(),
460 "frontend-web".to_string(),
461 ];
462
463 let systems_min2 = detect_systems(&repos, 2);
464 assert_eq!(systems_min2.len(), 1);
465 assert_eq!(systems_min2[0].id, "backend");
466
467 let systems_min3 = detect_systems(&repos, 3);
468 assert!(systems_min3.is_empty());
469 }
470
471 #[test]
472 fn test_majority_vote_security() {
473 let states = vec![
474 SecurityState {
475 secret_scanning: true,
476 push_protection: true,
477 dependabot_alerts: true,
478 dependabot_security_updates: false,
479 secret_scanning_ai_detection: false,
480 },
481 SecurityState {
482 secret_scanning: true,
483 push_protection: false,
484 dependabot_alerts: true,
485 dependabot_security_updates: false,
486 secret_scanning_ai_detection: true,
487 },
488 SecurityState {
489 secret_scanning: true,
490 push_protection: true,
491 dependabot_alerts: false,
492 dependabot_security_updates: false,
493 secret_scanning_ai_detection: false,
494 },
495 ];
496
497 let result = majority_vote_security(&states);
498 assert!(result.secret_scanning); assert!(result.push_protection); assert!(result.dependabot_alerts); assert!(!result.dependabot_security_updates); assert!(!result.secret_scanning_ai_detection); }
504
505 #[test]
506 fn test_generate_toml_output() {
507 let systems = vec![DetectedSystem {
508 id: "backend".to_string(),
509 repos: vec!["backend-api".to_string(), "backend-auth".to_string()],
510 }];
511 let ungrouped: Vec<&str> = vec!["standalone"];
512 let security = SampledSecurity {
513 secret_scanning: true,
514 push_protection: true,
515 dependabot_alerts: true,
516 dependabot_security_updates: false,
517 secret_scanning_ai_detection: false,
518 };
519 let protection = SampledProtection {
520 enabled: true,
521 required_approvals: 1,
522 ..Default::default()
523 };
524 let team_map = HashMap::new();
525
526 let toml = generate_toml(
527 "my-org",
528 &systems,
529 &ungrouped,
530 &security,
531 &protection,
532 true,
533 &team_map,
534 );
535
536 assert!(toml.contains("[org]"));
537 assert!(toml.contains("name = \"my-org\""));
538 assert!(toml.contains("secret_scanning = true"));
539 assert!(toml.contains("dependabot_security_updates = false"));
540 assert!(toml.contains("[[systems]]"));
541 assert!(toml.contains("id = \"backend\""));
542 assert!(toml.contains("enabled = true"));
543 assert!(toml.contains("required_approvals = 1"));
544 assert!(toml.contains("# - standalone"));
545 }
546
547 #[test]
548 fn test_detect_systems_excludes_single_segment_names() {
549 let repos = vec![
550 "standalone".to_string(),
551 "another".to_string(),
552 "third".to_string(),
553 ];
554 let systems = detect_systems(&repos, 2);
555 assert!(systems.is_empty());
556 }
557
558 #[test]
559 fn test_majority_vote_protection() {
560 let states = vec![
561 BranchProtectionState {
562 required_pull_request_reviews: true,
563 required_approving_review_count: 2,
564 dismiss_stale_reviews: true,
565 ..Default::default()
566 },
567 BranchProtectionState {
568 required_pull_request_reviews: true,
569 required_approving_review_count: 1,
570 dismiss_stale_reviews: false,
571 ..Default::default()
572 },
573 BranchProtectionState {
574 required_pull_request_reviews: false,
575 required_approving_review_count: 1,
576 dismiss_stale_reviews: true,
577 ..Default::default()
578 },
579 ];
580
581 let result = majority_vote_protection(&states);
582 assert!(result.enabled); assert_eq!(result.required_approvals, 1); assert!(result.dismiss_stale_reviews); }
586}