Skip to main content

grit_lib/
merge_base.rs

1//! Merge-base and reachability primitives.
2//!
3//! This module implements the subset needed by `grit merge-base`:
4//! default merge-base selection, `--all`, `--octopus`, `--independent`,
5//! and `--is-ancestor`.
6
7use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
8
9use crate::config::ConfigSet;
10use crate::error::{Error, Result};
11use crate::objects::{parse_commit, ObjectId, ObjectKind};
12use crate::promisor::{promisor_pack_object_ids, repo_treats_promisor_packs};
13use crate::reflog::read_reflog;
14use crate::repo::Repository;
15use crate::rev_parse::{
16    peel_to_commit_for_merge_base, resolve_revision, resolve_upstream_symbolic_name,
17    upstream_suffix_info,
18};
19
20/// Resolve commit-ish command arguments to commit object IDs.
21///
22/// # Parameters
23///
24/// - `repo` - repository used for revision lookup and object reads.
25/// - `specs` - revision arguments such as `HEAD`, ref names, or object IDs.
26///
27/// # Errors
28///
29/// Returns [`Error::ObjectNotFound`] when a revision does not resolve and
30/// [`Error::CorruptObject`] when the resolved object is not a commit.
31pub fn resolve_commit_specs(repo: &Repository, specs: &[String]) -> Result<Vec<ObjectId>> {
32    let mut out = Vec::with_capacity(specs.len());
33    for spec in specs {
34        let oid = resolve_revision(repo, spec)?;
35        ensure_is_commit(repo, oid)?;
36        out.push(oid);
37    }
38    Ok(out)
39}
40
41/// Compute merge bases for one commit vs one or more others.
42///
43/// Semantics match Git's default mode: for `<a> <b>...`, this computes merge
44/// bases between `a` and a hypothetical merge of all remaining commits.
45///
46/// # Parameters
47///
48/// - `repo` - repository used to walk commit parents.
49/// - `first` - first commit argument.
50/// - `others` - remaining commit arguments.
51///
52/// # Errors
53///
54/// Returns parse and object read errors from commit traversal.
55pub fn merge_bases_first_vs_rest(
56    repo: &Repository,
57    first: ObjectId,
58    others: &[ObjectId],
59) -> Result<Vec<ObjectId>> {
60    let mut cache = CommitGraphCache::new(repo);
61    let first_anc = cache.ancestor_closure(first)?;
62    let mut others_union = HashSet::new();
63    for &other in others {
64        others_union.extend(cache.ancestor_closure(other)?);
65    }
66    let candidates: HashSet<ObjectId> = first_anc.intersection(&others_union).copied().collect();
67    reduce_to_best(candidates, &mut cache)
68}
69
70/// Merge base of `HEAD` and one other commit, matching `git diff --merge-base <commit>`.
71///
72/// Returns an error when there is no merge base or more than one.
73#[must_use]
74pub fn merge_base_for_diff_index(
75    repo: &Repository,
76    head: ObjectId,
77    other: ObjectId,
78) -> std::result::Result<ObjectId, MergeBaseForDiffError> {
79    let bases = merge_bases_first_vs_rest(repo, other, &[head])
80        .map_err(|e| MergeBaseForDiffError::Other(e.to_string()))?;
81    match bases.len() {
82        0 => Err(MergeBaseForDiffError::None),
83        1 => Ok(bases[0]),
84        _ => Err(MergeBaseForDiffError::Multiple),
85    }
86}
87
88/// Merge base of two commits, matching `git diff --merge-base <a> <b>` / `diff-tree --merge-base`.
89///
90/// Returns an error when there is no merge base or more than one.
91#[must_use]
92pub fn merge_base_for_diff_two_commits(
93    repo: &Repository,
94    a: ObjectId,
95    b: ObjectId,
96) -> std::result::Result<ObjectId, MergeBaseForDiffError> {
97    let bases = merge_bases_first_vs_rest(repo, a, &[b])
98        .map_err(|e| MergeBaseForDiffError::Other(e.to_string()))?;
99    match bases.len() {
100        0 => Err(MergeBaseForDiffError::None),
101        1 => Ok(bases[0]),
102        _ => Err(MergeBaseForDiffError::Multiple),
103    }
104}
105
106/// Failure modes for [`merge_base_for_diff_index`] and [`merge_base_for_diff_two_commits`].
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum MergeBaseForDiffError {
109    /// No common ancestor between the commits.
110    None,
111    /// More than one minimal merge base (criss-cross history).
112    Multiple,
113    /// Resolution or object read error; message is suitable for stderr.
114    Other(String),
115}
116
117/// Compute merge bases common to all supplied commits (`--octopus` mode).
118///
119/// # Parameters
120///
121/// - `repo` - repository used to walk commit parents.
122/// - `commits` - commits to intersect.
123///
124/// # Errors
125///
126/// Returns parse and object read errors from commit traversal.
127pub fn merge_bases_octopus(repo: &Repository, commits: &[ObjectId]) -> Result<Vec<ObjectId>> {
128    let mut cache = CommitGraphCache::new(repo);
129    let mut iter = commits.iter();
130    let Some(&first) = iter.next() else {
131        return Ok(Vec::new());
132    };
133    let mut common = cache.ancestor_closure(first)?;
134    for &oid in iter {
135        let set = cache.ancestor_closure(oid)?;
136        common.retain(|item| set.contains(item));
137    }
138    reduce_to_best(common, &mut cache)
139}
140
141/// All merge bases common to every supplied commit (intersection of ancestor sets,
142/// reduced to minimal bases). Matches `git merge-base` with two or more tips.
143///
144/// This is the same intersection-and-reduction as [`merge_bases_octopus`]; the name
145/// documents the `git merge-base A B C ...` calling convention.
146pub fn merge_bases_all(repo: &Repository, commits: &[ObjectId]) -> Result<Vec<ObjectId>> {
147    merge_bases_octopus(repo, commits)
148}
149
150/// Check whether `ancestor` is reachable from `descendant`.
151///
152/// # Errors
153///
154/// Returns parse and object read errors from commit traversal.
155pub fn is_ancestor(repo: &Repository, ancestor: ObjectId, descendant: ObjectId) -> Result<bool> {
156    let mut cache = CommitGraphCache::new(repo);
157    if ancestor == descendant {
158        return Ok(true);
159    }
160    Ok(cache.ancestor_closure(descendant)?.contains(&ancestor))
161}
162
163/// Returns the ref path under `logs/` used for fork-point reflog scanning for `merge-base --fork-point`
164/// and `rebase --fork-point`, matching Git's resolution order.
165///
166/// # Parameters
167///
168/// - `spec` - upstream argument as given on the command line (`main`, `refs/heads/main`, `HEAD`, …).
169pub fn resolve_fork_point_reflog_ref(repo: &Repository, spec: &str) -> String {
170    if spec == "HEAD" || spec.starts_with("refs/") {
171        return spec.to_string();
172    }
173
174    let logs_dir = repo.git_dir.join("logs");
175    let candidates = [
176        spec.to_string(),
177        format!("refs/heads/{spec}"),
178        format!("refs/remotes/{spec}"),
179    ];
180
181    for candidate in candidates {
182        if logs_dir.join(&candidate).is_file() {
183            return candidate;
184        }
185    }
186
187    format!("refs/heads/{spec}")
188}
189
190/// Picks the fork-point candidate that is not strictly dominated by another candidate in the list.
191fn select_best_fork_point(repo: &Repository, candidates: &[ObjectId]) -> Result<Option<ObjectId>> {
192    if candidates.is_empty() {
193        return Ok(None);
194    }
195
196    let mut best = HashSet::new();
197    for &candidate in candidates {
198        let mut dominated = false;
199        for &other in candidates {
200            if candidate == other {
201                continue;
202            }
203            if is_ancestor(repo, candidate, other)? {
204                dominated = true;
205                break;
206            }
207        }
208        if !dominated {
209            best.insert(candidate);
210        }
211    }
212
213    Ok(candidates.iter().copied().find(|oid| best.contains(oid)))
214}
215
216/// Computes the fork-point commit between `upstream_tip` and `head`, using the upstream ref's reflog.
217///
218/// This matches `git merge-base --fork-point` / the merge base `git rebase --fork-point` uses for
219/// selecting commits to replay.
220///
221/// # Parameters
222///
223/// - `upstream_spec` - upstream revision string (used to locate the reflog; e.g. `main`,
224///   `refs/heads/main`, or `topic@{{upstream}}`).
225/// - `upstream_tip` - resolved commit of the upstream branch tip.
226/// - `head` - commit to rebase (usually `HEAD`).
227///
228/// # Errors
229///
230/// Propagates object read, reflog, and graph walk errors.
231pub fn fork_point(
232    repo: &Repository,
233    upstream_spec: &str,
234    upstream_tip: ObjectId,
235    head: ObjectId,
236) -> Result<ObjectId> {
237    let reflog_ref = if upstream_suffix_info(upstream_spec).is_some() {
238        resolve_upstream_symbolic_name(repo, upstream_spec)?
239    } else {
240        resolve_fork_point_reflog_ref(repo, upstream_spec)
241    };
242
243    let entries = read_reflog(&repo.git_dir, &reflog_ref)
244        .map_err(|e| Error::Message(format!("failed to read reflog for '{reflog_ref}': {e}")))?;
245
246    let mut candidates = Vec::new();
247    let mut seen = HashSet::new();
248
249    for entry in entries.iter().rev() {
250        let oid = if entry.message.starts_with("checkout:") {
251            entry.old_oid
252        } else {
253            entry.new_oid
254        };
255        if !seen.insert(oid) {
256            continue;
257        }
258        if is_ancestor(repo, oid, head)? {
259            candidates.push(oid);
260        }
261    }
262
263    if let Some(fp) = select_best_fork_point(repo, &candidates)? {
264        return Ok(fp);
265    }
266
267    let mut bases = merge_bases_first_vs_rest(repo, upstream_tip, &[head])?;
268    if bases.is_empty() {
269        return Err(Error::Message(
270            "no merge base found between upstream and HEAD".to_owned(),
271        ));
272    }
273    bases.sort();
274    Ok(bases[0])
275}
276
277/// Returns every commit reachable from `tip` by walking parent links (including `tip`).
278///
279/// # Errors
280///
281/// Returns [`Error::CorruptObject`] if an encountered object is not a commit.
282pub fn ancestor_closure(repo: &Repository, tip: ObjectId) -> Result<HashSet<ObjectId>> {
283    let mut cache = CommitGraphCache::new(repo);
284    cache.ancestor_closure(tip)
285}
286
287/// Count symmetric-diff commits between two tips, matching `git rev-list --left-right A...B`.
288///
289/// Returns `(ahead, behind)` where `ahead` counts commits reachable from `local` but not from
290/// `other`, and `behind` the converse. Shared history is excluded from both counts.
291///
292/// # Errors
293///
294/// Propagates errors from commit graph walks.
295pub fn count_symmetric_ahead_behind(
296    repo: &Repository,
297    local: ObjectId,
298    other: ObjectId,
299) -> Result<(usize, usize)> {
300    let left = ancestor_closure(repo, local)?;
301    let right = ancestor_closure(repo, other)?;
302    let ahead = left.difference(&right).count();
303    let behind = right.difference(&left).count();
304    Ok((ahead, behind))
305}
306
307/// Return commits that are not reachable from any other input commit.
308///
309/// The output order follows input order, dropping any commit reachable from
310/// another supplied commit.
311///
312/// # Errors
313///
314/// Returns parse and object read errors from commit traversal.
315pub fn independent_commits(repo: &Repository, commits: &[ObjectId]) -> Result<Vec<ObjectId>> {
316    let mut cache = CommitGraphCache::new(repo);
317    let mut out = Vec::new();
318    for (i, &candidate) in commits.iter().enumerate() {
319        let mut reachable = false;
320        for (j, &other) in commits.iter().enumerate() {
321            if i == j {
322                continue;
323            }
324            if cache.ancestor_closure(other)?.contains(&candidate) {
325                reachable = true;
326                break;
327            }
328        }
329        if !reachable {
330            out.push(candidate);
331        }
332    }
333    Ok(out)
334}
335
336fn ensure_is_commit(repo: &Repository, oid: ObjectId) -> Result<()> {
337    let object = repo.odb.read(&oid)?;
338    if object.kind != ObjectKind::Commit {
339        return Err(Error::CorruptObject(format!(
340            "object {oid} is not a commit"
341        )));
342    }
343    Ok(())
344}
345
346fn reduce_to_best(
347    candidates: HashSet<ObjectId>,
348    cache: &mut CommitGraphCache<'_>,
349) -> Result<Vec<ObjectId>> {
350    if candidates.is_empty() {
351        return Ok(Vec::new());
352    }
353    let mut best = BTreeSet::new();
354    for &candidate in &candidates {
355        let mut better_found = false;
356        for &other in &candidates {
357            if candidate == other {
358                continue;
359            }
360            if cache.ancestor_closure(other)?.contains(&candidate) {
361                better_found = true;
362                break;
363            }
364        }
365        if !better_found {
366            best.insert(candidate);
367        }
368    }
369    Ok(best.into_iter().collect())
370}
371
372struct CommitGraphCache<'r> {
373    repo: &'r Repository,
374    parents: HashMap<ObjectId, Vec<ObjectId>>,
375    closures: HashMap<ObjectId, HashSet<ObjectId>>,
376    promisor_stop: std::collections::HashSet<ObjectId>,
377}
378
379impl<'r> CommitGraphCache<'r> {
380    fn new(repo: &'r Repository) -> Self {
381        let cfg = ConfigSet::load(Some(&repo.git_dir), true).unwrap_or_default();
382        let promisor_stop = if repo_treats_promisor_packs(&repo.git_dir, &cfg) {
383            promisor_pack_object_ids(&repo.git_dir.join("objects"))
384        } else {
385            HashSet::new()
386        };
387        Self {
388            repo,
389            parents: HashMap::new(),
390            closures: HashMap::new(),
391            promisor_stop,
392        }
393    }
394
395    fn ancestor_closure(&mut self, start: ObjectId) -> Result<HashSet<ObjectId>> {
396        if let Some(existing) = self.closures.get(&start) {
397            return Ok(existing.clone());
398        }
399
400        let mut visited = HashSet::new();
401        let mut queue = VecDeque::new();
402        queue.push_back(start);
403        while let Some(oid) = queue.pop_front() {
404            if !visited.insert(oid) {
405                continue;
406            }
407            for parent in self.parents_of(oid)? {
408                queue.push_back(parent);
409            }
410        }
411        self.closures.insert(start, visited.clone());
412        Ok(visited)
413    }
414
415    fn parents_of(&mut self, oid: ObjectId) -> Result<Vec<ObjectId>> {
416        if let Some(parents) = self.parents.get(&oid) {
417            return Ok(parents.clone());
418        }
419        let commit_oid = peel_to_commit_for_merge_base(self.repo, oid).map_err(|e| match e {
420            Error::InvalidRef(msg) => Error::CorruptObject(msg),
421            other => other,
422        })?;
423        let object = match self.repo.odb.read(&commit_oid) {
424            Ok(o) => o,
425            Err(Error::ObjectNotFound(_)) => {
426                self.parents.insert(oid, Vec::new());
427                return Ok(Vec::new());
428            }
429            Err(e) => return Err(e),
430        };
431        if object.kind != ObjectKind::Commit {
432            return Err(Error::CorruptObject(format!(
433                "object {commit_oid} is not a commit"
434            )));
435        }
436        let commit = parse_commit(&object.data)?;
437        let parents: Vec<ObjectId> = commit
438            .parents
439            .iter()
440            .copied()
441            .filter(|p| !self.promisor_stop.contains(p))
442            .collect();
443        self.parents.insert(oid, parents.clone());
444        Ok(parents)
445    }
446}