Skip to main content

workon/
default_branch.rs

1use git2::{Direction, Remote, RemoteCallbacks, Repository};
2
3use crate::error::{DefaultBranchError, Result};
4use crate::get_remote_callbacks;
5
6/// Builder for resolving the default branch name of a repository or remote.
7pub struct DefaultBranch<'repo, 'cb> {
8    repo: &'repo Repository,
9    remote: Option<Remote<'repo>>,
10    callbacks: Option<RemoteCallbacks<'cb>>,
11}
12
13impl<'repo, 'cb> DefaultBranch<'repo, 'cb> {
14    /// Create a new builder for the given repository.
15    pub fn new(repo: &'repo Repository) -> Self {
16        Self {
17            repo,
18            remote: None,
19            callbacks: None,
20        }
21    }
22
23    /// Set the remote to query for its default branch.
24    pub fn remote(&mut self, remote: Remote<'repo>) -> &mut Self {
25        self.remote = Some(remote);
26        self
27    }
28
29    /// Set credential callbacks for the remote connection.
30    pub fn remote_callbacks(&mut self, cbs: RemoteCallbacks<'cb>) -> &mut Self {
31        self.callbacks = Some(cbs);
32        self
33    }
34
35    /// Resolve the default branch name.
36    ///
37    /// If a remote was set, queries the remote for its HEAD ref.
38    /// Otherwise reads `init.defaultbranch` from git config (defaulting to `"main"`).
39    pub fn get_name(self) -> Result<String> {
40        match self.remote {
41            Some(mut remote) => {
42                let mut cxn = remote.connect_auth(Direction::Fetch, self.callbacks, None)?;
43
44                if !cxn.connected() {
45                    return Err(DefaultBranchError::NotConnected.into());
46                }
47
48                match cxn.default_branch()?.as_str() {
49                    Some(default_branch) => Ok(default_branch
50                        .strip_prefix("refs/heads/")
51                        .unwrap_or(default_branch)
52                        .to_string()),
53                    None => Err(DefaultBranchError::NoRemoteDefault {
54                        remote: cxn.remote().name().map(|s| s.to_string()),
55                    }
56                    .into()),
57                }
58            }
59            None => {
60                let config = self.repo.config()?;
61                let defaultbranch = config.get_str("init.defaultbranch").unwrap_or("main");
62                Ok(defaultbranch.to_string())
63            }
64        }
65    }
66}
67
68/// Convenience wrapper around [`DefaultBranch`].
69///
70/// Queries the remote for its default branch if one is provided, otherwise
71/// falls back to `init.defaultbranch` config (defaulting to `"main"`).
72pub fn get_default_branch_name(repo: &Repository, remote: Option<Remote>) -> Result<String> {
73    let mut default_branch = DefaultBranch::new(repo);
74    if let Some(remote) = remote {
75        default_branch.remote(remote);
76        default_branch.remote_callbacks(get_remote_callbacks().unwrap());
77    }
78    default_branch.get_name()
79}
80
81/// Get the default branch name for a repository, validated to exist.
82///
83/// This function:
84/// 1. Checks the `init.defaultBranch` config
85/// 2. Falls back to "main" if it exists
86/// 3. Falls back to "master" if it exists
87/// 4. Returns an error if none exist
88pub fn get_default_branch(repo: &Repository) -> Result<String> {
89    // Check init.defaultBranch config
90    if let Ok(config) = repo.config() {
91        if let Ok(default_branch) = config.get_string("init.defaultBranch") {
92            // Verify the configured branch exists
93            if repo
94                .find_branch(&default_branch, git2::BranchType::Local)
95                .is_ok()
96            {
97                return Ok(default_branch);
98            }
99        }
100    }
101
102    // Fall back to "main" if it exists
103    if repo.find_branch("main", git2::BranchType::Local).is_ok() {
104        return Ok("main".to_string());
105    }
106
107    // Fall back to "master" if it exists
108    if repo.find_branch("master", git2::BranchType::Local).is_ok() {
109        return Ok("master".to_string());
110    }
111
112    Err(DefaultBranchError::NoDefaultBranch.into())
113}