use anyhow::{Context, Result};
use git2::{BranchType, Repository};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RemoteInfo {
pub name: String,
pub uri: String,
pub main_branch: String,
}
impl RemoteInfo {
pub fn get_all_remotes(repo: &Repository) -> Result<Vec<Self>> {
let mut remotes = Vec::new();
let remote_names = repo.remotes().context("Failed to get remote names")?;
for name in remote_names.iter().flatten() {
if let Ok(remote) = repo.find_remote(name) {
let uri = remote.url().unwrap_or("").to_string();
let main_branch = Self::detect_main_branch(repo, name)?;
remotes.push(Self {
name: name.to_string(),
uri,
main_branch,
});
}
}
Ok(remotes)
}
fn detect_main_branch(repo: &Repository, remote_name: &str) -> Result<String> {
let head_ref_name = format!("refs/remotes/{remote_name}/HEAD");
if let Ok(head_ref) = repo.find_reference(&head_ref_name) {
if let Some(target) = head_ref.symbolic_target() {
if let Some(branch_name) =
target.strip_prefix(&format!("refs/remotes/{remote_name}/"))
{
return Ok(branch_name.to_string());
}
}
}
if let Ok(remote) = repo.find_remote(remote_name) {
if let Some(uri) = remote.url() {
if uri.contains("github.com") {
if let Ok(main_branch) = Self::get_github_default_branch(uri) {
return Ok(main_branch);
}
}
}
}
let common_branches = ["main", "master", "develop"];
if remote_name == "origin" {
for branch_name in &common_branches {
let reference_name = format!("refs/remotes/origin/{branch_name}");
if repo.find_reference(&reference_name).is_ok() {
return Ok((*branch_name).to_string());
}
}
} else {
for branch_name in &common_branches {
let origin_reference = format!("refs/remotes/origin/{branch_name}");
if repo.find_reference(&origin_reference).is_ok() {
return Ok((*branch_name).to_string());
}
}
for branch_name in &common_branches {
let reference_name = format!("refs/remotes/{remote_name}/{branch_name}");
if repo.find_reference(&reference_name).is_ok() {
return Ok((*branch_name).to_string());
}
}
}
let branch_iter = repo.branches(Some(BranchType::Remote))?;
for branch_result in branch_iter {
let (branch, _) = branch_result?;
if let Some(name) = branch.name()? {
if name.starts_with(&format!("{remote_name}/")) {
let branch_name = name
.strip_prefix(&format!("{remote_name}/"))
.unwrap_or(name);
return Ok(branch_name.to_string());
}
}
}
Ok("unknown".to_string())
}
fn get_github_default_branch(uri: &str) -> Result<String> {
use std::process::Command;
let repo_name = Self::extract_github_repo_name(uri)?;
let output = Command::new("gh")
.args([
"repo",
"view",
&repo_name,
"--json",
"defaultBranchRef",
"--jq",
".defaultBranchRef.name",
])
.output();
match output {
Ok(output) if output.status.success() => {
let branch_name = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !branch_name.is_empty() && branch_name != "null" {
Ok(branch_name)
} else {
anyhow::bail!("GitHub CLI returned empty or null branch name")
}
}
_ => anyhow::bail!("Failed to get default branch from GitHub CLI"),
}
}
fn extract_github_repo_name(uri: &str) -> Result<String> {
let repo_name = if uri.starts_with("git@github.com:") {
uri.strip_prefix("git@github.com:")
.and_then(|s| s.strip_suffix(".git"))
.unwrap_or(uri.strip_prefix("git@github.com:").unwrap_or(uri))
} else if uri.contains("github.com") {
uri.split("github.com/")
.nth(1)
.and_then(|s| s.strip_suffix(".git"))
.unwrap_or(uri.split("github.com/").nth(1).unwrap_or(uri))
} else {
anyhow::bail!("Not a GitHub URI: {uri}");
};
if repo_name.split('/').count() != 2 {
anyhow::bail!("Invalid GitHub repository format: {repo_name}");
}
Ok(repo_name.to_string())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn ssh_url() {
let result = RemoteInfo::extract_github_repo_name("git@github.com:owner/repo.git");
assert_eq!(result.unwrap(), "owner/repo");
}
#[test]
fn https_url() {
let result = RemoteInfo::extract_github_repo_name("https://github.com/owner/repo.git");
assert_eq!(result.unwrap(), "owner/repo");
}
#[test]
fn https_url_no_git_suffix() {
let result = RemoteInfo::extract_github_repo_name("https://github.com/owner/repo");
assert_eq!(result.unwrap(), "owner/repo");
}
#[test]
fn ssh_url_no_git_suffix() {
let result = RemoteInfo::extract_github_repo_name("git@github.com:owner/repo");
assert_eq!(result.unwrap(), "owner/repo");
}
#[test]
fn non_github_url_fails() {
let result = RemoteInfo::extract_github_repo_name("git@gitlab.com:owner/repo.git");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Not a GitHub URI"));
}
#[test]
fn invalid_format_fails() {
let result = RemoteInfo::extract_github_repo_name("git@github.com:invalid");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Invalid GitHub repository format"));
}
mod prop {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn ssh_url_extracts_repo(
owner in "[a-z]{3,10}",
repo in "[a-z]{3,10}",
) {
let url = format!("git@github.com:{owner}/{repo}.git");
let result = RemoteInfo::extract_github_repo_name(&url).unwrap();
prop_assert_eq!(result, format!("{owner}/{repo}"));
}
#[test]
fn https_url_extracts_repo(
owner in "[a-z]{3,10}",
repo in "[a-z]{3,10}",
) {
let url = format!("https://github.com/{owner}/{repo}.git");
let result = RemoteInfo::extract_github_repo_name(&url).unwrap();
prop_assert_eq!(result, format!("{owner}/{repo}"));
}
#[test]
fn non_github_url_errors(
host in "(gitlab|bitbucket|codeberg)",
path in "[a-z]{3,10}/[a-z]{3,10}",
) {
let url = format!("git@{host}.com:{path}.git");
prop_assert!(RemoteInfo::extract_github_repo_name(&url).is_err());
}
}
}
}