git_stack/graph/
branch.rs

1use std::collections::BTreeMap;
2
3#[derive(Clone, Debug)]
4pub struct BranchSet {
5    branches: BTreeMap<git2::Oid, Vec<Branch>>,
6}
7
8impl BranchSet {
9    pub fn from_repo(
10        repo: &dyn crate::git::Repo,
11        protected: &crate::git::ProtectedBranches,
12    ) -> crate::git::Result<Self> {
13        let mut branches = Self::new();
14        for mut branch in repo.local_branches().map(Branch::from) {
15            if protected.is_protected(branch.base_name()) {
16                log::trace!("Branch `{}` is protected", branch.display_name());
17                if let Some(remote) =
18                    repo.find_remote_branch(repo.pull_remote(), branch.base_name())
19                {
20                    branch.set_kind(BranchKind::Mixed);
21                    branch.set_pull_id(remote.id);
22                    let mut remote: Branch = remote.into();
23                    remote.set_kind(BranchKind::Protected);
24                    branches.insert(remote);
25                } else {
26                    branch.set_kind(BranchKind::Protected);
27                }
28            } else {
29                if let Some(remote) =
30                    repo.find_remote_branch(repo.push_remote(), branch.base_name())
31                {
32                    branch.set_push_id(remote.id);
33                }
34                branch.set_kind(BranchKind::Mutable);
35            }
36            branches.insert(branch);
37        }
38        Ok(branches)
39    }
40
41    pub fn update(&mut self, repo: &dyn crate::git::Repo) -> crate::git::Result<()> {
42        let mut branches = Self::new();
43        for old_branch in self.branches.values().flatten() {
44            let new_branch = if let Some(remote) = old_branch.remote() {
45                repo.find_remote_branch(remote, old_branch.base_name())
46            } else {
47                repo.find_local_branch(old_branch.base_name())
48            };
49            let new_branch = if let Some(mut new_branch) = new_branch.map(Branch::from) {
50                new_branch.kind = old_branch.kind;
51                new_branch.pull_id = old_branch.pull_id.and_then(|_| {
52                    repo.find_remote_branch(repo.pull_remote(), old_branch.base_name())
53                        .map(|b| b.id)
54                });
55                new_branch.push_id = old_branch.push_id.and_then(|_| {
56                    repo.find_remote_branch(repo.push_remote(), old_branch.base_name())
57                        .map(|b| b.id)
58                });
59                if new_branch.id() != old_branch.id() {
60                    log::debug!(
61                        "{} moved from {} to {}",
62                        new_branch.display_name(),
63                        old_branch.id(),
64                        new_branch.id()
65                    );
66                }
67                new_branch
68            } else {
69                log::debug!("{} no longer exists", old_branch.display_name());
70                let mut old_branch = old_branch.clone();
71                old_branch.kind = BranchKind::Deleted;
72                old_branch.pull_id = None;
73                old_branch.push_id = None;
74                old_branch
75            };
76            branches.insert(new_branch);
77        }
78        *self = branches;
79        Ok(())
80    }
81}
82
83impl BranchSet {
84    pub fn new() -> Self {
85        Self {
86            branches: Default::default(),
87        }
88    }
89
90    pub fn insert(&mut self, mut branch: Branch) -> Option<Branch> {
91        let id = branch.id();
92        let branches = self.branches.entry(id).or_default();
93
94        let mut existing_index = None;
95        for (i, current) in branches.iter().enumerate() {
96            if current.core == branch.core {
97                existing_index = Some(i);
98                break;
99            }
100        }
101
102        if let Some(existing_index) = existing_index {
103            std::mem::swap(&mut branch, &mut branches[existing_index]);
104            Some(branch)
105        } else {
106            branches.push(branch);
107            None
108        }
109    }
110
111    pub fn remove(&mut self, oid: git2::Oid) -> Option<Vec<Branch>> {
112        self.branches.remove(&oid)
113    }
114
115    pub fn contains_oid(&self, oid: git2::Oid) -> bool {
116        self.branches.contains_key(&oid)
117    }
118
119    pub fn get(&self, oid: git2::Oid) -> Option<&[Branch]> {
120        self.branches.get(&oid).map(|v| v.as_slice())
121    }
122
123    pub fn get_mut(&mut self, oid: git2::Oid) -> Option<&mut [Branch]> {
124        self.branches.get_mut(&oid).map(|v| v.as_mut_slice())
125    }
126
127    pub fn is_empty(&self) -> bool {
128        self.branches.is_empty()
129    }
130
131    pub fn len(&self) -> usize {
132        self.branches.len()
133    }
134
135    pub fn iter(&self) -> impl Iterator<Item = (git2::Oid, &[Branch])> + '_ {
136        self.branches
137            .iter()
138            .map(|(oid, branch)| (*oid, branch.as_slice()))
139    }
140
141    pub fn oids(&self) -> impl Iterator<Item = git2::Oid> + '_ {
142        self.branches.keys().copied()
143    }
144}
145
146impl BranchSet {
147    pub fn all(&self) -> Self {
148        self.clone()
149    }
150
151    pub fn descendants(&self, repo: &dyn crate::git::Repo, base_oid: git2::Oid) -> Self {
152        let mut branches = Self::new();
153        for (branch_oid, branch) in &self.branches {
154            let is_base_descendant = repo
155                .merge_base(*branch_oid, base_oid)
156                .map(|merge_oid| merge_oid == base_oid)
157                .unwrap_or(false);
158            if is_base_descendant {
159                branches.insert_entry(self, *branch_oid, branch);
160            } else {
161                let first_branch = &branch.first().expect("we always have at least one branch");
162                log::trace!(
163                    "Branch {} is not on the branch of {}",
164                    first_branch.display_name(),
165                    base_oid
166                );
167            }
168        }
169        branches
170    }
171
172    pub fn dependents(
173        &self,
174        repo: &dyn crate::git::Repo,
175        base_oid: git2::Oid,
176        head_oid: git2::Oid,
177    ) -> Self {
178        let mut branches = Self::new();
179        for (branch_oid, branch) in &self.branches {
180            let is_shared_base = repo
181                .merge_base(*branch_oid, head_oid)
182                .map(|merge_oid| merge_oid == base_oid && *branch_oid != base_oid)
183                .unwrap_or(false);
184            let is_base_descendant = repo
185                .merge_base(*branch_oid, base_oid)
186                .map(|merge_oid| merge_oid == base_oid)
187                .unwrap_or(false);
188            if is_shared_base {
189                let first_branch = &branch.first().expect("we always have at least one branch");
190                log::trace!(
191                    "Branch {} is not on the branch of HEAD ({})",
192                    first_branch.display_name(),
193                    head_oid
194                );
195            } else if !is_base_descendant {
196                let first_branch = &branch.first().expect("we always have at least one branch");
197                log::trace!(
198                    "Branch {} is not on the branch of {}",
199                    first_branch.display_name(),
200                    base_oid
201                );
202            } else {
203                branches.insert_entry(self, *branch_oid, branch);
204            }
205        }
206        branches
207    }
208
209    pub fn branch(
210        &self,
211        repo: &dyn crate::git::Repo,
212        base_oid: git2::Oid,
213        head_oid: git2::Oid,
214    ) -> Self {
215        let mut branches = Self::new();
216        for (branch_oid, branch) in &self.branches {
217            let is_head_ancestor = repo
218                .merge_base(*branch_oid, head_oid)
219                .map(|merge_oid| *branch_oid == merge_oid)
220                .unwrap_or(false);
221            let is_base_descendant = repo
222                .merge_base(*branch_oid, base_oid)
223                .map(|merge_oid| merge_oid == base_oid)
224                .unwrap_or(false);
225            if !is_head_ancestor {
226                let first_branch = &branch.first().expect("we always have at least one branch");
227                log::trace!(
228                    "Branch {} is not on the branch of HEAD ({})",
229                    first_branch.display_name(),
230                    head_oid
231                );
232            } else if !is_base_descendant {
233                let first_branch = &branch.first().expect("we always have at least one branch");
234                log::trace!(
235                    "Branch {} is not on the branch of {}",
236                    first_branch.display_name(),
237                    base_oid
238                );
239            } else {
240                branches.insert_entry(self, *branch_oid, branch);
241            }
242        }
243        branches
244    }
245
246    fn insert_entry(&mut self, old: &Self, branch_oid: git2::Oid, branch: &[Branch]) {
247        self.branches.insert(branch_oid, branch.to_vec());
248        for mixed_branch in branch.iter().filter(|b| b.kind() == BranchKind::Mixed) {
249            let Some(remote_branch_oid) = mixed_branch.pull_id() else {
250                continue;
251            };
252            let Some(potential_remotes) = old.get(remote_branch_oid).map(|b| b.to_vec()) else {
253                continue;
254            };
255            for potential_remote in potential_remotes {
256                if potential_remote.kind() == BranchKind::Protected
257                    && potential_remote.base_name() == mixed_branch.base_name()
258                {
259                    self.insert(potential_remote);
260                }
261            }
262        }
263    }
264}
265
266impl Default for BranchSet {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272impl IntoIterator for BranchSet {
273    type Item = (git2::Oid, Vec<Branch>);
274    type IntoIter = std::collections::btree_map::IntoIter<git2::Oid, Vec<Branch>>;
275
276    fn into_iter(self) -> Self::IntoIter {
277        self.branches.into_iter()
278    }
279}
280
281impl Extend<Branch> for BranchSet {
282    fn extend<T: IntoIterator<Item = Branch>>(&mut self, iter: T) {
283        for branch in iter {
284            self.insert(branch);
285        }
286    }
287}
288
289#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
290pub struct Branch {
291    core: crate::git::Branch,
292    kind: BranchKind,
293    pull_id: Option<git2::Oid>,
294    push_id: Option<git2::Oid>,
295}
296
297impl Branch {
298    pub fn set_kind(&mut self, kind: BranchKind) -> &mut Self {
299        self.kind = kind;
300        self
301    }
302
303    pub fn set_id(&mut self, id: git2::Oid) -> &mut Self {
304        self.core.id = id;
305        self
306    }
307
308    pub fn set_pull_id(&mut self, pull_id: git2::Oid) -> &mut Self {
309        self.pull_id = Some(pull_id);
310        self
311    }
312
313    pub fn set_push_id(&mut self, push_id: git2::Oid) -> &mut Self {
314        self.push_id = Some(push_id);
315        self
316    }
317}
318
319impl Branch {
320    pub fn git(&self) -> &crate::git::Branch {
321        &self.core
322    }
323
324    pub fn name(&self) -> String {
325        self.core.to_string()
326    }
327
328    pub fn display_name(&self) -> impl std::fmt::Display + '_ {
329        &self.core
330    }
331
332    pub fn remote(&self) -> Option<&str> {
333        self.core.remote.as_deref()
334    }
335
336    pub fn base_name(&self) -> &str {
337        &self.core.name
338    }
339
340    pub fn local_name(&self) -> Option<&str> {
341        self.core.local_name()
342    }
343
344    pub fn kind(&self) -> BranchKind {
345        self.kind
346    }
347
348    pub fn id(&self) -> git2::Oid {
349        self.core.id
350    }
351
352    pub fn pull_id(&self) -> Option<git2::Oid> {
353        self.pull_id
354    }
355
356    pub fn push_id(&self) -> Option<git2::Oid> {
357        self.push_id
358    }
359}
360
361impl From<crate::git::Branch> for Branch {
362    fn from(core: crate::git::Branch) -> Self {
363        Self {
364            core,
365            kind: BranchKind::Deleted,
366            pull_id: None,
367            push_id: None,
368        }
369    }
370}
371
372impl PartialEq<crate::git::Branch> for Branch {
373    fn eq(&self, other: &crate::git::Branch) -> bool {
374        self.core == *other
375    }
376}
377
378#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
379pub enum BranchKind {
380    // Of no interest
381    Deleted,
382    // Completely mutable
383    Mutable,
384    // Local is mutable, remote is protected
385    Mixed,
386    // Must not touch
387    Protected,
388}
389
390impl BranchKind {
391    pub fn has_user_commits(self) -> bool {
392        match self {
393            Self::Deleted => false,
394            Self::Mutable => true,
395            Self::Mixed => true,
396            Self::Protected => false,
397        }
398    }
399}
400
401pub fn find_protected_base<'b>(
402    repo: &dyn crate::git::Repo,
403    branches: &'b BranchSet,
404    head_oid: git2::Oid,
405) -> Option<&'b Branch> {
406    // We're being asked about a protected branch
407    if let Some(head_branches) = branches.get(head_oid) {
408        if let Some(head_branch) = head_branches
409            .iter()
410            .find(|b| b.kind() == BranchKind::Protected)
411        {
412            return Some(head_branch);
413        }
414    }
415
416    let protected_base_oids = branches
417        .iter()
418        .filter_map(|(id, b)| {
419            b.iter()
420                .find(|b| b.kind() == BranchKind::Protected)
421                .map(|_| id)
422        })
423        .filter_map(|oid| {
424            let merge_oid = repo.merge_base(head_oid, oid)?;
425            Some((merge_oid, oid))
426        })
427        .collect::<Vec<_>>();
428
429    // Not much choice for applicable base
430    match protected_base_oids.len() {
431        0 => {
432            return None;
433        }
434        1 => {
435            let (_, protected_oid) = protected_base_oids[0];
436            let protected_branch = branches
437                .get(protected_oid)
438                .expect("protected_oid came from protected_branches")
439                .iter()
440                .find(|b| b.kind() == BranchKind::Protected)
441                .expect("protected_branches has at least one protected branch");
442            return Some(protected_branch);
443        }
444        _ => {}
445    }
446
447    // Prefer protected branch from first parent
448    let mut next_oid = Some(head_oid);
449    while let Some(parent_oid) = next_oid {
450        if let Some((_, closest_common_oid)) = protected_base_oids
451            .iter()
452            .filter(|(base, _)| *base == parent_oid)
453            .min_by_key(|(base, branch)| {
454                (
455                    repo.commit_count(*base, head_oid),
456                    repo.commit_count(*base, *branch),
457                )
458            })
459        {
460            let protected_branch = branches
461                .get(*closest_common_oid)
462                .expect("protected_oid came from protected_branches")
463                .iter()
464                .find(|b| b.kind() == BranchKind::Protected)
465                .expect("protected_branches has at least one protected branch");
466            return Some(protected_branch);
467        }
468        next_oid = repo
469            .parent_ids(parent_oid)
470            .expect("child_oid came from verified source")
471            .first()
472            .copied();
473    }
474
475    // Prefer most direct ancestors
476    if let Some((_, closest_common_oid)) =
477        protected_base_oids.iter().min_by_key(|(base, protected)| {
478            let to_protected = repo.commit_count(*base, *protected);
479            let to_head = repo.commit_count(*base, head_oid);
480            (to_protected, to_head)
481        })
482    {
483        let protected_branch = branches
484            .get(*closest_common_oid)
485            .expect("protected_oid came from protected_branches")
486            .iter()
487            .find(|b| b.kind() == BranchKind::Protected)
488            .expect("protected_branches has at least one protected branch");
489        return Some(protected_branch);
490    }
491
492    None
493}
494
495pub fn infer_base(repo: &dyn crate::git::Repo, head_oid: git2::Oid) -> Option<git2::Oid> {
496    let head_commit = repo.find_commit(head_oid)?;
497    let head_committer = head_commit.committer.clone();
498
499    let mut next_oid = head_oid;
500    loop {
501        let next_commit = repo.find_commit(next_oid)?;
502        if next_commit.committer != head_committer {
503            return Some(next_oid);
504        }
505        let parent_ids = repo.parent_ids(next_oid).ok()?;
506        match parent_ids.len() {
507            1 => {
508                next_oid = parent_ids[0];
509            }
510            _ => {
511                // Assume merge-commits are topic branches being merged into the upstream
512                return Some(next_oid);
513            }
514        }
515    }
516}