Skip to main content

ward/cli/
import.rs

1use std::collections::HashMap;
2
3use anyhow::Result;
4use clap::Args;
5use console::style;
6
7use crate::github::Client;
8use crate::github::branch_protection::BranchProtectionState;
9use crate::github::repos::Repository;
10use crate::github::security::SecurityState;
11
12#[derive(Args)]
13pub struct ImportCommand {
14    /// GitHub organization to import from
15    #[arg(long, required = true)]
16    org: String,
17
18    /// Output to stdout instead of ward.toml
19    #[arg(long)]
20    stdout: bool,
21
22    /// Minimum repos to form a system (default: 2)
23    #[arg(long, default_value_t = 2)]
24    min_group_size: usize,
25
26    /// Max concurrent API calls
27    #[arg(long, default_value_t = 5)]
28    parallelism: usize,
29}
30
31#[derive(Debug, Clone)]
32struct DetectedSystem {
33    id: String,
34    repos: Vec<String>,
35}
36
37#[derive(Debug, Clone, Default)]
38struct SampledSecurity {
39    secret_scanning: bool,
40    push_protection: bool,
41    dependabot_alerts: bool,
42    dependabot_security_updates: bool,
43    secret_scanning_ai_detection: bool,
44}
45
46#[derive(Debug, Clone, Default)]
47struct SampledProtection {
48    enabled: bool,
49    required_approvals: u32,
50    dismiss_stale_reviews: bool,
51    require_code_owner_reviews: bool,
52    require_status_checks: bool,
53    strict_status_checks: bool,
54    enforce_admins: bool,
55    required_linear_history: bool,
56    allow_force_pushes: bool,
57    allow_deletions: bool,
58}
59
60impl ImportCommand {
61    pub async fn run(self) -> Result<()> {
62        let client = Client::new(&self.org, self.parallelism).await?;
63
64        println!(
65            "\n  {} Fetching repositories for {}...",
66            style("[..]").dim(),
67            style(&self.org).cyan().bold()
68        );
69
70        let repos = client.list_repos().await?;
71        let active: Vec<&Repository> = repos.iter().filter(|r| !r.archived).collect();
72
73        println!(
74            "  {} Found {} repositories ({} active)",
75            style("[ok]").green(),
76            repos.len(),
77            active.len()
78        );
79
80        let active_names: Vec<String> = active.iter().map(|r| r.name.clone()).collect();
81        let systems = detect_systems(&active_names, self.min_group_size);
82
83        println!(
84            "  {} Detected {} systems",
85            style("[ok]").green(),
86            systems.len()
87        );
88        for sys in &systems {
89            println!(
90                "    - {} ({} repos)",
91                style(&sys.id).bold(),
92                sys.repos.len()
93            );
94        }
95
96        let grouped: Vec<&str> = systems
97            .iter()
98            .flat_map(|s| s.repos.iter().map(String::as_str))
99            .collect();
100        let ungrouped: Vec<&str> = active_names
101            .iter()
102            .filter(|n| !grouped.contains(&n.as_str()))
103            .map(String::as_str)
104            .collect();
105
106        println!(
107            "\n  {} Sampling security and branch protection...",
108            style("[..]").dim()
109        );
110
111        let repo_map: HashMap<&str, &Repository> =
112            active.iter().map(|r| (r.name.as_str(), *r)).collect();
113
114        let mut system_security: HashMap<String, SampledSecurity> = HashMap::new();
115        let mut global_protection = SampledProtection::default();
116        let mut sampled_any_protection = false;
117
118        for sys in &systems {
119            let sample: Vec<&str> = sys.repos.iter().take(5).map(String::as_str).collect();
120            let mut sec_states = Vec::new();
121            let mut prot_states = Vec::new();
122
123            for repo_name in &sample {
124                if let Ok(sec) = client.get_security_state(repo_name).await {
125                    sec_states.push(sec);
126                }
127                if let Some(repo) = repo_map.get(repo_name) {
128                    if let Ok(Some(prot)) = client
129                        .get_branch_protection(repo_name, &repo.default_branch)
130                        .await
131                    {
132                        prot_states.push(prot);
133                    }
134                }
135            }
136
137            if !sec_states.is_empty() {
138                system_security.insert(sys.id.clone(), majority_vote_security(&sec_states));
139            }
140
141            if !prot_states.is_empty() && !sampled_any_protection {
142                global_protection = majority_vote_protection(&prot_states);
143                sampled_any_protection = true;
144            }
145        }
146
147        let global_sec = if system_security.is_empty() {
148            SampledSecurity::default()
149        } else {
150            merge_security_samples(system_security.values())
151        };
152
153        let team_map = sample_teams(&client, &systems).await;
154
155        let toml_output = generate_toml(
156            &self.org,
157            &systems,
158            &ungrouped,
159            &global_sec,
160            &global_protection,
161            sampled_any_protection,
162            &team_map,
163        );
164
165        if self.stdout {
166            println!("{toml_output}");
167        } else {
168            let path = "ward.toml";
169            if std::path::Path::new(path).exists() {
170                anyhow::bail!(
171                    "ward.toml already exists. Use --stdout to print instead, or remove the file first."
172                );
173            }
174            std::fs::write(path, &toml_output)?;
175            println!("\n  {} Wrote {}", style("[ok]").green(), style(path).bold());
176        }
177
178        println!(
179            "\n  {} Import complete. Review the generated config and adjust as needed.",
180            style("[ok]").green()
181        );
182
183        Ok(())
184    }
185}
186
187fn detect_systems(repo_names: &[String], min_group_size: usize) -> Vec<DetectedSystem> {
188    let mut groups: HashMap<String, Vec<String>> = HashMap::new();
189
190    for name in repo_names {
191        if let Some(prefix) = name.split('-').next() {
192            if !prefix.is_empty() && prefix != name {
193                groups
194                    .entry(prefix.to_string())
195                    .or_default()
196                    .push(name.clone());
197            }
198        }
199    }
200
201    let mut systems: Vec<DetectedSystem> = groups
202        .into_iter()
203        .filter(|(_, repos)| repos.len() >= min_group_size)
204        .map(|(id, mut repos)| {
205            repos.sort();
206            DetectedSystem { id, repos }
207        })
208        .collect();
209
210    systems.sort_by(|a, b| a.id.cmp(&b.id));
211    systems
212}
213
214fn majority_vote_security(states: &[SecurityState]) -> SampledSecurity {
215    let n = states.len();
216    let threshold = n / 2 + 1;
217
218    SampledSecurity {
219        secret_scanning: states.iter().filter(|s| s.secret_scanning).count() >= threshold,
220        push_protection: states.iter().filter(|s| s.push_protection).count() >= threshold,
221        dependabot_alerts: states.iter().filter(|s| s.dependabot_alerts).count() >= threshold,
222        dependabot_security_updates: states
223            .iter()
224            .filter(|s| s.dependabot_security_updates)
225            .count()
226            >= threshold,
227        secret_scanning_ai_detection: states
228            .iter()
229            .filter(|s| s.secret_scanning_ai_detection)
230            .count()
231            >= threshold,
232    }
233}
234
235fn majority_vote_protection(states: &[BranchProtectionState]) -> SampledProtection {
236    let n = states.len();
237    let threshold = n / 2 + 1;
238
239    let approvals: Vec<u32> = states
240        .iter()
241        .map(|s| s.required_approving_review_count)
242        .collect();
243    let median_approvals = {
244        let mut sorted = approvals.clone();
245        sorted.sort();
246        sorted[sorted.len() / 2]
247    };
248
249    SampledProtection {
250        enabled: states
251            .iter()
252            .filter(|s| s.required_pull_request_reviews)
253            .count()
254            >= threshold,
255        required_approvals: median_approvals,
256        dismiss_stale_reviews: states.iter().filter(|s| s.dismiss_stale_reviews).count()
257            >= threshold,
258        require_code_owner_reviews: states
259            .iter()
260            .filter(|s| s.require_code_owner_reviews)
261            .count()
262            >= threshold,
263        require_status_checks: states.iter().filter(|s| s.required_status_checks).count()
264            >= threshold,
265        strict_status_checks: states.iter().filter(|s| s.strict_status_checks).count() >= threshold,
266        enforce_admins: states.iter().filter(|s| s.enforce_admins).count() >= threshold,
267        required_linear_history: states.iter().filter(|s| s.required_linear_history).count()
268            >= threshold,
269        allow_force_pushes: states.iter().filter(|s| s.allow_force_pushes).count() >= threshold,
270        allow_deletions: states.iter().filter(|s| s.allow_deletions).count() >= threshold,
271    }
272}
273
274fn merge_security_samples<'a>(
275    samples: impl Iterator<Item = &'a SampledSecurity>,
276) -> SampledSecurity {
277    let all: Vec<&SampledSecurity> = samples.collect();
278    let n = all.len();
279    let threshold = n / 2 + 1;
280
281    SampledSecurity {
282        secret_scanning: all.iter().filter(|s| s.secret_scanning).count() >= threshold,
283        push_protection: all.iter().filter(|s| s.push_protection).count() >= threshold,
284        dependabot_alerts: all.iter().filter(|s| s.dependabot_alerts).count() >= threshold,
285        dependabot_security_updates: all.iter().filter(|s| s.dependabot_security_updates).count()
286            >= threshold,
287        secret_scanning_ai_detection: all
288            .iter()
289            .filter(|s| s.secret_scanning_ai_detection)
290            .count()
291            >= threshold,
292    }
293}
294
295async fn sample_teams(
296    client: &Client,
297    systems: &[DetectedSystem],
298) -> HashMap<String, Vec<(String, String)>> {
299    let mut team_map: HashMap<String, Vec<(String, String)>> = HashMap::new();
300
301    for sys in systems {
302        if let Some(repo_name) = sys.repos.first() {
303            if let Ok(teams) = client.list_repo_teams(repo_name).await {
304                let entries: Vec<(String, String)> =
305                    teams.into_iter().map(|t| (t.slug, t.permission)).collect();
306                if !entries.is_empty() {
307                    team_map.insert(sys.id.clone(), entries);
308                }
309            }
310        }
311    }
312
313    team_map
314}
315
316fn generate_toml(
317    org: &str,
318    systems: &[DetectedSystem],
319    ungrouped: &[&str],
320    security: &SampledSecurity,
321    protection: &SampledProtection,
322    has_protection: bool,
323    team_map: &HashMap<String, Vec<(String, String)>>,
324) -> String {
325    let mut out = String::new();
326
327    out.push_str(&format!("# Ward configuration -- imported from {org}\n\n"));
328
329    out.push_str(&format!("[org]\nname = \"{org}\"\n\n"));
330
331    out.push_str("# Security settings (sampled from existing repos)\n");
332    out.push_str("[security]\n");
333    out.push_str(&format!("secret_scanning = {}\n", security.secret_scanning));
334    out.push_str(&format!(
335        "secret_scanning_ai_detection = {}\n",
336        security.secret_scanning_ai_detection
337    ));
338    out.push_str(&format!("push_protection = {}\n", security.push_protection));
339    out.push_str(&format!(
340        "dependabot_alerts = {}\n",
341        security.dependabot_alerts
342    ));
343    out.push_str(&format!(
344        "dependabot_security_updates = {}\n",
345        security.dependabot_security_updates
346    ));
347    out.push('\n');
348
349    if has_protection {
350        out.push_str("# Branch protection (sampled from existing repos)\n");
351        out.push_str("[branch_protection]\n");
352        out.push_str(&format!("enabled = {}\n", protection.enabled));
353        out.push_str(&format!(
354            "required_approvals = {}\n",
355            protection.required_approvals
356        ));
357        out.push_str(&format!(
358            "dismiss_stale_reviews = {}\n",
359            protection.dismiss_stale_reviews
360        ));
361        out.push_str(&format!(
362            "require_code_owner_reviews = {}\n",
363            protection.require_code_owner_reviews
364        ));
365        out.push_str(&format!(
366            "require_status_checks = {}\n",
367            protection.require_status_checks
368        ));
369        out.push_str(&format!(
370            "strict_status_checks = {}\n",
371            protection.strict_status_checks
372        ));
373        out.push_str(&format!("enforce_admins = {}\n", protection.enforce_admins));
374        out.push_str(&format!(
375            "required_linear_history = {}\n",
376            protection.required_linear_history
377        ));
378        out.push_str(&format!(
379            "allow_force_pushes = {}\n",
380            protection.allow_force_pushes
381        ));
382        out.push_str(&format!(
383            "allow_deletions = {}\n",
384            protection.allow_deletions
385        ));
386        out.push('\n');
387    }
388
389    for sys in systems {
390        out.push_str(&format!("# Detected system: {} repos\n", sys.repos.len()));
391        out.push_str("[[systems]]\n");
392        out.push_str(&format!("id = \"{}\"\n", sys.id));
393        out.push_str(&format!("name = \"{}\"\n", titlecase(&sys.id)));
394
395        if let Some(teams) = team_map.get(&sys.id) {
396            out.push_str("teams = [\n");
397            for (slug, perm) in teams {
398                out.push_str(&format!(
399                    "    {{ slug = \"{slug}\", permission = \"{perm}\" }},\n"
400                ));
401            }
402            out.push_str("]\n");
403        }
404
405        out.push('\n');
406    }
407
408    if !ungrouped.is_empty() {
409        out.push_str("# Ungrouped repositories (did not match any system prefix)\n");
410        for name in ungrouped {
411            out.push_str(&format!("# - {name}\n"));
412        }
413        out.push('\n');
414    }
415
416    out
417}
418
419fn titlecase(s: &str) -> String {
420    let mut chars = s.chars();
421    match chars.next() {
422        None => String::new(),
423        Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn test_detect_systems_groups_by_prefix() {
433        let repos = vec![
434            "backend-api".to_string(),
435            "backend-auth".to_string(),
436            "backend-common".to_string(),
437            "frontend-web".to_string(),
438            "frontend-mobile".to_string(),
439            "standalone".to_string(),
440        ];
441
442        let systems = detect_systems(&repos, 2);
443        assert_eq!(systems.len(), 2);
444
445        let be = systems.iter().find(|s| s.id == "backend").unwrap();
446        assert_eq!(be.repos.len(), 3);
447        assert!(be.repos.contains(&"backend-api".to_string()));
448        assert!(be.repos.contains(&"backend-auth".to_string()));
449        assert!(be.repos.contains(&"backend-common".to_string()));
450
451        let fe = systems.iter().find(|s| s.id == "frontend").unwrap();
452        assert_eq!(fe.repos.len(), 2);
453    }
454
455    #[test]
456    fn test_detect_systems_respects_min_group_size() {
457        let repos = vec![
458            "backend-api".to_string(),
459            "backend-auth".to_string(),
460            "frontend-web".to_string(),
461        ];
462
463        let systems_min2 = detect_systems(&repos, 2);
464        assert_eq!(systems_min2.len(), 1);
465        assert_eq!(systems_min2[0].id, "backend");
466
467        let systems_min3 = detect_systems(&repos, 3);
468        assert!(systems_min3.is_empty());
469    }
470
471    #[test]
472    fn test_majority_vote_security() {
473        let states = vec![
474            SecurityState {
475                secret_scanning: true,
476                push_protection: true,
477                dependabot_alerts: true,
478                dependabot_security_updates: false,
479                secret_scanning_ai_detection: false,
480            },
481            SecurityState {
482                secret_scanning: true,
483                push_protection: false,
484                dependabot_alerts: true,
485                dependabot_security_updates: false,
486                secret_scanning_ai_detection: true,
487            },
488            SecurityState {
489                secret_scanning: true,
490                push_protection: true,
491                dependabot_alerts: false,
492                dependabot_security_updates: false,
493                secret_scanning_ai_detection: false,
494            },
495        ];
496
497        let result = majority_vote_security(&states);
498        assert!(result.secret_scanning); // 3/3
499        assert!(result.push_protection); // 2/3
500        assert!(result.dependabot_alerts); // 2/3
501        assert!(!result.dependabot_security_updates); // 0/3
502        assert!(!result.secret_scanning_ai_detection); // 1/3
503    }
504
505    #[test]
506    fn test_generate_toml_output() {
507        let systems = vec![DetectedSystem {
508            id: "backend".to_string(),
509            repos: vec!["backend-api".to_string(), "backend-auth".to_string()],
510        }];
511        let ungrouped: Vec<&str> = vec!["standalone"];
512        let security = SampledSecurity {
513            secret_scanning: true,
514            push_protection: true,
515            dependabot_alerts: true,
516            dependabot_security_updates: false,
517            secret_scanning_ai_detection: false,
518        };
519        let protection = SampledProtection {
520            enabled: true,
521            required_approvals: 1,
522            ..Default::default()
523        };
524        let team_map = HashMap::new();
525
526        let toml = generate_toml(
527            "my-org",
528            &systems,
529            &ungrouped,
530            &security,
531            &protection,
532            true,
533            &team_map,
534        );
535
536        assert!(toml.contains("[org]"));
537        assert!(toml.contains("name = \"my-org\""));
538        assert!(toml.contains("secret_scanning = true"));
539        assert!(toml.contains("dependabot_security_updates = false"));
540        assert!(toml.contains("[[systems]]"));
541        assert!(toml.contains("id = \"backend\""));
542        assert!(toml.contains("enabled = true"));
543        assert!(toml.contains("required_approvals = 1"));
544        assert!(toml.contains("# - standalone"));
545    }
546
547    #[test]
548    fn test_detect_systems_excludes_single_segment_names() {
549        let repos = vec![
550            "standalone".to_string(),
551            "another".to_string(),
552            "third".to_string(),
553        ];
554        let systems = detect_systems(&repos, 2);
555        assert!(systems.is_empty());
556    }
557
558    #[test]
559    fn test_majority_vote_protection() {
560        let states = vec![
561            BranchProtectionState {
562                required_pull_request_reviews: true,
563                required_approving_review_count: 2,
564                dismiss_stale_reviews: true,
565                ..Default::default()
566            },
567            BranchProtectionState {
568                required_pull_request_reviews: true,
569                required_approving_review_count: 1,
570                dismiss_stale_reviews: false,
571                ..Default::default()
572            },
573            BranchProtectionState {
574                required_pull_request_reviews: false,
575                required_approving_review_count: 1,
576                dismiss_stale_reviews: true,
577                ..Default::default()
578            },
579        ];
580
581        let result = majority_vote_protection(&states);
582        assert!(result.enabled); // 2/3
583        assert_eq!(result.required_approvals, 1); // median
584        assert!(result.dismiss_stale_reviews); // 2/3
585    }
586}