grit_lib/
branch_tracking.rs1use std::collections::HashSet;
4use std::fs;
5
6use crate::config::ConfigSet;
7use crate::error::Result;
8use crate::merge_base::count_symmetric_ahead_behind;
9use crate::refs;
10use crate::repo::Repository;
11use crate::rev_parse::{
12 abbreviate_ref_name, resolve_push_full_ref_for_branch, resolve_upstream_symbolic_name,
13};
14
15#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum AheadBehindMode {
18 Full,
20 Quick,
22}
23
24#[derive(Clone, Debug)]
26pub enum TrackingStat {
27 UpToDate,
29 Gone {
31 display_name: String,
33 },
34 Diverged {
36 display_name: String,
38 ahead: usize,
40 behind: usize,
42 },
43}
44
45#[must_use]
47pub fn shorten_tracking_ref(full_ref: &str) -> String {
48 abbreviate_ref_name(full_ref)
49}
50
51fn branch_head_ref(short_name: &str) -> String {
52 format!("refs/heads/{short_name}")
53}
54
55#[must_use]
57pub fn upstream_tracking_full_ref(repo: &Repository, branch_short: &str) -> Option<String> {
58 let config = ConfigSet::load(Some(&repo.git_dir), true).ok()?;
59 let remote = config.get(&format!("branch.{branch_short}.remote"))?;
60 let merge = config.get(&format!("branch.{branch_short}.merge"))?;
61 if remote == "." {
62 let m = merge.trim();
63 if m.starts_with("refs/") {
64 Some(m.to_owned())
65 } else {
66 Some(format!("refs/heads/{m}"))
67 }
68 } else {
69 let mb = merge.strip_prefix("refs/heads/").unwrap_or(&merge);
70 Some(format!("refs/remotes/{remote}/{mb}"))
71 }
72}
73
74#[must_use]
79pub fn remote_tracking_ref_is_mapped(repo: &Repository, tracking_ref: &str) -> bool {
80 let Ok(config) = ConfigSet::load(Some(&repo.git_dir), true) else {
81 return false;
82 };
83 for entry in config.entries() {
84 let Some(rest) = entry.key.strip_prefix("remote.") else {
86 continue;
87 };
88 if !rest.ends_with(".fetch") {
89 continue;
90 }
91 let Some(spec) = entry.value.as_deref() else {
92 continue;
93 };
94 if refspec_dst_matches(spec, tracking_ref) {
95 return true;
96 }
97 }
98 false
99}
100
101fn refspec_dst_matches(spec: &str, target: &str) -> bool {
104 let spec = spec.strip_prefix('+').unwrap_or(spec);
105 let Some((_src, dst)) = spec.split_once(':') else {
106 return false;
107 };
108 let dst = dst.trim();
109 if let Some(prefix) = dst.strip_suffix('*') {
110 target.starts_with(prefix) && target.len() > prefix.len()
112 } else {
113 dst == target
114 }
115}
116
117pub fn stat_branch_pair(
119 repo: &Repository,
120 branch_short: &str,
121 base_ref: &str,
122 mode: AheadBehindMode,
123) -> Result<TrackingStat> {
124 let branch_ref = branch_head_ref(branch_short);
125 let local_oid = match refs::resolve_ref(&repo.git_dir, &branch_ref) {
126 Ok(o) => o,
127 Err(_) => {
128 return Ok(TrackingStat::Diverged {
129 display_name: shorten_tracking_ref(base_ref),
130 ahead: 0,
131 behind: 0,
132 });
133 }
134 };
135 let upstream_oid = match refs::resolve_ref(&repo.git_dir, base_ref) {
136 Ok(o) => o,
137 Err(_) => {
138 return Ok(TrackingStat::Gone {
139 display_name: shorten_tracking_ref(base_ref),
140 });
141 }
142 };
143 if local_oid == upstream_oid {
144 return Ok(TrackingStat::UpToDate);
145 }
146 if mode == AheadBehindMode::Quick {
147 return Ok(TrackingStat::Diverged {
148 display_name: shorten_tracking_ref(base_ref),
149 ahead: 0,
150 behind: 0,
151 });
152 }
153 let (ahead, behind) = count_symmetric_ahead_behind(repo, local_oid, upstream_oid)?;
154 Ok(TrackingStat::Diverged {
155 display_name: shorten_tracking_ref(base_ref),
156 ahead,
157 behind,
158 })
159}
160
161fn parse_status_compare_branches(config_content: &str) -> Option<String> {
163 let mut in_status = false;
164 for line in config_content.lines() {
165 let trimmed = line.trim();
166 if trimmed.starts_with('[') {
167 in_status = trimmed.eq_ignore_ascii_case("[status]");
168 continue;
169 }
170 let lower = trimmed.to_ascii_lowercase();
171 if lower.starts_with("status.comparebranches") {
172 return trimmed.split_once('=').map(|(_, v)| v.trim().to_owned());
173 }
174 if in_status && lower.starts_with("comparebranches") {
175 return trimmed.split_once('=').map(|(_, v)| v.trim().to_owned());
176 }
177 }
178 None
179}
180
181fn parse_compare_branch_specs(raw: &str) -> Vec<String> {
182 raw.split_whitespace()
183 .map(|s| s.to_string())
184 .filter(|s| !s.is_empty())
185 .collect()
186}
187
188fn resolve_compare_full_ref(repo: &Repository, branch_short: &str, token: &str) -> Option<String> {
189 let t = token.trim();
190 if t.eq_ignore_ascii_case("@{upstream}") || t.eq_ignore_ascii_case("@{u}") {
191 let spec = if branch_short.is_empty() {
192 "@{u}".to_string()
193 } else {
194 format!("{branch_short}@{{u}}")
195 };
196 resolve_upstream_symbolic_name(repo, &spec).ok()
197 } else if t.eq_ignore_ascii_case("@{push}") {
198 resolve_push_full_ref_for_branch(repo, branch_short).ok()
199 } else {
200 None
201 }
202}
203
204pub fn format_tracking_info(
206 repo: &Repository,
207 branch_short: &str,
208 mode: AheadBehindMode,
209 show_divergence_advice: bool,
210) -> Result<String> {
211 let config_path = repo.git_dir.join("config");
212 let config_raw = fs::read_to_string(&config_path).unwrap_or_default();
213 let compare_raw =
214 parse_status_compare_branches(&config_raw).unwrap_or_else(|| "@{upstream}".to_string());
215
216 let tokens = parse_compare_branch_specs(&compare_raw);
217 if tokens.is_empty() {
218 return Ok(String::new());
219 }
220
221 let upstream_full = resolve_compare_full_ref(repo, branch_short, "@{upstream}");
222 let push_full = resolve_compare_full_ref(repo, branch_short, "@{push}");
223
224 let mut seen: HashSet<String> = HashSet::new();
225 let mut out = String::new();
226 let mut reported = false;
227
228 for tok in tokens {
229 let Some(full_ref) = resolve_compare_full_ref(repo, branch_short, &tok) else {
230 continue;
231 };
232 if !seen.insert(full_ref.clone()) {
233 continue;
234 }
235
236 let is_upstream = upstream_full.as_ref() == Some(&full_ref);
237 let mut is_push = push_full.as_ref() == Some(&full_ref);
238 if is_upstream && push_full.as_ref().is_none_or(|p| p == &full_ref) {
239 is_push = true;
240 }
241
242 let stat = stat_branch_pair(repo, branch_short, &full_ref, mode)?;
243
244 match &stat {
245 TrackingStat::Gone { display_name } if is_upstream => {
246 if reported {
247 out.push('\n');
248 }
249 out.push_str(&format!(
250 "Your branch is based on '{display_name}', but the upstream is gone.\n"
251 ));
252 out.push_str(" (use \"git branch --unset-upstream\" to fixup)\n");
253 reported = true;
254 }
255 TrackingStat::Gone { .. } => {}
256 TrackingStat::UpToDate => {
257 if reported {
258 out.push('\n');
259 }
260 let d = shorten_tracking_ref(&full_ref);
261 out.push_str(&format!("Your branch is up to date with '{d}'.\n"));
262 reported = true;
263 }
264 TrackingStat::Diverged {
265 display_name,
266 ahead,
267 behind,
268 } => {
269 if reported {
270 out.push('\n');
271 }
272 if mode == AheadBehindMode::Quick {
273 out.push_str(&format!(
274 "Your branch and '{display_name}' refer to different commits.\n"
275 ));
276 if is_push {
277 out.push_str(" (use \"git status --ahead-behind\" for details)\n");
278 }
279 } else if *ahead > 0 && *behind > 0 {
280 out.push_str(&format!(
281 "Your branch and '{display_name}' have diverged,\n\
282and have {ahead} and {behind} different commits each, respectively.\n"
283 ));
284 if show_divergence_advice && is_upstream {
285 out.push_str(
286 " (use \"git pull\" if you want to integrate the remote branch with yours)\n",
287 );
288 }
289 } else if *ahead > 0 {
290 out.push_str(&format!(
291 "Your branch is ahead of '{display_name}' by {ahead} commit{}.\n",
292 if *ahead == 1 { "" } else { "s" }
293 ));
294 if is_push {
295 out.push_str(" (use \"git push\" to publish your local commits)\n");
296 }
297 } else {
298 out.push_str(&format!(
299 "Your branch is behind '{display_name}' by {behind} commit{}, and can be fast-forwarded.\n",
300 if *behind == 1 { "" } else { "s" }
301 ));
302 if is_upstream {
303 out.push_str(" (use \"git pull\" to update your local branch)\n");
304 }
305 }
306 reported = true;
307 }
308 }
309 }
310
311 Ok(out)
312}