Skip to main content

workon/
default_branch.rs

1use git2::{Direction, Remote, RemoteCallbacks, Repository};
2use log::debug;
3
4use crate::error::{DefaultBranchError, Result};
5use crate::get_remote_callbacks;
6
7/// Builder for resolving the default branch name of a repository or remote.
8pub struct DefaultBranch<'repo, 'cb> {
9    repo: &'repo Repository,
10    remote: Option<Remote<'repo>>,
11    callbacks: Option<RemoteCallbacks<'cb>>,
12}
13
14impl<'repo, 'cb> DefaultBranch<'repo, 'cb> {
15    /// Create a new builder for the given repository.
16    pub fn new(repo: &'repo Repository) -> Self {
17        Self {
18            repo,
19            remote: None,
20            callbacks: None,
21        }
22    }
23
24    /// Set the remote to query for its default branch.
25    pub fn remote(&mut self, remote: Remote<'repo>) -> &mut Self {
26        self.remote = Some(remote);
27        self
28    }
29
30    /// Set credential callbacks for the remote connection.
31    pub fn remote_callbacks(&mut self, cbs: RemoteCallbacks<'cb>) -> &mut Self {
32        self.callbacks = Some(cbs);
33        self
34    }
35
36    /// Resolve the default branch name.
37    ///
38    /// If a remote was set, queries the remote for its HEAD ref.
39    /// Otherwise reads `init.defaultbranch` from git config (defaulting to `"main"`).
40    pub fn get_name(self) -> Result<String> {
41        match self.remote {
42            Some(mut remote) => {
43                let mut cxn = remote.connect_auth(Direction::Fetch, self.callbacks, None)?;
44
45                if !cxn.connected() {
46                    return Err(DefaultBranchError::NotConnected.into());
47                }
48
49                match cxn.default_branch()?.as_str() {
50                    Ok(default_branch) => Ok(default_branch
51                        .strip_prefix("refs/heads/")
52                        .unwrap_or(default_branch)
53                        .to_string()),
54                    Err(_) => Err(DefaultBranchError::NoRemoteDefault {
55                        remote: cxn.remote().name().ok().flatten().map(|s| s.to_string()),
56                    }
57                    .into()),
58                }
59            }
60            None => {
61                let config = self.repo.config()?;
62                let defaultbranch = config.get_str("init.defaultbranch").unwrap_or("main");
63                Ok(defaultbranch.to_string())
64            }
65        }
66    }
67}
68
69/// Convenience wrapper around [`DefaultBranch`].
70///
71/// Queries the remote for its default branch if one is provided, otherwise
72/// falls back to `init.defaultbranch` config (defaulting to `"main"`).
73pub fn get_default_branch_name(repo: &Repository, remote: Option<Remote>) -> Result<String> {
74    let auth;
75    let mut default_branch = DefaultBranch::new(repo);
76    if let Some(remote) = remote {
77        let url = remote.url().ok().map(str::to_string);
78        default_branch.remote(remote);
79        auth = get_remote_callbacks(repo, url.as_deref())?;
80        default_branch.remote_callbacks(auth.callbacks());
81    }
82    default_branch.get_name().or_else(|_| {
83        debug!("Failed to read default branch from remote, trying git config");
84        let branch = repo
85            .config()
86            .ok()
87            .and_then(|config| config.get_string("init.defaultbranch").ok())
88            .unwrap_or_else(|| {
89                debug!("No init.defaultbranch config, falling back to 'main'");
90                "main".to_string()
91            });
92        Ok(branch)
93    })
94}
95
96/// Get the default branch name for a repository, validated to exist.
97///
98/// This function:
99/// 1. Checks the `init.defaultBranch` config
100/// 2. Falls back to "main" if it exists
101/// 3. Falls back to "master" if it exists
102/// 4. Returns an error if none exist
103pub fn get_default_branch(repo: &Repository) -> Result<String> {
104    // Check init.defaultBranch config
105    if let Ok(config) = repo.config() {
106        if let Ok(default_branch) = config.get_string("init.defaultBranch") {
107            // Verify the configured branch exists
108            if repo
109                .find_branch(&default_branch, git2::BranchType::Local)
110                .is_ok()
111            {
112                return Ok(default_branch);
113            }
114        }
115    }
116
117    // Fall back to "main" if it exists
118    if repo.find_branch("main", git2::BranchType::Local).is_ok() {
119        return Ok("main".to_string());
120    }
121
122    // Fall back to "master" if it exists
123    if repo.find_branch("master", git2::BranchType::Local).is_ok() {
124        return Ok("master".to_string());
125    }
126
127    Err(DefaultBranchError::NoDefaultBranch.into())
128}