git_prole/git/
remote.rs

1use std::fmt::Debug;
2use std::str::FromStr;
3
4use camino::Utf8Path;
5use command_error::CommandExt;
6use command_error::OutputContext;
7use miette::miette;
8use miette::Context;
9use rustc_hash::FxHashSet;
10use tap::TryConv;
11use tracing::instrument;
12use utf8_command::Utf8Output;
13use winnow::combinator::rest;
14use winnow::token::take_till;
15use winnow::PResult;
16use winnow::Parser;
17
18use crate::AppGit;
19
20use super::GitLike;
21use super::LocalBranchRef;
22use super::Ref;
23use super::RemoteBranchRef;
24
25/// Git methods for dealing with remotes.
26#[repr(transparent)]
27pub struct GitRemote<'a, G>(&'a G);
28
29impl<G> Debug for GitRemote<'_, G>
30where
31    G: GitLike,
32{
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_tuple("GitRemote")
35            .field(&self.0.get_current_dir().as_ref())
36            .finish()
37    }
38}
39
40impl<'a, G> GitRemote<'a, G>
41where
42    G: GitLike,
43{
44    pub fn new(git: &'a G) -> Self {
45        Self(git)
46    }
47
48    /// Get a list of all `git remote`s.
49    #[instrument(level = "trace")]
50    pub fn list(&self) -> miette::Result<Vec<String>> {
51        Ok(self
52            .0
53            .command()
54            .arg("remote")
55            .output_checked_utf8()
56            .wrap_err("Failed to list Git remotes")?
57            .stdout
58            .lines()
59            .map(|line| line.to_owned())
60            .collect())
61    }
62
63    /// Get the (push) URL for the given remote.
64    #[expect(dead_code)] // #[instrument(level = "trace")]
65    pub(crate) fn get_push_url(&self, remote: &str) -> miette::Result<String> {
66        Ok(self
67            .0
68            .command()
69            .args(["remote", "get-url", "--push", remote])
70            .output_checked_utf8()
71            .wrap_err("Failed to get Git remote URL")?
72            .stdout
73            .trim()
74            .to_owned())
75    }
76
77    #[instrument(level = "trace")]
78    fn default_branch_symbolic_ref(&self, remote: &str) -> miette::Result<RemoteBranchRef> {
79        Ok(self
80            .0
81            .command()
82            .args(["symbolic-ref", &format!("refs/remotes/{remote}/HEAD")])
83            .output_checked_as(|context: OutputContext<Utf8Output>| {
84                if !context.status().success() {
85                    Err(context.error())
86                } else {
87                    let output = context.output().stdout.trim_end();
88                    match Ref::from_str(output) {
89                        Err(err) => Err(context.error_msg(err)),
90                        Ok(ref_name) => match ref_name.try_conv::<RemoteBranchRef>() {
91                            Ok(remote_branch) => Ok(remote_branch),
92                            Err(err) => Err(context.error_msg(format!("{err}"))),
93                        },
94                    }
95                }
96            })?)
97    }
98
99    #[instrument(level = "trace")]
100    fn default_branch_ls_remote(&self, remote: &str) -> miette::Result<RemoteBranchRef> {
101        let branch = self
102            .0
103            .command()
104            .args(["ls-remote", "--symref", remote, "HEAD"])
105            .output_checked_as(|context: OutputContext<Utf8Output>| {
106                if !context.status().success() {
107                    Err(context.error())
108                } else {
109                    let output = &context.output().stdout;
110                    match parse_ls_remote_symref.parse(output) {
111                        Err(err) => {
112                            let err = miette!("{err}");
113                            Err(context.error_msg(err))
114                        }
115                        Ok(ref_name) => match ref_name.try_conv::<LocalBranchRef>() {
116                            Ok(local_branch) => Ok(local_branch.on_remote(remote)),
117                            Err(err) => Err(context.error_msg(format!("{err}"))),
118                        },
119                    }
120                }
121            })?;
122
123        // To avoid talking to the remote next time, write a symbolic-ref.
124        self.0
125            .command()
126            .args([
127                "symbolic-ref",
128                &format!("refs/remotes/{remote}/HEAD"),
129                &format!("refs/remotes/{remote}/{branch}"),
130            ])
131            .output_checked_utf8()
132            .wrap_err_with(|| {
133                format!("Failed to store symbolic ref for default branch for remote {remote}")
134            })?;
135
136        Ok(branch)
137    }
138
139    /// Get the default branch for the given remote.
140    #[instrument(level = "trace")]
141    pub fn default_branch(&self, remote: &str) -> miette::Result<RemoteBranchRef> {
142        self.default_branch_symbolic_ref(remote).or_else(|err| {
143            tracing::debug!("Failed to get default branch: {err}");
144            self.default_branch_ls_remote(remote)
145        })
146    }
147
148    /// Get the `checkout.defaultRemote` setting.
149    #[instrument(level = "trace")]
150    pub fn get_default(&self) -> miette::Result<Option<String>> {
151        self.0.config().get("checkout.defaultRemote")
152    }
153
154    /// Find a unique remote branch by name.
155    ///
156    /// The discovered remote, if any, is returned.
157    ///
158    /// This is (hopefully!) how Git determines which remote-tracking branch you want when you do a
159    /// `git switch` or `git worktree add`.
160    #[instrument(level = "trace")]
161    pub fn for_branch(&self, branch: &str) -> miette::Result<Option<RemoteBranchRef>> {
162        let mut exists_on_remotes = self
163            .0
164            .refs()
165            .for_each_ref(Some(&[&format!("refs/remotes/*/{branch}")]))?;
166
167        if exists_on_remotes.is_empty() {
168            Ok(None)
169        } else if exists_on_remotes.len() == 1 {
170            Ok(exists_on_remotes.pop().map(|ref_name| {
171                RemoteBranchRef::try_from(ref_name)
172                    .expect("`for-each-ref` restricted to `refs/remotes/*` refs")
173            }))
174        } else if let Some(default_remote) = self.get_default()? {
175            // if-let chains when?
176            match exists_on_remotes
177                .into_iter()
178                .map(|ref_name| {
179                    RemoteBranchRef::try_from(ref_name)
180                        .expect("`for-each-ref` restricted to `refs/remotes/*` refs")
181                })
182                .find(|branch| branch.remote() == default_remote)
183            {
184                Some(remote) => Ok(Some(remote)),
185                _ => Ok(None),
186            }
187        } else {
188            Ok(None)
189        }
190    }
191
192    /// Fetch a refspec from a remote.
193    #[instrument(level = "trace")]
194    pub fn fetch(&self, remote: &str, refspec: Option<&str>) -> miette::Result<()> {
195        let mut command = self.0.command();
196        command.args(["fetch", remote]);
197        if let Some(refspec) = refspec {
198            command.arg(refspec);
199        }
200        command.status_checked()?;
201        Ok(())
202    }
203}
204
205impl<'a, C> GitRemote<'a, AppGit<'a, C>>
206where
207    C: AsRef<Utf8Path>,
208{
209    /// Get a list of remotes in the user's preference order.
210    #[instrument(level = "trace")]
211    pub fn list_preferred(&self) -> miette::Result<Vec<String>> {
212        let mut all_remotes = self.list()?.into_iter().collect::<FxHashSet<_>>();
213
214        let mut sorted = Vec::with_capacity(all_remotes.len());
215
216        if let Some(default_remote) = self.get_default()? {
217            if let Some(remote) = all_remotes.take(&default_remote) {
218                sorted.push(remote);
219            }
220        }
221
222        let preferred_remotes = self.0.config.file.remote_names();
223        for remote in preferred_remotes {
224            if let Some(remote) = all_remotes.take(&remote) {
225                sorted.push(remote);
226            }
227        }
228
229        Ok(sorted)
230    }
231
232    /// Get the user's preferred remote, if any.
233    #[instrument(level = "trace")]
234    pub fn preferred(&self) -> miette::Result<Option<String>> {
235        Ok(self.list_preferred()?.first().cloned())
236    }
237}
238
239/// Parse a symbolic ref from the start of `git ls-remote --symref` output.
240fn parse_ls_remote_symref(input: &mut &str) -> PResult<Ref> {
241    let _ = "ref: ".parse_next(input)?;
242    let ref_name = take_till(1.., '\t')
243        .and_then(Ref::parser)
244        .parse_next(input)?;
245    let _ = '\t'.parse_next(input)?;
246    // Don't care about the rest!
247    let _ = rest.parse_next(input)?;
248    Ok(ref_name)
249}
250
251#[cfg(test)]
252mod tests {
253    use indoc::indoc;
254    use pretty_assertions::assert_eq;
255
256    use super::*;
257
258    #[test]
259    fn test_parse_ls_remote_symref() {
260        assert_eq!(
261            parse_ls_remote_symref
262                .parse(indoc!(
263                    "
264                    ref: refs/heads/main\tHEAD
265                    9afc843b4288394fe3a2680b13070cfd53164b92\tHEAD
266                    "
267                ))
268                .unwrap(),
269            Ref::from_str("refs/heads/main").unwrap(),
270        );
271    }
272}