Skip to main content

grit_lib/
branch_tracking.rs

1//! Branch vs remote-tracking comparison for status, checkout, and commit (matches `git/remote.c`).
2
3use std::collections::HashSet;
4use std::fs;
5
6use crate::config::ConfigSet;
7use crate::error::Result;
8use crate::merge_base::count_symmetric_ahead_behind;
9use crate::refs;
10use crate::repo::Repository;
11use crate::rev_parse::{
12    abbreviate_ref_name, resolve_push_full_ref_for_branch, resolve_upstream_symbolic_name,
13};
14
15/// How to compare local HEAD to a remote-tracking ref (`AHEAD_BEHIND_FULL` vs `QUICK`).
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum AheadBehindMode {
18    /// Compute exact ahead/behind counts (`rev-list --left-right`).
19    Full,
20    /// Only detect same vs different (cheap).
21    Quick,
22}
23
24/// Outcome of comparing `refs/heads/<branch>` to a tracking ref.
25#[derive(Clone, Debug)]
26pub enum TrackingStat {
27    /// Tips are the same commit.
28    UpToDate,
29    /// Tracking ref is missing (gone upstream).
30    Gone {
31        /// Short display name for the missing tracking ref.
32        display_name: String,
33    },
34    /// Tips differ; counts are zero in [`AheadBehindMode::Quick`] mode.
35    Diverged {
36        /// Short display name for the tracking ref.
37        display_name: String,
38        /// Number of commits local branch is ahead.
39        ahead: usize,
40        /// Number of commits local branch is behind.
41        behind: usize,
42    },
43}
44
45/// Short display name for a full ref (`refs/remotes/origin/main` -> `origin/main`).
46#[must_use]
47pub fn shorten_tracking_ref(full_ref: &str) -> String {
48    abbreviate_ref_name(full_ref)
49}
50
51fn branch_head_ref(short_name: &str) -> String {
52    format!("refs/heads/{short_name}")
53}
54
55/// Full ref for the configured upstream of `branch_short` (`refs/remotes/...` or `refs/heads/...`).
56#[must_use]
57pub fn upstream_tracking_full_ref(repo: &Repository, branch_short: &str) -> Option<String> {
58    let config = ConfigSet::load(Some(&repo.git_dir), true).ok()?;
59    let remote = config.get(&format!("branch.{branch_short}.remote"))?;
60    let merge = config.get(&format!("branch.{branch_short}.merge"))?;
61    if remote == "." {
62        let m = merge.trim();
63        if m.starts_with("refs/") {
64            Some(m.to_owned())
65        } else {
66            Some(format!("refs/heads/{m}"))
67        }
68    } else {
69        let mb = merge.strip_prefix("refs/heads/").unwrap_or(&merge);
70        Some(format!("refs/remotes/{remote}/{mb}"))
71    }
72}
73
74/// Compare local branch tip to `base_ref` (full ref like `refs/remotes/origin/main`).
75pub fn stat_branch_pair(
76    repo: &Repository,
77    branch_short: &str,
78    base_ref: &str,
79    mode: AheadBehindMode,
80) -> Result<TrackingStat> {
81    let branch_ref = branch_head_ref(branch_short);
82    let local_oid = match refs::resolve_ref(&repo.git_dir, &branch_ref) {
83        Ok(o) => o,
84        Err(_) => {
85            return Ok(TrackingStat::Diverged {
86                display_name: shorten_tracking_ref(base_ref),
87                ahead: 0,
88                behind: 0,
89            });
90        }
91    };
92    let upstream_oid = match refs::resolve_ref(&repo.git_dir, base_ref) {
93        Ok(o) => o,
94        Err(_) => {
95            return Ok(TrackingStat::Gone {
96                display_name: shorten_tracking_ref(base_ref),
97            });
98        }
99    };
100    if local_oid == upstream_oid {
101        return Ok(TrackingStat::UpToDate);
102    }
103    if mode == AheadBehindMode::Quick {
104        return Ok(TrackingStat::Diverged {
105            display_name: shorten_tracking_ref(base_ref),
106            ahead: 0,
107            behind: 0,
108        });
109    }
110    let (ahead, behind) = count_symmetric_ahead_behind(repo, local_oid, upstream_oid)?;
111    Ok(TrackingStat::Diverged {
112        display_name: shorten_tracking_ref(base_ref),
113        ahead,
114        behind,
115    })
116}
117
118/// Read `status.compareBranches` from `.git/config` (`[status]` section or dotted key).
119fn parse_status_compare_branches(config_content: &str) -> Option<String> {
120    let mut in_status = false;
121    for line in config_content.lines() {
122        let trimmed = line.trim();
123        if trimmed.starts_with('[') {
124            in_status = trimmed.eq_ignore_ascii_case("[status]");
125            continue;
126        }
127        let lower = trimmed.to_ascii_lowercase();
128        if lower.starts_with("status.comparebranches") {
129            return trimmed.split_once('=').map(|(_, v)| v.trim().to_owned());
130        }
131        if in_status && lower.starts_with("comparebranches") {
132            return trimmed.split_once('=').map(|(_, v)| v.trim().to_owned());
133        }
134    }
135    None
136}
137
138fn parse_compare_branch_specs(raw: &str) -> Vec<String> {
139    raw.split_whitespace()
140        .map(|s| s.to_string())
141        .filter(|s| !s.is_empty())
142        .collect()
143}
144
145fn resolve_compare_full_ref(repo: &Repository, branch_short: &str, token: &str) -> Option<String> {
146    let t = token.trim();
147    if t.eq_ignore_ascii_case("@{upstream}") || t.eq_ignore_ascii_case("@{u}") {
148        let spec = if branch_short.is_empty() {
149            "@{u}".to_string()
150        } else {
151            format!("{branch_short}@{{u}}")
152        };
153        resolve_upstream_symbolic_name(repo, &spec).ok()
154    } else if t.eq_ignore_ascii_case("@{push}") {
155        resolve_push_full_ref_for_branch(repo, branch_short).ok()
156    } else {
157        None
158    }
159}
160
161/// Multi-branch tracking lines for porcelain long status and checkout (Git `format_tracking_info`).
162pub fn format_tracking_info(
163    repo: &Repository,
164    branch_short: &str,
165    mode: AheadBehindMode,
166    show_divergence_advice: bool,
167) -> Result<String> {
168    let config_path = repo.git_dir.join("config");
169    let config_raw = fs::read_to_string(&config_path).unwrap_or_default();
170    let compare_raw =
171        parse_status_compare_branches(&config_raw).unwrap_or_else(|| "@{upstream}".to_string());
172
173    let tokens = parse_compare_branch_specs(&compare_raw);
174    if tokens.is_empty() {
175        return Ok(String::new());
176    }
177
178    let upstream_full = resolve_compare_full_ref(repo, branch_short, "@{upstream}");
179    let push_full = resolve_compare_full_ref(repo, branch_short, "@{push}");
180
181    let mut seen: HashSet<String> = HashSet::new();
182    let mut out = String::new();
183    let mut reported = false;
184
185    for tok in tokens {
186        let Some(full_ref) = resolve_compare_full_ref(repo, branch_short, &tok) else {
187            continue;
188        };
189        if !seen.insert(full_ref.clone()) {
190            continue;
191        }
192
193        let is_upstream = upstream_full.as_ref() == Some(&full_ref);
194        let mut is_push = push_full.as_ref() == Some(&full_ref);
195        if is_upstream && push_full.as_ref().is_none_or(|p| p == &full_ref) {
196            is_push = true;
197        }
198
199        let stat = stat_branch_pair(repo, branch_short, &full_ref, mode)?;
200
201        match &stat {
202            TrackingStat::Gone { display_name } if is_upstream => {
203                if reported {
204                    out.push('\n');
205                }
206                out.push_str(&format!(
207                    "Your branch is based on '{display_name}', but the upstream is gone.\n"
208                ));
209                out.push_str("  (use \"git branch --unset-upstream\" to fixup)\n");
210                reported = true;
211            }
212            TrackingStat::Gone { .. } => {}
213            TrackingStat::UpToDate => {
214                if reported {
215                    out.push('\n');
216                }
217                let d = shorten_tracking_ref(&full_ref);
218                out.push_str(&format!("Your branch is up to date with '{d}'.\n"));
219                reported = true;
220            }
221            TrackingStat::Diverged {
222                display_name,
223                ahead,
224                behind,
225            } => {
226                if reported {
227                    out.push('\n');
228                }
229                if mode == AheadBehindMode::Quick {
230                    out.push_str(&format!(
231                        "Your branch and '{display_name}' refer to different commits.\n"
232                    ));
233                    if is_push {
234                        out.push_str("  (use \"git status --ahead-behind\" for details)\n");
235                    }
236                } else if *ahead > 0 && *behind > 0 {
237                    out.push_str(&format!(
238                        "Your branch and '{display_name}' have diverged,\n\
239and have {ahead} and {behind} different commits each, respectively.\n"
240                    ));
241                    if show_divergence_advice && is_upstream {
242                        out.push_str(
243                            "  (use \"git pull\" if you want to integrate the remote branch with yours)\n",
244                        );
245                    }
246                } else if *ahead > 0 {
247                    out.push_str(&format!(
248                        "Your branch is ahead of '{display_name}' by {ahead} commit{}.\n",
249                        if *ahead == 1 { "" } else { "s" }
250                    ));
251                    if is_push {
252                        out.push_str("  (use \"git push\" to publish your local commits)\n");
253                    }
254                } else {
255                    out.push_str(&format!(
256                        "Your branch is behind of '{display_name}' by {behind} commit{}, and can be fast-forwarded.\n",
257                        if *behind == 1 { "" } else { "s" }
258                    ));
259                    if is_upstream {
260                        out.push_str("  (use \"git pull\" to update your local branch)\n");
261                    }
262                }
263                reported = true;
264            }
265        }
266    }
267
268    Ok(out)
269}