grit_lib/
branch_tracking.rs1use std::collections::HashSet;
4use std::fs;
5
6use crate::config::ConfigSet;
7use crate::error::Result;
8use crate::merge_base::count_symmetric_ahead_behind;
9use crate::refs;
10use crate::repo::Repository;
11use crate::rev_parse::{
12 abbreviate_ref_name, resolve_push_full_ref_for_branch, resolve_upstream_symbolic_name,
13};
14
15#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum AheadBehindMode {
18 Full,
20 Quick,
22}
23
24#[derive(Clone, Debug)]
26pub enum TrackingStat {
27 UpToDate,
29 Gone {
31 display_name: String,
33 },
34 Diverged {
36 display_name: String,
38 ahead: usize,
40 behind: usize,
42 },
43}
44
45#[must_use]
47pub fn shorten_tracking_ref(full_ref: &str) -> String {
48 abbreviate_ref_name(full_ref)
49}
50
51fn branch_head_ref(short_name: &str) -> String {
52 format!("refs/heads/{short_name}")
53}
54
55#[must_use]
57pub fn upstream_tracking_full_ref(repo: &Repository, branch_short: &str) -> Option<String> {
58 let config = ConfigSet::load(Some(&repo.git_dir), true).ok()?;
59 let remote = config.get(&format!("branch.{branch_short}.remote"))?;
60 let merge = config.get(&format!("branch.{branch_short}.merge"))?;
61 if remote == "." {
62 let m = merge.trim();
63 if m.starts_with("refs/") {
64 Some(m.to_owned())
65 } else {
66 Some(format!("refs/heads/{m}"))
67 }
68 } else {
69 let mb = merge.strip_prefix("refs/heads/").unwrap_or(&merge);
70 Some(format!("refs/remotes/{remote}/{mb}"))
71 }
72}
73
74pub fn stat_branch_pair(
76 repo: &Repository,
77 branch_short: &str,
78 base_ref: &str,
79 mode: AheadBehindMode,
80) -> Result<TrackingStat> {
81 let branch_ref = branch_head_ref(branch_short);
82 let local_oid = match refs::resolve_ref(&repo.git_dir, &branch_ref) {
83 Ok(o) => o,
84 Err(_) => {
85 return Ok(TrackingStat::Diverged {
86 display_name: shorten_tracking_ref(base_ref),
87 ahead: 0,
88 behind: 0,
89 });
90 }
91 };
92 let upstream_oid = match refs::resolve_ref(&repo.git_dir, base_ref) {
93 Ok(o) => o,
94 Err(_) => {
95 return Ok(TrackingStat::Gone {
96 display_name: shorten_tracking_ref(base_ref),
97 });
98 }
99 };
100 if local_oid == upstream_oid {
101 return Ok(TrackingStat::UpToDate);
102 }
103 if mode == AheadBehindMode::Quick {
104 return Ok(TrackingStat::Diverged {
105 display_name: shorten_tracking_ref(base_ref),
106 ahead: 0,
107 behind: 0,
108 });
109 }
110 let (ahead, behind) = count_symmetric_ahead_behind(repo, local_oid, upstream_oid)?;
111 Ok(TrackingStat::Diverged {
112 display_name: shorten_tracking_ref(base_ref),
113 ahead,
114 behind,
115 })
116}
117
118fn parse_status_compare_branches(config_content: &str) -> Option<String> {
120 let mut in_status = false;
121 for line in config_content.lines() {
122 let trimmed = line.trim();
123 if trimmed.starts_with('[') {
124 in_status = trimmed.eq_ignore_ascii_case("[status]");
125 continue;
126 }
127 let lower = trimmed.to_ascii_lowercase();
128 if lower.starts_with("status.comparebranches") {
129 return trimmed.split_once('=').map(|(_, v)| v.trim().to_owned());
130 }
131 if in_status && lower.starts_with("comparebranches") {
132 return trimmed.split_once('=').map(|(_, v)| v.trim().to_owned());
133 }
134 }
135 None
136}
137
138fn parse_compare_branch_specs(raw: &str) -> Vec<String> {
139 raw.split_whitespace()
140 .map(|s| s.to_string())
141 .filter(|s| !s.is_empty())
142 .collect()
143}
144
145fn resolve_compare_full_ref(repo: &Repository, branch_short: &str, token: &str) -> Option<String> {
146 let t = token.trim();
147 if t.eq_ignore_ascii_case("@{upstream}") || t.eq_ignore_ascii_case("@{u}") {
148 let spec = if branch_short.is_empty() {
149 "@{u}".to_string()
150 } else {
151 format!("{branch_short}@{{u}}")
152 };
153 resolve_upstream_symbolic_name(repo, &spec).ok()
154 } else if t.eq_ignore_ascii_case("@{push}") {
155 resolve_push_full_ref_for_branch(repo, branch_short).ok()
156 } else {
157 None
158 }
159}
160
161pub fn format_tracking_info(
163 repo: &Repository,
164 branch_short: &str,
165 mode: AheadBehindMode,
166 show_divergence_advice: bool,
167) -> Result<String> {
168 let config_path = repo.git_dir.join("config");
169 let config_raw = fs::read_to_string(&config_path).unwrap_or_default();
170 let compare_raw =
171 parse_status_compare_branches(&config_raw).unwrap_or_else(|| "@{upstream}".to_string());
172
173 let tokens = parse_compare_branch_specs(&compare_raw);
174 if tokens.is_empty() {
175 return Ok(String::new());
176 }
177
178 let upstream_full = resolve_compare_full_ref(repo, branch_short, "@{upstream}");
179 let push_full = resolve_compare_full_ref(repo, branch_short, "@{push}");
180
181 let mut seen: HashSet<String> = HashSet::new();
182 let mut out = String::new();
183 let mut reported = false;
184
185 for tok in tokens {
186 let Some(full_ref) = resolve_compare_full_ref(repo, branch_short, &tok) else {
187 continue;
188 };
189 if !seen.insert(full_ref.clone()) {
190 continue;
191 }
192
193 let is_upstream = upstream_full.as_ref() == Some(&full_ref);
194 let mut is_push = push_full.as_ref() == Some(&full_ref);
195 if is_upstream && push_full.as_ref().is_none_or(|p| p == &full_ref) {
196 is_push = true;
197 }
198
199 let stat = stat_branch_pair(repo, branch_short, &full_ref, mode)?;
200
201 match &stat {
202 TrackingStat::Gone { display_name } if is_upstream => {
203 if reported {
204 out.push('\n');
205 }
206 out.push_str(&format!(
207 "Your branch is based on '{display_name}', but the upstream is gone.\n"
208 ));
209 out.push_str(" (use \"git branch --unset-upstream\" to fixup)\n");
210 reported = true;
211 }
212 TrackingStat::Gone { .. } => {}
213 TrackingStat::UpToDate => {
214 if reported {
215 out.push('\n');
216 }
217 let d = shorten_tracking_ref(&full_ref);
218 out.push_str(&format!("Your branch is up to date with '{d}'.\n"));
219 reported = true;
220 }
221 TrackingStat::Diverged {
222 display_name,
223 ahead,
224 behind,
225 } => {
226 if reported {
227 out.push('\n');
228 }
229 if mode == AheadBehindMode::Quick {
230 out.push_str(&format!(
231 "Your branch and '{display_name}' refer to different commits.\n"
232 ));
233 if is_push {
234 out.push_str(" (use \"git status --ahead-behind\" for details)\n");
235 }
236 } else if *ahead > 0 && *behind > 0 {
237 out.push_str(&format!(
238 "Your branch and '{display_name}' have diverged,\n\
239and have {ahead} and {behind} different commits each, respectively.\n"
240 ));
241 if show_divergence_advice && is_upstream {
242 out.push_str(
243 " (use \"git pull\" if you want to integrate the remote branch with yours)\n",
244 );
245 }
246 } else if *ahead > 0 {
247 out.push_str(&format!(
248 "Your branch is ahead of '{display_name}' by {ahead} commit{}.\n",
249 if *ahead == 1 { "" } else { "s" }
250 ));
251 if is_push {
252 out.push_str(" (use \"git push\" to publish your local commits)\n");
253 }
254 } else {
255 out.push_str(&format!(
256 "Your branch is behind of '{display_name}' by {behind} commit{}, and can be fast-forwarded.\n",
257 if *behind == 1 { "" } else { "s" }
258 ));
259 if is_upstream {
260 out.push_str(" (use \"git pull\" to update your local branch)\n");
261 }
262 }
263 reported = true;
264 }
265 }
266 }
267
268 Ok(out)
269}