Skip to main content

grit_lib/
branch_tracking.rs

1//! Branch vs remote-tracking comparison for status, checkout, and commit (matches `git/remote.c`).
2
3use std::collections::HashSet;
4use std::fs;
5
6use crate::config::ConfigSet;
7use crate::error::Result;
8use crate::merge_base::count_symmetric_ahead_behind;
9use crate::refs;
10use crate::repo::Repository;
11use crate::rev_parse::{
12    abbreviate_ref_name, resolve_push_full_ref_for_branch, resolve_upstream_symbolic_name,
13};
14
15/// How to compare local HEAD to a remote-tracking ref (`AHEAD_BEHIND_FULL` vs `QUICK`).
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum AheadBehindMode {
18    /// Compute exact ahead/behind counts (`rev-list --left-right`).
19    Full,
20    /// Only detect same vs different (cheap).
21    Quick,
22}
23
24/// Outcome of comparing `refs/heads/<branch>` to a tracking ref.
25#[derive(Clone, Debug)]
26pub enum TrackingStat {
27    /// Tips are the same commit.
28    UpToDate,
29    /// Tracking ref is missing (gone upstream).
30    Gone {
31        /// Short display name for the missing tracking ref.
32        display_name: String,
33    },
34    /// Tips differ; counts are zero in [`AheadBehindMode::Quick`] mode.
35    Diverged {
36        /// Short display name for the tracking ref.
37        display_name: String,
38        /// Number of commits local branch is ahead.
39        ahead: usize,
40        /// Number of commits local branch is behind.
41        behind: usize,
42    },
43}
44
45/// Short display name for a full ref (`refs/remotes/origin/main` -> `origin/main`).
46#[must_use]
47pub fn shorten_tracking_ref(full_ref: &str) -> String {
48    abbreviate_ref_name(full_ref)
49}
50
51fn branch_head_ref(short_name: &str) -> String {
52    format!("refs/heads/{short_name}")
53}
54
55/// Full ref for the configured upstream of `branch_short` (`refs/remotes/...` or `refs/heads/...`).
56#[must_use]
57pub fn upstream_tracking_full_ref(repo: &Repository, branch_short: &str) -> Option<String> {
58    let config = ConfigSet::load(Some(&repo.git_dir), true).ok()?;
59    let remote = config.get(&format!("branch.{branch_short}.remote"))?;
60    let merge = config.get(&format!("branch.{branch_short}.merge"))?;
61    if remote == "." {
62        let m = merge.trim();
63        if m.starts_with("refs/") {
64            Some(m.to_owned())
65        } else {
66            Some(format!("refs/heads/{m}"))
67        }
68    } else {
69        let mb = merge.strip_prefix("refs/heads/").unwrap_or(&merge);
70        Some(format!("refs/remotes/{remote}/{mb}"))
71    }
72}
73
74/// Whether `tracking_ref` (a full `refs/remotes/<remote>/<branch>` ref) is the destination of some
75/// remote's fetch refspec — i.e. a genuine remote-tracking branch (Git `check_tracking_branch` /
76/// `validate_remote_tracking_branch`). A `refs/remotes/...` ref left over after its remote's fetch
77/// refspec was narrowed is NOT a valid tracking branch and must not be used for `--track`.
78#[must_use]
79pub fn remote_tracking_ref_is_mapped(repo: &Repository, tracking_ref: &str) -> bool {
80    let Ok(config) = ConfigSet::load(Some(&repo.git_dir), true) else {
81        return false;
82    };
83    for entry in config.entries() {
84        // Match keys of the form `remote.<name>.fetch`.
85        let Some(rest) = entry.key.strip_prefix("remote.") else {
86            continue;
87        };
88        if !rest.ends_with(".fetch") {
89            continue;
90        }
91        let Some(spec) = entry.value.as_deref() else {
92            continue;
93        };
94        if refspec_dst_matches(spec, tracking_ref) {
95            return true;
96        }
97    }
98    false
99}
100
101/// Does the destination half of a fetch refspec (`[+]<src>:<dst>`) match `target`?
102/// Supports a single trailing `*` glob, mirroring Git's refspec matching for the common case.
103fn refspec_dst_matches(spec: &str, target: &str) -> bool {
104    let spec = spec.strip_prefix('+').unwrap_or(spec);
105    let Some((_src, dst)) = spec.split_once(':') else {
106        return false;
107    };
108    let dst = dst.trim();
109    if let Some(prefix) = dst.strip_suffix('*') {
110        // Glob: dst = `refs/remotes/<remote>/*`; match any ref under that prefix.
111        target.starts_with(prefix) && target.len() > prefix.len()
112    } else {
113        dst == target
114    }
115}
116
117/// Compare local branch tip to `base_ref` (full ref like `refs/remotes/origin/main`).
118pub fn stat_branch_pair(
119    repo: &Repository,
120    branch_short: &str,
121    base_ref: &str,
122    mode: AheadBehindMode,
123) -> Result<TrackingStat> {
124    let branch_ref = branch_head_ref(branch_short);
125    let local_oid = match refs::resolve_ref(&repo.git_dir, &branch_ref) {
126        Ok(o) => o,
127        Err(_) => {
128            return Ok(TrackingStat::Diverged {
129                display_name: shorten_tracking_ref(base_ref),
130                ahead: 0,
131                behind: 0,
132            });
133        }
134    };
135    let upstream_oid = match refs::resolve_ref(&repo.git_dir, base_ref) {
136        Ok(o) => o,
137        Err(_) => {
138            return Ok(TrackingStat::Gone {
139                display_name: shorten_tracking_ref(base_ref),
140            });
141        }
142    };
143    if local_oid == upstream_oid {
144        return Ok(TrackingStat::UpToDate);
145    }
146    if mode == AheadBehindMode::Quick {
147        return Ok(TrackingStat::Diverged {
148            display_name: shorten_tracking_ref(base_ref),
149            ahead: 0,
150            behind: 0,
151        });
152    }
153    let (ahead, behind) = count_symmetric_ahead_behind(repo, local_oid, upstream_oid)?;
154    Ok(TrackingStat::Diverged {
155        display_name: shorten_tracking_ref(base_ref),
156        ahead,
157        behind,
158    })
159}
160
161/// Read `status.compareBranches` from `.git/config` (`[status]` section or dotted key).
162fn parse_status_compare_branches(config_content: &str) -> Option<String> {
163    let mut in_status = false;
164    for line in config_content.lines() {
165        let trimmed = line.trim();
166        if trimmed.starts_with('[') {
167            in_status = trimmed.eq_ignore_ascii_case("[status]");
168            continue;
169        }
170        let lower = trimmed.to_ascii_lowercase();
171        if lower.starts_with("status.comparebranches") {
172            return trimmed.split_once('=').map(|(_, v)| v.trim().to_owned());
173        }
174        if in_status && lower.starts_with("comparebranches") {
175            return trimmed.split_once('=').map(|(_, v)| v.trim().to_owned());
176        }
177    }
178    None
179}
180
181fn parse_compare_branch_specs(raw: &str) -> Vec<String> {
182    raw.split_whitespace()
183        .map(|s| s.to_string())
184        .filter(|s| !s.is_empty())
185        .collect()
186}
187
188fn resolve_compare_full_ref(repo: &Repository, branch_short: &str, token: &str) -> Option<String> {
189    let t = token.trim();
190    if t.eq_ignore_ascii_case("@{upstream}") || t.eq_ignore_ascii_case("@{u}") {
191        let spec = if branch_short.is_empty() {
192            "@{u}".to_string()
193        } else {
194            format!("{branch_short}@{{u}}")
195        };
196        resolve_upstream_symbolic_name(repo, &spec).ok()
197    } else if t.eq_ignore_ascii_case("@{push}") {
198        resolve_push_full_ref_for_branch(repo, branch_short).ok()
199    } else {
200        None
201    }
202}
203
204/// Multi-branch tracking lines for porcelain long status and checkout (Git `format_tracking_info`).
205pub fn format_tracking_info(
206    repo: &Repository,
207    branch_short: &str,
208    mode: AheadBehindMode,
209    show_divergence_advice: bool,
210) -> Result<String> {
211    let config_path = repo.git_dir.join("config");
212    let config_raw = fs::read_to_string(&config_path).unwrap_or_default();
213    let compare_raw =
214        parse_status_compare_branches(&config_raw).unwrap_or_else(|| "@{upstream}".to_string());
215
216    let tokens = parse_compare_branch_specs(&compare_raw);
217    if tokens.is_empty() {
218        return Ok(String::new());
219    }
220
221    let upstream_full = resolve_compare_full_ref(repo, branch_short, "@{upstream}");
222    let push_full = resolve_compare_full_ref(repo, branch_short, "@{push}");
223
224    let mut seen: HashSet<String> = HashSet::new();
225    let mut out = String::new();
226    let mut reported = false;
227
228    for tok in tokens {
229        let Some(full_ref) = resolve_compare_full_ref(repo, branch_short, &tok) else {
230            continue;
231        };
232        if !seen.insert(full_ref.clone()) {
233            continue;
234        }
235
236        let is_upstream = upstream_full.as_ref() == Some(&full_ref);
237        let mut is_push = push_full.as_ref() == Some(&full_ref);
238        if is_upstream && push_full.as_ref().is_none_or(|p| p == &full_ref) {
239            is_push = true;
240        }
241
242        let stat = stat_branch_pair(repo, branch_short, &full_ref, mode)?;
243
244        match &stat {
245            TrackingStat::Gone { display_name } if is_upstream => {
246                if reported {
247                    out.push('\n');
248                }
249                out.push_str(&format!(
250                    "Your branch is based on '{display_name}', but the upstream is gone.\n"
251                ));
252                out.push_str("  (use \"git branch --unset-upstream\" to fixup)\n");
253                reported = true;
254            }
255            TrackingStat::Gone { .. } => {}
256            TrackingStat::UpToDate => {
257                if reported {
258                    out.push('\n');
259                }
260                let d = shorten_tracking_ref(&full_ref);
261                out.push_str(&format!("Your branch is up to date with '{d}'.\n"));
262                reported = true;
263            }
264            TrackingStat::Diverged {
265                display_name,
266                ahead,
267                behind,
268            } => {
269                if reported {
270                    out.push('\n');
271                }
272                if mode == AheadBehindMode::Quick {
273                    out.push_str(&format!(
274                        "Your branch and '{display_name}' refer to different commits.\n"
275                    ));
276                    if is_push {
277                        out.push_str("  (use \"git status --ahead-behind\" for details)\n");
278                    }
279                } else if *ahead > 0 && *behind > 0 {
280                    out.push_str(&format!(
281                        "Your branch and '{display_name}' have diverged,\n\
282and have {ahead} and {behind} different commits each, respectively.\n"
283                    ));
284                    if show_divergence_advice && is_upstream {
285                        out.push_str(
286                            "  (use \"git pull\" if you want to integrate the remote branch with yours)\n",
287                        );
288                    }
289                } else if *ahead > 0 {
290                    out.push_str(&format!(
291                        "Your branch is ahead of '{display_name}' by {ahead} commit{}.\n",
292                        if *ahead == 1 { "" } else { "s" }
293                    ));
294                    if is_push {
295                        out.push_str("  (use \"git push\" to publish your local commits)\n");
296                    }
297                } else {
298                    out.push_str(&format!(
299                        "Your branch is behind '{display_name}' by {behind} commit{}, and can be fast-forwarded.\n",
300                        if *behind == 1 { "" } else { "s" }
301                    ));
302                    if is_upstream {
303                        out.push_str("  (use \"git pull\" to update your local branch)\n");
304                    }
305                }
306                reported = true;
307            }
308        }
309    }
310
311    Ok(out)
312}