1use anyhow::{Context, Result};
4use git2::{BranchType, Repository};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct RemoteInfo {
10 pub name: String,
12 pub uri: String,
14 pub main_branch: String,
16}
17
18impl RemoteInfo {
19 pub fn get_all_remotes(repo: &Repository) -> Result<Vec<Self>> {
21 let mut remotes = Vec::new();
22 let remote_names = repo.remotes().context("Failed to get remote names")?;
23
24 for name in remote_names.iter().flatten() {
25 if let Ok(remote) = repo.find_remote(name) {
26 let uri = remote.url().unwrap_or("").to_string();
27 let main_branch = Self::detect_main_branch(repo, name)?;
28
29 remotes.push(RemoteInfo {
30 name: name.to_string(),
31 uri,
32 main_branch,
33 });
34 }
35 }
36
37 Ok(remotes)
38 }
39
40 fn detect_main_branch(repo: &Repository, remote_name: &str) -> Result<String> {
42 let head_ref_name = format!("refs/remotes/{}/HEAD", remote_name);
44 if let Ok(head_ref) = repo.find_reference(&head_ref_name) {
45 if let Some(target) = head_ref.symbolic_target() {
46 if let Some(branch_name) =
48 target.strip_prefix(&format!("refs/remotes/{}/", remote_name))
49 {
50 return Ok(branch_name.to_string());
51 }
52 }
53 }
54
55 if let Ok(remote) = repo.find_remote(remote_name) {
57 if let Some(uri) = remote.url() {
58 if uri.contains("github.com") {
59 if let Ok(main_branch) = Self::get_github_default_branch(uri) {
60 return Ok(main_branch);
61 }
62 }
63 }
64 }
65
66 let common_branches = ["main", "master", "develop"];
68
69 if remote_name == "origin" {
71 for branch_name in &common_branches {
72 let reference_name = format!("refs/remotes/origin/{}", branch_name);
73 if repo.find_reference(&reference_name).is_ok() {
74 return Ok(branch_name.to_string());
75 }
76 }
77 } else {
78 for branch_name in &common_branches {
80 let origin_reference = format!("refs/remotes/origin/{}", branch_name);
81 if repo.find_reference(&origin_reference).is_ok() {
82 return Ok(branch_name.to_string());
83 }
84 }
85
86 for branch_name in &common_branches {
88 let reference_name = format!("refs/remotes/{}/{}", remote_name, branch_name);
89 if repo.find_reference(&reference_name).is_ok() {
90 return Ok(branch_name.to_string());
91 }
92 }
93 }
94
95 let branch_iter = repo.branches(Some(BranchType::Remote))?;
97 for branch_result in branch_iter {
98 let (branch, _) = branch_result?;
99 if let Some(name) = branch.name()? {
100 if name.starts_with(&format!("{}/", remote_name)) {
101 let branch_name = name
102 .strip_prefix(&format!("{}/", remote_name))
103 .unwrap_or(name);
104 return Ok(branch_name.to_string());
105 }
106 }
107 }
108
109 Ok("unknown".to_string())
111 }
112
113 fn get_github_default_branch(uri: &str) -> Result<String> {
115 use std::process::Command;
116
117 let repo_name = Self::extract_github_repo_name(uri)?;
119
120 let output = Command::new("gh")
122 .args([
123 "repo",
124 "view",
125 &repo_name,
126 "--json",
127 "defaultBranchRef",
128 "--jq",
129 ".defaultBranchRef.name",
130 ])
131 .output();
132
133 match output {
134 Ok(output) if output.status.success() => {
135 let branch_name = String::from_utf8_lossy(&output.stdout).trim().to_string();
136 if !branch_name.is_empty() && branch_name != "null" {
137 Ok(branch_name)
138 } else {
139 anyhow::bail!("GitHub CLI returned empty or null branch name")
140 }
141 }
142 _ => anyhow::bail!("Failed to get default branch from GitHub CLI"),
143 }
144 }
145
146 fn extract_github_repo_name(uri: &str) -> Result<String> {
148 let repo_name = if uri.starts_with("git@github.com:") {
150 uri.strip_prefix("git@github.com:")
152 .and_then(|s| s.strip_suffix(".git"))
153 .unwrap_or(uri.strip_prefix("git@github.com:").unwrap_or(uri))
154 } else if uri.contains("github.com") {
155 uri.split("github.com/")
157 .nth(1)
158 .and_then(|s| s.strip_suffix(".git"))
159 .unwrap_or(uri.split("github.com/").nth(1).unwrap_or(uri))
160 } else {
161 anyhow::bail!("Not a GitHub URI: {}", uri);
162 };
163
164 if repo_name.split('/').count() != 2 {
165 anyhow::bail!("Invalid GitHub repository format: {}", repo_name);
166 }
167
168 Ok(repo_name.to_string())
169 }
170}